コード例 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device', '-d', default='cuda', help='specifies the main device')
    parser.add_argument('--all_device_ids', nargs='+', type=str, default=None,
                        help="If not None, this list specifies devices for multiple GPU training. "
                             "The first device should match with the main device (args.device).")
    parser.add_argument('--batch_size', '-b', type=int, default=256)
    parser.add_argument('--epochs', '-e', type=int, default=400)
    parser.add_argument('--stopping_param', type=int, default=2**30)
    parser.add_argument('--save_iter', '-s', type=int, default=10)
    parser.add_argument('--vis_iter', '-v', type=int, default=10)
    parser.add_argument('--log_dir', '-l', type=str, default=None)
    parser.add_argument('--seed', type=int, default=42)

    # data parameters
    parser.add_argument('--dataset', '-D', type=str, default='mnist')
    parser.add_argument('--data_augmentation', '-A', action='store_true', dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation', action='store_true', default=False)
    parser.add_argument('--resize_to_imagenet', action='store_true', dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset', action='store_true', dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)
    parser.add_argument('--num_workers', type=int, default=0, help='number of workers in data loaders')

    # hyper-parameters
    parser.add_argument('--model_class', '-m', type=str, default='ClassifierL2')

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'])

    args = parser.parse_args()
    print(args)

    # Load data
    train_data, val_data, test_data, _ = load_data_from_arguments(args, build_loaders=False)

    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    train_loader, val_loader, test_loader = get_loaders_from_datasets(train_data, val_data, test_data,
                                                                      batch_size=args.batch_size,
                                                                      num_workers=args.num_workers)

    # Options
    optimization_args = {
        'optimizer': {
            'name': args.optimizer,
            'lr': args.lr,
        }
    }

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_loader.dataset[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        device=args.device,
                        seed=args.seed)

    metrics_list = [metrics.Accuracy(output_key='pred')]
    if args.dataset == 'imagenet':
        metrics_list.append(metrics.TopKAccuracy(k=5, output_key='pred'))

    callbacks_list = [callbacks.SaveBestWithMetric(metric=metrics_list[0], partition='val', direction='max')]

    stopper = callbacks.EarlyStoppingWithMetric(metric=metrics_list[0], stopping_param=args.stopping_param,
                                                partition='val', direction='max')

    training.train(model=model,
                   train_loader=train_loader,
                   val_loader=val_loader,
                   epochs=args.epochs,
                   save_iter=args.save_iter,
                   vis_iter=args.vis_iter,
                   optimization_args=optimization_args,
                   log_dir=args.log_dir,
                   args_to_log=args,
                   stopper=stopper,
                   metrics=metrics_list,
                   callbacks=callbacks_list,
                   device_ids=args.all_device_ids)
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device',
                        '-d',
                        default='cuda',
                        help='specifies the main device')
    parser.add_argument(
        '--all_device_ids',
        nargs='+',
        type=str,
        default=None,
        help=
        "If not None, this list specifies devices for multiple GPU training. "
        "The first device should match with the main device (args.device).")
    parser.add_argument('--batch_size', '-b', type=int, default=256)
    parser.add_argument('--epochs', '-e', type=int, default=400)
    parser.add_argument('--stopping_param', type=int, default=2**30)
    parser.add_argument('--save_iter', '-s', type=int, default=2**30)
    parser.add_argument('--vis_iter', '-v', type=int, default=2**30)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument(
        '--num_accumulation_steps',
        default=1,
        type=int,
        help='Number of training steps to accumulate before updating weights')

    # data parameters
    parser.add_argument('--dataset', '-D', type=str, default='mnist')
    parser.add_argument('--data_augmentation',
                        '-A',
                        action='store_true',
                        dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation',
                        action='store_true',
                        default=False)
    parser.add_argument('--resize_to_imagenet',
                        action='store_true',
                        dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset',
                        action='store_true',
                        dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)
    parser.add_argument(
        '--sample_ranking_file',
        type=str,
        default=None,
        help=
        'Points to a pickle file that stores an ordering of examples from least to '
        'most important. The most important args.exclude_ratio number of samples '
        'will be excluded from training.')
    parser.add_argument('--exclude_ratio',
                        type=float,
                        default=0.0,
                        help='Fraction of examples to exclude.')
    parser.add_argument('--exclude_side',
                        type=str,
                        default='top',
                        choices=['top', 'bottom'],
                        help='from which side of the order to remove')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers in data loaders')

    # hyper-parameters
    parser.add_argument('--model_class',
                        '-m',
                        type=str,
                        default='ClassifierL2')

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--optimizer',
                        type=str,
                        default='adam',
                        choices=['adam', 'sgd'])
    parser.add_argument('--random_baseline_seed', type=int, default=42)

    parser.add_argument('--output_dir',
                        '-o',
                        type=str,
                        default='sample_info/results/data-summarization/')
    parser.add_argument('--baseline_name', '-B', type=str, required=True)
    parser.add_argument('--exp_name', '-E', type=str, required=True)

    args = parser.parse_args()
    print(args)

    # set tensorboard log directory
    args.log_dir = os.path.join(args.output_dir, args.baseline_name,
                                args.exp_name, 'logs')
    utils.make_path(args.log_dir)

    # Load data
    train_data, val_data, test_data, _ = load_data_from_arguments(
        args, build_loaders=False)

    # exclude samples
    np.random.seed(args.random_baseline_seed)
    order = np.random.permutation(len(train_data))

    # if sample ranking file is given, take the order from there
    if args.sample_ranking_file is not None:
        with open(args.sample_ranking_file, 'rb') as f:
            order = pickle.load(f)

    exclude_count = int(args.exclude_ratio * len(train_data))
    if exclude_count == 0:
        exclude_indices = []
    else:
        if args.exclude_side == 'top':
            exclude_indices = order[-exclude_count:]
        else:
            exclude_indices = order[:exclude_count]

    train_data = SubsetDataWrapper(dataset=train_data,
                                   exclude_indices=exclude_indices)

    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    shuffle_train = (args.batch_size * args.num_accumulation_steps <
                     len(train_data))
    train_loader, val_loader, test_loader = get_loaders_from_datasets(
        train_data,
        val_data,
        test_data,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle_train=shuffle_train)

    # Options
    optimization_args = {
        'optimizer': {
            'name': args.optimizer,
            'lr': args.lr,
        }
    }

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_loader.dataset[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        device=args.device,
                        seed=args.seed)

    # put the model in always eval mode. This makes sure that in case the network has pretrained BatchNorm
    # layers, their running average is fixed.
    utils.put_always_eval_mode(model)

    metrics_list = [
        metrics.Accuracy(output_key='pred',
                         one_hot=(train_data[0][1].ndim > 0))
    ]
    if args.dataset == 'imagenet':
        metrics_list.append(metrics.TopKAccuracy(k=5, output_key='pred'))

    stopper = callbacks.EarlyStoppingWithMetric(
        metric=metrics_list[0],
        stopping_param=args.stopping_param,
        partition='val',
        direction='max')

    training.train(model=model,
                   train_loader=train_loader,
                   val_loader=val_loader,
                   epochs=args.epochs,
                   save_iter=args.save_iter,
                   vis_iter=args.vis_iter,
                   optimization_args=optimization_args,
                   log_dir=args.log_dir,
                   args_to_log=args,
                   stopper=stopper,
                   metrics=metrics_list,
                   device_ids=args.all_device_ids,
                   num_accumulation_steps=args.num_accumulation_steps)

    val_preds = utils.apply_on_dataset(model=model,
                                       dataset=val_data,
                                       cpu=True,
                                       partition='val',
                                       batch_size=args.batch_size)['pred']
    val_acc = metrics_list[0].value(epoch=args.epochs - 1, partition='val')

    file_name = f'results-{args.exclude_ratio:.4f}'
    if args.baseline_name == 'random':
        file_name += f'-{args.random_baseline_seed}'
    file_name += '.pkl'
    file_path = os.path.join(args.output_dir, args.baseline_name,
                             args.exp_name, file_name)
    utils.make_path(os.path.dirname(file_path))
    with open(file_path, 'wb') as f:
        pickle.dump({
            'val_preds': val_preds,
            'val_acc': val_acc,
            'args': args
        }, f)
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device',
                        '-d',
                        default='cuda',
                        help='specifies the main device')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--cpu', dest='cpu', action='store_true')
    parser.set_defaults(cpu=False)

    # data parameters
    parser.add_argument(
        '--dataset',
        '-D',
        type=str,
        default='mnist4vs9',
        choices=[
            'mnist4vs9', 'synthetic', 'cifar10-cat-vs-dog', 'cats-and-dogs'
        ],
        help='Which dataset to use. One can add more choices if needed.')
    parser.add_argument('--data_augmentation',
                        '-A',
                        action='store_true',
                        dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation',
                        action='store_true',
                        default=False)
    parser.add_argument('--resize_to_imagenet',
                        action='store_true',
                        dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset',
                        action='store_true',
                        dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)

    # hyper-parameters
    parser.add_argument('--model_class',
                        '-m',
                        type=str,
                        default='ClassifierL2')

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)
    parser.add_argument('--damping', type=float, default=1e-10)
    parser.add_argument('--scale', type=float, default=10.0)
    parser.add_argument('--recursion_depth', type=int, default=10000)
    parser.add_argument('--batch_size', type=int, default=128)

    parser.add_argument('--output_dir',
                        '-o',
                        type=str,
                        default='sample_info/results/ground-truth/')
    parser.add_argument('--exp_name', '-E', type=str, required=True)
    args = parser.parse_args()
    print(args)

    # Build data
    train_data, val_data, test_data, _ = load_data_from_arguments(
        args, build_loaders=False)
    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_data[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        seed=args.seed,
                        device=args.device)

    # load the final parameters
    saved_file_path = os.path.join(args.output_dir, 'ground-truth',
                                   args.exp_name, 'full-data-training.pkl')
    with open(saved_file_path, 'rb') as f:
        saved_data = pickle.load(f)

    params = dict(model.named_parameters())
    for k, v in saved_data['weights'].items():
        params[k].data = v.to(args.device)

    # compute per example gradients (d loss / d weights for train and d pred / d weights for validation)
    train_grads = gradients.get_weight_gradients(
        model=model,
        dataset=train_data,
        cpu=args.cpu,
        description='computing per example gradients on train data')

    jacobian_estimator = JacobianEstimator()
    val_grads = jacobian_estimator.compute_jacobian(
        model=model,
        dataset=val_data,
        cpu=args.cpu,
        description='computing jacobian on validation data')

    # compute weight and prediction influences
    weight_vectors = []
    weight_quantities = []

    pred_vectors = []
    pred_quantities = []

    for sample_idx in tqdm(range(len(train_data)),
                           desc='computing influences'):
        # compute weights
        v = []
        for k in dict(model.named_parameters()).keys():
            v.append(train_grads[k][sample_idx].to(model.device))
        inv_hvp = inverse_hvp_lissa(model,
                                    dataset=train_data,
                                    v=v,
                                    batch_size=args.batch_size,
                                    recursion_depth=args.recursion_depth,
                                    damping=args.damping,
                                    scale=args.scale)
        if args.cpu:
            inv_hvp = [utils.to_cpu(a) for a in inv_hvp]

        for a in inv_hvp:
            if torch.isnan(a).any():
                raise ValueError(
                    "Inverse hessian vector product contains NaNs. Increase the scale."
                )

        cur_weight_influence = 1.0 / len(train_data) * torch.cat(
            [a.flatten() for a in inv_hvp])
        weight_vectors.append(cur_weight_influence)
        weight_quantities.append(torch.sum(cur_weight_influence**2))

        # compute for predictions
        cur_pred_influences = []
        for val_sample_idx in range(len(val_data)):
            val_grad_flat = []
            for k, v in dict(model.named_parameters()).items():
                val_grad_flat.append(val_grads[k][val_sample_idx].flatten())
            val_grad_flat = torch.cat(val_grad_flat, dim=0)
            cur_pred_influences.append(
                torch.dot(cur_weight_influence, val_grad_flat))

        cur_pred_influences = torch.stack(cur_pred_influences)
        pred_vectors.append(cur_pred_influences)
        pred_quantities.append(torch.sum(cur_pred_influences**2))

    # save weights
    meta = {'description': f'weight influence functions', 'args': args}

    exp_dir = os.path.join(args.output_dir, 'influence-functions',
                           args.exp_name)
    process_results(vectors=weight_vectors,
                    quantities=weight_quantities,
                    meta=meta,
                    exp_name='weights',
                    output_dir=exp_dir,
                    train_data=train_data)

    # save preds
    meta = {'description': f'pred influence functions', 'args': args}

    exp_dir = os.path.join(args.output_dir, 'influence-functions',
                           args.exp_name)
    process_results(vectors=pred_vectors,
                    quantities=pred_quantities,
                    meta=meta,
                    exp_name='pred',
                    output_dir=exp_dir,
                    train_data=train_data)
