Exemplo n.º 1
0
    def __init__(
        self,
        parameters: dict,
    ):
        super(SDNet, self).__init__(parameters)
        self.anatomy_factors = 8
        self.modality_factors = 8

        if parameters["patch_size"] != [224, 224, 1]:
            print(
                "WARNING: The patch size is not 224x224, which is required for sdnet. Using default patch size instead",
                file=sys.stderr,
            )
            parameters["patch_size"] = [224, 224, 1]

        if parameters["batch_size"] == 1:
            raise ValueError(
                "'batch_size' needs to be greater than 1 for 'sdnet'")

        # amp is not supported for sdnet
        parameters["model"]["amp"] = False
        parameters["model"]["norm_type"] = "instance"

        parameters_unet = deepcopy(parameters)
        parameters_unet["model"]["num_classes"] = self.anatomy_factors
        parameters_unet["model"]["norm_type"] = "instance"
        parameters_unet["model"]["final_layer"] = None

        self.cencoder = unet(parameters_unet)
        self.mencoder = ModalityEncoder(
            parameters,
            self.anatomy_factors,
            self.modality_factors,
        )
        self.decoder = Decoder(
            parameters,
            self.anatomy_factors,
        )
        self.segmentor = Segmentor(
            parameters,
            self.anatomy_factors,
        )
Exemplo n.º 2
0
def get_model(which_model, n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer, psize, batch_size, **kwargs):
    '''
    This function takes the default constructor and returns the model

    kwargs can be used to pass key word arguments and use arguments that are not explicitly defined.
    '''

    divisibilityCheck_patch = True
    divisibilityCheck_baseFilter = True

    divisibilityCheck_denom_patch = 16 # for unet/resunet/uinc
    divisibilityCheck_denom_baseFilter = 4 # for uinc
    
    if which_model == 'resunet':
        model = unet(n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer = final_convolution_layer, residualConnections=True)
        divisibilityCheck_baseFilter = False
    elif which_model == 'unet':
        model = unet(n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer = final_convolution_layer)
        divisibilityCheck_baseFilter = False
    elif which_model == 'fcn':
        model = fcn(n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer = final_convolution_layer)
        # not enough information to perform checking for this, yet
        divisibilityCheck_patch = False 
        divisibilityCheck_baseFilter = False
    elif which_model == 'uinc':
        model = uinc(n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer = final_convolution_layer)
    elif which_model == 'msdnet':
        model = MSDNet(n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer = final_convolution_layer)
    elif which_model == 'densenet121': # regressor/classifier network
        model = densenet.generate_model(model_depth=121,
                                        num_classes=n_classes,
                                        n_dimensions=n_dimensions,
                                        n_input_channels=n_channels, final_convolution_layer = final_convolution_layer)
    elif which_model == 'densenet161': # regressor/classifier network
        model = densenet.generate_model(model_depth=161,
                                        num_classes=n_classes,
                                        n_dimensions=n_dimensions,
                                        n_input_channels=n_channels, final_convolution_layer = final_convolution_layer)
    elif which_model == 'densenet169': # regressor/classifier network
        model = densenet.generate_model(model_depth=169,
                                        num_classes=n_classes,
                                        n_dimensions=n_dimensions,
                                        n_input_channels=n_channels, final_convolution_layer = final_convolution_layer)
    elif which_model == 'densenet201': # regressor/classifier network
        model = densenet.generate_model(model_depth=201,
                                        num_classes=n_classes,
                                        n_dimensions=n_dimensions,
                                        n_input_channels=n_channels, final_convolution_layer = final_convolution_layer)
    elif which_model == 'densenet264': # regressor/classifier network
        model = densenet.generate_model(model_depth=264,
                                        num_classes=n_classes,
                                        n_dimensions=n_dimensions,
                                        n_input_channels=n_channels, final_convolution_layer = final_convolution_layer)
    elif which_model == 'vgg16':
        vgg_config = cfg['D']
        num_final_features = vgg_config[-2]
        divisibility_factor = Counter(vgg_config)['M']
        if psize[-1] == 1:
            psize_altered = np.array(psize[:-1])
        else:
            psize_altered = np.array(psize)
        divisibilityCheck_patch = False 
        divisibilityCheck_baseFilter = False
        featuresForClassifier = batch_size * num_final_features * np.prod(psize_altered // 2**divisibility_factor)
        layers = make_layers(cfg['D'], n_dimensions, n_channels)
        # n_classes is coming from 'class_list' in config, which needs to be changed to use a different variable for regression
        model = VGG(n_dimensions, layers, featuresForClassifier, n_classes, final_convolution_layer = final_convolution_layer)
    elif which_model == 'brain_age':
        if n_dimensions != 2:
            sys.exit("Brain Age predictions only works on 2D data")
        model = torchvision.models.vgg16(pretrained = True)
        model.final_convolution_layer = None
        # Freeze training for all layers
        for param in model.features.parameters():
            param.require_grad = False
        # Newly created modules have require_grad=True by default
        num_features = model.classifier[6].in_features
        features = list(model.classifier.children())[:-1] # Remove last layer
        #features.extend([nn.AvgPool2d(1024), nn.Linear(num_features,1024),nn.ReLU(True), nn.Dropout2d(0.8), nn.Linear(1024,1)]) # RuntimeError: non-empty 2D or 3D (batch mode) tensor expected for input
        features.extend([nn.Linear(num_features,1024),nn.ReLU(True), nn.Dropout2d(0.8), nn.Linear(1024,1)])
        model.classifier = nn.Sequential(*features) # Replace the model classifier
        divisibilityCheck_patch = False 
        divisibilityCheck_baseFilter = False
        
    else:
        print('WARNING: Could not find the requested model \'' + which_model + '\' in the implementation, using ResUNet, instead', file = sys.stderr)
        which_model = 'resunet'
        model = unet(n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer = final_convolution_layer, residualConnections=True)
    
    # check divisibility
    if divisibilityCheck_patch:
        if not checkPatchDivisibility(psize, divisibilityCheck_denom_patch):
            sys.exit('The \'patch_size\' should be divisible by \'' + str(divisibilityCheck_denom_patch) + '\' for the \'' + which_model + '\' architecture')
    if divisibilityCheck_baseFilter:
        if base_filters % divisibilityCheck_denom_baseFilter != 0:
            sys.exit('The \'base_filters\' should be divisible by \'' + str(divisibilityCheck_denom_baseFilter) + '\' for the \'' + which_model + '\' architecture')
    
    return model