------------------------入力ファイル--------------
最大試行回数 100 入力セルの数 2 訓練例の数 4 乱数 123
入力データファイル or.dat
------------------------or.dat--------------------
OR演算の訓練例.各行の最後のデータが目標出力値
-1 -1 -1
-1 1 1
1 -1 1
1 1 1
------------------------プログラム----------------
/***********************************/
/* パーセプトロン学習 */
/* (Pocket Algorith with Ratcet) */
/* coded by Y.Suganuma */
/***********************************/
import java.io.*;
import java.util.Random;
import java.util.StringTokenizer;
public class Test {
/****************/
/* main program */
/****************/
public static void main(String args[]) throws IOException, FileNotFoundException
{
int max, n, p, seed;
String name;
StringTokenizer str;
if (args.length > 0) {
// 基本データの入力
BufferedReader in = new BufferedReader(new FileReader(args[0]));
str = new StringTokenizer(in.readLine(), " ");
str.nextToken();
max = Integer.parseInt(str.nextToken());
str.nextToken();
p = Integer.parseInt(str.nextToken());
str.nextToken();
n = Integer.parseInt(str.nextToken());
str.nextToken();
seed = Integer.parseInt(str.nextToken());
str = new StringTokenizer(in.readLine(), " ");
str.nextToken();
name = str.nextToken();
in.close();
// ネットワークの定義
Perceptron net = new Perceptron (max, n, p, seed);
net.input(name);
// 学習と結果の出力
if (args.length == 1)
net.learn(0, "");
else
net.learn(1, args[1]);
}
// エラー
else {
System.out.print("***error 入力データファイル名を指定して下さい\n");
System.exit(1);
}
}
}
/**************************/
/* Perceptronクラスの定義 */
/**************************/
class Perceptron {
private int max; // 最大学習回数
private int n; // 訓練例の数
private int p; // 入力セルの数
private int W_p[]; // 重み(ポケット)
private int W[]; // 重み
private int E[][]; // 訓練例
private int C[]; // 各訓練例に対する正しい出力
private Random rn; // 乱数
/******************/
/* コンストラクタ */
/******************/
Perceptron(int max_i, int n_i, int p_i, int seed)
{
/*
設定
*/
max = max_i;
n = n_i;
p = p_i;
rn = new Random(seed); // 乱数の初期設定
/*
領域の確保
*/
E = new int [n][p+1];
W_p = new int [p+1];
W = new int [p+1];
C = new int [n];
}
/******************************************/
/* 訓練例の分類 */
/* return : 正しく分類した訓練例の数 */
/******************************************/
int bunrui()
{
int i1, i2, num = 0, s;
for (i1 = 0; i1 < n; i1++) {
s = 0;
for (i2 = 0; i2 <= p; i2++)
s += W[i2] * E[i1][i2];
if ((s > 0 && C[i1] > 0) || (s < 0 && C[i1] < 0))
num++;
}
return num;
}
/**************************/
/* 学習データの読み込み */
/* name : ファイル名 */
/**************************/
void input (String name) throws IOException, FileNotFoundException
{
int i1, i2;
StringTokenizer str;
BufferedReader st = new BufferedReader(new FileReader(name));
st.readLine();
for (i1 = 0; i1 < n; i1++) {
E[i1][0] = 1;
str = new StringTokenizer(st.readLine(), " ");
for (i2 = 1; i2 <= p; i2++)
E[i1][i2] = Integer.parseInt(str.nextToken());
C[i1] = Integer.parseInt(str.nextToken());
}
st.close();
}
/*********************************/
/* 学習と結果の出力 */
/* pr : =0 : 画面に出力 */
/* =1 : ファイルに出力 */
/* name : 出力ファイル名 */
/*********************************/
void learn(int pr, String name) throws FileNotFoundException
{
int i1, i2, n_tri, s;
int num[] = new int [1];
// 学習
n_tri = pocket(num);
// 結果の出力
if (pr == 0) {
System.out.print("重み\n");
for (i1 = 0; i1 <= p; i1++)
System.out.print(" " + W_p[i1]);
System.out.println();
System.out.print("分類結果\n");
for (i1 = 0; i1 < n; i1++) {
s = 0;
for (i2 = 0; i2 <= p; i2++)
s += E[i1][i2] * W_p[i2];
if (s > 0)
s = 1;
else
s = (s < 0) ? -1 : 0;
for (i2 = 1; i2 <= p; i2++)
System.out.print(" " + E[i1][i2]);
System.out.println(" Cor " + C[i1] + " Res " + s);
}
}
else {
PrintStream out = new PrintStream(new FileOutputStream(name));
out.print("重み\n");
for (i1 = 0; i1 <= p; i1++)
out.print(" " + W_p[i1]);
out.println();
out.print("分類結果\n");
for (i1 = 0; i1 < n; i1++) {
s = 0;
for (i2 = 0; i2 <= p; i2++)
s += E[i1][i2] * W_p[i2];
if (s > 0)
s = 1;
else
s = (s < 0) ? -1 : 0;
for (i2 = 1; i2 <= p; i2++)
out.print(" " + E[i1][i2]);
out.println(" Cor " + C[i1] + " Res " + s);
}
out.close();
}
if (n == num[0])
System.out.print(" !!すべてを分類(試行回数:" + n_tri + ")\n");
else
System.out.print(" !!" + num[0] + " 個を分類\n");
}
/********************************************/
/* Pocket Algorith with Ratcet */
/* num_p : 正しく分類した訓練例の数 */
/* return : =0 : 最大学習回数 */
/* >0 : すべてを分類(回数) */
/********************************************/
int pocket(int num_p[])
{
int count = 0, i1, k, num, run = 0, run_p = 0, s, sw = -1;
/*
初期設定
*/
num_p[0] = 0;
for (i1 = 0; i1 <= p; i1++)
W[i1] = 0;
/*
実行
*/
while (sw < 0) {
count++;
if (count > max)
sw = 0;
else {
// 訓練例の選択
k = (int)(rn.nextDouble() * n);
if (k >= n)
k = n - 1;
// 出力の計算
s = 0;
for (i1 = 0; i1 <= p; i1++)
s += W[i1] * E[k][i1];
// 正しい分類
if ((s > 0 && C[k] > 0) || (s < 0 && C[k] < 0)) {
run++;
if (run > run_p) {
num = bunrui();
if (num > num_p[0]) {
num_p[0] = num;
run_p = run;
for (i1 = 0; i1 <= p; i1++)
W_p[i1] = W[i1];
if (num == n)
sw = count;
}
}
}
// 誤った分類
else {
run = 0;
for (i1 = 0; i1 <= p; i1++)
W[i1] += C[k] * E[k][i1];
}
}
}
return sw;
}
}