示例#1
0
def test_from_dict_to_pvector():
    eps = 1e-8
    model = ConvNet()
    v = PVector.from_model(model)
    d1 = v.get_dict_representation()
    v2 = PVector(v.layer_collection, vector_repr=v.get_flat_representation())
    d2 = v2.get_dict_representation()
    assert d1.keys() == d2.keys()
    for k in d1.keys():
        assert torch.norm(d1[k][0] - d2[k][0]) < eps
        if len(d1[k]) == 2:
            assert torch.norm(d1[k][1] - d2[k][1]) < eps
示例#2
0
def test_grad_dict_repr():
    loader, lc, parameters, model, function, n_output = get_conv_gn_task()

    d, _ = next(iter(loader))
    scalar_output = model(to_device(d)).sum()
    vec = PVector.from_model(model)

    grad_nng = grad(scalar_output, vec, retain_graph=True)

    scalar_output.backward()
    grad_direct = PVector.from_model_grad(model)

    check_tensors(grad_direct.get_flat_representation(),
                  grad_nng.get_flat_representation())
示例#3
0
def test_detach():
    eps = 1e-8
    model = ConvNet()
    pvec = PVector.from_model(model)
    pvec_clone = pvec.clone()

    # first check grad on pvec_clone
    loss = torch.norm(pvec_clone.get_flat_representation())
    loss.backward()
    pvec_clone_dict = pvec_clone.get_dict_representation()
    pvec_dict = pvec.get_dict_representation()
    for layer_id, layer in pvec.layer_collection.layers.items():
        assert torch.norm(pvec_dict[layer_id][0].grad) > eps
        assert pvec_clone_dict[layer_id][0].grad is None
        pvec_dict[layer_id][0].grad.zero_()
        if layer.bias is not None:
            assert torch.norm(pvec_dict[layer_id][1].grad) > eps
            assert pvec_clone_dict[layer_id][1].grad is None
            pvec_dict[layer_id][1].grad.zero_()

    # second check that detached grad stays at 0 when detaching
    y = torch.tensor(1., requires_grad=True)
    loss = torch.norm(pvec.detach().get_flat_representation()) + y
    loss.backward()
    for layer_id, layer in pvec.layer_collection.layers.items():
        assert torch.norm(pvec_dict[layer_id][0].grad) < eps
        if layer.bias is not None:
            assert torch.norm(pvec_dict[layer_id][1].grad) < eps
示例#4
0
def test_from_to_model():
    model1 = ConvNet()
    model2 = ConvNet()

    w1 = PVector.from_model(model1).clone()
    w2 = PVector.from_model(model2).clone()

    model3 = ConvNet()
    w1.copy_to_model(model3)
    # now model1 and model3 should be the same

    for p1, p3 in zip(model1.parameters(), model3.parameters()):
        check_tensors(p1, p3)

    ###
    diff_1_2 = w2 - w1
    diff_1_2.add_to_model(model3)
    # now model2 and model3 should be the same

    for p2, p3 in zip(model2.parameters(), model3.parameters()):
        check_tensors(p2, p3)
示例#5
0
def test_PVector_pickle():
    _, _, _, model, _, _ = get_conv_task()

    vec = PVector.from_model(model)

    with open('/tmp/PVec.pkl', 'wb') as f:
        pkl.dump(vec, f)

    with open('/tmp/PVec.pkl', 'rb') as f:
        vec_pkl = pkl.load(f)

    check_tensors(vec.get_flat_representation(),
                  vec_pkl.get_flat_representation())
