Ejemplo n.º 1
0
    def __init__(self, mode=None, backbone='vgg16'):
        super().__init__()
        self.mode = mode

        if self.mode == 'nlt':
            if backbone == 'vgg16':
                print('backbone is vgg16')
                self.encoder = vgg16(nlt=True)
                self.decoder = decoder(feature_channel=512, nlt=True)

            elif backbone == 'ResNet50':
                print('backbone is ResNet50')
                self.encoder = ResNet50(nlt=True)
                self.encoder.load_state_dict(torch.load(
                    '/media/D/ht/PyTorch_Pretrained/resnet50-19c8e357.pth'),
                                             strict=False)
                self.decoder = decoder(feature_channel=1024, nlt=True)

        else:
            if backbone == 'vgg16':
                self.encoder = vgg16(pretrained=True, nlt=False)
                self.decoder = decoder(feature_channel=512)

            elif backbone == 'ResNet50':
                self.encoder = ResNet50(nlt=False)
                self.encoder.load_state_dict(torch.load(
                    '/media/D/ht/PyTorch_Pretrained/resnet50-19c8e357.pth'),
                                             strict=False)
                self.decoder = decoder(feature_channel=1024, nlt=False)
Ejemplo n.º 2
0
def get_model(train_model):

    if train_model == 'resnet18':
        return resnet.resnet18()
    elif train_model == 'resnet34':
        return resnet.resnet34()
    elif train_model == 'resnet50':
        return resnet.resnet50()
    elif train_model == 'resnet101':
        return resnet.resnet101()
    elif train_model == 'resnet152':
        return resnet.resnet152()
    elif train_model == 'resnet18_copy':
        return resnet_copy.resnet18()
    elif train_model == 'resnet34_copy':
        return resnet_copy.resnet34()
    elif train_model == 'resnet50_copy':
        return resnet_copy.resnet50()
    elif train_model == 'resnet101_copy':
        return resnet_copy.resnet101()
    elif train_model == 'resnet152':
        return resnet_copy.resnet152()
    elif train_model == 'vgg11':
        return vgg11()
    elif train_model == 'vgg13':
        return vgg13()
    elif train_model == 'vgg16':
        return vgg16()
    elif train_model == 'vgg19':
        return vgg19()
    elif train_model == 'nin':
        return nin()
    elif train_model == 'googlenet':
        return googlenet()
Ejemplo n.º 3
0
def get_model(args):
    model = vgg.vgg16(pretrained=True,
                      num_classes=args.num_classes,
                      att_dir=args.att_dir,
                      training_epoch=args.epoch)
    model = torch.nn.DataParallel(model).cuda()
    param_groups = model.module.get_parameter_groups()
    optimizer = optim.SGD([{
        'params': param_groups[0],
        'lr': args.lr
    }, {
        'params': param_groups[1],
        'lr': 2 * args.lr
    }, {
        'params': param_groups[2],
        'lr': 10 * args.lr
    }, {
        'params': param_groups[3],
        'lr': 20 * args.lr
    }],
                          momentum=0.9,
                          weight_decay=args.weight_decay,
                          nesterov=True)

    return model, optimizer
    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        if opt.net == 'vgg':
            self.net = vgg.vgg16(num_classes=opt.n_classes)
        elif opt.net == 'normal':
            self.net = Net(num_classes=opt.n_classes)
        elif opt.net == 'att':
            self.net = AttNet(num_classes=opt.n_classes)
        elif opt.net == 'vgg_att':
            self.net = VGGAttNet(num_classes=opt.n_classes)
        else:
            raise ValueError('[%s] cannot be used!' % opt.net)

        self.net = self.net.to(self.device)

        if not opt.is_train:
            self.load_network(self.opt.which_epoch)
        else:
            if self.opt.resume_train:
                self.load_network(self.opt.which_epoch)

            self.cls_criterion = nn.CrossEntropyLoss()
            if opt.training_type == 'att_consist':
                self.mask_criterion = nn.MSELoss()
            self.optimizer = optim.SGD(self.net.parameters(),
                                       lr=0.001,
                                       momentum=0.9)
Ejemplo n.º 5
0
    def _get_vgg16(self, pretrained=True):
        model = vgg16(pretrained=pretrained)
        model.classifier[6] = nn.Linear(in_features=4096,
                                        out_features=125,
                                        bias=True)

        return model
