########################## # 競合学習 # coded by Y.Suganuma ########################## ############################# # Competitionクラスの定義 # coded by Y.Suganuma ############################# class Competition ######################### # コンストラクタ # max : 最大学習回数 # n : 訓練例の数 # o : 出力セルの数 # p : 入力セルの数 # sig : 重み修正係数 ######################### def initialize(max_i, n_i, o_i, p_i, sig_i) # 設定 @_max = max_i @_n = n_i @_o = o_i @_p = p_i @_sig = sig_i # 領域の確保 @_e = Array.new(@_n) # 訓練例 for i1 in 0 ...@_n @_e[i1] = Array.new(@_p) end @_w = Array.new(@_o) # 重み for i1 in 0 ...@_o @_w[i1] = Array.new(@_p) end end #*************************/ # 学習データの読み込み */ # name : ファイル名 */ #*************************/ def Input(name) f = open(name, "r") for i1 in 0 ... @_n s = f.gets().split(" ") for i2 in 0 ... @_p @_e[i1][i2] = Integer(s[i2]) end end f.close() end ################################# # 学習と結果の出力 # pr : =0 : 画面に出力 # =1 : ファイルに出力 # name : 出力ファイル名 ################################# def Learn(pr, name="") # 初期設定 mx_v = 0.0 mx = 0 for i1 in 0 ... @_o sum = 0.0 for i2 in 0 ... @_p @_w[i1][i2] = rand(0) sum += @_w[i1][i2] end sum = 1.0 / sum for i2 in 0 ... @_p @_w[i1][i2] *= sum end end # 学習 for count in 0 ... @_max # 訓練例の選択 k = Integer(rand(0) * @_n) if k >= @_n k = @_n - 1 end # 出力の計算 for i1 in 0 ... @_o s = 0.0 for i2 in 0 ... @_p s += @_w[i1][i2] * @_e[k][i2] end if i1 == 0 or s > mx_v mx = i1 mx_v = s end end # 重みの修正 sum = 0.0 for i1 in 0 ... @_p sum += @_e[k][i1] end for i1 in 0 ... @_p @_w[mx][i1] += @_sig * (@_e[k][i1] / sum - @_w[mx][i1]) end end # 出力 if pr == 0 out = $stdout else out = open(name, "w") end out.print("分類結果\n") for i1 in 0 ... @_n for i2 in 0 ... @_p out.print(" " + String(@_e[i1][i2])) end out.print(" Res ") for i2 in 0 ... @_o s = 0.0 for i3 in 0 ... @_p s += @_w[i2][i3] * @_e[i1][i3] end if i2 == 0 or s > mx_v mx = i2 mx_v = s end end out.print(" " + String(mx+1) + "\n") end end end if ARGV[0] != nil # 基本データの入力 s = gets().split(" ") max = Integer(s[1]) p = Integer(s[3]) o = Integer(s[5]) n = Integer(s[7]) sig = Float(s[9]) s = gets().split(" ") name = s[1] # ネットワークの定義 srand() net = Competition.new(max, n, o, p, sig) net.Input(name) # 学習と結果の出力 if ARGV[0] == nil net.Learn(0) else net.Learn(1, ARGV[0]) end else print("***error 入力ファイル名を指定して下さい\n") end =begin ------------------------入力ファイル-------------- 最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 係数(0~1) 0.1 入力データファイル pat.dat ------------------------pat.dat------------------- 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 =end