バックプロパゲーション

##########################
# back propagation model
#      coded by Y.Suganuma
##########################

######################################################
# バックプロパゲーションの制御(クラス BackControl)
######################################################

class BackControl
	
	####################################
	# クラスBackControlのコンストラクタ
	#      name : 入力データファイル名
	####################################
	
	def initialize(name)
	
		f        = open(name, "r")
		s        = f.gets().split(" ")
		@_eps    = Float(s[1])   # 許容誤差
		@_p_type = Integer(s[3])   # 出力先・方法の指定
	                               #   =0 : 誤って認識した数だけ出力
	                               #   =1 : 認識結果を出力
	                               #   =2 : 認識結果と重みを出力
	                               #        (負の時は,認識結果と重みをファイルへも出力)
		if @_p_type < 0
			@_o_file = s[5]   # 出力ファイル名
		end
		s       = f.gets().split(" ")
		@_order = Integer(s[1])   # 入力パターンの与え方(=0:順番,=1:ランダム)
		@_eata  = Float(s[3])   # 重み及びバイアス修正パラメータ
		@_alpha = Float(s[5])   # 重み及びバイアス修正パラメータ
		f.close()
		srand()
	end
end
	
######################################################
# バックプロパゲーションのデータ(クラス BackData)
######################################################

class BackData
	
	####################################
	# クラスBackDataのコンストラクタ
	#      name : 入力データファイル名
	####################################
	
	def initialize(name)
				# 入力パターン数等
		f      = open(name, "r")
		s      = f.gets().split(" ")
		@_noip = Integer(s[1])   # 入力パターンの数
		@_noiu = Integer(s[3])   # 入力ユニットの数
		@_noou = Integer(s[5])   # 出力ユニットの数
				# 領域の確保
		@_iptn = Array.new(@_noip)
		for i1 in 0 ... @_noip
			@_iptn[i1] = Array.new(@_noiu)
		end
	                      # iptn[i][j] : (i+1)番目の入力パターンの(j+1)番目の
	                      #              入力ユニットの入力値
	                      #                i=0,noip-1  j=0,noiu-1
		@_optn = Array.new(@_noip)
		for i1 in 0 ... @_noip
			@_optn[i1] = Array.new(@_noou)
		end
	                      # optn[i][j] : (i+1)番目の入力パターンに対する(j+1)
	                      #              番目の出力ユニットの目標出力値
	                      #                i=0,noip-1  j=0,noou-1
				# 入力パターン及び各入力パターンに対する出力パターンの入力
		for i1 in 0 ... @_noip
			s = f.gets().split(" ")
			for i2 in 0 ... @_noiu
				@_iptn[i1][i2] = Float(s[i2+1])
			end
			s = f.gets().split(" ")
			for i2 in 0 ... @_noou
				@_optn[i1][i2] = Float(s[i2+1])
			end
		end
		f.close()
	end

	attr_accessor("_noip", "_noiu", "_noou", "_iptn", "_optn")
end

###########################################
# バックプロパゲーション(クラス Backpr)
###########################################

