예제 #1
0
 def inverse_direction_fun_vec(x):
     x_temp = x.clone().detach().requires_grad_()
     with torch.enable_grad():
         x_list = DEQFunc2d.vec2list(x_temp, cutoffs)
         loss = loss_function(x_list)
     loss.backward()
     dl_dx = x_temp.grad
     return dl_dx
예제 #2
0
def adj_broyden_convergence(opa_freq,
                            n_runs=1,
                            dataset='imagenet',
                            model_size='LARGE'):
    # setup
    model = setup_model(opa_freq is not None, dataset, model_size)
    if dataset == 'imagenet':
        traindir = os.path.join(config.DATASET.ROOT + '/images',
                                config.DATASET.TRAIN_SET)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = datasets.ImageFolder(traindir, transform_train)
    else:
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        augment_list = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ] if config.DATASET.AUGMENT else []
        transform_train = transforms.Compose(augment_list + [
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}',
                                         train=True,
                                         download=True,
                                         transform=transform_train)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=10,
        pin_memory=True,
        worker_init_fn=partial(worker_init_fn, seed=42),
    )
    iter_loader = iter(train_loader)
    solvers = {
        'adj_broyden': adj_broyden,
        'broyden': broyden,
    }
    convergence_results = {
        'correl': [],
        'ratio': [],
        'diff': [],
        'rdiff': [],
    }
    for solver_name in solvers.keys():
        convergence_results[f'{solver_name}_rdiff'] = []
        convergence_results[f'{solver_name}_diff'] = []
        convergence_results[f'{solver_name}_lstep'] = []
    for i_run in range(n_runs):
        input, target = next(iter_loader)
        target = target.cuda(non_blocking=True)
        solvers_results = {}
        x_list, z_list = model.feature_extraction(input.cuda())
        model.fullstage._reset(z_list)
        model.fullstage_copy._copy(model.fullstage)
        for solver_name, solver in solvers.items():
            # fixed point solving
            x_list = [x.clone().detach().requires_grad_() for x in x_list]
            z_list = [z.clone() for z in z_list]
            cutoffs = [(elem.size(1), elem.size(2), elem.size(3))
                       for elem in z_list]
            args = (27, int(1e9), None)
            nelem = sum([elem.nelement() for elem in z_list])
            eps = 1e-5 * np.sqrt(nelem)
            z1_est = DEQFunc2d.list2vec(z_list)
            g = lambda x: DEQFunc2d.g(model.fullstage_copy, x, x_list, cutoffs,
                                      *args)
            model.copy_modules()
            if solver_name == 'adj_broyden':
                loss_function = lambda y_est: model.get_fixed_point_loss(
                    y_est, target)

                def inverse_direction_fun_vec(x):
                    x_temp = x.clone().detach().requires_grad_()
                    with torch.enable_grad():
                        x_list = DEQFunc2d.vec2list(x_temp, cutoffs)
                        loss = loss_function(x_list)
                    loss.backward()
                    dl_dx = x_temp.grad
                    return dl_dx

                inverse_direction_fun = inverse_direction_fun_vec
                add_kwargs = dict(
                    inverse_direction_freq=opa_freq,
                    inverse_direction_fun=inverse_direction_fun
                    if opa_freq is not None else None,
                )
            else:
                add_kwargs = {}
            result_info = solver(
                g,
                z1_est,
                threshold=config.MODEL.F_THRES,
                eps=eps,
                name="forward",
                **add_kwargs,
            )
            z1_est = result_info['result']
            convergence_results[f'{solver_name}_diff'].append(
                result_info['diff'])
            lowest_step = result_info['lowest_step']
            convergence_results[f'{solver_name}_lstep'].append(lowest_step)
            convergence_results[f'{solver_name}_rdiff'].append(
                result_info['new_trace'][lowest_step])
            solvers_results[solver_name] = z1_est.clone().detach()
        z1_adj_br = solvers_results['adj_broyden']
        z1_br = solvers_results['broyden']
        correl = torch.dot(
            torch.flatten(z1_adj_br),
            torch.flatten(z1_br),
        )
        scaling = torch.norm(z1_adj_br) * torch.norm(z1_br)
        convergence_results['correl'].append((correl / scaling).item())
        convergence_results['ratio'].append(
            (torch.norm(z1_br) / torch.norm(z1_adj_br)).item())
        convergence_results['diff'].append(
            torch.norm(z1_br - z1_adj_br).item())
        convergence_results['rdiff'].append(
            (torch.norm(z1_br - z1_adj_br) / torch.norm(z1_br)).item())

    return convergence_results
