競合学習

##########################
# 競合学習
#      coded by Y.Suganuma
##########################

#############################
# Competitionクラスの定義
#      coded by Y.Suganuma
#############################

class Competition

	#########################
	# コンストラクタ
	#      max : 最大学習回数
	#      n : 訓練例の数
	#      o : 出力セルの数
	#      p : 入力セルの数
	#      sig : 重み修正係数
	#########################

	def initialize(max_i, n_i, o_i, p_i, sig_i)
			# 設定
		@_max = max_i
		@_n   = n_i
		@_o   = o_i
		@_p   = p_i
		@_sig = sig_i
			# 領域の確保
		@_e = Array.new(@_n)   # 訓練例
		for i1 in 0 ...@_n
			@_e[i1] = Array.new(@_p)
		end
		@_w = Array.new(@_o)   # 重み
		for i1 in 0 ...@_o
			@_w[i1] = Array.new(@_p)
		end
	end
	
	#*************************/
	# 学習データの読み込み   */
	#      name : ファイル名 */
	#*************************/
	
	def Input(name)
	
		f = open(name, "r")
	
		for i1 in 0 ... @_n
			s = f.gets().split(" ")
			for i2 in 0 ... @_p
				@_e[i1][i2] = Integer(s[i2])
			end
		end
	
		f.close()
	end
	
	#################################
	# 学習と結果の出力
	#      pr : =0 : 画面に出力
	#           =1 : ファイルに出力
	#      name : 出力ファイル名
	#################################
	
	def Learn(pr, name="")
			# 初期設定
		mx_v = 0.0
		mx   = 0
		for i1 in 0 ... @_o
			sum = 0.0
			for i2 in 0 ... @_p
				@_w[i1][i2]  = rand(0)
				sum         += @_w[i1][i2]
			end
			sum = 1.0 / sum
			for i2 in 0 ... @_p
				@_w[i1][i2] *= sum
			end
		end
			# 学習
		for count in 0 ... @_max
					# 訓練例の選択
			k = Integer(rand(0) * @_n)
			if k >= @_n
				k = @_n - 1
			end
					# 出力の計算
			for i1 in 0 ... @_o
				s = 0.0
				for i2 in 0 ... @_p
					s += @_w[i1][i2] * @_e[k][i2]
				end
				if i1 == 0 or s > mx_v
					mx   = i1
					mx_v = s
				end
			end
					# 重みの修正
			sum = 0.0
			for i1 in 0 ... @_p
				sum += @_e[k][i1]
			end
			for i1 in 0 ... @_p
				@_w[mx][i1] += @_sig * (@_e[k][i1] / sum - @_w[mx][i1])
			end
		end
			# 出力
		if pr == 0
			out = $stdout
		else
			out = open(name, "w")
		end
	
		out.print("分類結果\n")
		for i1 in 0 ... @_n
			for i2 in 0 ... @_p
				out.print(" " + String(@_e[i1][i2]))
			end
			out.print(" Res ")
			for i2 in 0 ... @_o
				s = 0.0
				for i3 in 0 ... @_p
					s += @_w[i2][i3] * @_e[i1][i3]
				end
				if i2 == 0 or s > mx_v
					mx   = i2
					mx_v = s
				end
			end
			out.print(" " + String(mx+1) + "\n")
		end
	end
end

if ARGV[0] != nil
				# 基本データの入力
	s    = gets().split(" ")
	max  = Integer(s[1])
	p    = Integer(s[3])
	o    = Integer(s[5])
	n    = Integer(s[7])
	sig  = Float(s[9])
	s    = gets().split(" ")
	name = s[1]
				# ネットワークの定義
	srand()
	net = Competition.new(max, n, o, p, sig)
	net.Input(name)
				# 学習と結果の出力
	if ARGV[0] == nil
		net.Learn(0)
	else
		net.Learn(1, ARGV[0])
	end

else
	print("***error   入力ファイル名を指定して下さい\n")
end

=begin
------------------------入力ファイル--------------
最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 係数(0~1) 0.1
入力データファイル pat.dat

------------------------pat.dat-------------------
0 0 0 0 0 0 0 1 1
0 0 0 0 0 0 1 0 1
0 0 0 0 0 0 1 1 0
0 0 0 0 1 1 0 0 0
0 0 0 1 0 1 0 0 0
0 0 0 1 1 0 0 0 0
0 1 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 0
1 1 0 0 0 0 0 0 0
=end