コード例 #1
0
ファイル: adv_train.py プロジェクト: jtx1999/perceptual-advex
            if args.num_epochs is None:
                args.num_epochs = 100
        elif (args.dataset.startswith('imagenet')
              or args.dataset == 'bird_or_bicycle'):
            if args.lr is None:
                args.lr = 1e-1
            if args.lr_schedule is None:
                args.lr_schedule = '30,60,80'
            if args.num_epochs is None:
                args.num_epochs = 90

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    dataset, model = get_dataset_model(args)
    if isinstance(model, FeatureModel):
        model.allow_train()
    if torch.cuda.is_available():
        model.cuda()

    if args.lpips_model is not None:
        _, lpips_model = get_dataset_model(args,
                                           checkpoint_fname=args.lpips_model)
        if torch.cuda.is_available():
            lpips_model.cuda()

    train_loader, val_loader = dataset.make_loaders(workers=4,
                                                    batch_size=args.batch_size)

    attacks = [eval(attack_str) for attack_str in args.attack]
コード例 #2
0
    add_dataset_model_arguments(parser, include_checkpoint=True)
    parser.add_argument('attacks', metavar='attack', type=str, nargs='+',
                        help='attack names')
    parser.add_argument('--batch_size', type=int, default=100,
                        help='number of examples/minibatch')
    parser.add_argument('--parallel', type=int, default=1,
                        help='number of GPUs to train on')
    parser.add_argument('--num_batches', type=int, required=False,
                        help='number of batches (default entire dataset)')
    parser.add_argument('--per_example', action='store_true', default=False,
                        help='output per-example accuracy')
    parser.add_argument('--output', type=str, help='output CSV')

    args = parser.parse_args()

    dataset, model = get_dataset_model(args)
    _, val_loader = dataset.make_loaders(1, args.batch_size, only_val=True)

    model.eval()
    if torch.cuda.is_available():
        model.cuda()

    attack_names: List[str] = args.attacks
    attacks = [eval(attack_name) for attack_name in attack_names]

    # Parallelize
    if torch.cuda.is_available():
        device_ids = list(range(args.parallel))
        model = nn.DataParallel(model, device_ids)
        attacks = [nn.DataParallel(attack, device_ids) for attack in attacks]
コード例 #3
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Common corruptions evaluation')

    add_dataset_model_arguments(parser, include_checkpoint=True)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--num_batches',
                        type=int,
                        required=False,
                        help='number of batches (default entire dataset)')
    parser.add_argument('--output', type=str, help='output CSV')

    args = parser.parse_args()

    _, model = get_dataset_model(args)
    dataset_cls = DATASETS[args.dataset]

    alexnet_args = copy.deepcopy(args)
    alexnet_args.arch = 'alexnet'
    alexnet_args.checkpoint = None
    if args.dataset == 'cifar10c':
        alexnet_checkpoint_fname = 'data/checkpoints/alexnet_cifar.pt'
    elif args.dataset == 'imagenet100c':
        alexnet_checkpoint_fname = 'data/checkpoints/alexnet_imagenet100.pt'
    else:
        raise ValueError(f'Invalid dataset "{args.dataset}"')
    _, alexnet = get_dataset_model(alexnet_args,
                                   checkpoint_fname=alexnet_checkpoint_fname)

    model.eval()