예제 #1
0
def mnasnet1_3(pretrained: bool,
               progress: bool = True,
               requires_grad: bool = True):
    model = models.mnasnet1_3(pretrained=pretrained, progress=progress)
    for params in model.parameters():
        params.requires_grad = requires_grad
    return model
예제 #2
0
def eval(model_name: str) -> None:
    if model_name == "mnasnet0_5":
        model = models.mnasnet0_5(num_classes=1000, pretrained=True).cuda()
    elif model_name == "mnasnet0_75":
        model = models.mnasnet0_75(num_classes=1000).cuda()
    elif model_name == "mnasnet1_0":
        model = models.mnasnet1_0(num_classes=1000, pretrained=True).cuda()
    elif model_name == "mnasnet1_3":
        model = models.mnasnet1_3(num_classes=1000).cuda()
    else:
        raise ValueError("Don't know how to evaluate {}".format(model_name))

    loss = torch.nn.CrossEntropyLoss().cuda()
    val_dataset = imagenet.validation(IMAGENET_DIR)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=512,
                                             shuffle=False,
                                             num_workers=8,
                                             pin_memory=True)
    all_metrics = metrics.default()

    model.eval()
    with torch.no_grad():
        val_losses = []
        metric_dict = collections.defaultdict(list)
        for batch_index, (inputs, truth) in enumerate(tqdm.tqdm(val_loader)):
            outputs = model(inputs.cuda()).cpu()
            val_losses.append(loss(outputs, truth).item())
            for name, metric_fn in all_metrics:
                metric_dict[name].append(metric_fn(outputs, truth))

        print(
            numpy.mean(val_losses),
            list([(name, numpy.mean(vals))
                  for name, vals in metric_dict.items()]))
예제 #3
0
def get_mnasnet1_3(class_num):
    model = models.mnasnet1_3(pretrained=True)
    set_parameter_requires_grad(model)
    model.name = 'mnasnet1_3'

    n_inputs = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(n_inputs, class_num)

    return model, 224
예제 #4
0
 def DefaultBackBone(self,pretrained=False, progress=True):
     """
     MNASNet with depth multiplier of 1.3 from “MnasNet: Platform-Aware Neural Architecture Search for Mobile”. 
     :param 
     pretrained: If True, returns a model pre-trained on ImageNet 
     :type pretrained: bool 
     :param 
     progress: If True, displays a progress bar of the download to stderr 
     :type progress: bool
     """
     self.model=models.mnasnet1_3(pretrained=pretrained, progress=progress)
예제 #5
0
def train(model_name: str) -> None:
    if model_name == "mnasnet0_5":
        model = models.mnasnet0_5(num_classes=1000).cuda()
    elif model_name == "mnasnet0_75":
        model = models.mnasnet0_75(num_classes=1000).cuda()
    elif model_name == "mnasnet1_0":
        model = models.mnasnet1_0(num_classes=1000).cuda()
    elif model_name == "mnasnet1_3":
        model = models.mnasnet1_3(num_classes=1000).cuda()
    else:
        raise ValueError("Don't know how to train {}".format(model_name))
    params = TRAINING_PARAMS[model_name]

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=params["base_lr"],
                                momentum=params["momentum"],
                                weight_decay=params["weight_decay"],
                                nesterov=True)
    loss = torch.nn.CrossEntropyLoss().cuda()
    lr_schedule = CosineWithWarmup(optimizer, WARMUP, 0.1,
                                   params["num_epochs"])

    train_dataset = imagenet.training(IMAGENET_DIR)
    val_dataset = imagenet.validation(IMAGENET_DIR)

    message = "Training {}, cosine annealing with warmup. Parameters: {}".format(
        model_name, params)
    train = trainer.Trainer(".",
                            message,
                            "multiclass_classification",
                            True,
                            model,
                            optimizer,
                            loss,
                            lr_schedule,
                            metrics.default(),
                            cudnn_autotune=True)

    train.fit(train_dataset,
              val_dataset,
              num_epochs=params["num_epochs"],
              batch_size=params["batch_size"],
              num_workers=multiprocessing.cpu_count())
