dl_dw = torch.autograd.grad(loss, model.parameters(), create_graph=True) d2l_dw2 = torch.autograd.grad(dl_dw, model.parameters(), ones, create_graph=True) sum_dl2_dw2 = sum([torch.abs(g).sum() if args.abs else g.sum() for g in d2l_dw2]) loss = loss + args.sf * sum_dl2_dw2 loss.backward() optimiser.step() print(loss) loss_lst.append(loss) Hess = FullHessian(crit='MSELoss', loader=training_loader, device=device, model=model, double=False, num_classes=10, hessian_type='Hessian', init_poly_deg=64, poly_deg=128, spectrum_margin=0.05, poly_points=1024, SSI_iters=128 ) lmin, lmax = Hess.compute_lb_ub() lmax_lst.append(lmax) lmin_lst.append(lmin) if args.num_epochs > args.scaling_epoch: f=osp.join(ckpt_dir , 'scaling_lmax_lmin.npz') f_loss = osp.join(images_dir,'scaling_loss.png') f_lmax = osp.join(images_dir,'scaling_lmax.png')
map_location=lambda storage, loc: storage) if 'model' in state_dict.keys(): state_dict = state_dict['model'] model.load_state_dict(state_dict, strict=True) model = model.to(device) model = model.eval() logger.info('weights loaded from path: {}'.format(weights_path)) logger.info('for epoch: {}'.format(epoch_number)) Hess = FullHessian(crit='CrossEntropyLoss', loader=loader, device=device, model=model, num_classes=C, hessian_type='Hessian', init_poly_deg=64, poly_deg=128, spectrum_margin=0.05, poly_points=1024, SSI_iters=128) lmin, lmax = Hess.compute_lb_ub() lmax_list.append((epoch_number, lmax)) lmin_list.append((epoch_number, lmin)) logger.info( 'computation finished for epoch number: {}'.format(epoch_number)) print(lmax) lmax_path = osp.join(ckpt_dir, 'full_lambda_max.pkl')
examples_per_class=args.examples_per_class, epc_seed=row.epc_seed, root=osp.join(args.dataset_root, args.dataset), train=True, transform=transform, download=True) loader = DataLoader(dataset=subset_dataset, drop_last=False, batch_size=args.batch_size) logger.info('Starting G decomposition...') res = FullHessian( crit='CrossEntropyLoss', loader=loader, device=device, model=model, num_classes=row.num_classes, hessian_type='G', ).compute_G_decomp() logger.info('G decomposition finished...') logger.info('Starting G eigenspectrum computation') Hess = FullHessian( crit='CrossEntropyLoss', loader=loader, device=device, model=model, num_classes=row.num_classes, hessian_type='G', init_poly_deg= 64, # number of iterations used to compute maximal/minimal eigenvalue
for i in range(args.num_models): model_weights_path = osp.join(ckpt_dir, f'model_weights_epochs_{i}.pth') print(f'Processing model number {i}') state_dict = torch.load(model_weights_path, device) model.load_state_dict(state_dict, strict=True) model = model.to(device) C = 10 Hess = FullHessian( crit='CrossEntropyLoss', loader=loader, device=device, model=model, num_classes=C, hessian_type='Hessian', init_poly_deg= 64, # number of iterations used to compute maximal/minimal eigenvalue poly_deg= 128, # the higher the parameter the better the approximation spectrum_margin=0.05, poly_points=1024, # number of points in spectrum approximation SSI_iters=128, # iterations of subspace iterations ) # Exact outliers computation using subspace iteration eigvecs_cur, eigvals_cur, _ = Hess.SubspaceIteration() # Flatten out eigvecs_cur as it is a list top_subspace = torch.zeros(0, device=device) for _ in range(len(eigvecs_cur)): b = torch.zeros(0, device=device)
download=True ) loader = DataLoader(dataset=subset_dataset, drop_last=False, batch_size=args.batch_size) C = row.num_classes logger.info('Starting hessiandecomposition...') Hess = FullHessian(crit='CrossEntropyLoss', loader=loader, device=device, model=model, num_classes=C, hessian_type='Hessian', init_poly_deg=64, poly_deg=128, spectrum_margin=0.05, poly_points=1024, SSI_iters=128 ) Hess_eigval, \ Hess_eigval_density = Hess.LanczosLoop(denormalize=True) # the higher the parameter the better the approximation H = FullHessian(crit='CrossEntropyLoss', loader=loader, device=device, model=model, num_classes=C, hessian_type='H',
epc_seed=row.epc_seed, root=osp.join(args.dataset_root, args.dataset), train=True, transform=transform, download=True ) loader = DataLoader(dataset=subset_dataset, drop_last=False, batch_size=args.batch_size) C = row.num_classes logger.info('Starting G decomposition...') res = FullHessian(crit='CrossEntropyLoss', loader=loader, device=device, model=model, num_classes=C, hessian_type='G', ).compute_G_decomp() logger.info('G decomposed') logger.info('Starting TSNE...') tsne_embedded = TSNE(n_components=2, metric='precomputed', perplexity=C).fit_transform(res['dist']) logger.info('TSNE finished...') # t-SNE X delta_c_X = tsne_embedded[:C,0] delta_ccp_X = tsne_embedded[C:,0] # t-SNE Y delta_c_Y = tsne_embedded[:C,1]
def train(model, optimizer, scheduler, dataloaders, criterion, device, num_epochs=100, args=None, dataset_sizes={ 'train': 5e4, 'test': 1e4 }, images_dir=None, ckpt_dir=None): logger = logging.getLogger('train') loss_list = {'train': list(), 'test': list()} acc_list = {'train': list(), 'test': list()} assert images_dir is not None assert ckpt_dir is not None loss_image_path = osp.join(images_dir, 'loss.png') acc_image_path = osp.join(images_dir, 'acc.png') model.train() full_eigenspectrums = list() epoch_eigenspectrums = list() full_eigenspectrums_path = osp.join(ckpt_dir, 'training_eigenspectrum_full.npy') C = config.num_classes valid_layers = get_valid_layers(model) for epoch in range(num_epochs): logger.info('epoch: %d' % epoch) with torch.enable_grad(): for batch, truth in dataloaders['train']: batch = batch.to(device) truth = truth.to(device) optimizer.zero_grad() output = model(batch) loss = criterion(output, truth) loss.backward() optimizer.step() scheduler.step() # updates finished for epochs mean, std = get_mean_std(args.dataset) pad = int((config.padded_im_size - config.im_size) / 2) transform = transforms.Compose([ transforms.Pad(pad), transforms.ToTensor(), transforms.Normalize(mean, std) ]) if args.dataset in ['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100']: full_dataset = getattr(datasets, args.dataset) subset_dataset = get_subset_dataset( full_dataset=full_dataset, examples_per_class=args.examples_per_class, epc_seed=config.epc_seed, root=osp.join(args.dataset_root, args.dataset), train=True, transform=transform, download=True) elif args.dataset in ['STL10', 'SVHN']: full_dataset = getattr(datasets, args.dataset) subset_dataset = get_subset_dataset( full_dataset=full_dataset, examples_per_class=args.examples_per_class, epc_seed=config.epc_seed, root=osp.join(args.dataset_root, args.dataset), split='train', transform=transform, download=True) else: raise Exception('Unknown dataset: {}'.format(args.dataset)) loader = data.DataLoader(dataset=subset_dataset, drop_last=False, batch_size=args.batch_size) Hess = FullHessian(crit='CrossEntropyLoss', loader=loader, device=device, model=model, num_classes=C, hessian_type='Hessian', init_poly_deg=64, poly_deg=128, spectrum_margin=0.05, poly_points=1024, SSI_iters=128) Hess_eigval, \ Hess_eigval_density = Hess.LanczosLoop(denormalize=True) full_eigenspectrums.append(Hess_eigval) full_eigenspectrums.append(Hess_eigval_density) for layer_name, _ in model.named_parameters(): if layer_name not in valid_layers: continue Hess = LayerHessian(crit='CrossEntropyLoss', loader=loader, device=device, model=model, num_classes=C, layer_name=layer_name, hessian_type='Hessian', init_poly_deg=64, poly_deg=128, spectrum_margin=0.05, poly_points=1024, SSI_iters=128) Hess_eigval, \ Hess_eigval_density = Hess.LanczosLoop(denormalize=True) layerwise_eigenspectrums_path = osp.join( ckpt_dir, 'training_eigenspectrums_epoch_{}_layer_{}.npz'.format( epoch, layer_name)) np.savez(layerwise_eigenspectrums_path, eigval=Hess_eigval, eigval_density=Hess_eigval_density) for phase in ['train', 'test']: stats = evaluate_model(model, criterion, dataloaders[phase], device, dataset_sizes[phase]) loss_list[phase].append(stats['loss']) acc_list[phase].append(stats['acc']) logger.info('{}:'.format(phase)) logger.info('\tloss:{}'.format(stats['loss'])) logger.info('\tacc :{}'.format(stats['acc'])) if phase == 'test': plt.clf() plt.plot(loss_list['test'], label='test_loss') plt.plot(loss_list['train'], label='train_loss') plt.legend() plt.savefig(loss_image_path) plt.clf() plt.plot(acc_list['test'], label='test_acc') plt.plot(acc_list['train'], label='train_acc') plt.legend() plt.savefig(acc_image_path) plt.clf() full_eigenspectrums = np.array(full_eigenspectrums) assert full_eigenspectrums.shape[0] % 2 == 0 assert full_eigenspectrums.shape[0] // 2 == num_epochs np.save(full_eigenspectrums_path, full_eigenspectrums) return full_eigenspectrums