def __init__(self, previous, num_classes, num_blocks_per_stage=None, network_props=None, deep_supervision=False,
                 upscale_logits=False, block=BasicResidualBlock):
        super(ResidualUNetDecoder, self).__init__()
        self.num_classes = num_classes
        self.deep_supervision = deep_supervision
        """
        We assume the bottleneck is part of the encoder, so we can start with upsample -> concat here
        """
        previous_stages = previous.stages
        previous_stage_output_features = previous.stage_output_features
        previous_stage_pool_kernel_size = previous.stage_pool_kernel_size
        previous_stage_conv_op_kernel_size = previous.stage_conv_op_kernel_size

        if network_props is None:
            self.props = previous.props
        else:
            self.props = network_props

        if self.props['conv_op'] == nn.Conv2d:
            transpconv = nn.ConvTranspose2d
            upsample_mode = "bilinear"
        elif self.props['conv_op'] == nn.Conv3d:
            transpconv = nn.ConvTranspose3d
            upsample_mode = "trilinear"
        else:
            raise ValueError("unknown convolution dimensionality, conv op: %s" % str(self.props['conv_op']))

        if num_blocks_per_stage is None:
            num_blocks_per_stage = previous.num_blocks_per_stage[:-1][::-1]
        assert len(num_blocks_per_stage) == len(previous.num_blocks_per_stage) - 1

        self.stage_pool_kernel_size = previous_stage_pool_kernel_size
        self.stage_output_features = previous_stage_output_features
        self.stage_conv_op_kernel_size = previous_stage_conv_op_kernel_size

        num_stages = len(previous_stages) - 1  # we have one less as the first stage here is what comes after the
        # bottleneck

        self.tus = []
        self.stages = []
        self.deep_supervision_outputs = []

        # only used for upsample_logits
        cum_upsample = np.cumprod(np.vstack(self.stage_pool_kernel_size), axis=0).astype(int)

        for i, s in enumerate(np.arange(num_stages)[::-1]):
            features_below = previous_stage_output_features[s + 1]
            features_skip = previous_stage_output_features[s]

            self.tus.append(transpconv(features_below, features_skip, previous_stage_pool_kernel_size[s + 1],
                                       previous_stage_pool_kernel_size[s + 1], bias=False))
            # after we tu we concat features so now we have 2xfeatures_skip
            self.stages.append(ResidualLayer(2 * features_skip, features_skip, previous_stage_conv_op_kernel_size[s],
                                             self.props, num_blocks_per_stage[i], None, block))

            if deep_supervision and s != 0:
                seg_layer = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False)
                if upscale_logits:
                    upsample = Upsample(scale_factor=cum_upsample[s], mode=upsample_mode)
                    self.deep_supervision_outputs.append(nn.Sequential(seg_layer, upsample))
                else:
                    self.deep_supervision_outputs.append(seg_layer)

        self.segmentation_output = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False)

        self.tus = nn.ModuleList(self.tus)
        self.stages = nn.ModuleList(self.stages)
        self.deep_supervision_outputs = nn.ModuleList(self.deep_supervision_outputs)