elif TopModelName=='resnext101_32x8d':
	pre_model = models.resnext101_32x8d(pretrained=True, progress=True)
elif TopModelName=="wide_resnet50_2":
	pre_model = models.wide_resnet50_2(pretrained=True, progress=True)
elif TopModelName=='wide_resnet101_2':
	pre_model = models.wide_resnet101_2(pretrained=True, progress=True)
##############################

elif TopModelName=="mnasnet0_5":
	pre_model = models.mnasnet0_5(pretrained=True, progress=True)
elif TopModelName=="mnasnet0_75":
	pre_model = models.mnasnet0_75(pretrained=True, progress=True)
elif TopModelName=="mnasnet1_0":
	pre_model = models.mnasnet1_0(pretrained=True, progress=True)
elif TopModelName=="mnasnet1_3":
	pre_model = models.mnasnet1_3(pretrained=True, progress=True)
	
elif TopModelName=="efficientnet-b0":
	from efficientnet_pytorch import EfficientNet
	pre_model = EfficientNet.from_pretrained('efficientnet-b0')
######################### Testing Models #################################
elif TopModelName=="resnext50_32x4d_test":
	model = models.resnext50_32x4d(pretrained=True, progress=True)
	model.conv1=nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
	model.fc.out_features=2
	test=True
elif TopModelName=="vgg16_test":
	model=models.vgg16(pretrained=True, progress=True)
	model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	model.classifier[6].out_features = 2
	test=True
예제 #7
0
 def test_mnasnet1_3(self):
     process_model(models.mnasnet1_3(self.pretrained), self.image,
                   _C_tests.forward_mnasnet1_3, 'MNASNet1_3')