class Backpr < BackControl
	
	################################################
	# クラスBackprのコンストラクタ
	#      name_c : 制御データ用入力ファイル名
	#      name_s : ネットワーク記述入力ファイル名
	################################################
	
	def initialize(name_c, name_s)
	
		super(name_c)   # 親のコンストラクタ
		f = open(name_s, "r")
				# 入力ユニット,出力ユニットの数,関数タイプ
		s        = f.gets().split(" ")
		@_noiu   = Integer(s[1])   # 入力ユニットの数
		@_noou   = Integer(s[3])   # 出力ユニットの数
		@_f_type = Integer(s[5])   # シグモイド関数のタイプ,0 : [0,1],1 : [-1,1]
		@_nou    = @_noiu + @_noou   # 入力ユニットと出力ユニットの和
	                                 # 各ユニットには最も上の出力ユニットから,
	                                 # 隠れ層の各ユニット,及び,入力ユニットに至る
	                                 # 一連のユニット番号が付けられる
				# 隠れユニットの階層数と各階層のユニット数
		s       = f.gets().split(" ")
		@_nolvl = Integer(s[1])   # 隠れユニットの階層数
		@_nohu  = Array.new(@_nolvl+1)
	                 # nohu[i] : レベル(i+1)の隠れ層のユニット数(隠れ層
	                 #           には下から順に番号が付けられ,出力層はレ
	                 #           ベル(nolvl+1)の隠れ層とも見做される)
	                 #             i=0,nolvl
		@_nohu[@_nolvl] = @_noou
	
		if @_nolvl > 0
			for i1 in 0 ... @_nolvl
				@_nohu[i1]  = Integer(s[i1+3])
				@_nou      += @_nohu[i1]
			end
		end
				# 領域の確保
		@_con = Array.new(@_nou)
		for i1 in 0 ... @_nou
			@_con[i1] = Array.new(@_nou)
		end
		for i1 in 0 ... @_nou
			for i2 in 0 ... @_nou
				if i1 == i2
					@_con[i1][i2] = -1
				else
					@_con[i1][i2] = 0
	              # con[i][j] : 各ユニットに対するバイアスの与え方,及び,接続方法
	              #   [i][i] : ユニット(i+1)のバイアスの与え方
	              #     =-3 : 入力値で固定
	              #     =-2 : 入力値を初期値として,学習により変更
	              #     =-1 : 乱数で初期値を設定し,学習により変更
	              #   [i][j] : ユニット(i+1)と(j+1)の接続方法(j>i)
	              #     =0 : 接続しない
	              #     =1 : 接続する(重みの初期値を乱数で設定し,学習)
	              #     =2 : 接続する(重みの初期値を入力で与え,学習)
	              #     =3 : 接続する(重みを入力値に固定)
	              #            i=0,nou-1  j=0,nou-1
				end
			end
		end
	
		@_w = Array.new(@_nou)
		for i1 in 0 ... @_nou
			@_w[i1] = Array.new(@_nou)
		end
		for i1 in 0 ... @_nou
			for i2 in 0 ... @_nou
				@_w[i1][i2] = 0.0
	              # w[i][j] : ユニット(i+1)から(j+1)への重み(j>i)
	              # w[j][i] : (i+1)から(j+1)への重みの前回修正量(j>i)
	              # w[i][i] : ユニット(i+1)のバイアスの前回修正量
	              #   i=0,nou-1  j=0,nou-1
			end
		end
	
		@_dp    = Array.new(@_nou)   # ユニット(i+1)の誤差  i=0,nou-1
		@_op    = Array.new(@_nou)   # ユニット(i+1)の出力  i=0,nou-1
		@_theta = Array.new(@_nou)   # ユニット(i+1)のバイアス  i=0,nou
		for i1 in 0 ... @_nou
			@_theta[i1] = 0.2 * rand(0) - 0.1
		end
				# 各ユニットのバイアスとユニット間の接続関係
						# バイアス
							# バイアスデータの数
		s = f.gets().split(" ")
		n = Integer(s[1])
		f.gets()
		f.gets()
		f.gets()
	
		if n > 0
							# バイアスデータの処理
			for i0 in 0 ... n
				s = f.gets().split(" ")
								# ユニット番号
				k1 = Integer(s[0])
								# 不適当なユニット番号のチェック
				if k1 < 1 or k1 > (@_nou-@_noiu)
					print("***error  ユニット番号 " + String(k1) + " が不適当" + "\n")
				end
								# バイアスの与え方
				k1            -= 1
				id             = Integer(s[1])
				@_con[k1][k1]  = id
								# バイアスの初期設定
				if @_con[k1][k1] == -1
					x1          = Float(s[2])
					x2          = Float(s[3])
					@_theta[k1] = (x2 - x1) * rand(0) + x1
				elsif @_con[k1][k1] == -2 or @_con[k1][k1] == -3
					@_theta[k1] = Float(s[2])
				else
					print("***error  バイアスの与え方が不適当\n")
				end
			end
		end
						# 接続方法
							# 接続データの数
		s = f.gets().split(" ")
		n = Integer(s[1])
		f.gets()
		f.gets()
		f.gets()
	
		if n > 0
							# 接続データの処理
			for i0 in 0 ... n
				s  = f.gets().split(" ")
								# 接続情報
				k1 = Integer(s[0])
				k2 = Integer(s[1])
				k3 = Integer(s[2])
				k4 = Integer(s[3])
								# 不適切な接続のチェック
				sw = 0
				if k1 < 1 or k2 < 1 or k3 < 1 or k4 < 1
					sw = 1
				else
					if k1 > @_nou or k2 > @_nou or k3 > @_nou or k4 > @_nou
						sw = 1
					else
						if k1 > k2 or k3 > k4
							sw = 1
						else
							if k4 >= k1
								sw = 1
							else
								l1 = -1
								k  = 0
								i1 = @_nolvl
								while i1 > -1
									k += @_nohu[i1]
									if k1 <= k
										l1 = i1
										break
									end
									i1 -= 1
								end
								l2 = -1
								k  = 0
								i1 = @_nolvl
								while i1 > -1
									k += @_nohu[i1]
									if k4 <= k
										l2 = i1
										break
									end
									i1 -= 1
								end
								if l2 <= l1
									sw = 1
								end
							end
						end
					end
				end
	
				if sw > 0
					print("***error  ユニット番号が不適当(" + String(k1) + " " + String(k2) + " " + String(k3) + " " + String(k4) + ")\n")
				end
								# 重みの初期値の与え方
				k1 -= 1
				k2 -= 1
				k3 -= 1
				k4 -= 1
	
				id = Integer(s[4])
	
				if id == 1
					x1  = Float(s[5])
					x2  = Float(s[6])
				else
					if id > 1
						x1  = Float(s[5])
					else
						if id != 0
							print("***error  接続方法が不適当\n")
						end
					end
				end
								# 重みの初期値の設定
				for i1 in k3 ... k4+1
					for i2 in k1 ... k2+1
						@_con[i1][i2] = id
						if id == 0
							@_w[i1][i2] = 0.0
						elsif id == 1
							@_w[i1][i2] = (x2 - x1) * rand(0) + x1
						elsif id == 2
							@_w[i1][i2] = x1
						elsif id == 3
							@_w[i1][i2] = x1
						end
					end
				end
			end
		end
	
		f.close()
	end
	
	##########################################
	# 誤差の計算,及び,重みとバイアスの修正
	#      ptn[i1] : 出力パターン
	##########################################
	
	def Err_back(ptn)
	
		for i1 in 0 ... @_nou-@_noiu
						# 誤差の計算
			if i1 < @_noou
				if @_f_type == 0
					@_dp[i1] = (ptn[i1] - @_op[i1]) * @_op[i1] * (1.0 - @_op[i1])
				else
					@_dp[i1] = 0.5 * (ptn[i1] - @_op[i1]) * (@_op[i1] - 1.0) * (@_op[i1] + 1.0)
				end
			else
				x1 = 0.0
				for i2 in 0 ... i1
					if @_con[i2][i1] > 0
						x1 += @_dp[i2] * @_w[i2][i1]
					end
				end
				if @_f_type == 0
					@_dp[i1] = @_op[i1] * (1.0 - @_op[i1]) * x1
				else
					@_dp[i1] = 0.5 * (@_op[i1] - 1.0) * (@_op[i1] + 1.0) * x1
				end
			end
						# 重みの修正
			for i2 in i1+1 ... @_nou
				if @_con[i1][i2] == 1 or @_con[i1][i2] == 2
					x1           = @_eata * @_dp[i1] * @_op[i2] + @_alpha * @_w[i2][i1]
					@_w[i2][i1]  = x1
					@_w[i1][i2] += x1
				end
			end
						# バイアスの修正
			if @_con[i1][i1] >= -2
				x1           = @_eata * @_dp[i1] + @_alpha * @_w[i1][i1]
				@_w[i1][i1]  = x1
				@_theta[i1] += x1
			end
		end
	end
	
	########################################################
	# 与えられた入力パターンに対する各ユニットの出力の計算
	########################################################
	
	def Forward()
	
		i1 = @_nou - @_noiu - 1
		while i1 > -1
	
			sum = -@_theta[i1]
	
			for i2 in i1+1 ... @_nou
				if @_con[i1][i2] > 0
					sum -= @_w[i1][i2] * @_op[i2]
				end
			end
	
			if @_f_type == 0
				@_op[i1] = 1.0 / (1.0 + Math.exp(sum))
			else
				@_op[i1] = 1.0 - 2.0 / (1.0 + Math.exp(sum))
			end
			i1 -= 1
		end
	end
	
	#############################
	# 学習の実行
	#      p : 認識パターン
	#      m_tri : 最大学習回数
	#############################
	
	def Learn(p, m_tri)
	
		k0 = -1
				# エラーチェック
		if @_noiu != p._noiu or @_noou != p._noou
			print("***error  入力または出力ユニットの数が違います\n")
		end
	
		for i1 in 0 ... m_tri
				# パターンを与える順番の決定
			if @_order == 0   # 順番
				k0 += 1
				if k0 >= p._noip
					k0 = 0
				end
			else   # ランダム
				k0 = Integer(rand(0) * p._noip)
				if k0 >= p._noip
					k0 = p._noip - 1
				end
			end
				# 出力ユニットの結果を計算
			k1 = @_nou - @_noiu
			for i2 in 0 ... @_noiu
				@_op[k1+i2] = p._iptn[k0][i2]
			end
	
			Forward()
				# 重みとバイアスの修正
			Err_back(p._optn[k0])
		end
	end
	
	################################################
	# 与えられた対象の認識と出力
	#      p : 認識パターン
	#      pr : =0 : 出力を行わない
	#           =1 : 出力を行う
	#           =2 : 出力を行う(未学習パターン)
	#      tri : 現在の学習回数
	#      return : 誤って認識したパターンの数
	################################################
	
	def Recog(p, pr, tri)
	
		no = 0
				# ファイルのオープン
		if @_p_type < 0 and pr > 0
			if pr == 1
				out = open(@_o_file, "w")
				out.print("***学習パターン***\n\n")
			else
				out = open(@_o_file, "a")
				out.print("\n***未学習パターン***\n\n")
			end
		end
				# 各パターンに対する出力
		for i1 in 0 ... p._noip
						# 入力パターンの設定
			k1 = @_nou - @_noiu
			for i2 in 0 ... @_noiu
				@_op[k1+i2] = p._iptn[i1][i2]
			end
						# 出力の計算
			Forward()
						# 結果の表示
			if @_p_type != 0 and pr > 0
	
				printf("入力パターン%4d    ", (i1+1))
				for i2 in 0 ... @_noiu
					printf("%5.2f", @_op[k1+i2])
					if i2 == @_noiu-1
						print("\n")
					else
						if ((i2+1) % 10) == 0
							print("\n                    ")
						end
					end
				end
	
				print("\n    出力パターン(理想)   ")
				for i2 in 0 ... @_noou
					printf("%10.3f", p._optn[i1][i2])
					if i2 == @_noou-1
						print("\n")
					else
						if ((i2+1) % 5) == 0
							print("\n                         ")
						end
					end
				end
			end
	
			sw = 0
			if @_p_type != 0 and pr > 0
				print("                (実際)   ")
			end
			for i2 in 0 ... @_noou
				if @_p_type != 0 and pr > 0
					printf("%10.3f", @_op[i2])
					if i2 == @_noou-1
						print("\n")
					else
						if ((i2+1) % 5) == 0
							print("\n                         ")
						end
					end
				end
				if (@_op[i2]-p._optn[i1][i2]).abs() > @_eps
					sw = 1
				end
			end
	
			if sw > 0
				no += 1
			end
	
			if @_p_type < 0 and pr > 0
	
				out.printf("入力パターン%4d    ", (i1+1))
				for i2 in 0 ... @_noiu
					out.printf("%5.2f", @_op[k1+i2])
					if i2 == @_noiu-1
						out.print("\n")
					else
						if ((i2+1) % 10) == 0
							out.print("\n                    ")
						end
					end
				end
	
				out.print("\n    出力パターン(理想)   ")
				for i2 in 0 ... @_noou
					out.printf("%10.3f", p._optn[i1][i2])
					if i2 == @_noou-1
						out.print("\n")
					else
						if ((i2+1) % 5) == 0
							out.print("\n                         ")
						end
					end
				end
	
				out.print("                (実際)   ")
				for i2 in 0 ... @_noou
					out.printf("%10.3f", @_op[i2])
					if i2 == @_noou-1
						out.print("\n")
					else
						if ((i2+1) % 5) == 0
							out.print("\n                         ")
						end
					end
				end
			end
	
			if @_p_type != 0 and pr > 0
				$stdin.gets()
			end
		end
				# 重みの出力
		if (@_p_type < -1 or @_p_type > 1) and pr == 1
	
			print("    重み")
			for i1 in 0 ... @_nou-@_noiu
				printf("      to%4d from   ", (i1+1))
				ln = -1
				for i2 in 0 ... @_nou
					if @_con[i1][i2] > 0
						if ln <= 0
							if ln < 0
								ln = 0
							else
								print("\n                    ")
							end
						end
						printf("%4d%11.3f", i2+1, @_w[i1][i2])
						ln += 1
						if ln == 4
							ln = 0
						end
					end
				end
	
				print("\n")
			end
	
			print("\n    バイアス   ")
			ln = 0
			for i1 in 0 ... @_nou-@_noiu
				printf("%4d%11.3f", i1+1, @_theta[i1])
				ln += 1
				if ln == 4 and i1 != @_nou-@_noiu-1
					ln = 0
					print("\n               ")
				end
			end
			print("\n")
	
			$stdin.gets()
		end
	
		if @_p_type < 0 and pr == 1
	
			out.print("    重み\n")
			for i1 in 0 ... @_nou-@_noiu
				out.printf("      to%4d from   ", (i1+1))
				ln = -1
				for i2 in 0 ... @_nou
					if @_con[i1][i2] > 0
						if ln <= 0
							if ln < 0
								ln = 0
							else
								out.print("\n                    ")
							end
						end
						out.printf("%4d%11.3f", i2+1, @_w[i1][i2])
						ln += 1
						if ln == 4
							ln = 0
						end
					end
				end
	
				out.print("\n")
			end
	
			out.print("\n    バイアス   ")
			ln = 0
			for i1 in 0 ... @_nou-@_noiu
				out.printf("%4d%11.3f", i1+1, @_theta[i1])
				ln += 1
				if ln == 4 and i1 != @_nou-@_noiu-1
					ln = 0
					out.print("\n               ")
				end
			end
	
			if ln != 0
				out.print("\n")
			end
		end
	
		if @_p_type < 0 and pr > 0
			out.close()
		end
	
		return no
	end
