Ejemplo n.º 1
0
def perturb_model(model, sigma, n_pars, use_cuda):
    perturb = torch.randn(n_pars) * sigma
    if use_cuda:
        perturb = perturb.cuda()
    perturb = utils.unflatten_like(perturb.unsqueeze(0), model.parameters())

    for i, par in enumerate(model.parameters()):
        par.data = par.data + perturb[i]

    return
Ejemplo n.º 2
0
    def hess_vec_prod(vec):
        hess_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)
Ejemplo n.º 3
0
    def hvp(rhs):
        padded_rhs = torch.zeros(total_pars,
                                 rhs.shape[-1],
                                 device=rhs.device,
                                 dtype=rhs.dtype)

        padded_rhs = unflatten_like(padded_rhs.t(), model.parameters())
        eval_hess_vec_prod(padded_rhs,
                           net=model,
                           criterion=loss,
                           inputs=train_x,
                           targets=train_y,
                           dataloader=loader,
                           use_cuda=use_cuda)
        full_hvp = gradtensor_to_tensor(model, include_bn=True)
        return full_hvp.unsqueeze(-1)
Ejemplo n.º 4
0
    def hess_vec_prod(vec):
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec,
                           params,
                           net,
                           criterion,
                           inputs=inputs,
                           targets=targets,
                           dataloader=dataloader,
                           use_cuda=use_cuda)
        prod_time = time.time() - start_time
        if verbose:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)
Ejemplo n.º 5
0
    def hess_vec_prod(vec):
        padded_rhs = torch.zeros(N, vec.shape[-1],
                             device=vec.device, dtype=vec.dtype)
        padded_rhs[mask==1] = vec
        
        print("vec shape = ", vec.shape)
        print("padded shape = ", padded_rhs.shape)
        hess_vec_prod.count += 1  # simulates a static variable
        padded_rhs = unflatten_like(padded_rhs.t(), net.parameters())

        start_time = time.time()
        eval_hess_vec_prod(padded_rhs, net=net, criterion=criterion,
                           dataloader=dataloader,
                          use_cuda=use_cuda)
        prod_time = time.time() - start_time
        out = gradtensor_to_tensor(net, include_bn=True)
        
        sliced = out[mask==1].unsqueeze(-1)
        print("sliced shape = ", sliced.shape)
        return sliced
Ejemplo n.º 6
0
def get_loss_surface(basis,
                     model,
                     dataloader,
                     criterion,
                     rng=0.1,
                     n_pts=25,
                     use_cuda=False):
    """
    note that loss should be a lambda function that just takes in the model!
    """

    start_pars = model.state_dict()
    ## get out the plane ##
    dir1, dir2 = get_plane(basis)

    ## init loss surface and the vector multipliers ##
    loss_surf = torch.zeros(n_pts, n_pts)
    vec_len = torch.linspace(-rng / 2., rng / 2., n_pts)

    ## loop and get loss at each point ##
    for ii in range(n_pts):
        for jj in range(n_pts):
            perturb = dir1.mul(vec_len[ii]) + dir2.mul(vec_len[jj])
            # print(perturb.shape)
            perturb = utils.unflatten_like(perturb.t(), model.parameters())
            for i, par in enumerate(model.parameters()):
                if use_cuda:
                    par.data = par.data + perturb[i].cuda()
                else:
                    par.data = par.data + perturb[i]

            loss_surf[ii, jj] = loss_getter(model, dataloader, criterion,
                                            use_cuda)

            model.load_state_dict(start_pars)

    return loss_surf