예제 #8
0
def modeldict(mmd):
    if mmd == 'alexnet_pt': out_model = models.alexnet(pretrained=True)
    elif mmd == 'alexnet': out_model = models.alexnet(pretrained=False)
    elif mmd == 'vgg11_pt': out_model = models.vgg11(pretrained=True)
    elif mmd == 'vgg11': out_model = models.vgg11(pretrained=False)
    elif mmd == 'vgg11bn_pt': out_model = models.vgg11_bn(pretrained=True)
    elif mmd == 'vgg11bn': out_model = models.vgg11_bn(pretrained=False)
    elif mmd == 'vgg13_pt': out_model = models.vgg13(pretrained=True)
    elif mmd == 'vgg13': out_model = models.vgg13(pretrained=False)
    elif mmd == 'vgg13bn_pt': out_model = models.vgg13_bn(pretrained=True)
    elif mmd == 'vgg13bn': out_model = models.vgg13_bn(pretrained=False)
    elif mmd == 'vgg16_pt': out_model = models.vgg16(pretrained=True)
    elif mmd == 'vgg16': out_model = models.vgg16(pretrained=False)
    elif mmd == 'vgg16bn_pt': out_model = models.vgg16_bn(pretrained=True)
    elif mmd == 'vgg16bn': out_model = models.vgg16_bn(pretrained=False)
    elif mmd == 'vgg19_pt': out_model = models.vgg19(pretrained=True)
    elif mmd == 'vgg19': out_model = models.vgg19(pretrained=False)
    elif mmd == 'vgg19bn_pt': out_model = models.vgg19_bn(pretrained=True)
    elif mmd == 'vgg19bn': out_model = models.vgg19_bn(pretrained=False)
    elif mmd == 'resnet18': out_model = models.resnet18(pretrained=False)
    elif mmd == 'resnet18_pt': out_model = models.resnet18(pretrained=True)
    elif mmd == 'resnet34': out_model = models.resnet34(pretrained=False)
    elif mmd == 'resnet34_pt': out_model = models.resnet34(pretrained=True)
    elif mmd == 'resnet50': out_model = models.resnet50(pretrained=False)
    elif mmd == 'resnet50_pt': out_model = models.resnet50(pretrained=True)
    elif mmd == 'resnet101': out_model = models.resnet101(pretrained=False)
    elif mmd == 'resnet101_pt': out_model = models.resnet101(pretrained=True)
    elif mmd == 'resnet152': out_model = models.resnet152(pretrained=False)
    elif mmd == 'resnet152_pt': out_model = models.resnet152(pretrained=True)
    elif mmd == 'squeezenet10_pt':
        out_model = models.squeezenet1_0(pretrained=True)
    elif mmd == 'squeezenet10':
        out_model = models.squeezenet1_0(pretrained=False)
    elif mmd == 'squeezenet11_pt':
        out_model = models.squeezenet1_1(pretrained=True)
    elif mmd == 'squeezenet11':
        out_model = models.squeezenet1_1(pretrained=False)
    elif mmd == 'densenet121_pt':
        out_model = models.densenet121(pretrained=True)
    elif mmd == 'densenet121':
        out_model = models.densenet121(pretrained=False)
    elif mmd == 'densenet161_pt':
        out_model = models.densenet161(pretrained=True)
    elif mmd == 'densenet161':
        out_model = models.densenet161(pretrained=False)
    elif mmd == 'densenet169_pt':
        out_model = models.densenet169(pretrained=True)
    elif mmd == 'densenet169':
        out_model = models.densenet169(pretrained=False)
    elif mmd == 'densenet201_pt':
        out_model = models.densenet201(pretrained=True)
    elif mmd == 'densenet201':
        out_model = models.densenet201(pretrained=False)
    elif mmd == 'inception_pt':
        out_model = models.inception_v3(pretrained=True, aux_logits=False)
    elif mmd == 'inception':
        out_model = models.inception_v3(pretrained=False, aux_logits=False)
    elif mmd == 'inceptionv4_pt':
        out_model = pretrainedmodels.__dict__['inceptionv4'](
            num_classes=1000, pretrained='imagenet')
    elif mmd == 'inceptionv4':
        out_model = pretrainedmodels.__dict__['inceptionv4'](num_classes=1000,
                                                             pretrained=None)
    elif mmd == 'inceptionresnetv2_pt':
        out_model = pretrainedmodels.__dict__['inceptionresnetv2'](
            num_classes=1000, pretrained='imagenet')
    elif mmd == 'inceptionresnetv2':
        out_model = pretrainedmodels.__dict__['inceptionresnetv2'](
            num_classes=1000, pretrained=None)
    elif mmd == 'googlenet_pt':
        out_model = models.googlenet(pretrained=True, aux_logits=False)
    elif mmd == 'googlenet':
        out_model = models.googlenet(pretrained=False, aux_logits=False)
    elif mmd == 'shufflenet05_pt':
        out_model = models.shufflenet_v2_x0_5(pretrained=True)
    elif mmd == 'shufflenet05':
        out_model = models.shufflenet_v2_x0_5(pretrained=False)
    elif mmd == 'shufflenet10_pt':
        out_model = models.shufflenet_v2_x1_0(pretrained=True)
    elif mmd == 'shufflenet10':
        out_model = models.shufflenet_v2_x1_0(pretrained=False)
    elif mmd == 'shufflenet20':
        out_model = models.shufflenet_v2_x2_0(pretrained=False)
    elif mmd == 'mobilenet_pt':
        out_model = models.mobilenet_v2(pretrained=True)
    elif mmd == 'mobilenet':
        out_model = models.mobilenet_v2(pretrained=False)
    elif mmd == 'resnext50_32x4d_pt':
        out_model = models.resnext50_32x4d(pretrained=True)
    elif mmd == 'resnext50_32x4d':
        out_model = models.resnext50_32x4d(pretrained=False)
    elif mmd == 'resnext101_32x8d_pt':
        out_model = models.resnext101_32x8d(pretrained=True)
    elif mmd == 'resnext101_32x8d':
        out_model = models.resnext101_32x8d(pretrained=False)
    elif mmd == 'wide_resnet50_2_pt':
        out_model = models.wide_resnet50_2(pretrained=True)
    elif mmd == 'wide_resnet50_2':
        out_model = models.wide_resnet50_2(pretrained=False)
    elif mmd == 'wide_resnet101_2_pt':
        out_model = models.wide_resnet101_2(pretrained=True)
    elif mmd == 'wide_resnet101_2':
        out_model = models.wide_resnet101_2(pretrained=False)
    elif mmd == 'mnasnet05_pt':
        out_model = models.mnasnet0_5(pretrained=True)
    elif mmd == 'mnasnet05':
        out_model = models.mnasnet0_5(pretrained=False)
    elif mmd == 'mnasnet075':
        out_model = models.mnasnet0_75(pretrained=False)
    elif mmd == 'mnasnet10_pt':
        out_model = models.mnasnet1_0(pretrained=True)
    elif mmd == 'mnasnet10':
        out_model = models.mnasnet1_0(pretrained=False)
    elif mmd == 'mnasnet13':
        out_model = models.mnasnet1_3(pretrained=False)
    elif mmd == 'xception_pt':
        out_model = pretrainedmodels.__dict__['xception'](
            num_classes=1000, pretrained='imagenet')
    elif mmd == 'xception':
        out_model = pretrainedmodels.__dict__['xception'](num_classes=1000,
                                                          pretrained=None)
    elif mmd == 'dpn68_pt':
        out_model = pretrainedmodels.__dict__['dpn68'](num_classes=1000,
                                                       pretrained='imagenet')
    elif mmd == 'dpn68':
        out_model = pretrainedmodels.__dict__['dpn68'](num_classes=1000,
                                                       pretrained=None)
    elif mmd == 'dpn98_pt':
        out_model = pretrainedmodels.__dict__['dpn98'](num_classes=1000,
                                                       pretrained='imagenet')
    elif mmd == 'dpn98':
        out_model = pretrainedmodels.__dict__['dpn98'](num_classes=1000,
                                                       pretrained=None)
    elif mmd == 'dpn131_pt':
        out_model = pretrainedmodels.__dict__['dpn131'](num_classes=1000,
                                                        pretrained='imagenet')
    elif mmd == 'dpn131':
        out_model = pretrainedmodels.__dict__['dpn131'](num_classes=1000,
                                                        pretrained=None)
    elif mmd == 'dpn68b_pt':
        out_model = pretrainedmodels.__dict__['dpn68b'](
            num_classes=1000, pretrained='imagenet+5k')
    elif mmd == 'dpn68b':
        out_model = pretrainedmodels.__dict__['dpn68b'](num_classes=1000,
                                                        pretrained=None)
    elif mmd == 'dpn92_pt':
        out_model = pretrainedmodels.__dict__['dpn92'](
            num_classes=1000, pretrained='imagenet+5k')
    elif mmd == 'dpn92':
        out_model = pretrainedmodels.__dict__['dpn92'](num_classes=1000,
                                                       pretrained=None)
    elif mmd == 'dpn107_pt':
        out_model = pretrainedmodels.__dict__['dpn107'](
            num_classes=1000, pretrained='imagenet+5k')
    elif mmd == 'dpn107':
        out_model = pretrainedmodels.__dict__['dpn107'](num_classes=1000,
                                                        pretrained=None)
    elif mmd == 'fbresnet152_pt':
        out_model = pretrainedmodels.__dict__['fbresnet152'](
            num_classes=1000, pretrained='imagenet')
    elif mmd == 'fbresnet152':
        out_model = pretrainedmodels.__dict__['fbresnet152'](num_classes=1000,
                                                             pretrained=None)
    else:
        out_model = None
        print('Invalid model name. Terminated.')
        exit(0)
    return out_model
