------------------------入力ファイル--------------
最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4 乱数 123
入力データファイル or.dat
------------------------or.dat--------------------
OR演算の訓練例.各行における最後の2つのデータが目標出力値
-1 -1 -1 1
-1 1 1 -1
1 -1 1 -1
1 1 1 -1
----------------------プログラム-----------------
/****************************/
/* Winner-Take-All Groups */
/* coded by Y.Suganuma */
/****************************/
import java.io.*;
import java.util.Random;
import java.util.StringTokenizer;
public class Test {
/****************/
/* main program */
/****************/
public static void main(String args[]) throws IOException, FileNotFoundException
{
int max, n, o, p, seed;
StringTokenizer str;
String name;
if (args.length > 0) {
// 基本データの入力
BufferedReader in = new BufferedReader(new FileReader(args[0]));
str = new StringTokenizer(in.readLine(), " ");
str.nextToken();
max = Integer.parseInt(str.nextToken());
str.nextToken();
p = Integer.parseInt(str.nextToken());
str.nextToken();
o = Integer.parseInt(str.nextToken());
str.nextToken();
n = Integer.parseInt(str.nextToken());
str.nextToken();
seed = Integer.parseInt(str.nextToken());
str = new StringTokenizer(in.readLine(), " ");
str.nextToken();
name = str.nextToken();
in.close();
// ネットワークの定義
Winner net = new Winner (max, n, o, p, seed);
net.input(name);
// 学習と結果の出力
if (args.length == 1)
net.learn(0, "");
else
net.learn(1, args[1]);
}
// エラー
else {
System.out.print("***error 入力データファイル名を指定して下さい\n");
System.exit(1);
}
}
}
/**********************/
/* Winnerクラスの定義 */
/**********************/
class Winner {
private int max; // 最大学習回数
private int n; // 訓練例の数
private int o; // 出力セルの数
private int p; // 入力セルの数
private int W_p[][]; // 重み(ポケット)
private int W[][]; // 重み
private int E[][]; // 訓練例
private int C[][]; // 各訓練例に対する正しい出力
private int Ct[]; // 作業領域
private Random rn; // 乱数
/******************/
/* コンストラクタ */
/******************/
Winner (int max_i, int n_i, int o_i, int p_i, int seed)
{
/*
設定
*/
max = max_i;
n = n_i;
o = o_i;
p = p_i;
rn = new Random(seed); // 乱数の初期設定
/*
領域の確保
*/
E = new int [n][p+1];
W_p = new int [o][p+1];
W = new int [o][p+1];
C = new int [n][o];
Ct = new int [o];
}
/******************************************/
/* 訓練例の分類 */
/* return : 正しく分類した訓練例の数 */
/******************************************/
int bunrui()
{
int cor, i1, i2, i3, mx = 0, mx_v = 0, num = 0, s, sw = 0;
for (i1 = 0; i1 < n; i1++) {
cor = 0;
for (i2 = 0; i2 < o; i2++) {
if (C[i1][i2] == 1)
cor = i2;
s = 0;
for (i3 = 0; i3 <= p; i3++)
s += W[i2][i3] * E[i1][i3];
if (i2 == 0) {
mx = 0;
mx_v = s;
}
else {
if (s > mx_v) {
mx = i2;
mx_v = s;
sw = 0;
}
else {
if (s == mx_v)
sw = 1;
}
}
}
if (sw == 0 && cor == mx)
num++;
}
return num;
}
/**************************/
/* 学習データの読み込み */
/* name : ファイル名 */
/**************************/
void input (String name) throws IOException, FileNotFoundException
{
int i1, i2;
StringTokenizer str;
BufferedReader st = new BufferedReader(new FileReader(name));
str = new StringTokenizer(st.readLine(), " ");
for (i1 = 0; i1 < n; i1++) {
E[i1][0] = 1;
str = new StringTokenizer(st.readLine(), " ");
for (i2 = 1; i2 <= p; i2++)
E[i1][i2] = Integer.parseInt(str.nextToken());
for (i2 = 0; i2 < o; i2++)
C[i1][i2] = Integer.parseInt(str.nextToken());
}
st.close();
}
/*********************************/
/* 学習と結果の出力 */
/* pr : =0 : 画面に出力 */
/* =1 : ファイルに出力 */
/* name : 出力ファイル名 */
/*********************************/
void learn(int pr, String name) throws FileNotFoundException
{
int i1, i2, i3, mx = 0, mx_v = 0, n_tri, s, sw;
int num[] = new int [1];
// 学習
n_tri = pocket(num);
// 結果の出力
if (pr == 0) {
System.out.print("重み\n");
for (i1 = 0; i1 < o; i1++) {
for (i2 = 0; i2 <= p; i2++)
System.out.print(" " + W_p[i1][i2]);
System.out.println();
}
System.out.print("分類結果\n");
for (i1 = 0; i1 < n; i1++) {
sw = 0;
for (i2 = 0; i2 < o; i2++) {
s = 0;
for (i3 = 0; i3 <= p; i3++)
s += W_p[i2][i3] * E[i1][i3];
if (i2 == 0) {
mx_v = s;
mx = 0;
}
else {
if (s > mx_v) {
sw = 0;
mx_v = s;
mx = i2;
}
else {
if (s == mx_v)
sw = 1;
}
}
}
for (i2 = 1; i2 <= p; i2++)
System.out.print(" " + E[i1][i2]);
System.out.print(" Cor ");
for (i2 = 0; i2 < o; i2++)
System.out.print(" " + C[i1][i2]);
if (sw > 0)
mx = -1;
System.out.println(" Res " + (mx+1));
}
}
else {
PrintStream out = new PrintStream(new FileOutputStream(name));
out.print("重み\n");
for (i1 = 0; i1 < o; i1++) {
for (i2 = 0; i2 <= p; i2++)
out.print(" " + W_p[i1][i2]);
out.println();
}
out.print("分類結果\n");
for (i1 = 0; i1 < n; i1++) {
sw = 0;
for (i2 = 0; i2 < o; i2++) {
s = 0;
for (i3 = 0; i3 <= p; i3++)
s += W_p[i2][i3] * E[i1][i3];
if (i2 == 0) {
mx_v = s;
mx = 0;
}
else {
if (s > mx_v) {
sw = 0;
mx_v = s;
mx = i2;
}
else {
if (s == mx_v)
sw = 1;
}
}
}
for (i2 = 1; i2 <= p; i2++)
out.print(" " + E[i1][i2]);
out.print(" Cor ");
for (i2 = 0; i2 < o; i2++)
out.print(" " + C[i1][i2]);
if (sw > 0)
mx = -1;
out.println(" Res " + (mx+1));
}
out.close();
}
if (n == num[0])
System.out.print(" !!すべてを分類(試行回数:" + n_tri + ")\n");
else
System.out.print(" !!" + num[0] + " 個を分類\n");
}
/********************************************/
/* Pocket Algorith with Ratcet */
/* num_p : 正しく分類した訓練例の数 */
/* return : =0 : 最大学習回数 */
/* >0 : すべてを分類(回数) */
/********************************************/
int pocket(int num_p[])
{
int cor, count = 0, i1, i2, k, mx = 0, num, run = 0, run_p = 0, s, sw = -1, sw1;
/*
初期設定
*/
num_p[0] = 0;
for (i1 = 0; i1 < o; i1++) {
for (i2 = 0; i2 <= p; i2++)
W[i1][i2] = 0;
}
/*
実行
*/
while (sw < 0) {
// 終了チェック
count++;;
if (count > max)
sw = 0;
else {
// 訓練例の選択
k = (int)(rn.nextDouble() * n);
if (k >= n)
k = n - 1;
// 出力の計算
sw1 = 0;
cor = -1;
for (i1 = 0; i1 < o; i1++) {
if (C[k][i1] == 1)
cor = i1;
s = 0;
for (i2 = 0; i2 <= p; i2++)
s += W[i1][i2] * E[k][i2];
Ct[i1] = s;
if (i1 == 0)
mx = 0;
else {
if (s > Ct[mx]) {
mx = i1;
sw1 = 0;
}
else {
if (s == Ct[mx]) {
sw1 = 1;
if (cor >= 0 && mx == cor)
mx = i1;
}
}
}
}
// 正しい分類
if (sw1 == 0 && cor == mx) {
run++;
if (run > run_p) {
num = bunrui();
if (num > num_p[0]) {
num_p[0] = num;
run_p = run;
for (i1 = 0; i1 < o; i1++) {
for (i2 = 0; i2 <= p; i2++)
W_p[i1][i2] = W[i1][i2];
}
if (num == n)
sw = count;
}
}
}
// 誤った分類
else {
run = 0;
for (i1 = 0; i1 <= p; i1++) {
W[cor][i1] += E[k][i1];
W[mx][i1] -= E[k][i1];
}
}
}
}
return sw;
}
}