# -*- coding: UTF-8 -*-
import sys
from math import *
from random import *
import numpy as np
#############################
# Winnerクラスの定義
# coded by Y.Suganuma
#############################
class Winner :
#########################
# コンストラクタ
# max : 最大学習回数
# n : 訓練例の数
# o : 出力セルの数
# p : 入力セルの数
#########################
def __init__(self, max_i, n_i, o_i, p_i) :
# 設定
self.max = max_i
self.n = n_i
self.o = o_i
self.p = p_i
# 領域の確保
self.E = np.empty((self.n, self.p+1), np.int) # 訓練例
self.C = np.empty((self.n, self.o), np.int) # 各訓練例に対する正しい出力
self.W_p = np.empty((self.o, self.p+1), np.int) # 重み(ポケット)
self.W = np.empty((self.o, self.p+1), np.int) # 重み
self.Ct = np.empty(self.o, np.int) # 作業領域
#########################################
# 訓練例の分類
# return : 正しく分類した訓練例の数
#########################################
def Bunrui(self) :
mx = 0
mx_v = 0
num = 0
sw = 0
for i1 in range(0, self.n) :
cor = 0
for i2 in range(0, self.o) :
if self.C[i1][i2] == 1 :
cor = i2
s = 0
for i3 in range(0, self.p+1) :
s += self.W[i2][i3] * self.E[i1][i3]
if i2 == 0 :
mx = 0
mx_v = s
else :
if s > mx_v :
mx = i2
mx_v = s
sw = 0
else :
if s == mx_v :
sw = 1
if sw == 0 and cor == mx :
num += 1
return num
#*************************/
# 学習データの読み込み */
# name : ファイル名 */
#*************************/
def Input(self, name) :
f = open(name, "r")
f.readline()
for i1 in range(0, self.n) :
self.E[i1][0] = 1
s = f.readline().split()
for i2 in range(1, self.p+1) :
self.E[i1][i2] = int(s[i2-1])
for i2 in range(0, self.o) :
self.C[i1][i2] = int(s[self.p+i2])
f.close()
#################################
# 学習と結果の出力
# pr : =0 : 画面に出力
# =1 : ファイルに出力
# name : 出力ファイル名
#################################
def Learn(self, pr, name="") :
mx = 0
mx_v = 0
num = np.empty(1, np.int)
n_tri = self.Pocket(num)
if pr == 0 :
out = sys.stdout
else :
out = open(name, "w")
out.write("重み\n")
for i1 in range(0, self.o) :
for i2 in range(0, self.p+1) :
out.write(" " + str(self.W_p[i1][i2]))
out.write("\n")
out.write("分類結果\n")
for i1 in range(0, self.n) :
sw = 0
for i2 in range(0, self.o) :
s = 0
for i3 in range(0, self.p+1) :
s += self.W_p[i2][i3] * self.E[i1][i3]
if i2 == 0 :
mx_v = s
mx = 0
else :
if s > mx_v :
sw = 0
mx_v = s
mx = i2
else :
if s == mx_v :
sw = 1
for i2 in range(1, self.p+1) :
out.write(" " + str(self.E[i1][i2]))
out.write(" Cor ")
for i2 in range(0, self.o) :
out.write(" " + str(self.C[i1][i2]))
if sw > 0 :
mx = -1
out.write(" Res " + str(mx+1) + "\n")
if self.n == num[0] :
print(" !!すべてを分類(試行回数:" + str(n_tri) + ")")
else :
print(" !!" + str(num[0]) + " 個を分類")
############################################
# Pocket Algorith with Ratcet
# num_p : 正しく分類した訓練例の数
# return : =0 : 最大学習回数
# >0 : すべてを分類(回数)
############################################
def Pocket(self, num_p) :
# 初期設定
count = 0
mx = 0
run = 0
run_p = 0
sw = -1
num_p[0] = 0
for i1 in range(0, self.o) :
for i2 in range(0, self.p+1) :
self.W[i1][i2] = 0
# 実行
while sw < 0 :
# 終了チェック
count += 1
if count > self.max :
sw = 0
else :
# 訓練例の選択
k = int(random() * self.n)
if k >= self.n :
k = self.n - 1
# 出力の計算
sw1 = 0
cor = -1
for i1 in range(0, self.o) :
if self.C[k][i1] == 1 :
cor = i1
s = 0
for i2 in range(0, self.p+1) :
s += self.W[i1][i2] * self.E[k][i2]
self.Ct[i1] = s
if i1 == 0 :
mx = 0
else :
if s > self.Ct[mx] :
mx = i1
sw1 = 0
else :
if s == self.Ct[mx] :
sw1 = 1
if cor >= 0 and mx == cor :
mx = i1
# 正しい分類
if sw1 == 0 and cor == mx :
run += 1
if run > run_p :
num = self.Bunrui()
if num > num_p[0] :
num_p[0] = num
run_p = run
for i1 in range(0, self.o) :
for i2 in range(0, self.p+1) :
self.W_p[i1][i2] = self.W[i1][i2]
if num == self.n :
sw = count
# 誤った分類
else :
run = 0
for i1 in range(0, self.p+1) :
self.W[cor][i1] += self.E[k][i1]
self.W[mx][i1] -= self.E[k][i1]
return sw
----------------------------------
# -*- coding: UTF-8 -*-
import numpy as np
import sys
from math import *
from random import *
from function import Winner
####################################
# Winner-Take-All Groups
# coded by Y.Suganuma
####################################
if len(sys.argv) > 1 :
# 基本データの入力
f = open(sys.argv[1], "r")
s = f.readline().split()
max = int(s[1])
p = int(s[3])
o = int(s[5])
n = int(s[7])
s = f.readline().split()
name = s[1]
f.close()
# ネットワークの定義
net = Winner(max, n, o, p)
net.Input(name)
# 学習と結果の出力
if len(sys.argv) == 2 :
net.Learn(0)
else :
net.Learn(1, sys.argv[2])
else :
print("***error 入力ファイル名を指定して下さい")
------------------------入力ファイル--------------
最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4
入力データファイル or.dat
------------------------or.dat--------------------
OR演算の訓練例.最後の2つのデータが目標出力値
-1 -1 -1 1
-1 1 1 -1
1 -1 1 -1
1 1 1 -1