예제 #9
0
def main(train_loader, train_datasets, valid_loader):
    '''args: model_arch, train_loader, train_datasets, valid_loader, epochs, learning_rate, hidden_units, device, save_dir

    saves checkpoint under save_dir

    returns: checkpoint with {'epoch_tot',
                              'model',
                              'criterion',
                              'optimizer',
                              'optimizer.state_dict',
                              'model.state_dict',
                              'model.class_to_idx'}'''

    arg = input_args()
    model_arch = arg.arch
    hidden_units = arg.hidden_units
    learning_rate = arg.learning_rate
    device = arg.device
    epochs = arg.epochs
    save_dir = arg.save_dir

    if model_arch == 'alexnet':
        gs_vgg = models.alexnet(pretrained=True)
    elif model_arch == 'vgg11':
        gs_vgg = models.vgg11(pretrained=True)
    elif model_arch == 'vgg11_bn':
        gs_vgg = models.vgg11_bn(pretrained=True)
    elif model_arch == 'vgg13':
        gs_vgg = models.vgg13(pretrained=True)
    elif model_arch == 'vgg13_bn':
        gs_vgg = models.vgg13_bn(pretrained=True)
    elif model_arch == 'vgg16':
        gs_vgg = models.vgg16(pretrained=True)
    elif model_arch == 'vgg16_bn':
        gs_vgg = models.vgg16_bn(pretrained=True)
    elif model_arch == 'vgg19':
        gs_vgg = models.vgg19(pretrained=True)
    elif model_arch == 'vgg19_bn':
        gs_vgg = models.vgg19_bn(pretrained=True)
    elif model_arch == 'resnet18':
        gs_vgg = models.resnet18(pretrained=True)
    elif model_arch == 'resnet34':
        gs_vgg = models.resnet34(pretrained=True)
    elif model_arch == 'resnet50':
        gs_vgg = models.resnet50(pretrained=True)
    elif model_arch == 'resnet101':
        gs_vgg = models.resnet101(pretrained=True)
    elif model_arch == 'resnet152':
        gs_vgg = models.resnet152(pretrained=True)
    elif model_arch == 'squeezenet1_0':
        gs_vgg = models.squeezenet1_0(pretrained=True)
    elif model_arch == 'squeezenet1_1':
        gs_vgg = models.squeezenet1_1(pretrained=True)
    elif model_arch == 'densenet121':
        gs_vgg = models.densenet121(pretrained=True)
    elif model_arch == 'densenet169':
        gs_vgg = models.densenet169(pretrained=True)
    elif model_arch == 'densenet161':
        gs_vgg = models.densenet161(pretrained=True)
    elif model_arch == 'densenet201':
        gs_vgg = models.densenet201(pretrained=True)
    elif model_arch == 'inception_v3':
        gs_vgg = models.inception_v3(pretrained=True)
    elif model_arch == 'googlenet':
        gs_vgg = models.googlenet(pretrained=True)
    elif model_arch == 'shufflenet_v2_x0_5':
        gs_vgg = models.shufflenet_v2_x0_5(pretrained=True)
    elif model_arch == 'shufflenet_v2_x1_0':
        gs_vgg = models.shufflenet_v2_x1_0(pretrained=True)
    elif model_arch == 'shufflenet_v2_x1_5':
        gs_vgg = models.shufflenet_v2_x1_5(pretrained=True)
    elif model_arch == 'shufflenet_v2_x2_0':
        gs_vgg = models.shufflenet_v2_x2_0(pretrained=True)
    elif model_arch == 'mobilenet_v2':
        gs_vgg = models.mobilenet_v2(pretrained=True)
    elif model_arch == 'resnext50_32x4d':
        gs_vgg = models.resnext50_32x4d(pretrained=True)
    elif model_arch == 'resnext101_32x8d':
        gs_vgg = models.resnext101_32x8d(pretrained=True)
    elif model_arch == 'wide_resnet50_2':
        gs_vgg = models.wide_resnet50_2(pretrained=True)
    elif model_arch == 'wide_resnet101_2':
        gs_vgg = models.wide_resnet101_2(pretrained=True)
    elif model_arch == 'mnasnet0_5':
        gs_vgg = models.mnasnet0_5(pretrained=True)
    elif model_arch == 'mnasnet0_75':
        gs_vgg = models.mnasnet0_75(pretrained=True)
    elif model_arch == 'mnasnet1_0':
        gs_vgg = models.mnasnet1_0(pretrained=True)
    elif model_arch == 'mnasnet1_3':
        gs_vgg = models.mnasnet1_3(pretrained=True)

    epoch_tot = 0

    for parameters in gs_vgg.parameters():
        parameters.requires_grad = False

    try:
        input_layer = gs_vgg.classifier[0].in_features
    except:
        input_layer = gs_vgg.classifier.in_features
    hidden_layers = [(int(hidden_units * 0.68)), (int(hidden_units * 0.32))]
    output_layer = len(train_loader)

    gs_vgg.classifier = nn.Sequential(
        nn.Linear(input_layer, hidden_layers[0]),
        nn.ReLU(),
        nn.Dropout(p=0.3),
        nn.Linear(hidden_layers[0], hidden_layers[1]),
        nn.ReLU(),
        #nn.Linear(hidden_layers[1], output_layer),
        nn.Linear(hidden_layers[1], 102),
        nn.LogSoftmax(dim=1))

    criterion = nn.NLLLoss()
    optimizer = optim.Adam(gs_vgg.classifier.parameters(), lr=learning_rate)

    gs_vgg.to(device)
    step_num = 0
    epoch = epochs
    running_loss = 0
    print_every = 10
    for e in range(epoch):
        epoch_tot += 1
        for images, labels in train_loader:

            gs_vgg.train()

            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            output = gs_vgg.forward(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            step_num += 1

            if step_num % print_every == 0:
                test_loss = 0
                accuracy = 0
                gs_vgg.eval()
                for images, labels in valid_loader:
                    images, labels = images.to(device), labels.to(device)
                    output = gs_vgg.forward(images)
                    loss = criterion(output, labels)

                    test_loss += loss.item()

                    prob = torch.exp(output)
                    top_p, top_class = prob.topk(1, dim=1)
                    equals = top_class == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(
                        torch.FloatTensor)).item()

                print(
                    f"Total Epochs: {epoch_tot}.. "
                    f"Train loss: {running_loss/print_every:.3f}.. "
                    f"Test loss: {test_loss/len(valid_loader):.3f}.. "
                    f"Test accuracy: {(accuracy/len(valid_loader))*100:.1f}%")
                running_loss = 0
                gs_vgg.train()
    gs_vgg.class_to_idx = train_datasets.class_to_idx
    gs_checkpoint = {
        'epoch_tot': epoch_tot,
        'model': gs_vgg,
        'criterion': criterion,
        'optimizer': optimizer,
        'optimizer.state_dict': optimizer.state_dict(),
        'model.state_dict': gs_vgg.state_dict(),
        'model.class_to_idx': gs_vgg.class_to_idx
    }
    torch.save(gs_checkpoint, save_dir)
    return gs_checkpoint
