Hopfieldネットワーク(連想記憶)

/************************************/
/* Hopfieldネットワーク(連想記憶) */
/*      Coded by Y.Suganuma)        */
/************************************/
import java.io.*;
import java.util.*;

public class Hopfield_d {
	public static void main(String args[])
	{
		int a[][] = {{0, 0, 1, 1, 0, 0,
		              0, 0, 1, 1, 0, 0,
		              0, 0, 1, 1, 0, 0,
		              0, 0, 1, 1, 0, 0,
		              0, 0, 1, 1, 0, 0,
		              0, 0, 1, 1, 0, 0},
		             {0, 0, 0, 0, 0, 0,
		              0, 0, 0, 0, 0, 0,
		              1, 1, 1, 1, 1, 1,
		              1, 1, 1, 1, 1, 1,
		              0, 0, 0, 0, 0, 0,
		              0, 0, 0, 0, 0, 0},
		             {1, 1, 0, 0, 0, 0,
		              1, 1, 1, 0, 0, 0,
		              0, 1, 1, 1, 0, 0,
		              0, 0, 1, 1, 1, 0,
		              0, 0, 0, 1, 1, 1,
		              0, 0, 0, 0, 1, 1}};
		int W[][] = new int [36][36], n = 36, p = 3;
					// 重みの設定(学習)
		for (int i1 = 0; i1 < n-1; i1++) {
			for (int i2 = i1+1; i2 < n; i2++) {
				W[i1][i2] = 0;
				for (int i3 = 0; i3 < p; i3++)
					W[i1][i2] += (2 * a[i3][i1] - 1) * (2 * a[i3][i2] - 1);
				W[i2][i1] = W[i1][i2];
			}
		}
		for (int i1 = 0; i1 < n; i1++)
			W[i1][i1] = 0;
					// 初期状態
		Random rn = new Random();
		Console con = System.console();
		String line = con.readLine("パターン番号と修正ユニット数 ");
		StringTokenizer str = new StringTokenizer(line, " ");
		int pn = Integer.parseInt(str.nextToken());
		int m = Integer.parseInt(str.nextToken());
		int u[] = new int [36];
		for (int i1 = 0; i1 < n; i1++)
			u[i1] = a[pn][i1];
		for (int i1 = 0; i1 < m; i1++) {
			int k = (int)(rn.nextDouble() * n);
			if (k >= n)
				k = n - 1;
			if (u[k] > 0)
				u[k] = 0;
			else
				u[k] = 1;
		}

		con.printf("初期状態:\n");
		int k = 0;
		for (int i1 = 0; i1 < 6; i1++) {
			for (int i2 = 0; i2 < 6; i2++) {
				con.printf("%2d", u[k]);
				k++;
			}
			con.printf("\n");
		}
					// 更新
		int count1 = 0, count2 = 0, count3 = 0;
		while (count1 < 100) {
			count2++;
			boolean sw = false;
			k = (int)(rn.nextDouble() * n);
			if (k >= n)
				k = n - 1;
			int s = 0;
			for (int i1 = 0; i1 < n; i1++)
				s += W[k][i1] * u[i1];
			if (s >= 0) {
				if (u[k] == 0) {
					sw = true;
					u[k] = 1;
				}
			}
			else {
				if (u[k] > 0) {
					sw = true;
					u[k] = 0;
				}
			}
			if (sw) {
				count1 = 0;
				count3++;
			}
			else
				count1++;
		}
					// 結果
		con.printf("試行回数 = %d,更新回数 = %d\n", count2, count3);
		k = 0;
		for (int i1 = 0; i1 < 6; i1++) {
			for (int i2 = 0; i2 < 6; i2++) {
				con.printf("%2d", u[k]);
				k++;
			}
			con.printf("\n");
		}
	}
}