------------------------入力ファイル-------------- 最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4 乱数 123 入力データファイル or.dat ------------------------or.dat-------------------- OR演算の訓練例.各行における最後の2つのデータが目標出力値 -1 -1 -1 1 -1 1 1 -1 1 -1 1 -1 1 1 1 -1 ----------------------プログラム----------------- /****************************/ /* Winner-Take-All Groups */ /* coded by Y.Suganuma */ /****************************/ import java.io.*; import java.util.Random; import java.util.StringTokenizer; public class Test { /****************/ /* main program */ /****************/ public static void main(String args[]) throws IOException, FileNotFoundException { int max, n, o, p, seed; StringTokenizer str; String name; if (args.length > 0) { // 基本データの入力 BufferedReader in = new BufferedReader(new FileReader(args[0])); str = new StringTokenizer(in.readLine(), " "); str.nextToken(); max = Integer.parseInt(str.nextToken()); str.nextToken(); p = Integer.parseInt(str.nextToken()); str.nextToken(); o = Integer.parseInt(str.nextToken()); str.nextToken(); n = Integer.parseInt(str.nextToken()); str.nextToken(); seed = Integer.parseInt(str.nextToken()); str = new StringTokenizer(in.readLine(), " "); str.nextToken(); name = str.nextToken(); in.close(); // ネットワークの定義 Winner net = new Winner (max, n, o, p, seed); net.input(name); // 学習と結果の出力 if (args.length == 1) net.learn(0, ""); else net.learn(1, args[1]); } // エラー else { System.out.print("***error 入力データファイル名を指定して下さい\n"); System.exit(1); } } } /**********************/ /* Winnerクラスの定義 */ /**********************/ class Winner { private int max; // 最大学習回数 private int n; // 訓練例の数 private int o; // 出力セルの数 private int p; // 入力セルの数 private int W_p[][]; // 重み(ポケット) private int W[][]; // 重み private int E[][]; // 訓練例 private int C[][]; // 各訓練例に対する正しい出力 private int Ct[]; // 作業領域 private Random rn; // 乱数 /******************/ /* コンストラクタ */ /******************/ Winner (int max_i, int n_i, int o_i, int p_i, int seed) { /* 設定 */ max = max_i; n = n_i; o = o_i; p = p_i; rn = new Random(seed); // 乱数の初期設定 /* 領域の確保 */ E = new int [n][p+1]; W_p = new int [o][p+1]; W = new int [o][p+1]; C = new int [n][o]; Ct = new int [o]; } /******************************************/ /* 訓練例の分類 */ /* return : 正しく分類した訓練例の数 */ /******************************************/ int bunrui() { int cor, i1, i2, i3, mx = 0, mx_v = 0, num = 0, s, sw = 0; for (i1 = 0; i1 < n; i1++) { cor = 0; for (i2 = 0; i2 < o; i2++) { if (C[i1][i2] == 1) cor = i2; s = 0; for (i3 = 0; i3 <= p; i3++) s += W[i2][i3] * E[i1][i3]; if (i2 == 0) { mx = 0; mx_v = s; } else { if (s > mx_v) { mx = i2; mx_v = s; sw = 0; } else { if (s == mx_v) sw = 1; } } } if (sw == 0 && cor == mx) num++; } return num; } /**************************/ /* 学習データの読み込み */ /* name : ファイル名 */ /**************************/ void input (String name) throws IOException, FileNotFoundException { int i1, i2; StringTokenizer str; BufferedReader st = new BufferedReader(new FileReader(name)); str = new StringTokenizer(st.readLine(), " "); for (i1 = 0; i1 < n; i1++) { E[i1][0] = 1; str = new StringTokenizer(st.readLine(), " "); for (i2 = 1; i2 <= p; i2++) E[i1][i2] = Integer.parseInt(str.nextToken()); for (i2 = 0; i2 < o; i2++) C[i1][i2] = Integer.parseInt(str.nextToken()); } st.close(); } /*********************************/ /* 学習と結果の出力 */ /* pr : =0 : 画面に出力 */ /* =1 : ファイルに出力 */ /* name : 出力ファイル名 */ /*********************************/ void learn(int pr, String name) throws FileNotFoundException { int i1, i2, i3, mx = 0, mx_v = 0, n_tri, s, sw; int num[] = new int [1]; // 学習 n_tri = pocket(num); // 結果の出力 if (pr == 0) { System.out.print("重み\n"); for (i1 = 0; i1 < o; i1++) { for (i2 = 0; i2 <= p; i2++) System.out.print(" " + W_p[i1][i2]); System.out.println(); } System.out.print("分類結果\n"); for (i1 = 0; i1 < n; i1++) { sw = 0; for (i2 = 0; i2 < o; i2++) { s = 0; for (i3 = 0; i3 <= p; i3++) s += W_p[i2][i3] * E[i1][i3]; if (i2 == 0) { mx_v = s; mx = 0; } else { if (s > mx_v) { sw = 0; mx_v = s; mx = i2; } else { if (s == mx_v) sw = 1; } } } for (i2 = 1; i2 <= p; i2++) System.out.print(" " + E[i1][i2]); System.out.print(" Cor "); for (i2 = 0; i2 < o; i2++) System.out.print(" " + C[i1][i2]); if (sw > 0) mx = -1; System.out.println(" Res " + (mx+1)); } } else { PrintStream out = new PrintStream(new FileOutputStream(name)); out.print("重み\n"); for (i1 = 0; i1 < o; i1++) { for (i2 = 0; i2 <= p; i2++) out.print(" " + W_p[i1][i2]); out.println(); } out.print("分類結果\n"); for (i1 = 0; i1 < n; i1++) { sw = 0; for (i2 = 0; i2 < o; i2++) { s = 0; for (i3 = 0; i3 <= p; i3++) s += W_p[i2][i3] * E[i1][i3]; if (i2 == 0) { mx_v = s; mx = 0; } else { if (s > mx_v) { sw = 0; mx_v = s; mx = i2; } else { if (s == mx_v) sw = 1; } } } for (i2 = 1; i2 <= p; i2++) out.print(" " + E[i1][i2]); out.print(" Cor "); for (i2 = 0; i2 < o; i2++) out.print(" " + C[i1][i2]); if (sw > 0) mx = -1; out.println(" Res " + (mx+1)); } out.close(); } if (n == num[0]) System.out.print(" !!すべてを分類(試行回数:" + n_tri + ")\n"); else System.out.print(" !!" + num[0] + " 個を分類\n"); } /********************************************/ /* Pocket Algorith with Ratcet */ /* num_p : 正しく分類した訓練例の数 */ /* return : =0 : 最大学習回数 */ /* >0 : すべてを分類(回数) */ /********************************************/ int pocket(int num_p[]) { int cor, count = 0, i1, i2, k, mx = 0, num, run = 0, run_p = 0, s, sw = -1, sw1; /* 初期設定 */ num_p[0] = 0; for (i1 = 0; i1 < o; i1++) { for (i2 = 0; i2 <= p; i2++) W[i1][i2] = 0; } /* 実行 */ while (sw < 0) { // 終了チェック count++;; if (count > max) sw = 0; else { // 訓練例の選択 k = (int)(rn.nextDouble() * n); if (k >= n) k = n - 1; // 出力の計算 sw1 = 0; cor = -1; for (i1 = 0; i1 < o; i1++) { if (C[k][i1] == 1) cor = i1; s = 0; for (i2 = 0; i2 <= p; i2++) s += W[i1][i2] * E[k][i2]; Ct[i1] = s; if (i1 == 0) mx = 0; else { if (s > Ct[mx]) { mx = i1; sw1 = 0; } else { if (s == Ct[mx]) { sw1 = 1; if (cor >= 0 && mx == cor) mx = i1; } } } } // 正しい分類 if (sw1 == 0 && cor == mx) { run++; if (run > run_p) { num = bunrui(); if (num > num_p[0]) { num_p[0] = num; run_p = run; for (i1 = 0; i1 < o; i1++) { for (i2 = 0; i2 <= p; i2++) W_p[i1][i2] = W[i1][i2]; } if (num == n) sw = count; } } } // 誤った分類 else { run = 0; for (i1 = 0; i1 <= p; i1++) { W[cor][i1] += E[k][i1]; W[mx][i1] -= E[k][i1]; } } } } return sw; } }
最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4 乱数 123 入力データファイル or.dat
OR演算の訓練例.各行における最後の2つのデータが目標出力値 -1 -1 -1 1 -1 1 1 -1 1 -1 1 -1 1 1 1 -1
重み 1 1 1 -1 -1 -1 分類結果 -1 -1 Cor -1 1 Res 2 -1 1 Cor 1 -1 Res 1 1 -1 Cor 1 -1 Res 1 1 1 Cor 1 -1 Res 1 !!すべてを分類(試行回数:11)