------------------------入力ファイル-------------- 最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 乱数 123 係数(0~1) 0.1 入力データファイル pat.dat ------------------------pat.dat------------------- 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 ------------------------プログラム---------------- /****************************/ /* 競合学習 */ /* coded by Y.Suganuma */ /****************************/ import java.io.*; import java.util.Random; import java.util.StringTokenizer; public class Test { /****************/ /* main program */ /****************/ public static void main(String args[]) throws IOException, FileNotFoundException { double sig; int max, n, o, p, seed; String name; StringTokenizer str; if (args.length > 0) { // 基本データの入力 BufferedReader in = new BufferedReader(new FileReader(args[0])); str = new StringTokenizer(in.readLine(), " "); str.nextToken(); max = Integer.parseInt(str.nextToken()); str.nextToken(); p = Integer.parseInt(str.nextToken()); str.nextToken(); o = Integer.parseInt(str.nextToken()); str.nextToken(); n = Integer.parseInt(str.nextToken()); str = new StringTokenizer(in.readLine(), " "); str.nextToken(); seed= Integer.parseInt(str.nextToken()); str.nextToken(); sig = Double.parseDouble(str.nextToken()); str = new StringTokenizer(in.readLine(), " "); str.nextToken(); name = str.nextToken(); in.close(); // ネットワークの定義 Competition net = new Competition (max, n, o, p, seed, sig); net.input(name); // 学習と結果の出力 if (args.length == 1) net.learn(0, ""); else net.learn(1, args[1]); } // エラー else { System.out.print("***error 入力データファイル名を指定して下さい\n"); System.exit(1); } } } /***************************/ /* Competitionクラスの定義 */ /***************************/ class Competition { private int max; // 最大学習回数 private int n; // 訓練例の数 private int o; // 出力セルの数 private int p; // 入力セルの数 private int E[][]; // 訓練例 private double W[][]; // 重み private double sig; // 重み修正係数 private Random rn; // 乱数 /******************/ /* コンストラクタ */ /******************/ Competition (int max_i, int n_i, int o_i, int p_i, int seed, double sig_i) { /* 設定 */ max = max_i; n = n_i; o = o_i; p = p_i; sig = sig_i; rn = new Random(seed); // 乱数の初期設定 /* 領域の確保 */ E = new int [n][p]; W = new double [o][p]; } /**************************/ /* 学習データの読み込み */ /* name : ファイル名 */ /**************************/ void input (String name) throws IOException, FileNotFoundException { int i1, i2; StringTokenizer str; BufferedReader st = new BufferedReader(new FileReader(name)); for (i1 = 0; i1 < n; i1++) { str = new StringTokenizer(st.readLine(), " "); for (i2 = 0; i2 < p; i2++) E[i1][i2] = Integer.parseInt(str.nextToken()); } st.close(); } /*********************************/ /* 学習と結果の出力 */ /* pr : =0 : 画面に出力 */ /* =1 : ファイルに出力 */ /* name : 出力ファイル名 */ /*********************************/ void learn(int pr, String name) throws FileNotFoundException { double mx_v = 0.0, s, sum; int count, i1, i2, i3, k, mx = 0; /* 初期設定 */ for (i1 = 0; i1 < o; i1++) { sum = 0.0; for (i2 = 0; i2 < p; i2++) { W[i1][i2] = rn.nextDouble(); sum += W[i1][i2]; } sum = 1.0 / sum; for (i2 = 0; i2 < p; i2++) W[i1][i2] *= sum; } /* 学習 */ for (count = 0; count < max; count++) { // 訓練例の選択 k = (int)(rn.nextDouble() * n); if (k >= n) k = n - 1; // 出力の計算 for (i1 = 0; i1 < o; i1++) { s = 0.0; for (i2 = 0; i2 < p; i2++) s += W[i1][i2] * E[k][i2]; if (i1 == 0 || s > mx_v) { mx = i1; mx_v = s; } } // 重みの修正 sum = 0.0; for (i1 = 0; i1 < p; i1++) sum += E[k][i1]; for (i1 = 0; i1 < p; i1++) W[mx][i1] += sig * (E[k][i1] / sum - W[mx][i1]); } /* 出力 */ if (pr == 0) { System.out.print("分類結果\n"); for (i1 = 0; i1 < n; i1++) { for (i2 = 0; i2 < p; i2++) System.out.print(" " + E[i1][i2]); System.out.print(" Res "); for (i2 = 0; i2 < o; i2++) { s = 0.0; for (i3 = 0; i3 < p; i3++) s += W[i2][i3] * E[i1][i3]; if (i2 == 0 || s > mx_v) { mx = i2; mx_v = s; } } System.out.println((mx+1)); } } else { PrintStream out = new PrintStream(new FileOutputStream(name)); out.print("分類結果\n"); for (i1 = 0; i1 < n; i1++) { for (i2 = 0; i2 < p; i2++) out.print(" " + E[i1][i2]); out.print(" Res "); for (i2 = 0; i2 < o; i2++) { s = 0.0; for (i3 = 0; i3 < p; i3++) s += W[i2][i3] * E[i1][i3]; if (i2 == 0 || s > mx_v) { mx = i2; mx_v = s; } } out.println((mx+1)); } out.close(); } } }
最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 乱数 123 係数(0~1) 0.1 入力データファイル pat.dat
0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0
分類結果 0 0 0 0 0 0 0 1 1 Res 3 0 0 0 0 0 0 1 0 1 Res 3 0 0 0 0 0 0 1 1 0 Res 3 0 0 0 0 1 1 0 0 0 Res 1 0 0 0 1 0 1 0 0 0 Res 1 0 0 0 1 1 0 0 0 0 Res 1 0 1 1 0 0 0 0 0 0 Res 2 1 0 1 0 0 0 0 0 0 Res 2 1 1 0 0 0 0 0 0 0 Res 2