예제 #10
0
def get_model(name, device):
    if name == "normal_cnn":
        model = Net().to(device)
        return model
    elif name == "alexnet":
        model = models.alexnet().to(device)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "vgg11":
        model = models.vgg11().to(device)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "vgg13":
        model = models.vgg13().to(device)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "vgg16":
        model = models.vgg16().to(device)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "vgg19":
        model = models.vgg19().to(device)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "vgg11_bn":
        model = models.vgg11_bn().to(device)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "vgg13_bn":
        model = models.vgg13_bn().to(device)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "vgg16_bn":
        model = models.vgg16_bn().to(device)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "vgg19_bn":
        model = models.vgg19_bn().to(device)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "densenet121":
        model = models.densenet121().to(device)
        num_features = model.classifier.in_features
        model.classifier = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "densenet161":
        model = models.densenet161().to(device)
        num_features = model.classifier.in_features
        model.classifier = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "densenet169":
        model = models.densenet169().to(device)
        num_features = model.classifier.in_features
        model.classifier = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "densenet201":
        model = models.densenet201().to(device)
        num_features = model.classifier.in_features
        model.classifier = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "resnet18":
        model = models.resnet18().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "resnet34":
        model = models.resnet34().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "resnet50":
        model = models.resnet50().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "resnet101":
        model = models.resnet101().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "resnet152":
        model = models.resnet152().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "googlenet":
        model = models.googlenet(aux_logits=False).to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "inception_v3":
        model = models.inception_v3(aux_logits=False).to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "shufflenet_v2_x0_5":
        model = models.shufflenet_v2_x0_5().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "shufflenet_v2_x1_0":
        model = models.shufflenet_v2_x1_0().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "shufflenet_v2_x1_5":
        model = models.shufflenet_v2_x1_5().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "shufflenet_v2_x2_0":
        model = models.shufflenet_v2_x2_0().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "mobilenet_v2":
        model = models.mobilenet_v2().to(device)
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "resnext50_32x4d":
        model = models.resnext50_32x4d().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "resnext101_32x8d":
        model = models.resnext101_32x8d().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "wide_resnet50_2":
        model = models.wide_resnet50_2().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "wide_resnet101_2":
        model = models.wide_resnet101_2().to(device)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "mnasnet0_5":
        model = models.mnasnet0_5().to(device)
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "mnasnet0_75":
        model = models.mnasnet0_75().to(device)
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "mnasnet1_0":
        model = models.mnasnet1_0().to(device)
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, 1).to(device)
        return model
    elif name == "mnasnet1_3":
        model = models.mnasnet1_3().to(device)
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, 1).to(device)
        return model
