Winner-Take-All Groups

# -*- coding: UTF-8 -*-
import sys
from math import *
from random import *
import numpy as np

#############################
# Winnerクラスの定義
#      coded by Y.Suganuma
#############################

class Winner :

	#########################
	# コンストラクタ
	#      max : 最大学習回数
	#      n : 訓練例の数
	#      o : 出力セルの数
	#      p : 入力セルの数
	#########################
	def __init__(self, max_i, n_i, o_i, p_i) :
			# 設定
		self.max = max_i
		self.n   = n_i
		self.o   = o_i
		self.p   = p_i
			# 領域の確保
		self.E   = np.empty((self.n, self.p+1), np.int)   # 訓練例
		self.C   = np.empty((self.n, self.o), np.int)   # 各訓練例に対する正しい出力
		self.W_p = np.empty((self.o, self.p+1), np.int)   # 重み(ポケット)
		self.W   = np.empty((self.o, self.p+1), np.int)   # 重み
		self.Ct  = np.empty(self.o, np.int)   # 作業領域
	
	#########################################
	# 訓練例の分類
	#      return : 正しく分類した訓練例の数
	#########################################
	def Bunrui(self) :
	
		mx   = 0
		mx_v = 0
		num  = 0
		sw   = 0
	
		for i1 in range(0, self.n) :
			cor = 0
			for i2 in range(0, self.o) :
				if self.C[i1][i2] == 1 :
					cor = i2
				s = 0
				for i3 in range(0, self.p+1) :
					s += self.W[i2][i3] * self.E[i1][i3]
				if i2 == 0 :
					mx   = 0
					mx_v = s
				else :
					if s > mx_v :
						mx   = i2
						mx_v = s
						sw   = 0
					else :
						if s == mx_v :
							sw = 1
	
			if sw == 0 and cor == mx :
				num += 1
	
		return num
	
	#*************************/
	# 学習データの読み込み   */
	#      name : ファイル名 */
	#*************************/
	
	def Input(self, name) :
	
		f = open(name, "r")
		f.readline()
	
		for i1 in range(0, self.n) :
			self.E[i1][0] = 1
			s = f.readline().split()
			for i2 in range(1, self.p+1) :
				self.E[i1][i2] = int(s[i2-1])
			for i2 in range(0, self.o) :
				self.C[i1][i2] = int(s[self.p+i2])
	
		f.close()
	
	#################################
	# 学習と結果の出力
	#      pr : =0 : 画面に出力
	#           =1 : ファイルに出力
	#      name : 出力ファイル名
	#################################

	def Learn(self, pr, name="") :

		mx   = 0
		mx_v = 0
	
		num   = np.empty(1, np.int)
		n_tri = self.Pocket(num)
	
		if pr == 0 :
			out = sys.stdout
		else :
			out = open(name, "w")
	
		out.write("重み\n")
		for i1 in range(0, self.o) :
			for i2 in range(0, self.p+1) :
				out.write(" " + str(self.W_p[i1][i2]))
			out.write("\n")
	
		out.write("分類結果\n")
		for i1 in range(0, self.n) :
			sw = 0
			for i2 in range(0, self.o) :
				s = 0
				for i3 in range(0, self.p+1) :
					s += self.W_p[i2][i3] * self.E[i1][i3]
				if i2 == 0 :
					mx_v = s
					mx   = 0
				else :
					if s > mx_v :
						sw   = 0
						mx_v = s
						mx   = i2
					else :
						if s == mx_v :
							sw = 1
			for i2 in range(1, self.p+1) :
				out.write(" " + str(self.E[i1][i2]))
			out.write(" Cor ")
			for i2 in range(0, self.o) :
				out.write(" " + str(self.C[i1][i2]))
			if sw > 0 :
				mx = -1
			out.write(" Res " + str(mx+1) + "\n")
	
		if self.n == num[0] :
			print("  !!すべてを分類(試行回数:" + str(n_tri) + ")")
		else :
			print("  !!" + str(num[0]) + " 個を分類")
	
	############################################
	# Pocket Algorith with Ratcet
	#      num_p : 正しく分類した訓練例の数
	#      return : =0 : 最大学習回数
	#               >0  : すべてを分類(回数)
	############################################
	def Pocket(self, num_p) :
			# 初期設定
		count    = 0
		mx       = 0
		run      = 0
		run_p    = 0
		sw       = -1
		num_p[0] = 0
	
		for i1 in range(0, self.o) :
			for i2 in range(0, self.p+1) :
				self.W[i1][i2] = 0
			# 実行
		while sw < 0 :
					# 終了チェック
			count += 1
			if count > self.max :
				sw = 0
	
			else :
					# 訓練例の選択
				k = int(random() * self.n)
				if k >= self.n :
					k = self.n - 1
					# 出力の計算
				sw1 = 0
				cor = -1
	
				for i1 in range(0, self.o) :
	
					if self.C[k][i1] == 1 :
						cor = i1
	
					s = 0
					for i2 in range(0, self.p+1) :
						s += self.W[i1][i2] * self.E[k][i2]
					self.Ct[i1] = s
	
					if i1 == 0 :
						mx = 0
					else :
						if s > self.Ct[mx] :
							mx = i1
							sw1 = 0
						else :
							if s == self.Ct[mx] :
								sw1 = 1
								if cor >= 0 and mx == cor :
									mx = i1
					# 正しい分類
				if sw1 == 0 and cor == mx :
					run += 1
					if run > run_p :
						num = self.Bunrui()
						if num > num_p[0] :
							num_p[0] = num
							run_p    = run
							for i1 in range(0, self.o) :
								for i2 in range(0, self.p+1) :
									self.W_p[i1][i2] = self.W[i1][i2]
							if num == self.n :
								sw = count
					# 誤った分類
				else :
					run  = 0
					for i1 in range(0, self.p+1) :
						self.W[cor][i1] += self.E[k][i1]
						self.W[mx][i1]  -= self.E[k][i1]
	
		return sw

----------------------------------

# -*- coding: UTF-8 -*-
import numpy as np
import sys
from math import *
from random import *
from function import Winner

####################################
# Winner-Take-All Groups
#      coded by Y.Suganuma
####################################

if len(sys.argv) > 1 :
				# 基本データの入力
	f    = open(sys.argv[1], "r")
	s    = f.readline().split()
	max  = int(s[1])
	p    = int(s[3])
	o    = int(s[5])
	n    = int(s[7])
	s    = f.readline().split()
	name = s[1]
	f.close()
				# ネットワークの定義
	net = Winner(max, n, o, p)
	net.Input(name)
				# 学習と結果の出力
	if len(sys.argv) == 2 :
		net.Learn(0)
	else :
		net.Learn(1, sys.argv[2])

else :
	print("***error   入力ファイル名を指定して下さい")

------------------------入力ファイル--------------
最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4
入力データファイル or.dat

------------------------or.dat--------------------
OR演算の訓練例.最後の2つのデータが目標出力値
-1 -1 -1 1
-1  1  1 -1
 1 -1  1 -1
 1  1  1 -1