def __init__(self, n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer, residualConnections = False): super(unet, self).__init__(n_dimensions, n_channels, n_classes, base_filters, final_convolution_layer) self.ins = in_conv(self.n_channels, base_filters, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections) self.ds_0 = DownsamplingModule(base_filters, base_filters*2, self.Conv, self.Dropout, self.InstanceNorm) self.en_1 = EncodingModule(base_filters*2, base_filters*2, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections) self.ds_1 = DownsamplingModule(base_filters*2, base_filters*4, self.Conv, self.Dropout, self.InstanceNorm) self.en_2 = EncodingModule(base_filters*4, base_filters*4, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections) self.ds_2 = DownsamplingModule(base_filters*4, base_filters*8, self.Conv, self.Dropout, self.InstanceNorm) self.en_3 = EncodingModule(base_filters*8, base_filters*8, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections) self.ds_3 = DownsamplingModule(base_filters*8, base_filters*16, self.Conv, self.Dropout, self.InstanceNorm) self.en_4 = EncodingModule(base_filters*16, base_filters*16, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections) self.us_3 = UpsamplingModule(base_filters*16, base_filters*8, self.Conv, self.Dropout, self.InstanceNorm) self.de_3 = DecodingModule(base_filters*16, base_filters*8, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections) self.us_2 = UpsamplingModule(base_filters*8, base_filters*4, self.Conv, self.Dropout, self.InstanceNorm) self.de_2 = DecodingModule(base_filters*8, base_filters*4, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections) self.us_1 = UpsamplingModule(base_filters*4, base_filters*2, self.Conv, self.Dropout, self.InstanceNorm) self.de_1 = DecodingModule(base_filters*4, base_filters*2, self.Conv, self.Dropout, self.InstanceNorm, res=residualConnections) self.us_0 = UpsamplingModule(base_filters*2, base_filters, self.Conv, self.Dropout, self.InstanceNorm) self.out = out_conv(base_filters*2, n_classes, self.Conv, self.Dropout, self.InstanceNorm, final_convolution_layer=self.final_convolution_layer, res=residualConnections)
def __init__( self, parameters: dict, residualConnections=False, ): self.network_kwargs = {"res": residualConnections} super(unet, self).__init__(parameters) if not (checkPatchDivisibility(parameters["patch_size"])): sys.exit( "The patch size is not divisible by 16, which is required for", parameters["model"]["architecture"], ) self.ins = in_conv( input_channels=self.n_channels, output_channels=self.base_filters, conv=self.Conv, dropout=self.Dropout, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.ds_0 = DownsamplingModule( input_channels=self.base_filters, output_channels=self.base_filters * 2, conv=self.Conv, norm=self.Norm, ) self.en_1 = EncodingModule( input_channels=self.base_filters * 2, output_channels=self.base_filters * 2, conv=self.Conv, dropout=self.Dropout, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.ds_1 = DownsamplingModule( input_channels=self.base_filters * 2, output_channels=self.base_filters * 4, conv=self.Conv, norm=self.Norm, ) self.en_2 = EncodingModule( input_channels=self.base_filters * 4, output_channels=self.base_filters * 4, conv=self.Conv, dropout=self.Dropout, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.ds_2 = DownsamplingModule( input_channels=self.base_filters * 4, output_channels=self.base_filters * 8, conv=self.Conv, norm=self.Norm, ) self.en_3 = EncodingModule( input_channels=self.base_filters * 8, output_channels=self.base_filters * 8, conv=self.Conv, dropout=self.Dropout, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.ds_3 = DownsamplingModule( input_channels=self.base_filters * 8, output_channels=self.base_filters * 16, conv=self.Conv, norm=self.Norm, ) self.en_4 = EncodingModule( input_channels=self.base_filters * 16, output_channels=self.base_filters * 16, conv=self.Conv, dropout=self.Dropout, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.us_3 = UpsamplingModule( input_channels=self.base_filters * 16, output_channels=self.base_filters * 8, conv=self.Conv, interpolation_mode=self.linear_interpolation_mode, ) self.de_3 = DecodingModule( input_channels=self.base_filters * 16, output_channels=self.base_filters * 8, conv=self.Conv, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.us_2 = UpsamplingModule( input_channels=self.base_filters * 8, output_channels=self.base_filters * 4, conv=self.Conv, interpolation_mode=self.linear_interpolation_mode, ) self.de_2 = DecodingModule( input_channels=self.base_filters * 8, output_channels=self.base_filters * 4, conv=self.Conv, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.us_1 = UpsamplingModule( input_channels=self.base_filters * 4, output_channels=self.base_filters * 2, conv=self.Conv, interpolation_mode=self.linear_interpolation_mode, ) self.de_1 = DecodingModule( input_channels=self.base_filters * 4, output_channels=self.base_filters * 2, conv=self.Conv, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.us_0 = UpsamplingModule( input_channels=self.base_filters * 2, output_channels=self.base_filters, conv=self.Conv, interpolation_mode=self.linear_interpolation_mode, ) self.de_0 = DecodingModule( input_channels=self.base_filters * 2, output_channels=self.base_filters * 2, conv=self.Conv, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.out = out_conv( input_channels=self.base_filters * 2, output_channels=self.n_classes, conv=self.Conv, norm=self.Norm, network_kwargs=self.network_kwargs, final_convolution_layer=self.final_convolution_layer, sigmoid_input_multiplier=self.sigmoid_input_multiplier, )
def __init__( self, parameters: dict, residualConnections=False, ): self.network_kwargs = {"res": residualConnections} super(light_unet_multilayer, self).__init__(parameters) # self.network_kwargs = {"res": False} if not ("depth" in parameters["model"]): parameters["model"]["depth"] = 4 print("Default depth set to 4.") patch_check = checkPatchDimensions(parameters["patch_size"], numlay=parameters["model"]["depth"]) if patch_check != parameters["model"]["depth"] and patch_check >= 2: print( "The patch size is not large enough for desired depth. It is expected that each dimension of the patch size is divisible by 2^i, where i is in a integer greater than or equal to 2. Only the first %d layers will run." % patch_check) elif patch_check < 2: sys.exit( "The patch size is not large enough for desired depth. It is expected that each dimension of the patch size is divisible by 2^i, where i is in a integer greater than or equal to 2." ) self.num_layers = patch_check self.ins = in_conv( input_channels=self.n_channels, output_channels=self.base_filters, conv=self.Conv, dropout=self.Dropout, norm=self.Norm, network_kwargs=self.network_kwargs, ) self.ds = ModuleList([]) self.en = ModuleList([]) self.us = ModuleList([]) self.de = ModuleList([]) for _ in range(0, self.num_layers): self.ds.append( DownsamplingModule( input_channels=self.base_filters, output_channels=self.base_filters, conv=self.Conv, norm=self.Norm, )) self.us.append( UpsamplingModule( input_channels=self.base_filters, output_channels=self.base_filters, conv=self.Conv, interpolation_mode=self.linear_interpolation_mode, )) self.de.append( DecodingModule( input_channels=self.base_filters * 2, output_channels=self.base_filters, conv=self.Conv, norm=self.Norm, network_kwargs=self.network_kwargs, )) self.en.append( EncodingModule( input_channels=self.base_filters, output_channels=self.base_filters, conv=self.Conv, dropout=self.Dropout, norm=self.Norm, network_kwargs=self.network_kwargs, )) self.out = out_conv( input_channels=self.base_filters, output_channels=self.n_classes, conv=self.Conv, norm=self.Norm, network_kwargs=self.network_kwargs, final_convolution_layer=self.final_convolution_layer, sigmoid_input_multiplier=self.sigmoid_input_multiplier, )
def __init__( self, parameters: dict, ): super(unetr, self).__init__(parameters) if not ("inner_patch_size" in parameters["model"]): parameters["model"]["inner_patch_size"] = parameters["patch_size"][0] print("Default inner patch size set to %d." % parameters["patch_size"][0]) if "inner_patch_size" in parameters["model"]: if np.ceil(np.log2(parameters["model"]["inner_patch_size"])) != np.floor( np.log2(parameters["model"]["inner_patch_size"]) ): sys.exit("The inner patch size must be a power of 2.") self.patch_size = parameters["model"]["inner_patch_size"] self.depth = int(np.log2(self.patch_size)) patch_check = checkPatchDimensions(parameters["patch_size"], self.depth) if patch_check != self.depth and patch_check >= 2: print( "The image size is not large enough for desired depth. It is expected that each dimension of the image is divisible by 2^i, where i is in a integer greater than or equal to 2. Only the first %d layers will run." % patch_check ) elif patch_check < 2: sys.exit( "The image size is not large enough for desired depth. It is expected that each dimension of the image is divisible by 2^i, where i is in a integer greater than or equal to 2." ) if not ("num_heads" in parameters["model"]): parameters["model"]["num_heads"] = 12 print( "Default number of heads in multi-head self-attention (MSA) set to 12." ) if not ("embed_dim" in parameters["model"]): parameters["model"]["embed_dim"] = 768 print("Default size of embedded dimension set to 768.") if self.n_dimensions == 2: self.img_size = parameters["patch_size"][0:2] elif self.n_dimensions == 3: self.img_size = parameters["patch_size"] self.num_layers = 3 * self.depth # number of transformer layers self.out_layers = np.arange(2, self.num_layers, 3) self.num_heads = parameters["model"]["num_heads"] self.embed_size = parameters["model"]["embed_dim"] self.patch_dim = [i // self.patch_size for i in self.img_size] if not all([i % self.patch_size == 0 for i in self.img_size]): sys.exit( "The image size is not divisible by the patch size in at least 1 dimension. UNETR is not defined in this case." ) if not all([self.patch_size <= i for i in self.img_size]): sys.exit("The inner patch size must be smaller than the input image.") self.transformer = _Transformer( img_size=self.img_size, patch_size=self.patch_size, in_feats=self.n_channels, embed_size=self.embed_size, num_heads=self.num_heads, mlp_dim=2048, num_layers=self.num_layers, out_layers=self.out_layers, Conv=self.Conv, Norm=self.Norm, ) self.upsampling = ModuleList([]) self.convs = ModuleList([]) for i in range(0, self.depth - 1): # add deconv blocks tempconvs = nn.Sequential() tempconvs.add_module( "conv0", _DeconvConvBlock( self.embed_size, 32 * 2**self.depth, self.Norm, self.Conv, self.ConvTranspose, ), ) for j in range(self.depth - 2, i, -1): tempconvs.add_module( "conv%d" % j, _DeconvConvBlock( 128 * 2**j, 128 * 2 ** (j - 1), self.Norm, self.Conv, self.ConvTranspose, ), ) self.convs.append(tempconvs) # add upsampling self.upsampling.append( _UpsampleBlock( 128 * 2 ** (i + 1), self.Norm, self.Conv, self.ConvTranspose ) ) # add upsampling for transformer output (no convs) self.upsampling.append( self.ConvTranspose( in_channels=self.embed_size, out_channels=32 * 2**self.depth, kernel_size=2, stride=2, padding=0, output_padding=0, ) ) self.input_conv = nn.Sequential() self.input_conv.add_module( "conv1", _ConvBlock(self.n_channels, 32, self.Norm, self.Conv) ) self.input_conv.add_module("conv2", _ConvBlock(32, 64, self.Norm, self.Conv)) self.output_conv = nn.Sequential() self.output_conv.add_module("conv1", _ConvBlock(128, 64, self.Norm, self.Conv)) self.output_conv.add_module("conv2", _ConvBlock(64, 64, self.Norm, self.Conv)) self.output_conv.add_module( "conv3", out_conv( 64, self.n_classes, conv_kwargs={ "kernel_size": 1, "stride": 1, "padding": 0, "bias": False, }, norm=self.Norm, conv=self.Conv, final_convolution_layer=self.final_convolution_layer, sigmoid_input_multiplier=self.sigmoid_input_multiplier, ), )