多項式近似(関数の最小値)

# -*- coding: UTF-8 -*-
from math import *
import numpy as np

##########################################
# 多項式近似(関数の最小値)
#      x0  : 初期値
#      d0  : 初期ステップ
#      eps : 許容誤差
#      val : 間数値
#      ind : 計算状況
#              >= 0 : 正常終了(収束回数)
#              = -1 : 収束せず
#      max : 最大試行回数
#      fun : 関数値を計算する関数の名前
#      return : 結果
#      coded by Y.Suganuma
##########################################

def approx(x0, d0, eps, val, ind, max, fun) :

	x      = np.empty(4, np.float)
	f      = np.empty(4, np.float)
	xx     = 0.0
	k      = 0
	count  = 0
	d      = d0
	x[1]   = x0
	f[1]   = fun(x0)
	ind[0] = -1

	while count < max and ind[0] < 0 :
		x[3] = x[1] + d
		f[3] = fun(x[3])
		while k < max and f[3] <= f[1] :
			k += 1
			d *= 2.0
			x[0] = x[1]
			f[0] = f[1]
			x[1] = x[3]
			f[1] = f[3]
			x[3] = x[1] + d
			f[3] = fun(x[3])
					# 初期値が不適当
		if k >= max :
			count = max
		else :
					# 3点の選択
			sw = 0
			if k > 0 :
				x[2] = x[3] - 0.5 * d
				f[2] = fun(x[2])
				min  = -1
				for i1 in range(0, 4) :
					if min < 0 or f[i1] < f[min] :
						min = i1
				if min >= 2 :
					for i1 in range(0, 3) :
						x[i1] = x[i1+1]
						f[i1] = f[i1+1]
				sw = 1
			else :
				x[0] = x[1] - d0
				f[0] = fun(x[0])
				if f[0] > f[1] :
					x[2] = x[3]
					f[2] = f[3]
					sw = 1
				else :
					x[1] = x[0]
					f[1] = f[0]
					d0   = -d0
					d    = 2.0 * d0
					k    = 1
					# 収束?
			if sw > 0 :
				count += 1
				dl     = 0.5 * d * (f[2] - f[0]) / (f[0] - 2.0 * f[1] + f[2])
				xx     = x[1] - dl
				val[0] = fun(xx)
				if abs(dl) < eps :
					ind[0] = count
				else :
					k  = 0
					d0 = 0.5 * d
					d  = d0
					if val[0] < f[1] :
						x[1] = xx
						f[1] = val[0]

	return xx

----------------------------------

# -*- coding: UTF-8 -*-
import numpy as np
import sys
from math import *
from function import approx

############################################
# 多項式近似によるy=x^4+3x^3+2x^2+1の最小値
#      coded by Y.Suganuma
############################################

###############
# 関数値の計算
###############
def snx(x) :
	return x * x * x * x + 3.0 * x * x * x + 2.0 * x * x + 1.0

			# 設定と実行
x   = -3.0
d   = 0.1
eps = 1.0e-7
max = 100
val = np.empty(1, np.float)
ind = np.empty(1, np.int)

x = approx(x, d, eps, val, ind, max, snx)

print("x " + str(x) + " val " + str(val[0]) + " ind " + str(ind[0]))