def load_models(args):
    if args.task == 'AE':
        if args.dataset == 'mnist':
            netG = generators.MNISTgenerator(args).cuda()
            netD = discriminators.MNISTdiscriminator(args).cuda()
            netE = encoders.MNISTencoder(args).cuda()

        elif args.dataset == 'cifar10':
            netG = generators.CIFARgenerator(args).cuda()
            netD = discriminators.CIFARdiscriminator(args).cuda()
            netE = encoders.CIFARencoder(args).cuda()

    if args.task == 'sr':
        if args.dataset == 'cifar10':
            netG = generators.genResNet(args).cuda()
            netD = discriminators.SRdiscriminatorCIFAR(args).cuda()
            vgg = vgg19_bn(pretrained=True).cuda()
            netE = VGGextraction(vgg)

        elif args.dataset == 'imagenet':
            netG = generators.genResNet(args, (3, 224, 224)).cuda(0)
            netD = discriminators.SRdiscriminatorCIFAR(args).cuda(0)
            netD = None
            vgg = vgg19_bn(pretrained=True).cuda(1)
            netE = VGGextraction(vgg).cuda(1)

    print(netG, netD, netE)
    return (netG, netD, netE)
def get_net(name):
    if name == 'densenet121':
        net = densenet121()
    elif name == 'densenet161':
        net = densenet161()
    elif name == 'densenet169':
        net = densenet169()
    elif name == 'googlenet':
        net = googlenet()
    elif name == 'inception_v3':
        net = inception_v3()
    elif name == 'mobilenet_v2':
        net = mobilenet_v2()
    elif name == 'resnet18':
        net = resnet18()
    elif name == 'resnet34':
        net = resnet34()
    elif name == 'resnet50':
        net = resnet50()
    elif name == 'resnet_orig':
        net = resnet_orig()
    elif name == 'vgg11_bn':
        net = vgg11_bn()
    elif name == 'vgg13_bn':
        net = vgg13_bn()
    elif name == 'vgg16_bn':
        net = vgg16_bn()
    elif name == 'vgg19_bn':
        net = vgg19_bn()
    else:
        print(f'{name} not a valid model name')
        sys.exit(0)

    return net.to(device)
Example #3
0
def select_img_network(img_net_type, image_size, latent_len):
    if img_net_type is 'resnet50':
        from resnet import resnet50
        img_encoder = resnet50(image_size, pretrained=True)
        from resnet import deresnet50
        img_decoder = deresnet50(image_size, latent_len)
    elif img_net_type is 'vgg19bn':
        from vgg import vgg19_bn
        img_encoder = vgg19_bn(image_size, pretrained=True)
        from vgg import devgg
        img_decoder = devgg(image_size)
    elif img_net_type is 'wrn':
        from wrn import WideResNet
        img_encoder = WideResNet(image_size)
        from resnet import deresnet50
        img_decoder = deresnet50(image_size, latent_len)
    elif img_net_type is 'wiser':
        from wiser import wiser
        img_encoder = wiser()
        from resnet import deresnet50
        img_decoder = deresnet50(image_size, latent_len)
    else:
        assert 1 < 0, 'Please indicate backbone network of image channel with any of resnet50/vgg19bn/wrn/wiser'

    return img_encoder, img_decoder
