예제 #1
0
파일: loading.py 프로젝트: renebidart/evol
def net_from_args(args, num_classes, IM_SIZE):
    if (args.net_type == 'PreActResNet18'):
        net = PreActResNet18()
        file_name = 'PreActResNet18-'
    elif (args.net_type == 'resnet'):
        net = ResNet(args.depth, num_classes, IM_SIZE)
        file_name = 'resnet-' + str(args.depth)
    elif (args.net_type == 'wide-resnet'):
        net = Wide_ResNet(args.depth, args.widen_factor, args.dropout,
                          num_classes, IM_SIZE)
        file_name = 'wide-resnet-' + str(args.depth) + 'x' + str(
            args.widen_factor)
    elif (args.net_type == 'SimpleNetMNIST'):
        net = SimpleNetMNIST()
        file_name = 'SimpleNetMNIST'
    else:
        # print('Error : Wrong net type')
        sys.exit(0)
    return net, file_name


# def load_net_cifar(model_loc):
#     """ Make a model
#     Network must be saved in the form model_name-depth, where this is a unique identifier
#     """
#     model_file = Path(model_loc).name
#     model_name = model_file.split('-')[0]
#     print('Loading model_file', model_file)
#     if (model_name == 'vggnet'):
#         model = VGG(int(model_file.split('-')[1]), 10)
#     elif (model_name == 'resnet'):
#         model = ResNet(int(model_file.split('-')[1]), 10)
#     # so ugly
#     elif (model_name == 'preact_resnet'):
#         if model_file.split('/')[-1].split('_')[2] == 'model':
#             model = PreActResNet(int(model_file.split('-')[1].split('_')[0]), 10)
#         else:
#             model = PResNetReg(int(model_file.split('-')[1]), float(model_file.split('-')[2]), 1, 10)

#     elif (model_name == 'wide'):
#         model = Wide_ResNet(model_file.split('-')[2][0:2], model_file.split('-')[2][2:4], 0, 10, 32)

#     # Dumb ones
#     elif (model_name == 'PResNetRegNoRelU'):
#         model = PResNetRegNoRelU(int(model_file.split('-')[1]), float(model_file.split('-')[2]), 1, 10)

#     else:
#         print(f'Error : {model_file} not found')
#         sys.exit(0)
#     model.load_state_dict(torch.load(model_loc)['state_dict'])
#     return model
예제 #2
0
def get_net(network: str, num_classes) -> torch.nn.Module:
    return VGG('VGG16', num_classes=num_classes) if network == 'VGG16' else \
        ResNet34(num_classes=num_classes) if network == 'ResNet34' else \
        PreActResNet18(num_classes=num_classes) if network == 'PreActResNet18' else \
        GoogLeNet(num_classes=num_classes) if network == 'GoogLeNet' else \
        densenet_cifar(num_classes=num_classes) if network == 'densenet_cifar' else \
        ResNeXt29_2x64d(num_classes=num_classes) if network == 'ResNeXt29_2x64d' else \
        MobileNet(num_classes=num_classes) if network == 'MobileNet' else \
        MobileNetV2(num_classes=num_classes) if network == 'MobileNetV2' else \
        DPN92(num_classes=num_classes) if network == 'DPN92' else \
        ShuffleNetG2(num_classes=num_classes) if network == 'ShuffleNetG2' else \
        SENet18(num_classes=num_classes) if network == 'SENet18' else \
        ShuffleNetV2(1, num_classes=num_classes) if network == 'ShuffleNetV2' else \
        EfficientNetB0(
            num_classes=num_classes) if network == 'EfficientNetB0' else None
