def test_detach(): eps = 1e-8 model = ConvNet() pvec = PVector.from_model(model) pvec_clone = pvec.clone() # first check grad on pvec_clone loss = torch.norm(pvec_clone.get_flat_representation()) loss.backward() pvec_clone_dict = pvec_clone.get_dict_representation() pvec_dict = pvec.get_dict_representation() for layer_id, layer in pvec.layer_collection.layers.items(): assert torch.norm(pvec_dict[layer_id][0].grad) > eps assert pvec_clone_dict[layer_id][0].grad is None pvec_dict[layer_id][0].grad.zero_() if layer.bias is not None: assert torch.norm(pvec_dict[layer_id][1].grad) > eps assert pvec_clone_dict[layer_id][1].grad is None pvec_dict[layer_id][1].grad.zero_() # second check that detached grad stays at 0 when detaching y = torch.tensor(1., requires_grad=True) loss = torch.norm(pvec.detach().get_flat_representation()) + y loss.backward() for layer_id, layer in pvec.layer_collection.layers.items(): assert torch.norm(pvec_dict[layer_id][0].grad) < eps if layer.bias is not None: assert torch.norm(pvec_dict[layer_id][1].grad) < eps
def test_from_dict_to_pvector(): eps = 1e-8 model = ConvNet() v = PVector.from_model(model) d1 = v.get_dict_representation() v2 = PVector(v.layer_collection, vector_repr=v.get_flat_representation()) d2 = v2.get_dict_representation() assert d1.keys() == d2.keys() for k in d1.keys(): assert torch.norm(d1[k][0] - d2[k][0]) < eps if len(d1[k]) == 2: assert torch.norm(d1[k][1] - d2[k][1]) < eps
def test_from_to_model(): model1 = ConvNet() model2 = ConvNet() w1 = PVector.from_model(model1).clone() w2 = PVector.from_model(model2).clone() model3 = ConvNet() w1.copy_to_model(model3) # now model1 and model3 should be the same for p1, p3 in zip(model1.parameters(), model3.parameters()): check_tensors(p1, p3) ### diff_1_2 = w2 - w1 diff_1_2.add_to_model(model3) # now model2 and model3 should be the same for p2, p3 in zip(model2.parameters(), model3.parameters()): check_tensors(p2, p3)
def test_PVector_pickle(): _, _, _, model, _, _ = get_conv_task() vec = PVector.from_model(model) with open('/tmp/PVec.pkl', 'wb') as f: pkl.dump(vec, f) with open('/tmp/PVec.pkl', 'rb') as f: vec_pkl = pkl.load(f) check_tensors(vec.get_flat_representation(), vec_pkl.get_flat_representation())
def test_grad_dict_repr(): loader, lc, parameters, model, function, n_output = get_conv_gn_task() d, _ = next(iter(loader)) scalar_output = model(to_device(d)).sum() vec = PVector.from_model(model) grad_nng = grad(scalar_output, vec, retain_graph=True) scalar_output.backward() grad_direct = PVector.from_model_grad(model) check_tensors(grad_direct.get_flat_representation(), grad_nng.get_flat_representation())
def test_clone(): eps = 1e-8 model = ConvNet() pvec = PVector.from_model(model) pvec_clone = pvec.clone() l_to_m, _ = pvec.layer_collection.get_layerid_module_maps(model) for layer_id, layer in pvec.layer_collection.layers.items(): m = l_to_m[layer_id] assert m.weight is pvec.get_dict_representation()[layer_id][0] assert (m.weight is not pvec_clone.get_dict_representation()[layer_id][0]) assert (torch.norm(m.weight - pvec_clone.get_dict_representation()[layer_id][0]) < eps) if m.bias is not None: assert m.bias is pvec.get_dict_representation()[layer_id][1] assert (m.bias is not pvec_clone.get_dict_representation()[layer_id][1]) assert (torch.norm(m.bias - pvec_clone.get_dict_representation()[layer_id] [1]) < eps)
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') model = model.to(device) if True: # print('load') # id_epoch = '' # model.load_state_dict(torch.load('/home/pezeshki/scratch/dd/Deep-Double-Descent/runs2/cifar10/resnet_' + str(int(label_noise*100)) + '_k' + str(k) + '/ckpt' + str(id_epoch) + '.pkl')['net']) # flat_params = [] # for p in model.parameters(): # flat_params += [p.view(-1)] # flat_params = torch.cat(flat_params) flat_params = PVector.from_model(model).get_flat_representation() sums = torch.zeros(*flat_params.shape).cuda() sums_sqr = torch.zeros(*flat_params.shape).cuda() model.eval() def output_fn(input, target): # input = input.to('cuda') return model(input) layer_collection = LayerCollection.from_model(model) layer_collection.numel() # loader = torch.utils.data.DataLoader( # test_data, batch_size=150, shuffle=False, num_workers=0, # drop_last=False) loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True, num_workers=0,