コード例 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device',
                        '-d',
                        default='cuda',
                        help='specifies the main device')
    parser.add_argument('--seed', type=int, default=42)

    # data parameters
    parser.add_argument('--dataset', '-D', type=str, default='mnist4vs9')
    parser.add_argument('--data_augmentation',
                        '-A',
                        action='store_true',
                        dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation',
                        action='store_true',
                        default=False)
    parser.add_argument('--resize_to_imagenet',
                        action='store_true',
                        dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset',
                        action='store_true',
                        dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)

    # hyper-parameters
    parser.add_argument('--model_class',
                        '-m',
                        type=str,
                        default='ClassifierL2')

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate')

    parser.add_argument(
        '--output_dir',
        '-o',
        type=str,
        default='sample_info/results/data-summarization/orders/')
    parser.add_argument('--exp_name', '-E', type=str, required=True)

    # which measures to compute
    parser.add_argument('--which_measure',
                        '-w',
                        type=str,
                        required=True,
                        choices=['weights-plain', 'predictions'])

    # NTK arguments
    parser.add_argument('--t', '-t', type=int, default=None)
    parser.add_argument('--projection',
                        type=str,
                        default='none',
                        choices=['none', 'random-subset', 'very-sparse'])
    parser.add_argument('--cpu', dest='cpu', action='store_true')
    parser.set_defaults(cpu=False)
    parser.add_argument('--large_model_regime',
                        dest='large_model_regime',
                        action='store_true')
    parser.add_argument('--random_subset_n_select', type=int, default=2000)
    parser.set_defaults(large_model_regime=False)

    args = parser.parse_args()
    print(args)

    # Load data
    train_data, val_data, test_data, _ = load_data_from_arguments(
        args, build_loaders=False)

    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_data[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        device=args.device,
                        seed=args.seed)
    model.eval()
    print("Number of parameters: ", utils.get_num_parameters(model))

    iter_idx = 0
    exclude_indices = []

    while len(exclude_indices) / len(train_data) < 0.95:
        print(f"Computing the order for iteration {iter_idx}")

        # Prepare the needed terms
        cur_train_data = SubsetDataWrapper(train_data,
                                           exclude_indices=exclude_indices)
        n = len(cur_train_data)
        ret = prepare_needed_items(model=model,
                                   train_data=cur_train_data,
                                   test_data=val_data,
                                   projection=args.projection,
                                   cpu=args.cpu)

        quantities = None
        order_file_name = None

        # weights without SGD
        if args.which_measure == 'weights-plain':
            _, quantities = weight_stability(
                t=args.t,
                n=n,
                eta=args.lr / n,
                init_params=ret['init_params'],
                jacobians=ret['train_jacobians'],
                ntk=ret['ntk'],
                init_preds=ret['train_init_preds'],
                Y=ret['train_Y'],
                l2_reg_coef=n * args.l2_reg_coef,
                continuous=False,
                without_sgd=True,
                model=model,
                dataset=cur_train_data,
                large_model_regime=args.large_model_regime,
                return_change_vectors=False)

            order_file_name = f'iter{iter_idx}-weights.pkl'

        # test prediction
        if args.which_measure == 'predictions':
            _, quantities = test_pred_stability(
                t=args.t,
                n=n,
                eta=args.lr / n,
                ntk=ret['ntk'],
                test_train_ntk=ret['test_train_ntk'],
                train_init_preds=ret['train_init_preds'],
                test_init_preds=ret['test_init_preds'],
                train_Y=ret['train_Y'],
                l2_reg_coef=n * args.l2_reg_coef,
                continuous=False)

            order_file_name = f'iter{iter_idx}-predictions.pkl'

        # save the order
        relative_order = np.argsort(
            utils.to_numpy(torch.stack(quantities).flatten()))
        absolute_order = [
            cur_train_data.include_indices[rel_idx]
            for rel_idx in relative_order
        ]
        absolute_order = exclude_indices + absolute_order
        file_path = os.path.join(args.output_dir, args.exp_name,
                                 order_file_name)
        utils.make_path(os.path.dirname(file_path))
        with open(file_path, 'wb') as f:
            pickle.dump(absolute_order, f)

        # remove 5% percent of remaining samples
        exclude_count = int(0.05 * len(cur_train_data))
        new_exclude_indices = [
            cur_train_data.include_indices[rel_idx]
            for rel_idx in relative_order[:exclude_count]
        ]
        exclude_indices.extend(new_exclude_indices)
        iter_idx += 1
        print(len(exclude_indices))
