------------------------入力ファイル--------------
最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9
乱数 123 係数(0~1) 0.1
入力データファイル pat.dat
------------------------pat.dat-------------------
0 0 0 0 0 0 0 1 1
0 0 0 0 0 0 1 0 1
0 0 0 0 0 0 1 1 0
0 0 0 0 1 1 0 0 0
0 0 0 1 0 1 0 0 0
0 0 0 1 1 0 0 0 0
0 1 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 0
1 1 0 0 0 0 0 0 0
------------------------プログラム----------------
/****************************/
/* 競合学習 */
/* coded by Y.Suganuma */
/****************************/
import java.io.*;
import java.util.Random;
import java.util.StringTokenizer;
public class Test {
/****************/
/* main program */
/****************/
public static void main(String args[]) throws IOException, FileNotFoundException
{
double sig;
int max, n, o, p, seed;
String name;
StringTokenizer str;
if (args.length > 0) {
// 基本データの入力
BufferedReader in = new BufferedReader(new FileReader(args[0]));
str = new StringTokenizer(in.readLine(), " ");
str.nextToken();
max = Integer.parseInt(str.nextToken());
str.nextToken();
p = Integer.parseInt(str.nextToken());
str.nextToken();
o = Integer.parseInt(str.nextToken());
str.nextToken();
n = Integer.parseInt(str.nextToken());
str = new StringTokenizer(in.readLine(), " ");
str.nextToken();
seed= Integer.parseInt(str.nextToken());
str.nextToken();
sig = Double.parseDouble(str.nextToken());
str = new StringTokenizer(in.readLine(), " ");
str.nextToken();
name = str.nextToken();
in.close();
// ネットワークの定義
Competition net = new Competition (max, n, o, p, seed, sig);
net.input(name);
// 学習と結果の出力
if (args.length == 1)
net.learn(0, "");
else
net.learn(1, args[1]);
}
// エラー
else {
System.out.print("***error 入力データファイル名を指定して下さい\n");
System.exit(1);
}
}
}
/***************************/
/* Competitionクラスの定義 */
/***************************/
class Competition {
private int max; // 最大学習回数
private int n; // 訓練例の数
private int o; // 出力セルの数
private int p; // 入力セルの数
private int E[][]; // 訓練例
private double W[][]; // 重み
private double sig; // 重み修正係数
private Random rn; // 乱数
/******************/
/* コンストラクタ */
/******************/
Competition (int max_i, int n_i, int o_i, int p_i, int seed, double sig_i)
{
/*
設定
*/
max = max_i;
n = n_i;
o = o_i;
p = p_i;
sig = sig_i;
rn = new Random(seed); // 乱数の初期設定
/*
領域の確保
*/
E = new int [n][p];
W = new double [o][p];
}
/**************************/
/* 学習データの読み込み */
/* name : ファイル名 */
/**************************/
void input (String name) throws IOException, FileNotFoundException
{
int i1, i2;
StringTokenizer str;
BufferedReader st = new BufferedReader(new FileReader(name));
for (i1 = 0; i1 < n; i1++) {
str = new StringTokenizer(st.readLine(), " ");
for (i2 = 0; i2 < p; i2++)
E[i1][i2] = Integer.parseInt(str.nextToken());
}
st.close();
}
/*********************************/
/* 学習と結果の出力 */
/* pr : =0 : 画面に出力 */
/* =1 : ファイルに出力 */
/* name : 出力ファイル名 */
/*********************************/
void learn(int pr, String name) throws FileNotFoundException
{
double mx_v = 0.0, s, sum;
int count, i1, i2, i3, k, mx = 0;
/*
初期設定
*/
for (i1 = 0; i1 < o; i1++) {
sum = 0.0;
for (i2 = 0; i2 < p; i2++) {
W[i1][i2] = rn.nextDouble();
sum += W[i1][i2];
}
sum = 1.0 / sum;
for (i2 = 0; i2 < p; i2++)
W[i1][i2] *= sum;
}
/*
学習
*/
for (count = 0; count < max; count++) {
// 訓練例の選択
k = (int)(rn.nextDouble() * n);
if (k >= n)
k = n - 1;
// 出力の計算
for (i1 = 0; i1 < o; i1++) {
s = 0.0;
for (i2 = 0; i2 < p; i2++)
s += W[i1][i2] * E[k][i2];
if (i1 == 0 || s > mx_v) {
mx = i1;
mx_v = s;
}
}
// 重みの修正
sum = 0.0;
for (i1 = 0; i1 < p; i1++)
sum += E[k][i1];
for (i1 = 0; i1 < p; i1++)
W[mx][i1] += sig * (E[k][i1] / sum - W[mx][i1]);
}
/*
出力
*/
if (pr == 0) {
System.out.print("分類結果\n");
for (i1 = 0; i1 < n; i1++) {
for (i2 = 0; i2 < p; i2++)
System.out.print(" " + E[i1][i2]);
System.out.print(" Res ");
for (i2 = 0; i2 < o; i2++) {
s = 0.0;
for (i3 = 0; i3 < p; i3++)
s += W[i2][i3] * E[i1][i3];
if (i2 == 0 || s > mx_v) {
mx = i2;
mx_v = s;
}
}
System.out.println((mx+1));
}
}
else {
PrintStream out = new PrintStream(new FileOutputStream(name));
out.print("分類結果\n");
for (i1 = 0; i1 < n; i1++) {
for (i2 = 0; i2 < p; i2++)
out.print(" " + E[i1][i2]);
out.print(" Res ");
for (i2 = 0; i2 < o; i2++) {
s = 0.0;
for (i3 = 0; i3 < p; i3++)
s += W[i2][i3] * E[i1][i3];
if (i2 == 0 || s > mx_v) {
mx = i2;
mx_v = s;
}
}
out.println((mx+1));
}
out.close();
}
}
}