예제 #11
0
 def test_mnasnet1_3(self):
     process_model(models.mnasnet1_3(), self.image,
                   _C_tests.forward_mnasnet1_3, "MNASNet1_3")
예제 #12
0
    def get_model(model_id, use_pretrained):
        model_ft = None
        if model_id == PyTorchModelsEnum.ALEXNET:
            model_ft = models.alexnet(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.DENSENET121:
            model_ft = models.densenet121(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.DENSENET161:
            model_ft = models.densenet161(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.DENSENET169:
            model_ft = models.densenet169(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.DENSENET201:
            model_ft = models.densenet201(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.GOOGLENET:
            model_ft = models.googlenet(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.INCEPTION_V3:
            model_ft = models.inception_v3(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MOBILENET_V2:
            model_ft = models.mobilenet_v2(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MNASNET_0_5:
            model_ft = models.mnasnet0_5(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MNASNET_0_75:  # no pretrained
            model_ft = models.mnasnet0_75(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MNASNET_1_0:
            model_ft = models.mnasnet1_0(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.MNASNET_1_3:
            model_ft = models.mnasnet1_3(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET18:
            model_ft = models.resnet18(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET34:
            model_ft = models.resnet34(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET50:
            model_ft = models.resnet50(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET101:
            model_ft = models.resnet101(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNET152:
            model_ft = models.resnet152(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNEXT50:
            model_ft = models.resnext50_32x4d(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.RESNEXT101:
            model_ft = models.resnext101_32x8d(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SHUFFLENET_V2_0_5:
            model_ft = models.shufflenet_v2_x0_5(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SHUFFLENET_V2_1_0:
            model_ft = models.shufflenet_v2_x1_0(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SHUFFLENET_V2_1_5:
            model_ft = models.shufflenet_v2_x1_5(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SHUFFLENET_V2_2_0:
            model_ft = models.shufflenet_v2_x2_0(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SQUEEZENET1_0:
            model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.SQUEEZENET1_1:
            model_ft = models.squeezenet1_1(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG11:
            model_ft = models.vgg11(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG11_BN:
            model_ft = models.vgg11_bn(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG13:
            model_ft = models.vgg13(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG13_BN:
            model_ft = models.vgg13_bn(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG16:
            model_ft = models.vgg16(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG16_BN:
            model_ft = models.vgg16_bn(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG19:
            model_ft = models.vgg19(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.VGG19_BN:
            model_ft = models.vgg19_bn(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.WIDE_RESNET50:
            model_ft = models.wide_resnet50_2(pretrained=use_pretrained)
        elif model_id == PyTorchModelsEnum.WIDE_RESNET101:
            model_ft = models.wide_resnet101_2(pretrained=use_pretrained)

        return model_ft
예제 #13
0
        'batch_size':
        batch_size,
        'network':
        lambda: models.shufflenet_v2_x1_0(),
        'criterion':
        lambda: nn.CrossEntropyLoss(),
        'optimizer':
        lambda parameters: optim.SGD(parameters, lr=0.001, momentum=0.9)
    },
    {
        'skip':
        True,
        'name': {
            'folder': 'MnasNet1.3_CrossEntropy_SGDMomentum',
            'network': 'MnasNet1.3',
            'criterion': 'CrossEntropyLoss',
            'optimizer': 'SGD with momentum',
        },
        'epochs':
        epochs,
        'batch_size':
        batch_size,
        'network':
        lambda: models.mnasnet1_3(),
        'criterion':
        lambda: nn.CrossEntropyLoss(),
        'optimizer':
        lambda parameters: optim.SGD(parameters, lr=0.001, momentum=0.9)
    },
]