def layer_alignment(model, output_fn, loader, n_output, centering=True): lc = LayerCollection.from_model(model) alignments = [] targets = torch.cat([args[1] for args in iter(loader)]) targets = one_hot(targets).float() targets -= targets.mean(dim=0) targets = FVector(vector_repr=targets.t().contiguous()) for l in lc.layers.items(): # print(l) lc_this = LayerCollection() lc_this.add_layer(*l) generator = Jacobian(layer_collection=lc_this, model=model, loader=loader, function=output_fn, n_output=n_output, centering=centering) K_dense = FMatDense(generator) yTKy = K_dense.vTMv(targets) frobK = K_dense.frobenius_norm() align = yTKy / (frobK * torch.norm(targets.get_flat_representation())**2) alignments.append(align.item()) return alignments
def get_fullyconnect_onlylast_task(): train_loader, lc_full, _, net, output_fn, n_output = \ get_fullyconnect_task() layer_collection = LayerCollection() # only keep last layer parameters layer_collection.add_layer(*lc_full.layers.popitem()) parameters = net.net[-1].parameters() return train_loader, layer_collection, parameters, net, output_fn, n_output
def test_dot(): model = ConvNet() layer_collection = LayerCollection.from_model(model) r1 = random_pvector(layer_collection) r2 = random_pvector(layer_collection) dotr1r2 = r1.dot(r2) check_ratio( torch.dot(r1.get_flat_representation(), r2.get_flat_representation()), dotr1r2) r1 = random_pvector_dict(layer_collection) r2 = random_pvector_dict(layer_collection) dotr1r2 = r1.dot(r2) check_ratio( torch.dot(r1.get_flat_representation(), r2.get_flat_representation()), dotr1r2) r1 = random_pvector(layer_collection) r2 = random_pvector_dict(layer_collection) dotr1r2 = r1.dot(r2) dotr2r1 = r2.dot(r1) check_ratio( torch.dot(r1.get_flat_representation(), r2.get_flat_representation()), dotr1r2) check_ratio( torch.dot(r1.get_flat_representation(), r2.get_flat_representation()), dotr2r1)
def test_size(): model = ConvNet() layer_collection = LayerCollection.from_model(model) v = random_pvector(layer_collection) assert v.size() == v.get_flat_representation().size() v = random_pvector_dict(layer_collection) assert v.size() == v.get_flat_representation().size()
def test_norm(): model = ConvNet() layer_collection = LayerCollection.from_model(model) v = random_pvector(layer_collection) check_ratio(torch.norm(v.get_flat_representation()), v.norm()) v = random_pvector_dict(layer_collection) check_ratio(torch.norm(v.get_flat_representation()), v.norm())
def get_batchnorm_conv_linear_task(): train_set = get_mnist() train_set = Subset(train_set, range(70)) train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) net = BatchNormConvLinearNet() to_device_model(net) net.eval() def output_fn(input, target): return net(to_device(input)) lc_full = LayerCollection.from_model(net) layer_collection = LayerCollection() # only keep fc1 and fc2 layer_collection.add_layer(*lc_full.layers.popitem()) layer_collection.add_layer(*lc_full.layers.popitem()) parameters = list(net.conv2.parameters()) + \ list(net.conv1.parameters()) return (train_loader, layer_collection, parameters, net, output_fn, 2)
def get_conv_skip_task(): train_set = get_mnist() train_set = Subset(train_set, range(70)) train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) net = ConvNetWithSkipConnection() to_device_model(net) net.eval() def output_fn(input, target): return net(to_device(input)) layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)
def get_fullyconnect_task(normalization='none'): train_set = get_mnist() train_set = Subset(train_set, range(70)) train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) net = FCNet(out_size=3, normalization=normalization) to_device_model(net) net.eval() def output_fn(input, target): return net(to_device(input)) layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)
def get_batchnorm_nonlinear_task(): train_set = get_mnist() train_set = Subset(train_set, range(70)) train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) net = BatchNormNonLinearNet() to_device_model(net) net.eval() def output_fn(input, target): return net(to_device(input)) layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 5)
def get_fullyconnect_kfac_task(bs=300): train_set = get_mnist() train_set = Subset(train_set, range(1000)) train_set = to_onexdataset(train_set, device) train_loader = DataLoader(dataset=train_set, batch_size=bs, shuffle=False) net = Net(in_size=18 * 18) net.to(device) net.eval() def output_fn(input, target): return net(to_device(input)) layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 10)
def get_linear_fc_task(): train_set = get_mnist() train_set = Subset(train_set, range(1000)) train_loader = DataLoader( dataset=train_set, batch_size=300, shuffle=False) net = LinearFCNet() net.to(device) def output_fn(input, target): return net(to_device(input)) layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 2)
def get_conv_task(normalization='none'): train_set = get_mnist() train_set = Subset(train_set, range(1000)) train_loader = DataLoader( dataset=train_set, batch_size=300, shuffle=False) net = ConvNet(normalization=normalization) net.to(device) def output_fn(input, target): return net(to_device(input)) layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)
def get_convnet_kfc_task(bs=300): train_set = datasets.MNIST(root=default_datapath, train=True, download=True, transform=transforms.ToTensor()), train_set = Subset(train_set, range(1000)) train_loader = DataLoader(dataset=train_set, batch_size=bs, shuffle=False) net = ConvNet() net.to(device) net.eval() def output_fn(input, target): return net(to_device(input)) layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 10)
def NTK_Left_SV(net, X, y): def output_fn(input, target): # input = input.to('cuda') return net(input) layer_collection = LayerCollection.from_model(net) layer_collection.numel() batch = TensorDataset(X, y) batch_loader = DataLoader(batch) generator = Jacobian(layer_collection=layer_collection, model=net, loader=batch_loader, function=output_fn, n_output=1) jac = generator.get_jacobian()[0] K = torch.mm(jac, jac.transpose(0, 1)) U, S, V = torch.svd(K, some=False) return U
def __init__(self, model, function=None, n_output=1, centering=False, layer_collection=None): self.model = model self.handles = [] self.xs = dict() self.n_output = n_output self.centering = centering if function is None: function = lambda *x: model(x[0]) self.function = function if layer_collection is None: self.layer_collection = LayerCollection.from_model(model) else: self.layer_collection = layer_collection # maps parameters to their position in flattened representation self.l_to_m, self.m_to_l = \ self.layer_collection.get_layerid_module_maps(model)
def test_sub(): model = ConvNet() layer_collection = LayerCollection.from_model(model) r1 = random_pvector(layer_collection) r2 = random_pvector(layer_collection) sumr1r2 = r1 - r2 assert torch.norm(sumr1r2.get_flat_representation() - (r1.get_flat_representation() - r2.get_flat_representation())) < 1e-5 r1 = random_pvector_dict(layer_collection) r2 = random_pvector_dict(layer_collection) sumr1r2 = r1 - r2 assert torch.norm(sumr1r2.get_flat_representation() - (r1.get_flat_representation() - r2.get_flat_representation())) < 1e-5 r1 = random_pvector(layer_collection) r2 = random_pvector_dict(layer_collection) sumr1r2 = r1 - r2 assert torch.norm(sumr1r2.get_flat_representation() - (r1.get_flat_representation() - r2.get_flat_representation())) < 1e-5
transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) trainset = datasets.CIFAR10(root='/tmp/data', train=True, download=True, transform=transform) trainset = Subset(trainset, range(100)) trainloader = DataLoader(trainset, batch_size=50, shuffle=False, num_workers=1) # %% from resnet import ResNet50 resnet = ResNet50().cuda() layer_collection = LayerCollection.from_model(resnet) v = random_pvector(LayerCollection.from_model(resnet), device='cuda') print(f'{layer_collection.numel()} parameters') # %% # compute timings and display FIMs def perform_timing(): timings = dict() for repr in [PMatImplicit, PMatDiag, PMatEKFAC, PMatKFAC, PMatQuasiDiag]: print('Timing representation:') pprint.pprint(repr)
# model.load_state_dict(torch.load('/home/pezeshki/scratch/dd/Deep-Double-Descent/runs2/cifar10/resnet_' + str(int(label_noise*100)) + '_k' + str(k) + '/ckpt' + str(id_epoch) + '.pkl')['net']) # flat_params = [] # for p in model.parameters(): # flat_params += [p.view(-1)] # flat_params = torch.cat(flat_params) flat_params = PVector.from_model(model).get_flat_representation() sums = torch.zeros(*flat_params.shape).cuda() sums_sqr = torch.zeros(*flat_params.shape).cuda() model.eval() def output_fn(input, target): # input = input.to('cuda') return model(input) layer_collection = LayerCollection.from_model(model) layer_collection.numel() # loader = torch.utils.data.DataLoader( # test_data, batch_size=150, shuffle=False, num_workers=0, # drop_last=False) loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True, num_workers=0, drop_last=False) it = iter(loader) for X, y in tqdm(it): X = X.cuda() y = y.cuda() batch = TensorDataset(X, y)