競合学習

  以下,複数のファイル構成になっています.ファイル間の区切りを「---・・・」で示します.

------------------------入力ファイル--------------
最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9
乱数 123 係数(0~1) 0.1
入力データファイル pat.dat

------------------------pat.dat-------------------
0 0 0 0 0 0 0 1 1
0 0 0 0 0 0 1 0 1
0 0 0 0 0 0 1 1 0
0 0 0 0 1 1 0 0 0
0 0 0 1 0 1 0 0 0
0 0 0 1 1 0 0 0 0
0 1 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 0
1 1 0 0 0 0 0 0 0

------------------------プログラム----------------
/****************************/
/* 競合学習                 */
/*      coded by Y.Suganuma */
/****************************/
import java.io.*;
import java.util.Random;
import java.util.StringTokenizer;

public class Test {
	/****************/
	/* main program */
	/****************/
	public static void main(String args[]) throws IOException, FileNotFoundException
	{
		double sig;
		int max, n, o, p, seed;
		String name;
		StringTokenizer str;

		if (args.length > 0) {
					// 基本データの入力
			BufferedReader in = new BufferedReader(new FileReader(args[0]));

			str = new StringTokenizer(in.readLine(), " ");
			str.nextToken();
			max = Integer.parseInt(str.nextToken());
			str.nextToken();
			p = Integer.parseInt(str.nextToken());
			str.nextToken();
			o = Integer.parseInt(str.nextToken());
			str.nextToken();
			n = Integer.parseInt(str.nextToken());

			str = new StringTokenizer(in.readLine(), " ");
			str.nextToken();
			seed= Integer.parseInt(str.nextToken());
			str.nextToken();
			sig = Double.parseDouble(str.nextToken());

			str = new StringTokenizer(in.readLine(), " ");
			str.nextToken();
			name = str.nextToken();

			in.close();
					// ネットワークの定義
			Competition net = new Competition (max, n, o, p, seed, sig);
			net.input(name);
					// 学習と結果の出力
			if (args.length == 1)
				net.learn(0, "");
			else
				net.learn(1, args[1]);
		}
					// エラー
		else {
			System.out.print("***error   入力データファイル名を指定して下さい\n");
			System.exit(1);
		}
	}
}

/***************************/
/* Competitionクラスの定義 */
/***************************/
class Competition {

	private int max;   // 最大学習回数
	private int n;   // 訓練例の数
	private int o;   // 出力セルの数
	private int p;   // 入力セルの数
	private int E[][];   // 訓練例
	private double W[][];   // 重み
	private double sig;   // 重み修正係数
	private Random rn;   // 乱数

	/******************/
	/* コンストラクタ */
	/******************/
	Competition (int max_i, int n_i, int o_i, int p_i, int seed, double sig_i)
	{
	/*
	     設定
	*/
		max = max_i;
		n   = n_i;
		o   = o_i;
		p   = p_i;
		sig = sig_i;

		rn  = new Random(seed);   // 乱数の初期設定
	/*
	     領域の確保
	*/
		E   = new int [n][p];
		W   = new double [o][p];
	}

	/**************************/
	/* 学習データの読み込み   */
	/*      name : ファイル名 */
	/**************************/
	void input (String name) throws IOException, FileNotFoundException
	{
		int i1, i2;
		StringTokenizer str;

		BufferedReader st = new BufferedReader(new FileReader(name));

		for (i1 = 0; i1 < n; i1++) {
			str = new StringTokenizer(st.readLine(), " ");
			for (i2 = 0; i2 < p; i2++)
				E[i1][i2] = Integer.parseInt(str.nextToken());
		}

		st.close();
	}

