パーセプトロン学習

  以下,複数のファイル構成になっています.ファイル間の区切りを「---・・・」で示します.

------------------------入力ファイル--------------
最大試行回数 100 入力セルの数 2 訓練例の数 4 乱数 123
入力データファイル or.dat

------------------------or.dat--------------------
OR演算の訓練例.各行の最後のデータが目標出力値
-1 -1 -1
-1  1  1
 1 -1  1
 1  1  1

------------------------プログラム----------------
/***********************************/
/* パーセプトロン学習              */
/* (Pocket Algorith with Ratcet) */
/*      coded by Y.Suganuma        */
/***********************************/
import java.io.*;
import java.util.Random;
import java.util.StringTokenizer;

public class Test {
	/****************/
	/* main program */
	/****************/
	public static void main(String args[]) throws IOException, FileNotFoundException
	{
		int max, n, p, seed;
		String name;
		StringTokenizer str;

		if (args.length > 0) {
					// 基本データの入力
			BufferedReader in = new BufferedReader(new FileReader(args[0]));

			str = new StringTokenizer(in.readLine(), " ");
			str.nextToken();
			max = Integer.parseInt(str.nextToken());
			str.nextToken();
			p = Integer.parseInt(str.nextToken());
			str.nextToken();
			n = Integer.parseInt(str.nextToken());
			str.nextToken();
			seed = Integer.parseInt(str.nextToken());

			str = new StringTokenizer(in.readLine(), " ");
			str.nextToken();
			name = str.nextToken();

			in.close();
					// ネットワークの定義
			Perceptron net = new Perceptron (max, n, p, seed);
			net.input(name);
					// 学習と結果の出力
			if (args.length == 1)
				net.learn(0, "");
			else
				net.learn(1, args[1]);
		}
					// エラー
		else {
			System.out.print("***error   入力データファイル名を指定して下さい\n");
			System.exit(1);
		}
	}
}

/**************************/
/* Perceptronクラスの定義 */
/**************************/
class Perceptron {

	private int max;   // 最大学習回数
	private int n;   // 訓練例の数
	private int p;   // 入力セルの数
	private int W_p[];   // 重み(ポケット)
	private int W[];   // 重み
	private int E[][];   // 訓練例
	private int C[];   // 各訓練例に対する正しい出力
	private Random rn;   // 乱数

	/******************/
	/* コンストラクタ */
	/******************/
	Perceptron(int max_i, int n_i, int p_i, int seed)
	{
	/*
	     設定
	*/
		max = max_i;
		n   = n_i;
		p   = p_i;

		rn  = new Random(seed);   // 乱数の初期設定
	/*
	     領域の確保
	*/
		E   = new int [n][p+1];
		W_p = new int [p+1];
		W   = new int [p+1];
		C   = new int [n];
	}

	/******************************************/
	/* 訓練例の分類                           */
	/*      return : 正しく分類した訓練例の数 */
	/******************************************/
	int bunrui()
	{
		int i1, i2, num = 0, s;

		for (i1 = 0; i1 < n; i1++) {
			s = 0;
			for (i2 = 0; i2 <= p; i2++)
				s += W[i2] * E[i1][i2];
			if ((s > 0 && C[i1] > 0) || (s < 0 && C[i1] < 0))
				num++;
		}

		return num;
	}

	/**************************/
	/* 学習データの読み込み   */
	/*      name : ファイル名 */
	/**************************/
	void input (String name) throws IOException, FileNotFoundException
	{
		int i1, i2;
		StringTokenizer str;

		BufferedReader st = new BufferedReader(new FileReader(name));

		st.readLine();

		for (i1 = 0; i1 < n; i1++) {
			E[i1][0] = 1;
			str = new StringTokenizer(st.readLine(), " ");
			for (i2 = 1; i2 <= p; i2++)
				E[i1][i2] = Integer.parseInt(str.nextToken());
			C[i1] = Integer.parseInt(str.nextToken());
		}

		st.close();
	}

	/*********************************/
	/* 学習と結果の出力              */
	/*      pr : =0 : 画面に出力     */
	/*           =1 : ファイルに出力 */
	/*      name : 出力ファイル名    */
	/*********************************/
	void learn(int pr, String name) throws FileNotFoundException
	{
		int i1, i2, n_tri, s;
		int num[] = new int [1];
					// 学習
		n_tri = pocket(num);
					// 結果の出力
		if (pr == 0) {

			System.out.print("重み\n");
			for (i1 = 0; i1 <= p; i1++)
				System.out.print("  " + W_p[i1]);
			System.out.println();

			System.out.print("分類結果\n");
			for (i1 = 0; i1 < n; i1++) {
				s  = 0;
				for (i2 = 0; i2 <= p; i2++)
					s += E[i1][i2] * W_p[i2];
				if (s > 0)
					s = 1;
				else
					s = (s < 0) ? -1 : 0;
				for (i2 = 1; i2 <= p; i2++)
					System.out.print(" " + E[i1][i2]);
				System.out.println(" Cor " + C[i1] + " Res " + s);
			}
		}

		else {

			PrintStream out = new PrintStream(new FileOutputStream(name));

			out.print("重み\n");
			for (i1 = 0; i1 <= p; i1++)
				out.print("  " + W_p[i1]);
			out.println();

			out.print("分類結果\n");
			for (i1 = 0; i1 < n; i1++) {
				s  = 0;
				for (i2 = 0; i2 <= p; i2++)
					s += E[i1][i2] * W_p[i2];
				if (s > 0)
					s = 1;
				else
					s = (s < 0) ? -1 : 0;
				for (i2 = 1; i2 <= p; i2++)
					out.print(" " + E[i1][i2]);
				out.println(" Cor " + C[i1] + " Res " + s);
			}

			out.close();
		}

		if (n == num[0])
			System.out.print("  !!すべてを分類(試行回数:" + n_tri + ")\n");
		else
			System.out.print("  !!" + num[0] + " 個を分類\n");
	}

