##########################
# 競合学習
# coded by Y.Suganuma
##########################
#############################
# Competitionクラスの定義
# coded by Y.Suganuma
#############################
class Competition
#########################
# コンストラクタ
# max : 最大学習回数
# n : 訓練例の数
# o : 出力セルの数
# p : 入力セルの数
# sig : 重み修正係数
#########################
def initialize(max_i, n_i, o_i, p_i, sig_i)
# 設定
@_max = max_i
@_n = n_i
@_o = o_i
@_p = p_i
@_sig = sig_i
# 領域の確保
@_e = Array.new(@_n) # 訓練例
for i1 in 0 ...@_n
@_e[i1] = Array.new(@_p)
end
@_w = Array.new(@_o) # 重み
for i1 in 0 ...@_o
@_w[i1] = Array.new(@_p)
end
end
#*************************/
# 学習データの読み込み */
# name : ファイル名 */
#*************************/
def Input(name)
f = open(name, "r")
for i1 in 0 ... @_n
s = f.gets().split(" ")
for i2 in 0 ... @_p
@_e[i1][i2] = Integer(s[i2])
end
end
f.close()
end
#################################
# 学習と結果の出力
# pr : =0 : 画面に出力
# =1 : ファイルに出力
# name : 出力ファイル名
#################################
def Learn(pr, name="")
# 初期設定
mx_v = 0.0
mx = 0
for i1 in 0 ... @_o
sum = 0.0
for i2 in 0 ... @_p
@_w[i1][i2] = rand(0)
sum += @_w[i1][i2]
end
sum = 1.0 / sum
for i2 in 0 ... @_p
@_w[i1][i2] *= sum
end
end
# 学習
for count in 0 ... @_max
# 訓練例の選択
k = Integer(rand(0) * @_n)
if k >= @_n
k = @_n - 1
end
# 出力の計算
for i1 in 0 ... @_o
s = 0.0
for i2 in 0 ... @_p
s += @_w[i1][i2] * @_e[k][i2]
end
if i1 == 0 or s > mx_v
mx = i1
mx_v = s
end
end
# 重みの修正
sum = 0.0
for i1 in 0 ... @_p
sum += @_e[k][i1]
end
for i1 in 0 ... @_p
@_w[mx][i1] += @_sig * (@_e[k][i1] / sum - @_w[mx][i1])
end
end
# 出力
if pr == 0
out = $stdout
else
out = open(name, "w")
end
out.print("分類結果\n")
for i1 in 0 ... @_n
for i2 in 0 ... @_p
out.print(" " + String(@_e[i1][i2]))
end
out.print(" Res ")
for i2 in 0 ... @_o
s = 0.0
for i3 in 0 ... @_p
s += @_w[i2][i3] * @_e[i1][i3]
end
if i2 == 0 or s > mx_v
mx = i2
mx_v = s
end
end
out.print(" " + String(mx+1) + "\n")
end
end
end
if ARGV[0] != nil
# 基本データの入力
s = gets().split(" ")
max = Integer(s[1])
p = Integer(s[3])
o = Integer(s[5])
n = Integer(s[7])
sig = Float(s[9])
s = gets().split(" ")
name = s[1]
# ネットワークの定義
srand()
net = Competition.new(max, n, o, p, sig)
net.Input(name)
# 学習と結果の出力
if ARGV[0] == nil
net.Learn(0)
else
net.Learn(1, ARGV[0])
end
else
print("***error 入力ファイル名を指定して下さい\n")
end
=begin
------------------------入力ファイル--------------
最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 係数(0~1) 0.1
入力データファイル pat.dat
------------------------pat.dat-------------------
0 0 0 0 0 0 0 1 1
0 0 0 0 0 0 1 0 1
0 0 0 0 0 0 1 1 0
0 0 0 0 1 1 0 0 0
0 0 0 1 0 1 0 0 0
0 0 0 1 1 0 0 0 0
0 1 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 0
1 1 0 0 0 0 0 0 0
=end