Example #4
0
    def init_net(self):

        net_args = {
            "pretrained": True,
            "n_input_channels": len(self.kwargs["static"]["imagery_bands"])
        }

        # https://pytorch.org/docs/stable/torchvision/models.html
        if self.kwargs["net"] == "resnet18":
            self.model = resnet.resnet18(**net_args)
        elif self.kwargs["net"] == "resnet34":
            self.model = resnet.resnet34(**net_args)
        elif self.kwargs["net"] == "resnet50":
            self.model = resnet.resnet50(**net_args)
        elif self.kwargs["net"] == "resnet101":
            self.model = resnet.resnet101(**net_args)
        elif self.kwargs["net"] == "resnet152":
            self.model = resnet.resnet152(**net_args)
        elif self.kwargs["net"] == "vgg11":
            self.model = vgg.vgg11(**net_args)
        elif self.kwargs["net"] == "vgg11_bn":
            self.model = vgg.vgg11_bn(**net_args)
        elif self.kwargs["net"] == "vgg13":
            self.model = vgg.vgg13(**net_args)
        elif self.kwargs["net"] == "vgg13_bn":
            self.model = vgg.vgg13_bn(**net_args)
        elif self.kwargs["net"] == "vgg16":
            self.model = vgg.vgg16(**net_args)
        elif self.kwargs["net"] == "vgg16_bn":
            self.model = vgg.vgg16_bn(**net_args)
        elif self.kwargs["net"] == "vgg19":
            self.model = vgg.vgg19(**net_args)
        elif self.kwargs["net"] == "vgg19_bn":
            self.model = vgg.vgg19_bn(**net_args)

        else:
            raise ValueError("Invalid network specified: {}".format(
                self.kwargs["net"]))

        #  run type: 1 = fine tune, 2 = fixed feature extractor
        #  - replace run type option with "# of layers to fine tune"
        if self.kwargs["run_type"] == 2:
            layer_count = len(list(self.model.parameters()))
            for layer, param in enumerate(self.model.parameters()):
                if layer <= layer_count - 5:
                    param.requires_grad = False

        # Parameters of newly constructed modules have requires_grad=True by default
        # get existing number for input features
        # set new number for output features to number of categories being classified
        # see: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
        if "resnet" in self.kwargs["net"]:
            num_ftrs = self.model.fc.in_features
            self.model.fc = nn.Linear(num_ftrs, self.ncats)
        elif "vgg" in self.kwargs["net"]:
            num_ftrs = self.model.classifier[6].in_features
            self.model.classifier[6] = nn.Linear(num_ftrs, self.ncats)
Example #5
0
def create_vgg19bn(load_weights=False):
    vgg19_bn_ft = vgg19_bn(pretrained=True)
    #vgg19_bn_ft.classifier = nn.Linear(25088, 3)
    vgg19_bn_ft.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),
                                           nn.ReLU(True), nn.Dropout(),
                                           nn.Linear(4096, 4096),
                                           nn.ReLU(True), nn.Dropout(),
                                           nn.Linear(4096, 3))

    vgg19_bn_ft = vgg19_bn_ft.cuda()

    vgg19_bn_ft.name = 'vgg19bn'
    return vgg19_bn_ft
Example #6
0
def main():
    print(f"Train numbers:{len(dataset)}")

    # first train run this line
    model = vgg19_bn().to(device)
    # Load model
    # if device == 'cuda':
    #     model = torch.load(MODEL_PATH + MODEL_NAME).to(device)
    # else:
    #     model = torch.load(MODEL_PATH + MODEL_NAME, map_location='cpu')
    # cast
    cast = torch.nn.CrossEntropyLoss().to(device)
    # Optimization
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=1e-8)
    step = 1
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()

        # cal one epoch time
        start = time.time()

        for images, labels in dataset_loader:
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = cast(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"Step [{step * BATCH_SIZE}/{NUM_EPOCHS * len(dataset)}], "
                  f"Loss: {loss.item():.8f}.")
            step += 1

        # cal train one epoch time
        end = time.time()
        print(f"Epoch [{epoch}/{NUM_EPOCHS}], "
              f"time: {end - start} sec!")

        # Save the model checkpoint
        torch.save(model, MODEL_PATH + '/' + MODEL_NAME)
    print(f"Model save to {MODEL_PATH + '/' + MODEL_NAME}.")
Example #7
0
def infer():
    best_epoch = 273   # from 0 start
    checkpoint_file = '../save_da_vgg19_bn/checkpoint_{}.tar'.format(best_epoch)
    checkpoint = torch.load(checkpoint_file)
    print('best epoch is', checkpoint['epoch'])
    
    # import vgg
    sys.path.insert(0, '../src')
    import vgg
    
    model = vgg.vgg19_bn()
    model.features = torch.nn.DataParallel(model.features)
    state = checkpoint['state_dict']
    model.load_state_dict(state)
    print(model.parameters)