示例#6
0
def test_pspace_kfac_eigendecomposition():
    """
    Check KFAC eigendecomposition by comparing Mv products with v
    where v are the top eigenvectors. The remaining ones can be
    more erratic because of numerical precision
    """
    eps = 1e-3
    loader, lc, parameters, model, function, n_output = get_fullyconnect_task()

    generator = Jacobian(layer_collection=lc,
                         model=model,
                         loader=loader,
                         function=function,
                         n_output=n_output)

    M_kfac = PMatKFAC(generator)
    M_kfac.compute_eigendecomposition()
    evals, evecs = M_kfac.get_eigendecomposition()
    # Loop through all vectors in KFE
    l_to_m, _ = lc.get_layerid_module_maps(model)
    for l_id, layer in lc.layers.items():
        for i_a in range(-3, 0):
            for i_g in range(-3, 0):
                evec_v = dict()
                for l_id2, layer2 in lc.layers.items():
                    m = l_to_m[l_id2]
                    if l_id2 == l_id:
                        v_a = evecs[l_id][0][:, i_a].unsqueeze(0)
                        v_g = evecs[l_id][1][:, i_g].unsqueeze(1)
                        evec_block = kronecker(v_g, v_a)
                        evec_tuple = (evec_block[:, :-1].contiguous(),
                                      evec_block[:, -1].contiguous())
                        evec_v[l_id] = evec_tuple
                    else:
                        evec_v[l_id2] = (torch.zeros_like(m.weight),
                                         torch.zeros_like(m.bias))
                evec_v = PVector(lc, dict_repr=evec_v)
                Mv = M_kfac.mv(evec_v)
                angle_v_Mv = angle(Mv, evec_v)
                assert angle_v_Mv < 1 + eps and angle_v_Mv > 1 - eps
                norm_mv = torch.norm(Mv.get_flat_representation())
                check_ratio(evals[l_id][0][i_a] * evals[l_id][1][i_g], norm_mv)
示例#7
0
def test_clone():
    eps = 1e-8
    model = ConvNet()
    pvec = PVector.from_model(model)
    pvec_clone = pvec.clone()
    l_to_m, _ = pvec.layer_collection.get_layerid_module_maps(model)

    for layer_id, layer in pvec.layer_collection.layers.items():
        m = l_to_m[layer_id]
        assert m.weight is pvec.get_dict_representation()[layer_id][0]
        assert (m.weight is not
                pvec_clone.get_dict_representation()[layer_id][0])
        assert (torch.norm(m.weight -
                           pvec_clone.get_dict_representation()[layer_id][0])
                < eps)
        if m.bias is not None:
            assert m.bias is pvec.get_dict_representation()[layer_id][1]
            assert (m.bias is not
                    pvec_clone.get_dict_representation()[layer_id][1])
            assert (torch.norm(m.bias -
                               pvec_clone.get_dict_representation()[layer_id]
                               [1])
                    < eps)
示例#8
0
def grad(output, vec, *args, **kwargs):
    """
    Computes the gradient of `output` with respect to the `PVector` `vec`

    ..warning This function only works when internally your `vec` has been
        created from leaf nodes in the graph (e.g. model parameters)
    
    :param output: The scalar quantity to be differentiated
    :param vec: a `PVector`
    :return: a `PVector` of gradients of `output` w.r.t `vec`
    """
    if vec.dict_repr is not None:
        # map all parameters to a list
        params = []
        pos = []
        lenghts = []
        current_pos = 0
        for k in vec.dict_repr.keys():
            p = vec.dict_repr[k]
            params += list(p)
            pos.append(current_pos)
            lenghts.append(len(p))
            current_pos = current_pos + len(p)

        grad_list = torch.autograd.grad(output, params, *args, **kwargs)
        dict_repr_grad = dict()

        for k, p, l in zip(vec.dict_repr.keys(), pos, lenghts):
            if l == 1:
                dict_repr_grad[k] = (grad_list[p], )
            elif l == 2:
                dict_repr_grad[k] = (grad_list[p], grad_list[p + 1])

        return PVector(vec.layer_collection, dict_repr=dict_repr_grad)
    else:
        raise RuntimeError('grad only works with the vector is created ' +
                           'from leaf nodes in the computation graph')