def eval_ratio_fb_classifier(
    n_epochs=100,
    pretrained=False,
    n_gpus=1,
    dataset='imagenet',
    model_size='SMALL',
    shine=False,
    fpn=False,
    gradient_correl=False,
    gradient_ratio=False,
    adjoint_broyden=False,
    opa=False,
    refine=False,
    fallback=False,
    save_at=None,
    restart_from=None,
    use_group_norm=False,
    seed=0,
    n_samples=1,
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    args = update_config_w_args(
        n_epochs=n_epochs,
        pretrained=pretrained,
        n_gpus=n_gpus,
        dataset=dataset,
        model_size=model_size,
        use_group_norm=use_group_norm,
    )
    print(colored("Setting default tensor type to cuda.FloatTensor", "cyan"))
    try:
        torch.multiprocessing.set_start_method('spawn')
    except RuntimeError:
        pass

    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    logger, final_output_dir, tb_log_dir = create_logger(
        config,
        args.cfg,
        'train',
        shine=shine,
        fpn=fpn,
        seed=seed,
        use_group_norm=use_group_norm,
        adjoint_broyden=adjoint_broyden,
        opa=opa,
        refine=refine,
        fallback=fallback,
    )

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_cls_net')(
        config,
        shine=shine,
        fpn=fpn,
        gradient_correl=gradient_correl,
        gradient_ratio=gradient_ratio,
        adjoint_broyden=adjoint_broyden,
        refine=refine,
        fallback=fallback,
        opa=opa,
    ).cuda()

    # dump_input = torch.rand(config.TRAIN.BATCH_SIZE_PER_GPU, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]).cuda()
    # logger.info(get_model_summary(model, dump_input))

    if config.TRAIN.MODEL_FILE:
        model.load_state_dict(torch.load(config.TRAIN.MODEL_FILE))
        logger.info(
            colored('=> loading model from {}'.format(config.TRAIN.MODEL_FILE),
                    'red'))

    # copy model file
    models_dst_dir = os.path.join(final_output_dir, 'models')
    if os.path.exists(models_dst_dir):
        shutil.rmtree(models_dst_dir)

    gpus = list(config.GPUS)
    print("Finished constructing model!")

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        if restart_from is None:
            resume_file = 'checkpoint.pth.tar'
        else:
            resume_file = f'checkpoint_{restart_from}.pth.tar'
        model_state_file = os.path.join(final_output_dir, resume_file)
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            model.load_state_dict(checkpoint['state_dict'])

    # Data loading code
    dataset_name = config.DATASET.DATASET

    if dataset_name == 'imagenet':
        traindir = os.path.join(config.DATASET.ROOT + '/images',
                                config.DATASET.TRAIN_SET)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = datasets.ImageFolder(traindir, transform_train)
    else:
        assert dataset_name == "cifar10", "Only CIFAR-10 is supported at this phase"
        classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog',
                   'horse', 'ship', 'truck')  # For reference

        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        augment_list = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ] if config.DATASET.AUGMENT else []
        transform_train = transforms.Compose(augment_list + [
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}',
                                         train=True,
                                         download=True,
                                         transform=transform_train)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=True,
        num_workers=10,
        pin_memory=True,
        worker_init_fn=partial(worker_init_fn, seed=seed),
    )

    iter_loader = iter(train_loader)
    ratios = []
    with torch.autograd.profiler.profile(use_cuda=True) as prof:
        for i_sample in range(n_samples):
            with profiler.record_function("Data loading"):
                input, target = next(iter_loader)
                input = input.cuda(non_blocking=False)
            with profiler.record_function("Feature extraction PASS"):
                x_list, z_list = model.feature_extraction(input)
            # For variational dropout mask resetting and weight normalization re-computations
            model.fullstage._reset(z_list)
            model.fullstage_copy._copy(model.fullstage)
            x_list = [x.clone().detach().requires_grad_() for x in x_list]
            z_list = [z.clone().detach().requires_grad_() for z in z_list]
            with profiler.record_function("Forward PASS"):
                start_forward = time.time()
                with torch.no_grad():
                    model.fullstage_copy(z_list, x_list)
                    torch.cuda.synchronize()
                end_forward = time.time()
                time_forward = end_forward - start_forward
            with profiler.record_function("Forward PASS enable grad"):
                z_list_new = model.fullstage_copy(z_list, x_list)
            z = DEQFunc2d.list2vec(z_list_new)
            with profiler.record_function("Backward PASS"):
                start_backward = time.time()
                z.backward(z)
                torch.cuda.synchronize()
                end_backward = time.time()
            time_backward = end_backward - start_backward
            ratios.append(time_backward / time_forward)
    print(prof.key_averages().table(sort_by="cuda_time_total"))
    prof.export_chrome_trace("trace.json")
    return ratios
