def __init__(self, model, n_classes): # Last layer (the classifier), fully convolutional from torchvision.models.segmentation.fcn import FCNHead # Modify classifier layer to predict requested number of classes classifier = FCNHead(model.classifier[0].in_channels, n_classes) # Modify aux classifier layer to predict requested number of classes aux_classifier = FCNHead(model.aux_classifier[0].in_channels, n_classes) # Call Module constructor to initialize network module super(torchvision.models.segmentation.fcn.FCN, self).__init__(model.backbone, classifier, aux_classifier)
def __init__(self, num_convs=3, fine_tune=False): super(MyModel, self).__init__() assert 8 >= num_convs > 1, "Cannot have less than 1 or greater than 8 convolutional+pooling layers." self.num_convs = num_convs self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") # Simple Segmentation Model self.backbone = models.resnet50( pretrained=True, replace_stride_with_dilation=[False, True, True]) self.backbone = IntermediateLayerGetter( self.backbone, return_layers={'layer4': 'out'}) if not fine_tune: for p in self.backbone.parameters(): p.requires_grad = False in_channels = 1028 classes = 200 # actually classes # Graph convolution layers and graph pooling layers self.classifier = FCNHead(in_channels, classes) self.convolutions = dict() for conv in range(num_convs): self.convolutions['conv' + str(conv + 1)] = GCNConv(2048, 2048).to( self.device) self.convolutions['pool' + str(conv + 1)] = TopKPooling( 2048, ratio=.6).to(self.device) #Final Output self.lin1 = torch.nn.Linear(2048, 1024).to(self.device) self.lin2 = torch.nn.Linear(1024, 1024).to(self.device) self.lin3 = torch.nn.Linear(1024, classes).to(self.device) self.act1 = torch.nn.ReLU().to(self.device) self.act2 = torch.nn.ReLU().to(self.device)
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True, **kwargs): backbone = resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True], resnet_local_path=kwargs["resnet_local_path"] if "resnet_local_path" in kwargs else None ) return_layers = {'layer4': 'out'} if aux: return_layers['layer3'] = 'aux' backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: inplanes = 1024 aux_classifier = FCNHead(inplanes, num_classes) model_map = { 'deeplabv3': (DeepLabHead, DeepLabV3), 'fcn': (FCNHead, FCN), } inplanes = 2048 classifier = model_map[name][0](inplanes, num_classes) base_model = model_map[name][1] model = base_model(backbone, classifier, aux_classifier) return model
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): backbone = resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) backbone.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) return_layers = {'layer4': 'out'} if aux: return_layers['layer3'] = 'aux' backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: inplanes = 1024 aux_classifier = FCNHead(inplanes, num_classes) model_map = { 'fcn': (FCNHead, FCN), } inplanes = 2048 classifier = model_map[name][0](inplanes, num_classes) base_model = model_map[name][1] model = base_model(backbone, classifier, aux_classifier) return model
def _segm_resnet(name, backbone_name, num_classes, aux, **kwargs): # FIXME: 1000 and _ if isinstance(backbone_name, dict): backbone = model_utils.which_architecture(backbone_name['arch'], backbone_name['customs']) elif os.path.isfile(backbone_name): backbone, _ = model_utils.which_network_classification( backbone_name, 1000, **kwargs) else: backbone = model_utils.which_architecture(backbone_name, **kwargs) return_layers = {'layer4': 'out'} if aux: return_layers['layer3'] = 'aux' backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: layer3 = list(backbone.layer3)[-1] inplanes = list(layer3.children())[-2].num_features aux_classifier = FCNHead(inplanes, num_classes) model_map = { 'deeplabv3': (DeepLabHead, DeepLabV3), 'fcn': (FCNHead, FCN), } layer4 = list(backbone.layer4)[-1] inplanes = list(layer4.children())[-2].num_features classifier = model_map[name][0](inplanes, num_classes) base_model = model_map[name][1] model = ColourTransferModel( base_model(backbone, classifier, aux_classifier), **kwargs) return model
def __init__(self, num_classes, backbone_fn, chip_size=224): super().__init__() if getattr(backbone_fn, '_is_multispectral', False): self.backbone = create_body(backbone_fn, pretrained=True, cut=_get_backbone_meta( backbone_fn.__name__)['cut']) else: self.backbone = create_body(backbone_fn, pretrained=True) backbone_name = backbone_fn.__name__ ## Support for different backbones if "densenet" in backbone_name or "vgg" in backbone_name: hookable_modules = list(self.backbone.children())[0] else: hookable_modules = list(self.backbone.children()) if "vgg" in backbone_name: modify_dilation_index = -5 else: modify_dilation_index = -2 if backbone_name == 'resnet18' or backbone_name == 'resnet34': module_to_check = 'conv' else: module_to_check = 'conv2' ## Hook at the index where we need to get the auxillary logits out self.hook = hook_output(hookable_modules[modify_dilation_index]) custom_idx = 0 for i, module in enumerate(hookable_modules[modify_dilation_index:]): dilation = 2 * (i + 1) padding = 2 * (i + 1) for n, m in module.named_modules(): if module_to_check in n: m.dilation, m.padding, m.stride = (dilation, dilation), ( padding, padding), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) if "vgg" in backbone_fn.__name__: if isinstance(module, nn.Conv2d): dilation = 2 * (custom_idx + 1) padding = 2 * (custom_idx + 1) module.dilation, module.padding, module.stride = ( dilation, dilation), (padding, padding), (1, 1) custom_idx += 1 ## returns the size of various activations feature_sizes = model_sizes(self.backbone, size=(chip_size, chip_size)) ## Geting the number of channel persent in stored activation inside of the hook num_channels_aux_classifier = self.hook.stored.shape[1] ## Get number of channels in the last layer num_channels_classifier = feature_sizes[-1][1] self.classifier = DeepLabHead(num_channels_classifier, num_classes) self.aux_classifier = FCNHead(num_channels_aux_classifier, num_classes)
def _create_deeplab(num_class, backbone, pretrained=True, **kwargs): ''' Create default torchvision pretrained model with resnet101. ''' model = models.segmentation.deeplabv3_resnet101(pretrained=True, progress=True, **kwargs) model = _DeepLabOverride(model.backbone, model.classifier, model.aux_classifier) model.classifier = DeepLabHead(2048, num_class) model.aux_classifier = FCNHead(1024, num_class) return model
def __init__(self, num_classes=23): super(SegmentationNN, self).__init__() ####################################################################### # YOUR CODE # ####################################################################### # self.backbone = models.vgg16_bn(pretrained=True) # self.backbone.classifier = nn.Sequential( # nn.Conv2d(512, 4096, kernel_size=1), # nn.BatchNorm2d(4096), # nn.ReLU(True), # nn.Conv2d(4096, 4096, kernel_size=1), # nn.BatchNorm2d(4096), # nn.ReLU(True), # nn.Conv2d(4096, num_classes, kernel_size=1) # ) # self.avgpool = nn.AvgPool2d(kernel_size=2) # self.upsample1 = nn.Sequential( # nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2, # stride=2, padding=0, output_padding=1), # nn.BatchNorm2d(num_classes), # nn.ReLU() # ) # self.upsample2 = nn.Sequential( # nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2, # stride=2, padding=0), # nn.BatchNorm2d(num_classes), # nn.ReLU() # ) # self.upsample3 = nn.Sequential( # nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, # stride=8, padding=1, dilation=3), # # nn.BatchNorm2d(num_classes), # # nn.ReLU(True) # ) # self.conv1 = nn.Sequential( # nn.Conv2d(256, num_classes, 1), # nn.BatchNorm2d(num_classes), # nn.ReLU() # ) # self.conv2 = nn.Sequential( # nn.Conv2d(512, num_classes, 1), # nn.BatchNorm2d(num_classes), # nn.ReLU() # ) self.fcn = models.segmentation.fcn_resnet101(pretrained=True) self.fcn.classifier = FCNHead(2048, num_classes) self.fcn.aux_classifier = None for name, param in self.fcn.named_parameters(): if "classifier" in name: param.requires_grad = True else: param.requires_grad = False
def __init__(self, in_chan=3, out_chan=2, pretrained=False): super(FCN50, self).__init__() self.model = torchvision.models.segmentation.fcn_resnet50( pretrained=False, pretrained_backbone=pretrained) self.model.classifier = FCNHead(2048, out_chan) if in_chan != 3: self.model.backbone.conv1 = nn.Conv2d(in_chan, 64, kernel_size=7, stride=2, padding=3, bias=False)
def _segm_resnet(name, backbone_name, num_classes, aux, **kwargs): # FIXME: 1000 and _ if isinstance(backbone_name, dict): if 'pretrained' in kwargs: del kwargs['pretrained'] backbone = model_utils.which_architecture( backbone_name['arch'], backbone_name['customs'] ) elif os.path.isfile(backbone_name): if 'pretrained' in kwargs: del kwargs['pretrained'] backbone, _ = model_utils.which_network_classification( backbone_name, 1000, **kwargs ) else: backbone = model_utils.which_architecture(backbone_name, **kwargs) return_layers = {'layer4': 'out'} if aux: return_layers['layer3'] = 'aux' backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: layer3 = list(backbone.layer3)[-1] # Depending on Bottleneck or basic choose the num_features if isinstance(layer3, (cresnet.Bottleneck, presnet.Bottleneck)): inplanes = list(layer3.children())[-2].num_features else: inplanes = list(layer3.children())[-1].num_features aux_classifier = FCNHead(inplanes, num_classes) model_map = { 'deeplabv3': (DeepLabHead, DeepLabV3), 'fcn': (FCNHead, FCN), } layer4 = list(backbone.layer4)[-1] if isinstance(layer4, (cresnet.Bottleneck, presnet.Bottleneck)): inplanes = list(layer4.children())[-2].num_features else: inplanes = list(layer4.children())[-1].num_features classifier = model_map[name][0](inplanes, num_classes) base_model = model_map[name][1] model = base_model(backbone, classifier, aux_classifier) return model
def pre_fcn_resnet101(in_channel, out_channel): model = fcn_resnet101(pretrained=False, progress=False) url = "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth" # COCO model_dict = model.state_dict() pretrained_dict = model_zoo.load_url(url, progress=False) pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(model_dict) model.backbone.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False) model.classifier = FCNHead(2048, out_channel) return model
def deeplabv3(pretrained=False, resnet="res103", head_in_ch=2048, num_classes=21): resnet = {"res101": resnet101, "res103": resnet103}[resnet] net = SmallDeepLab(backbone=IntermediateLayerGetter(resnet103( pretrained=True, replace_stride_with_dilation=[False, True, True]), return_layers={ 'layer2': 'res2', 'layer4': 'out' }), classifier=DeepLabHead(head_in_ch, num_classes), aux_classifier=FCNHead(head_in_ch // 2, num_classes)) if pretrained: state_dict = load_state_dict_from_url( 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', progress=True) net.load_state_dict(state_dict) return net
def createFCNResnet101(outputchannels=1, feature_extract=True): """FCNResnet101 class with custom head Args: outputchannels (int, optional): The number of output channels in your dataset masks. Defaults to 1. feature_extract (bool, optional): If False the whole model is trained otherwise only the classifier (head) is trained Returns: model: Returns the FCN model with the ResNet101 backbone. """ model = models.segmentation.fcn_resnet101(pretrained=True, progress=True) print(model) if feature_extract: for param in model.parameters(): param.requires_grad = False model.classifier = FCNHead(2048, outputchannels) # Set the model in training mode model.train() return model
def build_fcn_resnet(name, backbone_fct, num_classes, aux, layers, pretrained_backbone=False): """Constructs a custom ResNet backbone and modifies for fully convolutional, semantic segmentation Args: name (str): either deeplabv3 or fcn backbone_fct (function): the model function for the non-fcn-only backbone num_classes (int): number of classes aux (bool): use of auxiliary loss layers (list): configuration of layer-blocks (Bottlenecks) pretrained_backbone (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC """ backbone = backbone_fct(layers=layers, pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) return_layers = {'layer4': 'out'} if aux: return_layers['layer3'] = 'aux' backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: inplanes = 1024 aux_classifier = FCNHead(inplanes, num_classes) model_map = { 'deeplabv3': (DeepLabHead, DeepLabV3), 'fcn': (FCNHead, FCN), } inplanes = 2048 classifier = model_map[name][0](inplanes, num_classes) base_model = model_map[name][1] model = base_model(backbone, classifier, aux_classifier) return model
def _segm_resnet(name, backbone_name, num_classes, aux, dilation, pretrained_backbone=True): backbone = resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, replace_stride_with_dilation=dilation) return_layers = {'layer4': 'out'} if aux: return_layers['layer3'] = 'aux' backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: inplanes = 1024 aux_classifier = FCNHead(inplanes, num_classes) model_map = { 'deeplabv3': (CustomDeepLabHead, DeepLabV3), 'fcn': (FCNHead, FCN), } inplanes = 2048 # print('dilation: {}'.format(dilation)) if name == 'fcn': classifier = model_map[name][0](inplanes, num_classes) else: if (not dilation[0]) and dilation[1] and dilation[2]: classifier = model_map[name][0](inplanes, num_classes, [12, 24, 36]) elif (not dilation[0]) and (not dilation[1]) and dilation[2]: classifier = model_map[name][0](inplanes, num_classes, [6, 12, 18]) else: print('invalid dilation value') return None base_model = model_map[name][1] model = base_model(backbone, classifier, aux_classifier) return model
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True, replace_stride_with_dilation=[False, True, True], rates=[12, 24, 36], return_layers = {'layer4': 'out'}): backbone = resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, replace_stride_with_dilation=replace_stride_with_dilation) backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: inplanes = 1024 aux_classifier = FCNHead(inplanes, num_classes) model_map = { 'deeplabv3': (DeepLabHead, DeepLabV3), 'fcn': (FCNHead, FCN), } inplanes = 2048 classifier = model_map[name][0](inplanes, num_classes, rates) base_model = model_map[name][1] model = base_model(backbone, classifier, aux_classifier) return model
def __init__(self, num_classes: int = 2, ignore_index: Optional[int] = None, lr: float = 0.001, weight_decay: float = 0.001, aux_loss_factor: float = 0.3): super().__init__() self.num_classes = num_classes self.ignore_index = ignore_index self.lr = lr self.weight_decay = weight_decay self.aux_loss_factor = aux_loss_factor # Create model from pre-trained DeepLabv3 self.model = deeplabv3_resnet101(pretrained=True, progress=True) self.model.aux_classifier = FCNHead(1024, self.num_classes) self.model.classifier = DeepLabHead(2048, self.num_classes) # Setup trainable layers self.model.requires_grad_(True) self.model.backbone.requires_grad_(False) # Loss function and metrics self.focal_tversky_loss = FocalTverskyMetric( self.num_classes, alpha=0.7, beta=0.3, gamma=4.0 / 3.0, ignore_index=self.ignore_index, ) self.accuracy_metric = Accuracy(ignore_index=self.ignore_index) self.iou_metric = JaccardIndex( num_classes=self.num_classes, reduction="none", ignore_index=self.ignore_index, ) self.precision_metric = Precision(num_classes=self.num_classes, ignore_index=self.ignore_index, average='weighted', mdmc_average='global') self.recall_metric = Recall(num_classes=self.num_classes, ignore_index=self.ignore_index, average='weighted', mdmc_average='global')
def loadModel(model_arch="", classes=None, pre_trained_path=None, expType=None, trainable_backbone_flag=False, lower_features=False): print("Load model architecture ... ") if (model_arch == "deeplabv3_resnet101_orig"): print("deeplab_resnet architecture selected ...") model = models.segmentation.deeplabv3_resnet101(pretrained=True, progress=True) for params in model.parameters(): params.requires_grad = trainable_backbone_flag model.classifier[-1] = torch.nn.Conv2d(256, len(classes), kernel_size=(1, 1)) model.aux_classifier[-1] = torch.nn.Conv2d(256, len(classes), kernel_size=(1, 1)) features = model.backbone if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load( pre_trained_path) # + "/best_model.pth") model.load_state_dict(model_dict_state['net']) return model, features elif (model_arch == "fcn_resnet101_orig"): print("deeplab_resnet architecture selected ...") model = models.segmentation.fcn_resnet101(pretrained=True, progress=True) for params in model.parameters(): params.requires_grad = trainable_backbone_flag model.classifier[-1] = torch.nn.Conv2d(512, len(classes), kernel_size=(1, 1)) model.aux_classifier[-1] = torch.nn.Conv2d(256, len(classes), kernel_size=(1, 1)) features = model.backbone if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load( pre_trained_path) # + "/best_model.pth") model.load_state_dict(model_dict_state['net']) return model, features elif (model_arch == "deeplabv3_resnet101"): print("deeplabv3_resnet101 architecture selected ...") backbone_net = CNN(model_arch="resnet101", n_classes=len(classes), include_top=False, pretrained=trainable_backbone_flag, lower_features=lower_features) if (lower_features == True): classifier = nn.Sequential(DeepLabHead(256, len(classes)), # nn.Softmax() ) else: classifier = nn.Sequential(DeepLabHead(2048, len(classes)), # nn.Softmax() ) features = backbone_net model = models.segmentation.DeepLabV3(backbone=backbone_net, classifier=classifier, aux_classifier=None) if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load( pre_trained_path) # + "/best_model.pth") model.load_state_dict(model_dict_state['net']) return model, features elif (model_arch == "deeplabv3_vgg16"): print("deeplabv3_vgg architecture selected ...") # backbone_net = CNN(model_arch="resnet101", n_classes=len(classes), include_top=False) backbone_net = CNN(model_arch="vgg16", n_classes=len(classes), include_top=False, pretrained=trainable_backbone_flag, lower_features=lower_features) if (lower_features == True): classifier = nn.Sequential(DeepLabHead(64, len(classes)), # nn.Softmax() ) else: classifier = nn.Sequential(DeepLabHead(512, len(classes)), # nn.Softmax() ) features = backbone_net model = models.segmentation.DeepLabV3(backbone=backbone_net, classifier=classifier, aux_classifier=None) #print(model) #exit() if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load( pre_trained_path) # + "/best_model.pth") model.load_state_dict(model_dict_state['net']) # Find total parameters and trainable parameters total_params = sum(p.numel() for p in model.parameters()) print("total_params:" + str(total_params)) total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("total_trainable_params: " + str(total_trainable_params)) #exit() return model, features elif (model_arch == "deeplabv3_mobilenet"): print("deeplabv3_mobilenet architecture selected ...") backbone_net = CNN(model_arch="mobilenet", n_classes=len(classes), include_top=False, pretrained=trainable_backbone_flag, lower_features=lower_features) if (lower_features == True): classifier = nn.Sequential(DeepLabHead(32, len(classes)), # nn.Softmax() ) else: classifier = nn.Sequential(DeepLabHead(1280, len(classes)), # nn.Softmax() ) features = backbone_net model = models.segmentation.DeepLabV3(backbone=backbone_net, classifier=classifier, aux_classifier=None) if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load(pre_trained_path) model.load_state_dict(model_dict_state['net']) return model, features elif (model_arch == "deeplabv3_squeezenet"): print("deeplabv3_mobilenet architecture selected ...") backbone_net = CNN(model_arch="squeezenet", n_classes=len(classes), include_top=False, pretrained=trainable_backbone_flag, lower_features=lower_features) if (lower_features == True): classifier = nn.Sequential(DeepLabHead(128, len(classes)), # nn.Softmax() ) else: classifier = nn.Sequential(DeepLabHead(512, len(classes)), # nn.Softmax() ) features = backbone_net model = models.segmentation.DeepLabV3(backbone=backbone_net, classifier=classifier, aux_classifier=None) if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load( pre_trained_path) # + "/best_model.pth") model.load_state_dict(model_dict_state['net']) return model, features elif (model_arch == "fcn_vgg16"): print("fcn_vgg16 architecture selected ...") backbone_net = CNN(model_arch="vgg16", n_classes=len(classes), include_top=False, pretrained=trainable_backbone_flag, lower_features=lower_features) if (lower_features == True): classifier = nn.Sequential(FCNHead(64, len(classes)), # nn.Softmax() ) else: classifier = nn.Sequential(FCNHead(512, len(classes)), # nn.Softmax() ) features = backbone_net model = models.segmentation.FCN(backbone=backbone_net, classifier=classifier, aux_classifier=None) # print(model) if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load( pre_trained_path) # + "/best_model.pth") model.load_state_dict(model_dict_state['net']) return model, features elif (model_arch == "fcn_resnet101"): print("fcn_resnet101 architecture selected ...") backbone_net = CNN(model_arch="resnet101", n_classes=len(classes), include_top=False, pretrained=trainable_backbone_flag, lower_features=lower_features) if (lower_features == True): classifier = nn.Sequential(FCNHead(256, len(classes)), # nn.Softmax() ) else: classifier = nn.Sequential(FCNHead(2048, len(classes)), # nn.Softmax() ) features = backbone_net model = models.segmentation.FCN(backbone=backbone_net, classifier=classifier, aux_classifier=None) if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load( pre_trained_path) # + "/best_model.pth") model.load_state_dict(model_dict_state['net']) # Find total parameters and trainable parameters total_params = sum(p.numel() for p in model.parameters()) print("total_params:" + str(total_params)) total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("total_trainable_params: " + str(total_trainable_params)) #exit() return model, features elif (model_arch == "fcn_squeezenet"): print("deeplabv3_squeezenet architecture selected ...") backbone_net = CNN(model_arch="squeezenet", n_classes=len(classes), include_top=False, pretrained=trainable_backbone_flag, lower_features=lower_features) if (lower_features == True): classifier = nn.Sequential(FCNHead(128, len(classes)), # nn.Softmax() ) else: classifier = nn.Sequential(FCNHead(512, len(classes)), # nn.Softmax() ) features = backbone_net model = models.segmentation.FCN(backbone=backbone_net, classifier=classifier, aux_classifier=None) if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load( pre_trained_path) # + "/best_model.pth") model.load_state_dict(model_dict_state['net']) # Find total parameters and trainable parameters total_params = sum(p.numel() for p in model.parameters()) print("total_params:" + str(total_params)) total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("total_trainable_params: " + str(total_trainable_params)) # exit() return model, features elif (model_arch == "fcn_mobilenet"): print("deeplabv3_mobilenet architecture selected ...") backbone_net = CNN(model_arch="mobilenet", n_classes=len(classes), include_top=False, pretrained=trainable_backbone_flag, lower_features=lower_features) if (lower_features == True): classifier = nn.Sequential(FCNHead(32, len(classes)), # nn.Softmax() ) else: classifier = nn.Sequential(FCNHead(1280, len(classes)), # nn.Softmax() ) features = backbone_net model = models.segmentation.FCN(backbone=backbone_net, classifier=classifier, aux_classifier=None) if (pre_trained_path != None): print("load pre-trained-weights ... ") model_dict_state = torch.load( pre_trained_path) # + "/best_model.pth") model.load_state_dict(model_dict_state['net']) # Find total parameters and trainable parameters total_params = sum(p.numel() for p in model.parameters()) print("total_params:" + str(total_params)) total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("total_trainable_params: " + str(total_trainable_params)) # exit() return model, features else: print("ERROR: select valid model architecture!") exit()