def evaluation_procedure(config): """Train model and evaluate eigenvalues with given configuration.""" # Setup data augmentations = False trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=augmentations, normalize=True, shuffle=False) if config['model'] == 'MLP': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, 2048)), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(2048, 2048)), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(2048, 1024)), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(1024, 10))])) elif config['model'] == 'ResNet': net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10) else: raise NotImplementedError() linear_classifier = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(3072, 10)) linear_classifier.to(**config['setup']) net.to(**config['setup']) if torch.cuda.device_count() > 1: net = torch.nn.DataParallel(net) linear_classifier = torch.nn.DataParallel(linear_classifier) net.eval() # Optimizer and loss optimizer = torch.optim.SGD(linear_classifier.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) config['epochs'] = config['epochs_linear'] scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[50, 75, 85, 95], gamma=0.1) loss_fn = torch.nn.CrossEntropyLoss() # Check initial model analyze_model(linear_classifier, trainloader, testloader, loss_fn, config) linear_classifier.to(**config['setup']) net.to(**config['setup']) # Train print('Starting training linear classifier ...') dl.train(linear_classifier, optimizer, scheduler, loss_fn, trainloader, config, dryrun=args.dryrun) # Analyze results print('----Results after training linear classifier ------------') analyze_model(linear_classifier, trainloader, testloader, loss_fn, config) for name, param in linear_classifier.named_parameters(): dprint(name, param) param.requires_grad = False # Check full model print('----Distill learned classifier onto network ------------') config['epochs'] = config['epochs_distill'] loss_distill = torch.nn.KLDivLoss(reduction='batchmean') optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[120, 180, 240], gamma=0.2) dl.distill(linear_classifier, net, optimizer, scheduler, loss_distill, trainloader, config, dryrun=args.dryrun) # Analyze results analyze_model(net, trainloader, testloader, loss_fn, config)
def main(): """Check ntks in a single call.""" print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}') print( f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}' ) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False) if args.net == 'ResNet': net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width']) elif args.net == 'WideResNet': # meliketoy wideresnet variant net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10) elif args.net == 'MLP': net = torch.nn.Sequential( OrderedDict([ ('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(config['width'], config['width'])), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(config['width'], config['width'])), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10)) ])) elif args.net == 'TwoLP': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10))])) elif args.net == 'MobileNetV2': net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4) elif args.net == 'VGG': cfg_base = [ 64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' ] cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)] print(cfg) net = VGG(make_layers(cfg), num_classes=10) net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096) elif args.net == 'ConvNet': net = torch.nn.Sequential( OrderedDict([ ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)), ('relu0', torch.nn.ReLU()), # ('pool0', torch.nn.MaxPool2d(3)), ('conv1', torch.nn.Conv2d(1 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu1', torch.nn.ReLU()), # ('pool1', torch.nn.MaxPool2d(3)), ('conv2', torch.nn.Conv2d(2 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu2', torch.nn.ReLU()), # ('pool2', torch.nn.MaxPool2d(3)), ('conv3', torch.nn.Conv2d(2 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu3', torch.nn.ReLU()), ('pool3', torch.nn.MaxPool2d(3)), ('conv4', torch.nn.Conv2d(4 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu4', torch.nn.ReLU()), ('pool4', torch.nn.MaxPool2d(3)), ('flatten', torch.nn.Flatten()), ('linear', torch.nn.Linear(36 * config['width'], 10)) ])) else: raise ValueError('Invalid network specified.') net.to(**config['setup']) try: net.load_state_dict( torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth', map_location=device)) print('Initialized net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str( config["width"]) + '_before.pth' if not args.dryrun: torch.save(net.state_dict(), path) print('Initialized net saved to file.') else: print(f'Would save to {path}') num_params = sum([p.numel() for p in net.parameters()]) print( f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} ' f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%') ntk_matrix_before = batch_wise_ntk(net, trainloader, samplesize=args.sampling) plt.imshow(ntk_matrix_before) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_BEFORE.png', bbox_inches='tight', dpi=1200) ntk_matrix_before_norm = np.linalg.norm(ntk_matrix_before.flatten()) print( f'The total norm of the NTK sample before training is {ntk_matrix_before_norm:.2f}' ) param_norm_before = np.sqrt( np.sum( [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()])) print(f'The L2 norm of the parameter vector is {param_norm_before:.2f}') if args.pdist: pdist_init, cos_init, prod_init = batch_feature_correlations( trainloader) pdist_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_init]) cos_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_init]) prod_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_init]) print( f'The total norm of feature distances before training is {pdist_init_norm:.2f}' ) print( f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}' ) print( f'The total norm of feature inner product before training is {prod_init_norm:.2f}' ) save_plot(pdist_init, trainloader, name='pdist_before_training') save_plot(cos_init, trainloader, name='cosine_before_training') save_plot(prod_init, trainloader, name='prod_before_training') # Start training net.to(**config['setup']) if torch.cuda.device_count() > 1: net = torch.nn.DataParallel(net) optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) loss_fn = torch.nn.CrossEntropyLoss() print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) try: net.load_state_dict( torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth', map_location=device)) print('Net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str( config["width"]) + '_after.pth' dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun) if not args.dryrun: torch.save(net.state_dict(), path) print('Net saved to file.') else: print(f'Would save to {path}') print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) if isinstance(net, torch.nn.DataParallel): net = net.module param_norm_after = np.sqrt( np.sum( [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()])) print(f'The L2 norm of the parameter vector is {param_norm_after:.2f}') ntk_matrix_after = batch_wise_ntk(net, trainloader, samplesize=args.sampling) plt.imshow(ntk_matrix_after) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_AFTER.png', bbox_inches='tight', dpi=1200) ntk_matrix_after_norm = np.linalg.norm(ntk_matrix_after.flatten()) print( f'The total norm of the NTK sample after training is {ntk_matrix_after_norm:.2f}' ) ntk_matrix_diff = np.abs(ntk_matrix_before - ntk_matrix_after) plt.imshow(ntk_matrix_diff) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_DIFF.png', bbox_inches='tight', dpi=1200) ntk_matrix_diff_norm = np.linalg.norm(ntk_matrix_diff.flatten()) print( f'The total norm of the NTK sample diff is {ntk_matrix_diff_norm:.2f}') ntk_matrix_rdiff = np.abs(ntk_matrix_before - ntk_matrix_after) / ( np.abs(ntk_matrix_before) + 1e-4) plt.imshow(ntk_matrix_rdiff) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_RDIFF.png', bbox_inches='tight', dpi=1200) ntk_matrix_rdiff_norm = np.linalg.norm(ntk_matrix_rdiff.flatten()) print( f'The total norm of the NTK sample relative diff is {ntk_matrix_rdiff_norm:.2f}' ) n1_mean = np.mean(ntk_matrix_before) n2_mean = np.mean(ntk_matrix_after) matrix_corr = (ntk_matrix_before - n1_mean) * (ntk_matrix_after - n2_mean) / \ np.std(ntk_matrix_before) / np.std(ntk_matrix_after) plt.imshow(matrix_corr) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png', bbox_inches='tight', dpi=1200) corr_coeff = np.mean(matrix_corr) print( f'The Correlation coefficient of the NTK sample before and after training is {corr_coeff:.2f}' ) matrix_sim = (ntk_matrix_before * ntk_matrix_after) / \ np.sqrt(np.sum(ntk_matrix_before**2) * np.sum(ntk_matrix_after**2)) plt.imshow(matrix_corr) plt.savefig(config['path'] + f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png', bbox_inches='tight', dpi=1200) corr_tom = np.sum(matrix_sim) print( f'The Similarity coefficient of the NTK sample before and after training is {corr_tom:.2f}' ) save_output(args.table_path, name='ntk', width=config['width'], num_params=num_params, before_norm=ntk_matrix_before_norm, after_norm=ntk_matrix_after_norm, diff_norm=ntk_matrix_diff_norm, rdiff_norm=ntk_matrix_rdiff_norm, param_norm_before=param_norm_before, param_norm_after=param_norm_after, corr_coeff=corr_coeff, corr_tom=corr_tom) if args.pdist: # Check feature maps after training pdist_after, cos_after, prod_after = batch_feature_correlations( trainloader) pdist_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_after]) cos_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_after]) prod_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_after]) print( f'The total norm of feature distances after training is {pdist_after_norm:.2f}' ) print( f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}' ) print( f'The total norm of feature inner product after training is {prod_after_norm:.2f}' ) save_plot(pdist_after, trainloader, name='pdist_after_training') save_plot(cos_after, trainloader, name='cosine_after_training') save_plot(prod_after, trainloader, name='prod_after_training') # Check feature map differences pdist_ndiff = [ np.abs(co1 - co2) / pdist_init_norm for co1, co2 in zip(pdist_init, pdist_after) ] cos_ndiff = [ np.abs(co1 - co2) / cos_init_norm for co1, co2 in zip(cos_init, cos_after) ] prod_ndiff = [ np.abs(co1 - co2) / prod_init_norm for co1, co2 in zip(prod_init, prod_after) ] pdist_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff]) cos_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_ndiff]) prod_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_ndiff]) print( f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}' ) print( f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}' ) print( f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}' ) save_plot(pdist_ndiff, trainloader, name='pdist_ndiff') save_plot(cos_ndiff, trainloader, name='cosine_ndiff') save_plot(prod_ndiff, trainloader, name='prod_ndiff') # Check feature map differences pdist_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(pdist_init, pdist_after) ] cos_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(cos_init, cos_after) ] prod_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(prod_init, prod_after) ] pdist_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff]) cos_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_rdiff]) prod_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_rdiff]) print( f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}' ) print( f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}' ) print( f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}' ) save_plot(pdist_rdiff, trainloader, name='pdist_rdiff') save_plot(cos_rdiff, trainloader, name='cosine_rdiff') save_plot(prod_rdiff, trainloader, name='prod_rdiff') save_output(args.table_path, 'pdist', width=config['width'], num_params=num_params, pdist_init_norm=pdist_init_norm, pdist_after_norm=pdist_after_norm, pdist_ndiff_norm=pdist_ndiff_norm, pdist_rdiff_norm=pdist_rdiff_norm, cos_init_norm=pdist_init_norm, cos_after_norm=pdist_after_norm, cos_ndiff_norm=pdist_ndiff_norm, cos_rdiff_norm=cos_rdiff_norm, prod_init_norm=pdist_init_norm, prod_after_norm=pdist_after_norm, prod_ndiff_norm=pdist_ndiff_norm, prod_rdiff_norm=prod_rdiff_norm) # Save raw data # raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init, # pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after, # pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff, # pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff) # path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_rawmaps.pth' # torch.save(raw_pkg, path) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) print('-----------------------------------------------------') print('Job finished.----------------------------------------') print('-----------------------------------------------------')
def evaluation_procedure(config): """Train model and evaluate eigenvalues with given configuration.""" # Setup data augmentations = False trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=augmentations, normalize=args.normalize, shuffle=False) class Restrict(torch.nn.Module): def __init__(self, subrank): super(Restrict, self).__init__() self.shape = int(subrank) def forward(self, x): return x[:, :self.shape] if config['model'] == 'MLP': fullnet = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, args.width)), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(args.width, args.width)), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(args.width, args.width)), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(args.width, 10))])) # breakpoint() subnet = torch.nn.Sequential( torch.nn.Flatten(), Restrict(args.width), *list(fullnet.children())[-config['subnet_depth']:]) else: raise NotImplementedError() subnet.to(**config['setup']) fullnet.to(**config['setup']) if torch.cuda.device_count() > 1: subnet = torch.nn.DataParallel(subnet) fullnet = torch.nn.DataParallel(fullnet) subnet.eval() fullnet.eval() # Optimizer and loss optimizer = torch.optim.SGD(subnet.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[50, 200, 400, 600, 700], gamma=0.1) loss_fn = torch.nn.CrossEntropyLoss() # Check initial model analyze_model(subnet, trainloader, testloader, loss_fn, config) subnet.to(**config['setup']) fullnet.to(**config['setup']) # Train print( 'Starting training subnet ...........................................') dl.train(subnet, optimizer, scheduler, loss_fn, trainloader, config, dryrun=args.dryrun) # Analyze results print( '----Results after training subnet -----------------------------------------------------------' ) analyze_model(subnet, trainloader, testloader, loss_fn, config) for name, param in subnet.named_parameters(): dprint(name, param) # Check full model print( '----Extend to full model and check local optimality -----------------------------------------' ) # assert all([p1 is p2 for (p1, p2) in zip(fullnet[-1].parameters(), subnet.parameters())]) bias_first = True bias_offset = 2 for name, param in fullnet.named_parameters(): if all([param is not p for p in subnet.parameters()]): dprint(f'Currently setting {name}') if 'weight' in name: torch.nn.init.eye_(param) dprint(f'{name} set to Id.') elif 'bias' in name: if bias_first: torch.nn.init.constant_(param, bias_offset) bias_first = False dprint(f'{name} set to 1.') else: torch.nn.init.constant_(param, 0) dprint(f'{name} set to 0.') # if normalize=False, input will be in [0,1] so no bias is necessary elif 'conv.weight' in name: torch.nn.init.dirac_(param) dprint(f'{name} set to dirac.') else: if 'linear3.bias' in name: Axb = subnet( bias_offset * torch.ones(1, 3072, **config['setup'])).detach().squeeze() param.data -= Axb - param.data dprint(f'{name} set to b - Ax') print('Model extended to full model.') for name, param in fullnet.named_parameters(): dprint(name, param) # Analyze results analyze_model(fullnet, trainloader, testloader, loss_fn, config) # Finetune print( 'Finetune full net .................................................') config['full_batch'] = False optimizer = torch.optim.SGD(subnet.parameters(), lr=1e-4, momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[50, 200, 400, 600, 700], gamma=0.1) dl.train(fullnet, optimizer, scheduler, loss_fn, trainloader, config, dryrun=args.dryrun) analyze_model(fullnet, trainloader, testloader, loss_fn, config)
def main(): """Check ntks in a single call.""" print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}') print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}') print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False) if args.net == 'ResNet': net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width']) elif args.net == 'WideResNet': # meliketoy wideresnet variant net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10) elif args.net == 'MLP': net = torch.nn.Sequential(OrderedDict([ ('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(config['width'], config['width'])), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(config['width'], config['width'])), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10))])) elif args.net == 'TwoLP': net = torch.nn.Sequential(OrderedDict([ ('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10))])) elif args.net == 'MobileNetV2': net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4) elif args.net == 'VGG': cfg_base = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'] cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)] print(cfg) net = VGG(make_layers(cfg), num_classes=10) net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096) elif args.net == 'ConvNet': net = torch.nn.Sequential(OrderedDict([ ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)), ('relu0', torch.nn.ReLU()), # ('pool0', torch.nn.MaxPool2d(3)), ('conv1', torch.nn.Conv2d(1 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu1', torch.nn.ReLU()), # ('pool1', torch.nn.MaxPool2d(3)), ('conv2', torch.nn.Conv2d(2 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu2', torch.nn.ReLU()), # ('pool2', torch.nn.MaxPool2d(3)), ('conv3', torch.nn.Conv2d(2 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu3', torch.nn.ReLU()), ('pool3', torch.nn.MaxPool2d(3)), ('conv4', torch.nn.Conv2d(4 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu4', torch.nn.ReLU()), ('pool4', torch.nn.MaxPool2d(3)), ('flatten', torch.nn.Flatten()), ('linear', torch.nn.Linear(36 * config['width'], 10)) ])) else: raise ValueError('Invalid network specified.') net.to(**config['setup']) try: net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth', map_location=device)) print('Initialized net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth' if not args.dryrun: torch.save(net.state_dict(), path) print('Initialized net saved to file.') else: print(f'Would save to {path}') num_params = sum([p.numel() for p in net.parameters()]) print(f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} ' f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%') # Start training net.to(**config['setup']) if torch.cuda.device_count() > 1: net = torch.nn.DataParallel(net) optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) loss_fn = torch.nn.CrossEntropyLoss() print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) try: net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth', map_location=device)) print('Net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth' dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun) if not args.dryrun: torch.save(net.state_dict(), path) print('Net saved to file.') else: print(f'Would save to {path}') print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) if isinstance(net, torch.nn.DataParallel): net = net.module save_output(args.table_path, name='ntk', width=config['width'], num_params=num_params, before_norm=ntk_matrix_before_norm, after_norm=ntk_matrix_after_norm, diff_norm=ntk_matrix_diff_norm, rdiff_norm=ntk_matrix_rdiff_norm, param_norm_before=param_norm_before, param_norm_after=param_norm_after, corr_coeff=corr_coeff, corr_tom=corr_tom) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) print('-----------------------------------------------------') print('Job finished.----------------------------------------') print('-----------------------------------------------------')
def evaluation_procedure(config): """Train model and evaluate eigenvalues with given configuration.""" # Setup data augmentations = False trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=augmentations) # Setup Network if config['model'] == 'MLP': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, 2048)), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(2048, 2048)), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(2048, 1024)), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(1024, 10))])) elif config['model'] == 'MLPsmall': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, 256)), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(256, 256)), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(256, 256)), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(256, 10))])) elif config['model'] == 'MLPsmallB': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, 256)), ('relu0', torch.nn.ReLU()), ('bn0', torch.nn.BatchNorm2d(256)), ('linear1', torch.nn.Linear(256, 256)), ('relu1', torch.nn.ReLU()), ('bn0', torch.nn.BatchNorm2d(256)), ('linear2', torch.nn.Linear(256, 256)), ('relu2', torch.nn.ReLU()), ('bn0', torch.nn.BatchNorm2d(256)), ('linear3', torch.nn.Linear(256, 10))])) elif config['model'] == 'ResNet': net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10) elif config['model'] == 'L-MLP': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, 2048)), ('linear1', torch.nn.Linear(2048, 2048)), ('linear2', torch.nn.Linear(2048, 1024)), ('linear3', torch.nn.Linear(1024, 10))])) elif config['model'] == 'L-ResNet': net = ResNetLinear(BasicBlockLinear, [2, 2, 2, 2], num_classes=10) net.to(**config['setup']) net = torch.nn.DataParallel(net) net.eval() def initialize_net(net, init): for name, param in net.named_parameters(): with torch.no_grad(): if init == 'default': pass elif init == 'zero': param.zero_() elif init == 'low_bias': if 'bias' in name: param -= 20 elif init == 'high_bias': if 'bias' in name: param += 20 elif init == 'equal': torch.nn.init.constant_(param, 0.001) elif init == 'variant_bias': if 'bias' in name: torch.nn.init.uniform_(param, -args.var, args.var) initialize_net(net, config['init']) # Optimizer and loss optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=args.mom, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[150, 250, 350], gamma=0.1) loss_fn = torch.nn.CrossEntropyLoss() # Analyze model before training analyze_model(net, trainloader, testloader, loss_fn, config) # Train print('Starting training ...') dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, dryrun=args.dryrun) # Analyze results acc_train, acc_test, loss_train, loss_trainw, grd_train, maxeig, mineig = analyze_model( net, trainloader, testloader, loss_fn, config) save_output(args.table_path, init=config['init'], var=args.var, acc_train=acc_train, acc_test=acc_test, loss_train=loss_train, loss_trainw=loss_trainw, grd_train=grd_train, maxeig=maxeig, mineig=mineig)
def evaluation_procedure(config): """Train model and evaluate eigenvalues with given configuration.""" # Setup data trainset = torchvision.datasets.FakeData(size=50_000, image_size=(3, 32, 32), num_classes=10, transform=transforms.ToTensor(), random_offset=0) cc = torch.cat( [trainset[i][0].reshape(3, -1) for i in range(len(trainset))], dim=1) data_mean = torch.mean(cc, dim=1).tolist() data_std = torch.std(cc, dim=1).tolist() print(f'Data mean is {data_mean}, data std is {data_std}') if config['data_augmentation'] is None: padding = 0 transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(data_mean, data_std)]) elif config['data_augmentation'] == 'default': padding = 2 transform = transforms.Compose([ transforms.RandomCrop(32, padding=padding, padding_mode='reflect'), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(data_mean, data_std) ]) elif config['data_augmentation'] == 'alot': padding = 4 transform = transforms.Compose([ transforms.RandomCrop(32, padding=padding, padding_mode='reflect'), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(data_mean, data_std) ]) trainset = torchvision.datasets.FakeData(size=50_000, image_size=(3, 32, 32), num_classes=10, transform=transform, random_offset=0) testset = torchvision.datasets.FakeData(size=1_000, image_size=(3, 32, 32), num_classes=10, transform=transforms.ToTensor(), random_offset=267914296) # fib42 aug = max((padding * 2)**2 * 2, 1) aug_datapoints = len(trainset) * aug if config['extra_data'] == 'add_extra_data': padding = 0 transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(data_mean, data_std)]) trainset = torchvision.datasets.FakeData(size=aug_datapoints, image_size=(3, 32, 32), num_classes=10, transform=transform, random_offset=0) config['epochs'] = max(config['epochs'] // aug, 20) print(f'Effective epochs reduced to {config["epochs"]}') elif config['extra_data'] == 'do_augmentation': pass num_workers = torch.get_num_threads() if torch.get_num_threads() > 0 else 0 trainloader = torch.utils.data.DataLoader(trainset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers) testloader = torch.utils.data.DataLoader(testset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers) # Setup Network if config['model'] == 'ResNetWide1': net = ResNet18() elif config['model'] == 'ResNetWide2': net = WideResNet(BasicBlock, [2, 2, 2, 2], num_classes=10, widen_factor=2) elif config['model'] == 'ResNetWide5': net = WideResNet(BasicBlock, [2, 2, 2, 2], num_classes=10, widen_factor=5) elif config['model'] == 'ResNetWide7': net = WideResNet(BasicBlock, [2, 2, 2, 2], num_classes=10, widen_factor=7) else: net = WideResNet(BasicBlock, [2, 2, 2, 2], num_classes=10, widen_factor=10) net.to(**config['setup']) net = torch.nn.DataParallel(net) net.eval() num_params = sum([p.numel() for p in net.parameters()]) print( f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} ' f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%') aug_datapoints = len(trainloader.dataset) * max((padding * 2)**2 * 2, 1) print( f'Number of params: {num_params} - number of aug. data points: {aug_datapoints}' f'- ratio : {aug_datapoints/num_params*100:.2f}%') # Optimizer and loss optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) if config['extra_data'] == 'add_extra_data': scheduling = [ max(100 // aug, 5), max(250 // aug, 10), max(350 // aug, 15) ] scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=scheduling, gamma=0.1) print(f'New scheduling is {scheduling}.') else: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[100, 250, 350], gamma=0.1) loss_fn = torch.nn.CrossEntropyLoss() full_batch = config['full_batch'] print('Start training ...') dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun) if config['save_models']: torch.save( net.state_dict(), 'models/' + config['model'] + '_random_labels_' + str(config['weight_decay']) + str(config['data_augmentation']) + '.pth') print( f'Accuracy of the network on training images: {(100 * dl.get_accuracy(net, trainloader, config))} %%' ) print( f'Accuracy of the network on test images: {(100 * dl.get_accuracy(net, testloader, config))} %%' ) print( f'Loss in training is {dl.compute_loss(net, loss_fn, trainloader, config):.12f}' ) print( f'Loss in testing is {dl.compute_loss(net, loss_fn, testloader, config):.12f}' ) grd_train = dl.gradient_norm(trainloader, net, loss_fn, config['setup']['device'], config['weight_decay']) grd_test = dl.gradient_norm(testloader, net, loss_fn, config['setup']['device'], config['weight_decay']) print(f'Gradient norm in training is {grd_train:.12f}') print(f'Gradient norm in testing is {grd_test:.12f}')
def main(): """Check ntks in a single call.""" print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}') print( f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}' ) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False) if args.net == 'ResNet': net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width']) elif args.net == 'WideResNet': # meliketoy wideresnet variant net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10) elif args.net == 'MLP': net = torch.nn.Sequential( OrderedDict([ ('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear1', torch.nn.Linear(config['width'], config['width'])), ('relu1', torch.nn.ReLU()), ('linear2', torch.nn.Linear(config['width'], config['width'])), ('relu2', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10)) ])) elif args.net == 'TwoLP': net = torch.nn.Sequential( OrderedDict([('flatten', torch.nn.Flatten()), ('linear0', torch.nn.Linear(3072, config['width'])), ('relu0', torch.nn.ReLU()), ('linear3', torch.nn.Linear(config['width'], 10))])) elif args.net == 'MobileNetV2': net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4) elif args.net == 'VGG': cfg_base = [ 64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' ] cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)] print(cfg) net = VGG(make_layers(cfg), num_classes=10) net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096) elif args.net == 'ConvNet': net = torch.nn.Sequential( OrderedDict([ ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)), ('relu0', torch.nn.ReLU()), # ('pool0', torch.nn.MaxPool2d(3)), ('conv1', torch.nn.Conv2d(1 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu1', torch.nn.ReLU()), # ('pool1', torch.nn.MaxPool2d(3)), ('conv2', torch.nn.Conv2d(2 * config['width'], 2 * config['width'], kernel_size=3, padding=1)), ('relu2', torch.nn.ReLU()), # ('pool2', torch.nn.MaxPool2d(3)), ('conv3', torch.nn.Conv2d(2 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu3', torch.nn.ReLU()), ('pool3', torch.nn.MaxPool2d(3)), ('conv4', torch.nn.Conv2d(4 * config['width'], 4 * config['width'], kernel_size=3, padding=1)), ('relu4', torch.nn.ReLU()), ('pool4', torch.nn.MaxPool2d(3)), ('flatten', torch.nn.Flatten()), ('linear', torch.nn.Linear(36 * config['width'], 10)) ])) else: raise ValueError('Invalid network specified.') net.to(**config['setup']) num_params = sum([p.numel() for p in net.parameters()]) print( f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} ' f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%') def batch_feature_correlations(dataloader, device=torch.device('cpu')): net.eval() net.to(device) dist_maps = list() cosine_maps = list() prod_maps = list() hooks = [] def batch_wise_feature_correlation(self, input, output): feat_vec = input[0].detach().view(dataloader.batch_size, -1) dist_maps.append( torch.cdist(feat_vec, feat_vec, 2).detach().cpu().numpy()) cosine_map = np.empty( (dataloader.batch_size, dataloader.batch_size)) prod_map = np.empty((dataloader.batch_size, dataloader.batch_size)) for row in range(dataloader.batch_size): cosine_map[row, :] = torch.nn.functional.cosine_similarity( feat_vec[row:row + 1, :], feat_vec, dim=1, eps=1e-8).detach().cpu().numpy() prod_map[row, :] = torch.mean(feat_vec[row:row + 1, :] * feat_vec, dim=1).detach().cpu().numpy() cosine_maps.append(cosine_map) prod_maps.append(prod_map) if isinstance(net, torch.nn.DataParallel): hooks.append( net.module.linear.register_forward_hook( batch_wise_feature_correlation)) else: if args.net in ['MLP', 'TwoLP']: hooks.append( net.linear3.register_forward_hook( batch_wise_feature_correlation)) elif args.net in ['VGG', 'MobileNetV2']: hooks.append( net.classifier.register_forward_hook( batch_wise_feature_correlation)) else: hooks.append( net.linear.register_forward_hook( batch_wise_feature_correlation)) for inputs, _ in dataloader: outputs = net(inputs.to(device)) if args.dryrun: break for hook in hooks: hook.remove() return dist_maps, cosine_maps, prod_maps pdist_init, cos_init, prod_init = batch_feature_correlations(trainloader) pdist_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_init]) cos_init_norm = np.mean([np.linalg.norm(cm.flatten()) for cm in cos_init]) prod_init_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_init]) print( f'The total norm of feature distances before training is {pdist_init_norm:.2f}' ) print( f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}' ) print( f'The total norm of feature inner product before training is {prod_init_norm:.2f}' ) save_plot(pdist_init, trainloader, name='pdist_before_training') save_plot(cos_init, trainloader, name='cosine_before_training') save_plot(prod_init, trainloader, name='prod_before_training') # Start training net.to(**config['setup']) if torch.cuda.device_count() > 1: net = torch.nn.DataParallel(net) optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) loss_fn = torch.nn.CrossEntropyLoss() print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) try: net.load_state_dict( torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '.pth', map_location=device)) print('Net loaded from file.') except Exception as e: # :> path = config['path'] + 'Cifar10_' + args.net + str( config["width"]) + '.pth' dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun) if not args.dryrun: torch.save(net.state_dict(), path) print('Net saved to file.') print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) if isinstance(net, torch.nn.DataParallel): net = net.module # Check feature maps after training pdist_after, cos_after, prod_after = batch_feature_correlations( trainloader) pdist_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_after]) cos_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_after]) prod_after_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_after]) print( f'The total norm of feature distances after training is {pdist_after_norm:.2f}' ) print( f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}' ) print( f'The total norm of feature inner product after training is {prod_after_norm:.2f}' ) save_plot(pdist_after, trainloader, name='pdist_after_training') save_plot(cos_after, trainloader, name='cosine_after_training') save_plot(prod_after, trainloader, name='prod_after_training') # Check feature map differences pdist_ndiff = [ np.abs(co1 - co2) / pdist_init_norm for co1, co2 in zip(pdist_init, pdist_after) ] cos_ndiff = [ np.abs(co1 - co2) / cos_init_norm for co1, co2 in zip(cos_init, cos_after) ] prod_ndiff = [ np.abs(co1 - co2) / prod_init_norm for co1, co2 in zip(prod_init, prod_after) ] pdist_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff]) cos_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_ndiff]) prod_ndiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_ndiff]) print( f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}' ) print( f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}' ) print( f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}' ) save_plot(pdist_ndiff, trainloader, name='pdist_ndiff') save_plot(cos_ndiff, trainloader, name='cosine_ndiff') save_plot(prod_ndiff, trainloader, name='prod_ndiff') # Check feature map differences pdist_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(pdist_init, pdist_after) ] cos_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(cos_init, cos_after) ] prod_rdiff = [ np.abs(co1 - co2) / (np.abs(co1) + 1e-6) for co1, co2 in zip(prod_init, prod_after) ] pdist_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff]) cos_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in cos_rdiff]) prod_rdiff_norm = np.mean( [np.linalg.norm(cm.flatten()) for cm in prod_rdiff]) print( f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}' ) print( f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}' ) print( f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}' ) save_plot(pdist_rdiff, trainloader, name='pdist_rdiff') save_plot(cos_rdiff, trainloader, name='cosine_rdiff') save_plot(prod_rdiff, trainloader, name='prod_rdiff') save_output(args.table_path, width=config['width'], num_params=num_params, pdist_init_norm=pdist_init_norm, pdist_after_norm=pdist_after_norm, pdist_ndiff_norm=pdist_ndiff_norm, pdist_rdiff_norm=pdist_rdiff_norm, cos_init_norm=pdist_init_norm, cos_after_norm=pdist_after_norm, cos_ndiff_norm=pdist_ndiff_norm, cos_rdiff_norm=cos_rdiff_norm, prod_init_norm=pdist_init_norm, prod_after_norm=pdist_after_norm, prod_ndiff_norm=pdist_ndiff_norm, prod_rdiff_norm=prod_rdiff_norm) # Save raw data raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init, pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after, pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff, pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff) path = config['path'] + 'Cifar10_' + args.net + str( config["width"]) + '_rawmaps.pth' torch.save(raw_pkg, path) print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) print('-----------------------------------------------------') print('Job finished.----------------------------------------') print('-----------------------------------------------------')