Winner-Take-All

<?php

/****************************/
/* Winner-Take-All Groups   */
/*      coded by Y.Suganuma */
/****************************/

/**********************/
/* Winnerクラスの定義 */
/**********************/
class Winner {
	private $max;   // 最大学習回数
	private $n;   // 訓練例の数
	private $o;   // 出力セルの数
	private $p;   // 入力セルの数
	private $W_p;   // 重み(ポケット)
	private $W;   // 重み
	private $E;   // 訓練例
	private $C;   // 各訓練例に対する正しい出力
	private $Ct;   // 作業領域
	
	/******************/
	/* コンストラクタ */
	/******************/
	function Winner($max_i, $n_i, $o_i, $p_i)
	{
	/*
	     設定
	*/
		$this->max = $max_i;
		$this->n   = $n_i;
		$this->o   = $o_i;
		$this->p   = $p_i;
	
		mt_srand();
	/*
	     領域の確保
	*/
		$this->E = array($this->n);
		$this->C = array($this->n);
		for ($i1 = 0; $i1 < $this->n; $i1++) {
			$this->E[$i1] = array($this->p+1);
			$this->C[$i1] = array($this->o);
		}
	
		$this->W_p = array($this->o);
		$this->W   = array($this->o);
		for ($i1 = 0; $i1 < $this->o; $i1++) {
			$this->W_p[$i1] = array($this->p+1);
			$this->W[$i1]   = array($this->p+1);
		}
	
		$this->Ct = array($this->o);
	}
	
	/******************************************/
	/* 訓練例の分類                           */
	/*      return : 正しく分類した訓練例の数 */
	/******************************************/
	function Bunrui()
	{
		$mx   = 0;
		$mx_v = 0;
		$num  = 0;
		$sw   = 0;
	
		for ($i1 = 0; $i1 < $this->n; $i1++) {
			$cor = 0;
			for ($i2 = 0; $i2 < $this->o; $i2++) {
				if ($this->C[$i1][$i2] == 1)
					$cor = $i2;
				$s = 0;
				for ($i3 = 0; $i3 <= $this->p; $i3++)
					$s += $this->W[$i2][$i3] * $this->E[$i1][$i3];
				if ($i2 == 0) {
					$mx   = 0;
					$mx_v = $s;
				}
				else {
					if ($s > $mx_v) {
						$mx   = $i2;
						$mx_v = $s;
						$sw   = 0;
					}
					else {
						if ($s == $mx_v)
							$sw = 1;
					}
				}
			}
	
			if ($sw == 0 && $cor == $mx)
				$num++;
		}
	
		return $num;
	}
	
	/**************************/
	/* 学習データの読み込み   */
	/*      name : ファイル名 */
	/**************************/
	function Input($name)
	{
		$st = fopen($name, "rb");
	
		fgets($st);
	
		for ($i1 = 0; $i1 < $this->n; $i1++) {
			$this->E[$i1][0] = 1;
			$str             = trim(fgets($st));
			$this->E[$i1][1] = intval(strtok($str, " "));
			for ($i2 = 2; $i2 <= $this->p; $i2++)
				$this->E[$i1][$i2] = intval(strtok(" "));
			for ($i2 = 0; $i2 < $this->o; $i2++)
				$this->C[$i1][$i2] = intval(strtok(" "));
		}
	
		fclose($st);
	}
	
	/*********************************/
	/* 学習と結果の出力              */
	/*      pr : =0 : 画面に出力     */
	/*           =1 : ファイルに出力 */
	/*      name : 出力ファイル名    */
	/*********************************/
	function Learn($pr, $name = "STDOUT")
	{
		$mx   = 0;
		$mx_v = 0;
	
		$n_tri = $this->Pocket($num);
	
		if ($pr == 0)
			$out = STDOUT;
		else
			$out = fopen($name, "w");
	
		fwrite($out, "重み\n");
		for ($i1 = 0; $i1 < $this->o; $i1++) {
			$str = "";
			for ($i2 = 0; $i2 <= $this->p; $i2++)
				$str = $str." ".$this->W_p[$i1][$i2];
			fwrite($out, $str."\n");
		}
	
		fwrite($out, "分類結果\n");
		for ($i1 = 0; $i1 < $this->n; $i1++) {
			$sw = 0;
			for ($i2 = 0; $i2 < $this->o; $i2++) {
				$s = 0;
				for ($i3 = 0; $i3 <= $this->p; $i3++)
					$s += $this->W_p[$i2][$i3] * $this->E[$i1][$i3];
				if ($i2 == 0) {
					$mx_v = $s;
					$mx   = 0;
				}
				else {
					if ($s > $mx_v) {
						$sw   = 0;
						$mx_v = $s;
						$mx   = $i2;
					}
					else {
						if ($s == $mx_v)
							$sw = 1;
					}
				}
			}
			$str = "";
			for ($i2 = 1; $i2 <= $this->p; $i2++)
				$str = $str." ".$this->E[$i1][$i2];
			$str = $str." Cor ";
			for ($i2 = 0; $i2 < $this->o; $i2++)
				$str = $str." ".$this->C[$i1][$i2];
			if ($sw > 0)
				$mx = -1;
			$str = $str." Res ".($mx+1);
			fwrite($out, $str."\n");
		}
	
		if ($this->n == $num)
			printf("  !!すべてを分類(試行回数:%ld)\n", $n_tri);
		else
			printf("  !!%ld 個を分類\n", $num);
	}
	
