def test_from_dict_to_pvector(): eps = 1e-8 model = ConvNet() v = PVector.from_model(model) d1 = v.get_dict_representation() v2 = PVector(v.layer_collection, vector_repr=v.get_flat_representation()) d2 = v2.get_dict_representation() assert d1.keys() == d2.keys() for k in d1.keys(): assert torch.norm(d1[k][0] - d2[k][0]) < eps if len(d1[k]) == 2: assert torch.norm(d1[k][1] - d2[k][1]) < eps
def test_grad_dict_repr(): loader, lc, parameters, model, function, n_output = get_conv_gn_task() d, _ = next(iter(loader)) scalar_output = model(to_device(d)).sum() vec = PVector.from_model(model) grad_nng = grad(scalar_output, vec, retain_graph=True) scalar_output.backward() grad_direct = PVector.from_model_grad(model) check_tensors(grad_direct.get_flat_representation(), grad_nng.get_flat_representation())
def test_detach(): eps = 1e-8 model = ConvNet() pvec = PVector.from_model(model) pvec_clone = pvec.clone() # first check grad on pvec_clone loss = torch.norm(pvec_clone.get_flat_representation()) loss.backward() pvec_clone_dict = pvec_clone.get_dict_representation() pvec_dict = pvec.get_dict_representation() for layer_id, layer in pvec.layer_collection.layers.items(): assert torch.norm(pvec_dict[layer_id][0].grad) > eps assert pvec_clone_dict[layer_id][0].grad is None pvec_dict[layer_id][0].grad.zero_() if layer.bias is not None: assert torch.norm(pvec_dict[layer_id][1].grad) > eps assert pvec_clone_dict[layer_id][1].grad is None pvec_dict[layer_id][1].grad.zero_() # second check that detached grad stays at 0 when detaching y = torch.tensor(1., requires_grad=True) loss = torch.norm(pvec.detach().get_flat_representation()) + y loss.backward() for layer_id, layer in pvec.layer_collection.layers.items(): assert torch.norm(pvec_dict[layer_id][0].grad) < eps if layer.bias is not None: assert torch.norm(pvec_dict[layer_id][1].grad) < eps
def test_from_to_model(): model1 = ConvNet() model2 = ConvNet() w1 = PVector.from_model(model1).clone() w2 = PVector.from_model(model2).clone() model3 = ConvNet() w1.copy_to_model(model3) # now model1 and model3 should be the same for p1, p3 in zip(model1.parameters(), model3.parameters()): check_tensors(p1, p3) ### diff_1_2 = w2 - w1 diff_1_2.add_to_model(model3) # now model2 and model3 should be the same for p2, p3 in zip(model2.parameters(), model3.parameters()): check_tensors(p2, p3)
def test_PVector_pickle(): _, _, _, model, _, _ = get_conv_task() vec = PVector.from_model(model) with open('/tmp/PVec.pkl', 'wb') as f: pkl.dump(vec, f) with open('/tmp/PVec.pkl', 'rb') as f: vec_pkl = pkl.load(f) check_tensors(vec.get_flat_representation(), vec_pkl.get_flat_representation())
def test_pspace_kfac_eigendecomposition(): """ Check KFAC eigendecomposition by comparing Mv products with v where v are the top eigenvectors. The remaining ones can be more erratic because of numerical precision """ eps = 1e-3 loader, lc, parameters, model, function, n_output = get_fullyconnect_task() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) M_kfac = PMatKFAC(generator) M_kfac.compute_eigendecomposition() evals, evecs = M_kfac.get_eigendecomposition() # Loop through all vectors in KFE l_to_m, _ = lc.get_layerid_module_maps(model) for l_id, layer in lc.layers.items(): for i_a in range(-3, 0): for i_g in range(-3, 0): evec_v = dict() for l_id2, layer2 in lc.layers.items(): m = l_to_m[l_id2] if l_id2 == l_id: v_a = evecs[l_id][0][:, i_a].unsqueeze(0) v_g = evecs[l_id][1][:, i_g].unsqueeze(1) evec_block = kronecker(v_g, v_a) evec_tuple = (evec_block[:, :-1].contiguous(), evec_block[:, -1].contiguous()) evec_v[l_id] = evec_tuple else: evec_v[l_id2] = (torch.zeros_like(m.weight), torch.zeros_like(m.bias)) evec_v = PVector(lc, dict_repr=evec_v) Mv = M_kfac.mv(evec_v) angle_v_Mv = angle(Mv, evec_v) assert angle_v_Mv < 1 + eps and angle_v_Mv > 1 - eps norm_mv = torch.norm(Mv.get_flat_representation()) check_ratio(evals[l_id][0][i_a] * evals[l_id][1][i_g], norm_mv)
def test_clone(): eps = 1e-8 model = ConvNet() pvec = PVector.from_model(model) pvec_clone = pvec.clone() l_to_m, _ = pvec.layer_collection.get_layerid_module_maps(model) for layer_id, layer in pvec.layer_collection.layers.items(): m = l_to_m[layer_id] assert m.weight is pvec.get_dict_representation()[layer_id][0] assert (m.weight is not pvec_clone.get_dict_representation()[layer_id][0]) assert (torch.norm(m.weight - pvec_clone.get_dict_representation()[layer_id][0]) < eps) if m.bias is not None: assert m.bias is pvec.get_dict_representation()[layer_id][1] assert (m.bias is not pvec_clone.get_dict_representation()[layer_id][1]) assert (torch.norm(m.bias - pvec_clone.get_dict_representation()[layer_id] [1]) < eps)
def grad(output, vec, *args, **kwargs): """ Computes the gradient of `output` with respect to the `PVector` `vec` ..warning This function only works when internally your `vec` has been created from leaf nodes in the graph (e.g. model parameters) :param output: The scalar quantity to be differentiated :param vec: a `PVector` :return: a `PVector` of gradients of `output` w.r.t `vec` """ if vec.dict_repr is not None: # map all parameters to a list params = [] pos = [] lenghts = [] current_pos = 0 for k in vec.dict_repr.keys(): p = vec.dict_repr[k] params += list(p) pos.append(current_pos) lenghts.append(len(p)) current_pos = current_pos + len(p) grad_list = torch.autograd.grad(output, params, *args, **kwargs) dict_repr_grad = dict() for k, p, l in zip(vec.dict_repr.keys(), pos, lenghts): if l == 1: dict_repr_grad[k] = (grad_list[p], ) elif l == 2: dict_repr_grad[k] = (grad_list[p], grad_list[p + 1]) return PVector(vec.layer_collection, dict_repr=dict_repr_grad) else: raise RuntimeError('grad only works with the vector is created ' + 'from leaf nodes in the computation graph')
def test_jacobian_pblockdiag(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_blockdiag = PMatBlockDiag(generator) dw = random_pvector(lc, device=device) dense_tensor = PMat_blockdiag.get_dense_tensor() # Test get_diag check_tensors(torch.diag(dense_tensor), PMat_blockdiag.get_diag()) # Test frobenius frob_PMat = PMat_blockdiag.frobenius_norm() frob_direct = (dense_tensor**2).sum()**.5 check_ratio(frob_direct, frob_PMat) # Test trace trace_PMat = PMat_blockdiag.trace() trace_direct = torch.trace(dense_tensor) check_ratio(trace_PMat, trace_direct) # Test mv mv_direct = torch.mv(dense_tensor, dw.get_flat_representation()) check_tensors(mv_direct, PMat_blockdiag.mv(dw).get_flat_representation()) # Test vTMV check_ratio(torch.dot(mv_direct, dw.get_flat_representation()), PMat_blockdiag.vTMv(dw)) # Test solve regul = 1e-3 Mv_regul = torch.mv(dense_tensor + regul * torch.eye(PMat_blockdiag.size(0), device=device), dw.get_flat_representation()) Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul) dw_using_inv = PMat_blockdiag.solve(Mv_regul, regul=regul) check_tensors(dw.get_flat_representation(), dw_using_inv.get_flat_representation(), eps=5e-3) # Test inv PMat_inv = PMat_blockdiag.inverse(regul=regul) check_tensors(dw.get_flat_representation(), PMat_inv.mv(PMat_blockdiag.mv(dw) + regul * dw) .get_flat_representation(), eps=5e-3) # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_blockdiag2 = PMatBlockDiag(generator) check_tensors(PMat_blockdiag.get_dense_tensor() + PMat_blockdiag2.get_dense_tensor(), (PMat_blockdiag + PMat_blockdiag2) .get_dense_tensor()) check_tensors(PMat_blockdiag.get_dense_tensor() - PMat_blockdiag2.get_dense_tensor(), (PMat_blockdiag - PMat_blockdiag2) .get_dense_tensor()) check_tensors(1.23 * PMat_blockdiag.get_dense_tensor(), (1.23 * PMat_blockdiag).get_dense_tensor())
def test_jacobian_pdiag_vs_pdense(): for get_task in nonlinear_tasks: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_diag = PMatDiag(generator) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) # Test get_dense_tensor matrix_diag = PMat_diag.get_dense_tensor() matrix_dense = PMat_dense.get_dense_tensor() check_tensors(torch.diag(matrix_diag), torch.diag(matrix_dense)) assert torch.norm(matrix_diag - torch.diag(torch.diag(matrix_diag))) < 1e-5 # Test trace check_ratio(torch.trace(matrix_diag), PMat_diag.trace()) # Test frobenius check_ratio(torch.norm(matrix_diag), PMat_diag.frobenius_norm()) # Test mv mv_direct = torch.mv(matrix_diag, dw.get_flat_representation()) mv_PMat_diag = PMat_diag.mv(dw) check_tensors(mv_direct, mv_PMat_diag.get_flat_representation()) # Test vTMv vTMv_direct = torch.dot(mv_direct, dw.get_flat_representation()) vTMv_PMat_diag = PMat_diag.vTMv(dw) check_ratio(vTMv_direct, vTMv_PMat_diag) # Test inverse regul = 1e-3 PMat_diag_inverse = PMat_diag.inverse(regul) prod = torch.mm(matrix_diag + regul * torch.eye(lc.numel(), device=device), PMat_diag_inverse.get_dense_tensor()) check_tensors(torch.eye(lc.numel(), device=device), prod) # Test solve regul = 1e-3 Mv_regul = torch.mv(matrix_diag + regul * torch.eye(PMat_diag.size(0), device=device), dw.get_flat_representation()) Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul) dw_using_inv = PMat_diag.solve(Mv_regul, regul=regul) check_tensors(dw.get_flat_representation(), dw_using_inv.get_flat_representation(), eps=5e-3) # Test get_diag diag_direct = torch.diag(matrix_diag) diag_PMat_diag = PMat_diag.get_diag() check_tensors(diag_direct, diag_PMat_diag) # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output) PMat_diag2 = PMatDiag(generator) check_tensors(PMat_diag.get_dense_tensor() + PMat_diag2.get_dense_tensor(), (PMat_diag + PMat_diag2).get_dense_tensor()) check_tensors(PMat_diag.get_dense_tensor() - PMat_diag2.get_dense_tensor(), (PMat_diag - PMat_diag2).get_dense_tensor()) check_tensors(1.23 * PMat_diag.get_dense_tensor(), (1.23 * PMat_diag).get_dense_tensor())
def test_jacobian_pdense(): for get_task in nonlinear_tasks: for centering in [True, False]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) PMat_dense = PMatDense(generator) dw = random_pvector(lc, device=device) # Test get_diag check_tensors(torch.diag(PMat_dense.get_dense_tensor()), PMat_dense.get_diag()) # Test frobenius frob_PMat = PMat_dense.frobenius_norm() frob_direct = (PMat_dense.get_dense_tensor()**2).sum()**.5 check_ratio(frob_direct, frob_PMat) # Test trace trace_PMat = PMat_dense.trace() trace_direct = torch.trace(PMat_dense.get_dense_tensor()) check_ratio(trace_PMat, trace_direct) # Test solve # NB: regul is high since the matrix is not full rank regul = 1e-3 Mv_regul = torch.mv(PMat_dense.get_dense_tensor() + regul * torch.eye(PMat_dense.size(0), device=device), dw.get_flat_representation()) Mv_regul = PVector(layer_collection=lc, vector_repr=Mv_regul) dw_using_inv = PMat_dense.solve(Mv_regul, regul=regul) check_tensors(dw.get_flat_representation(), dw_using_inv.get_flat_representation(), eps=5e-3) # Test inv PMat_inv = PMat_dense.inverse(regul=regul) check_tensors(dw.get_flat_representation(), PMat_inv.mv(PMat_dense.mv(dw) + regul * dw) .get_flat_representation(), eps=5e-3) # Test add, sub, rmul loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian(layer_collection=lc, model=model, loader=loader, function=function, n_output=n_output, centering=centering) PMat_dense2 = PMatDense(generator) check_tensors(PMat_dense.get_dense_tensor() + PMat_dense2.get_dense_tensor(), (PMat_dense + PMat_dense2).get_dense_tensor()) check_tensors(PMat_dense.get_dense_tensor() - PMat_dense2.get_dense_tensor(), (PMat_dense - PMat_dense2).get_dense_tensor()) check_tensors(1.23 * PMat_dense.get_dense_tensor(), (1.23 * PMat_dense).get_dense_tensor())
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') model = model.to(device) if True: # print('load') # id_epoch = '' # model.load_state_dict(torch.load('/home/pezeshki/scratch/dd/Deep-Double-Descent/runs2/cifar10/resnet_' + str(int(label_noise*100)) + '_k' + str(k) + '/ckpt' + str(id_epoch) + '.pkl')['net']) # flat_params = [] # for p in model.parameters(): # flat_params += [p.view(-1)] # flat_params = torch.cat(flat_params) flat_params = PVector.from_model(model).get_flat_representation() sums = torch.zeros(*flat_params.shape).cuda() sums_sqr = torch.zeros(*flat_params.shape).cuda() model.eval() def output_fn(input, target): # input = input.to('cuda') return model(input) layer_collection = LayerCollection.from_model(model) layer_collection.numel() # loader = torch.utils.data.DataLoader( # test_data, batch_size=150, shuffle=False, num_workers=0, # drop_last=False) loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True, num_workers=0,
def implicit_mv(self, v, examples): # add hooks self.handles += self._add_hooks(self._hook_savex, self._hook_compute_Jv, self.l_to_m.values()) self._v = v.get_dict_representation() parameters = [] output = dict() for layer_id, layer in self.layer_collection.layers.items(): mod = self.l_to_m[layer_id] mod_class = mod.__class__.__name__ if mod_class in ['BatchNorm1d', 'BatchNorm2d']: raise NotImplementedError parameters.append(mod.weight) output[mod.weight] = torch.zeros_like(mod.weight) if layer.bias is not None: parameters.append(mod.bias) output[mod.bias] = torch.zeros_like(mod.bias) device = next(self.model.parameters()).device loader = self._get_dataloader(examples) n_examples = len(loader.sampler) self.i_output = 0 self.start = 0 for d in loader: inputs = d[0] inputs.requires_grad = True bs = inputs.size(0) f_output = self.function(*d).view(bs, self.n_output) for i in range(self.n_output): # TODO reuse instead of reallocating memory self._Jv = torch.zeros((1, bs), device=device) self.compute_switch = True torch.autograd.grad(f_output[:, i].sum(dim=0), [inputs], retain_graph=True, only_inputs=True) self.compute_switch = False pseudo_loss = torch.dot(self._Jv[0, :], f_output[:, i]) grads = torch.autograd.grad(pseudo_loss, parameters, retain_graph=i < self.n_output - 1, only_inputs=True) for i_p, p in enumerate(parameters): output[p].add_(grads[i_p]) output_dict = dict() for layer_id, layer in self.layer_collection.layers.items(): mod = self.l_to_m[layer_id] if layer.bias is None: output_dict[layer_id] = (output[mod.weight] / n_examples, ) else: output_dict[layer_id] = (output[mod.weight] / n_examples, output[mod.bias] / n_examples) # remove hooks self.xs = dict() del self._Jv del self._v del self.compute_switch for h in self.handles: h.remove() return PVector(layer_collection=self.layer_collection, dict_repr=output_dict)