コード例 #1
0
    def __init__(self, pathModel, nnArchitecture, nnClassCount, transCrop):

        #---- Initialize the network
        if nnArchitecture == 'DENSE-NET-121': model = densenet121(False).cuda()
        elif nnArchitecture == 'DENSE-NET-169':
            model = densenet169(False).cuda()
        elif nnArchitecture == 'DENSE-NET-201':
            model = densenet201(False).cuda()

        model = torch.nn.DataParallel(model).cuda()

        modelCheckpoint = torch.load(pathModel)
        model.load_state_dict(modelCheckpoint['best_model_wts'], strict=False)

        self.model = model.module.features
        self.model.eval()

        #---- Initialize the weights
        self.weights = list(self.model.parameters())[-2]

        #---- Initialize the image transform - resize + normalize
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        transformList = []
        transformList.append(transforms.Resize(transCrop))
        transformList.append(transforms.ToTensor())
        transformList.append(normalize)

        self.transformSequence = transforms.Compose(transformList)
コード例 #2
0
def get_net(name):
    if name == 'densenet121':
        net = densenet121()
    elif name == 'densenet161':
        net = densenet161()
    elif name == 'densenet169':
        net = densenet169()
    elif name == 'googlenet':
        net = googlenet()
    elif name == 'inception_v3':
        net = inception_v3()
    elif name == 'mobilenet_v2':
        net = mobilenet_v2()
    elif name == 'resnet18':
        net = resnet18()
    elif name == 'resnet34':
        net = resnet34()
    elif name == 'resnet50':
        net = resnet50()
    elif name == 'resnet_orig':
        net = resnet_orig()
    elif name == 'vgg11_bn':
        net = vgg11_bn()
    elif name == 'vgg13_bn':
        net = vgg13_bn()
    elif name == 'vgg16_bn':
        net = vgg16_bn()
    elif name == 'vgg19_bn':
        net = vgg19_bn()
    else:
        print(f'{name} not a valid model name')
        sys.exit(0)

    return net.to(device)
コード例 #3
0
def load_model():
    model_path = os.path.join(
        os.path.dirname(
            os.path.abspath(__file__)), 'models/model18.pth'
        )
    # print(model_path)
    model = densenet169()
    if torch.cuda.is_available():
        model = model.cuda()
    model.load_state_dict(
        torch.load(model_path, map_location=lambda storage, loc: storage)['weights']
    )
    return model
コード例 #4
0
    def __init__(self, densenet_path, resnet_path, vgg_path):
        super(Ensemble, self).__init__()

        self.densenet = densenet169(pretrained=True, droprate=0)
        self.densenet.load_state_dict(torch.load(densenet_path))

        self.resnet = resnet101()
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, 1)
        self.resnet.load_state_dict(torch.load(resnet_path))

        self.vgg = vgg16_bn()
        self.vgg.classifier[6] = nn.Linear(4096, 1)
        self.vgg.load_state_dict(torch.load(vgg_path))
コード例 #5
0
ファイル: main.py プロジェクト: mrkovaliv/X-Ray
print('Wt1 valid:', Wt1['valid'])


class Loss(torch.nn.modules.Module):
    def __init__(self, Wt1, Wt0):
        super(Loss, self).__init__()
        self.Wt1 = Wt1
        self.Wt0 = Wt0

    def forward(self, inputs, targets, phase):
        loss = -(self.Wt1[phase] * targets * inputs.log() + self.Wt0[phase] *
                 (1 - targets) * (1 - inputs).log())
        return loss


model = densenet169(pretrained=True)
model = model.cuda()

criterion = Loss(Wt1, Wt0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode='min',
                                                       patience=1,
                                                       verbose=True)

