def get_model(args): if args.model_name == 'inception': if args.pre_trained_checkpoint is not None: model = models.inception_v3(pretrained=False, transform_input=False) model.fc = nn.Linear(2048, args.n_outputs) model.AuxLogits = InceptionAux(768, args.n_outputs) model.aux_logits = False new_conv = BasicConv2d(1, 32, kernel_size=3, stride=2) model.Conv2d_1a_3x3 = new_conv sd = torch.load(args.pre_trained_checkpoint)['model'] model.load_state_dict(sd) else: model = models.inception_v3(pretrained=True, transform_input=False) model.fc = nn.Linear(2048, args.n_outputs) model.AuxLogits = InceptionAux(768, args.n_outputs) model.aux_logits = False new_conv = BasicConv2d(1, 32, kernel_size=3, stride=2) first_layer_sd = model.Conv2d_1a_3x3.state_dict() first_layer_sd['conv.weight'] = first_layer_sd['conv.weight'].mean(dim=1, keepdim=True) new_conv.load_state_dict(first_layer_sd) model.Conv2d_1a_3x3 = new_conv elif args.model_name == 'resnet': if args.pre_trained_checkpoint is not None: model = models.resnet50(pretrained=False) model.fc = nn.Linear(in_features=2048, out_features=args.n_outputs, bias=True) new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) model.conv1 = new_conv sd = torch.load(args.pre_trained_checkpoint)['model'] model.load_state_dict(sd) else: model = models.resnet50(pretrained=True) model.fc = nn.Linear(in_features=2048, out_features=args.n_outputs, bias=True) new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) first_layer_sd = model.conv1.state_dict() first_layer_sd['weight'] = first_layer_sd['weight'].mean(dim=1, keepdim=True) new_conv.load_state_dict(first_layer_sd) model.conv1 = new_conv elif args.model_name == 'resnext': if args.pre_trained_checkpoint is not None: model = models.resnext50_32x4d(pretrained=False) model.fc = nn.Linear(in_features=2048, out_features=args.n_outputs, bias=True) new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) model.conv1 = new_conv sd = torch.load(args.pre_trained_checkpoint)['model'] model.load_state_dict(sd) else: model = models.resnext50_32x4d(pretrained=True) model.fc = nn.Linear(in_features=2048, out_features=args.n_outputs, bias=True) new_conv = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) first_layer_sd = model.conv1.state_dict() first_layer_sd['weight'] = first_layer_sd['weight'].mean(dim=1, keepdim=True) new_conv.load_state_dict(first_layer_sd) model.conv1 = new_conv return model
def __init__(self, num_classes=NLABEL, aux_logits=False, transform_input=False): super(Inception3, self).__init__() self.aux_logits = aux_logits self.transform_input = transform_input self.Conv2d_1a_3x3 = BasicConv2d(4, 32, kernel_size=3, stride=2) self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) self.Mixed_5b = InceptionA(192, pool_features=32) self.Mixed_5c = InceptionA(256, pool_features=64) self.Mixed_5d = InceptionA(288, pool_features=64) self.Mixed_6a = InceptionB(288) self.Mixed_6b = InceptionC(768, channels_7x7=128) self.Mixed_6c = InceptionC(768, channels_7x7=160) self.Mixed_6d = InceptionC(768, channels_7x7=160) self.Mixed_6e = InceptionC(768, channels_7x7=192) if aux_logits: self.AuxLogits = InceptionAux(768, num_classes) self.Mixed_7a = InceptionD(768) self.Mixed_7b = InceptionE(1280) self.Mixed_7c = InceptionE(2048) self.fc = nn.Linear(2048, num_classes) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): import scipy.stats as stats stddev = m.stddev if hasattr(m, 'stddev') else 0.1 X = stats.truncnorm(-2, 2, scale=stddev) values = torch.Tensor(X.rvs(m.weight.numel())) values = values.view(m.weight.size()) m.weight.data.copy_(values) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
def __init__(self, num_classes=1000, transform_input=True): super(InceptionV3UptoPool3, self).__init__() self.transform_input = transform_input self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) self.Mixed_5b = InceptionA(192, pool_features=32) self.Mixed_5c = InceptionA(256, pool_features=64) self.Mixed_5d = InceptionA(288, pool_features=64) self.Mixed_6a = InceptionB(288) self.Mixed_6b = InceptionC(768, channels_7x7=128) self.Mixed_6c = InceptionC(768, channels_7x7=160) self.Mixed_6d = InceptionC(768, channels_7x7=160) self.Mixed_6e = InceptionC(768, channels_7x7=192) self.AuxLogits = InceptionAux(768, num_classes) self.Mixed_7a = InceptionD(768) self.Mixed_7b = InceptionE(1280) self.Mixed_7c = InceptionE(2048) self.fc = nn.Linear(2048, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): import scipy.stats as stats stddev = m.stddev if hasattr(m, 'stddev') else 0.1 X = stats.truncnorm(-2, 2, scale=stddev) values = torch.Tensor(X.rvs(m.weight.data.numel())) values = values.view(m.weight.data.size()) m.weight.data.copy_(values) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_()
def __init__(self, pre=True): super().__init__() self.encoder = torchvision.models.inception_v3(pretrained=pre) conv1 = BasicConv2d(4, 32, kernel_size=3, stride=2) if pre: w = self.encoder.Conv2d_1a_3x3.conv.weight conv1.conv.weight = nn.Parameter(torch.cat((w, 0.5 * (w[:, :1, :, :] + w[:, 2:, :, :])), dim=1)) self.encoder.Conv2d_1a_3x3 = conv1 self.encoder.AuxLogits = InceptionAux(768, num_class()) self.encoder.fc = nn.Linear(2048, num_class())
def get_model(args): if args.model_name == 'inception': if args.pre_trained_checkpoint is not None: model = models.inception_v3(pretrained=False, transform_input=False) model.fc = nn.Linear(2048, args.n_outputs) model.AuxLogits = InceptionAux(768, args.n_outputs) model.aux_logits = False new_conv = BasicConv2d(1, 32, kernel_size=3, stride=2) model.Conv2d_1a_3x3 = new_conv sd = torch.load(args.pre_trained_checkpoint)['model'] model.load_state_dict(sd) else: print('Missing model.') return model
def __init__(self, num_classes=80, aux_logits=True, transform_input=False, apply_avgpool=False): super(Inception3, self).__init__() self.aux_logits = aux_logits self.transform_input = transform_input self.apply_avgpool = apply_avgpool self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) self.Mixed_5b = InceptionA(192, pool_features=32) self.Mixed_5c = InceptionA(256, pool_features=64) self.Mixed_5d = InceptionA(288, pool_features=64) self.Mixed_6a = InceptionB(288) self.Mixed_6b = InceptionC(768, channels_7x7=128) self.Mixed_6c = InceptionC(768, channels_7x7=160) self.Mixed_6d = InceptionC(768, channels_7x7=160) self.Mixed_6e = InceptionC(768, channels_7x7=192) if aux_logits: self.AuxLogits = InceptionAux(768, num_classes) self.Mixed_7a = InceptionD(768) self.Mixed_7b = InceptionE(1280) self.Mixed_7c = InceptionE(2048)
def get_model(model_type, model_name, num_classes, pretrained=True, **kwargs): """ :param model_type: (ModelType): type of model we're trying to obtain (classification or segmentation) :param model_name: (string): name of the model. By convention (for classification models) lowercase names represent pretrained model variants while Uppercase do not. :param num_classes: (int): number of classes to initialize with (this will replace the last classification layer or set the number of segmented classes) :param pretrained: (bool): whether to load the default pretrained version of the model NOTE! NOTE! For classification, the lowercase model names are the pretrained variants while the Uppercase model names are not. The only exception applies to torch.hub models (all efficientnet, mixnet, mobilenetv3, mnasnet, spnasnet variants) where a single lower-case string can be used for vanilla and pretrained versions. Otherwise, it is IN ERROR to specify an Uppercase model name variant with pretrained=True but one can specify a lowercase model variant with pretrained=False (default: True) :return: model """ if model_name not in get_supported_models( model_type) and not model_name.startswith('TEST'): raise ValueError( 'The supplied model name: {} was not found in the list of acceptable model names.' ' Use get_supported_models() to obtain a list of supported models.' .format(model_name)) print("INFO: Loading Model: -- " + model_name + " with number of classes: " + str(num_classes)) if model_type == ModelType.CLASSIFICATION: torch_hub_names = torch.hub.list('rwightman/gen-efficientnet-pytorch') if model_name in torch_hub_names: model = torch.hub.load('rwightman/gen-efficientnet-pytorch', model_name, pretrained=pretrained, num_classes=num_classes) else: # 1. Load model (pretrained or vanilla) fc_name = get_fc_names( model_name=model_name, model_type=model_type)[-1:][ 0] # we're only interested in the last layer name new_fc = None # Custom layer to replace with (if none, then it will be handled generically) if model_name in torch_models.__dict__: print('INFO: Loading torchvision model: {}\t Pretrained: {}'. format(model_name, pretrained)) model = torch_models.__dict__[model_name]( pretrained=pretrained ) # find a model included in the torchvision package else: net_list = [ 'fbresnet', 'inception', 'mobilenet', 'nasnet', 'polynet', 'resnext', 'se_resnet', 'senet', 'shufflenet', 'xception' ] if pretrained: print('INFO: Loading a pretrained model: {}'.format( model_name)) if 'dpn' in model_name: model = classification.__dict__[model_name]( pretrained=True ) # find a model included in the pywick classification package elif any(net_name in model_name for net_name in net_list): model = classification.__dict__[model_name]( pretrained='imagenet') else: print( 'INFO: Loading a vanilla model: {}'.format(model_name)) model = classification.__dict__[model_name]( pretrained=None ) # pretrained must be set to None for the extra models... go figure # 2. Create custom FC layers for non-standardized models if 'squeezenet' in model_name: final_conv = nn.Conv2d(512, num_classes, kernel_size=1) new_fc = nn.Sequential(nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AvgPool2d(13, stride=1)) model.num_classes = num_classes elif 'vgg' in model_name: new_fc = nn.Sequential(nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes)) elif 'inception3' in model_name.lower( ) or 'inception_v3' in model_name.lower(): # Replace the extra aux_logits FC layer if aux_logits are enabled if getattr(model, 'aux_logits', False): model.AuxLogits = InceptionAux(768, num_classes) elif 'dpn' in model_name.lower(): old_fc = getattr(model, fc_name) new_fc = nn.Conv2d(old_fc.in_channels, num_classes, kernel_size=1, bias=True) # 3. For standard FC layers (nn.Linear) perform a reflection lookup and generate a new FC if new_fc is None: old_fc = getattr(model, fc_name) new_fc = nn.Linear(old_fc.in_features, num_classes) # 4. perform replacement of the last FC / Linear layer with a new one setattr(model, fc_name, new_fc) return model elif model_type == ModelType.SEGMENTATION: """ Additional Segmentation Option Parameters ----------------------------------------- BiSeNet - :param backbone: (str, default: 'resnet18') The type of backbone to use (one of `{'resnet18'}`) - :param aux: (bool, default: False) Whether to output auxiliary loss (typically an FC loss to help with multi-class segmentation) DANet_ResnetXXX, DUNet_ResnetXXX, EncNet, OCNet_XXX_XXX, PSANet_XXX - :param aux: (bool, default: False) Whether to output auxiliary loss (typically an FC loss to help with multi-class segmentation) - :param backbone: (str, default: 'resnet101') The type of backbone to use (one of `{'resnet50', 'resnet101', 'resnet152'}`) - :param norm_layer (Pytorch nn.Module, default: nn.BatchNorm2d) The normalization layer to use. Typically it is not necessary to change this parameter unless you know what you're doing. DenseASPP_XXX - :param aux: (bool, default: False) Whether to output auxiliary loss (typically an FC loss to help with multi-class segmentation) - :param backbone: (str, default: 'densenet161') The type of backbone to use (one of `{'densenet121', 'densenet161', 'densenet169', 'densenet201'}`) - :param dilate_scale (int, default: 8) The size of the dilation to use (one of `{8, 16}`) - :param norm_layer (Pytorch nn.Module, default: nn.BatchNorm2d) The normalization layer to use. Typically it is not necessary to change this parameter unless you know what you're doing. DRNSeg - :param model_name: (str - required) The type of backbone to use. One of `{'DRN_C_42', 'DRN_C_58', 'DRN_D_38', 'DRN_D_54', 'DRN_D_105'}` EncNet_ResnetXXX - :param aux: (bool, default: False) Whether to output auxiliary loss (typically an FC loss to help with multi-class segmentation) - :param backbone: (str, default: 'resnet101') The type of backbone to use (one of `{'resnet50', 'resnet101', 'resnet152'}`) - :param norm_layer (Pytorch nn.Module, default: nn.BatchNorm2d) The normalization layer to use. Typically it is not necessary to change this parameter unless you know what you're doing. - :param se_loss (bool, default: True) Whether to compute se_loss - :param lateral (bool, default: False) frrn - :param model_type: (str - required) The type of model to use. One of `{'A', 'B'}` GCN, GCN_DENSENET, GCN_NASNET, GCN_PSP, GCN_RESNEXT - :param k: (int - optional) The size of global kernel GCN_PSP, GCN_RESNEXT, Unet_stack - :param input_size: (int - required) The size of output image (will be square) LinkCeption, 'LinkDenseNet121', 'LinkDenseNet161', 'LinkInceptionResNet', 'LinkNet18', 'LinkNet34', 'LinkNet50', 'LinkNet101', 'LinkNet152', 'LinkNeXt', 'CoarseLinkNet50' - :param num_channels: (int, default: 3) Number of channels in the image (e.g. 3 = RGB) - :param is_deconv: (bool, default: False) - :param decoder_kernel_size: (int, default: 3) Size of the decoder kernel PSPNet - :param backend: (str, default: densenet121) The type of extractor to use. One of `{'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'densenet121'}` RefineNet4Cascade, RefineNet4CascadePoolingImproved - :param input_shape: (tuple(int, int), default: (1, 512) - required!) Tuple representing input shape (num_channels, dim) - :param freeze_resnet: (bool, default: False) - whether to freeze the underlying resnet """ model_exists = False for m_name in get_supported_models(model_type): if model_name in m_name: model_exists = True break if model_exists: # Print warnings and helpful messages for nets that require additional configuration if model_name in [ 'GCN_PSP', 'GCN_RESNEXT', 'RefineNet4Cascade', 'RefineNet4CascadePoolingImproved', 'Unet_stack' ]: print( 'WARN: Did you remember to set the input_size parameter: (int) ?' ) elif model_name in [ 'RefineNet4Cascade', 'RefineNet4CascadePoolingImproved' ]: print( 'WARN: Did you remember to set the input_shape parameter: tuple(int, int)?' ) # logic to switch between different constructors if model_name in [ 'FusionNet', 'Enet', 'frrn', 'Tiramisu57', 'Tiramisu67', 'Tiramisu101' ] or model_name.startswith('UNet') and pretrained: # FusionNet print( "WARN: FusionNet, Enet, FRRN, Tiramisu, UNetXXX do not have a pretrained model! Empty model as been created instead." ) net = segmentation.__dict__[model_name](num_classes=num_classes, pretrained=pretrained, **kwargs) else: raise Exception( 'Combination of type: {} and model_name: {} is not valid'. format(model_type, model_name)) return net
def get_model(type, model_name, num_classes, input_size, pretrained=True): ''' :param type: str one of {'classification', 'segmentation'} :param model_name: str name of the model. By convention (for classification models) lowercase names represent pretrained model variants while Uppercase do not. :param num_classes: int number of classes to initialize with (this will replace the last classification layer or set the number of segmented classes) :param input_size: (int,int) Segmentation-only param. What size of input the network will accept e.g. (256, 256), (512, 512) :param pretrained: bool whether to load the default pretrained version of the model NOTE! NOTE! For classification, the lowercase model names are the pretrained variants while the Uppercase model names are not. It is IN ERROR to specify an Uppercase model name variant with pretrained=True but one can specify a lowercase model variant with pretrained=False (default: True) :return model ''' if model_name not in get_supported_models(type) and not model_name.startswith('TEST'): raise ValueError('The supplied model name: {} was not found in the list of acceptable model names.' ' Use get_supported_models() to obtain a list of supported models.') print("INFO: Loading Model: -- " + model_name + " with number of classes: " + str(num_classes)) if type == 'classification': # 1. Load model (pretrained or vanilla) fc_name = 'last_linear' # most common name of the last layer (to be replaced) if model_name in torch_models.__dict__: print('INFO: Loading a pretrained?: {} model: {}'.format(pretrained, model_name)) model = torch_models.__dict__[model_name](pretrained=pretrained) # find a model included in the torchvision package if 'densenet' in model_name: # apparently densenet is special.. fc_name = 'classifier' # the name of the last layer to be replaced in torchvision models else: fc_name = 'fc' # the name of the last layer to be replaced in torchvision models else: if pretrained: print('INFO: Loading a pretrained model: {}'.format(model_name)) if 'dpn' in model_name: model = classification.__dict__[model_name](pretrained=True) # find a model included in the wick classification package elif 'inception' in model_name or 'nasnet' in model_name or 'polynet' in model_name or 'resnext' in model_name\ or 'se_resnet' in model_name or 'xception' in model_name: model = classification.__dict__[model_name](pretrained='imagenet') else: print('INFO: Loading a vanilla model: {}'.format(model_name)) model = classification.__dict__[model_name](pretrained=None) # pretrained must be set to None for the extra models... go figure # 2. Tweak FC layer with the num_classes provided # Custom handling of models that have non-standard classifier layers if 'squeezenet' in model_name: final_conv = nn.Conv2d(512, num_classes, kernel_size=1) classifier = nn.Sequential( nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AvgPool2d(13, stride=1) ) model.classifier = classifier model.num_classes = num_classes elif 'vgg' in model_name: classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes) ) model.classifier = classifier elif 'inception3' in model_name.lower() or 'inception_v3' in model_name.lower(): # Two FC layers if aux_logits are enabled if getattr(model, 'aux_logits', False): model.AuxLogits = InceptionAux(768, num_classes) model.fc = nn.Linear(2048, num_classes) elif 'dpn' in model_name.lower(): old_fc = getattr(model, fc_name) new_fc = nn.Conv2d(old_fc.in_channels, num_classes, kernel_size=1, bias=True) setattr(model, fc_name, new_fc) else: # perform standard replacement of the last FC / Linear layer with a new one old_fc = getattr(model, fc_name) new_fc = nn.Linear(old_fc.in_features, num_classes) setattr(model, fc_name, new_fc) return model elif type == 'segmentation': if model_name == 'Enet': # standard enet net = ENet(num_classes=num_classes) if pretrained: print("WARN: Enet does not have a pretrained model! Empty model as been created instead.") elif model_name == 'deeplabv2_ASPP': # Deeplab Atrous Convolutions net = DeepLabv2_ASPP(num_classes=num_classes, pretrained=pretrained) elif model_name == 'deeplabv2_FOV': # Deeplab FOV net = DeepLabv2_FOV(num_classes=num_classes, pretrained=pretrained) elif model_name == 'deeplabv3': # Deeplab V3! net = DeepLabv3(num_classes=num_classes, pretrained=pretrained) elif model_name == 'deeplabv3_Plus': # Deeplab V3! net = DeepLabv3_plus(num_classes=num_classes, pretrained=pretrained) elif 'DRN_' in model_name: net = DRNSeg(model_name=model_name, classes=num_classes, pretrained=pretrained) elif model_name == 'FRRN_A': # FRRN net = frrn(num_classes=num_classes, model_type='A') if pretrained: print("FRRN_A Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'FRRN_B': # FRRN net = frrn(num_classes=num_classes, model_type='B') if pretrained: print("FRRN_B Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'FusionNet': # FusionNet net = FusionNet(num_classes=num_classes) if pretrained: print("FusionNet Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'GCN': # GCN Resnet net = GCN(num_classes=num_classes, pretrained=pretrained) elif model_name == 'GCN_VisDa': # Another GCN Implementation net = GCN_VisDa(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'GCN_Densenet': # Another GCN Implementation net = GCN_DENSENET(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'GCN_PSP': # Another GCN Implementation net = GCN_PSP(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'GCN_NASNetA': # Another GCN Implementation net = GCN_NASNET(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'GCN_Resnext': # Another GCN Implementation net = GCN_RESNEXT(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'Linknet': # Linknet34 net = LinkNet34(num_classes=num_classes, pretrained=pretrained) elif model_name == 'PSPNet': net = PSPNet(num_classes=num_classes, pretrained=pretrained, backend='resnet101') elif model_name == 'Resnet_DUC': net = ResNetDUC(num_classes=num_classes, pretrained=pretrained) elif model_name == 'Resnet_DUC_HDC': net = ResNetDUCHDC(num_classes=num_classes, pretrained=pretrained) elif model_name == 'Resnet_GCN': # GCN Resnet 2 net = ResnetGCN(num_classes=num_classes, pretrained=pretrained) elif model_name == 'Segnet': # standard segnet net = SegNet(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_DiLinknet': net = TEST_DiLinknet(num_classes=num_classes, pretrained=False) elif model_name == 'TEST_DLR_Resnet': net = create_DLR_V3_pretrained(num_classes=num_classes) elif model_name == 'TEST_DLX_Resnet': net = create_DLX_V3_pretrained(num_classes=num_classes) elif model_name == 'TEST_PSPNet2': net = TEST_PSPNet2(num_classes=num_classes) elif model_name == 'TEST_DLV2': net = TEST_DLV2(n_classes=num_classes, n_blocks=[3, 4, 23, 3], pyramids=[6, 12, 18, 24]) net = TEST_DLV3_Xception(n_classes=num_classes, os=8, pretrained=True, _print=False) elif model_name == 'TEST_DLV3': net = TEST_DLV3(n_classes=num_classes, n_blocks=[3, 4, 23, 3], pyramids=[12, 24, 36], grids=[1, 2, 4], output_stride=8) elif model_name == 'TEST_LinkCeption': net = TEST_LinkCeption(num_classes=num_classes) elif model_name == 'TEST_LinkDensenet121': net = TEST_LinkDenseNet121(num_classes=num_classes) elif model_name == 'TEST_Linknet50': net = TEST_Linknet101(num_classes=num_classes) elif model_name == 'TEST_Linknet101': net = TEST_Linknet101(num_classes=num_classes) elif model_name == 'TEST_Linknet152': net = TEST_Linknet152(num_classes=num_classes) elif model_name == 'TEST_Linknext': net = TEST_Linknext(num_classes=num_classes) elif model_name == 'TEST_FCDensenet': net = TEST_FCDensenet(out_channels=num_classes) elif model_name == 'TEST_Tiramisu57': net = TEST_Tiramisu57(num_classes=num_classes) elif model_name == 'TEST_Unet_nested_dilated': net = TEST_Unet_nested_dilated(n_classes=num_classes) elif model_name == 'TEST_Unet_plus_plus': net = Unet_Plus_Plus(in_channels=3, n_classes=num_classes) elif model_name == 'Tiramisu57': # Tiramisu net = FCDenseNet57(n_classes=num_classes) if pretrained: print("Tiramisu67 Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'Tiramisu67': # Tiramisu net = FCDenseNet67(n_classes=num_classes) if pretrained: print("Tiramisu67 Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'Tiramisu103': # Tiramisu net = FCDenseNet103(n_classes=num_classes) if pretrained: print("Tiramisu103 Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'Unet': # standard unet net = UNet(num_classes=num_classes) if pretrained: print("UNet Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'UNet256': # Unet for 256px square imgs net = UNet256(in_shape=(3,256,256)) if pretrained: print("UNet256 Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'UNet512': # Unet for 512px square imgs net = UNet512(in_shape=(3, 512, 512)) if pretrained: print("UNet512 Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'UNet1024': # Unet for 1024px square imgs net = UNet1024(in_shape=(3, 1024, 1024)) if pretrained: print("UNet1024 Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'UNet960': # Another Unet specifically with 960px resolution net = UNet960(filters=12) if pretrained: print("UNet960 Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'unet_dilated': # dilated unet net = uNetDilated(num_classes=num_classes) elif model_name == 'Unet_res': # residual unet net = UNetRes(num_class=num_classes) if pretrained: print("UNet_res Does not have a pretrained model! Empty model has been created instead.") elif model_name == 'UNet_stack': # Stacked Unet variation with resnet connections net = UNet_stack(input_size=(input_size, input_size), filters=12) if pretrained: print("UNet_stack Does not have a pretrained model! Empty model has been created instead.") else: raise Exception('Combination of type: {} and model_name: {} is not valid'.format(type, model_name)) return net
def __init__(self, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000, aux_logits=True, transform_input=False): super(Inception3Fc, self).__init__() model_inception = inception_v3(pretrained=True) self.aux_logits = aux_logits self.transform_input = transform_input self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) self.Mixed_5b = InceptionA(192, pool_features=32) self.Mixed_5c = InceptionA(256, pool_features=64) self.Mixed_5d = InceptionA(288, pool_features=64) self.Mixed_6a = InceptionB(288) self.Mixed_6b = InceptionC(768, channels_7x7=128) self.Mixed_6c = InceptionC(768, channels_7x7=160) self.Mixed_6d = InceptionC(768, channels_7x7=160) self.Mixed_6e = InceptionC(768, channels_7x7=192) if aux_logits: self.AuxLogits = InceptionAux(768, class_num) self.Mixed_7a = InceptionD(768) self.Mixed_7b = InceptionE(1280) self.Mixed_7c = InceptionE(2048) self.fc = nn.Linear(2048, class_num) # self.avgpool = model_xception.avgpool self.feature_layers = nn.Sequential( self.Conv2d_1a_3x3, self.Conv2d_2a_3x3, self.Conv2d_2b_3x3, self.Conv2d_3b_1x1, self.Conv2d_4a_3x3, self.Mixed_5b, self.Mixed_5c, self.Mixed_5d, self.Mixed_6a, self.Mixed_6b, self.Mixed_6c, self.Mixed_6d, self.Mixed_6e, self.Mixed_7a, self.Mixed_7b, self.Mixed_7c, ) #################### self.use_bottleneck = use_bottleneck self.new_cls = new_cls # print("classes inside network",new_cls) if new_cls: if self.use_bottleneck: print(bottleneck_dim) self.bottleneck = nn.Linear(model_inception.fc.in_features, bottleneck_dim) self.fc = nn.Linear(bottleneck_dim, class_num) self.bottleneck.apply(init_weights) self.fc.apply(init_weights) self.__in_features = bottleneck_dim else: self.fc = nn.Linear(model_inception.fc.in_features, class_num) self.fc.apply(init_weights) self.__in_features = model_inception.fc.in_features else: self.fc = model_inception.fc self.__in_features = model_inception.fc.in_features
def get_model(model_type, model_name, num_classes, input_size, pretrained=True): """ :param model_type: (ModelType): type of model we're trying to obtain (classification or segmentation) :param model_name: (string): name of the model. By convention (for classification models) lowercase names represent pretrained model variants while Uppercase do not. :param num_classes: (int): number of classes to initialize with (this will replace the last classification layer or set the number of segmented classes) :param input_size: (int,int): Segmentation-only param. What size of input the network will accept e.g. (256, 256), (512, 512) :param pretrained: (bool): whether to load the default pretrained version of the model NOTE! NOTE! For classification, the lowercase model names are the pretrained variants while the Uppercase model names are not. It is IN ERROR to specify an Uppercase model name variant with pretrained=True but one can specify a lowercase model variant with pretrained=False (default: True) :return: model """ if model_name not in get_supported_models( model_type) and not model_name.startswith('TEST'): raise ValueError( 'The supplied model name: {} was not found in the list of acceptable model names.' ' Use get_supported_models() to obtain a list of supported models.' .format(model_name)) print("INFO: Loading Model: -- " + model_name + " with number of classes: " + str(num_classes)) if model_type == ModelType.CLASSIFICATION: # 1. Load model (pretrained or vanilla) fc_name = get_fc_names( model_name=model_name, model_type=model_type)[-1:][ 0] # we're only interested in the last layer name new_fc = None # Custom layer to replace with (if none, then it will be handled generically) if model_name in torch_models.__dict__: print( 'INFO: Loading torchvision model: {}\t Pretrained: {}'.format( model_name, pretrained)) model = torch_models.__dict__[model_name]( pretrained=pretrained ) # find a model included in the torchvision package else: net_list = [ 'fbresnet', 'inception', 'mobilenet', 'nasnet', 'polynet', 'resnext', 'se_resnet', 'senet', 'shufflenet', 'xception' ] if pretrained: print( 'INFO: Loading a pretrained model: {}'.format(model_name)) if 'dpn' in model_name: model = classification.__dict__[model_name]( pretrained=True ) # find a model included in the pywick classification package elif any(net_name in model_name for net_name in net_list): model = classification.__dict__[model_name]( pretrained='imagenet') else: print('INFO: Loading a vanilla model: {}'.format(model_name)) model = classification.__dict__[model_name]( pretrained=None ) # pretrained must be set to None for the extra models... go figure # 2. Create custom FC layers for non-standardized models if 'squeezenet' in model_name: final_conv = nn.Conv2d(512, num_classes, kernel_size=1) new_fc = nn.Sequential(nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AvgPool2d(13, stride=1)) model.num_classes = num_classes elif 'vgg' in model_name: new_fc = nn.Sequential(nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes)) elif 'inception3' in model_name.lower( ) or 'inception_v3' in model_name.lower(): # Replace the extra aux_logits FC layer if aux_logits are enabled if getattr(model, 'aux_logits', False): model.AuxLogits = InceptionAux(768, num_classes) elif 'dpn' in model_name.lower(): old_fc = getattr(model, fc_name) new_fc = nn.Conv2d(old_fc.in_channels, num_classes, kernel_size=1, bias=True) # 3. For standard FC layers (nn.Linear) perform a reflection lookup and generate a new FC if new_fc is None: old_fc = getattr(model, fc_name) new_fc = nn.Linear(old_fc.in_features, num_classes) # 4. perform replacement of the last FC / Linear layer with a new one setattr(model, fc_name, new_fc) return model elif model_type == ModelType.SEGMENTATION: if model_name == 'Enet': # standard enet net = ENet(num_classes=num_classes) if pretrained: print( "WARN: Enet does not have a pretrained model! Empty model as been created instead." ) elif model_name == 'deeplabv2_ASPP': # Deeplab Atrous Convolutions net = DeepLabv2_ASPP(num_classes=num_classes, pretrained=pretrained) elif model_name == 'deeplabv2_FOV': # Deeplab FOV net = DeepLabv2_FOV(num_classes=num_classes, pretrained=pretrained) elif model_name == 'deeplabv3': # Deeplab V3! net = DeepLabv3(num_classes=num_classes, pretrained=pretrained) elif model_name == 'deeplabv3_Plus': # Deeplab V3! net = DeepLabv3_plus(num_classes=num_classes, pretrained=pretrained) elif 'DRN_' in model_name: net = DRNSeg(model_name=model_name, classes=num_classes, pretrained=pretrained) elif model_name == 'FRRN_A': # FRRN net = frrn(num_classes=num_classes, model_type='A') if pretrained: print( "FRRN_A Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'FRRN_B': # FRRN net = frrn(num_classes=num_classes, model_type='B') if pretrained: print( "FRRN_B Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'FusionNet': # FusionNet net = FusionNet(num_classes=num_classes) if pretrained: print( "FusionNet Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'GCN': # GCN Resnet net = GCN(num_classes=num_classes, pretrained=pretrained) elif model_name == 'GCN_VisDa': # Another GCN Implementation net = GCN_VisDa(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'GCN_Densenet': # Another GCN Implementation net = GCN_DENSENET(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'GCN_PSP': # Another GCN Implementation net = GCN_PSP(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'GCN_NASNetA': # Another GCN Implementation net = GCN_NASNET(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'GCN_Resnext': # Another GCN Implementation net = GCN_RESNEXT(num_classes=num_classes, input_size=input_size, pretrained=pretrained) elif model_name == 'Linknet': # Linknet34 net = LinkNet34(num_classes=num_classes, pretrained=pretrained) elif model_name == 'PSPNet': net = PSPNet(num_classes=num_classes, pretrained=pretrained, backend='resnet101') elif model_name == 'RefineNet4Cascade': net = RefineNet4Cascade((1, input_size), num_classes=num_classes, pretrained=pretrained) elif model_name == 'RefineNet4CascadePoolingImproved': net = RefineNet4Cascade((1, input_size), num_classes=num_classes, pretrained=pretrained) elif model_name == 'Resnet_DUC': net = ResNetDUC(num_classes=num_classes, pretrained=pretrained) elif model_name == 'Resnet_DUC_HDC': net = ResNetDUCHDC(num_classes=num_classes, pretrained=pretrained) elif model_name == 'Resnet_GCN': # GCN Resnet 2 net = ResnetGCN(num_classes=num_classes, pretrained=pretrained) elif model_name == 'Segnet': # standard segnet net = SegNet(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_BiSeNet_Res18': net = TEST_BiSeNet_Res18(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_DANet_Res50': net = TEST_DANet_Res50(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_DANet_Res101': net = TEST_DANet_Res101(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_DANet_Res152': net = TEST_DANet_Res152(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_DenseASPP_121': net = TEST_DenseASPP_121(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_DenseASPP_161': net = TEST_DenseASPP_161(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_DenseASPP_169': net = TEST_DenseASPP_169(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_DenseASPP_201': net = TEST_DenseASPP_201(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_DiLinknet': net = TEST_DiLinknet(num_classes=num_classes, pretrained=False) elif model_name == 'TEST_DLR_Resnet': net = create_DLR_V3_pretrained(num_classes=num_classes) elif model_name == 'TEST_DLX_Resnet': net = create_DLX_V3_pretrained(num_classes=num_classes) elif model_name == 'TEST_EncNet_Res101': net = TEST_EncNet_Res101(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_EncNet_Res152': net = TEST_EncNet_Res152(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_PSPNet2': net = TEST_PSPNet2(num_classes=num_classes) elif model_name == 'TEST_DLV2': net = TEST_DLV2(n_classes=num_classes, n_blocks=[3, 4, 23, 3], pyramids=[6, 12, 18, 24]) net = TEST_DLV3_Xception(n_classes=num_classes, os=8, pretrained=True, _print=False) elif model_name == 'TEST_DLV3': net = TEST_DLV3(n_classes=num_classes, n_blocks=[3, 4, 23, 3], pyramids=[12, 24, 36], grids=[1, 2, 4], output_stride=8) elif model_name == 'TEST_FCDensenet': net = TEST_FCDensenet(out_channels=num_classes) elif model_name == 'TEST_LinkCeption': net = TEST_LinkCeption(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_LinkDensenet121': net = TEST_LinkDenseNet121(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_LinkDensenet161': net = TEST_LinkDenseNet161(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_Linknet50': net = TEST_Linknet50(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_Linknet101': net = TEST_Linknet101(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_Linknet152': net = TEST_Linknet152(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_LinkNext_Mnas': net = TEST_LinkNext_Mnas(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_Linknext': net = TEST_Linknext(num_classes=num_classes) elif model_name == 'TEST_OCNet_Base_Res101': net = TEST_OCNet_Base_Res101(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_OCNet_ASP_Res101': net = TEST_OCNet_ASP_Res101(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_OCNet_Pyr_Res101': net = TEST_OCNet_Pyr_Res101(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_OCNet_Base_Res152': net = TEST_OCNet_Base_Res152(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_OCNet_ASP_Res152': net = TEST_OCNet_ASP_Res152(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_OCNet_Pyr_Res152': net = TEST_OCNet_Pyr_Res152(num_classes=num_classes, pretrained=pretrained) elif model_name == 'TEST_Tiramisu57': net = TEST_Tiramisu57(num_classes=num_classes) elif model_name == 'TEST_Unet_nested_dilated': net = TEST_Unet_nested_dilated(n_classes=num_classes) elif model_name == 'TEST_Unet_plus_plus': net = Unet_Plus_Plus(in_channels=3, n_classes=num_classes) elif model_name == 'Tiramisu57': # Tiramisu net = FCDenseNet57(n_classes=num_classes) if pretrained: print( "Tiramisu67 Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'Tiramisu67': # Tiramisu net = FCDenseNet67(n_classes=num_classes) if pretrained: print( "Tiramisu67 Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'Tiramisu103': # Tiramisu net = FCDenseNet103(n_classes=num_classes) if pretrained: print( "Tiramisu103 Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'Unet': # standard unet net = UNet(num_classes=num_classes) if pretrained: print( "UNet Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'UNet256': # Unet for 256px square imgs net = UNet256(in_shape=(3, 256, 256)) if pretrained: print( "UNet256 Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'UNet512': # Unet for 512px square imgs net = UNet512(in_shape=(3, 512, 512)) if pretrained: print( "UNet512 Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'UNet1024': # Unet for 1024px square imgs net = UNet1024(in_shape=(3, 1024, 1024)) if pretrained: print( "UNet1024 Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'UNet960': # Another Unet specifically with 960px resolution net = UNet960(filters=12) if pretrained: print( "UNet960 Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'unet_dilated': # dilated unet net = uNetDilated(num_classes=num_classes) elif model_name == 'Unet_res': # residual unet net = UNetRes(num_class=num_classes) if pretrained: print( "UNet_res Does not have a pretrained model! Empty model has been created instead." ) elif model_name == 'UNet_stack': # Stacked Unet variation with resnet connections net = UNet_stack(input_size=(input_size, input_size), filters=12) if pretrained: print( "UNet_stack Does not have a pretrained model! Empty model has been created instead." ) else: raise Exception( 'Combination of type: {} and model_name: {} is not valid'. format(model_type, model_name)) return net
import torch import torch.nn as nn import torchvision.models as models from torchvision.models.inception import InceptionAux model = models.inception_v3() model.fc = torch.nn.Linear(2048, 26) model.aux_logits = InceptionAux(768, 26) model.load_state_dict( torch.load('Best_inception_v3_fl_enhanced.pth.tar')['state_dict']) class MyInception(nn.Module): def __init__(self, num_classes, aux_logits=True): super(MyInception, self).__init__() self.aux_logits = aux_logits self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3 self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3 self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3 self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1 self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3 self.Mixed_5b = model.Mixed_5b self.Mixed_5c = model.Mixed_5c self.Mixed_5d = model.Mixed_5d self.Mixed_6a = model.Mixed_6a self.Mixed_6b = model.Mixed_6b self.Mixed_6c = model.Mixed_6c self.Mixed_6d = model.Mixed_6d self.Mixed_6e = model.Mixed_6e if aux_logits: