# -*- coding: UTF-8 -*-
import sys
from math import *
from random import *
import numpy as np
#############################
# Competitionクラスの定義
# coded by Y.Suganuma
#############################
class Competition :
#########################
# コンストラクタ
# max : 最大学習回数
# n : 訓練例の数
# o : 出力セルの数
# p : 入力セルの数
# sig : 重み修正係数
#########################
def __init__(self, max_i, n_i, o_i, p_i, sig_i) :
# 設定
self.max = max_i
self.n = n_i
self.o = o_i
self.p = p_i
self.sig = sig_i
# 領域の確保
self.E = np.empty((self.n, self.p), np.int) # 訓練例
self.W = np.empty((self.o, self.p), np.float) # 重み
#*************************/
# 学習データの読み込み */
# name : ファイル名 */
#*************************/
def Input(self, name) :
f = open(name, "r")
for i1 in range(0, self.n) :
s = f.readline().split()
for i2 in range(0, self.p) :
self.E[i1][i2] = int(s[i2])
f.close()
#################################
# 学習と結果の出力
# pr : =0 : 画面に出力
# =1 : ファイルに出力
# name : 出力ファイル名
#################################
def Learn(self, pr, name="") :
# 初期設定
mx_v = 0.0
mx = 0
for i1 in range(0, self.o) :
sum = 0.0
for i2 in range(0, self.p) :
self.W[i1][i2] = random()
sum += self.W[i1][i2]
sum = 1.0 / sum
for i2 in range(0, self.p) :
self.W[i1][i2] *= sum
# 学習
for count in range(0, self.max) :
# 訓練例の選択
k = int(random() * self.n)
if k >= self.n :
k = self.n - 1
# 出力の計算
for i1 in range(0, self.o) :
s = 0.0
for i2 in range(0, self.p) :
s += self.W[i1][i2] * self.E[k][i2]
if i1 == 0 or s > mx_v :
mx = i1
mx_v = s
# 重みの修正
sum = 0.0
for i1 in range(0, self.p) :
sum += self.E[k][i1]
for i1 in range(0, self.p) :
self.W[mx][i1] += self.sig * (self.E[k][i1] / sum - self.W[mx][i1])
# 出力
if pr == 0 :
out = sys.stdout
else :
out = open(name, "w")
out.write("分類結果\n")
for i1 in range(0, self.n) :
for i2 in range(0, self.p) :
out.write(" " + str(self.E[i1][i2]))
out.write(" Res ")
for i2 in range(0, self.o) :
s = 0.0
for i3 in range(0, self.p) :
s += self.W[i2][i3] * self.E[i1][i3]
if i2 == 0 or s > mx_v :
mx = i2
mx_v = s
out.write(" " + str(mx+1) + "\n")
----------------------------------
# -*- coding: UTF-8 -*-
import numpy as np
import sys
from math import *
from random import *
from function import Competition
##########################
# 競合学習
# coded by Y.Suganuma
##########################
if len(sys.argv) > 1 :
# 基本データの入力
f = open(sys.argv[1], "r")
s = f.readline().split()
max = int(s[1])
p = int(s[3])
o = int(s[5])
n = int(s[7])
sig = float(s[9])
s = f.readline().split()
name = s[1]
f.close()
# ネットワークの定義
net = Competition(max, n, o, p, sig)
net.Input(name)
# 学習と結果の出力
if len(sys.argv) == 2 :
net.Learn(0)
else :
net.Learn(1, sys.argv[2])
else :
print("***error 入力ファイル名を指定して下さい")
------------------------入力ファイル--------------
最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 係数(0~1) 0.1
入力データファイル pat.dat
------------------------pat.dat-------------------
0 0 0 0 0 0 0 1 1
0 0 0 0 0 0 1 0 1
0 0 0 0 0 0 1 1 0
0 0 0 0 1 1 0 0 0
0 0 0 1 0 1 0 0 0
0 0 0 1 1 0 0 0 0
0 1 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 0
1 1 0 0 0 0 0 0 0