Example #2
0
    def __init__(self,
                 input_channels,
                 base_num_features,
                 num_classes,
                 num_pool,
                 num_conv_per_stage=2,
                 feat_map_mul_on_downscale=2,
                 conv_op=nn.Conv2d,
                 norm_op=nn.BatchNorm2d,
                 norm_op_kwargs=None,
                 dropout_op=nn.Dropout2d,
                 dropout_op_kwargs=None,
                 nonlin=nn.LeakyReLU,
                 nonlin_kwargs=None,
                 deep_supervision=True,
                 dropout_in_localization=False,
                 final_nonlin=softmax_helper,
                 weightInitializer=InitWeights_He(1e-2),
                 pool_op_kernel_sizes=None,
                 conv_kernel_sizes=None,
                 upscale_logits=False,
                 convolutional_pooling=False,
                 convolutional_upsampling=False,
                 max_num_features=None,
                 basic_block=ConvDropoutNormNonlin,
                 seg_output_use_bias=False):
        super(SAWNet, self).__init__(
            input_channels, base_num_features, num_classes, num_pool,
            num_conv_per_stage, feat_map_mul_on_downscale, conv_op, norm_op,
            norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin,
            nonlin_kwargs, deep_supervision, dropout_in_localization,
            final_nonlin, weightInitializer, pool_op_kernel_sizes,
            conv_kernel_sizes, upscale_logits, convolutional_pooling,
            convolutional_upsampling, max_num_features, basic_block,
            seg_output_use_bias)

        self.conv_blocks_w = []
        self.tuw = []
        self.w_outputs = []

        upsample_mode = 'bilinear'
        transpconv = nn.ConvTranspose2d

        output_features = base_num_features
        input_features = input_channels
        print("num_pool:", num_pool)
        print("input_features:", input_features)
        for d in range(num_pool):
            input_features = output_features
            output_features = int(
                np.round(output_features * feat_map_mul_on_downscale))
            output_features = min(output_features, self.max_num_features)
            print("{}: {}".format(d, input_features))

        self.sau = SAUnit(output_features)

        if self.convolutional_upsampling:
            final_num_features = output_features  # base_num_features
        else:
            final_num_features = self.conv_blocks_context[-1].output_channels

        for u in range(num_pool):
            nfeatures_from_down = final_num_features
            nfeatures_from_skip = self.conv_blocks_context[-(
                2 + u)].output_channels
            n_features_after_tu_and_concat = nfeatures_from_skip * 2

            # the first conv reduces the number of features to match those of skip
            # the following convs work on that number of features
            # if not convolutional upsampling then the final conv reduces the num of features again
            if u != num_pool - 1 and not self.convolutional_upsampling:
                final_num_features = self.conv_blocks_context[-(
                    3 + u)].output_channels
            else:
                final_num_features = nfeatures_from_skip

            if not self.convolutional_upsampling:
                self.tuw.append(
                    Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)],
                             mode=upsample_mode))
            else:
                # print("u: {} nfeatures_from_down: {} nfeatures_from_skip: {}".format(u, nfeatures_from_down,
                #                                                                     nfeatures_from_skip))
                self.tuw.append(
                    transpconv(nfeatures_from_down,
                               nfeatures_from_skip,
                               pool_op_kernel_sizes[-(u + 1)],
                               pool_op_kernel_sizes[-(u + 1)],
                               bias=False))

            self.conv_blocks_w.append(
                nn.Sequential(
                    StackedConvLayers(n_features_after_tu_and_concat,
                                      nfeatures_from_skip,
                                      num_conv_per_stage - 1,
                                      self.conv_op,
                                      self.conv_kwargs,
                                      self.norm_op,
                                      self.norm_op_kwargs,
                                      self.dropout_op,
                                      self.dropout_op_kwargs,
                                      self.nonlin,
                                      self.nonlin_kwargs,
                                      basic_block=basic_block),
                    StackedConvLayers(nfeatures_from_skip,
                                      final_num_features,
                                      1,
                                      self.conv_op,
                                      self.conv_kwargs,
                                      self.norm_op,
                                      self.norm_op_kwargs,
                                      self.dropout_op,
                                      self.dropout_op_kwargs,
                                      self.nonlin,
                                      self.nonlin_kwargs,
                                      basic_block=basic_block)))

        for ds in range(len(self.conv_blocks_w)):
            # todo: relu?
            self.w_outputs.append(
                conv_op(self.conv_blocks_w[ds][-1].output_channels, 1, 1, 1, 0,
                        1, 1, seg_output_use_bias))

        self.conv_blocks_w = nn.ModuleList(self.conv_blocks_w)
        self.tuw = nn.ModuleList(self.tuw)
        self.w_outputs = nn.ModuleList(self.w_outputs)
        # self.final_conv = nn.Conv2d(32, 1, (1, 1))

        if self.weightInitializer is not None:
            self.apply(self.weightInitializer)