def model_select(args):
    if args.usenet == "bn_alexnet":
        model = bn_alexnet(pretrained=False,
                           num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "vgg16":
        model = vgg16_bn(pretrained=False,
                         num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "vgg19":
        model = vgg19_bn(pretrained=False,
                         num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "resnet18":
        model = resnet18(pretrained=False,
                         num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "resnet34":
        model = resnet34(pretrained=False,
                         num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "resnet50":
        model = resnet50(pretrained=False,
                         num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "resnet101":
        model = resnet101(pretrained=False,
                          num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "resnet152":
        model = resnet152(pretrained=False,
                          num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "resnet200":
        model = resnet200(pretrained=False,
                          num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "resnext101":
        model = resnext101(pretrained=False,
                           num_classes=args.numof_classes).to(device)
        return model
    elif args.usenet == "densenet161":
        model = densenet161(pretrained=False,
                            num_classes=args.numof_classes).to(device)
        return model
Example #9
0
def create_vgg19bn(load_weights=False, freeze=False):
    vgg19_bn_ft = vgg19_bn(pretrained=True)
    if freeze:
        for param in vgg19_bn_ft.parameters():
            param.requires_grad = False
    #vgg19_bn_ft.classifier = nn.Linear(25088, 3)
    vgg19_bn_ft.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),
                                           nn.ReLU(True), nn.Dropout(),
                                           nn.Linear(4096, 4096),
                                           nn.ReLU(True), nn.Dropout(),
                                           nn.Linear(4096, 1), nn.Sigmoid())

    vgg19_bn_ft = vgg19_bn_ft.cuda()

    vgg19_bn_ft.name = 'vgg19bn'
    vgg19_bn_ft.max_num = 1
    #vgg19_bn_ft.batch_size = 32
    return vgg19_bn_ft
Example #10
0
    def __init__(self, w1=100, w2=0.1, w3=0.5, w4=1):
        """

        Return weighted sum of CoarseNet, EdgeNet, DetailsNet and Adversarial losses averaged over
        all losses in each mini-batch.

        :param w1: Weight of CoarseNet loss
        :param w2: Weight of EdgeNet loss
        :param w3: Weight of Local Patch loss
        :param w4: Weight of Adversarial loss
        """

        super(DetailsLoss, self).__init__()
        self.w1 = w1
        self.w2 = w2
        self.w3 = w3
        self.w4 = w4
        self.l1_loss = nn.L1Loss(reduction='mean')
        self.MSE_loss = nn.MSELoss(reduction='mean')
        self.BCE_loss = nn.BCELoss(reduction='mean')
        self.vgg19_bn = vgg19_bn(pretrained=True)
Example #11
0
def main(arg_seed, arg_timestamp):
    random_seed = arg_seed
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True  # need to set to True as well

    print('Random Seed {}\n'.format(arg_seed))

    # -- training parameters
    num_epoch = args.epoch
    milestone = [50, 75]
    batch_size = args.batch
    num_workers = 2

    weight_decay = 1e-3
    gamma = 0.2
    current_delta = args.delta

    lr = args.lr
    start_epoch = 0

    # -- specify dataset
    # data augmentation
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    trainset = Animal10(split='train', transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                              worker_init_fn=_init_fn, drop_last=True)

    testset = Animal10(split='test', transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size * 4, shuffle=False, num_workers=num_workers)

    num_class = 10

    print('train data size:', len(trainset))
    print('test data size:', len(testset))

    # -- create log file
    if arg_timestamp:
        time_stamp = time.strftime("%Y%m%d-%H%M%S")
        file_name = 'Ours(' + time_stamp + ').txt'
    else:
        file_name = 'Ours.txt'

    log_dir = check_folder('logs')
    file_name = os.path.join(log_dir, file_name)
    saver = open(file_name, "w")

    saver.write(args.__repr__() + "\n\n")
    saver.flush()

    # -- set network, optimizer, scheduler, etc
    net = vgg19_bn(num_classes=num_class, pretrained=False)
    net = nn.DataParallel(net)

    optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = net.to(device)

    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=milestone, gamma=gamma)
    criterion = torch.nn.CrossEntropyLoss()

    # -- misc
    iterations = 0
    f_record = torch.zeros([args.rollWindow, len(trainset), num_class])

    for epoch in range(start_epoch, num_epoch):
        train_correct = 0
        train_loss = 0
        train_total = 0

        net.train()

        for i, (images, labels, indices) in enumerate(trainloader):
            if images.size(0) == 1:  # when batch size equals 1, skip, due to batch normalization
                continue

            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_total += images.size(0)
            _, predicted = outputs.max(1)
            train_correct += predicted.eq(labels).sum().item()

            f_record[epoch % args.rollWindow, indices] = F.softmax(outputs.detach().cpu(), dim=1)

            iterations += 1
            if iterations % 100 == 0:
                cur_train_acc = train_correct / train_total * 100.
                cur_train_loss = train_loss / train_total
                cprint('epoch: {}\titerations: {}\tcurrent train accuracy: {:.4f}\ttrain loss:{:.4f}'.format(
                    epoch, iterations, cur_train_acc, cur_train_loss), 'yellow')

                if iterations % 5000 == 0:
                    saver.write('epoch: {}\titerations: {}\ttrain accuracy: {}\ttrain loss: {}\n'.format(
                        epoch, iterations, cur_train_acc, cur_train_loss))
                    saver.flush()

        train_acc = train_correct / train_total * 100.

        cprint('epoch: {}'.format(epoch), 'yellow')
        cprint('train accuracy: {:.4f}\ntrain loss: {:.4f}'.format(train_acc, train_loss), 'yellow')
        saver.write('epoch: {}\ntrain accuracy: {}\ntrain loss: {}\n'.format(epoch, train_acc, train_loss))
        saver.flush()

        exp_lr_scheduler.step()

        if epoch >= args.warm_up:
            f_x = f_record.mean(0)
            y_tilde = trainset.targets

            y_corrected, current_delta = lrt_correction(y_tilde, f_x, current_delta=current_delta, delta_increment=0.1)

            logging.info('Current delta:\t{}\n'.format(current_delta))

            trainset.update_corrupted_label(y_corrected)

        # testing
        net.eval()
        test_total = 0
        test_correct = 0
        with torch.no_grad():
            for i, (images, labels, _) in enumerate(testloader):
                images, labels = images.to(device), labels.to(device)

                outputs = net(images)

                test_total += images.size(0)
                _, predicted = outputs.max(1)
                test_correct += predicted.eq(labels).sum().item()

            test_acc = test_correct / test_total * 100.

        cprint('>> current test accuracy: {:.4f}'.format(test_acc), 'cyan')

        saver.write('>> current test accuracy: {}\n'.format(test_acc))
        saver.flush()

    saver.close()
Example #12
0
    batch_eval = model.batch_iter(x_test, y_test)
    total_acc = 0.0
    data_len = len(x_test)
    for x_batch, y_batch in batch_eval:
        batch_len = len(x_batch)
        outputs = net(x_batch)
        _, prediction = torch.max(outputs.data, 1)
        correct = (prediction == y_batch).sum().item()
        acc = correct / batch_len
        total_acc += acc * batch_len
    return total_acc / data_len


#### Model Training Configs ####
# cnn = Net().to(device)
net = vgg19_bn().to(device)
optimizer = Adam(net.parameters(), lr=0.001,
                 betas=(0.9, 0.999))  # 选用AdamOptimizer
loss_fn = nn.CrossEntropyLoss()  # 定义损失函数

#### Training ####
best_accuracy = 0
for i in range(args.EPOCHS):
    net.train()
    x_train, y_train, x_test, y_test = dataset.next_batch(args.BATCH)  # 读取数据

    x_train = torch.from_numpy(x_train)
    y_train = torch.from_numpy(y_train)
    x_train = x_train.float().to(device)
    y_train = y_train.long().to(device)
Example #13
0
    dataset = Rand_num(csv_path, img_path, 224, None)
    validationset = Rand_num(validation_label, validation_data, 224, None)
    sampler = RandomSampler(dataset)
    val_sampler = RandomSampler(validationset)
    loader = DataLoader(dataset, batch_size = batch_size, sampler = sampler, shuffle = False, num_workers=2)
    val_loader = DataLoader(validationset, batch_size = batch_size, sampler = val_sampler, shuffle = False, num_workers=2)
    print (datetime.datetime.now())
    print ('dataset comp')

#    dataiter = iter(loader)
#    images, labels = dataiter.next()
#    print (images)
#    images=tensor_to_img(images)
#    print (labels)
#    print (images)
    vgg_model = vgg.vgg19_bn(num_classes=7*7*10)
    net = Net(batch_size)
    if load_checkpoint:
        vgg_model.load_state_dict(torch.load(SAVE_PATH))
    print('network loaded')

    net.cuda()
    vgg_model.cuda()

    optimizer = optim.Adam(vgg_model.parameters(), lr=0.0001)
    #optimizer = optim.SGD(vgg_model.parameters(), lr=0.01)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True,threshold=0.000001, eps=1e-16)
    for epoch in range(2000):
        for i, data in enumerate(loader, 0):
            # get the inputs
            inputs, labels = data
Example #14
0
    if model_name == 'vgg11':
        model = vgg.vgg11(pretrained=pretrain_check)
    elif model_name == 'vgg11_bn':
        model = vgg.vgg11_bn(pretrained=pretrain_check)
    elif model_name == 'vgg13':
        model = vgg.vgg13(pretrained=pretrain_check)
    elif model_name == 'vgg13_bn':
        model = vgg.vgg13_bn(pretrained=pretrain_check)
    elif model_name == 'vgg16':
        model = vgg.vgg16(pretrained=pretrain_check)
    elif model_name == 'vgg16_bn':
        model = vgg.vgg16_bn(pretrained=pretrain_check)
    elif model_name == 'vgg19':
        model = vgg.vgg19(pretrained=pretrain_check)
    elif model_name == 'vgg19_bn':
        model = vgg.vgg19_bn(pretrained=pretrain_check)
    model.eval()
    model = torch.nn.DataParallel(model).cuda()

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer,
                                  factor=0.01,
                                  patience=patience,
                                  mode='min')

    criterion = nn.CrossEntropyLoss()
Example #15
0
 def __init__(self):
     super(Loss, self).__init__()
     self.l1_loss = nn.L1Loss()
     self.l2_loss = nn.MSELoss()
     self.vgg = vgg19_bn(pretrained=False)
Example #16
0
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.train_bs,
                                          shuffle=True,
                                          num_workers=3)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=2000,
                                         shuffle=False,
                                         num_workers=3)
truncloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.prune_bs,
                                          num_workers=3)

if args.vgg_type == 'vgg16':
    model = vgg.vgg16_bn(num_classes=N_CLASSES)
elif args.vgg_type == 'vgg19':
    model = vgg.vgg19_bn(num_classes=N_CLASSES)
model = model.to(device)

model.train()
x, y = map(lambda x: x.to(device), next(iter(trainloader)))
p = model(x)
loss = F.cross_entropy(p, y)
loss.backward()

agg_tensor = []
for child in model.modules():
    if isinstance(child, vgg.MaskedConv2d) or isinstance(
            child, vgg.MaskedLinear):
        agg_tensor += child.get_mask_grad()

agg_tensor = torch.cat(agg_tensor, dim=0).cpu().numpy()
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=128, shuffle=True, pin_memory=True)

loader_test = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=128, shuffle=False, pin_memory=True)


