Ejemplo n.º 1
0
def select_model(m):
    if m == 'large':
        # raise ValueError
        model = pblm.cifar_model_large().cuda()
    else:
        model = pblm.cifar_model().cuda()
    return model
Ejemplo n.º 2
0
def select_model(m): 
    if m == 'large': 
        # raise ValueError
        model = pblm.cifar_model_large().cuda()
    elif m == 'resnet': 
        model = pblm.cifar_model_resnet(N=args.resnet_N, factor=args.resnet_factor).cuda()
    else: 
        model = pblm.cifar_model().cuda() 
    return model
Ejemplo n.º 3
0
def select_model(m):
    if m == 'small':
        model = pblm.cifar_model().cuda()
    elif m == 'large':
        model = pblm.cifar_model_large().cuda()
    # elif m == 'resNet':
    #     model = pblm.cifar_model_resnet().cuda()
    else:
        raise ValueError('model argument not recognized for imagenet')
    return model
def select_model(m):
    if m == 'large':
        # raise ValueError
        model = pblm.cifar_model_large().to(device)
    elif m == 'resnet':
        model = pblm.cifar_model_resnet(N=args.resnet_N,
                                        factor=args.resnet_factor).to(device)
    else:
        model = pblm.cifar_model().to(device)

    summary(model, (3, 32, 32))
    return model
Ejemplo n.º 5
0
def select_model(m):
    if m == 'large':
        # raise ValueError
        model = pblm.cifar_model_large().cuda()
    elif m == 'resnet':
        model = pblm.cifar_model_resnet(N=args.resnet_N,
                                        factor=args.resnet_factor).cuda()
    elif m == 'm1':
        print('using a reduced sized network')
        model = pblm.cifar_model_m1().cuda()
    elif m == 'm2':
        print('using a slightly reduced sized network')
        model = pblm.cifar_model_m2().cuda()
    else:
        model = pblm.cifar_model().cuda()
    return model
Ejemplo n.º 6
0
    parser.add_argument('--fashion', action='store_true')
    parser.add_argument('--model')

    args = parser.parse_args()

    if args.mnist: 
        train_loader, test_loader = pblm.mnist_loaders(args.batch_size)
        model = pblm.mnist_model().to(device)
        model.load_state_dict(torch.load('icml/mnist_epochs_100_baseline_model.pth'))
    elif args.svhn: 
        train_loader, test_loader = pblm.svhn_loaders(args.batch_size)
        model = pblm.svhn_model().to(device)
        model.load_state_dict(torch.load('pixel2/svhn_small_batch_size_50_epochs_100_epsilon_0.0078_l1_proj_50_l1_test_median_l1_train_median_lr_0.001_opt_adam_schedule_length_20_seed_0_starting_epsilon_0.001_checkpoint.pth')['state_dict'])
    elif args.model == 'cifar': 
        train_loader, test_loader = pblm.cifar_loaders(args.batch_size)
        model = pblm.cifar_model().to(device)
        model.load_state_dict(torch.load('pixel2/cifar_small_batch_size_50_epochs_100_epsilon_0.0347_l1_proj_50_l1_test_median_l1_train_median_lr_0.05_momentum_0.9_opt_sgd_schedule_length_20_seed_0_starting_epsilon_0.001_weight_decay_0.0005_checkpoint.pth')['state_dict'])
    elif args.har:
        pass
    elif args.fashion: 
        pass
    else:
        raise ValueError("Need to specify which problem.")
    for p in model.parameters(): 
        p.requires_grad = False

    num_classes = model[-1].out_features

    correct = []
    incorrect = []
    l = []
Ejemplo n.º 7
0
    if args.mnist:
        train_loader, test_loader = pblm.mnist_loaders(args.batch_size)
        model = pblm.mnist_model().cuda()
        model.load_state_dict(
            torch.load('icml/mnist_epochs_100_baseline_model.pth'))
    elif args.svhn:
        train_loader, test_loader = pblm.svhn_loaders(args.batch_size)
        model = pblm.svhn_model().cuda()
        model.load_state_dict(
            torch.load(
                'pixel2/svhn_small_batch_size_50_epochs_100_epsilon_0.0078_l1_proj_50_l1_test_median_l1_train_median_lr_0.001_opt_adam_schedule_length_20_seed_0_starting_epsilon_0.001_checkpoint.pth'
            )['state_dict'])
    elif args.model == 'cifar':
        train_loader, test_loader = pblm.cifar_loaders(args.batch_size)
        model = pblm.cifar_model().cuda()
        model.load_state_dict(
            torch.load(
                'pixel2/cifar_small_batch_size_50_epochs_100_epsilon_0.0347_l1_proj_50_l1_test_median_l1_train_median_lr_0.05_momentum_0.9_opt_sgd_schedule_length_20_seed_0_starting_epsilon_0.001_weight_decay_0.0005_checkpoint.pth'
            )['state_dict'])
    elif args.har:
        pass
    elif args.fashion:
        pass
    else:
        raise ValueError("Need to specify which problem.")
    for p in model.parameters():
        p.requires_grad = False

    num_classes = model[-1].out_features