競合学習

# -*- coding: UTF-8 -*-
import sys
from math import *
from random import *
import numpy as np

#############################
# Competitionクラスの定義
#      coded by Y.Suganuma
#############################

class Competition :

	#########################
	# コンストラクタ
	#      max : 最大学習回数
	#      n : 訓練例の数
	#      o : 出力セルの数
	#      p : 入力セルの数
	#      sig : 重み修正係数
	#########################

	def __init__(self, max_i, n_i, o_i, p_i, sig_i) :
			# 設定
		self.max = max_i
		self.n   = n_i
		self.o   = o_i
		self.p   = p_i
		self.sig = sig_i
			# 領域の確保
		self.E = np.empty((self.n, self.p), np.int)   # 訓練例
		self.W = np.empty((self.o, self.p), np.float)   # 重み
	
	#*************************/
	# 学習データの読み込み   */
	#      name : ファイル名 */
	#*************************/
	
	def Input(self, name) :
	
		f = open(name, "r")
	
		for i1 in range(0, self.n) :
			s = f.readline().split()
			for i2 in range(0, self.p) :
				self.E[i1][i2] = int(s[i2])
	
		f.close()
	
	#################################
	# 学習と結果の出力
	#      pr : =0 : 画面に出力
	#           =1 : ファイルに出力
	#      name : 出力ファイル名
	#################################
	
	def Learn(self, pr, name="") :
			# 初期設定
		mx_v = 0.0
		mx   = 0
		for i1 in range(0, self.o) :
			sum = 0.0
			for i2 in range(0, self.p) :
				self.W[i1][i2]  = random()
				sum            += self.W[i1][i2]
			sum = 1.0 / sum
			for i2 in range(0, self.p) :
				self.W[i1][i2] *= sum
			# 学習
		for count in range(0, self.max) :
					# 訓練例の選択
			k = int(random() * self.n)
			if k >= self.n :
				k = self.n - 1
					# 出力の計算
			for i1 in range(0, self.o) :
				s = 0.0
				for i2 in range(0, self.p) :
					s += self.W[i1][i2] * self.E[k][i2]
				if i1 == 0 or s > mx_v :
					mx   = i1
					mx_v = s
					# 重みの修正
			sum = 0.0
			for i1 in range(0, self.p) :
				sum += self.E[k][i1]
			for i1 in range(0, self.p) :
				self.W[mx][i1] += self.sig * (self.E[k][i1] / sum - self.W[mx][i1])
			# 出力
		if pr == 0 :
			out = sys.stdout
		else :
			out = open(name, "w")
	
		out.write("分類結果\n")
		for i1 in range(0, self.n) :
			for i2 in range(0, self.p) :
				out.write(" " + str(self.E[i1][i2]))
			out.write(" Res ")
			for i2 in range(0, self.o) :
				s = 0.0
				for i3 in range(0, self.p) :
					s += self.W[i2][i3] * self.E[i1][i3]
				if i2 == 0 or s > mx_v :
					mx   = i2
					mx_v = s
			out.write(" " + str(mx+1) + "\n")

----------------------------------

# -*- coding: UTF-8 -*-
import numpy as np
import sys
from math import *
from random import *
from function import Competition

##########################
# 競合学習
#      coded by Y.Suganuma
##########################

if len(sys.argv) > 1 :
				# 基本データの入力
	f    = open(sys.argv[1], "r")
	s    = f.readline().split()
	max  = int(s[1])
	p    = int(s[3])
	o    = int(s[5])
	n    = int(s[7])
	sig  = float(s[9])
	s    = f.readline().split()
	name = s[1]
	f.close()
				# ネットワークの定義
	net = Competition(max, n, o, p, sig)
	net.Input(name)
				# 学習と結果の出力
	if len(sys.argv) == 2 :
		net.Learn(0)
	else :
		net.Learn(1, sys.argv[2])

else :
	print("***error   入力ファイル名を指定して下さい")

------------------------入力ファイル--------------
最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 係数(0~1) 0.1
入力データファイル pat.dat

------------------------pat.dat-------------------
0 0 0 0 0 0 0 1 1
0 0 0 0 0 0 1 0 1
0 0 0 0 0 0 1 1 0
0 0 0 0 1 1 0 0 0
0 0 0 1 0 1 0 0 0
0 0 0 1 1 0 0 0 0
0 1 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 0
1 1 0 0 0 0 0 0 0