# Load the pretrained model
net = vgg19_bn()

# for m in net.modules():
#     if isinstance(m, nn.Conv2d):
#         m.set_mask(torch.rand((2,3,4)))
#print('ok')

if torch.cuda.is_available():
    print('CUDA ensabled.')
    net.cuda()
print("--- Pretrained network loaded ---")

criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'],
#                                 weight_decay=param['weight_decay'])
optimizer = torch.optim.SGD(net.parameters(), param['learning_rate'],
Example #18
0
def create_lowrank_model(orig_model):
    lrm = vgg.vgg19_bn()
    lrm.features = lowrankify(lrm.features, args['K_vals'])
    approx_lowrank_weights(orig_model.features, lrm.features)
    return lrm
Example #19
0
def get_model(args):
    network = args.network

    if network == 'vgg11':
        model = vgg.vgg11(num_classes=args.class_num)
    elif network == 'vgg13':
        model = vgg.vgg13(num_classes=args.class_num)
    elif network == 'vgg16':
        model = vgg.vgg16(num_classes=args.class_num)
    elif network == 'vgg19':
        model = vgg.vgg19(num_classes=args.class_num)
    elif network == 'vgg11_bn':
        model = vgg.vgg11_bn(num_classes=args.class_num)
    elif network == 'vgg13_bn':
        model = vgg.vgg13_bn(num_classes=args.class_num)
    elif network == 'vgg16_bn':
        model = vgg.vgg16_bn(num_classes=args.class_num)
    elif network == 'vgg19_bn':
        model = vgg.vgg19_bn(num_classes=args.class_num)
    elif network == 'resnet18':
        model = models.resnet18(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet34':
        model = models.resnet34(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet50':
        model = models.resnet50(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet101':
        model = models.resnet101(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet152':
        model = models.resnet152(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'densenet121':
        model = densenet.densenet121(num_classes=args.class_num)
    elif network == 'densenet169':
        model = densenet.densenet169(num_classes=args.class_num)
    elif network == 'densenet161':
        model = densenet.densenet161(num_classes=args.class_num)
    elif network == 'densenet201':
        model = densenet.densenet201(num_classes=args.class_num)

    return model