end

ct = 0
no = 1
				# エラー
if ARGV.length != 2
	print("***error   入力データファイル名を指定して下さい\n")

else
				# ネットワークの定義
	net   = Backpr.new(ARGV[0], ARGV[1])
				# 学習パターン等の入力
	print("学習回数は? ")
	m_tri = Integer($stdin.gets())
	print("何回毎に収束を確認しますか? ")
	conv = Integer($stdin.gets())
	print("学習パターンのファイル名は? ")
	f_name = $stdin.gets().strip()
	dt1    = BackData.new(f_name)
				# 学習
	while ct < m_tri and no > 0

		if (ct + conv) < m_tri
			tri = conv
		else
			tri = m_tri - ct
		end
		ct += tri

		net.Learn(dt1, tri)   # 学習

		no = net.Recog(dt1, 0, ct)   # 学習対象の認識

		print("   回数 " + String(ct) + " 誤って認識したパターン数 " + String(no) + "\n")
	end

	no = net.Recog(dt1, 1, ct)   # 学習対象の認識と出力
				# 未学習パターンの認識
	print("未学習パターンの認識を行いますか?(=1:行う,=0:行わない) ")
	sw = Integer($stdin.gets())

	if sw > 0
		print("未学習パターンのファイル名は? ")
		f_name = $stdin.gets().strip()
		dt2    = BackData.new(f_name)
		no     = net.Recog(dt2, 2, ct)   # 未学習対象の認識と出力
	end
