def get_model(args):
    if args.model_name == 'inception':
        if args.pre_trained_checkpoint is not None:
            model = models.inception_v3(pretrained=False, transform_input=False)
            model.fc = nn.Linear(2048, args.n_outputs)
            model.AuxLogits = InceptionAux(768, args.n_outputs)
            model.aux_logits = False
            new_conv = BasicConv2d(1, 32, kernel_size=3, stride=2)
            model.Conv2d_1a_3x3 = new_conv
            sd = torch.load(args.pre_trained_checkpoint)['model']
            model.load_state_dict(sd)
        else:
            model = models.inception_v3(pretrained=True, transform_input=False)
            model.fc = nn.Linear(2048, args.n_outputs)
            model.AuxLogits = InceptionAux(768, args.n_outputs)
            model.aux_logits = False
            new_conv = BasicConv2d(1, 32, kernel_size=3, stride=2)
            first_layer_sd = model.Conv2d_1a_3x3.state_dict()
            first_layer_sd['conv.weight'] = first_layer_sd['conv.weight'].mean(dim=1, keepdim=True)
            new_conv.load_state_dict(first_layer_sd)
            model.Conv2d_1a_3x3 = new_conv
    elif args.model_name == 'resnet':
        if args.pre_trained_checkpoint is not None:
            model = models.resnet50(pretrained=False)
            model.fc = nn.Linear(in_features=2048, out_features=args.n_outputs, bias=True)
            new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            model.conv1 = new_conv
            sd = torch.load(args.pre_trained_checkpoint)['model']
            model.load_state_dict(sd)
        else:
            model = models.resnet50(pretrained=True)
            model.fc = nn.Linear(in_features=2048, out_features=args.n_outputs, bias=True)
            new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            first_layer_sd = model.conv1.state_dict()
            first_layer_sd['weight'] = first_layer_sd['weight'].mean(dim=1, keepdim=True)
            new_conv.load_state_dict(first_layer_sd)
            model.conv1 = new_conv
    elif args.model_name == 'resnext':
        if args.pre_trained_checkpoint is not None:
            model = models.resnext50_32x4d(pretrained=False)
            model.fc = nn.Linear(in_features=2048, out_features=args.n_outputs, bias=True)
            new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            model.conv1 = new_conv
            sd = torch.load(args.pre_trained_checkpoint)['model']
            model.load_state_dict(sd)
        else:
            model = models.resnext50_32x4d(pretrained=True)
            model.fc = nn.Linear(in_features=2048, out_features=args.n_outputs, bias=True)
            new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            first_layer_sd = model.conv1.state_dict()
            first_layer_sd['weight'] = first_layer_sd['weight'].mean(dim=1, keepdim=True)
            new_conv.load_state_dict(first_layer_sd)
            model.conv1 = new_conv
    return model
示例#2
0
 def __init__(self, num_classes=NLABEL, aux_logits=False, transform_input=False):
     super(Inception3, self).__init__()
     self.aux_logits = aux_logits
     self.transform_input = transform_input
     self.Conv2d_1a_3x3 = BasicConv2d(4, 32, kernel_size=3, stride=2)
     self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
     self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
     self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
     self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
     self.Mixed_5b = InceptionA(192, pool_features=32)
     self.Mixed_5c = InceptionA(256, pool_features=64)
     self.Mixed_5d = InceptionA(288, pool_features=64)
     self.Mixed_6a = InceptionB(288)
     self.Mixed_6b = InceptionC(768, channels_7x7=128)
     self.Mixed_6c = InceptionC(768, channels_7x7=160)
     self.Mixed_6d = InceptionC(768, channels_7x7=160)
     self.Mixed_6e = InceptionC(768, channels_7x7=192)
     if aux_logits:
         self.AuxLogits = InceptionAux(768, num_classes)
     self.Mixed_7a = InceptionD(768)
     self.Mixed_7b = InceptionE(1280)
     self.Mixed_7c = InceptionE(2048)
     self.fc = nn.Linear(2048, num_classes)
     self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
     for m in self.modules():
         if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
             import scipy.stats as stats
             stddev = m.stddev if hasattr(m, 'stddev') else 0.1
             X = stats.truncnorm(-2, 2, scale=stddev)
             values = torch.Tensor(X.rvs(m.weight.numel()))
             values = values.view(m.weight.size())
             m.weight.data.copy_(values)
         elif isinstance(m, nn.BatchNorm2d):
             nn.init.constant_(m.weight, 1)
             nn.init.constant_(m.bias, 0)
