def inverse_direction_fun_vec(x): x_temp = x.clone().detach().requires_grad_() with torch.enable_grad(): x_list = DEQFunc2d.vec2list(x_temp, cutoffs) loss = loss_function(x_list) loss.backward() dl_dx = x_temp.grad return dl_dx
def adj_broyden_convergence(opa_freq, n_runs=1, dataset='imagenet', model_size='LARGE'): # setup model = setup_model(opa_freq is not None, dataset, model_size) if dataset == 'imagenet': traindir = os.path.join(config.DATASET.ROOT + '/images', config.DATASET.TRAIN_SET) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform_train = transforms.Compose([ transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) train_dataset = datasets.ImageFolder(traindir, transform_train) else: normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) augment_list = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip() ] if config.DATASET.AUGMENT else [] transform_train = transforms.Compose(augment_list + [ transforms.ToTensor(), normalize, ]) train_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=32, shuffle=True, num_workers=10, pin_memory=True, worker_init_fn=partial(worker_init_fn, seed=42), ) iter_loader = iter(train_loader) solvers = { 'adj_broyden': adj_broyden, 'broyden': broyden, } convergence_results = { 'correl': [], 'ratio': [], 'diff': [], 'rdiff': [], } for solver_name in solvers.keys(): convergence_results[f'{solver_name}_rdiff'] = [] convergence_results[f'{solver_name}_diff'] = [] convergence_results[f'{solver_name}_lstep'] = [] for i_run in range(n_runs): input, target = next(iter_loader) target = target.cuda(non_blocking=True) solvers_results = {} x_list, z_list = model.feature_extraction(input.cuda()) model.fullstage._reset(z_list) model.fullstage_copy._copy(model.fullstage) for solver_name, solver in solvers.items(): # fixed point solving x_list = [x.clone().detach().requires_grad_() for x in x_list] z_list = [z.clone() for z in z_list] cutoffs = [(elem.size(1), elem.size(2), elem.size(3)) for elem in z_list] args = (27, int(1e9), None) nelem = sum([elem.nelement() for elem in z_list]) eps = 1e-5 * np.sqrt(nelem) z1_est = DEQFunc2d.list2vec(z_list) g = lambda x: DEQFunc2d.g(model.fullstage_copy, x, x_list, cutoffs, *args) model.copy_modules() if solver_name == 'adj_broyden': loss_function = lambda y_est: model.get_fixed_point_loss( y_est, target) def inverse_direction_fun_vec(x): x_temp = x.clone().detach().requires_grad_() with torch.enable_grad(): x_list = DEQFunc2d.vec2list(x_temp, cutoffs) loss = loss_function(x_list) loss.backward() dl_dx = x_temp.grad return dl_dx inverse_direction_fun = inverse_direction_fun_vec add_kwargs = dict( inverse_direction_freq=opa_freq, inverse_direction_fun=inverse_direction_fun if opa_freq is not None else None, ) else: add_kwargs = {} result_info = solver( g, z1_est, threshold=config.MODEL.F_THRES, eps=eps, name="forward", **add_kwargs, ) z1_est = result_info['result'] convergence_results[f'{solver_name}_diff'].append( result_info['diff']) lowest_step = result_info['lowest_step'] convergence_results[f'{solver_name}_lstep'].append(lowest_step) convergence_results[f'{solver_name}_rdiff'].append( result_info['new_trace'][lowest_step]) solvers_results[solver_name] = z1_est.clone().detach() z1_adj_br = solvers_results['adj_broyden'] z1_br = solvers_results['broyden'] correl = torch.dot( torch.flatten(z1_adj_br), torch.flatten(z1_br), ) scaling = torch.norm(z1_adj_br) * torch.norm(z1_br) convergence_results['correl'].append((correl / scaling).item()) convergence_results['ratio'].append( (torch.norm(z1_br) / torch.norm(z1_adj_br)).item()) convergence_results['diff'].append( torch.norm(z1_br - z1_adj_br).item()) convergence_results['rdiff'].append( (torch.norm(z1_br - z1_adj_br) / torch.norm(z1_br)).item()) return convergence_results
def eval_ratio_fb_classifier( n_epochs=100, pretrained=False, n_gpus=1, dataset='imagenet', model_size='SMALL', shine=False, fpn=False, gradient_correl=False, gradient_ratio=False, adjoint_broyden=False, opa=False, refine=False, fallback=False, save_at=None, restart_from=None, use_group_norm=False, seed=0, n_samples=1, ): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) args = update_config_w_args( n_epochs=n_epochs, pretrained=pretrained, n_gpus=n_gpus, dataset=dataset, model_size=model_size, use_group_norm=use_group_norm, ) print(colored("Setting default tensor type to cuda.FloatTensor", "cyan")) try: torch.multiprocessing.set_start_method('spawn') except RuntimeError: pass torch.set_default_tensor_type('torch.cuda.FloatTensor') logger, final_output_dir, tb_log_dir = create_logger( config, args.cfg, 'train', shine=shine, fpn=fpn, seed=seed, use_group_norm=use_group_norm, adjoint_broyden=adjoint_broyden, opa=opa, refine=refine, fallback=fallback, ) logger.info(pprint.pformat(args)) logger.info(pprint.pformat(config)) # cudnn related setting cudnn.benchmark = config.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = config.CUDNN.ENABLED model = eval('models.' + config.MODEL.NAME + '.get_cls_net')( config, shine=shine, fpn=fpn, gradient_correl=gradient_correl, gradient_ratio=gradient_ratio, adjoint_broyden=adjoint_broyden, refine=refine, fallback=fallback, opa=opa, ).cuda() # dump_input = torch.rand(config.TRAIN.BATCH_SIZE_PER_GPU, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]).cuda() # logger.info(get_model_summary(model, dump_input)) if config.TRAIN.MODEL_FILE: model.load_state_dict(torch.load(config.TRAIN.MODEL_FILE)) logger.info( colored('=> loading model from {}'.format(config.TRAIN.MODEL_FILE), 'red')) # copy model file models_dst_dir = os.path.join(final_output_dir, 'models') if os.path.exists(models_dst_dir): shutil.rmtree(models_dst_dir) gpus = list(config.GPUS) print("Finished constructing model!") # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() last_epoch = config.TRAIN.BEGIN_EPOCH if config.TRAIN.RESUME: if restart_from is None: resume_file = 'checkpoint.pth.tar' else: resume_file = f'checkpoint_{restart_from}.pth.tar' model_state_file = os.path.join(final_output_dir, resume_file) if os.path.isfile(model_state_file): checkpoint = torch.load(model_state_file) last_epoch = checkpoint['epoch'] best_perf = checkpoint['perf'] model.load_state_dict(checkpoint['state_dict']) # Data loading code dataset_name = config.DATASET.DATASET if dataset_name == 'imagenet': traindir = os.path.join(config.DATASET.ROOT + '/images', config.DATASET.TRAIN_SET) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform_train = transforms.Compose([ transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) train_dataset = datasets.ImageFolder(traindir, transform_train) else: assert dataset_name == "cifar10", "Only CIFAR-10 is supported at this phase" classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # For reference normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) augment_list = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip() ] if config.DATASET.AUGMENT else [] transform_train = transforms.Compose(augment_list + [ transforms.ToTensor(), normalize, ]) train_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus), shuffle=True, num_workers=10, pin_memory=True, worker_init_fn=partial(worker_init_fn, seed=seed), ) iter_loader = iter(train_loader) ratios = [] with torch.autograd.profiler.profile(use_cuda=True) as prof: for i_sample in range(n_samples): with profiler.record_function("Data loading"): input, target = next(iter_loader) input = input.cuda(non_blocking=False) with profiler.record_function("Feature extraction PASS"): x_list, z_list = model.feature_extraction(input) # For variational dropout mask resetting and weight normalization re-computations model.fullstage._reset(z_list) model.fullstage_copy._copy(model.fullstage) x_list = [x.clone().detach().requires_grad_() for x in x_list] z_list = [z.clone().detach().requires_grad_() for z in z_list] with profiler.record_function("Forward PASS"): start_forward = time.time() with torch.no_grad(): model.fullstage_copy(z_list, x_list) torch.cuda.synchronize() end_forward = time.time() time_forward = end_forward - start_forward with profiler.record_function("Forward PASS enable grad"): z_list_new = model.fullstage_copy(z_list, x_list) z = DEQFunc2d.list2vec(z_list_new) with profiler.record_function("Backward PASS"): start_backward = time.time() z.backward(z) torch.cuda.synchronize() end_backward = time.time() time_backward = end_backward - start_backward ratios.append(time_backward / time_forward) print(prof.key_averages().table(sort_by="cuda_time_total")) prof.export_chrome_trace("trace.json") return ratios
def fallback_ratio(n_runs=1, dataset='imagenet', model_size='SMALL'): # setup model = setup_model(False, dataset, model_size) if dataset == 'imagenet': traindir = os.path.join(config.DATASET.ROOT + '/images', config.DATASET.TRAIN_SET) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform_train = transforms.Compose([ transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) train_dataset = datasets.ImageFolder(traindir, transform_train) else: normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) augment_list = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip() ] if config.DATASET.AUGMENT else [] transform_train = transforms.Compose(augment_list + [ transforms.ToTensor(), normalize, ]) train_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=32, shuffle=True, num_workers=10, pin_memory=True, worker_init_fn=partial(worker_init_fn, seed=42), ) fallback_uses = 0 iter_loader = iter(train_loader) for i_run in range(n_runs): input, target = next(iter_loader) target = target.cuda(non_blocking=True) with torch.no_grad(): x_list, z_list = model.feature_extraction(input.cuda()) model.fullstage._reset(z_list) model.fullstage_copy._copy(model.fullstage) # fixed point solving x_list = [x.clone().detach().requires_grad_() for x in x_list] cutoffs = [(elem.size(1), elem.size(2), elem.size(3)) for elem in z_list] args = (27, int(1e9), None) nelem = sum([elem.nelement() for elem in z_list]) eps = 1e-5 * np.sqrt(nelem) z1_est = DEQFunc2d.list2vec(z_list) z1_est = torch.zeros_like(z1_est) g = lambda x: DEQFunc2d.g(model.fullstage_copy, x, x_list, cutoffs, *args) model.copy_modules() loss_function = lambda y_est: model.get_fixed_point_loss( y_est, target) def inverse_direction_fun(x): x_temp = x.clone().detach().requires_grad_() with torch.enable_grad(): x_list = DEQFunc2d.vec2list(x_temp, cutoffs) loss = loss_function(x_list) loss.backward() dl_dx = x_temp.grad return dl_dx result_info = broyden( g, z1_est, threshold=config.MODEL.F_THRES, eps=eps, name="forward", ) z1_est = result_info['result'] Us = result_info['Us'] VTs = result_info['VTs'] nstep = result_info['lowest_step'] # compute true incoming gradient grad = inverse_direction_fun(z1_est) inv_dir = -rmatvec(Us[:, :, :, :nstep - 1], VTs[:, :nstep - 1], grad) fallback_mask = inv_dir.view( 32, -1).norm(dim=1) > 1.8 * grad.view(32, -1).norm(dim=1) fallback_uses += fallback_mask.sum().item() return fallback_uses
def adj_broyden_correl(opa_freq, n_runs=1, random_prescribed=True, dataset='imagenet', model_size='LARGE'): # setup model = setup_model(opa_freq is not None, dataset, model_size) if dataset == 'imagenet': traindir = os.path.join(config.DATASET.ROOT + '/images', config.DATASET.TRAIN_SET) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform_train = transforms.Compose([ transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) train_dataset = datasets.ImageFolder(traindir, transform_train) else: normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) augment_list = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip() ] if config.DATASET.AUGMENT else [] transform_train = transforms.Compose(augment_list + [ transforms.ToTensor(), normalize, ]) train_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=32, shuffle=True, num_workers=10, pin_memory=True, worker_init_fn=partial(worker_init_fn, seed=42), ) methods_results = { method_name: { 'correl': [], 'ratio': [] } for method_name in ['shine-adj-br', 'shine', 'shine-opa', 'fpn'] } methods_solvers = { 'shine': broyden, 'shine-adj-br': adj_broyden, 'shine-opa': adj_broyden, 'fpn': broyden, } random_results = {'correl': [], 'ratio': []} iter_loader = iter(train_loader) for i_run in range(n_runs): input, target = next(iter_loader) target = target.cuda(non_blocking=True) x_list, z_list = model.feature_extraction(input.cuda()) model.fullstage._reset(z_list) model.fullstage_copy._copy(model.fullstage) # fixed point solving x_list = [x.clone().detach().requires_grad_() for x in x_list] cutoffs = [(elem.size(1), elem.size(2), elem.size(3)) for elem in z_list] args = (27, int(1e9), None) nelem = sum([elem.nelement() for elem in z_list]) eps = 1e-5 * np.sqrt(nelem) z1_est = DEQFunc2d.list2vec(z_list) directions_dir = { 'random': torch.randn(z1_est.shape), 'prescribed': torch.randn(z1_est.shape), } for method_name in methods_results.keys(): z1_est = torch.zeros_like(z1_est) g = lambda x: DEQFunc2d.g(model.fullstage_copy, x, x_list, cutoffs, *args) if random_prescribed: inverse_direction_fun = lambda x: directions_dir['prescribed'] else: model.copy_modules() loss_function = lambda y_est: model.get_fixed_point_loss( y_est, target) def inverse_direction_fun_vec(x): x_temp = x.clone().detach().requires_grad_() with torch.enable_grad(): x_list = DEQFunc2d.vec2list(x_temp, cutoffs) loss = loss_function(x_list) loss.backward() dl_dx = x_temp.grad return dl_dx inverse_direction_fun = inverse_direction_fun_vec solver = methods_solvers[method_name] if 'opa' in method_name: add_kwargs = dict( inverse_direction_freq=opa_freq, inverse_direction_fun=inverse_direction_fun if opa_freq is not None else None, ) else: add_kwargs = {} result_info = solver( g, z1_est, threshold=config.MODEL.F_THRES, eps=eps, name="forward", **add_kwargs, ) z1_est = result_info['result'] Us = result_info['Us'] VTs = result_info['VTs'] nstep = result_info['lowest_step'] if opa_freq is not None: nstep += (nstep - 1) // opa_freq # compute true incoming gradient if needed if not random_prescribed: directions_dir['prescribed'] = inverse_direction_fun_vec( z1_est) # making sure the random direction norm is not unrealistic directions_dir[ 'random'] = directions_dir['random'] * torch.norm( directions_dir['prescribed']) / torch.norm( directions_dir['random']) # inversion on random gradients z1_temp = z1_est.clone().detach().requires_grad_() with torch.enable_grad(): y = DEQFunc2d.g(model.fullstage_copy, z1_temp, x_list, cutoffs, *args) eps = 2e-10 for direction_name, direction in directions_dir.items(): def g(x): y.backward(x, retain_graph=True) res = z1_temp.grad + direction z1_temp.grad.zero_() return res result_info_inversion = broyden( g, direction, # we initialize Jacobian Free style # in order to accelerate the convergence threshold=35, eps=eps, name="backward", ) true_inv = result_info_inversion['result'] inv_dir = { 'fpn': direction, 'shine': -rmatvec(Us[:, :, :, :nstep - 1], VTs[:, :nstep - 1], direction), } inv_dir['shine-opa'] = inv_dir['shine'] inv_dir['shine-adj-br'] = inv_dir['shine'] approx_inv = inv_dir[method_name] correl = torch.dot( torch.flatten(true_inv), torch.flatten(approx_inv), ) scaling = torch.norm(true_inv) * torch.norm(approx_inv) correl = correl / scaling ratio = torch.norm(true_inv) / torch.norm(approx_inv) if direction_name == 'prescribed': methods_results[method_name]['correl'].append( correl.item()) methods_results[method_name]['ratio'].append(ratio.item()) else: if method_name == 'fpn': random_results['correl'].append(correl.item()) random_results['ratio'].append(ratio.item()) y.backward(torch.zeros_like(true_inv), retain_graph=False) return methods_results, random_results