예제 #3
0
    def _get_model(self, backbone):
        if backbone == 'resnet18':
            model = resnet18(pretrained=True, num_classes=self.args.classnum).to(self.args.device)
        elif backbone == 'resnet34':
            model = resnet34(pretrained=True, num_classes=self.args.classnum).to(self.args.device)
        elif backbone == 'resnet50':
            model = resnet50(pretrained=True, num_classes=self.args.classnum).to(self.args.device)
        elif backbone == 'resnet101':
            model = resnet101(pretrained=True, num_classes=self.args.classnum).to(self.args.device)
        elif backbone == 'resnet152':
            model = resnet152(pretrained=True, num_classes=self.args.classnum).to(self.args.device)
        elif backbone == 'preact_resnet18':
            model = PreActResNet18(num_classes=self.args.classnum, input_size=self.args.image_size,
                                   input_dim=self.args.input_dim).to(self.args.device)
        elif backbone == 'preact_resnet34':
            model = PreActResNet34(num_classes=self.args.classnum, input_size=self.args.image_size,
                                   input_dim=self.args.input_dim).to(self.args.device)
        elif backbone == 'preact_resnet50':
            model = PreActResNet50(num_classes=self.args.classnum, input_size=self.args.image_size,
                                   input_dim=self.args.input_dim).to(self.args.device)
        elif backbone == 'preact_resnet101':
            model = PreActResNet101(num_classes=self.args.classnum, input_size=self.args.image_size,
                                    input_dim=self.args.input_dim).to(self.args.device)
        elif backbone == 'preact_resnet152':
            model = PreActResNet152(num_classes=self.args.classnum, input_size=self.args.image_size,
                                    input_dim=self.args.input_dim).to(self.args.device)
        elif backbone == 'densenet121':
            model = densenet121(num_classes=self.args.classnum, pretrained=True).to(self.args.device)
        elif backbone == 'densenet161':
            model = densenet161(num_classes=self.args.classnum, pretrained=True).to(self.args.device)
        elif backbone == 'densenet169':
            model = densenet169(num_classes=self.args.classnum, pretrained=True).to(self.args.device)
        elif backbone == 'densenet201':
            model = densenet201(num_classes=self.args.classnum, pretrained=True).to(self.args.device)
        elif backbone == 'mlp':
            model = MLPNet().to(self.args.device)
        elif backbone == 'cnn_small' or backbone == "CNN_SMALL":
            model = CNN_small(self.args.classnum).to(self.args.device)
        elif backbone == "cnn" or backbone == "CNN":
            model = CNN(n_outputs=self.args.classnum, input_channel=self.args.input_dim, linear_num=self.args.linear_num).to(self.args.device)
        else:
            print("No matched backbone. Using ResNet50...")
            model = resnet50(pretrained=True, num_classes=self.args.classnum,
                             input_size=self.args.image_size).to(self.args.device)

        return model
예제 #4
0
 def select(self, model, args):
     """
     Selector utility to create models from model directory
     :param model: which model to select. Currently choices are: (cnn | resnet | preact_resnet | densenet | wresnet)
     :return: neural network to be trained
     """
     if model == 'cnn':
         net = SimpleModel(in_shape=self.in_shape,
                           activation=args.activation,
                           num_classes=self.num_classes,
                           filters=args.filters,
                           strides=args.strides,
                           kernel_sizes=args.kernel_sizes,
                           linear_widths=args.linear_widths,
                           use_batch_norm=args.use_batch_norm)
     else:
         assert (args.dataset != 'MNIST' and args.dataset != 'Fashion-MNIST'), \
             "Cannot use resnet or densenet for mnist style data"
         if model == 'resnet':
             assert args.resdepth in [18, 34, 50, 101, 152], \
                 "Non-standard and unsupported resnet depth ({})".format(args.resdepth)
             if args.resdepth == 18:
                 net = ResNet18(self.num_classes)
             elif args.resdepth == 34:
                 net = ResNet34(self.num_classes)
             elif args.resdepth == 50:
                 net = ResNet50(self.num_classes)
             elif args.resdepth == 101:
                 net = ResNet101(self.num_classes)
             else:
                 net = ResNet152()
         elif model == 'densenet':
             assert args.resdepth in [121, 161, 169, 201], \
                 "Non-standard and unsupported densenet depth ({})".format(args.resdepth)
             if args.resdepth == 121:
                 net = DenseNet121(
                     growth_rate=12, num_classes=self.num_classes
                 )  # NB NOTE: growth rate controls cifar implementation
             elif args.resdepth == 161:
                 net = DenseNet161(growth_rate=12,
                                   num_classes=self.num_classes)
             elif args.resdepth == 169:
                 net = DenseNet169(growth_rate=12,
                                   num_classes=self.num_classes)
             else:
                 net = DenseNet201(growth_rate=12,
                                   num_classes=self.num_classes)
         elif model == 'preact_resnet':
             assert args.resdepth in [18, 34, 50, 101, 152], \
                 "Non-standard and unsupported preact resnet depth ({})".format(args.resdepth)
             if args.resdepth == 18:
                 net = PreActResNet18(self.num_classes)
             elif args.resdepth == 34:
                 net = PreActResNet34(self.num_classes)
             elif args.resdepth == 50:
                 net = PreActResNet50(self.num_classes)
             elif args.resdepth == 101:
                 net = PreActResNet101(self.num_classes)
             else:
                 net = PreActResNet152()
         elif model == 'wresnet':
             assert ((args.resdepth - 4) % 6 == 0), \
                 "Wideresnet depth of {} not supported, must fulfill: (depth - 4) % 6 = 0".format(args.resdepth)
             net = WideResNet(depth=args.resdepth,
                              num_classes=self.num_classes,
                              widen_factor=args.widen_factor)
         else:
             raise NotImplementedError(
                 'Model {} not supported'.format(model))
     return net
