def __init__(self, previous, num_classes, num_blocks_per_stage=None, network_props=None, deep_supervision=False, upscale_logits=False, block=BasicResidualBlock): super(ResidualUNetDecoder, self).__init__() self.num_classes = num_classes self.deep_supervision = deep_supervision """ We assume the bottleneck is part of the encoder, so we can start with upsample -> concat here """ previous_stages = previous.stages previous_stage_output_features = previous.stage_output_features previous_stage_pool_kernel_size = previous.stage_pool_kernel_size previous_stage_conv_op_kernel_size = previous.stage_conv_op_kernel_size if network_props is None: self.props = previous.props else: self.props = network_props if self.props['conv_op'] == nn.Conv2d: transpconv = nn.ConvTranspose2d upsample_mode = "bilinear" elif self.props['conv_op'] == nn.Conv3d: transpconv = nn.ConvTranspose3d upsample_mode = "trilinear" else: raise ValueError("unknown convolution dimensionality, conv op: %s" % str(self.props['conv_op'])) if num_blocks_per_stage is None: num_blocks_per_stage = previous.num_blocks_per_stage[:-1][::-1] assert len(num_blocks_per_stage) == len(previous.num_blocks_per_stage) - 1 self.stage_pool_kernel_size = previous_stage_pool_kernel_size self.stage_output_features = previous_stage_output_features self.stage_conv_op_kernel_size = previous_stage_conv_op_kernel_size num_stages = len(previous_stages) - 1 # we have one less as the first stage here is what comes after the # bottleneck self.tus = [] self.stages = [] self.deep_supervision_outputs = [] # only used for upsample_logits cum_upsample = np.cumprod(np.vstack(self.stage_pool_kernel_size), axis=0).astype(int) for i, s in enumerate(np.arange(num_stages)[::-1]): features_below = previous_stage_output_features[s + 1] features_skip = previous_stage_output_features[s] self.tus.append(transpconv(features_below, features_skip, previous_stage_pool_kernel_size[s + 1], previous_stage_pool_kernel_size[s + 1], bias=False)) # after we tu we concat features so now we have 2xfeatures_skip self.stages.append(ResidualLayer(2 * features_skip, features_skip, previous_stage_conv_op_kernel_size[s], self.props, num_blocks_per_stage[i], None, block)) if deep_supervision and s != 0: seg_layer = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False) if upscale_logits: upsample = Upsample(scale_factor=cum_upsample[s], mode=upsample_mode) self.deep_supervision_outputs.append(nn.Sequential(seg_layer, upsample)) else: self.deep_supervision_outputs.append(seg_layer) self.segmentation_output = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False) self.tus = nn.ModuleList(self.tus) self.stages = nn.ModuleList(self.stages) self.deep_supervision_outputs = nn.ModuleList(self.deep_supervision_outputs)
def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2, feat_map_mul_on_downscale=2, conv_op=nn.Conv2d, norm_op=nn.BatchNorm2d, norm_op_kwargs=None, dropout_op=nn.Dropout2d, dropout_op_kwargs=None, nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False, final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None, conv_kernel_sizes=None, upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False, max_num_features=None, basic_block=ConvDropoutNormNonlin, seg_output_use_bias=False): super(SAWNet, self).__init__( input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage, feat_map_mul_on_downscale, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, deep_supervision, dropout_in_localization, final_nonlin, weightInitializer, pool_op_kernel_sizes, conv_kernel_sizes, upscale_logits, convolutional_pooling, convolutional_upsampling, max_num_features, basic_block, seg_output_use_bias) self.conv_blocks_w = [] self.tuw = [] self.w_outputs = [] upsample_mode = 'bilinear' transpconv = nn.ConvTranspose2d output_features = base_num_features input_features = input_channels print("num_pool:", num_pool) print("input_features:", input_features) for d in range(num_pool): input_features = output_features output_features = int( np.round(output_features * feat_map_mul_on_downscale)) output_features = min(output_features, self.max_num_features) print("{}: {}".format(d, input_features)) self.sau = SAUnit(output_features) if self.convolutional_upsampling: final_num_features = output_features # base_num_features else: final_num_features = self.conv_blocks_context[-1].output_channels for u in range(num_pool): nfeatures_from_down = final_num_features nfeatures_from_skip = self.conv_blocks_context[-( 2 + u)].output_channels n_features_after_tu_and_concat = nfeatures_from_skip * 2 # the first conv reduces the number of features to match those of skip # the following convs work on that number of features # if not convolutional upsampling then the final conv reduces the num of features again if u != num_pool - 1 and not self.convolutional_upsampling: final_num_features = self.conv_blocks_context[-( 3 + u)].output_channels else: final_num_features = nfeatures_from_skip if not self.convolutional_upsampling: self.tuw.append( Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode)) else: # print("u: {} nfeatures_from_down: {} nfeatures_from_skip: {}".format(u, nfeatures_from_down, # nfeatures_from_skip)) self.tuw.append( transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)], pool_op_kernel_sizes[-(u + 1)], bias=False)) self.conv_blocks_w.append( nn.Sequential( StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block), StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block))) for ds in range(len(self.conv_blocks_w)): # todo: relu? self.w_outputs.append( conv_op(self.conv_blocks_w[ds][-1].output_channels, 1, 1, 1, 0, 1, 1, seg_output_use_bias)) self.conv_blocks_w = nn.ModuleList(self.conv_blocks_w) self.tuw = nn.ModuleList(self.tuw) self.w_outputs = nn.ModuleList(self.w_outputs) # self.final_conv = nn.Conv2d(32, 1, (1, 1)) if self.weightInitializer is not None: self.apply(self.weightInitializer)