競合学習

<?php

/****************************/
/* 競合学習                 */
/*      coded by Y.Suganuma */
/****************************/

/***************************/
/* Competitionクラスの定義 */
/***************************/
class Competition {
	private $max;   // 最大学習回数
	private $n;   // 訓練例の数
	private $o;   // 出力セルの数
	private $p;   // 入力セルの数
	private $E;   // 訓練例
	private $W;   // 重み
	private $sig;   // 重み修正係数

	/******************/
	/* コンストラクタ */
	/******************/
	function Competition($max_i, $n_i, $o_i, $p_i, $sig_i)
	{
	/*
	     設定
	*/
		$this->sig = $sig_i;
		$this->max = $max_i;
		$this->n   = $n_i;
		$this->o   = $o_i;
		$this->p   = $p_i;
	
		mt_srand();
	/*
	     領域の確保
	*/
		$this->E = array($this->n);
		for ($i1 = 0; $i1 < $this->n; $i1++)
			$this->E[$i1] = array($this->p);
	
		$this->W = array($this->o);
		for ($i1 = 0; $i1 < $this->o; $i1++)
			$this->W[$i1]   = array($this->p);
	}
	
	/**************************/
	/* 学習データの読み込み   */
	/*      name : ファイル名 */
	/**************************/
	function Input($name)
	{
		$st = fopen($name, "rb");
	
		for ($i1 = 0; $i1 < $this->n; $i1++) {
			$str             = trim(fgets($st));
			$this->E[$i1][0] = intval(strtok($str, " "));
			for ($i2 = 1; $i2 < $this->p; $i2++)
				$this->E[$i1][$i2] = intval(strtok(" "));
		}
	
		fclose($st);
	}
	
	/*********************************/
	/* 学習と結果の出力              */
	/*      pr : =0 : 画面に出力     */
	/*           =1 : ファイルに出力 */
	/*      name : 出力ファイル名    */
	/*********************************/
	function Learn($pr, $name = "STDOUT")
	{
		$mx   = 0;
		$mx_v = 0.0;
	/*
	     初期設定
	*/
		for ($i1 = 0; $i1 < $this->o; $i1++) {
			$sum = 0.0;
			for ($i2 = 0; $i2 < $this->p; $i2++) {
				$this->W[$i1][$i2]  = mt_rand() / mt_getrandmax();
				$sum               += $this->W[$i1][$i2];
			}
			$sum = 1.0 / $sum;
			for ($i2 = 0; $i2 < $this->p; $i2++)
				$this->W[$i1][$i2] *= $sum;
		}
	/*
	     学習
	*/
		for ($count = 0; $count < $this->max; $count++) {
						// 訓練例の選択
			$k = intval(mt_rand() / mt_getrandmax() * $this->n);
			if ($k >= $this->n)
				$k = $this->n - 1;
						// 出力の計算
			for ($i1 = 0; $i1 < $this->o; $i1++) {
				$s = 0.0;
				for ($i2 = 0; $i2 < $this->p; $i2++)
					$s += $this->W[$i1][$i2] * $this->E[$k][$i2];
				if ($i1 == 0 || $s > $mx_v) {
					$mx   = $i1;
					$mx_v = $s;
				}
			}
						// 重みの修正
			$sum = 0.0;
			for ($i1 = 0; $i1 < $this->p; $i1++)
				$sum += $this->E[$k][$i1];
			for ($i1 = 0; $i1 < $this->p; $i1++)
				$this->W[$mx][$i1] += $this->sig * ($this->E[$k][$i1] / $sum - $this->W[$mx][$i1]);
		}
	/*
	     出力
	*/
		if ($pr == 0)
			$out = STDOUT;
		else
			$out = fopen($name, "w");
	
		fwrite($out, "分類結果\n");
		for ($i1 = 0; $i1 < $this->n; $i1++) {
			$str = "";
			for ($i2 = 0; $i2 < $this->p; $i2++)
				$str = $str." ".$this->E[$i1][$i2];
			$str = $str." Res ";
			for ($i2 = 0; $i2 < $this->o; $i2++) {
				$s = 0.0;
				for ($i3 = 0; $i3 < $this->p; $i3++)
		            $s += $this->W[$i2][$i3] * $this->E[$i1][$i3];
				if ($i2 == 0 || $s > $mx_v) {
					$mx   = $i2;
					$mx_v = $s;
				}
			}
			fwrite($out, $str.($mx+1)."\n");
		}
	}
}
	
/****************/
/* main program */
/****************/
	
	if (count($argv) > 1) {
					// 基本データの入力
		$st = fopen($argv[1], "r");
		fscanf($st, "%*s %ld %*s %ld %*s %ld %*s %ld %*s %lf", $max, $p, $o, $n, $sig);
		fscanf($st, "%*s %s", $name);
		fclose($st);
					// ネットワークの定義
		$net = new Competition($max, $n, $o, $p, $sig);
		$net->Input($name);
					// 学習と結果の出力
		if (count($argv) == 2)
			$net->Learn(0);
		else
			$net->Learn(1, $argv[2]);
	}

	else
		exit("***error   入力データファイル名を指定して下さい\n");

/*
------------------------入力ファイル--------------
最大試行回数 1000 入力セルの数 9 出力セルの数 3 訓練例の数 9 係数(0~1) 0.1
入力データファイル pat.dat

------------------------pat.dat-------------------
0 0 0 0 0 0 0 1 1
0 0 0 0 0 0 1 0 1
0 0 0 0 0 0 1 1 0
0 0 0 0 1 1 0 0 0
0 0 0 1 0 1 0 0 0
0 0 0 1 1 0 0 0 0
0 1 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 0
1 1 0 0 0 0 0 0 0
*/

?>