예제 #4
0
def fallback_ratio(n_runs=1, dataset='imagenet', model_size='SMALL'):
    # setup
    model = setup_model(False, dataset, model_size)
    if dataset == 'imagenet':
        traindir = os.path.join(config.DATASET.ROOT + '/images',
                                config.DATASET.TRAIN_SET)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = datasets.ImageFolder(traindir, transform_train)
    else:
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        augment_list = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ] if config.DATASET.AUGMENT else []
        transform_train = transforms.Compose(augment_list + [
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}',
                                         train=True,
                                         download=True,
                                         transform=transform_train)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=10,
        pin_memory=True,
        worker_init_fn=partial(worker_init_fn, seed=42),
    )
    fallback_uses = 0
    iter_loader = iter(train_loader)
    for i_run in range(n_runs):
        input, target = next(iter_loader)
        target = target.cuda(non_blocking=True)
        with torch.no_grad():
            x_list, z_list = model.feature_extraction(input.cuda())
            model.fullstage._reset(z_list)
            model.fullstage_copy._copy(model.fullstage)
            # fixed point solving
            x_list = [x.clone().detach().requires_grad_() for x in x_list]
            cutoffs = [(elem.size(1), elem.size(2), elem.size(3))
                       for elem in z_list]
            args = (27, int(1e9), None)
            nelem = sum([elem.nelement() for elem in z_list])
            eps = 1e-5 * np.sqrt(nelem)
            z1_est = DEQFunc2d.list2vec(z_list)
            z1_est = torch.zeros_like(z1_est)
            g = lambda x: DEQFunc2d.g(model.fullstage_copy, x, x_list, cutoffs,
                                      *args)
            model.copy_modules()
            loss_function = lambda y_est: model.get_fixed_point_loss(
                y_est, target)

            def inverse_direction_fun(x):
                x_temp = x.clone().detach().requires_grad_()
                with torch.enable_grad():
                    x_list = DEQFunc2d.vec2list(x_temp, cutoffs)
                    loss = loss_function(x_list)
                loss.backward()
                dl_dx = x_temp.grad
                return dl_dx

            result_info = broyden(
                g,
                z1_est,
                threshold=config.MODEL.F_THRES,
                eps=eps,
                name="forward",
            )
            z1_est = result_info['result']
            Us = result_info['Us']
            VTs = result_info['VTs']
            nstep = result_info['lowest_step']
            # compute true incoming gradient
            grad = inverse_direction_fun(z1_est)

            inv_dir = -rmatvec(Us[:, :, :, :nstep - 1], VTs[:, :nstep - 1],
                               grad)
            fallback_mask = inv_dir.view(
                32, -1).norm(dim=1) > 1.8 * grad.view(32, -1).norm(dim=1)
            fallback_uses += fallback_mask.sum().item()
    return fallback_uses