Ejemplo n.º 7
0
def main():

    ## generate model and load in trained instance ##
    use_cuda = torch.cuda.is_available()
    model = Net()

    saved_model = torch.load("./model.pt", map_location=('cpu'))
    model.load_state_dict(saved_model)
    if use_cuda:
        model = model.cuda()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='/datasets/cifar10/',
                                            train=True,
                                            download=False,
                                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=32,
                                              shuffle=False,
                                              num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='/datasets/cifar10/',
                                           train=False,
                                           download=False,
                                           transform=transform)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=32,
                                             shuffle=False,
                                             num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    ## load in eigenpairs and clean up ##
    fpath = "./"
    fname = "cifar_evecs_200.pt"
    evecs = torch.load(fpath + fname, map_location=("cpu")).squeeze()

    fname = "cifar_evals_200.pt"
    evals = torch.load(fpath + fname, map_location=("cpu"))

    keep = np.where(evals != 1)[0]
    n_evals = keep.size
    evals = evals[keep].numpy()
    evecs = evecs[:, keep].numpy()

    idx = np.abs(evals).argsort()[::-1]
    evals = torch.FloatTensor(evals[idx])
    evecs = torch.FloatTensor(evecs[:, idx])

    pars = utils.flatten(model.parameters())
    n_par = pars.numel()
    par_len = pars.norm()

    criterion = nn.CrossEntropyLoss()

    ## going to need the original model predictions ##
    model_preds = []
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        if use_cuda:
            inputs, labels = inputs.cuda(), labels.cuda()

        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

        model_preds.append(predicted.cpu())

    n_scale = 20
    n_trial = 10
    scales = torch.linspace(0, 1., n_scale)

    ## Test high curvature directions ##
    high_curve_losses = torch.zeros(n_scale, n_trial)
    n_diff_high = torch.zeros(n_scale, n_trial)
    for ii in range(n_scale):
        for tt in range(n_trials):
            alpha = torch.randn(n_evals)
            pert = evecs.matmul(alpha.unsqueeze(-1)).t()
            pert = scales[ii] * pert.div(pert.norm())
            if use_cuda:
                pert = pert.cuda()
            pert = utils.unflatten_like(pert, model.parameters())

            ## perturb ##
            for i, par in enumerate(model.parameters()):
                par.data = par.data + pert[i]

            ## compute the loss and label diffs ##
            train_loss, train_diff = compute_loss_differences(
                model, trainloader, criterion, model_preds)
            high_curve_losses[ii, tt] = train_loss
            n_diff_high[ii, tt] = train_diff

            ## need to reload pars after each perturbation ##
            model.load_state_dict(saved_model)

        ## just to track progress ##
        print("high curve scale {} of {} done".format(ii, n_scale))

    ## save the high curvature results ##
    fpath = "./"
    fname = "high_curve_losses.pt"
    torch.save(high_curve_losses, fpath + fname)

    fname = "n_diff_high.pt"
    torch.save(n_diff_high, fpath + fname)
    print("all high curvature done \n\n")

    ## go through the low curvature directions ##
    low_curve_losses = torch.zeros(n_scale, n_trial)
    n_diff_low = torch.zeros(n_scale, n_trial)
    for ii in range(n_scale):
        for tt in range(n_trials):
            alpha = torch.randn(n_evals)  # random direction
            pert = gram_schmidt(alpha,
                                evecs).unsqueeze(-1).t()  # orthogonal to evecs
            pert = scales[ii] * pert.div(pert.norm())  # scaled correctly

            if use_cuda:
                pert = pert.cuda()
            pert = utils.unflatten_like(pert, model.parameters())

            ## go through trainloader and keep track of losses/differences in preds ##
            for i, par in enumerate(model.parameters()):
                par.data = par.data + pert[i]

            ## compute the loss and label diffs ##
            train_loss, train_diff = compute_loss_differences(
                model, trainloader, criterion, model_preds)
            low_curve_losses[ii, tt] = train_loss
            n_diff_low[ii, tt] = train_diff

            model.load_state_dict(saved_model)

        print("low curve scale {} of {} done".format(ii, n_scale))

    ## save the high curvature results ##
    fpath = "./"
    fname = "low_curve_losses.pt"
    torch.save(low_curve_losses, fpath + fname)

    fname = "n_diff_low.pt"
    torch.save(n_diff_low, fpath + fname)

    print("\n\n Gucci.")
Ejemplo n.º 8
0
def main():

    ##########################
    ## SET UP TRAINING DATA ##
    ##########################
    X, Y = twospirals(500, noise=1.5)
    train_x, train_y = torch.FloatTensor(X), torch.FloatTensor(Y).unsqueeze(-1)

    test_X, test_Y = twospirals(100, 1.5)
    test_x, test_y = torch.FloatTensor(test_X), torch.FloatTensor(
        test_Y).unsqueeze(-1)

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(2)
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
        train_x, train_y = train_x.cuda(), train_y.cuda()

    #############################
    ## SOME HYPERS AND STORAGE ##
    #############################
    widths = [i for i in range(5, 6)]
    loss_func = torch.nn.BCEWithLogitsLoss()

    in_dim = 2
    out_dim = 1

    hessians = []
    n_pars = []
    test_errors = []
    ###############
    ## MAIN LOOP ##
    ###############
    for width_ind, width in enumerate(widths):
        model = hess.nets.SimpleNet(in_dim,
                                    out_dim,
                                    n_hidden=5,
                                    hidden_size=20,
                                    activation=torch.nn.ELU(),
                                    bias=True)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        n_par = sum(p.numel() for p in model.parameters())
        n_pars.append(n_par)

        ## TRAIN MODEL ##
        for step in range(2000):
            optimizer.zero_grad()
            outputs = model(train_x)

            loss = loss_func(outputs, train_y)
            loss.backward()
            optimizer.step()
        print("model %i trained" % width)

        hessian = torch.zeros(n_par, n_par)
        for pp in range(n_par):
            base_vec = torch.zeros(n_par).unsqueeze(0)
            base_vec[0, pp] = 1.

            base_vec = utils.unflatten_like(base_vec, model.parameters())
            utils.eval_hess_vec_prod(base_vec,
                                     model,
                                     criterion=torch.nn.BCEWithLogitsLoss(),
                                     inputs=train_x,
                                     targets=train_y)
            if pp == 0:
                output = utils.gradtensor_to_tensor(model, include_bn=True)
                hessian = torch.zeros(output.nelement(), output.nelement())
                hessian[:, pp] = output

            hessian[:, pp] = utils.gradtensor_to_tensor(model,
                                                        include_bn=True).cpu()

        ## SAVE THOSE OUTPUTS ##
        hessians.append(hessian)
        test_errors.append(loss_func(model(test_x), test_y).item())

    ## SAVE EVERYTHING ##
    fpath = "./"
    fname = "hessians.P"
    with open(fpath + fname, 'wb') as fp:
        pickle.dump(hessians, fp)

    fname = "test_errors.P"
    with open(fpath + fname, 'wb') as fp:
        pickle.dump(test_errors, fp)

    fname = "n_pars.P"
    with open(fpath + fname, 'wb') as fp:
        pickle.dump(n_pars, fp)

    fname = "widths.pt"
    torch.save(widths, fpath + fname)