# -*- coding: UTF-8 -*- import sys from math import * from random import * import numpy as np ############################# # Competitionクラスの定義 # coded by Y.Suganuma ############################# class Competition : ######################### # コンストラクタ # max : 最大学習回数 # n : 訓練例の数 # o : 出力セルの数 # p : 入力セルの数 # sig : 重み修正係数 ######################### def __init__(self, max_i, n_i, o_i, p_i, sig_i) : # 設定 self.max = max_i self.n = n_i self.o = o_i self.p = p_i self.sig = sig_i # 領域の確保 self.E = np.empty((self.n, self.p), np.int) # 訓練例 self.W = np.empty((self.o, self.p), np.float) # 重み #*************************/ # 学習データの読み込み */ # name : ファイル名 */ #*************************/ def Input(self, name) : f = open(name, "r") for i1 in range(0, self.n) : s = f.readline().split() for i2 in range(0, self.p) : self.E[i1][i2] = int(s[i2]) f.close() ################################# # 学習と結果の出力 # pr : =0 : 画面に出力 # =1 : ファイルに出力 # name : 出力ファイル名 ################################# def Learn(self, pr, name="") : # 初期設定 mx_v = 0.0 mx = 0 for i1 in range(0, self.o) : sum = 0.0 for i2 in range(0, self.p) : self.W[i1][i2] = random() sum += self.W[i1][i2] sum = 1.0 / sum for i2 in range(0, self.p) : self.W[i1][i2] *= sum # 学習 for count in range(0, self.max) : # 訓練例の選択 k = int(random() * self.n) if k >= self.n : k = self.n - 1 # 出力の計算 for i1 in range(0, self.o) : s = 0.0 for i2 in range(0, self.p) : s += self.W[i1][i2] * self.E[k][i2] if i1 == 0 or s > mx_v : mx = i1 mx_v = s # 重みの修正 sum = 0.0 for i1 in range(0, self.p) : sum += self.E[k][i1] for i1 in range(0, self.p) : self.W[mx][i1] += self.sig * (self.E[k][i1] / sum - self.W[mx][i1]) # 出力 if pr == 0 : out = sys.stdout else : out = open(name, "w") out.write("分類結果\n") for i1 in range(0, self.n) : for i2 in range(0, self.p) : out.write(" " + str(self.E[i1][i2])) out.write(" Res ") for i2 in range(0, self.o) : s = 0.0 for i3 in range(0, self.p) : s += self.W[i2][i3] * self.E[i1][i3] if i2 == 0 or s > mx_v : mx = i2 mx_v = s out.write(" " + str(mx+1) + "\n") ---------------------------------- # -*- coding: UTF-8 -*- import numpy as np import sys from math import * from random import * from function import Competition ########################## # 競合学習 # coded by Y.Suganuma ########################## if len(sys.argv) > 1 : # 基本データの入力 f = open(sys.argv[1], "r") s = f.readline().split() max = int(s[1]) p = int(s[3]) o = int(s[5]) n = int(s[7]) sig = float(s[9]) s = f.readline().split() name = s[1] f.close() # ネットワークの定義 net = Competition(max, n, o, p, sig) net.Input(name) # 学習と結果の出力 if len(sys.argv) == 2 : net.Learn(0) else : net.Learn(1, sys.argv[2]) else : print("***error 入力ファイル名を指定して下さい") ------------------------入力ファイル-------------- 最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 係数(0~1) 0.1 入力データファイル pat.dat ------------------------pat.dat------------------- 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0