def fmin_bfgs_f(f_g, x0, B0=None, M=2, gtol=1e-5, Delta=10.0, maxiter=None, callback=None, norm_ord=np.Inf, **_kwargs): """test BFGS with nonmonote line search""" fk, gk = f_g(x0) if B0 is None: Bk = np.eye(len(x0)) else: Bk = B0 Hk = np.linalg.inv(Bk) maxiter = 200 * len(x0) if maxiter is None else maxiter xk = x0 norm = lambda x: np.linalg.norm(x, ord=norm_ord) theta = 0.9 C = 0.5 k = 0 old_old_fval = fk + np.linalg.norm(gk) / 2 old_fval = fk f_s = Seq(M) f_s.add(fk) flag = 0 re_search = 0 for k in range(maxiter): if norm(gk) <= gtol: break dki = -np.dot(Hk, gk) try: pk = dki f = f_g.fun myfprime = f_g.grad gfk = gk old_fval = fk ( alpha_k, fc, gc, old_fval, old_old_fval, gfkp1, ) = line_search_wolfe2(f, myfprime, xk, pk, gfk, f_s.get_max(), old_fval, old_old_fval) except Exception as e: print(e) re_search += 1 xk = xk + dki fk, gk = f_g(xk) old_fval, old_old_fval = fk, old_fval f_s.add(fk) if re_search > 2: flag = 1 break continue if alpha_k is None: print("alpha is None") xk = xk + dki fk, gk = f_g(xk) old_fval, old_old_fval = fk, old_fval f_s.add(fk) re_search += 1 if re_search > 2: flag = 1 break continue dki = alpha_k * pk # fki, gki = f_g(xk + dki) fki, gki = old_fval, gfkp1 Aredk = fk - fki Predk = -(np.dot(gk, dki) + 0.5 * np.dot(np.dot(Bk, dki), dki)) rk = Aredk / Predk xk = xk + dki fk = fki yk = gki - gk tk = C + max(0, -np.dot(yk, dki) / norm(dki)**2) / norm(gk) ystark = (1 - theta) * yk + theta * tk * norm(gk) * dki gk = gki bs = np.dot(Bk, dki) Bk = (Bk + np.outer(yk, yk) / np.dot(yk, dki) - np.outer(bs, bs) / np.dot(bs, dki)) # sk = dki # rhok = 1.0 / (np.dot(yk, sk)) # A1 = 1 - np.outer(sk, yk) * rhok # A2 = 1 - np.outer(yk, sk) * rhok # Hk = np.dot(A2, np.dot(Hk, A1)) - (rhok * np.outer(sk, sk)) # Bk = Bk + np.outer(ystark, ystark)/np.dot(ystark, dki) - \ # np.outer(bs, bs)/np.dot(bs, dki) # MBFGS # print(np.dot(Hk, Bk)) try: Hk = np.linalg.inv(Bk) except Exception: pass f_s.add(fk) if callback is not None: callback(xk) else: flag = 2 # print("fit final: ", k, p, f_g.ncall) s = OptimizeResult() s.messgae = message_dict[flag] s.fun = float(fk) s.nit = k s.nfev = f_g.ncall s.njev = f_g.ncall s.status = flag s.x = np.array(xk) s.jac = np.array(gk) s.hess = np.array(Bk) s.success = flag == 0 return s