Exemple #1
0
    def __init__(self, model, n_classes):
        # Last layer (the classifier), fully convolutional
        from torchvision.models.segmentation.fcn import FCNHead

        # Modify classifier layer to predict requested number of classes
        classifier = FCNHead(model.classifier[0].in_channels, n_classes)

        # Modify aux classifier layer to predict requested number of classes
        aux_classifier = FCNHead(model.aux_classifier[0].in_channels,
                                 n_classes)

        # Call Module constructor to initialize network module
        super(torchvision.models.segmentation.fcn.FCN,
              self).__init__(model.backbone, classifier, aux_classifier)
    def __init__(self, num_convs=3, fine_tune=False):
        super(MyModel, self).__init__()
        assert 8 >= num_convs > 1, "Cannot have less than 1 or greater than 8 convolutional+pooling layers."
        self.num_convs = num_convs
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        # Simple Segmentation Model
        self.backbone = models.resnet50(
            pretrained=True, replace_stride_with_dilation=[False, True, True])
        self.backbone = IntermediateLayerGetter(
            self.backbone, return_layers={'layer4': 'out'})
        if not fine_tune:
            for p in self.backbone.parameters():
                p.requires_grad = False
        in_channels = 1028
        classes = 200  # actually classes

        # Graph convolution layers and graph pooling layers
        self.classifier = FCNHead(in_channels, classes)
        self.convolutions = dict()
        for conv in range(num_convs):
            self.convolutions['conv' + str(conv + 1)] = GCNConv(2048, 2048).to(
                self.device)
            self.convolutions['pool' + str(conv + 1)] = TopKPooling(
                2048, ratio=.6).to(self.device)

        #Final Output
        self.lin1 = torch.nn.Linear(2048, 1024).to(self.device)
        self.lin2 = torch.nn.Linear(1024, 1024).to(self.device)
        self.lin3 = torch.nn.Linear(1024, classes).to(self.device)
        self.act1 = torch.nn.ReLU().to(self.device)
        self.act2 = torch.nn.ReLU().to(self.device)
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True, **kwargs):
    backbone = resnet.__dict__[backbone_name](
        pretrained=pretrained_backbone,
        replace_stride_with_dilation=[False, True, True],
        resnet_local_path=kwargs["resnet_local_path"] if "resnet_local_path" in kwargs else None
    )

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        inplanes = 1024
        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {
        'deeplabv3': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    inplanes = 2048
    classifier = model_map[name][0](inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model
Exemple #4
0
def _segm_resnet(name,
                 backbone_name,
                 num_classes,
                 aux,
                 pretrained_backbone=True):
    backbone = resnet.__dict__[backbone_name](
        pretrained=pretrained_backbone,
        replace_stride_with_dilation=[False, True, True])

    backbone.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=(7, 7),
                               stride=(2, 2),
                               padding=(3, 3),
                               bias=False)

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        inplanes = 1024
        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {
        'fcn': (FCNHead, FCN),
    }
    inplanes = 2048
    classifier = model_map[name][0](inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model
Exemple #5
0
def _segm_resnet(name, backbone_name, num_classes, aux, **kwargs):
    # FIXME: 1000 and _
    if isinstance(backbone_name, dict):
        backbone = model_utils.which_architecture(backbone_name['arch'],
                                                  backbone_name['customs'])
    elif os.path.isfile(backbone_name):
        backbone, _ = model_utils.which_network_classification(
            backbone_name, 1000, **kwargs)
    else:
        backbone = model_utils.which_architecture(backbone_name, **kwargs)

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        layer3 = list(backbone.layer3)[-1]
        inplanes = list(layer3.children())[-2].num_features
        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {
        'deeplabv3': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    layer4 = list(backbone.layer4)[-1]
    inplanes = list(layer4.children())[-2].num_features
    classifier = model_map[name][0](inplanes, num_classes)
    base_model = model_map[name][1]

    model = ColourTransferModel(
        base_model(backbone, classifier, aux_classifier), **kwargs)
    return model
Exemple #6
0
    def __init__(self, num_classes, backbone_fn, chip_size=224):
        super().__init__()
        if getattr(backbone_fn, '_is_multispectral', False):
            self.backbone = create_body(backbone_fn,
                                        pretrained=True,
                                        cut=_get_backbone_meta(
                                            backbone_fn.__name__)['cut'])
        else:
            self.backbone = create_body(backbone_fn, pretrained=True)

        backbone_name = backbone_fn.__name__

        ## Support for different backbones
        if "densenet" in backbone_name or "vgg" in backbone_name:
            hookable_modules = list(self.backbone.children())[0]
        else:
            hookable_modules = list(self.backbone.children())

        if "vgg" in backbone_name:
            modify_dilation_index = -5
        else:
            modify_dilation_index = -2

        if backbone_name == 'resnet18' or backbone_name == 'resnet34':
            module_to_check = 'conv'
        else:
            module_to_check = 'conv2'

        ## Hook at the index where we need to get the auxillary logits out
        self.hook = hook_output(hookable_modules[modify_dilation_index])

        custom_idx = 0
        for i, module in enumerate(hookable_modules[modify_dilation_index:]):
            dilation = 2 * (i + 1)
            padding = 2 * (i + 1)
            for n, m in module.named_modules():
                if module_to_check in n:
                    m.dilation, m.padding, m.stride = (dilation, dilation), (
                        padding, padding), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)

            if "vgg" in backbone_fn.__name__:
                if isinstance(module, nn.Conv2d):
                    dilation = 2 * (custom_idx + 1)
                    padding = 2 * (custom_idx + 1)
                    module.dilation, module.padding, module.stride = (
                        dilation, dilation), (padding, padding), (1, 1)
                    custom_idx += 1

        ## returns the size of various activations
        feature_sizes = model_sizes(self.backbone, size=(chip_size, chip_size))
        ## Geting the number of channel persent in stored activation inside of the hook
        num_channels_aux_classifier = self.hook.stored.shape[1]
        ## Get number of channels in the last layer
        num_channels_classifier = feature_sizes[-1][1]

        self.classifier = DeepLabHead(num_channels_classifier, num_classes)
        self.aux_classifier = FCNHead(num_channels_aux_classifier, num_classes)
Exemple #7
0
def _create_deeplab(num_class, backbone, pretrained=True, **kwargs):
    '''
    Create default torchvision pretrained model with resnet101.
    '''
    model = models.segmentation.deeplabv3_resnet101(pretrained=True, progress=True, **kwargs)
    model = _DeepLabOverride(model.backbone, model.classifier, model.aux_classifier)
    model.classifier = DeepLabHead(2048, num_class)
    model.aux_classifier = FCNHead(1024, num_class)

    return model
Exemple #8
0
    def __init__(self, num_classes=23):
        super(SegmentationNN, self).__init__()

        #######################################################################
        #                             YOUR CODE                               #
        #######################################################################
        #         self.backbone = models.vgg16_bn(pretrained=True)
        #         self.backbone.classifier = nn.Sequential(
        #             nn.Conv2d(512, 4096, kernel_size=1),
        #             nn.BatchNorm2d(4096),
        #             nn.ReLU(True),
        #             nn.Conv2d(4096, 4096, kernel_size=1),
        #             nn.BatchNorm2d(4096),
        #             nn.ReLU(True),
        #             nn.Conv2d(4096, num_classes, kernel_size=1)
        #         )
        #         self.avgpool = nn.AvgPool2d(kernel_size=2)
        #         self.upsample1 = nn.Sequential(
        #             nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2,
        #                                             stride=2, padding=0, output_padding=1),
        #             nn.BatchNorm2d(num_classes),
        #             nn.ReLU()
        #             )
        #         self.upsample2 = nn.Sequential(
        #             nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2,
        #                                stride=2, padding=0),
        #             nn.BatchNorm2d(num_classes),
        #             nn.ReLU()
        #         )
        #         self.upsample3 = nn.Sequential(
        #             nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4,
        #                                stride=8, padding=1, dilation=3),
        # #             nn.BatchNorm2d(num_classes),
        # #             nn.ReLU(True)
        #         )

        #         self.conv1 = nn.Sequential(
        #             nn.Conv2d(256, num_classes, 1),
        #             nn.BatchNorm2d(num_classes),
        #             nn.ReLU()
        #         )
        #         self.conv2 = nn.Sequential(
        #             nn.Conv2d(512, num_classes, 1),
        #             nn.BatchNorm2d(num_classes),
        #             nn.ReLU()
        #         )
        self.fcn = models.segmentation.fcn_resnet101(pretrained=True)
        self.fcn.classifier = FCNHead(2048, num_classes)
        self.fcn.aux_classifier = None

        for name, param in self.fcn.named_parameters():
            if "classifier" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
Exemple #9
0
 def __init__(self, in_chan=3, out_chan=2, pretrained=False):
     super(FCN50, self).__init__()
     self.model = torchvision.models.segmentation.fcn_resnet50(
         pretrained=False, pretrained_backbone=pretrained)
     self.model.classifier = FCNHead(2048, out_chan)
     if in_chan != 3:
         self.model.backbone.conv1 = nn.Conv2d(in_chan,
                                               64,
                                               kernel_size=7,
                                               stride=2,
                                               padding=3,
                                               bias=False)
def _segm_resnet(name, backbone_name, num_classes, aux, **kwargs):
    # FIXME: 1000 and _
    if isinstance(backbone_name, dict):
        if 'pretrained' in kwargs:
            del kwargs['pretrained']
        backbone = model_utils.which_architecture(
            backbone_name['arch'], backbone_name['customs']
        )
    elif os.path.isfile(backbone_name):
        if 'pretrained' in kwargs:
            del kwargs['pretrained']
        backbone, _ = model_utils.which_network_classification(
            backbone_name, 1000, **kwargs
        )
    else:
        backbone = model_utils.which_architecture(backbone_name, **kwargs)

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        layer3 = list(backbone.layer3)[-1]
        # Depending on Bottleneck or basic choose the num_features
        if isinstance(layer3, (cresnet.Bottleneck, presnet.Bottleneck)):
            inplanes = list(layer3.children())[-2].num_features
        else:
            inplanes = list(layer3.children())[-1].num_features
        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {
        'deeplabv3': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    layer4 = list(backbone.layer4)[-1]
    if isinstance(layer4, (cresnet.Bottleneck, presnet.Bottleneck)):
        inplanes = list(layer4.children())[-2].num_features
    else:
        inplanes = list(layer4.children())[-1].num_features
    classifier = model_map[name][0](inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model
def pre_fcn_resnet101(in_channel, out_channel):
    model = fcn_resnet101(pretrained=False, progress=False)
    url = "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth"  # COCO
    model_dict = model.state_dict()
    pretrained_dict = model_zoo.load_url(url, progress=False)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    model.backbone.conv1 = nn.Conv2d(in_channel,
                                     64,
                                     kernel_size=7,
                                     stride=2,
                                     padding=3,
                                     bias=False)
    model.classifier = FCNHead(2048, out_channel)
    return model
Exemple #12
0
def deeplabv3(pretrained=False,
              resnet="res103",
              head_in_ch=2048,
              num_classes=21):
    resnet = {"res101": resnet101, "res103": resnet103}[resnet]

    net = SmallDeepLab(backbone=IntermediateLayerGetter(resnet103(
        pretrained=True, replace_stride_with_dilation=[False, True, True]),
                                                        return_layers={
                                                            'layer2': 'res2',
                                                            'layer4': 'out'
                                                        }),
                       classifier=DeepLabHead(head_in_ch, num_classes),
                       aux_classifier=FCNHead(head_in_ch // 2, num_classes))
    if pretrained:
        state_dict = load_state_dict_from_url(
            'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
            progress=True)
        net.load_state_dict(state_dict)
    return net
Exemple #13
0
def createFCNResnet101(outputchannels=1, feature_extract=True):
    """FCNResnet101 class with custom head
    Args:
        outputchannels (int, optional): The number of output channels
        in your dataset masks. Defaults to 1.
        feature_extract (bool, optional): If False the whole model is
        trained otherwise only the classifier (head) is trained
    Returns:
        model: Returns the FCN model with the ResNet101 backbone.
    """
    model = models.segmentation.fcn_resnet101(pretrained=True, progress=True)
    print(model)
    if feature_extract:
        for param in model.parameters():
            param.requires_grad = False

    model.classifier = FCNHead(2048, outputchannels)
    # Set the model in training mode
    model.train()
    return model
Exemple #14
0
def build_fcn_resnet(name,
                     backbone_fct,
                     num_classes,
                     aux,
                     layers,
                     pretrained_backbone=False):
    """Constructs a custom ResNet backbone and modifies for fully convolutional, semantic segmentation
    Args:
        name (str): either deeplabv3 or fcn
        backbone_fct (function): the model function for the non-fcn-only backbone
        num_classes (int): number of classes
        aux (bool): use of auxiliary loss
        layers (list): configuration of layer-blocks (Bottlenecks)
        pretrained_backbone (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
    """
    backbone = backbone_fct(layers=layers,
                            pretrained=pretrained_backbone,
                            replace_stride_with_dilation=[False, True, True])

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        inplanes = 1024
        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {
        'deeplabv3': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    inplanes = 2048
    classifier = model_map[name][0](inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model
Exemple #15
0
def _segm_resnet(name,
                 backbone_name,
                 num_classes,
                 aux,
                 dilation,
                 pretrained_backbone=True):
    backbone = resnet.__dict__[backbone_name](
        pretrained=pretrained_backbone, replace_stride_with_dilation=dilation)

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        inplanes = 1024
        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {
        'deeplabv3': (CustomDeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    inplanes = 2048
    # print('dilation: {}'.format(dilation))
    if name == 'fcn':
        classifier = model_map[name][0](inplanes, num_classes)
    else:
        if (not dilation[0]) and dilation[1] and dilation[2]:
            classifier = model_map[name][0](inplanes, num_classes,
                                            [12, 24, 36])
        elif (not dilation[0]) and (not dilation[1]) and dilation[2]:
            classifier = model_map[name][0](inplanes, num_classes, [6, 12, 18])
        else:
            print('invalid dilation value')
            return None
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model
Exemple #16
0
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True, replace_stride_with_dilation=[False, True, True], rates=[12, 24, 36], return_layers = {'layer4': 'out'}):
	backbone = resnet.__dict__[backbone_name](
		pretrained=pretrained_backbone,
		replace_stride_with_dilation=replace_stride_with_dilation)

	backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

	aux_classifier = None
	if aux:
		inplanes = 1024
		aux_classifier = FCNHead(inplanes, num_classes)

	model_map = {
		'deeplabv3': (DeepLabHead, DeepLabV3),
		'fcn': (FCNHead, FCN),
	}
	inplanes = 2048
	classifier = model_map[name][0](inplanes, num_classes, rates)
	base_model = model_map[name][1]

	model = base_model(backbone, classifier, aux_classifier)

	return model
Exemple #17
0
    def __init__(self, num_classes: int = 2, ignore_index: Optional[int] = None, lr: float = 0.001,
                 weight_decay: float = 0.001, aux_loss_factor: float = 0.3):

        super().__init__()
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.lr = lr
        self.weight_decay = weight_decay
        self.aux_loss_factor = aux_loss_factor

        # Create model from pre-trained DeepLabv3
        self.model = deeplabv3_resnet101(pretrained=True, progress=True)
        self.model.aux_classifier = FCNHead(1024, self.num_classes)
        self.model.classifier = DeepLabHead(2048, self.num_classes)

        # Setup trainable layers
        self.model.requires_grad_(True)
        self.model.backbone.requires_grad_(False)

        # Loss function and metrics
        self.focal_tversky_loss = FocalTverskyMetric(
            self.num_classes,
            alpha=0.7,
            beta=0.3,
            gamma=4.0 / 3.0,
            ignore_index=self.ignore_index,
        )
        self.accuracy_metric = Accuracy(ignore_index=self.ignore_index)
        self.iou_metric = JaccardIndex(
            num_classes=self.num_classes,
            reduction="none",
            ignore_index=self.ignore_index,
        )
        self.precision_metric = Precision(num_classes=self.num_classes, ignore_index=self.ignore_index,
                                          average='weighted', mdmc_average='global')
        self.recall_metric = Recall(num_classes=self.num_classes, ignore_index=self.ignore_index,
                                    average='weighted', mdmc_average='global')
def loadModel(model_arch="",
              classes=None,
              pre_trained_path=None,
              expType=None,
              trainable_backbone_flag=False,
              lower_features=False):
    print("Load model architecture ... ")

    if (model_arch == "deeplabv3_resnet101_orig"):
        print("deeplab_resnet architecture selected ...")
        model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                        progress=True)

        for params in model.parameters():
            params.requires_grad = trainable_backbone_flag

        model.classifier[-1] = torch.nn.Conv2d(256,
                                               len(classes),
                                               kernel_size=(1, 1))
        model.aux_classifier[-1] = torch.nn.Conv2d(256,
                                                   len(classes),
                                                   kernel_size=(1, 1))
        features = model.backbone

        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(
                pre_trained_path)  # + "/best_model.pth")
            model.load_state_dict(model_dict_state['net'])
        return model, features

    elif (model_arch == "fcn_resnet101_orig"):
        print("deeplab_resnet architecture selected ...")
        model = models.segmentation.fcn_resnet101(pretrained=True,
                                                  progress=True)

        for params in model.parameters():
            params.requires_grad = trainable_backbone_flag

        model.classifier[-1] = torch.nn.Conv2d(512,
                                               len(classes),
                                               kernel_size=(1, 1))
        model.aux_classifier[-1] = torch.nn.Conv2d(256,
                                                   len(classes),
                                                   kernel_size=(1, 1))
        features = model.backbone

        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(
                pre_trained_path)  # + "/best_model.pth")
            model.load_state_dict(model_dict_state['net'])

        return model, features

    elif (model_arch == "deeplabv3_resnet101"):
        print("deeplabv3_resnet101 architecture selected ...")
        backbone_net = CNN(model_arch="resnet101",
                           n_classes=len(classes),
                           include_top=False,
                           pretrained=trainable_backbone_flag,
                           lower_features=lower_features)

        if (lower_features == True):
            classifier = nn.Sequential(DeepLabHead(256, len(classes)),
                                       # nn.Softmax()
                                       )
        else:
            classifier = nn.Sequential(DeepLabHead(2048, len(classes)),
                                       # nn.Softmax()
                                       )

        features = backbone_net
        model = models.segmentation.DeepLabV3(backbone=backbone_net,
                                              classifier=classifier,
                                              aux_classifier=None)

        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(
                pre_trained_path)  # + "/best_model.pth")
            model.load_state_dict(model_dict_state['net'])
        return model, features

    elif (model_arch == "deeplabv3_vgg16"):
        print("deeplabv3_vgg architecture selected ...")
        # backbone_net = CNN(model_arch="resnet101", n_classes=len(classes), include_top=False)
        backbone_net = CNN(model_arch="vgg16",
                           n_classes=len(classes),
                           include_top=False,
                           pretrained=trainable_backbone_flag,
                           lower_features=lower_features)

        if (lower_features == True):
            classifier = nn.Sequential(DeepLabHead(64, len(classes)),
                                       # nn.Softmax()
                                       )
        else:
            classifier = nn.Sequential(DeepLabHead(512, len(classes)),
                                       # nn.Softmax()
                                       )

        features = backbone_net
        model = models.segmentation.DeepLabV3(backbone=backbone_net,
                                              classifier=classifier,
                                              aux_classifier=None)
        #print(model)
        #exit()
        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(
                pre_trained_path)  # + "/best_model.pth")
            model.load_state_dict(model_dict_state['net'])

        # Find total parameters and trainable parameters
        total_params = sum(p.numel() for p in model.parameters())
        print("total_params:" + str(total_params))
        total_trainable_params = sum(p.numel() for p in model.parameters()
                                     if p.requires_grad)
        print("total_trainable_params: " + str(total_trainable_params))
        #exit()

        return model, features

    elif (model_arch == "deeplabv3_mobilenet"):
        print("deeplabv3_mobilenet architecture selected ...")
        backbone_net = CNN(model_arch="mobilenet",
                           n_classes=len(classes),
                           include_top=False,
                           pretrained=trainable_backbone_flag,
                           lower_features=lower_features)

        if (lower_features == True):
            classifier = nn.Sequential(DeepLabHead(32, len(classes)),
                                       # nn.Softmax()
                                       )
        else:
            classifier = nn.Sequential(DeepLabHead(1280, len(classes)),
                                       # nn.Softmax()
                                       )

        features = backbone_net
        model = models.segmentation.DeepLabV3(backbone=backbone_net,
                                              classifier=classifier,
                                              aux_classifier=None)

        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(pre_trained_path)
            model.load_state_dict(model_dict_state['net'])

        return model, features

    elif (model_arch == "deeplabv3_squeezenet"):
        print("deeplabv3_mobilenet architecture selected ...")
        backbone_net = CNN(model_arch="squeezenet",
                           n_classes=len(classes),
                           include_top=False,
                           pretrained=trainable_backbone_flag,
                           lower_features=lower_features)

        if (lower_features == True):
            classifier = nn.Sequential(DeepLabHead(128, len(classes)),
                                       # nn.Softmax()
                                       )
        else:
            classifier = nn.Sequential(DeepLabHead(512, len(classes)),
                                       # nn.Softmax()
                                       )

        features = backbone_net
        model = models.segmentation.DeepLabV3(backbone=backbone_net,
                                              classifier=classifier,
                                              aux_classifier=None)

        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(
                pre_trained_path)  # + "/best_model.pth")
            model.load_state_dict(model_dict_state['net'])

        return model, features

    elif (model_arch == "fcn_vgg16"):
        print("fcn_vgg16 architecture selected ...")
        backbone_net = CNN(model_arch="vgg16",
                           n_classes=len(classes),
                           include_top=False,
                           pretrained=trainable_backbone_flag,
                           lower_features=lower_features)

        if (lower_features == True):
            classifier = nn.Sequential(FCNHead(64, len(classes)),
                                       # nn.Softmax()
                                       )
        else:
            classifier = nn.Sequential(FCNHead(512, len(classes)),
                                       # nn.Softmax()
                                       )
        features = backbone_net
        model = models.segmentation.FCN(backbone=backbone_net,
                                        classifier=classifier,
                                        aux_classifier=None)
        # print(model)

        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(
                pre_trained_path)  # + "/best_model.pth")
            model.load_state_dict(model_dict_state['net'])

        return model, features
    elif (model_arch == "fcn_resnet101"):
        print("fcn_resnet101 architecture selected ...")
        backbone_net = CNN(model_arch="resnet101",
                           n_classes=len(classes),
                           include_top=False,
                           pretrained=trainable_backbone_flag,
                           lower_features=lower_features)

        if (lower_features == True):
            classifier = nn.Sequential(FCNHead(256, len(classes)),
                                       # nn.Softmax()
                                       )
        else:
            classifier = nn.Sequential(FCNHead(2048, len(classes)),
                                       # nn.Softmax()
                                       )

        features = backbone_net
        model = models.segmentation.FCN(backbone=backbone_net,
                                        classifier=classifier,
                                        aux_classifier=None)

        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(
                pre_trained_path)  # + "/best_model.pth")
            model.load_state_dict(model_dict_state['net'])

        # Find total parameters and trainable parameters
        total_params = sum(p.numel() for p in model.parameters())
        print("total_params:" + str(total_params))
        total_trainable_params = sum(p.numel() for p in model.parameters()
                                     if p.requires_grad)
        print("total_trainable_params: " + str(total_trainable_params))
        #exit()

        return model, features

    elif (model_arch == "fcn_squeezenet"):
        print("deeplabv3_squeezenet architecture selected ...")
        backbone_net = CNN(model_arch="squeezenet",
                           n_classes=len(classes),
                           include_top=False,
                           pretrained=trainable_backbone_flag,
                           lower_features=lower_features)

        if (lower_features == True):
            classifier = nn.Sequential(FCNHead(128, len(classes)),
                                       # nn.Softmax()
                                       )
        else:
            classifier = nn.Sequential(FCNHead(512, len(classes)),
                                       # nn.Softmax()
                                       )

        features = backbone_net
        model = models.segmentation.FCN(backbone=backbone_net,
                                        classifier=classifier,
                                        aux_classifier=None)

        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(
                pre_trained_path)  # + "/best_model.pth")
            model.load_state_dict(model_dict_state['net'])

        # Find total parameters and trainable parameters
        total_params = sum(p.numel() for p in model.parameters())
        print("total_params:" + str(total_params))
        total_trainable_params = sum(p.numel() for p in model.parameters()
                                     if p.requires_grad)
        print("total_trainable_params: " + str(total_trainable_params))
        # exit()
        return model, features

    elif (model_arch == "fcn_mobilenet"):
        print("deeplabv3_mobilenet architecture selected ...")
        backbone_net = CNN(model_arch="mobilenet",
                           n_classes=len(classes),
                           include_top=False,
                           pretrained=trainable_backbone_flag,
                           lower_features=lower_features)

        if (lower_features == True):
            classifier = nn.Sequential(FCNHead(32, len(classes)),
                                       # nn.Softmax()
                                       )
        else:
            classifier = nn.Sequential(FCNHead(1280, len(classes)),
                                       # nn.Softmax()
                                       )

        features = backbone_net
        model = models.segmentation.FCN(backbone=backbone_net,
                                        classifier=classifier,
                                        aux_classifier=None)

        if (pre_trained_path != None):
            print("load pre-trained-weights ... ")
            model_dict_state = torch.load(
                pre_trained_path)  # + "/best_model.pth")
            model.load_state_dict(model_dict_state['net'])

        # Find total parameters and trainable parameters
        total_params = sum(p.numel() for p in model.parameters())
        print("total_params:" + str(total_params))
        total_trainable_params = sum(p.numel() for p in model.parameters()
                                     if p.requires_grad)
        print("total_trainable_params: " + str(total_trainable_params))
        # exit()
        return model, features

    else:
        print("ERROR: select valid model architecture!")
        exit()