	/*********************************/
	/* 学習と結果の出力              */
	/*      pr : =0 : 画面に出力     */
	/*           =1 : ファイルに出力 */
	/*      name : 出力ファイル名    */
	/*********************************/
	void learn(int pr, String name) throws FileNotFoundException
	{
		double mx_v = 0.0, s, sum;
		int count, i1, i2, i3, k, mx = 0;
	/*
	     初期設定
	*/
		for (i1 = 0; i1 < o; i1++) {
			sum = 0.0;
			for (i2 = 0; i2 < p; i2++) {
				W[i1][i2]  = rn.nextDouble();
				sum       += W[i1][i2];
			}
			sum = 1.0 / sum;
			for (i2 = 0; i2 < p; i2++)
				W[i1][i2] *= sum;
		}
	/*
	     学習
	*/
		for (count = 0; count < max; count++) {
					// 訓練例の選択
			k = (int)(rn.nextDouble() * n);
			if (k >= n)
				k = n - 1;
					// 出力の計算
			for (i1 = 0; i1 < o; i1++) {
				s = 0.0;
				for (i2 = 0; i2 < p; i2++)
					s += W[i1][i2] * E[k][i2];
				if (i1 == 0 || s > mx_v) {
					mx   = i1;
					mx_v = s;
				}
			}
					// 重みの修正
			sum = 0.0;
			for (i1 = 0; i1 < p; i1++)
				sum += E[k][i1];
			for (i1 = 0; i1 < p; i1++)
				W[mx][i1] += sig * (E[k][i1] / sum - W[mx][i1]);
		}
	/*
	     出力
	*/
		if (pr == 0) {

			System.out.print("分類結果\n");
			for (i1 = 0; i1 < n; i1++) {
				for (i2 = 0; i2 < p; i2++)
					System.out.print(" " + E[i1][i2]);
				System.out.print(" Res ");
				for (i2 = 0; i2 < o; i2++) {
					s = 0.0;
					for (i3 = 0; i3 < p; i3++)
	    		        s += W[i2][i3] * E[i1][i3];
					if (i2 == 0 || s > mx_v) {
						mx   = i2;
						mx_v = s;
					}
				}
				System.out.println((mx+1));
			}
		}

		else {

			PrintStream out = new PrintStream(new FileOutputStream(name));

			out.print("分類結果\n");
			for (i1 = 0; i1 < n; i1++) {
				for (i2 = 0; i2 < p; i2++)
					out.print(" " + E[i1][i2]);
				out.print(" Res ");
				for (i2 = 0; i2 < o; i2++) {
					s = 0.0;
					for (i3 = 0; i3 < p; i3++)
	    		        s += W[i2][i3] * E[i1][i3];
					if (i2 == 0 || s > mx_v) {
						mx   = i2;
						mx_v = s;
					}
				}
				out.println((mx+1));
			}

			out.close();
		}
	}
}
		

  コンパイルした後,

java Test 入力ファイル名 出力ファイル名

と入力してやれば実行できます.出力ファイル名は,結果を出力するファイルの名前であり,省略すると画面に出力されます.また,入力ファイル名は,実行に必要なデータを記述したファイルの名前であり,たとえば以下のような形式で作成します.
最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9
乱数 123 係数(0~1) 0.1
入力データファイル pat.dat		
  日本語で記述した部分(「最大試行回数」,「入力セルの数」等)は,次に続くデータの説明ですのでどのように修正しても構いませんが,削除したり,または,複数の文(間に半角のスペースを入れる)にするようなことはしないでください.各データの意味は以下に示す通りです.

最大試行回数

  学習回数を入力します.この例では 1000 を与えています.

入力セルの数

  入力セル(入力ユニット)の数を入力します.この例では 9 となっています.

出力セルの数

  出力セル(出力ユニット)の数を入力します.この例では 3 となっています.

訓練例の数

  訓練例(分類すべきパターン)の数を入力します(この例では 9 ).訓練例は,「入力データファイル」の項に入力されたファイル(この例では,pat.dat )に記述します.ファイル pat.dat は,たとえば,以下のようになります.各行のデータが一つのパターンに相当しています.
	0 0 0 0 0 0 0 1 1
	0 0 0 0 0 0 1 0 1
	0 0 0 0 0 0 1 1 0
	0 0 0 0 1 1 0 0 0
	0 0 0 1 0 1 0 0 0
	0 0 0 1 1 0 0 0 0
	0 1 1 0 0 0 0 0 0
	1 0 1 0 0 0 0 0 0
	1 1 0 0 0 0 0 0 0			
乱数

  乱数の初期値です.

係数(0~1)

  重みを修正する際の係数です.0 と 1 の間の数値を入力してください(この例では,0.1 )

  上で説明したデータの元で実行すると,たとえば,以下のような出力が得られます.
分類結果
 0 0 0 0 0 0 0 1 1 Res 3
 0 0 0 0 0 0 1 0 1 Res 3
 0 0 0 0 0 0 1 1 0 Res 3
 0 0 0 0 1 1 0 0 0 Res 1
 0 0 0 1 0 1 0 0 0 Res 1
 0 0 0 1 1 0 0 0 0 Res 1
 0 1 1 0 0 0 0 0 0 Res 2
 1 0 1 0 0 0 0 0 0 Res 2
 1 1 0 0 0 0 0 0 0 Res 2		
  各行において,res の後ろに書かれた数値が,その左側に与えられたパターンが分類された結果です.たとえば,2 行目は,パターン「0 0 0 0 0 0 0 1 1」がグループ 3 に分類された(出力ユニットの内,3 番目のユニットの活性度が最も大きくなった)ことを意味しています.