예제 #5
0
def adj_broyden_correl(opa_freq,
                       n_runs=1,
                       random_prescribed=True,
                       dataset='imagenet',
                       model_size='LARGE'):
    # setup
    model = setup_model(opa_freq is not None, dataset, model_size)
    if dataset == 'imagenet':
        traindir = os.path.join(config.DATASET.ROOT + '/images',
                                config.DATASET.TRAIN_SET)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = datasets.ImageFolder(traindir, transform_train)
    else:
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        augment_list = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ] if config.DATASET.AUGMENT else []
        transform_train = transforms.Compose(augment_list + [
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = datasets.CIFAR10(root=f'{config.DATASET.ROOT}',
                                         train=True,
                                         download=True,
                                         transform=transform_train)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=10,
        pin_memory=True,
        worker_init_fn=partial(worker_init_fn, seed=42),
    )
    methods_results = {
        method_name: {
            'correl': [],
            'ratio': []
        }
        for method_name in ['shine-adj-br', 'shine', 'shine-opa', 'fpn']
    }
    methods_solvers = {
        'shine': broyden,
        'shine-adj-br': adj_broyden,
        'shine-opa': adj_broyden,
        'fpn': broyden,
    }
    random_results = {'correl': [], 'ratio': []}
    iter_loader = iter(train_loader)
    for i_run in range(n_runs):
        input, target = next(iter_loader)
        target = target.cuda(non_blocking=True)
        x_list, z_list = model.feature_extraction(input.cuda())
        model.fullstage._reset(z_list)
        model.fullstage_copy._copy(model.fullstage)
        # fixed point solving
        x_list = [x.clone().detach().requires_grad_() for x in x_list]
        cutoffs = [(elem.size(1), elem.size(2), elem.size(3))
                   for elem in z_list]
        args = (27, int(1e9), None)
        nelem = sum([elem.nelement() for elem in z_list])
        eps = 1e-5 * np.sqrt(nelem)
        z1_est = DEQFunc2d.list2vec(z_list)
        directions_dir = {
            'random': torch.randn(z1_est.shape),
            'prescribed': torch.randn(z1_est.shape),
        }
        for method_name in methods_results.keys():
            z1_est = torch.zeros_like(z1_est)
            g = lambda x: DEQFunc2d.g(model.fullstage_copy, x, x_list, cutoffs,
                                      *args)
            if random_prescribed:
                inverse_direction_fun = lambda x: directions_dir['prescribed']
            else:
                model.copy_modules()
                loss_function = lambda y_est: model.get_fixed_point_loss(
                    y_est, target)

                def inverse_direction_fun_vec(x):
                    x_temp = x.clone().detach().requires_grad_()
                    with torch.enable_grad():
                        x_list = DEQFunc2d.vec2list(x_temp, cutoffs)
                        loss = loss_function(x_list)
                    loss.backward()
                    dl_dx = x_temp.grad
                    return dl_dx

                inverse_direction_fun = inverse_direction_fun_vec

            solver = methods_solvers[method_name]
            if 'opa' in method_name:
                add_kwargs = dict(
                    inverse_direction_freq=opa_freq,
                    inverse_direction_fun=inverse_direction_fun
                    if opa_freq is not None else None,
                )
            else:
                add_kwargs = {}
            result_info = solver(
                g,
                z1_est,
                threshold=config.MODEL.F_THRES,
                eps=eps,
                name="forward",
                **add_kwargs,
            )
            z1_est = result_info['result']
            Us = result_info['Us']
            VTs = result_info['VTs']
            nstep = result_info['lowest_step']
            if opa_freq is not None:
                nstep += (nstep - 1) // opa_freq
            # compute true incoming gradient if needed
            if not random_prescribed:
                directions_dir['prescribed'] = inverse_direction_fun_vec(
                    z1_est)
                # making sure the random direction norm is not unrealistic
                directions_dir[
                    'random'] = directions_dir['random'] * torch.norm(
                        directions_dir['prescribed']) / torch.norm(
                            directions_dir['random'])
            # inversion on random gradients
            z1_temp = z1_est.clone().detach().requires_grad_()
            with torch.enable_grad():
                y = DEQFunc2d.g(model.fullstage_copy, z1_temp, x_list, cutoffs,
                                *args)

            eps = 2e-10
            for direction_name, direction in directions_dir.items():

                def g(x):
                    y.backward(x, retain_graph=True)
                    res = z1_temp.grad + direction
                    z1_temp.grad.zero_()
                    return res

                result_info_inversion = broyden(
                    g,
                    direction,  # we initialize Jacobian Free style
                    # in order to accelerate the convergence
                    threshold=35,
                    eps=eps,
                    name="backward",
                )
                true_inv = result_info_inversion['result']
                inv_dir = {
                    'fpn':
                    direction,
                    'shine':
                    -rmatvec(Us[:, :, :, :nstep - 1], VTs[:, :nstep - 1],
                             direction),
                }
                inv_dir['shine-opa'] = inv_dir['shine']
                inv_dir['shine-adj-br'] = inv_dir['shine']
                approx_inv = inv_dir[method_name]
                correl = torch.dot(
                    torch.flatten(true_inv),
                    torch.flatten(approx_inv),
                )
                scaling = torch.norm(true_inv) * torch.norm(approx_inv)
                correl = correl / scaling
                ratio = torch.norm(true_inv) / torch.norm(approx_inv)
                if direction_name == 'prescribed':
                    methods_results[method_name]['correl'].append(
                        correl.item())
                    methods_results[method_name]['ratio'].append(ratio.item())
                else:
                    if method_name == 'fpn':
                        random_results['correl'].append(correl.item())
                        random_results['ratio'].append(ratio.item())
            y.backward(torch.zeros_like(true_inv), retain_graph=False)
    return methods_results, random_results