示例#3
0
    def __init__(self, num_classes=1000, transform_input=True):
        super(InceptionV3UptoPool3, self).__init__()
        self.transform_input = transform_input
        self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
        self.Mixed_5b = InceptionA(192, pool_features=32)
        self.Mixed_5c = InceptionA(256, pool_features=64)
        self.Mixed_5d = InceptionA(288, pool_features=64)
        self.Mixed_6a = InceptionB(288)
        self.Mixed_6b = InceptionC(768, channels_7x7=128)
        self.Mixed_6c = InceptionC(768, channels_7x7=160)
        self.Mixed_6d = InceptionC(768, channels_7x7=160)
        self.Mixed_6e = InceptionC(768, channels_7x7=192)
        self.AuxLogits = InceptionAux(768, num_classes)
        self.Mixed_7a = InceptionD(768)
        self.Mixed_7b = InceptionE(1280)
        self.Mixed_7c = InceptionE(2048)
        self.fc = nn.Linear(2048, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                stddev = m.stddev if hasattr(m, 'stddev') else 0.1
                X = stats.truncnorm(-2, 2, scale=stddev)
                values = torch.Tensor(X.rvs(m.weight.data.numel()))
                values = values.view(m.weight.data.size())
                m.weight.data.copy_(values)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    def __init__(self, pre=True):
        super().__init__()
        self.encoder = torchvision.models.inception_v3(pretrained=pre)

        conv1 = BasicConv2d(4, 32, kernel_size=3, stride=2)
        if pre:
            w = self.encoder.Conv2d_1a_3x3.conv.weight
            conv1.conv.weight = nn.Parameter(torch.cat((w, 0.5 * (w[:, :1, :, :] + w[:, 2:, :, :])), dim=1))
        self.encoder.Conv2d_1a_3x3 = conv1
        self.encoder.AuxLogits = InceptionAux(768, num_class())
        self.encoder.fc = nn.Linear(2048, num_class())
示例#5
0
def get_model(args):
    if args.model_name == 'inception':
        if args.pre_trained_checkpoint is not None:
            model = models.inception_v3(pretrained=False,
                                        transform_input=False)
            model.fc = nn.Linear(2048, args.n_outputs)
            model.AuxLogits = InceptionAux(768, args.n_outputs)
            model.aux_logits = False
            new_conv = BasicConv2d(1, 32, kernel_size=3, stride=2)
            model.Conv2d_1a_3x3 = new_conv
            sd = torch.load(args.pre_trained_checkpoint)['model']
            model.load_state_dict(sd)
        else:
            print('Missing model.')
    return model
示例#6
0
    def __init__(self, num_classes=80, aux_logits=True, transform_input=False, apply_avgpool=False):
        super(Inception3, self).__init__()
        self.aux_logits = aux_logits
        self.transform_input = transform_input
        self.apply_avgpool = apply_avgpool

        self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
        self.Mixed_5b = InceptionA(192, pool_features=32)
        self.Mixed_5c = InceptionA(256, pool_features=64)
        self.Mixed_5d = InceptionA(288, pool_features=64)
        self.Mixed_6a = InceptionB(288)
        self.Mixed_6b = InceptionC(768, channels_7x7=128)
        self.Mixed_6c = InceptionC(768, channels_7x7=160)
        self.Mixed_6d = InceptionC(768, channels_7x7=160)
        self.Mixed_6e = InceptionC(768, channels_7x7=192)
        if aux_logits:
            self.AuxLogits = InceptionAux(768, num_classes)
        self.Mixed_7a = InceptionD(768)
        self.Mixed_7b = InceptionE(1280)
        self.Mixed_7c = InceptionE(2048)
示例#7
0
def get_model(model_type, model_name, num_classes, pretrained=True, **kwargs):
    """
    :param model_type: (ModelType):
        type of model we're trying to obtain (classification or segmentation)
    :param model_name: (string):
        name of the model. By convention (for classification models) lowercase names represent pretrained model variants while Uppercase do not.
    :param num_classes: (int):
        number of classes to initialize with (this will replace the last classification layer or set the number of segmented classes)
    :param pretrained: (bool):
        whether to load the default pretrained version of the model
        NOTE! NOTE! For classification, the lowercase model names are the pretrained variants while the Uppercase model names are not.
        The only exception applies to torch.hub models (all efficientnet, mixnet, mobilenetv3, mnasnet, spnasnet variants) where a single
        lower-case string can be used for vanilla and pretrained versions. Otherwise, it is IN ERROR to specify an Uppercase model name variant
        with pretrained=True but one can specify a lowercase model variant with pretrained=False
        (default: True)
    :return: model
    """

    if model_name not in get_supported_models(
            model_type) and not model_name.startswith('TEST'):
        raise ValueError(
            'The supplied model name: {} was not found in the list of acceptable model names.'
            ' Use get_supported_models() to obtain a list of supported models.'
            .format(model_name))

    print("INFO: Loading Model:   --   " + model_name +
          "  with number of classes: " + str(num_classes))

    if model_type == ModelType.CLASSIFICATION:
        torch_hub_names = torch.hub.list('rwightman/gen-efficientnet-pytorch')
        if model_name in torch_hub_names:
            model = torch.hub.load('rwightman/gen-efficientnet-pytorch',
                                   model_name,
                                   pretrained=pretrained,
                                   num_classes=num_classes)
        else:
            # 1. Load model (pretrained or vanilla)
            fc_name = get_fc_names(
                model_name=model_name, model_type=model_type)[-1:][
                    0]  # we're only interested in the last layer name
            new_fc = None  # Custom layer to replace with (if none, then it will be handled generically)
            if model_name in torch_models.__dict__:
                print('INFO: Loading torchvision model: {}\t Pretrained: {}'.
                      format(model_name, pretrained))
                model = torch_models.__dict__[model_name](
                    pretrained=pretrained
                )  # find a model included in the torchvision package
            else:
                net_list = [
                    'fbresnet', 'inception', 'mobilenet', 'nasnet', 'polynet',
                    'resnext', 'se_resnet', 'senet', 'shufflenet', 'xception'
                ]
                if pretrained:
                    print('INFO: Loading a pretrained model: {}'.format(
                        model_name))
                    if 'dpn' in model_name:
                        model = classification.__dict__[model_name](
                            pretrained=True
                        )  # find a model included in the pywick classification package
                    elif any(net_name in model_name for net_name in net_list):
                        model = classification.__dict__[model_name](
                            pretrained='imagenet')
                else:
                    print(
                        'INFO: Loading a vanilla model: {}'.format(model_name))
                    model = classification.__dict__[model_name](
                        pretrained=None
                    )  # pretrained must be set to None for the extra models... go figure

            # 2. Create custom FC layers for non-standardized models
            if 'squeezenet' in model_name:
                final_conv = nn.Conv2d(512, num_classes, kernel_size=1)
                new_fc = nn.Sequential(nn.Dropout(p=0.5), final_conv,
                                       nn.ReLU(inplace=True),
                                       nn.AvgPool2d(13, stride=1))
                model.num_classes = num_classes
            elif 'vgg' in model_name:
                new_fc = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),
                                       nn.ReLU(True), nn.Dropout(),
                                       nn.Linear(4096, 4096), nn.ReLU(True),
                                       nn.Dropout(),
                                       nn.Linear(4096, num_classes))
            elif 'inception3' in model_name.lower(
            ) or 'inception_v3' in model_name.lower():
                # Replace the extra aux_logits FC layer if aux_logits are enabled
                if getattr(model, 'aux_logits', False):
                    model.AuxLogits = InceptionAux(768, num_classes)
            elif 'dpn' in model_name.lower():
                old_fc = getattr(model, fc_name)
                new_fc = nn.Conv2d(old_fc.in_channels,
                                   num_classes,
                                   kernel_size=1,
                                   bias=True)

            # 3. For standard FC layers (nn.Linear) perform a reflection lookup and generate a new FC
            if new_fc is None:
                old_fc = getattr(model, fc_name)
                new_fc = nn.Linear(old_fc.in_features, num_classes)

            # 4. perform replacement of the last FC / Linear layer with a new one
            setattr(model, fc_name, new_fc)

        return model

    elif model_type == ModelType.SEGMENTATION:
        """
        Additional Segmentation Option Parameters
        -----------------------------------------
        
        BiSeNet
            - :param backbone: (str, default: 'resnet18') The type of backbone to use (one of `{'resnet18'}`)
            - :param aux: (bool, default: False) Whether to output auxiliary loss (typically an FC loss to help with multi-class segmentation)
            
        DANet_ResnetXXX, DUNet_ResnetXXX, EncNet, OCNet_XXX_XXX, PSANet_XXX
            - :param aux: (bool, default: False) Whether to output auxiliary loss (typically an FC loss to help with multi-class segmentation)
            - :param backbone: (str, default: 'resnet101') The type of backbone to use (one of `{'resnet50', 'resnet101', 'resnet152'}`)
            - :param norm_layer (Pytorch nn.Module, default: nn.BatchNorm2d) The normalization layer to use. Typically it is not necessary to change this parameter unless you know what you're doing.
        
        DenseASPP_XXX
            - :param aux: (bool, default: False) Whether to output auxiliary loss (typically an FC loss to help with multi-class segmentation)
            - :param backbone: (str, default: 'densenet161') The type of backbone to use (one of `{'densenet121', 'densenet161', 'densenet169', 'densenet201'}`)
            - :param dilate_scale (int, default: 8) The size of the dilation to use (one of `{8, 16}`)
            - :param norm_layer (Pytorch nn.Module, default: nn.BatchNorm2d) The normalization layer to use. Typically it is not necessary to change this parameter unless you know what you're doing.
            
        DRNSeg
            - :param model_name: (str - required) The type of backbone to use. One of `{'DRN_C_42', 'DRN_C_58', 'DRN_D_38', 'DRN_D_54', 'DRN_D_105'}`
            
        EncNet_ResnetXXX
            - :param aux: (bool, default: False) Whether to output auxiliary loss (typically an FC loss to help with multi-class segmentation)
            - :param backbone: (str, default: 'resnet101') The type of backbone to use (one of `{'resnet50', 'resnet101', 'resnet152'}`)
            - :param norm_layer (Pytorch nn.Module, default: nn.BatchNorm2d) The normalization layer to use. Typically it is not necessary to change this parameter unless you know what you're doing.
            - :param se_loss (bool, default: True) Whether to compute se_loss
            - :param lateral (bool, default: False)
        
        frrn
            - :param model_type: (str - required) The type of model to use. One of `{'A', 'B'}`
        
        GCN, GCN_DENSENET, GCN_NASNET, GCN_PSP, GCN_RESNEXT
            - :param k: (int - optional) The size of global kernel
        
        GCN_PSP, GCN_RESNEXT, Unet_stack
            - :param input_size: (int - required) The size of output image (will be square)
        
        LinkCeption, 'LinkDenseNet121', 'LinkDenseNet161', 'LinkInceptionResNet', 'LinkNet18', 'LinkNet34', 'LinkNet50', 'LinkNet101', 'LinkNet152', 'LinkNeXt', 'CoarseLinkNet50'
            - :param num_channels: (int, default: 3) Number of channels in the image (e.g. 3 = RGB)
            - :param is_deconv: (bool, default: False)
            - :param decoder_kernel_size: (int, default: 3) Size of the decoder kernel
        
        PSPNet
            - :param backend: (str, default: densenet121) The type of extractor to use. One of `{'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'densenet121'}`
        
        RefineNet4Cascade, RefineNet4CascadePoolingImproved
            - :param input_shape: (tuple(int, int), default: (1, 512) - required!) Tuple representing input shape (num_channels, dim)
            - :param freeze_resnet: (bool, default: False) - whether to freeze the underlying resnet
        """

        model_exists = False
        for m_name in get_supported_models(model_type):
            if model_name in m_name:
                model_exists = True
                break
        if model_exists:
            # Print warnings and helpful messages for nets that require additional configuration
            if model_name in [
                    'GCN_PSP', 'GCN_RESNEXT', 'RefineNet4Cascade',
                    'RefineNet4CascadePoolingImproved', 'Unet_stack'
            ]:
                print(
                    'WARN: Did you remember to set the input_size parameter: (int) ?'
                )
            elif model_name in [
                    'RefineNet4Cascade', 'RefineNet4CascadePoolingImproved'
            ]:
                print(
                    'WARN: Did you remember to set the input_shape parameter: tuple(int, int)?'
                )

            # logic to switch between different constructors
            if model_name in [
                    'FusionNet', 'Enet', 'frrn', 'Tiramisu57', 'Tiramisu67',
                    'Tiramisu101'
            ] or model_name.startswith('UNet') and pretrained:  # FusionNet
                print(
                    "WARN: FusionNet, Enet, FRRN, Tiramisu, UNetXXX do not have a pretrained model! Empty model as been created instead."
                )

            net = segmentation.__dict__[model_name](num_classes=num_classes,
                                                    pretrained=pretrained,
                                                    **kwargs)

        else:
            raise Exception(
                'Combination of type: {} and model_name: {} is not valid'.
                format(model_type, model_name))

    return net