예제 #5
0
EMBEDDING_SIZE = 500 if dataset == 'mnist' else 512


def experiment_id(dataset, k, tau, nloglr, method):
    return 'baseline-resnet-%s-%s-k%d-t%d-b%d' % (dataset, method, k, tau,
                                                  nloglr)


e_id = experiment_id(dataset, k, tau * 10, args.nloglr, method)

gpu = torch.device('cuda')

if dataset == 'mnist':
    h_phi = ConvNet().to(gpu)
else:
    h_phi = PreActResNet18(
        num_channels=3 if dataset == 'cifar10' else 1).to(gpu)

optimizer = torch.optim.SGD(h_phi.parameters(),
                            lr=LEARNING_RATE,
                            momentum=0.9,
                            weight_decay=5e-4)

linear_layer = torch.nn.Linear(EMBEDDING_SIZE, 10).to(device=gpu)
ce_loss = torch.nn.CrossEntropyLoss()

batched_train = split.get_train_loader(NUM_TRAIN_QUERIES)


def train(epoch):
    h_phi.train()
    to_average = []
예제 #6
0
    def select(self, model, path_fc=False, upsample='pixel'):
        if model == 'cnn':
            net = SimpleModel(
                in_shape=self.in_shape,
                activation=self.activation,
                num_classes=self.num_classes,
                filters=self.filters,
            )
        else:
            assert (self.dataset != 'MNIST' and self.dataset != 'Fashion-MNIST'
                    ), "Cannot use resnet or densenet for mnist style data"
            if model == 'resnet':
                assert self.resdepth in [
                    18, 34, 50, 101, 152
                ], "Non-standard and unsupported resnet depth ({})".format(
                    self.resdepth)
                if self.resdepth == 18:
                    net = ResNet18()
                elif self.resdepth == 34:
                    net = ResNet34()
                elif self.resdepth == 50:
                    net = ResNet50()
                elif self.resdepth == 101:
                    net = ResNet101()
                else:
                    net = ResNet152()
            elif model == 'densenet':
                assert self.resdepth in [
                    121, 161, 169, 201
                ], "Non-standard and unsupported densenet depth ({})".format(
                    self.resdepth)
                if self.resdepth == 121:
                    net = DenseNet121()
                elif self.resdepth == 161:
                    net = DenseNet161()
                elif self.resdepth == 169:
                    net = DenseNet169()
                else:
                    net = DenseNet201()
            elif model == 'preact_resnet':
                assert self.resdepth in [
                    10, 18, 34, 50, 101, 152
                ], "Non-standard and unsupported preact resnet depth ({})".format(
                    self.resdepth)
                if self.resdepth == 10:
                    net = PreActResNet10(path_fc=path_fc,
                                         num_classes=self.num_classes,
                                         upsample=upsample)
                elif self.resdepth == 18:
                    net = PreActResNet18()
                elif self.resdepth == 34:
                    net = PreActResNet34()
                elif self.resdepth == 50:
                    net = PreActResNet50()
                elif self.resdepth == 101:
                    net = PreActResNet101()
                else:
                    net = PreActResNet152()
            elif model == 'wresnet':
                assert (
                    (self.resdepth - 4) % 6 == 0
                ), "Wideresnet depth of {} not supported, must fulfill: (depth - 4) % 6 = 0".format(
                    self.resdepth)
                net = WideResNet(depth=self.resdepth,
                                 num_classes=self.num_classes,
                                 widen_factor=self.widen_factor)

        return net
