コード例 #1
0
 def f(v):
     hessp_beta1 = GNvp(c1, z1, tmp_params1, v)
     hessp_beta2 = GNvp(c2, z2, tmp_params2, v)
     if step >= 1:
         weighted_hessp = ((step - 1) * hessp_beta1 + hessp_beta2) / step
     else:
         weighted_hessp = hessp_beta2
     return weighted_hessp.data / bias_correction2
コード例 #2
0
 def f(p):
     # pvar = Variable(p.float(), requires_grad=False)
     pvar = torch.nn.Parameter(p.float())  #, requires_grad=False)
     # vector_to_parameters(pvar, self._params_tmp)
     import time
     s = time.time()
     # c, z, tmp_params = closure(pvar)
     c, z, tmp_params = closure([pvar])  #self._params_tmp)
     e = time.time()
     # print ("Closure time: ", (e-s))
     v1 = GNvp(c, z, tmp_params, ng)
     # v2 = weighted_fvp_fn(ng)
     loss = F.mse_loss(v1, v2)
     return float(loss.data)
コード例 #3
0
 def f(v):
     hessp = GNvp(c, z, params, v)
     return hessp.data / bias_correction2