Ejemplo n.º 6
0
    def __init__(self, opt):
        super(FeatureLoss, self).__init__()
        self.opt = opt
        self.isTrain = opt.isTrain
        self.vgg = VGG.vgg16(pretrained = True)
        self.Tensor = torch.cuda.FloatTensor if use_gpu else torch.Tensor

        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)

        # Assuming norm_type = batch
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
        # model  of Generator Net is unet_256
        self.GeneratorNet = Generator(opt.input_nc, opt.output_nc, 8, opt.ngf, norm_layer=norm_layer,use_dropout = not opt.no_dropout)
        if use_gpu:
            self.GeneratorNet.cuda(0)
        self.GeneratorNet.apply(init_weights)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            # model  of Discriminator Net is basic
            self.DiscriminatorNet = Discriminator(opt.input_nc+ opt.output_nc, opt.ndf, n_layers = 3, norm_layer = norm_layer, use_sigmoid = use_sigmoid)
            if use_gpu:
                self.DiscriminatorNet.cuda(0)
            self.DiscriminatorNet.apply(init_weights)

        if not self.isTrain or opt.continue_train:
            self.load_network(self.GeneratorNet, 'Generator', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.DiscriminatorNet, 'Discriminator', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.learning_rate = opt.lr
            # defining loss functions
            self.criterionGAN = GANLoss(use_lsgan = not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionFV = loss_FV.FeatureVectorLoss()

            self.MySchedulers = []  # initialising schedulers
            self.MyOptimizers = []  # initialising optimizers
            self.generator_optimizer = torch.optim.Adam(self.GeneratorNet.parameters(), lr=self.learning_rate, betas = (opt.beta1, 0.999))
            self.discriminator_optimizer = torch.optim.Adam(self.DiscriminatorNet.parameters(), lr=self.learning_rate, betas = (opt.beta1, 0.999))
            self.MyOptimizers.append(self.generator_optimizer)
            self.MyOptimizers.append(self.discriminator_optimizer)
            def lambda_rule(epoch):
                lr_l = 1.0 - max(0, epoch - opt.niter)/float(opt.niter_decay+1)
                return lr_l
            for optimizer in self.MyOptimizers:
                self.MySchedulers.append(lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda_rule))
                # assuming opt.lr_policy == 'lambda'


        print('<============ NETWORKS INITIATED ============>')
        print_net(self.GeneratorNet)
        if self.isTrain:
            print_net(self.DiscriminatorNet)
        print('<=============================================>')
Ejemplo n.º 7
0
def get_network(args):

    if args.net == 'vgg16':
        from models.vgg import vgg16
        model_ft = vgg16(args.num_classes, export_onnx=args.export_onnx)
    elif args.net == 'alexnet':
        from models.alexnet import alexnet
        model_ft = alexnet(num_classes=args.num_classes,
                           export_onnx=args.export_onnx)
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet_v2
        model_ft = mobilenet_v2(pretrained=True, export_onnx=args.export_onnx)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        model_ft = vgg19(args.num_classes, export_onnx=args.export_onnx)
    else:
        if args.net == 'googlenet':
            from models.googlenet import googlenet
            model_ft = googlenet(pretrained=True)
        elif args.net == 'inception':
            from models.inception import inception_v3
            model_ft = inception_v3(args,
                                    pretrained=True,
                                    export_onnx=args.export_onnx)
        elif args.net == 'resnet18':
            from models.resnet import resnet18
            model_ft = resnet18(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet34':
            from models.resnet import resnet34
            model_ft = resnet34(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet101':
            from models.resnet import resnet101
            model_ft = resnet101(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet50':
            from models.resnet import resnet50
            model_ft = resnet50(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet152':
            from models.resnet import resnet152
            model_ft = resnet152(pretrained=True, export_onnx=args.export_onnx)
        else:
            print("The %s is not supported..." % (args.net))
            return
    if args.net == 'mobilenet':
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs * 4, args.num_classes)
    else:
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, args.num_classes)
    net = model_ft

    return net
Ejemplo n.º 8
0
def get_network(args,cfg):
    """ return given network
    """
    # pdb.set_trace()
    if args.net == 'lenet5':
        net = LeNet5().cuda()
    elif args.net == 'alexnet':
        net = alexnet(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16':
        net = vgg16(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13':
        net = vgg13(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11':
        net = vgg11(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19':
        net = vgg19(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16_bn':
        net = vgg16_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13_bn':
        net = vgg13_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11_bn':
        net = vgg11_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19_bn':
        net = vgg19_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net =='inceptionv3':
        net = inception_v3().cuda()
    # elif args.net == 'inceptionv4':
    #     net = inceptionv4().cuda()
    # elif args.net == 'inceptionresnetv2':
    #     net = inception_resnet_v2().cuda()
    elif args.net == 'resnet18':
        net = resnet18(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet34':
        net = resnet34(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet50':
        net = resnet50(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet101':
        net = resnet101(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet152':
        net = resnet152(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'squeezenet':
        net = squeezenet1_0().cuda()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return net
Ejemplo n.º 9
0
    def __init__(self,
                 rnn_dim,
                 embed_dim,
                 embed_matrix: np.array = None,
                 vgg_layer='fc2'):
        self.coco = COCODataset()
        self.vgg = vgg16(path=os.path.join(CONFIG.COCO_PATH,
                                           'vgg16_weights.npz'),
                         from_tf_check_point=False)
        self.vgg_feature_name = vgg_layer
        self.cnn_dim = self.vgg.get_layer_shape(
            self.vgg_feature_name).as_list()[-1]

        # self.cnn_dim = 512

        self.word_to_idx = self.coco.word_to_idx
        self.idx_to_word = self.coco.idx_to_word
        self.vocab_size = CONFIG.VOCAB_SIZE

        self.null_idx = self.word_to_idx['<NULL>']
        self.start_idx = self.word_to_idx['<START>']
        self.end_idx = self.word_to_idx['<END>']

        self.embed_dim = embed_dim
        self.rnn_dim = rnn_dim
        self.embed_matrix = embed_matrix

        # Suppose the caption is :
        #       <START> i play basketball ... as Klay <END>    # length 17
        # Then the input is
        #       <START> i play basketball ... as Klay           # length 16
        # And the desired output is
        #       i play basketball ... as Klay <END>             # length
        # If the caption is padded with <NULL>, then <NULL> is not accounted in loss function,
        # thus we need a mask too indicate the <NULL> token

        self.time_span = CONFIG.TIME_SPAN - 1
        self.params = {}

        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)

        with self.graph.as_default():
            with tf.name_scope('Captioning'):
                self.setup()
                self.build_graph()
Ejemplo n.º 10
0
def train_without_augmentation_ck():
    model = vgg16()
    (X_train, Y_train), (X_test, Y_test), (X_validation, Y_validation) = ck()
    if load_weights:
        model.load_weights('model_vgg_16_ck.h5')
    history = model.fit(X_train,
                        Y_train,
                        batch_size=batch_size,
                        nb_epoch=nb_epoch,
                        validation_data=(X_validation, Y_validation),
                        shuffle=True)
    predictions = model.predict(X_test, batch_size=batch_size, verbose=1)
    evaluates = model.evaluate(X_test, Y_test)
    historic(history)
    confusion_matrix(predictions, Y_test)
    if save_weights:
        model.save_weights('model_vgg_16_ck.h5')
Ejemplo n.º 11
0
def get_network(args):
    """ return given network
    """
    if args.task == 'cifar10':
        nclass = 10
    elif args.task == 'cifar100':
        nclass = 100
    #Yang added none bn vggs
    if args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11(num_classes=nclass)
    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13(num_classes=nclass)
    elif args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16(num_classes=nclass)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19(num_classes=nclass) 
    
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18(num_classes=nclass)
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34(num_classes=nclass)
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50(num_classes=nclass)
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101(num_classes=nclass)
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152(num_classes=nclass)

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu: #use_gpu
        net = net.cuda()

    return net
Ejemplo n.º 12
0
def get_model(args):
    model = vgg.vgg16(num_classes=args.num_classes)
    model = torch.nn.DataParallel(model).cuda()

    pretrained_dict = torch.load(args.restore_from)['state_dict']
    model_dict = model.state_dict()
    
    print(model_dict.keys())
    print(pretrained_dict.keys())
    
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
    print("Weights cannot be loaded:")
    print([k for k in model_dict.keys() if k not in pretrained_dict.keys()])

    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    return  model
Ejemplo n.º 13
0
def train_without_augmentation():
    model = vgg16(lr=0.0001, dropout_in=0.25, dropout_out=0.5)
    (X_train, Y_train), (X_test, Y_test), (X_validation,
                                           Y_validation) = fer2013()
    if load_weights:
        model.load_weights('model_vgg_16_eq.h5')
    history = model.fit(X_train,
                        Y_train,
                        batch_size=batch_size,
                        nb_epoch=nb_epoch,
                        validation_data=(X_validation, Y_validation),
                        shuffle=True)
    if save_weights:
        model.save_weights('vgg16_fer2013_np.h5')
    predictions = model.predict(X_test, batch_size=batch_size, verbose=1)
    evaluates = model.evaluate(X_test, Y_test)
    historic(history)
    confusion_matrix(predictions, Y_test)
Ejemplo n.º 14
0
def main():
    args = parse_args()
    use_cuda = not args.cpu and torch.cuda.is_available()
    device = torch.device(f"cuda:{args.gpu}" if use_cuda else "cpu")

    torch.manual_seed(args.seed)
    trainset, testset, trainloader, testloader = stl10(args.batch_size)

    model = vgg.vgg16(not args.deterministic, device, False).to(device)
    saved_state = torch.load(args.saved_model)
    if args.saved_model[-4:] == '.tar':
        saved_state = saved_state['model_state_dict']
    model.load_state_dict(saved_state)

    criterion = torch.nn.CrossEntropyLoss()

    test_loss, pct_right = test(args, model, device, testloader, criterion, 10)
    print(f'test loss: {test_loss}, correct: {100*pct_right}')
Ejemplo n.º 15
0
def get_network(args):

    if args.net == 'vgg16':
        from models.vgg import vgg16
        net = vgg16()

    elif args.net == 'vgg11':
        from models.vgg import vgg11
        net = vgg11()

    elif args.net == 'vgg13':
        from models.vgg import vgg13
        net = vgg13()

    elif args.net == 'vgg19':
        from models.vgg import vgg19
        net = vgg19()

    return net
Ejemplo n.º 16
0
def prune_vgg(num_classes: int,
              model,
              pruning_strategy,
              cuda=True,
              dataparallel=False):
    cfg, cfg_mask = _calculate_channel_mask(model, pruning_strategy, cuda=cuda)

    if isinstance(model.classifier, nn.Sequential):
        pruned_model = vgg16_linear(num_classes=num_classes, cfg=cfg)
    elif isinstance(model.classifier, nn.Linear):
        pruned_model = vgg16(num_classes=num_classes, cfg=cfg)

    if cuda:
        pruned_model.cuda()
    if dataparallel:
        pruned_model.feature = torch.nn.DataParallel(pruned_model.feature)
    assign_model(model, pruned_model, cfg_mask)

    return pruned_model, cfg
    pass
Ejemplo n.º 17
0
    def test_forward_pass_vgg16(self):
        # This is not really a test. Just printing out chapes of tensors.
        # Input shape in SSD paper (300x300) is different from VGG paper (224x224)
        # Here we'll find dimensions of output of each layer of VGG-16 and compare with architecture in SSD paper
        # Turns out, SSD paper expects base model output shape 512x38x38 while VGG-16 from torchvision
        # outputs 512x37x37 on the same input size.
        # MaxPool2d layer reduces the output dimension if the input dimension is odd/
        batch_size = 4
        num_channels = 3
        image_size = (300, 300)

        base_model = vgg16(pretrained=True)

        input = torch.zeros((batch_size, num_channels, image_size[0], image_size[1]), dtype=torch.float32)

        x = input
        for f in list(base_model.features.modules())[1:]:
            x = f(x)
            print(f)
            print(x.shape)
def get_model(model_type, use_gpu):

    if model_type == 'vgg16':
        from models.vgg import vgg16
        model = vgg16()
    elif model_type == 'resnet50':
        from models.resnet import resnet50
        model = resnet50()
    elif model_type == 'resnet18':
        from models.resnet import resnet18
        model = resnet18()
    elif model_type == 'googlenet':
        from models.googlenet import googlenet
        model = googlenet()
    else:
        print('this model is not supported')
        sys.exit()
    
    if use_gpu:
        model = model.cuda()
    
    return model
Ejemplo n.º 19
0
def train_with_augmentation_ck():
    datagen = ImageDataGenerator(rotation_range=10., horizontal_flip=True)

    model = vgg16(lr=0.00005)
    (X_train, Y_train), (X_test, Y_test), (X_validation, Y_validation) = ck()

    if load_weights:
        model.load_weights('model_vgg_16_eq.h5')

    history = model.fit_generator(datagen.flow(X_train,
                                               Y_train,
                                               batch_size=batch_size),
                                  samples_per_epoch=2000,
                                  nb_epoch=nb_epoch,
                                  validation_data=(X_validation, Y_validation))
    predictions = model.predict(X_test, batch_size=batch_size, verbose=1)

    historic(history)

    confusion_matrix(predictions, Y_test)

    if save_weights:
        model.save_weights('model_vgg_16_ft_ck.h5')
Ejemplo n.º 20
0
def get_model(args):
    model = vgg16(pretrained=True, delta=args.delta)

    model = torch.nn.DataParallel(model).cuda()
    param_groups = model.module.get_parameter_groups()

    optimizer = optim.SGD([{
        'params': param_groups[0],
        'lr': args.lr
    }, {
        'params': param_groups[1],
        'lr': 2 * args.lr
    }, {
        'params': param_groups[2],
        'lr': 10 * args.lr
    }, {
        'params': param_groups[3],
        'lr': 20 * args.lr
    }],
                          momentum=0.9,
                          weight_decay=args.weight_decay,
                          nesterov=True)

    return model, optimizer
Ejemplo n.º 21
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print('Use GPU: {} for training'.format(args.gpu))

    if args.distributed:
        if args.dist_url == 'env://' and args.rank == -1:
            args.rank = int(os.environ['RANK'])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(  #需要nccl、url、节点数、gpu的总数
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank,
        )
    # create model
    print(
        "=> rank_id '{}', world_size '{}'".format(args.rank,
                                                  args.world_size), )
    print("=> creating model '{}'".format(args.model))
    if args.model == 'resnet50':
        model = resnet50()
    elif args.model == 'vgg16':
        model = vgg16()
    else:
        raise ValueError("Only support resnet50 and vgg16.")

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            # args.batch_size = int(args.batch_size / ngpus_per_node)
            # args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.gpu],
            )
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate
            # batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size
        # to all available GPUs
        if args.model.startswith('alexnet') or args.model.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    # optimizer = get_optimizer(model, params, args.lr)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        nesterov=False,
    )
    cudnn.benchmark = True  #将cudnn.benchmark设置为true,可显著提升速度

    # Data loading code
    train_dir = os.path.join(args.data_dir, 'train')
    # valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
    )

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, )
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler,
    )

    # val_loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(256),
    #         transforms.CenterCrop(224),
    #         transforms.ToTensor(),
    #         normalize,
    #     ])),
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True)

    # if args.evaluate:
    #     validate(val_loader, model, criterion, args)
    #     return

    for epoch in range(args.num_epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)
Ejemplo n.º 22
0
def get_model(args, model_path=None):
    """

    :param args: super arguments
    :param model_path: if not None, load already trained model parameters.
    :return: model
    """
    if args.scratch:  # train model from scratch
        pretrained = False
        model_dir = None
        print("=> Loading model '{}' from scratch...".format(args.model))
    else:  # train model with pretrained model
        pretrained = True
        model_dir = os.path.join(args.root_path, args.pretrained_models_path)
        print("=> Loading pretrained model '{}'...".format(args.model))

    if args.model.startswith('resnet'):

        if args.model == 'resnet18':
            model = resnet18(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet34':
            model = resnet34(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet50':
            model = resnet50(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet101':
            model = resnet101(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'resnet152':
            model = resnet152(pretrained=pretrained, model_dir=model_dir)

        model.fc = nn.Linear(model.fc.in_features, args.num_classes)

    elif args.model.startswith('vgg'):
        if args.model == 'vgg11':
            model = vgg11(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg11_bn':
            model = vgg11_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13':
            model = vgg13(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg13_bn':
            model = vgg13_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16':
            model = vgg16(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg16_bn':
            model = vgg16_bn(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19':
            model = vgg19(pretrained=pretrained, model_dir=model_dir)
        elif args.model == 'vgg19_bn':
            model = vgg19_bn(pretrained=pretrained, model_dir=model_dir)

        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    elif args.model == 'alexnet':
        model = alexnet(pretrained=pretrained, model_dir=model_dir)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, args.num_classes)

    # Load already trained model parameters and go on training
    if model_path is not None:
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])

    return model
Ejemplo n.º 23
0
import argparse
from utils import load_data

WIN_SIZE = 40
NUM_CLS = 2
parser = argparse.ArgumentParser()
parser.add_argument("--test_map_path", type=str)
parser.add_argument("--pred_path", type=str)
parser.add_argument("--model_path", type=str)
args = parser.parse_args()

TEST_MAP_PATH = args.test_map_path
PRED_PATH = args.pred_path
STRIDE = 20

model = vgg16()
model.load_weights(args.model_path)

test_data, img_names = load_data.load_all_data(MAP_PATH,
                                               '',
                                               WIN_SIZE,
                                               20,
                                               flip=False)
test_data = preprocess_input(test_data.astype('float64'))

pred = model.predict(test_data)
map_img = cv2.imread(MAP_PATH)
res = np.zeros((map_img.shape[0], map_img.shape[1]))
for i in range(pred.shape[0]):
    if pred[i, 0] > 0.99:
        idx = img_names[i].split('_')
Ejemplo n.º 24
0
print("Length of validation set: ", len(valset))
train_loader = DataLoader(trainset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4)
val_loader = DataLoader(valset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_name = args.backbone
if model_name == "vgg16":
    model = vgg16(architecture_type="cam",
                  pretrained=True,
                  num_classes=n_classes,
                  large_feature_map=False)
elif model_name == "resnet18":
    model = resnet18('cam', num_classes=n_classes)
elif model_name == "squeezenet1_1":
    model = squeezenet1_1(num_classes=n_classes)

print(model)

checkpoint = torch.load(args.checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
root = args.dataset_path + 'val/'
masks_root = args.masks_path
import glob
filename = [
    glob.glob(root + i + "/*.jpg") for i in object_categories
Ejemplo n.º 25
0
def get_model(class_num):
    if (MODEL_TYPE == 'alexnet'):
        model = alexnet.alexnet(pretrained=FINETUNE)
    elif (MODEL_TYPE == 'vgg'):
        if (MODEL_DEPTH_OR_VERSION == 11):
            model = vgg.vgg11(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 13):
            model = vgg.vgg13(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 16):
            model = vgg.vgg16(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 19):
            model = vgg.vgg19(pretrained=FINETUNE)
        else:
            print('Error : VGG should have depth of either [11, 13, 16, 19]')
            sys.exit(1)
    elif (MODEL_TYPE == 'squeezenet'):
        if (MODEL_DEPTH_OR_VERSION == 0 or MODEL_DEPTH_OR_VERSION == 'v0'):
            model = squeezenet.squeezenet1_0(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 1 or MODEL_DEPTH_OR_VERSION == 'v1'):
            model = squeezenet.squeezenet1_1(pretrained=FINETUNE)
        else:
            print('Error : Squeezenet should have version of either [0, 1]')
            sys.exit(1)
    elif (MODEL_TYPE == 'resnet'):
        if (MODEL_DEPTH_OR_VERSION == 18):
            model = resnet.resnet18(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 34):
            model = resnet.resnet34(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 50):
            model = resnet.resnet50(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 101):
            model = resnet.resnet101(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 152):
            model = resnet.resnet152(pretrained=FINETUNE)
        else:
            print(
                'Error : Resnet should have depth of either [18, 34, 50, 101, 152]'
            )
            sys.exit(1)
    elif (MODEL_TYPE == 'densenet'):
        if (MODEL_DEPTH_OR_VERSION == 121):
            model = densenet.densenet121(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 169):
            model = densenet.densenet169(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 161):
            model = densenet.densenet161(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 201):
            model = densenet.densenet201(pretrained=FINETUNE)
        else:
            print(
                'Error : Densenet should have depth of either [121, 169, 161, 201]'
            )
            sys.exit(1)
    elif (MODEL_TYPE == 'inception'):
        if (MODEL_DEPTH_OR_VERSION == 3 or MODEL_DEPTH_OR_VERSION == 'v3'):
            model = inception.inception_v3(pretrained=FINETUNE)
        else:
            print('Error : Inception should have version of either [3, ]')
            sys.exit(1)
    else:
        print(
            'Error : Network should be either [alexnet / squeezenet / vgg / resnet / densenet / inception]'
        )
        sys.exit(1)

    if (MODEL_TYPE == 'alexnet' or MODEL_TYPE == 'vgg'):
        num_ftrs = model.classifier[6].in_features
        feature_model = list(model.classifier.children())
        feature_model.pop()
        feature_model.append(nn.Linear(num_ftrs, class_num))
        model.classifier = nn.Sequential(*feature_model)
    elif (MODEL_TYPE == 'resnet' or MODEL_TYPE == 'inception'):
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, class_num)
    elif (MODEL_TYPE == 'densenet'):
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, class_num)

    return model
Ejemplo n.º 26
0
                                          shuffle=True,
                                          num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=4,
                                         shuffle=False,
                                         num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

net = vgg.vgg16(deformable=deform, num_classes=10)
PATH = './cifar_net2.pth' if deform else './cifar_net.pth'
net.load_state_dict(torch.load(PATH))

net.to(device)

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
Ejemplo n.º 27
0
 def __init__(self):
     super(FeatureVectorLoss, self).__init__()
     self.loss = nn.L1Loss()
     self.model = VGG.vgg16()
Ejemplo n.º 28
0
def main():
    global args, best_err1
    args = parser.parse_args()

    # TensorBoard configure
    if args.tensorboard:
        configure('%s_checkpoints/%s'%(args.dataset, args.expname))

    # CUDA
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_ids)
    if torch.cuda.is_available():
        cudnn.benchmark = True  # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        kwargs = {'num_workers': 2, 'pin_memory': True}
    else:
        kwargs = {'num_workers': 2}

    # Data loading code
    if args.dataset == 'cifar10':
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
    elif args.dataset == 'cifar100':
        normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                         std=[0.2634, 0.2528, 0.2719])
    elif args.dataset == 'cub':
        normalize = transforms.Normalize(mean=[0.4862, 0.4973, 0.4293],
                                         std=[0.2230, 0.2185, 0.2472])
    elif args.dataset == 'webvision':
        normalize = transforms.Normalize(mean=[0.49274242, 0.46481857, 0.41779366],
                                         std=[0.26831809, 0.26145372, 0.27042758])
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Transforms
    if args.augment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.ToTensor(),
            normalize,
        ])
    val_transform = transforms.Compose([
        transforms.Resize(args.test_image_size),
        transforms.CenterCrop(args.test_crop_image_size),
        transforms.ToTensor(),
        normalize
    ])

    # Datasets
    num_classes = 10    # default 10 classes
    if args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR10('./data/', train=False, download=True, transform=val_transform)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR100('./data/', train=False, download=True, transform=val_transform)
        num_classes = 100
    elif args.dataset == 'cub':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/train/',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/test/',
                                           transform=val_transform)
        num_classes = 200
    elif args.dataset == 'webvision':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/train',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/val',
                                           transform=val_transform)
        num_classes = 1000
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Data Loader
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, **kwargs)

    # Create model
    if args.model == 'AlexNet':
        model = alexnet(pretrained=False, num_classes=num_classes)
    elif args.model == 'VGG':
        use_batch_normalization = True  # default use Batch Normalization
        if use_batch_normalization:
            if args.depth == 11:
                model = vgg11_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19_bn(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
        else:
            if args.depth == 11:
                model = vgg11(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
    elif args.model == 'Inception':
        model = inception_v3(pretrained=False, num_classes=num_classes)
    elif args.model == 'ResNet':
        if args.depth == 18:
            model = resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    elif args.model == 'MPN-COV-ResNet':
        if args.depth == 18:
            model = mpn_cov_resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = mpn_cov_resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = mpn_cov_resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = mpn_cov_resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = mpn_cov_resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport MPN-COV-ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    else:
        raise Exception('Unsupport model'.format(args.model))

    # Get the number of model parameters
    print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    if torch.cuda.is_available():
        model = model.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_err1 = checkpoint['best_err1']
            model.load_state_dict(checkpoint['state_dict'])
            print("==> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        criterion = criterion.cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # Evaluate on validation set
        err1 = validate(val_loader, model, criterion, epoch)

        # Remember best err1 and save checkpoint
        is_best = (err1 <= best_err1)
        best_err1 = min(err1, best_err1)
        print("Current best accuracy (error):", best_err1)
        save_checkpoint({
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best)

    print("Best accuracy (error):", best_err1)
Ejemplo n.º 29
0
def main(set_,batch_size,model_str, base,io_scale,n_scale,interaction,n,epoch, init_rate,decay, gamma, step_size,save_dir):
   
    params = Params(set_,batch_size,model_str, base,io_scale,n_scale,interaction,n,epoch, init_rate,decay, gamma, step_size)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    downscale_kern = BesselConv2d(base,n_scales=n_scale).to(device)
    n_train = 10000
    n_test = 10000  
    num_process =8  
    resize= None
    if set_ == "cifar":
        resize= ["resize",2]
    if "dss" in model_str or 'deform' in model_str or 'final' in model_str:
        down = [downscale_kern,n_scale]
    else:
        down = None
    resize= None 
    train_gen = dataset.get_gen_rand(set_,"train", batch_size,downscale_kern=down,n=n_train, resize=resize)
    test_gen = dataset.get_gen_rand(set_,"test" ,batch_size,downscale_kern=down,n=n_test, resize=resize)
    
    with mlflow.start_run():  
        output_dir = './log'

        t_list = []
        # MLflow log parameters
        for key, value in vars(params).items():
            mlflow.log_param(key, value)
        output_dir = dirpath = tempfile.mkdtemp()
        writer = SummaryWriter(output_dir)
        print("Writing TensorBoard events locally to %s\n" % output_dir)

        #load model
        if set_ == "MNIST":
            if model_str == 'dss_v1':
                model = dss.DSS_plain(base,io_scale,n,interaction)
            elif model_str =='dss_v2':
                model = dss.DSS_plain_2(base,io_scale,n,interaction)
            elif model_str == "deform":
                model = dss2.deform_dss(base,io_scale[0],n,10)
            elif model_str == "deform2":
                model = dss2.deform_dss2(base,io_scale[0],n,10)               
        elif set_ == "cifar":
            if model_str == 'dss_v1':
                model = dss.DSS_cifar(base,io_scale,n,interaction)
            elif model_str =='dss_v2':
                model = dss.DSS_2_cifar(base,io_scale,n,interaction)
            elif model_str =="cnn":
                model = vgg.vgg16()
            elif model_str == "deform":
                model = dss2.deform_dss_cifar(base,io_scale[0],n,10)
            elif model_str == "deform_res":
                model = dcn.deform_ResNet101()
        model = model.to(device)
        model = nn.DataParallel(model,device_ids=[0])
        #model.share_memory()
        # For multiprocessing
        #mp.set_start_method('spawn')
        #num_gpus = torch.cuda.device_count()
        #rank = int(os.environ['RANK'])
        #dist.init_process_group(backend='nccl')        


        print(model,flush=True)
        model_param = model.parameters()
        print("Number of parameters = ",sum([np.prod(p.size()) for p in model_param]),flush=True)
        enter = 0 
        # train/val model
        for e in range(1,epoch+1):
            print("Epoch: %d" % (e),flush=True)
            start = time.time()
            model = funcs.train_network(model,train_gen,init_rate, step_size, gamma,decay, n_train,batch_size,e,writer) 
            end = time.time()
            t_list.append(end-start)
            acc, _ = funcs.test_network(model, test_gen, n_test, batch_size,e,writer)
            if e > 3 and acc < 0.2:
                enter = 1
                break
          
        avg_time = sum(t_list)/len(t_list)

        #test model
        if enter == 0:
            for s in [1,1.1,1.5,2.0,2.5,3.0]:
                print("Test for scale by %2.3f"%(1/s), flush=True)
                if  set_ =="cifar":
                    sc_test_gen = dataset.get_gen(set_,"test", batch_size, 1/s,n_test,down, resize=resize)
                else:
                    sc_test_gen = dataset.get_gen(set_,"test", batch_size, 1/s,n_test,down)
                model = model.to(device)
                acc,_ = funcs.test_network(model, sc_test_gen, n_test, batch_size,-1,writer,val=False)
            print("Average time taken for one epoch =",str(avg_time),flush= True)
            torch.save(model, str(save_dir))
Ejemplo n.º 30
0
                                            train=False,
                                            download=True,
                                            transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=2)
    num_classes = 100
else:
    raise "only support dataset CIFAR10 or CIFAR100"

if args.model == "lenet":
    net = lenet.LeNet(num_classes=num_classes)

elif args.model == "vgg16":
    net = vgg.vgg16(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "vgg16_bn":
    net = vgg.vgg16_bn(num_classes=num_classes, pretrained=args.pretrain)

elif args.model == "resnet18":
    net = resnet.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnet34":
    net = resnet.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnet50":
    net = resnet.resnet50(num_classes=num_classes, pretrained=args.pretrain)

elif args.model == "resnetv2_18":
    net = resnet_v2.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnetv2_34":
    net = resnet_v2.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnetv2_50":