コード例 #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device',
                        '-d',
                        default='cuda',
                        help='specifies the main device')
    parser.add_argument(
        '--all_device_ids',
        nargs='+',
        type=str,
        default=None,
        help=
        "If not None, this list specifies devices for multiple GPU training. "
        "The first device should match with the main device (args.device).")
    parser.add_argument('--batch_size', '-b', type=int, default=2**20)
    parser.add_argument('--epochs', '-e', type=int, default=2000)
    parser.add_argument('--stopping_param', type=int, default=2**20)
    parser.add_argument('--save_iter', '-s', type=int, default=2**20)
    parser.add_argument('--vis_iter', '-v', type=int, default=2**20)
    parser.add_argument('--log_dir',
                        '-l',
                        type=str,
                        default='sample_info/logs/junk')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument(
        '--num_accumulation_steps',
        default=1,
        type=int,
        help='Number of training steps to accumulate before updating weights')

    # data parameters
    parser.add_argument(
        '--dataset',
        '-D',
        type=str,
        default='mnist4vs9',
        choices=[
            'mnist4vs9', 'synthetic', 'cifar10-cat-vs-dog', 'cats-and-dogs'
        ],
        help='Which dataset to use. One can add more choices if needed.')
    parser.add_argument('--data_augmentation',
                        '-A',
                        action='store_true',
                        dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation',
                        action='store_true',
                        default=False)
    parser.add_argument('--resize_to_imagenet',
                        action='store_true',
                        dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset',
                        action='store_true',
                        dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers in data loaders')
    parser.add_argument('--exclude_index',
                        type=int,
                        default=None,
                        help='Index of an example to remove.')

    # hyper-parameters
    parser.add_argument('--model_class',
                        '-m',
                        type=str,
                        default='ClassifierL2')
    parser.add_argument('--linearized', dest='linearized', action='store_true')
    parser.set_defaults(linearized=False)

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--optimizer',
                        type=str,
                        default='sgd',
                        choices=['adam', 'sgd'])

    parser.add_argument(
        '--output_dir',
        '-o',
        type=str,
        default='sample_info/results/ground-truth/ground-truth/')
    parser.add_argument('--exp_name', '-E', type=str, required=True)
    args = parser.parse_args()
    print(args)

    # Build data
    train_data, val_data, test_data, _ = load_data_from_arguments(
        args, build_loaders=False)

    # exclude the example
    if args.exclude_index is not None:
        train_data = SubsetDataWrapper(dataset=train_data,
                                       exclude_indices=[args.exclude_index])

    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    shuffle_train = (args.batch_size * args.num_accumulation_steps <
                     len(train_data))
    train_loader, val_loader, test_loader = get_loaders_from_datasets(
        train_data,
        val_data,
        test_data,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle_train=shuffle_train)

    # Options
    optimization_args = {
        'optimizer': {
            'name': args.optimizer,
            'lr': args.lr,
        }
    }

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_loader.dataset[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        seed=args.seed,
                        device=args.device)

    # put the model in always eval mode. This makes sure that in case the network has pretrained BatchNorm
    # layers, their running average is fixed.
    utils.put_always_eval_mode(model)

    if args.linearized:
        print("Using a linearized model")
        model = LinearizedModelV2(model=model,
                                  train_data=train_data,
                                  val_data=val_data,
                                  l2_reg_coef=args.l2_reg_coef)

    if args.dataset == 'synthetic':
        model.visualize = (lambda *args, **kwargs: {}
                           )  # no visualization is needed

    metrics_list = [metrics.Accuracy(output_key='pred')]

    training.train(model=model,
                   train_loader=train_loader,
                   val_loader=val_loader,
                   epochs=args.epochs + 1,
                   save_iter=args.save_iter,
                   vis_iter=args.vis_iter,
                   optimization_args=optimization_args,
                   log_dir=args.log_dir,
                   args_to_log=args,
                   metrics=metrics_list,
                   device_ids=args.all_device_ids,
                   num_accumulation_steps=args.num_accumulation_steps)

    params = dict(model.named_parameters())
    for k in params.keys():
        params[k] = utils.to_cpu(params[k])
    val_preds = utils.apply_on_dataset(model=model,
                                       dataset=val_data,
                                       cpu=True,
                                       partition='val',
                                       batch_size=args.batch_size)['pred']
    val_acc = metrics_list[0].value(epoch=args.epochs, partition='val')

    exp_dir = os.path.join(args.output_dir, args.exp_name)

    # if it the the full dataset save params and val_preds, otherwise compare to the saved weights/predictions
    if args.exclude_index is None:
        file_path = os.path.join(exp_dir, 'full-data-training.pkl')
    else:
        file_path = os.path.join(exp_dir, f'{args.exclude_index}.pkl')

    utils.make_path(os.path.dirname(file_path))
    with open(file_path, 'wb') as f:
        pickle.dump(
            {
                'weights': params,
                'val_preds': val_preds,
                'val_acc': val_acc,
                'args': args
            }, f)