예제 #7
0
def get_model(model_name, dataset_name):
    if dataset_name == "mnist":
        grayscale = True
        num_classes = 10
    elif dataset_name == "cifar10":
        grayscale = False
        num_classes = 10
    elif dataset_name == "cifar100":
        grayscale = False
        num_classes = 100
    elif dataset_name == "tiny-imagenet":
        grayscale = False
        num_classes = 200
    elif dataset_name == "clothing1m":
        grayscale = False
        num_classes = 14
    else:
        raise NameError("Invalid dataset")

    if model_name == "jocor_model":
        if dataset_name == "mnist":
            net = MLPNet()
        elif dataset_name == "cifar10":
            net = CNN(n_outputs=10)
        elif dataset_name == "cifar100":
            net = CNN(n_outputs=100)
        elif dataset_name == "tiny-imagenet":
            net = PreActResNet18(num_classes)
        elif dataset_name == "clothing1m":
            model = getattr(resnet, "resnet18")
            net = model(grayscale, num_classes)
            net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3])
        return net

    elif model_name.startswith("attention"):
        model = getattr(attention, model_name)
    elif model_name.startswith("densenet"):
        model = getattr(densenet, model_name)
    elif model_name.startswith("googlenet"):
        model = getattr(googlenet, model_name)
    elif model_name.startswith("inceptionv3"):
        model = getattr(inceptionv3, model_name)
    elif model_name.startswith("inception"):
        model = getattr(inceptionv4, model_name)
    elif model_name.startswith("mobilenetv2"):
        model = getattr(mobilenetv2, model_name)
    elif model_name.startswith("mobilenet"):
        model = getattr(mobilenet, model_name)
    elif model_name.startswith("nasnet"):
        model = getattr(nasnet, model_name)
    elif model_name.startswith("preactresnet"):
        model = getattr(preactresnet, model_name)
    elif model_name.startswith("resnet"):
        model = getattr(resnet, model_name)
    elif model_name.startswith("resnext"):
        model = getattr(resnext, model_name)
    elif model_name.startswith("rir"):
        model = getattr(rir, model_name)
    elif model_name.startswith("seresnet"):
        model = getattr(senet, model_name)
    elif model_name.startswith("shufflenetv2"):
        model = getattr(shufflenetv2, model_name)
    elif model_name.startswith("shufflenet"):
        model = getattr(shufflenet, model_name)
    elif model_name.startswith("squeezenet"):
        model = getattr(squeezenet, model_name)
    elif model_name.startswith("vgg"):
        model = getattr(vgg, model_name)
    elif model_name.startswith("xception"):
        model = getattr(xception, model_name)
    else:
        raise NameError("Invalid model")

    net = model(grayscale, num_classes)
    if dataset_name == "clothing1m":
        net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3])

    return model(grayscale, num_classes)