	/********************************************/
	/* Pocket Algorith with Ratcet              */
	/*      num_p : 正しく分類した訓練例の数    */
	/*      return : =0 : 最大学習回数          */
	/*               >0  : すべてを分類(回数) */
	/********************************************/
	function Pocket(&$num_p)
	{
	/*
	     初期設定
	*/
		$count = 0;
		$mx    = 0;
		$run   = 0;
		$run_p = 0;
		$sw    = -1;
		$num_p = 0;
	
		for ($i1 = 0; $i1 < $this->o; $i1++) {
			for ($i2 = 0; $i2 <= $this->p; $i2++)
				$this->W[$i1][$i2] = 0;
		}
	/*
	     実行
	*/
		while ($sw < 0) {
						// 終了チェック
			$count++;;
			if ($count > $this->max)
				$sw = 0;
	
			else {
						// 訓練例の選択
				$k = intval(mt_rand() / mt_getrandmax() * $this->n);
				if ($k >= $this->n)
					$k = $this->n - 1;
						// 出力の計算
				$sw1 = 0;
				$cor = -1;
	
				for ($i1 = 0; $i1 < $this->o; $i1++) {
	
					if ($this->C[$k][$i1] == 1)
						$cor = $i1;
	
					$s = 0;
					for ($i2 = 0; $i2 <= $this->p; $i2++)
						$s += $this->W[$i1][$i2] * $this->E[$k][$i2];
					$this->Ct[$i1] = $s;
	
					if ($i1 == 0)
						$mx = 0;
					else {
						if ($s > $this->Ct[$mx]) {
							$mx = $i1;
							$sw1 = 0;
						}
						else {
							if ($s == $this->Ct[$mx]) {
								$sw1 = 1;
								if ($cor >= 0 && $mx == $cor)
									$mx = $i1;
							}
						}
					}
				}
						// 正しい分類
				if ($sw1 == 0 && $cor == $mx) {
					$run++;
					if ($run > $run_p) {
						$num = $this->Bunrui();
						if ($num > $num_p) {
							$num_p = $num;
							$run_p = $run;
							for ($i1 = 0; $i1 < $this->o; $i1++) {
								for ($i2 = 0; $i2 <= $this->p; $i2++)
									$this->W_p[$i1][$i2] = $this->W[$i1][$i2];
							}
							if ($num == $this->n)
								$sw = $count;
						}
					}
				}
						// 誤った分類
				else {
					$run  = 0;
					for ($i1 = 0; $i1 <= $this->p; $i1++) {
						$this->W[$cor][$i1] += $this->E[$k][$i1];
						$this->W[$mx][$i1]  -= $this->E[$k][$i1];
					}
				}
			}
		}
	
		return $sw;
	}
}

/****************/
/* main program */
/****************/

	if (count($argv) > 1) {
					// 基本データの入力
		$st = fopen($argv[1], "rb");
		fscanf($st, "%*s %ld %*s %ld %*s %ld %*s %ld",$max, $p, $o, $n);
		fscanf($st, "%*s %s", $name);
		fclose($st);
					// ネットワークの定義と学習データ等の設定
		$net = new Winner($max, $n, $o, $p);
		$net->Input($name);
					// 学習と結果の出力
		if (count($argv) == 2)
			$net->Learn(0);
		else
			$net->Learn(1, $argv[2]);
	}

	else
		exit("***error   入力データファイル名を指定して下さい\n");

/*
------------------------入力ファイル--------------
最大試行回数 100 入力セルの数 2 出力セルの数 2 訓練例の数 4
入力データファイル or.dat

------------------------or.dat--------------------
OR演算の訓練例.最後の2つのデータが目標出力値
-1 -1 -1 1
-1  1  1 -1
 1 -1  1 -1
 1  1  1 -1
*/

?>