示例#1
0
    def hess_vec_prod(vec):
        hess_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)
示例#2
0
    def obsfisher_vec_prod(vec):
        obsfisher_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        out = eval_obsfisher_vec_prod(vec, params, net, criterion, dataloader,
                                      use_cuda)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" %
                  (obsfisher_vec_prod.count, prod_time))
        # out = gradtensor_to_tensor(net)
        return out.view(-1, 1)
示例#3
0
    def fisher_vec_prod(vec):
        fisher_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), net.parameters())

        start_time = time.time()
        out = eval_fisher_vec_prod(vec,
                                   net,
                                   dataloader,
                                   use_cuda,
                                   fvp_matmul=fvp_matmul)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (fisher_vec_prod.count, prod_time))
        #out = gradtensor_to_tensor(net)
        return out
示例#4
0
 def update_params(self, vec, model):
     vec_list = unflatten_like(likeTensorList=list(model.parameters()),
                               vector=vec.view(1, -1))
     for param, v in zip(model.parameters(), vec_list):
         param.detach_()
         param.mul_(0.0).add_(v)