	/********************************************/
	/* Pocket Algorith with Ratcet              */
	/*      num_p : 正しく分類した訓練例の数    */
	/*      return : =0 : 最大学習回数          */
	/*               >0  : すべてを分類(回数) */
	/********************************************/
	int pocket(int num_p[])
	{
		int count = 0, i1, k, num, run = 0, run_p = 0, s, sw = -1;
	/*
	     初期設定
	*/
		num_p[0] = 0;

		for (i1 = 0; i1 <= p; i1++)
			W[i1] = 0;
	/*
	     実行
	*/
		while (sw < 0) {

			count++;
			if (count > max)
				sw = 0;

			else {
					// 訓練例の選択
				k = (int)(rn.nextDouble() * n);
				if (k >= n)
					k = n - 1;
					// 出力の計算
				s = 0;
				for (i1 = 0; i1 <= p; i1++)
					s += W[i1] * E[k][i1];
					// 正しい分類
				if ((s > 0 && C[k] > 0) || (s < 0 && C[k] < 0)) {
					run++;
					if (run > run_p) {
						num = bunrui();
						if (num > num_p[0]) {
							num_p[0] = num;
							run_p  = run;
							for (i1 = 0; i1 <= p; i1++)
								W_p[i1] = W[i1];
							if (num == n)
								sw = count;
						}
					}
				}
					// 誤った分類
				else {
					run = 0;
					for (i1 = 0; i1 <= p; i1++)
						W[i1] += C[k] * E[k][i1];
				}
			}
		}

		return sw;
	}
}
		

  コンパイルした後,

java Test 入力ファイル名 出力ファイル名

と入力してやれば実行できます.出力ファイル名は,結果を出力するファイルの名前であり,省略すると画面に出力されます.また,入力ファイル名は,実行に必要なデータを記述したファイルの名前であり,たとえば以下のような形式で作成します.
最大試行回数 100 入力セルの数 2 訓練例の数 4 乱数 123
入力データファイル or.dat		
  日本語で記述した部分(「最大試行回数」,「入力セルの数」等)は,次に続くデータの説明ですのでどのように修正しても構いませんが,削除したり,または,複数の文(間に半角のスペースを入れる)にするようなことはしないでください.各データの意味は以下に示す通りです.
最大試行回数

  学習回数を入力します.この例では 100 を与えています.

入力セルの数

  入力セル(入力ユニット)の数を入力します.この例では 2 となっています.

訓練例の数

  訓練例の数を入力します(この例では 4 ).訓練例は,「入力データファイル」の項に入力されたファイル(この例では,or.dat )に記述します.ファイル or.dat は,たとえば,以下のようになります.
	OR演算の訓練例.各行の最後のデータが目標出力値
	-1 -1 -1
	-1  1  1
	 1 -1  1
	 1  1  1			
  1 行目はこのファイルの説明であり,何を記述しても構いません.ただし,文全体を削除したり,文の途中に半角のスペースを入れるようなことはしないでください.2 行目以下が 4 つの訓練例を表しています.各訓練例において,最初の 2 つの値が各入力ユニットに入力される値であり,3 番目の値が,そのときの目標出力値になっています.

乱数

  乱数の初期値です.
  上で説明したデータの元で実行すると,たとえば,以下のような出力が得られます.出力ファイル名を指定した場合は,最後の 1 行だけがコンソールに,残りはファイルに出力されます.
重み
  1  1  1
分類結果
 -1 -1 Cor -1 Res -1
 -1 1 Cor 1 Res 1
 1 -1 Cor 1 Res 1
 1 1 Cor 1 Res 1
  !!すべてを分類(試行回数:11)		
  2 行目の 2 番目以降のデータが,各入力ユニットから出力ユニットへ向かう枝に付けられた重みです.また,1 番目のデータはバイアスです.分類結果において,Cor の次に出力された値が,入力データ(たとえば,4 行目では「-1 -1」)に対する目標出力値であり,また,Res の後の値が実際の計算(分類)結果です.