示例#9
0
def test_jacobian_pblockdiag():
    for get_task in nonlinear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_blockdiag = PMatBlockDiag(generator)
        dw = random_pvector(lc, device=device)
        dense_tensor = PMat_blockdiag.get_dense_tensor()

        # Test get_diag
        check_tensors(torch.diag(dense_tensor),
                      PMat_blockdiag.get_diag())

        # Test frobenius
        frob_PMat = PMat_blockdiag.frobenius_norm()
        frob_direct = (dense_tensor**2).sum()**.5
        check_ratio(frob_direct, frob_PMat)

        # Test trace
        trace_PMat = PMat_blockdiag.trace()
        trace_direct = torch.trace(dense_tensor)
        check_ratio(trace_PMat, trace_direct)

        # Test mv
        mv_direct = torch.mv(dense_tensor, dw.get_flat_representation())
        check_tensors(mv_direct,
                      PMat_blockdiag.mv(dw).get_flat_representation())

        # Test vTMV
        check_ratio(torch.dot(mv_direct, dw.get_flat_representation()),
                    PMat_blockdiag.vTMv(dw))

        # Test solve
        regul = 1e-3
        Mv_regul = torch.mv(dense_tensor +
                            regul * torch.eye(PMat_blockdiag.size(0),
                                              device=device),
                            dw.get_flat_representation())
        Mv_regul = PVector(layer_collection=lc,
                           vector_repr=Mv_regul)
        dw_using_inv = PMat_blockdiag.solve(Mv_regul, regul=regul)
        check_tensors(dw.get_flat_representation(),
                      dw_using_inv.get_flat_representation(), eps=5e-3)

        # Test inv
        PMat_inv = PMat_blockdiag.inverse(regul=regul)
        check_tensors(dw.get_flat_representation(),
                      PMat_inv.mv(PMat_blockdiag.mv(dw) + regul * dw)
                      .get_flat_representation(), eps=5e-3)

        # Test add, sub, rmul
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_blockdiag2 = PMatBlockDiag(generator)

        check_tensors(PMat_blockdiag.get_dense_tensor() +
                      PMat_blockdiag2.get_dense_tensor(),
                      (PMat_blockdiag + PMat_blockdiag2)
                      .get_dense_tensor())
        check_tensors(PMat_blockdiag.get_dense_tensor() -
                      PMat_blockdiag2.get_dense_tensor(),
                      (PMat_blockdiag - PMat_blockdiag2)
                      .get_dense_tensor())
        check_tensors(1.23 * PMat_blockdiag.get_dense_tensor(),
                      (1.23 * PMat_blockdiag).get_dense_tensor())
