def vgg13_bn(pretrained=False, **kwargs): """VGG 13-layer model (configuration "B") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) return model
def __init__(self, vgg_model: torchvision_models.VGG): super().__init__() self.normalizer = ImageNetNormalizer() self.model = vgg_model.eval() self.layer1 = nn.Sequential(self.model.features[:4]) self.layer2 = nn.Sequential(self.model.features[4:9]) self.layer3 = nn.Sequential(self.model.features[9:16]) self.layer4 = nn.Sequential(self.model.features[16:23]) self.layer5 = nn.Sequential(self.model.features[23:30])
def get_model_by_name(model_name) -> Module: if model_name == "vgg": return VGG(vgg11_bn().features, num_classes=2) if model_name == "mobilenetv2": return MobileNetV2(num_classes=2) if model_name == "simple": return SimpleNetwork() if model_name == "squeezenet": return SqueezeNet(version=1.1, num_classes=2) raise Exception(f"Invalid model name: {model_name}.")
def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg = VGG(make_layers([ 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M' ]), init_weights=False) vgg_weight = torch.load('vgg16-397923af.pth') vgg.load_state_dict(vgg_weight) self.mean = torch.tensor([0.485, 0.456, 0.406], ).cuda() self.mean = self.mean.view(1, 3, 1, 1) self.std = torch.tensor([0.229, 0.224, 0.225]).cuda() self.std = self.std.view(1, 3, 1, 1) vgg_pretrained_features = vgg.features del vgg_weight, self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 self.slice1.add_module('0', vgg_pretrained_features[0]) self.slice2.add_module('1', vgg_pretrained_features[1]) self.slice3.add_module('2', vgg_pretrained_features[2]) self.slice4.add_module('3', vgg_pretrained_features[3]) # # for x in range(4): # self.slice1.add_module(str(x), vgg_pretrained_features[x]) # for x in range(4, 9): # self.slice2.add_module(str(x), vgg_pretrained_features[x]) # for x in range(9, 16): # self.slice3.add_module(str(x), vgg_pretrained_features[x]) # for x in range(16, 23): # self.slice4.add_module(str(x), vgg_pretrained_features[x]) # for x in range(23, 30): # self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False
def vgg16(pretrained=False, model_path=None, **kwargs): """VGG 16-layer model (configuration "D") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: kwargs['init_weights'] = False model = VGG(make_layers(cfg['D']), **kwargs) if pretrained: if model_path is not None: model.load_state_dict(torch.load(model_path)) else: model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) return model
def main(): """Check ntks in a single call.""" print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}') print( f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}' ) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False) if args.net == 'ResNet': net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width']) elif args.net == 'WideResNet': # meliketoy wideresnet variant net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10) elif args.net == 'MLP': net = torch.nn.Sequential( OrderedDict([ ('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(config['width'], config['width'])), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(config['width'], config['width'])), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10)) ])) elif args.net == 'TwoLP': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10))])) elif args.net == 'MobileNetV2': net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4) elif args.net == 'VGG': cfg_base = [ 64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' ] cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)] print(cfg) net = VGG(make_layers(cfg), num_classes=10) net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096) elif args.net == 'ConvNet': net = torch.nn.Sequential( OrderedDict([ ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)), ('relu0', torch.nn.ReLU()), # ('pool0', torch.nn.MaxPool2d(3)), ('conv1', torch.nn.Conv2d(1 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu1', torch.nn.ReLU()), # ('pool1', torch.nn.MaxPool2d(3)), ('conv2', torch.nn.Conv2d(2 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu2', torch.nn.ReLU()), # ('pool2', torch.nn.MaxPool2d(3)), ('conv3', torch.nn.Conv2d(2 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu3', torch.nn.ReLU()), ('pool3', torch.nn.MaxPool2d(3)), ('conv4', torch.nn.Conv2d(4 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu4', torch.nn.ReLU()), ('pool4', torch.nn.MaxPool2d(3)), ('flatten', torch.nn.Flatten()), ('linear', torch.nn.Linear(36 * config['width'], 10)) ])) else: raise ValueError('Invalid network specified.') net.to(**config['setup']) try: net.load_state_dict( torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth', map_location=device)) print('Initialized net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str( config["width"]) + '_before.pth' if not args.dryrun: torch.save(net.state_dict(), path) print('Initialized net saved to file.') else: print(f'Would save to {path}') num_params = sum([p.numel() for p in net.parameters()]) print( f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} ' f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%') ntk_matrix_before = batch_wise_ntk(net, trainloader, samplesize=args.sampling) plt.imshow(ntk_matrix_before) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_BEFORE.png', bbox_inches='tight', dpi=1200) ntk_matrix_before_norm = np.linalg.norm(ntk_matrix_before.flatten()) print( f'The total norm of the NTK sample before training is {ntk_matrix_before_norm:.2f}' ) param_norm_before = np.sqrt( np.sum( [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()])) print(f'The L2 norm of the parameter vector is {param_norm_before:.2f}') if args.pdist: pdist_init, cos_init, prod_init = batch_feature_correlations( trainloader) pdist_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_init]) cos_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_init]) prod_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_init]) print( f'The total norm of feature distances before training is {pdist_init_norm:.2f}' ) print( f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}' ) print( f'The total norm of feature inner product before training is {prod_init_norm:.2f}' ) save_plot(pdist_init, trainloader, name='pdist_before_training') save_plot(cos_init, trainloader, name='cosine_before_training') save_plot(prod_init, trainloader, name='prod_before_training') # Start training net.to(**config['setup']) if torch.cuda.device_count() > 1: net = torch.nn.DataParallel(net) optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) loss_fn = torch.nn.CrossEntropyLoss() print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) try: net.load_state_dict( torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth', map_location=device)) print('Net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str( config["width"]) + '_after.pth' dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun) if not args.dryrun: torch.save(net.state_dict(), path) print('Net saved to file.') else: print(f'Would save to {path}') print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) if isinstance(net, torch.nn.DataParallel): net = net.module param_norm_after = np.sqrt( np.sum( [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()])) print(f'The L2 norm of the parameter vector is {param_norm_after:.2f}') ntk_matrix_after = batch_wise_ntk(net, trainloader, samplesize=args.sampling) plt.imshow(ntk_matrix_after) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_AFTER.png', bbox_inches='tight', dpi=1200) ntk_matrix_after_norm = np.linalg.norm(ntk_matrix_after.flatten()) print( f'The total norm of the NTK sample after training is {ntk_matrix_after_norm:.2f}' ) ntk_matrix_diff = np.abs(ntk_matrix_before - ntk_matrix_after) plt.imshow(ntk_matrix_diff) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_DIFF.png', bbox_inches='tight', dpi=1200) ntk_matrix_diff_norm = np.linalg.norm(ntk_matrix_diff.flatten()) print( f'The total norm of the NTK sample diff is {ntk_matrix_diff_norm:.2f}') ntk_matrix_rdiff = np.abs(ntk_matrix_before - ntk_matrix_after) / ( np.abs(ntk_matrix_before) + 1e-4) plt.imshow(ntk_matrix_rdiff) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_RDIFF.png', bbox_inches='tight', dpi=1200) ntk_matrix_rdiff_norm = np.linalg.norm(ntk_matrix_rdiff.flatten()) print( f'The total norm of the NTK sample relative diff is {ntk_matrix_rdiff_norm:.2f}' ) n1_mean = np.mean(ntk_matrix_before) n2_mean = np.mean(ntk_matrix_after) matrix_corr = (ntk_matrix_before - n1_mean) * (ntk_matrix_after - n2_mean) / \ np.std(ntk_matrix_before) / np.std(ntk_matrix_after) plt.imshow(matrix_corr) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png', bbox_inches='tight', dpi=1200) corr_coeff = np.mean(matrix_corr) print( f'The Correlation coefficient of the NTK sample before and after training is {corr_coeff:.2f}' ) matrix_sim = (ntk_matrix_before * ntk_matrix_after) / \ np.sqrt(np.sum(ntk_matrix_before**2) * np.sum(ntk_matrix_after**2)) plt.imshow(matrix_corr) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png', bbox_inches='tight', dpi=1200) corr_tom = np.sum(matrix_sim) print( f'The Similarity coefficient of the NTK sample before and after training is {corr_tom:.2f}' ) save_output(args.table_path, name='ntk', width=config['width'], num_params=num_params, before_norm=ntk_matrix_before_norm, after_norm=ntk_matrix_after_norm, diff_norm=ntk_matrix_diff_norm, rdiff_norm=ntk_matrix_rdiff_norm, param_norm_before=param_norm_before, param_norm_after=param_norm_after, corr_coeff=corr_coeff, corr_tom=corr_tom) if args.pdist: # Check feature maps after training pdist_after, cos_after, prod_after = batch_feature_correlations( trainloader) pdist_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_after]) cos_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_after]) prod_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_after]) print( f'The total norm of feature distances after training is {pdist_after_norm:.2f}' ) print( f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}' ) print( f'The total norm of feature inner product after training is {prod_after_norm:.2f}' ) save_plot(pdist_after, trainloader, name='pdist_after_training') save_plot(cos_after, trainloader, name='cosine_after_training') save_plot(prod_after, trainloader, name='prod_after_training') # Check feature map differences pdist_ndiff = [ np.abs(co1 - co2) / pdist_init_norm for co1, co2 in zip(pdist_init, pdist_after) ] cos_ndiff = [ np.abs(co1 - co2) / cos_init_norm for co1, co2 in zip(cos_init, cos_after) ] prod_ndiff = [ np.abs(co1 - co2) / prod_init_norm for co1, co2 in zip(prod_init, prod_after) ] pdist_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff]) cos_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_ndiff]) prod_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_ndiff]) print( f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}' ) print( f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}' ) print( f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}' ) save_plot(pdist_ndiff, trainloader, name='pdist_ndiff') save_plot(cos_ndiff, trainloader, name='cosine_ndiff') save_plot(prod_ndiff, trainloader, name='prod_ndiff') # Check feature map differences pdist_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(pdist_init, pdist_after) ] cos_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(cos_init, cos_after) ] prod_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(prod_init, prod_after) ] pdist_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff]) cos_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_rdiff]) prod_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_rdiff]) print( f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}' ) print( f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}' ) print( f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}' ) save_plot(pdist_rdiff, trainloader, name='pdist_rdiff') save_plot(cos_rdiff, trainloader, name='cosine_rdiff') save_plot(prod_rdiff, trainloader, name='prod_rdiff') save_output(args.table_path, 'pdist', width=config['width'], num_params=num_params, pdist_init_norm=pdist_init_norm, pdist_after_norm=pdist_after_norm, pdist_ndiff_norm=pdist_ndiff_norm, pdist_rdiff_norm=pdist_rdiff_norm, cos_init_norm=pdist_init_norm, cos_after_norm=pdist_after_norm, cos_ndiff_norm=pdist_ndiff_norm, cos_rdiff_norm=cos_rdiff_norm, prod_init_norm=pdist_init_norm, prod_after_norm=pdist_after_norm, prod_ndiff_norm=pdist_ndiff_norm, prod_rdiff_norm=prod_rdiff_norm) # Save raw data # raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init, # pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after, # pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff, # pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff) # path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_rawmaps.pth' # torch.save(raw_pkg, path) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) print('-----------------------------------------------------') print('Job finished.----------------------------------------') print('-----------------------------------------------------')
def main(): """Check ntks in a single call.""" print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}') print( f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}' ) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False) if args.net == 'ResNet': net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width']) elif args.net == 'WideResNet': # meliketoy wideresnet variant net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10) elif args.net == 'MLP': net = torch.nn.Sequential( OrderedDict([ ('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(config['width'], config['width'])), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(config['width'], config['width'])), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10)) ])) elif args.net == 'TwoLP': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10))])) elif args.net == 'MobileNetV2': net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4) elif args.net == 'VGG': cfg_base = [ 64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' ] cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)] print(cfg) net = VGG(make_layers(cfg), num_classes=10) net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096) elif args.net == 'ConvNet': net = torch.nn.Sequential( OrderedDict([ ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)), ('relu0', torch.nn.ReLU()), # ('pool0', torch.nn.MaxPool2d(3)), ('conv1', torch.nn.Conv2d(1 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu1', torch.nn.ReLU()), # ('pool1', torch.nn.MaxPool2d(3)), ('conv2', torch.nn.Conv2d(2 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu2', torch.nn.ReLU()), # ('pool2', torch.nn.MaxPool2d(3)), ('conv3', torch.nn.Conv2d(2 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu3', torch.nn.ReLU()), ('pool3', torch.nn.MaxPool2d(3)), ('conv4', torch.nn.Conv2d(4 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu4', torch.nn.ReLU()), ('pool4', torch.nn.MaxPool2d(3)), ('flatten', torch.nn.Flatten()), ('linear', torch.nn.Linear(36 * config['width'], 10)) ])) else: raise ValueError('Invalid network specified.') net.to(**config['setup']) try: net.load_state_dict( torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth', map_location=device)) print('Initialized net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str( config["width"]) + '_before.pth' if not args.dryrun: torch.save(net.state_dict(), path) print('Initialized net saved to file.') else: print(f'Would save to {path}') num_params = sum([p.numel() for p in net.parameters()]) print( f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} ' f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%') param_norm_before = np.sqrt( np.sum( [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()])) print(f'The L2 norm of the parameter vector is {param_norm_before:.2f}') net_init = [p.detach().clone() for p in net.parameters()] # Start training net.to(**config['setup']) if torch.cuda.device_count() > 1: net = torch.nn.DataParallel(net) optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) loss_fn = torch.nn.CrossEntropyLoss() analyze_model(net, trainloader, testloader, loss_fn, config) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) try: net.load_state_dict( torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth', map_location=device)) print('Net loaded from file.') except Exception as e: # :> print(repr(e)) print('Could not find model data ... aborting ...') return print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) if isinstance(net, torch.nn.DataParallel): net = net.module param_norm_after = np.sqrt( np.sum( [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()])) print(f'The L2 norm of the parameter vector is {param_norm_after:.2f}') change_total = 0.0 for p1, p2 in zip(net_init, net.parameters()): change_total += (p1 - p2).detach().pow(2).sum() change_total = change_total.sqrt().cpu().numpy() change_rel = 0.0 for p1, p2 in zip(net_init, net.parameters()): change_rel += (p1 - p2).detach().pow(2).mean() change_rel = change_rel.sqrt().cpu().numpy() change_nrmsum = 0.0 for p1, p2 in zip(net_init, net.parameters()): change_nrmsum += (p1 - p2).norm() change_nrmsum = change_nrmsum.cpu().numpy() # Analyze results acc_train, acc_test, loss_train, loss_trainw, grd_train = analyze_model( net, trainloader, testloader, loss_fn, config) save_output(args.table_path, name='ntk_stats', width=config['width'], num_params=num_params, acc_train=acc_train, acc_test=acc_test, loss_train=loss_train, loss_trainw=loss_trainw, grd_train=grd_train, param_norm_before=param_norm_before, param_norm_after=param_norm_after, change_total=change_total, change_rel=change_rel, change_nrmsum=change_nrmsum) # Save raw data # raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init, # pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after, # pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff, # pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff) # path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_rawmaps.pth' # torch.save(raw_pkg, path) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) print('-----------------------------------------------------') print('Job finished.----------------------------------------') print('-----------------------------------------------------')
def main(): """Check ntks in a single call.""" print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}') print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}') print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False) if args.net == 'ResNet': net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width']) elif args.net == 'WideResNet': # meliketoy wideresnet variant net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10) elif args.net == 'MLP': net = torch.nn.Sequential(OrderedDict([ ('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(config['width'], config['width'])), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(config['width'], config['width'])), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10))])) elif args.net == 'TwoLP': net = torch.nn.Sequential(OrderedDict([ ('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10))])) elif args.net == 'MobileNetV2': net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4) elif args.net == 'VGG': cfg_base = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'] cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)] print(cfg) net = VGG(make_layers(cfg), num_classes=10) net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096) elif args.net == 'ConvNet': net = torch.nn.Sequential(OrderedDict([ ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)), ('relu0', torch.nn.ReLU()), # ('pool0', torch.nn.MaxPool2d(3)), ('conv1', torch.nn.Conv2d(1 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu1', torch.nn.ReLU()), # ('pool1', torch.nn.MaxPool2d(3)), ('conv2', torch.nn.Conv2d(2 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu2', torch.nn.ReLU()), # ('pool2', torch.nn.MaxPool2d(3)), ('conv3', torch.nn.Conv2d(2 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu3', torch.nn.ReLU()), ('pool3', torch.nn.MaxPool2d(3)), ('conv4', torch.nn.Conv2d(4 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu4', torch.nn.ReLU()), ('pool4', torch.nn.MaxPool2d(3)), ('flatten', torch.nn.Flatten()), ('linear', torch.nn.Linear(36 * config['width'], 10)) ])) else: raise ValueError('Invalid network specified.') net.to(**config['setup']) try: net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth', map_location=device)) print('Initialized net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth' if not args.dryrun: torch.save(net.state_dict(), path) print('Initialized net saved to file.') else: print(f'Would save to {path}') num_params = sum([p.numel() for p in net.parameters()]) print(f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} ' f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%') # Start training net.to(**config['setup']) if torch.cuda.device_count() > 1: net = torch.nn.DataParallel(net) optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) loss_fn = torch.nn.CrossEntropyLoss() print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) try: net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth', map_location=device)) print('Net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth' dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun) if not args.dryrun: torch.save(net.state_dict(), path) print('Net saved to file.') else: print(f'Would save to {path}') print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) if isinstance(net, torch.nn.DataParallel): net = net.module save_output(args.table_path, name='ntk', width=config['width'], num_params=num_params, before_norm=ntk_matrix_before_norm, after_norm=ntk_matrix_after_norm, diff_norm=ntk_matrix_diff_norm, rdiff_norm=ntk_matrix_rdiff_norm, param_norm_before=param_norm_before, param_norm_after=param_norm_after, corr_coeff=corr_coeff, corr_tom=corr_tom) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) print('-----------------------------------------------------') print('Job finished.----------------------------------------') print('-----------------------------------------------------')
def main(): """Check ntks in a single call.""" print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}') print( f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}' ) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False) if args.net == 'ResNet': net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width']) elif args.net == 'WideResNet': # meliketoy wideresnet variant net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10) elif args.net == 'MLP': net = torch.nn.Sequential( OrderedDict([ ('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(config['width'], config['width'])), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(config['width'], config['width'])), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10)) ])) elif args.net == 'TwoLP': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10))])) elif args.net == 'MobileNetV2': net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4) elif args.net == 'VGG': cfg_base = [ 64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' ] cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)] print(cfg) net = VGG(make_layers(cfg), num_classes=10) net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096) elif args.net == 'ConvNet': net = torch.nn.Sequential( OrderedDict([ ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)), ('relu0', torch.nn.ReLU()), # ('pool0', torch.nn.MaxPool2d(3)), ('conv1', torch.nn.Conv2d(1 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu1', torch.nn.ReLU()), # ('pool1', torch.nn.MaxPool2d(3)), ('conv2', torch.nn.Conv2d(2 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu2', torch.nn.ReLU()), # ('pool2', torch.nn.MaxPool2d(3)), ('conv3', torch.nn.Conv2d(2 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu3', torch.nn.ReLU()), ('pool3', torch.nn.MaxPool2d(3)), ('conv4', torch.nn.Conv2d(4 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu4', torch.nn.ReLU()), ('pool4', torch.nn.MaxPool2d(3)), ('flatten', torch.nn.Flatten()), ('linear', torch.nn.Linear(36 * config['width'], 10)) ])) else: raise ValueError('Invalid network specified.') net.to(**config['setup']) num_params = sum([p.numel() for p in net.parameters()]) print( f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} ' f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%') def batch_feature_correlations(dataloader, device=torch.device('cpu')): net.eval() net.to(device) dist_maps = list() cosine_maps = list() prod_maps = list() hooks = [] def batch_wise_feature_correlation(self, input, output): feat_vec = input[0].detach().view(dataloader.batch_size, -1) dist_maps.append( torch.cdist(feat_vec, feat_vec, 2).detach().cpu().numpy()) cosine_map = np.empty( (dataloader.batch_size, dataloader.batch_size)) prod_map = np.empty((dataloader.batch_size, dataloader.batch_size)) for row in range(dataloader.batch_size): cosine_map[row, :] = torch.nn.functional.cosine_similarity( feat_vec[row:row + 1, :], feat_vec, dim=1, eps=1e-8).detach().cpu().numpy() prod_map[row, :] = torch.mean(feat_vec[row:row + 1, :] * feat_vec, dim=1).detach().cpu().numpy() cosine_maps.append(cosine_map) prod_maps.append(prod_map) if isinstance(net, torch.nn.DataParallel): hooks.append( net.module.linear.register_forward_hook( batch_wise_feature_correlation)) else: if args.net in ['MLP', 'TwoLP']: hooks.append( net.linear3.register_forward_hook( batch_wise_feature_correlation)) elif args.net in ['VGG', 'MobileNetV2']: hooks.append( net.classifier.register_forward_hook( batch_wise_feature_correlation)) else: hooks.append( net.linear.register_forward_hook( batch_wise_feature_correlation)) for inputs, _ in dataloader: outputs = net(inputs.to(device)) if args.dryrun: break for hook in hooks: hook.remove() return dist_maps, cosine_maps, prod_maps pdist_init, cos_init, prod_init = batch_feature_correlations(trainloader) pdist_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_init]) cos_init_norm = np.mean([np.linalg.norm(cm.flatten()) for cm in cos_init]) prod_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_init]) print( f'The total norm of feature distances before training is {pdist_init_norm:.2f}' ) print( f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}' ) print( f'The total norm of feature inner product before training is {prod_init_norm:.2f}' ) save_plot(pdist_init, trainloader, name='pdist_before_training') save_plot(cos_init, trainloader, name='cosine_before_training') save_plot(prod_init, trainloader, name='prod_before_training') # Start training net.to(**config['setup']) if torch.cuda.device_count() > 1: net = torch.nn.DataParallel(net) optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) loss_fn = torch.nn.CrossEntropyLoss() print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) try: net.load_state_dict( torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '.pth', map_location=device)) print('Net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str( config["width"]) + '.pth' dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun) if not args.dryrun: torch.save(net.state_dict(), path) print('Net saved to file.') print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) if isinstance(net, torch.nn.DataParallel): net = net.module # Check feature maps after training pdist_after, cos_after, prod_after = batch_feature_correlations( trainloader) pdist_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_after]) cos_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_after]) prod_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_after]) print( f'The total norm of feature distances after training is {pdist_after_norm:.2f}' ) print( f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}' ) print( f'The total norm of feature inner product after training is {prod_after_norm:.2f}' ) save_plot(pdist_after, trainloader, name='pdist_after_training') save_plot(cos_after, trainloader, name='cosine_after_training') save_plot(prod_after, trainloader, name='prod_after_training') # Check feature map differences pdist_ndiff = [ np.abs(co1 - co2) / pdist_init_norm for co1, co2 in zip(pdist_init, pdist_after) ] cos_ndiff = [ np.abs(co1 - co2) / cos_init_norm for co1, co2 in zip(cos_init, cos_after) ] prod_ndiff = [ np.abs(co1 - co2) / prod_init_norm for co1, co2 in zip(prod_init, prod_after) ] pdist_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff]) cos_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_ndiff]) prod_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_ndiff]) print( f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}' ) print( f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}' ) print( f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}' ) save_plot(pdist_ndiff, trainloader, name='pdist_ndiff') save_plot(cos_ndiff, trainloader, name='cosine_ndiff') save_plot(prod_ndiff, trainloader, name='prod_ndiff') # Check feature map differences pdist_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(pdist_init, pdist_after) ] cos_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(cos_init, cos_after) ] prod_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(prod_init, prod_after) ] pdist_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff]) cos_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_rdiff]) prod_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_rdiff]) print( f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}' ) print( f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}' ) print( f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}' ) save_plot(pdist_rdiff, trainloader, name='pdist_rdiff') save_plot(cos_rdiff, trainloader, name='cosine_rdiff') save_plot(prod_rdiff, trainloader, name='prod_rdiff') save_output(args.table_path, width=config['width'], num_params=num_params, pdist_init_norm=pdist_init_norm, pdist_after_norm=pdist_after_norm, pdist_ndiff_norm=pdist_ndiff_norm, pdist_rdiff_norm=pdist_rdiff_norm, cos_init_norm=pdist_init_norm, cos_after_norm=pdist_after_norm, cos_ndiff_norm=pdist_ndiff_norm, cos_rdiff_norm=cos_rdiff_norm, prod_init_norm=pdist_init_norm, prod_after_norm=pdist_after_norm, prod_ndiff_norm=pdist_ndiff_norm, prod_rdiff_norm=prod_rdiff_norm) # Save raw data raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init, pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after, pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff, pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff) path = config['path'] + 'Cifar10_' + args.net + str( config["width"]) + '_rawmaps.pth' torch.save(raw_pkg, path) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) print('-----------------------------------------------------') print('Job finished.----------------------------------------') print('-----------------------------------------------------')