end

=begin
------------------------制御データ----------------
誤差 0.1 出力 -2 出力ファイル kekka
順番 0 η 0.5 α 0.8

------------------------構造データ----------------
入力ユニット数 2 出力ユニット数 1 関数タイプ 0
隠れ層の数 1 各隠れ層のユニット数(下から) 1
バイアス入力ユニット数 1
 ユニット番号:出力ユニットから順に番号付け
 入力方法:=-3:固定,=-2:入力後学習,=-1:乱数(default,[-0.1,0.1]))
 値:バイアス値(ー2またはー3の時)または一様乱数の範囲(下限,上限)
1 -1 -0.05 0.05
接続方法の数 2
 ユニット番号:ユニットk1からk2を,k3からk4に接続
 接続方法:=0:接続なし,=1:乱数,=2:重み入力後学習,=3:重み固定
 値:重み(2または3の時)または一様乱数の範囲(1の時:下限,上限)
3 4 1 2 1 -0.1 0.1
2 2 1 1 1 -0.1 0.1

------------------------学習データ----------------
パターンの数 4 入力ユニット数 2 出力ユニット数 1
入力1 0 0
 出力1 0
入力2 0 1
 出力2 1
入力3 1 0
 出力3 1
入力4 1 1
 出力4 0

------------------------認識データ----------------
パターンの数 4 入力ユニット数 2 出力ユニット数 1
入力1 0 0
 出力1 0
入力2 0 1
 出力2 1
入力3 1 0
 出力3 1
入力4 1 1
 出力4 0
=end