コード例 #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device',
                        '-d',
                        default='cuda',
                        help='specifies the main device')
    parser.add_argument(
        '--all_device_ids',
        nargs='+',
        type=str,
        default=None,
        help=
        "If not None, this list specifies devices for multiple GPU training. "
        "The first device should match with the main device (args.device).")
    parser.add_argument('--batch_size', '-b', type=int, default=256)
    parser.add_argument('--epochs', '-e', type=int, default=400)
    parser.add_argument('--stopping_param', type=int, default=2**30)
    parser.add_argument('--save_iter', '-s', type=int, default=10)
    parser.add_argument('--vis_iter', '-v', type=int, default=10)
    parser.add_argument('--log_dir', '-l', type=str, default=None)
    parser.add_argument('--seed', type=int, default=42)

    # data parameters
    parser.add_argument('--dataset', '-D', type=str, default='corrupt4_mnist')
    parser.add_argument('--data_augmentation',
                        '-A',
                        action='store_true',
                        dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation',
                        action='store_true',
                        default=False)

    # hyper-parameters
    parser.add_argument('--model_class',
                        '-m',
                        type=str,
                        default='ClassifierL2WithGradCollector')

    parser.add_argument('--weight_decay', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--optimizer',
                        type=str,
                        default='adam',
                        choices=['adam', 'sgd'])

    parser.add_argument('--output_dir',
                        '-o',
                        type=str,
                        default='results/stability/mnist-4vs9-1000-samples/')
    args = parser.parse_args()
    print(args)

    # Load data
    # TODO: remove hard coding
    train_data, val_data, test_data, _ = load_data_from_arguments(
        {
            'dataset': 'mnist',
            'num_train_examples': 10 * 500
        },
        build_loaders=False)
    train_data = BinaryDatasetWrapper(train_data, which_labels=(4, 9))
    val_data = BinaryDatasetWrapper(val_data, which_labels=(4, 9))
    test_data = BinaryDatasetWrapper(test_data, which_labels=(4, 9))

    train_data = ReturnSampleIndexWrapper(train_data)
    val_data = ReturnSampleIndexWrapper(val_data)
    test_data = ReturnSampleIndexWrapper(test_data)

    train_loader, val_loader, test_loader = get_loaders_from_datasets(
        train_data,
        val_data,
        test_data,
        batch_size=2**30,
        shuffle_train=False,
        num_workers=0)

    # Options
    optimization_args = {
        'optimizer': {
            'name': args.optimizer,
            'lr': args.lr,
            'weight_decay': args.weight_decay
        }
    }

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    ts = range(100, 401, 100)

    for t in ts:
        model_class = getattr(methods, args.model_class)

        model = model_class(input_shape=train_loader.dataset[0][0][0].shape,
                            architecture_args=architecture_args,
                            device=args.device,
                            seed=args.seed)

        metrics_list = [metrics.Accuracy(output_key='pred')]

        training.train(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            epochs=t,
            save_iter=args.save_iter,
            vis_iter=2**30,  # NOTE: never visualize
            optimization_args=optimization_args,
            log_dir=args.log_dir,
            args_to_log=args,
            metrics=metrics_list,
            device_ids=args.all_device_ids)

        vectors = model._grad_updates

        norms = []
        for i in range(len(train_data)):
            grad_dict = vectors[i]
            norm = 0.0
            for k, v in grad_dict.items():
                norm += torch.norm(v.flatten())
            norms.append(norm)

        quantities = norms

        meta = {
            'description':
            'Total gradient update per example. The measures are the norm of total gradient update.',
            'time': t,
            'continuous': False,
            'args': args
        }

        process_results(vectors=vectors,
                        quantities=quantities,
                        meta=meta,
                        exp_name=f'total-grad-t{t}',
                        output_dir=args.output_dir,
                        train_data=train_data.dataset)