# #### Train model
model = train_model(model,
                    criterion,
                    optimizer,
                    dataloaders,
                    scheduler,
コード例 #6
0
    extClassifier.load_state_dict(best_classifier_wts)
    model_serial = str(datetime.now().timestamp())
    # torch.save(model.state_dict(),
    #            os.path.join(r'C:\Users\wzuo\Developer\ML for APT\models', model_serial + '.model'))
    # torch.save(patchModel.state_dict(),
    #            os.path.join(r'C:\Users\wzuo\Developer\ML for APT\models', model_serial + '.patchModel'))
    torch.save(extClassifier.state_dict(),
               os.path.join(model_path, model_serial + '.clsmodel'))
    with open(os.path.join(model_path, model_serial + '.json'), 'w') as fp:
        json.dump(param_dict, fp)

    return extClassifier


model_global = MD.densenet201(pretrained=True)
model_local = MD.densenet169(pretrained=True)

param_dict['model_global'] = 'densenet201'
param_dict['model_local'] = 'densenet169'

# todo change architecture from here
#patch_model = MD.densenet121(pretrained=True)
#patch_model = MD.SimpleNet()
#param_dict['patch_base'] = 'densenet121'

for param in model_global.parameters():
    param.requires_grad = False

for param in model_local.parameters():
    param.requires_grad = False
# Parameters of newly constructed modules have requires_grad=True by default
コード例 #7
0
    mask = mask - np.min(mask)
    mask = mask / np.max(mask)

    mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    cv2.imwrite(save_dir, np.uint8(255 * cam))


if __name__ == "__main__":
    img_dir = "image2.png"
    input_shapes = (3, 320, 320)
    model = densenet169(input_shapes=input_shapes, num_classes=1)

    model.load_state_dict(torch.load('model.pth', map_location='cpu'))
    model.cuda()
    model.eval()

    _transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    img = default_loader(img_dir)
    img = _transform(img).unsqueeze(0).cuda()
    cls_score, cam, grad_cams = model(img)
    print(cls_score)
コード例 #8
0
def main():
    torch.manual_seed(23)
    # Band_num = 2
    # Tag_id = 4
    data_l = data_loader_(batch_size=64,proportion=0.85, shuffle=True, data_add=2, onehot=False, data_size=224, nb_classes=100)
    print data_l.train_length
    print data_l.test_length
    # print 'loading....'
    # trX = np.load('bddog/trX.npy')
    # trY = np.load('bddog/trY.npy')
    # print 'load train data'
    # trX = torch.from_numpy(trX).float()
    # trY = torch.from_numpy(trY).long()
    # teX = np.load('bddog/teX.npy').astype(np.float)
    # teY = np.load('bddog/teY.npy')
    # print 'load test data'
    # teX[:, 0, ...] -= MEAN_VALUE[0]
    # teX[:, 1, ...] -= MEAN_VALUE[1]
    # teX[:, 2, ...] -= MEAN_VALUE[2]
    # teX = torch.from_numpy(teX).float()
    # teY = torch.from_numpy(teY).long()
    # print 'numpy data to tensor'
    # n_examples = len(trX)
    # n_classes = 100
    # model = torch.load('models/resnet_model_pretrained_adam_2_2_SGD_1.pkl')
    model = densenet169(pretrained=True)
    print '==============================='
    print model
    # for param in model.parameters():
    #     param.requires_grad = False
    # model.classifier[-1] = nn.Linear(4096, 100)
    # n = model.classifier[-1].weight.size(1)
    # model.classifier[-1].weight.data.normal_(0, 0.01)
    # model.classifier[-1].bias.data.zero_()

    # VGG16 classifier层
    # model.classifier = nn.Sequential(
    #     nn.Linear(512 * 7 * 7, 4096),
    #     nn.ReLU(inplace=True),
    #     nn.Dropout(),
    #     nn.Linear(4096, 4096),
    #     nn.ReLU(inplace=True),
    #     nn.Dropout(),
    #     nn.Linear(4096, 100),
    # )
    # count = 0
    # print '==============================='
    # for module in model.modules():
    #     print '**** %d' % count
    #     print(module)
    #     count+=1
    # print '==============================='
    # count= 0
    # model.classifier[6] = nn.Linear(4096, 100)
    # for m in model.classifier:
    #     if count == 6:
    #         m = nn.Linear(4096, 100)
    #         if isinstance(m, nn.Conv2d):
    #             n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    #             m.weight.data.normal_(0, math.sqrt(2. / n))
    #             if m.bias is not None:
    #                 m.bias.data.zero_()
    #         elif isinstance(m, nn.BatchNorm2d):
    #             m.weight.data.fill_(1)
    #             m.bias.data.zero_()
    #         elif isinstance(m, nn.Linear):
    #             n = m.weight.size(1)
    #             m.weight.data.normal_(0, 0.01)
    #             m.bias.data.zero_()
    #     count+=1
    # try:
    #     print model.classifier[0]
    # except Exception as e:
    #     print e

    # print '==============================='
    # for module in model.modules()[-7:]:
    #     print '****'
    #     print(module)
    # resnet50 FC层
    # model.group1 = nn.Sequential(
    #     OrderedDict([
    #         ('fc', nn.Linear(2048, 100))
    #     ])
    # )
    model.classifier = nn.Linear(2208, 100)
    # ignored_params = list(map(id, model.group2.parameters()))
    # base_params = filter(lambda p: id(p) not in ignored_params,
    #                      model.parameters())
    # print '==============================='
    # print model
    model = model.cuda()
    loss = torch.nn.CrossEntropyLoss(size_average=True)
    loss = loss.cuda()
    # 对局部优化
    # optimizer = optim.SGD(model.group2.parameters(), lr=(1e-03), momentum=0.9,weight_decay=0.001)
    # optimizer = optim.Adam([{'params':model.layer4[2].parameters()},
    #                         {'params':model.group2.parameters()}
    #                         ],lr=(1e-04),eps=1e-08, betas=(0.9, 0.999), weight_decay=0.0005)
    # optimizer_a = optim.Adam([{'params':model.group2.parameters()}
    #                         ],lr=(1e-04))

    # optimizer = optim.Adam(model.group1.parameters(),lr=(1e-04))

    # optimizer.lr = (1e-04)
    # print optimizer.lr
    # print optimizer.momentum
    # for param_group in optimizer.param_groups:
    #     print param_group['lr']
    # 全局优化
    optimizer = optim.SGD(model.parameters(), lr=(0.001), momentum=0.9, weight_decay=0.0005)
    batch_size = data_l.batch_szie
    data_aug_num = data_l.data_add
    mini_batch_size = batch_size / data_aug_num
    epochs = 1000
    print '1'
    for e in range(epochs):
        cost = 0.0
        train_acc = 0.0
        if e == 4:
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * 0.3
        if e == 8:
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * 0.3


        num_batches_train = data_l.train_length / mini_batch_size
        print num_batches_train
        train_acc= 0.0
        cost = 0.0
        k =1
        for k in range(num_batches_train+1):
            batch_train_data_X, batch_train_data_Y = data_l.get_train_data()
            batch_train_data_X = batch_train_data_X.transpose(0, 3, 1, 2)
            # batch_train_data_X[:, 0, ...] -= MEAN_VALUE[0]
            # batch_train_data_X[:, 1, ...] -= MEAN_VALUE[1]
            # batch_train_data_X[:, 2, ...] -= MEAN_VALUE[2]
            # print batch_train_data_X.shape
            # print batch_train_data_Y.shape
            # batch_train_data_X = preprocess_input(batch_train_data_X)
            torch_batch_train_data_X = torch.from_numpy(batch_train_data_X).float()
            torch_batch_train_data_Y = torch.from_numpy(batch_train_data_Y).long()
            cost_temp, acc_temp = train(model, loss, optimizer, torch_batch_train_data_X, torch_batch_train_data_Y)
            train_acc += acc_temp
            cost += cost_temp
            if (k + 1) % 10 == 0:
                print 'now step train loss is : %f' % (cost_temp)
                print 'now step train acc is : %f' % (acc_temp)
            if (k + 1) % 20 == 0:
                print 'all average train loss is : %f' % (cost / (k + 1))
                print 'all average train acc is : %f' % (train_acc / (k + 1))
            # if (k + 1) % 100 == 0:
            #     model.training = False
            #     acc = 0.0
            #     num_batches_test = data_l.test_length / batch_size
            #     for j in range(num_batches_test+1):
            #         teX, teY = data_l.get_test_data()
            #         teX = teX.transpose(0, 3, 1, 2)
            #         # teX[:, 0, ...] -= MEAN_VALUE[0]
            #         # teX[:, 1, ...] -= MEAN_VALUE[1]
            #         # teX[:, 2, ...] -= MEAN_VALUE[2]
            #         teX = preprocess_input(teX)
            #         teX = torch.from_numpy(teX).float()
            #         # teY = torch.from_numpy(teY).long()
            #         predY = predict(model, teX)
            #         # print predY.dtype
            #         # print teY[start:end]
            #         acc += 1. * np.mean(predY == teY)
            #         # print ('Epoch %d ,Step %d, acc = %.2f%%'%(e,k,100.*np.mean(predY==teY[start:end])))
            #     model.training = True
            #     print 'Epoch %d ,Step %d, all test acc is : %f' % (e, k, acc / num_batches_test)
            #     torch.save(model, 'models/inception_model_pretrained_%s_%s_%s_1.pkl' % ('SGD', str(e), str(k)))
        # model.training = False
        acc = 0.0
        num_batches_test = data_l.test_length / batch_size
        for j in range(num_batches_test+1):
            teX, teY = data_l.get_test_data()
            teX = teX.transpose(0, 3, 1, 2)
            # teX[:, 0, ...] -= 0.5
            # teX[:, 1, ...] -= 0.5
            # teX[:, 2, ...] -= 0.5
            # teX = preprocess_input(teX)
            teX = torch.from_numpy(teX).float()
            # teY = torch.from_numpy(teY).long()
            predY = predict(model, teX)
            # print predY.dtype
            # print teY[start:end]
            acc += 1. * np.mean(predY == teY)
            # print ('Epoch %d ,Step %d, acc = %.2f%%'%(e,k,100.*np.mean(predY==teY[start:end])))
        # model.training = True
        print 'Epoch %d ,Step %d, all test acc is : %f' % (e, k, acc / num_batches_test)
        torch.save(model, 'models/densenet161_model_pretrained_%s_%s_%s_4.pkl' % ('SGD', str(e), str(k)))
    print 'train over'
コード例 #9
0
    data_cat = ['train', 'valid']  # data categories
    dataloaders = get_dataloaders(study_data, batch_size)
    dataset_sizes = {x: len(study_data[x]) for x in data_cat}

    # tai = total abnormal images, tni = total normal images
    tai = {x: get_count(study_data[x], 'positive') for x in data_cat}
    tni = {x: get_count(study_data[x], 'negative') for x in data_cat}

    # Find the weights of abnormal images and normal images
    Wt1 = {x: (tni[x] / (tni[x] + tai[x])) for x in data_cat}
    Wt0 = {x: (tai[x] / (tni[x] + tai[x])) for x in data_cat}

    # For training & testing individual models
    if model_type != 'ensemble':
        if model_type == "dense":
            model = densenet169(pretrained=True, droprate=droprate)

        elif model_type == "vgg":
            model = vgg16_bn(pretrained=True)
            model.classifier[6] = nn.Linear(4096, 1)

        elif model_type == "shufflenet":
            model = shufflenet_v2_x1_0()
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 1)

        else:
            model = resnet101()
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, 1)
torch_dataset_train = data.TensorDataset(train_data, train_label)
torch_dataset_val = data.TensorDataset(val_data, val_label)
train_loader = data.DataLoader(dataset=torch_dataset_train,
                               batch_size=BATCH_SIZE,
                               shuffle=True,
                               drop_last=True)
