Exemple #1
0
    def hess_vec_prod(vec):
        hess_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)
    def hvp(rhs):
        padded_rhs = torch.zeros(total_pars,
                                 rhs.shape[-1],
                                 device=rhs.device,
                                 dtype=rhs.dtype)

        padded_rhs = unflatten_like(padded_rhs.t(), model.parameters())
        eval_hess_vec_prod(padded_rhs,
                           net=model,
                           criterion=loss,
                           inputs=train_x,
                           targets=train_y,
                           dataloader=loader,
                           use_cuda=use_cuda)
        full_hvp = gradtensor_to_tensor(model, include_bn=True)
        return full_hvp.unsqueeze(-1)
Exemple #3
0
    def hess_vec_prod(vec):
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec,
                           params,
                           net,
                           criterion,
                           inputs=inputs,
                           targets=targets,
                           dataloader=dataloader,
                           use_cuda=use_cuda)
        prod_time = time.time() - start_time
        if verbose:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)
Exemple #4
0
    def hess_vec_prod(vec):
        padded_rhs = torch.zeros(N, vec.shape[-1],
                             device=vec.device, dtype=vec.dtype)
        padded_rhs[mask==1] = vec
        
        print("vec shape = ", vec.shape)
        print("padded shape = ", padded_rhs.shape)
        hess_vec_prod.count += 1  # simulates a static variable
        padded_rhs = unflatten_like(padded_rhs.t(), net.parameters())

        start_time = time.time()
        eval_hess_vec_prod(padded_rhs, net=net, criterion=criterion,
                           dataloader=dataloader,
                          use_cuda=use_cuda)
        prod_time = time.time() - start_time
        out = gradtensor_to_tensor(net, include_bn=True)
        
        sliced = out[mask==1].unsqueeze(-1)
        print("sliced shape = ", sliced.shape)
        return sliced
def main():

    ##########################
    ## SET UP TRAINING DATA ##
    ##########################
    X, Y = twospirals(500, noise=1.5)
    train_x, train_y = torch.FloatTensor(X), torch.FloatTensor(Y).unsqueeze(-1)

    test_X, test_Y = twospirals(100, 1.5)
    test_x, test_y = torch.FloatTensor(test_X), torch.FloatTensor(
        test_Y).unsqueeze(-1)

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(2)
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
        train_x, train_y = train_x.cuda(), train_y.cuda()

    #############################
    ## SOME HYPERS AND STORAGE ##
    #############################
    widths = [i for i in range(5, 6)]
    loss_func = torch.nn.BCEWithLogitsLoss()

    in_dim = 2
    out_dim = 1

    hessians = []
    n_pars = []
    test_errors = []
    ###############
    ## MAIN LOOP ##
    ###############
    for width_ind, width in enumerate(widths):
        model = hess.nets.SimpleNet(in_dim,
                                    out_dim,
                                    n_hidden=5,
                                    hidden_size=20,
                                    activation=torch.nn.ELU(),
                                    bias=True)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        n_par = sum(p.numel() for p in model.parameters())
        n_pars.append(n_par)

        ## TRAIN MODEL ##
        for step in range(2000):
            optimizer.zero_grad()
            outputs = model(train_x)

            loss = loss_func(outputs, train_y)
            loss.backward()
            optimizer.step()
        print("model %i trained" % width)

        hessian = torch.zeros(n_par, n_par)
        for pp in range(n_par):
            base_vec = torch.zeros(n_par).unsqueeze(0)
            base_vec[0, pp] = 1.

            base_vec = utils.unflatten_like(base_vec, model.parameters())
            utils.eval_hess_vec_prod(base_vec,
                                     model,
                                     criterion=torch.nn.BCEWithLogitsLoss(),
                                     inputs=train_x,
                                     targets=train_y)
            if pp == 0:
                output = utils.gradtensor_to_tensor(model, include_bn=True)
                hessian = torch.zeros(output.nelement(), output.nelement())
                hessian[:, pp] = output

            hessian[:, pp] = utils.gradtensor_to_tensor(model,
                                                        include_bn=True).cpu()

        ## SAVE THOSE OUTPUTS ##
        hessians.append(hessian)
        test_errors.append(loss_func(model(test_x), test_y).item())

    ## SAVE EVERYTHING ##
    fpath = "./"
    fname = "hessians.P"
    with open(fpath + fname, 'wb') as fp:
        pickle.dump(hessians, fp)

    fname = "test_errors.P"
    with open(fpath + fname, 'wb') as fp:
        pickle.dump(test_errors, fp)

    fname = "n_pars.P"
    with open(fpath + fname, 'wb') as fp:
        pickle.dump(n_pars, fp)

    fname = "widths.pt"
    torch.save(widths, fpath + fname)