예제 #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--stage', default='train', type=str)

    parser.add_argument('--gpus', default='0,1,2,3', type=str)
    parser.add_argument('--max_epoch', default=200, type=int)
    parser.add_argument('--lr_decay_steps', default='160,190,200', type=str)
    parser.add_argument('--exp', default='', type=str)
    parser.add_argument('--res_path', default='', type=str)
    parser.add_argument('--resume_path', default='', type=str)
    parser.add_argument('--pretrain_path', default='', type=str)

    parser.add_argument('--dataset', default='imagenet', type=str)
    parser.add_argument('--lr', default=0.03, type=float)
    parser.add_argument('--lr_decay_rate', default=0.1, type=float)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--weight_decay', default=5e-4, type=float)
    parser.add_argument('--n_workers', default=32, type=int)
    parser.add_argument('--n_background', default=4096, type=int)
    parser.add_argument('--t', default=0.07, type=float)
    parser.add_argument('--m', default=0.5, type=float)
    parser.add_argument('--dropout', action='store_true')
    parser.add_argument('--blur', action='store_true')
    parser.add_argument('--cos', action='store_true')

    parser.add_argument('--network', default='resnet18', type=str)
    parser.add_argument('--mix', action='store_true')
    parser.add_argument('--not_hardpos', action='store_true')
    parser.add_argument('--InvP', type=int, default=1)
    parser.add_argument('--ramp_up', default='binary', type=str)
    parser.add_argument('--lam_inv', default=0.6, type=float)
    parser.add_argument('--lam_mix', default=1.0, type=float)
    parser.add_argument('--diffusion_layer', default=3, type=int)
    # for cifar 10 the best diffusion_layer is 3 and cifar 100 is 4
    # for imagenet I have only tested when diffusion_layer = 3
    parser.add_argument('--K_nearst', default=4, type=int)
    parser.add_argument('--n_pos', default=50, type=int)
    # for cifar10 the best n_pos is 20, for cifar 100 the best is 10 or 20
    parser.add_argument('--exclusive', default=1, type=int)
    parser.add_argument('--nonlinearhead', default=0, type=int)
    # exclusive best to be 0

    global args
    args = parser.parse_args()
    exp_identifier = get_expidentifier([
        'mix', 'network', 'lam_inv', 'lam_mix', 'diffusion_layer', 'K_nearst',
        'n_pos', 'exclusive', 'max_epoch', 'ramp_up', 'nonlinearhead', 't',
        'weight_decay'
    ], args)
    if not args.InvP: exp_identifier = 'hard'
    args.exp = os.path.join(args.exp, exp_identifier)

    if not os.path.exists(args.exp):
        os.makedirs(args.exp)
    if not os.path.exists(os.path.join(args.exp, 'runs')):
        os.makedirs(os.path.join(args.exp, 'runs'))
    if not os.path.exists(os.path.join(args.exp, 'models')):
        os.makedirs(os.path.join(args.exp, 'models'))
    if not os.path.exists(os.path.join(args.exp, 'logs')):
        os.makedirs(os.path.join(args.exp, 'logs'))

    logger = getLogger(args.exp)

    device_ids = list(map(lambda x: int(x), args.gpus.split(',')))
    device = torch.device('cuda: 0')

    if args.dataset.startswith('cifar'):
        train_loader, val_loader, train_ordered_labels, train_dataset, val_dataset = cifar.get_dataloader(
            args)
    elif args.dataset.startswith('imagenet'):
        train_loader, val_loader, train_ordered_labels, train_dataset, val_dataset = imagenet.get_instance_dataloader(
            args)
    elif args.dataset == 'svhn':
        train_loader, val_loader, train_ordered_labels, train_dataset, val_dataset = svhn.get_dataloader(
            args)

    # create model
    if args.network == 'alexnet':
        network = alexnet(128)
    if args.network == 'alexnet_cifar':
        network = AlexNet_cifar(128)
    elif args.network == 'resnet18_cifar':
        network = ResNet18_cifar(128,
                                 dropout=args.dropout,
                                 non_linear_head=args.nonlinearhead)
    elif args.network == 'resnet50_cifar':
        network = ResNet50_cifar(128, dropout=args.dropout)
    elif args.network == 'wide_resnet28':
        network = WideResNetInstance(28, 2)
    elif args.network == 'resnet18':
        network = resnet18(non_linear_head=args.nonlinearhead)
    elif args.network == 'pre-resnet18':
        network = PreActResNet18(128)
    elif args.network == 'resnet50':
        network = resnet50(non_linear_head=args.nonlinearhead)
    elif args.network == 'pre-resnet50':
        network = PreActResNet50(128)
    network = nn.DataParallel(network, device_ids=device_ids)
    network.to(device)

    # create optimizer

    if args.network == 'pre-resnet18' or args.network == 'pre-resnet50':
        logging.info(
            colorful(
                'Exclude bns from weight decay, copied from LocalAggregation proposed by Zhuang et al [ICCV 2019]'
            ))
        parameters = exclude_bn_weight_bias_from_weight_decay(
            network, weight_decay=args.weight_decay)
    else:
        parameters = network.parameters()

    optimizer = torch.optim.SGD(
        parameters,
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.weight_decay,
    )

    cudnn.benchmark = True

    # create memory_bank
    global writer
    writer = SummaryWriter(comment='InvariancePropagation',
                           logdir=os.path.join(args.exp, 'runs'))
    memory_bank = objective.MemoryBank_v1(len(train_dataset),
                                          train_ordered_labels,
                                          writer,
                                          device,
                                          m=args.m)

    # create criterion
    criterionA = objective.InvariancePropagationLoss(
        args.t,
        diffusion_layer=args.diffusion_layer,
        k=args.K_nearst,
        n_pos=args.n_pos,
        exclusive=args.exclusive,
        InvP=args.InvP,
        hard_pos=(not args.not_hardpos))
    criterionB = objective.MixPointLoss(args.t)
    if args.ramp_up == 'binary':
        ramp_up = lambda i_epoch: objective.BinaryRampUp(i_epoch, 30)
    elif args.ramp_up == 'gaussian':
        ramp_up = lambda i_epoch: objective.GaussianRampUp(i_epoch, 30, 5)
    elif args.ramp_up == 'zero':
        ramp_up = lambda i_epoch: 1

    logging.info(beautify(args))
    start_epoch = 0
    if args.pretrain_path != '' and args.pretrain_path != 'none':
        logging.info('loading pretrained file from {}'.format(
            args.pretrain_path))
        checkpoint = torch.load(args.pretrain_path)
        network.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        _memory_bank = checkpoint['memory_banks']
        try:
            _neigh = checkpoint['neigh']
            memory_bank.neigh = _neigh
        except:
            logging.info(
                colorful(
                    'The Pretrained Path has No NEIGH and require a epoch to re-calculate'
                ))
        memory_bank.points = _memory_bank
        start_epoch = checkpoint['epoch']
    else:
        initialize_memorybank(network, train_loader, device, memory_bank)
    logging.info('start training')
    best_acc = 0.0

    try:
        for i_epoch in range(start_epoch, args.max_epoch):
            adjust_learning_rate(args.lr,
                                 args.lr_decay_steps,
                                 optimizer,
                                 i_epoch,
                                 lr_decay_rate=args.lr_decay_rate,
                                 cos=args.cos,
                                 max_epoch=args.max_epoch)
            train(i_epoch, network, criterionA, criterionB, optimizer,
                  train_loader, device, memory_bank, ramp_up)

            save_name = 'checkpoint.pth'
            checkpoint = {
                'epoch': i_epoch + 1,
                'state_dict': network.state_dict(),
                'optimizer': optimizer.state_dict(),
                'memory_banks': memory_bank.points,
                'neigh': memory_bank.neigh,
            }
            torch.save(checkpoint, os.path.join(args.exp, 'models', save_name))

            # scheduler.step()
            # validate(network, memory_bank, val_loader, train_ordered_labels, device)
            acc = kNN(i_epoch,
                      network,
                      memory_bank,
                      val_loader,
                      train_ordered_labels,
                      K=200,
                      sigma=0.07)
            if acc >= best_acc:
                best_acc = acc
                torch.save(checkpoint,
                           os.path.join(args.exp, 'models', 'best.pth'))
            if i_epoch in [30, 60, 120, 160, 200, 400, 600]:
                torch.save(
                    checkpoint,
                    os.path.join(args.exp, 'models',
                                 '{}.pth'.format(i_epoch + 1)))

            args.y_best_acc = best_acc
            logging.info(
                colorful('[Epoch: {}] val acc: {:.4f}'.format(i_epoch, acc)))
            logging.info(
                colorful('[Epoch: {}] best acc: {:.4f}'.format(
                    i_epoch, best_acc)))
            writer.add_scalar('acc', acc, i_epoch + 1)

            with torch.no_grad():
                for name, param in network.named_parameters():
                    if 'bn' not in name:
                        writer.add_histogram(name, param, i_epoch)

            # cluster
    except KeyboardInterrupt as e:
        logging.info('KeyboardInterrupt at {} Epochs'.format(i_epoch))
        save_result(args)
        exit()

    save_result(args)