コード例 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device', '-d', default='cuda', help='specifies the main device')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch_size', '-b', type=int, default=256)

    # data parameters
    parser.add_argument('--dataset', '-D', type=str, default='mnist4vs9')
    parser.add_argument('--data_augmentation', '-A', action='store_true', dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation', action='store_true', default=False)
    parser.add_argument('--resize_to_imagenet', action='store_true', dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset', action='store_true', dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)

    # hyper-parameters
    parser.add_argument('--model_class', '-m', type=str, default='ClassifierL2')

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate')

    parser.add_argument('--output_dir', '-o', type=str, default='sample_info/results/ground-truth/informativeness')
    parser.add_argument('--exp_name', '-E', type=str, required=True)

    # which measures to compute
    parser.add_argument('--which_measures', '-w', type=str, nargs='+', required=True,
                        help="Options are 'weights-full', 'weights-plain', and 'predictions'")

    # NTK arguments
    parser.add_argument('--t', '-t', type=int, default=None)
    parser.add_argument('--projection', type=str, default='none', choices=['none', 'random-subset', 'very-sparse'])
    parser.add_argument('--cpu', dest='cpu', action='store_true')
    parser.set_defaults(cpu=False)
    parser.add_argument('--large_model_regime', dest='large_model_regime', action='store_true')
    parser.set_defaults(large_model_regime=False)
    parser.add_argument('--random_subset_n_select', type=int, default=2000)
    parser.add_argument('--return_change_vectors', dest='return_change_vectors', action='store_true')
    parser.set_defaults(return_change_vectors=False)

    args = parser.parse_args()
    print(args)

    # Load data
    train_data, val_data, test_data, _ = load_data_from_arguments(args, build_loaders=False)
    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_data[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        device=args.device,
                        seed=args.seed)
    model.eval()
    print("Number of parameters: ", utils.get_num_parameters(model))

    # Prepare the needed terms
    ret = prepare_needed_items(model=model, train_data=train_data, test_data=val_data,
                               projection=args.projection, cpu=args.cpu, batch_size=args.batch_size,
                               random_subset_n_select=args.random_subset_n_select)
    n = len(train_data)

    exp_dir = os.path.join(args.output_dir, args.exp_name)

    # weights with SGD
    if 'weights-full' in args.which_measures:
        vectors, quantities = weight_stability(t=args.t, n=n, eta=args.lr / n, init_params=ret['init_params'],
                                               jacobians=ret['train_jacobians'], ntk=ret['ntk'],
                                               init_preds=ret['train_init_preds'], Y=ret['train_Y'],
                                               l2_reg_coef=n * args.l2_reg_coef, continuous=False,
                                               without_sgd=False, model=model, dataset=train_data,
                                               large_model_regime=args.large_model_regime,
                                               return_change_vectors=False,
                                               batch_size=args.batch_size)

        meta = {
            'description': f'weights (full) at epoch {args.t}',
            'continuous': False,
            'args': args
        }

        process_results(vectors=vectors, quantities=quantities, meta=meta,
                        exp_name=f'weight-full-t{args.t}', output_dir=exp_dir, train_data=train_data)

    # weights without SGD
    if 'weights-plain' in args.which_measures:
        vectors, quantities = weight_stability(t=args.t, n=n, eta=args.lr / n, init_params=ret['init_params'],
                                               jacobians=ret['train_jacobians'], ntk=ret['ntk'],
                                               init_preds=ret['train_init_preds'], Y=ret['train_Y'],
                                               l2_reg_coef=n * args.l2_reg_coef, continuous=False,
                                               without_sgd=True, model=model, dataset=train_data,
                                               large_model_regime=args.large_model_regime,
                                               return_change_vectors=args.return_change_vectors,
                                               batch_size=args.batch_size)

        meta = {
            'description': f'weights (plain) at epoch {args.t}',
            'continuous': False,
            'args': args
        }

        process_results(vectors=vectors, quantities=quantities, meta=meta,
                        exp_name=f'weight-plain-t{args.t}', output_dir=exp_dir, train_data=train_data)

    # test prediction
    if 'predictions' in args.which_measures:
        vectors, quantities = test_pred_stability(t=args.t, n=n, eta=args.lr / n, ntk=ret['ntk'],
                                                  test_train_ntk=ret['test_train_ntk'],
                                                  train_init_preds=ret['train_init_preds'],
                                                  test_init_preds=ret['test_init_preds'],
                                                  train_Y=ret['train_Y'],
                                                  l2_reg_coef=n * args.l2_reg_coef,
                                                  continuous=False)
        meta = {
            'description': f'validation predictions at epoch {args.t}',
            'continuous': False,
            'args': args
        }

        process_results(vectors=vectors, quantities=quantities, meta=meta,
                        exp_name=f'predictions-t{args.t}', output_dir=exp_dir, train_data=train_data)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device',
                        '-d',
                        default='cuda',
                        help='specifies the main device')
    parser.add_argument('--seed', type=int, default=42)

    # data parameters
    parser.add_argument(
        '--dataset',
        '-D',
        type=str,
        default='mnist4vs9',
        choices=['mnist4vs9', 'synthetic', 'cifar10-cat-vs-dog'],
        help='Which dataset to use. One can add more choices if needed.')
    parser.add_argument('--data_augmentation',
                        '-A',
                        action='store_true',
                        dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation',
                        action='store_true',
                        default=False)
    parser.add_argument('--resize_to_imagenet',
                        action='store_true',
                        dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset',
                        action='store_true',
                        dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)

    # hyper-parameters
    parser.add_argument('--model_class',
                        '-m',
                        type=str,
                        default='ClassifierL2')

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)

    parser.add_argument('--output_dir',
                        '-o',
                        type=str,
                        default='sample_info/results/ground-truth/')
    parser.add_argument('--exp_name', '-E', type=str, required=True)
    args = parser.parse_args()
    print(args)

    # Build data
    train_data, val_data, test_data, _ = load_data_from_arguments(
        args, build_loaders=False)
    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    train_loader, val_loader, test_loader = get_loaders_from_datasets(
        train_data, val_data, test_data, batch_size=2**30, shuffle_train=False)

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_data[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        seed=args.seed,
                        device=args.device)

    # load the final parameters
    saved_file_path = os.path.join(args.output_dir, 'ground-truth',
                                   args.exp_name, 'full-data-training.pkl')
    with open(saved_file_path, 'rb') as f:
        saved_data = pickle.load(f)

    params = dict(model.named_parameters())
    for k, v in saved_data['weights'].items():
        params[k].data = v.to(args.device)

    # brute force compute hessian and its inverse
    total_loss = 0.0
    for x, y in train_loader:
        out = model.forward(inputs=[x], labels=[y])
        losses, _ = model.compute_loss(inputs=[x], labels=[y], outputs=out)
        total_loss = total_loss + sum([v for k, v in losses.items()])

    with utils.Timing(description='Computing the Hessian'):
        H = hessian(ys=[total_loss], xs=tuple(model.parameters()))

    params = tuple(model.parameters())
    for i in range(len(H)):
        for j in range(len(H[i])):
            ni = params[i].nelement()
            nj = params[j].nelement()
            H[i][j] = H[i][j].reshape((ni, nj))
        H[i] = torch.cat(H[i], dim=1)
    H = torch.cat(H, dim=0)
    # add extra eps to the diagonal to make it invertible
    if args.l2_reg_coef < 1e-10:
        H += 1e-10 * torch.eye(H.shape[0], dtype=torch.float, device=H.device)
    print(f"Hessian shape: {H.shape}")
    H_inv = torch.inverse(H)

    # compute per example gradients (d loss / d weights for train and d pred / d weights for validation)
    train_grads = gradients.get_weight_gradients(
        model=model,
        dataset=train_data,
        cpu=False,
        description='computing per example gradients on train data')

    jacobian_estimator = JacobianEstimator()
    val_grads = jacobian_estimator.compute_jacobian(
        model=model,
        dataset=val_data,
        cpu=False,
        description='computing jacobian on validation data')

    # compute weight and prediction influences
    weight_vectors = []
    weight_quantities = []

    pred_vectors = []
    pred_quantities = []

    for sample_idx in tqdm(range(len(train_data)),
                           desc='computing influences'):
        # compute for weights
        train_grad_flat = []
        for k, v in dict(model.named_parameters()).items():
            train_grad_flat.append(train_grads[k][sample_idx].flatten())
        train_grad_flat = torch.cat(train_grad_flat, dim=0)

        cur_weight_influence = 1.0 / len(train_data) * torch.mm(
            H_inv, train_grad_flat.view((-1, 1)))
        cur_weight_influence = cur_weight_influence.view((-1, ))
        weight_vectors.append(cur_weight_influence)
        weight_quantities.append(torch.sum(cur_weight_influence**2))

        # compute for predictions
        cur_pred_influences = []
        for val_sample_idx in range(len(val_data)):
            val_grad_flat = []
            for k, v in dict(model.named_parameters()).items():
                val_grad_flat.append(val_grads[k][val_sample_idx].flatten())
            val_grad_flat = torch.cat(val_grad_flat, dim=0)
            cur_pred_influences.append(
                torch.dot(cur_weight_influence, val_grad_flat))

        cur_pred_influences = torch.stack(cur_pred_influences)
        pred_vectors.append(cur_pred_influences)
        pred_quantities.append(torch.sum(cur_pred_influences**2))

    # save weights
    meta = {'description': f'weight influence functions', 'args': args}

    exp_dir = os.path.join(args.output_dir, 'influence-functions-brute-force',
                           args.exp_name)
    process_results(vectors=weight_vectors,
                    quantities=weight_quantities,
                    meta=meta,
                    exp_name='weights',
                    output_dir=exp_dir,
                    train_data=train_data)

    # save preds
    meta = {'description': f'pred influence functions', 'args': args}

    exp_dir = os.path.join(args.output_dir, 'influence-functions-brute-force',
                           args.exp_name)
    process_results(vectors=pred_vectors,
                    quantities=pred_quantities,
                    meta=meta,
                    exp_name='pred',
                    output_dir=exp_dir,
                    train_data=train_data)