val_loader = data.DataLoader(dataset=torch_dataset_val,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             drop_last=True)
test_loader = data.DataLoader(test_data,
                              batch_size=test_num,
                              shuffle=False,
                              drop_last=True)

model = densenet169()
model = ACSConverter(model)
model = model.cuda()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# train and evaluate
for epoch in range(NUM_EPOCHS):
    print(epoch)
    train_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    for step, (batch_x, batch_y) in enumerate(train_loader):
コード例 #11
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #12
0
def get_model(args):
    network = args.network

    if network == 'vgg11':
        model = vgg.vgg11(num_classes=args.class_num)
    elif network == 'vgg13':
        model = vgg.vgg13(num_classes=args.class_num)
    elif network == 'vgg16':
        model = vgg.vgg16(num_classes=args.class_num)
    elif network == 'vgg19':
        model = vgg.vgg19(num_classes=args.class_num)
    elif network == 'vgg11_bn':
        model = vgg.vgg11_bn(num_classes=args.class_num)
    elif network == 'vgg13_bn':
        model = vgg.vgg13_bn(num_classes=args.class_num)
    elif network == 'vgg16_bn':
        model = vgg.vgg16_bn(num_classes=args.class_num)
    elif network == 'vgg19_bn':
        model = vgg.vgg19_bn(num_classes=args.class_num)
    elif network == 'resnet18':
        model = models.resnet18(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet34':
        model = models.resnet34(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet50':
        model = models.resnet50(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet101':
        model = models.resnet101(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet152':
        model = models.resnet152(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'densenet121':
        model = densenet.densenet121(num_classes=args.class_num)
    elif network == 'densenet169':
        model = densenet.densenet169(num_classes=args.class_num)
    elif network == 'densenet161':
        model = densenet.densenet161(num_classes=args.class_num)
    elif network == 'densenet201':
        model = densenet.densenet201(num_classes=args.class_num)

    return model
コード例 #13
0
def generate_model(opt):
    assert opt.mode in ['score', 'feature']
    if opt.mode == 'score':
        last_fc = True
    elif opt.mode == 'feature':
        last_fc = False

    assert opt.model_name in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model_name == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
    elif opt.model_name == 'wideresnet':
        assert opt.model_depth in [50]

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
    elif opt.model_name == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
    elif opt.model_name == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
    elif opt.model_name == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

    return model
コード例 #14
0
ファイル: model.py プロジェクト: whudonggu/classification
    def inference(self, inputs, is_training=True, reuse=False):
        """
        网络前向传播计算,输出logits张量,keep_prob为drop out参数,预测时置为1
        """
        if self.model == "vgg16":
            if self.scope == None:
                self.scope = 'vgg_16'

            logits = vgg16.vgg_16(inputs=inputs,
                                  num_classes=self.n_classes,
                                  is_training=is_training,
                                  reuse=reuse,
                                  dropout_keep_prob=self.dropprob,
                                  scope=self.scope,
                                  weight_decay=self.l2_rate,
                                  use_batch_norm=self.use_bn,
                                  batch_norm_decay=self.bn_decay,
                                  batch_norm_epsilon=self.bn_epsilon,
                                  batch_norm_scale=self.bn_scale)
        elif self.model == "res50":
            if self.scope == None:
                self.scope = 'resnet_v1_50'
            logits = resnet.resnet_50(inputs=inputs,
                                      num_classes=self.n_classes,
                                      is_training=is_training,
                                      reuse=reuse,
                                      use_se_module=False,
                                      scope=self.scope,
                                      weight_decay=self.l2_rate,
                                      use_batch_norm=self.use_bn,
                                      batch_norm_decay=self.bn_decay,
                                      batch_norm_epsilon=self.bn_epsilon,
                                      batch_norm_scale=self.bn_scale)
        elif self.model == "res50_senet":
            if self.scope == None:
                self.scope = 'resnet_v1_50'
            logits = resnet.resnet_50(inputs=inputs,
                                      num_classes=self.n_classes,
                                      is_training=is_training,
                                      reuse=reuse,
                                      use_se_module=True,
                                      scope=self.scope,
                                      weight_decay=self.l2_rate,
                                      use_batch_norm=self.use_bn,
                                      batch_norm_decay=self.bn_decay,
                                      batch_norm_epsilon=self.bn_epsilon,
                                      batch_norm_scale=self.bn_scale)
        elif self.model == "densenet":
            if self.scope == None:
                self.scope = 'densenet169'
            logits = densenet.densenet169(inputs=inputs,
                                          num_classes=self.n_classes,
                                          is_training=is_training,
                                          reuse=reuse,
                                          dropout_keep_prob=self.dropprob,
                                          scope=self.scope,
                                          weight_decay=self.l2_rate,
                                          use_batch_norm=self.use_bn,
                                          batch_norm_decay=self.bn_decay,
                                          batch_norm_epsilon=self.bn_epsilon,
                                          batch_norm_scale=self.bn_scale)
        else:
            raise ValueError("Unknown cost function: " % cost_name)
        return tf.squeeze(logits)