예제 #9
0
        #    labels = labels.type_as(torch.LongTensor()).view(-1) - 1

        images = Variable(images, requires_grad=False).cuda()
        labels = Variable(labels, requires_grad=False).cuda()
        pred, _ = cnn(images)
        test_loss += loss_func(pred, labels).data[0]
        pred = torch.max(pred.data, 1)[1]
        total += labels.size(0)
        correct += (pred == labels.data).sum()
    val_acc = correct / total
    val_loss = test_loss / total
    cnn.train()
    return val_acc, val_loss

if args.model == 'resnet18':
    cnn = PreActResNet18(channels=num_channels, num_classes=num_classes)
elif args.model == 'resnet34':
    cnn = PreActResNet34(channels=num_channels, num_classes=num_classes)
elif args.model == 'resnet50':
    cnn = PreActResNet50(channels=num_channels, num_classes=num_classes)
elif args.model == 'resnet101':
    cnn = PreActResNet101(channels=num_channels, num_classes=num_classes)
elif args.model == 'resnet152':
    cnn = PreActResNet152(channels=num_channels, num_classes=num_classes)
elif args.model == 'vgg':
    cnn = VGG(depth=16, num_classes=num_classes, channels=num_channels)
elif args.model == 'wideresnet':
    if args.dataset == 'svhn':
        cnn = Wide_ResNet(depth=16, num_classes=num_classes, widen_factor=8, dropout_rate=args.dropout_rate)
    else:
        cnn = Wide_ResNet(depth=28, num_classes=num_classes, widen_factor=10, dropout_rate=args.dropout_rate)