コード例 #1
0
def _set_model_input_shape_attr(model, arch, dataset, pretrained, cadene):
    if cadene and pretrained:
        # When using pre-trained weights, Cadene models already have an input size attribute
        # We add the batch dimension to it
        input_size = model.module.input_size if isinstance(model, torch.nn.DataParallel) else model.input_size
        shape = tuple([1] + input_size)
        set_model_input_shape_attr(model, input_shape=shape)
    elif arch == 'inception_v3':
        set_model_input_shape_attr(model, input_shape=(1, 3, 299, 299))
    else:
        set_model_input_shape_attr(model, dataset=dataset)
コード例 #2
0
def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
    """Create a pytorch model based on the model architecture and dataset

    Args:
        pretrained [boolean]: True is you wish to load a pretrained model.
            Some models do not have a pretrained version.
        dataset: dataset name (only 'imagenet' and 'cifar10' are supported)
        arch: architecture name
        parallel [boolean]: if set, use torch.nn.DataParallel
        device_ids: Devices on which model should be created -
            None - GPU if available, otherwise CPU
            -1 - CPU
            >=0 - GPU device IDs
    """
    model = None
    dataset = dataset.lower()
    cadene = False
    if dataset == 'imagenet':
        if arch in RESNET_SYMS:
            model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
        elif arch in TORCHVISION_MODEL_NAMES:
            try:
                model = getattr(torch_models, arch)(pretrained=pretrained)
                if arch == "mobilenet_v2":
                    patch_torchvision_mobilenet_v2_bug(model)
            except NotImplementedError:
                # In torchvision 0.3, trying to download a model that has no
                # pretrained image available will raise NotImplementedError
                if not pretrained:
                    raise
        if model is None and (
                arch in imagenet_extra_models.__dict__) and not pretrained:
            model = imagenet_extra_models.__dict__[arch]()
        if model is None and (arch in pretrainedmodels.model_names):
            cadene = True
            model = pretrainedmodels.__dict__[arch](
                num_classes=1000, pretrained=(dataset if pretrained else None))
        if model is None:
            error_message = ''
            if arch not in IMAGENET_MODEL_NAMES:
                error_message = "Model {} is not supported for dataset ImageNet".format(
                    arch)
            elif pretrained:
                error_message = "Model {} (ImageNet) does not have a pretrained model".format(
                    arch)
            raise ValueError(error_message
                             or 'Failed to find model {}'.format(arch))

    elif dataset == 'cifar10':
        if pretrained:
            raise ValueError(
                "Model {} (CIFAR10) does not have a pretrained model".format(
                    arch))
        try:
            model = cifar10_models.__dict__[arch]()
        except KeyError:
            raise ValueError(
                "Model {} is not supported for dataset CIFAR10".format(arch))

    elif dataset == 'mnist':
        if pretrained:
            raise ValueError(
                "Model {} (MNIST) does not have a pretrained model".format(
                    arch))
        try:
            model = mnist_models.__dict__[arch]()
        except KeyError:
            raise ValueError(
                "Model {} is not supported for dataset MNIST".format(arch))
    else:
        raise ValueError('Could not recognize dataset {}'.format(dataset))

    msglogger.info("=> creating a %s%s model with the %s dataset" %
                   ('pretrained ' if pretrained else '', arch, dataset))
    if torch.cuda.is_available() and device_ids != -1:
        device = 'cuda'
        if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel:
            model.features = torch.nn.DataParallel(model.features,
                                                   device_ids=device_ids)
        elif parallel:
            model = torch.nn.DataParallel(model, device_ids=device_ids)
    else:
        device = 'cpu'

    if cadene and pretrained:
        # When using pre-trained weights, Cadene models already have an input size attribute
        # We add the batch dimension to it
        input_size = model.module.input_size if isinstance(
            model, torch.nn.DataParallel) else model.input_size
        shape = tuple([1] + input_size)
        set_model_input_shape_attr(model, input_shape=shape)
    elif arch == 'inception_v3':
        set_model_input_shape_attr(model, input_shape=(1, 3, 299, 299))
    else:
        set_model_input_shape_attr(model, dataset=dataset)

    return model.to(device)