示例#10
0
def test_jacobian_pdiag_vs_pdense():
    for get_task in nonlinear_tasks:
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_diag = PMatDiag(generator)
        PMat_dense = PMatDense(generator)
        dw = random_pvector(lc, device=device)

        # Test get_dense_tensor
        matrix_diag = PMat_diag.get_dense_tensor()
        matrix_dense = PMat_dense.get_dense_tensor()
        check_tensors(torch.diag(matrix_diag),
                      torch.diag(matrix_dense))
        assert torch.norm(matrix_diag -
                          torch.diag(torch.diag(matrix_diag))) < 1e-5

        # Test trace
        check_ratio(torch.trace(matrix_diag),
                    PMat_diag.trace())

        # Test frobenius
        check_ratio(torch.norm(matrix_diag),
                    PMat_diag.frobenius_norm())

        # Test mv
        mv_direct = torch.mv(matrix_diag, dw.get_flat_representation())
        mv_PMat_diag = PMat_diag.mv(dw)
        check_tensors(mv_direct,
                      mv_PMat_diag.get_flat_representation())

        # Test vTMv
        vTMv_direct = torch.dot(mv_direct, dw.get_flat_representation())
        vTMv_PMat_diag = PMat_diag.vTMv(dw)
        check_ratio(vTMv_direct, vTMv_PMat_diag)

        # Test inverse
        regul = 1e-3
        PMat_diag_inverse = PMat_diag.inverse(regul)
        prod = torch.mm(matrix_diag + regul * torch.eye(lc.numel(),
                                                        device=device),
                        PMat_diag_inverse.get_dense_tensor())
        check_tensors(torch.eye(lc.numel(), device=device),
                      prod)

        # Test solve
        regul = 1e-3
        Mv_regul = torch.mv(matrix_diag +
                            regul * torch.eye(PMat_diag.size(0),
                                              device=device),
                            dw.get_flat_representation())
        Mv_regul = PVector(layer_collection=lc,
                           vector_repr=Mv_regul)
        dw_using_inv = PMat_diag.solve(Mv_regul, regul=regul)
        check_tensors(dw.get_flat_representation(),
                      dw_using_inv.get_flat_representation(), eps=5e-3)

        # Test get_diag
        diag_direct = torch.diag(matrix_diag)
        diag_PMat_diag = PMat_diag.get_diag()
        check_tensors(diag_direct, diag_PMat_diag)

        # Test add, sub, rmul
        loader, lc, parameters, model, function, n_output = get_task()
        model.train()
        generator = Jacobian(layer_collection=lc,
                             model=model,
                             loader=loader,
                             function=function,
                             n_output=n_output)
        PMat_diag2 = PMatDiag(generator)

        check_tensors(PMat_diag.get_dense_tensor() +
                      PMat_diag2.get_dense_tensor(),
                      (PMat_diag + PMat_diag2).get_dense_tensor())
        check_tensors(PMat_diag.get_dense_tensor() -
                      PMat_diag2.get_dense_tensor(),
                      (PMat_diag - PMat_diag2).get_dense_tensor())
        check_tensors(1.23 * PMat_diag.get_dense_tensor(),
                      (1.23 * PMat_diag).get_dense_tensor())
示例#11
0
def test_jacobian_pdense():
    for get_task in nonlinear_tasks:
        for centering in [True, False]:
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            generator = Jacobian(layer_collection=lc,
                                 model=model,
                                 loader=loader,
                                 function=function,
                                 n_output=n_output,
                                 centering=centering)
            PMat_dense = PMatDense(generator)
            dw = random_pvector(lc, device=device)

            # Test get_diag
            check_tensors(torch.diag(PMat_dense.get_dense_tensor()),
                          PMat_dense.get_diag())

            # Test frobenius
            frob_PMat = PMat_dense.frobenius_norm()
            frob_direct = (PMat_dense.get_dense_tensor()**2).sum()**.5
            check_ratio(frob_direct, frob_PMat)

            # Test trace
            trace_PMat = PMat_dense.trace()
            trace_direct = torch.trace(PMat_dense.get_dense_tensor())
            check_ratio(trace_PMat, trace_direct)

            # Test solve
            # NB: regul is high since the matrix is not full rank
            regul = 1e-3
            Mv_regul = torch.mv(PMat_dense.get_dense_tensor() +
                                regul * torch.eye(PMat_dense.size(0),
                                                  device=device),
                                dw.get_flat_representation())
            Mv_regul = PVector(layer_collection=lc,
                               vector_repr=Mv_regul)
            dw_using_inv = PMat_dense.solve(Mv_regul, regul=regul)
            check_tensors(dw.get_flat_representation(),
                          dw_using_inv.get_flat_representation(), eps=5e-3)

            # Test inv
            PMat_inv = PMat_dense.inverse(regul=regul)
            check_tensors(dw.get_flat_representation(),
                          PMat_inv.mv(PMat_dense.mv(dw) + regul * dw)
                          .get_flat_representation(), eps=5e-3)

            # Test add, sub, rmul
            loader, lc, parameters, model, function, n_output = get_task()
            model.train()
            generator = Jacobian(layer_collection=lc,
                                 model=model,
                                 loader=loader,
                                 function=function,
                                 n_output=n_output,
                                 centering=centering)
            PMat_dense2 = PMatDense(generator)

            check_tensors(PMat_dense.get_dense_tensor() +
                          PMat_dense2.get_dense_tensor(),
                          (PMat_dense + PMat_dense2).get_dense_tensor())
            check_tensors(PMat_dense.get_dense_tensor() -
                          PMat_dense2.get_dense_tensor(),
                          (PMat_dense - PMat_dense2).get_dense_tensor())
            check_tensors(1.23 * PMat_dense.get_dense_tensor(),
                          (1.23 * PMat_dense).get_dense_tensor())
