最適化(多項式近似法)

/*********************************************/
/* 多項式近似によるy=x^4+3x^3+2x^2+1の最小値 */
/*      coded by Y.Suganuma                  */
/*********************************************/
#include <stdio.h>

double snx(double);
double approx(double, double, double, double *, int *, int, double (*)(double));

int main()
{
	double eps, val, x, d;
	int ind, max;

	x   = -2.0;
	d   = 0.1;
	eps = 1.0e-10;
	max = 100;

	x = approx(x, d, eps, &val, &ind, max, snx);

	printf("x %f val %f ind %d\n", x, val, ind);

	return 0;
}

/****************/
/* 関数値の計算 */
/****************/
double snx(double x)
{
	double f;

	f  = x * x * x * x + 3.0 * x * x * x + 2.0 * x * x + 1.0;

	return f;
}

/******************************************/
/* 多項式近似(関数の最小値)               */
/*      x0  : 初期値                      */
/*      d0  : 初期ステップ                */
/*      eps : 許容誤差                    */
/*      val : 間数値                      */
/*      ind : 計算状況                    */
/*              >= 0 : 正常終了(収束回数) */
/*              = -1 : 収束せず           */
/*      max : 最大試行回数                */
/*      fun : 関数値を計算する関数の名前  */
/*      return : 結果                     */
/******************************************/
#include <math.h>

double approx(double x0, double d0, double eps, double *val, int *ind, int max, double (*fun)(double))
{
	double f[4], x[4], xx = 0.0, d, dl;
	int i1, k = 0, count = 0, min, sw;

	d    = d0;
	x[1] = x0;
	f[1] = fun(x0);
	*ind = -1;

	while (count < max && *ind < 0) {
		x[3] = x[1] + d;
		f[3] = fun(x[3]);
		while (k < max && f[3] <= f[1]) {
			k++;
			d *= 2.0;
			x[0] = x[1];
			f[0] = f[1];
			x[1] = x[3];
			f[1] = f[3];
			x[3] = x[1] + d;
			f[3] = fun(x[3]);
		}
					// 初期値が不適当
		if (k >= max)
			count = max;
		else {
					// 3点の選択
			sw = 0;
			if (k > 0) {
				x[2] = x[3] - 0.5 * d;
				f[2] = fun(x[2]);
				min  = -1;
				for (i1 = 0; i1 < 4; i1++) {
					if (min < 0 || f[i1] < f[min])
						min = i1;
				}
				if (min >= 2) {
					for (i1 = 0; i1 < 3; i1++) {
						x[i1] = x[i1+1];
						f[i1] = f[i1+1];
					}
				}
				sw = 1;
			}
			else {
				x[0] = x[1] - d0;
				f[0] = fun(x[0]);
				if (f[0] > f[1]) {
					x[2] = x[3];
					f[2] = f[3];
					sw = 1;
				}
				else {
					x[1] = x[0];
					f[1] = f[0];
					d0   = -d0;
					d    = 2.0 * d0;
					k    = 1;
				}
			}
					// 収束?
			if (sw > 0) {
				count++;
				dl = 0.5 * d * (f[2] - f[0]) / (f[0] - 2.0 * f[1] + f[2]);
				xx   = x[1] - dl;
				*val = fun(xx);
				if (fabs(dl) < eps)
					*ind = count;
				else {
					k  = 0;
					d0 = 0.5 * d;
					d  = d0;
					if (*val < f[1]) {
						x[1] = xx;
						f[1] = *val;
					}
				}
			}
		}
	}

	return xx;
}