# -*- coding: UTF-8 -*- import sys from math import * from random import * import numpy as np ############################# # Winnerクラスの定義 # coded by Y.Suganuma ############################# class Winner : ######################### # コンストラクタ # max : 最大学習回数 # n : 訓練例の数 # o : 出力セルの数 # p : 入力セルの数 ######################### def __init__(self, max_i, n_i, o_i, p_i) : # 設定 self.max = max_i self.n = n_i self.o = o_i self.p = p_i # 領域の確保 self.E = np.empty((self.n, self.p+1), np.int) # 訓練例 self.C = np.empty((self.n, self.o), np.int) # 各訓練例に対する正しい出力 self.W_p = np.empty((self.o, self.p+1), np.int) # 重み(ポケット) self.W = np.empty((self.o, self.p+1), np.int) # 重み self.Ct = np.empty(self.o, np.int) # 作業領域 ######################################### # 訓練例の分類 # return : 正しく分類した訓練例の数 ######################################### def Bunrui(self) : mx = 0 mx_v = 0 num = 0 sw = 0 for i1 in range(0, self.n) : cor = 0 for i2 in range(0, self.o) : if self.C[i1][i2] == 1 : cor = i2 s = 0 for i3 in range(0, self.p+1) : s += self.W[i2][i3] * self.E[i1][i3] if i2 == 0 : mx = 0 mx_v = s else : if s > mx_v : mx = i2 mx_v = s sw = 0 else : if s == mx_v : sw = 1 if sw == 0 and cor == mx : num += 1 return num #*************************/ # 学習データの読み込み */ # name : ファイル名 */ #*************************/ def Input(self, name) : f = open(name, "r") f.readline() for i1 in range(0, self.n) : self.E[i1][0] = 1 s = f.readline().split() for i2 in range(1, self.p+1) : self.E[i1][i2] = int(s[i2-1]) for i2 in range(0, self.o) : self.C[i1][i2] = int(s[self.p+i2]) f.close() ################################# # 学習と結果の出力 # pr : =0 : 画面に出力 # =1 : ファイルに出力 # name : 出力ファイル名 ################################# def Learn(self, pr, name="") : mx = 0 mx_v = 0 num = np.empty(1, np.int) n_tri = self.Pocket(num) if pr == 0 : out = sys.stdout else : out = open(name, "w") out.write("重み\n") for i1 in range(0, self.o) : for i2 in range(0, self.p+1) : out.write(" " + str(self.W_p[i1][i2])) out.write("\n") out.write("分類結果\n") for i1 in range(0, self.n) : sw = 0 for i2 in range(0, self.o) : s = 0 for i3 in range(0, self.p+1) : s += self.W_p[i2][i3] * self.E[i1][i3] if i2 == 0 : mx_v = s mx = 0 else : if s > mx_v : sw = 0 mx_v = s mx = i2 else : if s == mx_v : sw = 1 for i2 in range(1, self.p+1) : out.write(" " + str(self.E[i1][i2])) out.write(" Cor ") for i2 in range(0, self.o) : out.write(" " + str(self.C[i1][i2])) if sw > 0 : mx = -1 out.write(" Res " + str(mx+1) + "\n") if self.n == num[0] : print(" !!すべてを分類(試行回数:" + str(n_tri) + ")") else : print(" !!" + str(num[0]) + " 個を分類") ############################################ # Pocket Algorith with Ratcet # num_p : 正しく分類した訓練例の数 # return : =0 : 最大学習回数 # >0 : すべてを分類(回数) ############################################ def Pocket(self, num_p) : # 初期設定 count = 0 mx = 0 run = 0 run_p = 0 sw = -1 num_p[0] = 0 for i1 in range(0, self.o) : for i2 in range(0, self.p+1) : self.W[i1][i2] = 0 # 実行 while sw < 0 : # 終了チェック count += 1 if count > self.max : sw = 0 else : # 訓練例の選択 k = int(random() * self.n) if k >= self.n : k = self.n - 1 # 出力の計算 sw1 = 0 cor = -1 for i1 in range(0, self.o) : if self.C[k][i1] == 1 : cor = i1 s = 0 for i2 in range(0, self.p+1) : s += self.W[i1][i2] * self.E[k][i2] self.Ct[i1] = s if i1 == 0 : mx = 0 else : if s > self.Ct[mx] : mx = i1 sw1 = 0 else : if s == self.Ct[mx] : sw1 = 1 if cor >= 0 and mx == cor : mx = i1 # 正しい分類 if sw1 == 0 and cor == mx : run += 1 if run > run_p : num = self.Bunrui() if num > num_p[0] : num_p[0] = num run_p = run for i1 in range(0, self.o) : for i2 in range(0, self.p+1) : self.W_p[i1][i2] = self.W[i1][i2] if num == self.n : sw = count # 誤った分類 else : run = 0 for i1 in range(0, self.p+1) : self.W[cor][i1] += self.E[k][i1] self.W[mx][i1] -= self.E[k][i1] return sw ---------------------------------- # -*- coding: UTF-8 -*- import numpy as np import sys from math import * from random import * from function import Winner #################################### # Winner-Take-All Groups # coded by Y.Suganuma #################################### if len(sys.argv) > 1 : # 基本データの入力 f = open(sys.argv[1], "r") s = f.readline().split() max = int(s[1]) p = int(s[3]) o = int(s[5]) n = int(s[7]) s = f.readline().split() name = s[1] f.close() # ネットワークの定義 net = Winner(max, n, o, p) net.Input(name) # 学習と結果の出力 if len(sys.argv) == 2 : net.Learn(0) else : net.Learn(1, sys.argv[2]) else : print("***error 入力ファイル名を指定して下さい") ------------------------入力ファイル-------------- 最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4 入力データファイル or.dat ------------------------or.dat-------------------- OR演算の訓練例.最後の2つのデータが目標出力値 -1 -1 -1 1 -1 1 1 -1 1 -1 1 -1 1 1 1 -1