示例#12
0
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model = model.to(device)




if True:
    # print('load')
    # id_epoch = ''
    # model.load_state_dict(torch.load('/home/pezeshki/scratch/dd/Deep-Double-Descent/runs2/cifar10/resnet_' + str(int(label_noise*100)) + '_k' + str(k) + '/ckpt' + str(id_epoch) + '.pkl')['net'])

    # flat_params = []
    # for p in model.parameters():
    #     flat_params += [p.view(-1)]
    # flat_params = torch.cat(flat_params)
    flat_params = PVector.from_model(model).get_flat_representation()
    sums = torch.zeros(*flat_params.shape).cuda()
    sums_sqr = torch.zeros(*flat_params.shape).cuda()

    model.eval()
    def output_fn(input, target):
        # input = input.to('cuda')
        return model(input)

    layer_collection = LayerCollection.from_model(model)
    layer_collection.numel()

    # loader = torch.utils.data.DataLoader(
    #     test_data, batch_size=150, shuffle=False, num_workers=0,
    #     drop_last=False)
    loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True, num_workers=0,
示例#13
0
    def implicit_mv(self, v, examples):
        # add hooks
        self.handles += self._add_hooks(self._hook_savex,
                                        self._hook_compute_Jv,
                                        self.l_to_m.values())

        self._v = v.get_dict_representation()
        parameters = []
        output = dict()
        for layer_id, layer in self.layer_collection.layers.items():
            mod = self.l_to_m[layer_id]
            mod_class = mod.__class__.__name__
            if mod_class in ['BatchNorm1d', 'BatchNorm2d']:
                raise NotImplementedError
            parameters.append(mod.weight)
            output[mod.weight] = torch.zeros_like(mod.weight)
            if layer.bias is not None:
                parameters.append(mod.bias)
                output[mod.bias] = torch.zeros_like(mod.bias)

        device = next(self.model.parameters()).device
        loader = self._get_dataloader(examples)
        n_examples = len(loader.sampler)

        self.i_output = 0
        self.start = 0
        for d in loader:
            inputs = d[0]
            inputs.requires_grad = True
            bs = inputs.size(0)

            f_output = self.function(*d).view(bs, self.n_output)
            for i in range(self.n_output):
                # TODO reuse instead of reallocating memory
                self._Jv = torch.zeros((1, bs), device=device)

                self.compute_switch = True
                torch.autograd.grad(f_output[:, i].sum(dim=0), [inputs],
                                    retain_graph=True,
                                    only_inputs=True)
                self.compute_switch = False
                pseudo_loss = torch.dot(self._Jv[0, :], f_output[:, i])
                grads = torch.autograd.grad(pseudo_loss,
                                            parameters,
                                            retain_graph=i < self.n_output - 1,
                                            only_inputs=True)
                for i_p, p in enumerate(parameters):
                    output[p].add_(grads[i_p])

        output_dict = dict()
        for layer_id, layer in self.layer_collection.layers.items():
            mod = self.l_to_m[layer_id]
            if layer.bias is None:
                output_dict[layer_id] = (output[mod.weight] / n_examples, )
            else:
                output_dict[layer_id] = (output[mod.weight] / n_examples,
                                         output[mod.bias] / n_examples)

        # remove hooks
        self.xs = dict()
        del self._Jv
        del self._v
        del self.compute_switch
        for h in self.handles:
            h.remove()

        return PVector(layer_collection=self.layer_collection,
                       dict_repr=output_dict)