示例#8
0
def get_model(type, model_name, num_classes, input_size, pretrained=True):
    '''
    :param type: str
        one of {'classification', 'segmentation'}
    :param model_name: str
        name of the model. By convention (for classification models) lowercase names represent pretrained model variants while Uppercase do not.
    :param num_classes: int
        number of classes to initialize with (this will replace the last classification layer or set the number of segmented classes)
    :param input_size: (int,int)
        Segmentation-only param. What size of input the network will accept e.g. (256, 256), (512, 512)
    :param pretrained: bool
        whether to load the default pretrained version of the model
        NOTE! NOTE! For classification, the lowercase model names are the pretrained variants while the Uppercase model names are not.
        It is IN ERROR to specify an Uppercase model name variant with pretrained=True but one can specify a lowercase model variant with pretrained=False
        (default: True)
    :return model
    '''
    if model_name not in get_supported_models(type) and not model_name.startswith('TEST'):
        raise ValueError('The supplied model name: {} was not found in the list of acceptable model names.'
                         ' Use get_supported_models() to obtain a list of supported models.')

    print("INFO: Loading Model:   --   " + model_name + "  with number of classes: " + str(num_classes))
    
    if type == 'classification':

        # 1. Load model (pretrained or vanilla)
        fc_name = 'last_linear'  # most common name of the last layer (to be replaced)
        if model_name in torch_models.__dict__:
            print('INFO: Loading a pretrained?: {}   model: {}'.format(pretrained, model_name))
            model = torch_models.__dict__[model_name](pretrained=pretrained)  # find a model included in the torchvision package
            if 'densenet' in model_name:    # apparently densenet is special..
                fc_name = 'classifier'  # the name of the last layer to be replaced in torchvision models
            else:
                fc_name = 'fc'  # the name of the last layer to be replaced in torchvision models
        else:
            if pretrained:
                print('INFO: Loading a pretrained model: {}'.format(model_name))
                if 'dpn' in model_name:
                    model = classification.__dict__[model_name](pretrained=True)  # find a model included in the wick classification package
                elif 'inception' in model_name or 'nasnet' in model_name or 'polynet' in model_name or 'resnext' in model_name\
                        or 'se_resnet' in model_name or 'xception' in model_name:
                    model = classification.__dict__[model_name](pretrained='imagenet')
            else:
                print('INFO: Loading a vanilla model: {}'.format(model_name))
                model = classification.__dict__[model_name](pretrained=None)  # pretrained must be set to None for the extra models... go figure

        # 2. Tweak FC layer with the num_classes provided
        # Custom handling of models that have non-standard classifier layers
        if 'squeezenet' in model_name:
            final_conv = nn.Conv2d(512, num_classes, kernel_size=1)
            classifier = nn.Sequential(
                nn.Dropout(p=0.5),
                final_conv,
                nn.ReLU(inplace=True),
                nn.AvgPool2d(13, stride=1)
            )
            model.classifier = classifier
            model.num_classes = num_classes
        elif 'vgg' in model_name:
            classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, num_classes)
            )
            model.classifier = classifier
        elif 'inception3' in model_name.lower() or 'inception_v3' in model_name.lower():
            # Two FC layers if aux_logits are enabled
            if getattr(model, 'aux_logits', False):
                model.AuxLogits = InceptionAux(768, num_classes)
            model.fc = nn.Linear(2048, num_classes)

        elif 'dpn' in model_name.lower():
            old_fc = getattr(model, fc_name)
            new_fc = nn.Conv2d(old_fc.in_channels, num_classes, kernel_size=1, bias=True)
            setattr(model, fc_name, new_fc)
        else:  # perform standard replacement of the last FC / Linear layer with a new one
            old_fc = getattr(model, fc_name)
            new_fc = nn.Linear(old_fc.in_features, num_classes)
            setattr(model, fc_name, new_fc)

        return model

    elif type == 'segmentation':
        if model_name == 'Enet':                                            # standard enet
            net = ENet(num_classes=num_classes)
            if pretrained:
                print("WARN: Enet does not have a pretrained model! Empty model as been created instead.")
        elif model_name == 'deeplabv2_ASPP':                                # Deeplab Atrous Convolutions
            net = DeepLabv2_ASPP(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'deeplabv2_FOV':                                 # Deeplab FOV
            net = DeepLabv2_FOV(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'deeplabv3':                                     # Deeplab V3!
            net = DeepLabv3(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'deeplabv3_Plus':  # Deeplab V3!
            net = DeepLabv3_plus(num_classes=num_classes, pretrained=pretrained)
        elif 'DRN_' in model_name:
            net = DRNSeg(model_name=model_name, classes=num_classes, pretrained=pretrained)
        elif model_name == 'FRRN_A':                                        # FRRN
            net = frrn(num_classes=num_classes, model_type='A')
            if pretrained:
                print("FRRN_A Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'FRRN_B':                                        # FRRN
            net = frrn(num_classes=num_classes, model_type='B')
            if pretrained:
                print("FRRN_B Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'FusionNet':                                     # FusionNet
            net = FusionNet(num_classes=num_classes)
            if pretrained:
                print("FusionNet Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'GCN':                                           # GCN Resnet
            net = GCN(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'GCN_VisDa':                                     # Another GCN Implementation
            net = GCN_VisDa(num_classes=num_classes, input_size=input_size, pretrained=pretrained)
        elif model_name == 'GCN_Densenet':                                     # Another GCN Implementation
            net = GCN_DENSENET(num_classes=num_classes, input_size=input_size, pretrained=pretrained)
        elif model_name == 'GCN_PSP':                                     # Another GCN Implementation
            net = GCN_PSP(num_classes=num_classes, input_size=input_size, pretrained=pretrained)
        elif model_name == 'GCN_NASNetA':                                     # Another GCN Implementation
            net = GCN_NASNET(num_classes=num_classes, input_size=input_size, pretrained=pretrained)
        elif model_name == 'GCN_Resnext':                                     # Another GCN Implementation
            net = GCN_RESNEXT(num_classes=num_classes, input_size=input_size, pretrained=pretrained)
        elif model_name == 'Linknet':                                       # Linknet34
            net = LinkNet34(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'PSPNet':
            net = PSPNet(num_classes=num_classes, pretrained=pretrained, backend='resnet101')
        elif model_name == 'Resnet_DUC':
            net = ResNetDUC(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'Resnet_DUC_HDC':
            net = ResNetDUCHDC(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'Resnet_GCN':                                    # GCN Resnet 2
            net = ResnetGCN(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'Segnet':                                          # standard segnet
            net = SegNet(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'TEST_DiLinknet':
            net = TEST_DiLinknet(num_classes=num_classes, pretrained=False)
        elif model_name == 'TEST_DLR_Resnet':
            net = create_DLR_V3_pretrained(num_classes=num_classes)
        elif model_name == 'TEST_DLX_Resnet':
            net = create_DLX_V3_pretrained(num_classes=num_classes)
        elif model_name == 'TEST_PSPNet2':
            net = TEST_PSPNet2(num_classes=num_classes)
        elif model_name == 'TEST_DLV2':
            net = TEST_DLV2(n_classes=num_classes, n_blocks=[3, 4, 23, 3], pyramids=[6, 12, 18, 24])
            net = TEST_DLV3_Xception(n_classes=num_classes, os=8, pretrained=True, _print=False)
        elif model_name == 'TEST_DLV3':
            net = TEST_DLV3(n_classes=num_classes, n_blocks=[3, 4, 23, 3], pyramids=[12, 24, 36], grids=[1, 2, 4], output_stride=8)
        elif model_name == 'TEST_LinkCeption':
            net = TEST_LinkCeption(num_classes=num_classes)
        elif model_name == 'TEST_LinkDensenet121':
            net = TEST_LinkDenseNet121(num_classes=num_classes)
        elif model_name == 'TEST_Linknet50':
            net = TEST_Linknet101(num_classes=num_classes)
        elif model_name == 'TEST_Linknet101':
            net = TEST_Linknet101(num_classes=num_classes)
        elif model_name == 'TEST_Linknet152':
            net = TEST_Linknet152(num_classes=num_classes)
        elif model_name == 'TEST_Linknext':
            net = TEST_Linknext(num_classes=num_classes)
        elif model_name == 'TEST_FCDensenet':
            net = TEST_FCDensenet(out_channels=num_classes)
        elif model_name == 'TEST_Tiramisu57':
            net = TEST_Tiramisu57(num_classes=num_classes)
        elif model_name == 'TEST_Unet_nested_dilated':
            net = TEST_Unet_nested_dilated(n_classes=num_classes)
        elif model_name == 'TEST_Unet_plus_plus':
            net = Unet_Plus_Plus(in_channels=3, n_classes=num_classes)
        elif model_name == 'Tiramisu57':  # Tiramisu
            net = FCDenseNet57(n_classes=num_classes)
            if pretrained:
                print("Tiramisu67 Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'Tiramisu67':                                     # Tiramisu
            net = FCDenseNet67(n_classes=num_classes)
            if pretrained:
                print("Tiramisu67 Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'Tiramisu103':                                   # Tiramisu
            net = FCDenseNet103(n_classes=num_classes)
            if pretrained:
                print("Tiramisu103 Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'Unet':                                          # standard unet
            net = UNet(num_classes=num_classes)
            if pretrained:
                print("UNet Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'UNet256':                                       # Unet for 256px square imgs
            net = UNet256(in_shape=(3,256,256))
            if pretrained:
                print("UNet256 Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'UNet512':                                       # Unet for 512px square imgs
            net = UNet512(in_shape=(3, 512, 512))
            if pretrained:
                print("UNet512 Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'UNet1024':                                      # Unet for 1024px square imgs
            net = UNet1024(in_shape=(3, 1024, 1024))
            if pretrained:
                print("UNet1024 Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'UNet960':                                       # Another Unet specifically with 960px resolution
            net = UNet960(filters=12)
            if pretrained:
                print("UNet960 Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'unet_dilated':                                  # dilated unet
            net = uNetDilated(num_classes=num_classes)
        elif model_name == 'Unet_res':                                      # residual unet
            net = UNetRes(num_class=num_classes)
            if pretrained:
                print("UNet_res Does not have a pretrained model! Empty model has been created instead.")
        elif model_name == 'UNet_stack':                                    # Stacked Unet variation with resnet connections
            net = UNet_stack(input_size=(input_size, input_size), filters=12)
            if pretrained:
                print("UNet_stack Does not have a pretrained model! Empty model has been created instead.")
        else:
            raise Exception('Combination of type: {} and model_name: {} is not valid'.format(type, model_name))

    return net
示例#9
0
    def __init__(self,
                 use_bottleneck=True,
                 bottleneck_dim=256,
                 new_cls=False,
                 class_num=1000,
                 aux_logits=True,
                 transform_input=False):
        super(Inception3Fc, self).__init__()

        model_inception = inception_v3(pretrained=True)

        self.aux_logits = aux_logits
        self.transform_input = transform_input
        self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
        self.Mixed_5b = InceptionA(192, pool_features=32)
        self.Mixed_5c = InceptionA(256, pool_features=64)
        self.Mixed_5d = InceptionA(288, pool_features=64)
        self.Mixed_6a = InceptionB(288)
        self.Mixed_6b = InceptionC(768, channels_7x7=128)
        self.Mixed_6c = InceptionC(768, channels_7x7=160)
        self.Mixed_6d = InceptionC(768, channels_7x7=160)
        self.Mixed_6e = InceptionC(768, channels_7x7=192)
        if aux_logits:
            self.AuxLogits = InceptionAux(768, class_num)
        self.Mixed_7a = InceptionD(768)
        self.Mixed_7b = InceptionE(1280)
        self.Mixed_7c = InceptionE(2048)
        self.fc = nn.Linear(2048, class_num)

        # self.avgpool = model_xception.avgpool
        self.feature_layers = nn.Sequential(
            self.Conv2d_1a_3x3,
            self.Conv2d_2a_3x3,
            self.Conv2d_2b_3x3,
            self.Conv2d_3b_1x1,
            self.Conv2d_4a_3x3,
            self.Mixed_5b,
            self.Mixed_5c,
            self.Mixed_5d,
            self.Mixed_6a,
            self.Mixed_6b,
            self.Mixed_6c,
            self.Mixed_6d,
            self.Mixed_6e,
            self.Mixed_7a,
            self.Mixed_7b,
            self.Mixed_7c,
        )

        ####################

        self.use_bottleneck = use_bottleneck
        self.new_cls = new_cls
        # print("classes inside network",new_cls)
        if new_cls:
            if self.use_bottleneck:
                print(bottleneck_dim)
                self.bottleneck = nn.Linear(model_inception.fc.in_features,
                                            bottleneck_dim)
                self.fc = nn.Linear(bottleneck_dim, class_num)
                self.bottleneck.apply(init_weights)
                self.fc.apply(init_weights)
                self.__in_features = bottleneck_dim
            else:
                self.fc = nn.Linear(model_inception.fc.in_features, class_num)
                self.fc.apply(init_weights)
                self.__in_features = model_inception.fc.in_features
        else:
            self.fc = model_inception.fc
            self.__in_features = model_inception.fc.in_features
示例#10
0
def get_model(model_type,
              model_name,
              num_classes,
              input_size,
              pretrained=True):
    """
    :param model_type: (ModelType):
        type of model we're trying to obtain (classification or segmentation)
    :param model_name: (string):
        name of the model. By convention (for classification models) lowercase names represent pretrained model variants while Uppercase do not.
    :param num_classes: (int):
        number of classes to initialize with (this will replace the last classification layer or set the number of segmented classes)
    :param input_size: (int,int):
        Segmentation-only param. What size of input the network will accept e.g. (256, 256), (512, 512)
    :param pretrained: (bool):
        whether to load the default pretrained version of the model
        NOTE! NOTE! For classification, the lowercase model names are the pretrained variants while the Uppercase model names are not.
        It is IN ERROR to specify an Uppercase model name variant with pretrained=True but one can specify a lowercase model variant with pretrained=False
        (default: True)
    :return: model
    """

    if model_name not in get_supported_models(
            model_type) and not model_name.startswith('TEST'):
        raise ValueError(
            'The supplied model name: {} was not found in the list of acceptable model names.'
            ' Use get_supported_models() to obtain a list of supported models.'
            .format(model_name))

    print("INFO: Loading Model:   --   " + model_name +
          "  with number of classes: " + str(num_classes))

    if model_type == ModelType.CLASSIFICATION:

        # 1. Load model (pretrained or vanilla)
        fc_name = get_fc_names(
            model_name=model_name, model_type=model_type)[-1:][
                0]  # we're only interested in the last layer name
        new_fc = None  # Custom layer to replace with (if none, then it will be handled generically)
        if model_name in torch_models.__dict__:
            print(
                'INFO: Loading torchvision model: {}\t Pretrained: {}'.format(
                    model_name, pretrained))
            model = torch_models.__dict__[model_name](
                pretrained=pretrained
            )  # find a model included in the torchvision package
        else:
            net_list = [
                'fbresnet', 'inception', 'mobilenet', 'nasnet', 'polynet',
                'resnext', 'se_resnet', 'senet', 'shufflenet', 'xception'
            ]
            if pretrained:
                print(
                    'INFO: Loading a pretrained model: {}'.format(model_name))
                if 'dpn' in model_name:
                    model = classification.__dict__[model_name](
                        pretrained=True
                    )  # find a model included in the pywick classification package
                elif any(net_name in model_name for net_name in net_list):
                    model = classification.__dict__[model_name](
                        pretrained='imagenet')
            else:
                print('INFO: Loading a vanilla model: {}'.format(model_name))
                model = classification.__dict__[model_name](
                    pretrained=None
                )  # pretrained must be set to None for the extra models... go figure

        # 2. Create custom FC layers for non-standardized models
        if 'squeezenet' in model_name:
            final_conv = nn.Conv2d(512, num_classes, kernel_size=1)
            new_fc = nn.Sequential(nn.Dropout(p=0.5), final_conv,
                                   nn.ReLU(inplace=True),
                                   nn.AvgPool2d(13, stride=1))
            model.num_classes = num_classes
        elif 'vgg' in model_name:
            new_fc = nn.Sequential(nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True),
                                   nn.Dropout(), nn.Linear(4096, 4096),
                                   nn.ReLU(True), nn.Dropout(),
                                   nn.Linear(4096, num_classes))
        elif 'inception3' in model_name.lower(
        ) or 'inception_v3' in model_name.lower():
            # Replace the extra aux_logits FC layer if aux_logits are enabled
            if getattr(model, 'aux_logits', False):
                model.AuxLogits = InceptionAux(768, num_classes)
        elif 'dpn' in model_name.lower():
            old_fc = getattr(model, fc_name)
            new_fc = nn.Conv2d(old_fc.in_channels,
                               num_classes,
                               kernel_size=1,
                               bias=True)

        # 3. For standard FC layers (nn.Linear) perform a reflection lookup and generate a new FC
        if new_fc is None:
            old_fc = getattr(model, fc_name)
            new_fc = nn.Linear(old_fc.in_features, num_classes)

        # 4. perform replacement of the last FC / Linear layer with a new one
        setattr(model, fc_name, new_fc)

        return model

    elif model_type == ModelType.SEGMENTATION:
        if model_name == 'Enet':  # standard enet
            net = ENet(num_classes=num_classes)
            if pretrained:
                print(
                    "WARN: Enet does not have a pretrained model! Empty model as been created instead."
                )
        elif model_name == 'deeplabv2_ASPP':  # Deeplab Atrous Convolutions
            net = DeepLabv2_ASPP(num_classes=num_classes,
                                 pretrained=pretrained)
        elif model_name == 'deeplabv2_FOV':  # Deeplab FOV
            net = DeepLabv2_FOV(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'deeplabv3':  # Deeplab V3!
            net = DeepLabv3(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'deeplabv3_Plus':  # Deeplab V3!
            net = DeepLabv3_plus(num_classes=num_classes,
                                 pretrained=pretrained)
        elif 'DRN_' in model_name:
            net = DRNSeg(model_name=model_name,
                         classes=num_classes,
                         pretrained=pretrained)
        elif model_name == 'FRRN_A':  # FRRN
            net = frrn(num_classes=num_classes, model_type='A')
            if pretrained:
                print(
                    "FRRN_A Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'FRRN_B':  # FRRN
            net = frrn(num_classes=num_classes, model_type='B')
            if pretrained:
                print(
                    "FRRN_B Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'FusionNet':  # FusionNet
            net = FusionNet(num_classes=num_classes)
            if pretrained:
                print(
                    "FusionNet Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'GCN':  # GCN Resnet
            net = GCN(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'GCN_VisDa':  # Another GCN Implementation
            net = GCN_VisDa(num_classes=num_classes,
                            input_size=input_size,
                            pretrained=pretrained)
        elif model_name == 'GCN_Densenet':  # Another GCN Implementation
            net = GCN_DENSENET(num_classes=num_classes,
                               input_size=input_size,
                               pretrained=pretrained)
        elif model_name == 'GCN_PSP':  # Another GCN Implementation
            net = GCN_PSP(num_classes=num_classes,
                          input_size=input_size,
                          pretrained=pretrained)
        elif model_name == 'GCN_NASNetA':  # Another GCN Implementation
            net = GCN_NASNET(num_classes=num_classes,
                             input_size=input_size,
                             pretrained=pretrained)
        elif model_name == 'GCN_Resnext':  # Another GCN Implementation
            net = GCN_RESNEXT(num_classes=num_classes,
                              input_size=input_size,
                              pretrained=pretrained)
        elif model_name == 'Linknet':  # Linknet34
            net = LinkNet34(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'PSPNet':
            net = PSPNet(num_classes=num_classes,
                         pretrained=pretrained,
                         backend='resnet101')
        elif model_name == 'RefineNet4Cascade':
            net = RefineNet4Cascade((1, input_size),
                                    num_classes=num_classes,
                                    pretrained=pretrained)
        elif model_name == 'RefineNet4CascadePoolingImproved':
            net = RefineNet4Cascade((1, input_size),
                                    num_classes=num_classes,
                                    pretrained=pretrained)
        elif model_name == 'Resnet_DUC':
            net = ResNetDUC(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'Resnet_DUC_HDC':
            net = ResNetDUCHDC(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'Resnet_GCN':  # GCN Resnet 2
            net = ResnetGCN(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'Segnet':  # standard segnet
            net = SegNet(num_classes=num_classes, pretrained=pretrained)
        elif model_name == 'TEST_BiSeNet_Res18':
            net = TEST_BiSeNet_Res18(num_classes=num_classes,
                                     pretrained=pretrained)
        elif model_name == 'TEST_DANet_Res50':
            net = TEST_DANet_Res50(num_classes=num_classes,
                                   pretrained=pretrained)
        elif model_name == 'TEST_DANet_Res101':
            net = TEST_DANet_Res101(num_classes=num_classes,
                                    pretrained=pretrained)
        elif model_name == 'TEST_DANet_Res152':
            net = TEST_DANet_Res152(num_classes=num_classes,
                                    pretrained=pretrained)
        elif model_name == 'TEST_DenseASPP_121':
            net = TEST_DenseASPP_121(num_classes=num_classes,
                                     pretrained=pretrained)
        elif model_name == 'TEST_DenseASPP_161':
            net = TEST_DenseASPP_161(num_classes=num_classes,
                                     pretrained=pretrained)
        elif model_name == 'TEST_DenseASPP_169':
            net = TEST_DenseASPP_169(num_classes=num_classes,
                                     pretrained=pretrained)
        elif model_name == 'TEST_DenseASPP_201':
            net = TEST_DenseASPP_201(num_classes=num_classes,
                                     pretrained=pretrained)
        elif model_name == 'TEST_DiLinknet':
            net = TEST_DiLinknet(num_classes=num_classes, pretrained=False)
        elif model_name == 'TEST_DLR_Resnet':
            net = create_DLR_V3_pretrained(num_classes=num_classes)
        elif model_name == 'TEST_DLX_Resnet':
            net = create_DLX_V3_pretrained(num_classes=num_classes)
        elif model_name == 'TEST_EncNet_Res101':
            net = TEST_EncNet_Res101(num_classes=num_classes,
                                     pretrained=pretrained)
        elif model_name == 'TEST_EncNet_Res152':
            net = TEST_EncNet_Res152(num_classes=num_classes,
                                     pretrained=pretrained)
        elif model_name == 'TEST_PSPNet2':
            net = TEST_PSPNet2(num_classes=num_classes)
        elif model_name == 'TEST_DLV2':
            net = TEST_DLV2(n_classes=num_classes,
                            n_blocks=[3, 4, 23, 3],
                            pyramids=[6, 12, 18, 24])
            net = TEST_DLV3_Xception(n_classes=num_classes,
                                     os=8,
                                     pretrained=True,
                                     _print=False)
        elif model_name == 'TEST_DLV3':
            net = TEST_DLV3(n_classes=num_classes,
                            n_blocks=[3, 4, 23, 3],
                            pyramids=[12, 24, 36],
                            grids=[1, 2, 4],
                            output_stride=8)
        elif model_name == 'TEST_FCDensenet':
            net = TEST_FCDensenet(out_channels=num_classes)
        elif model_name == 'TEST_LinkCeption':
            net = TEST_LinkCeption(num_classes=num_classes,
                                   pretrained=pretrained)
        elif model_name == 'TEST_LinkDensenet121':
            net = TEST_LinkDenseNet121(num_classes=num_classes,
                                       pretrained=pretrained)
        elif model_name == 'TEST_LinkDensenet161':
            net = TEST_LinkDenseNet161(num_classes=num_classes,
                                       pretrained=pretrained)
        elif model_name == 'TEST_Linknet50':
            net = TEST_Linknet50(num_classes=num_classes,
                                 pretrained=pretrained)
        elif model_name == 'TEST_Linknet101':
            net = TEST_Linknet101(num_classes=num_classes,
                                  pretrained=pretrained)
        elif model_name == 'TEST_Linknet152':
            net = TEST_Linknet152(num_classes=num_classes,
                                  pretrained=pretrained)
        elif model_name == 'TEST_LinkNext_Mnas':
            net = TEST_LinkNext_Mnas(num_classes=num_classes,
                                     pretrained=pretrained)
        elif model_name == 'TEST_Linknext':
            net = TEST_Linknext(num_classes=num_classes)
        elif model_name == 'TEST_OCNet_Base_Res101':
            net = TEST_OCNet_Base_Res101(num_classes=num_classes,
                                         pretrained=pretrained)
        elif model_name == 'TEST_OCNet_ASP_Res101':
            net = TEST_OCNet_ASP_Res101(num_classes=num_classes,
                                        pretrained=pretrained)
        elif model_name == 'TEST_OCNet_Pyr_Res101':
            net = TEST_OCNet_Pyr_Res101(num_classes=num_classes,
                                        pretrained=pretrained)
        elif model_name == 'TEST_OCNet_Base_Res152':
            net = TEST_OCNet_Base_Res152(num_classes=num_classes,
                                         pretrained=pretrained)
        elif model_name == 'TEST_OCNet_ASP_Res152':
            net = TEST_OCNet_ASP_Res152(num_classes=num_classes,
                                        pretrained=pretrained)
        elif model_name == 'TEST_OCNet_Pyr_Res152':
            net = TEST_OCNet_Pyr_Res152(num_classes=num_classes,
                                        pretrained=pretrained)
        elif model_name == 'TEST_Tiramisu57':
            net = TEST_Tiramisu57(num_classes=num_classes)
        elif model_name == 'TEST_Unet_nested_dilated':
            net = TEST_Unet_nested_dilated(n_classes=num_classes)
        elif model_name == 'TEST_Unet_plus_plus':
            net = Unet_Plus_Plus(in_channels=3, n_classes=num_classes)
        elif model_name == 'Tiramisu57':  # Tiramisu
            net = FCDenseNet57(n_classes=num_classes)
            if pretrained:
                print(
                    "Tiramisu67 Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'Tiramisu67':  # Tiramisu
            net = FCDenseNet67(n_classes=num_classes)
            if pretrained:
                print(
                    "Tiramisu67 Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'Tiramisu103':  # Tiramisu
            net = FCDenseNet103(n_classes=num_classes)
            if pretrained:
                print(
                    "Tiramisu103 Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'Unet':  # standard unet
            net = UNet(num_classes=num_classes)
            if pretrained:
                print(
                    "UNet Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'UNet256':  # Unet for 256px square imgs
            net = UNet256(in_shape=(3, 256, 256))
            if pretrained:
                print(
                    "UNet256 Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'UNet512':  # Unet for 512px square imgs
            net = UNet512(in_shape=(3, 512, 512))
            if pretrained:
                print(
                    "UNet512 Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'UNet1024':  # Unet for 1024px square imgs
            net = UNet1024(in_shape=(3, 1024, 1024))
            if pretrained:
                print(
                    "UNet1024 Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'UNet960':  # Another Unet specifically with 960px resolution
            net = UNet960(filters=12)
            if pretrained:
                print(
                    "UNet960 Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'unet_dilated':  # dilated unet
            net = uNetDilated(num_classes=num_classes)
        elif model_name == 'Unet_res':  # residual unet
            net = UNetRes(num_class=num_classes)
            if pretrained:
                print(
                    "UNet_res Does not have a pretrained model! Empty model has been created instead."
                )
        elif model_name == 'UNet_stack':  # Stacked Unet variation with resnet connections
            net = UNet_stack(input_size=(input_size, input_size), filters=12)
            if pretrained:
                print(
                    "UNet_stack Does not have a pretrained model! Empty model has been created instead."
                )
        else:
            raise Exception(
                'Combination of type: {} and model_name: {} is not valid'.
                format(model_type, model_name))

    return net
示例#11
0
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models.inception import InceptionAux

model = models.inception_v3()
model.fc = torch.nn.Linear(2048, 26)
model.aux_logits = InceptionAux(768, 26)

model.load_state_dict(
    torch.load('Best_inception_v3_fl_enhanced.pth.tar')['state_dict'])


class MyInception(nn.Module):
    def __init__(self, num_classes, aux_logits=True):
        super(MyInception, self).__init__()
        self.aux_logits = aux_logits
        self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3
        self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3
        self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3
        self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1
        self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3
        self.Mixed_5b = model.Mixed_5b
        self.Mixed_5c = model.Mixed_5c
        self.Mixed_5d = model.Mixed_5d
        self.Mixed_6a = model.Mixed_6a
        self.Mixed_6b = model.Mixed_6b
        self.Mixed_6c = model.Mixed_6c
        self.Mixed_6d = model.Mixed_6d
        self.Mixed_6e = model.Mixed_6e
        if aux_logits: