Esempio n. 1
0
    def initialize_network(self):
        """
        changed deep supervision to False
        :return:
        """
        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = MyGroupNorm

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = MyGroupNorm

        norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'num_groups': 8}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        self.network = Generic_UNet(
            self.num_input_channels, self.base_num_features, self.num_classes,
            len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2,
            conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
            net_nonlin, net_nonlin_kwargs, True, False, lambda x: x,
            InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper
Esempio n. 2
0
    def initialize_network(self):
        """
        replace genericUNet with the implementation of above for super speeds
        """
        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = nn.InstanceNorm3d

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = nn.InstanceNorm2d

        norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        self.network = Generic_UNet_DP(
            self.num_input_channels, self.base_num_features, self.num_classes,
            len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2,
            conv_op, norm_op, norm_op_kwargs, dropout_op,
            dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False,
            InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper
    def initialize_network(self):
        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = nn.InstanceNorm3d

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = nn.InstanceNorm2d

        norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        self.network = Generic_UNet(
            self.num_input_channels, self.base_num_features, self.num_classes,
            len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2,
            conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
            net_nonlin, net_nonlin_kwargs, True, False, lambda x: x,
            InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = nn.Sigmoid()
Esempio n. 4
0
    def initialize_network(self):
        """
        - momentum 0.99
        - SGD instead of Adam
        - self.lr_scheduler = None because we do poly_lr
        - deep supervision = True
        """
        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = nn.InstanceNorm3d

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = nn.InstanceNorm2d

        norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        self.network = Generic_UNet(
            self.num_input_channels, self.base_num_features, self.num_classes,
            len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2,
            conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
            net_nonlin, net_nonlin_kwargs, True, False, lambda x: x,
            InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper
Esempio n. 5
0
    def initialize_network_optimizer_and_scheduler(self):
        """
        This is specific to the U-Net and must be adapted for other network architectures
        :return:
        """
        # self.print_to_log_file(self.net_num_pool_op_kernel_sizes)
        # self.print_to_log_file(self.net_conv_kernel_sizes)

        net_numpool = len(self.net_num_pool_op_kernel_sizes)

        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = nn.InstanceNorm3d
        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = nn.InstanceNorm2d

        norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        self.network = Generic_UNet(
            self.num_input_channels, self.base_num_features, self.num_classes,
            net_numpool, 2, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
            dropout_op_kwargs,
            net_nonlin, net_nonlin_kwargs, False, False, lambda x: x,
            InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        self.optimizer = torch.optim.Adam(self.network.parameters(),
                                          self.initial_lr,
                                          weight_decay=self.weight_decay,
                                          amsgrad=True)
        self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.2,
            patience=self.lr_scheduler_patience,
            verbose=True,
            threshold=self.lr_scheduler_eps,
            threshold_mode="abs")
        self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper
Esempio n. 6
0
    def on_epoch_end(self):
        """
        overwrite patient-based early stopping. Always run to 1000 epochs
        """
        super().on_epoch_end()
        continue_training = self.epoch < self.max_num_epochs

        # it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the
        # estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95
        if self.epoch == 100:
            if self.all_val_eval_metrics[-1] == 0:
                self.optimizer.param_groups[0]["momentum"] = 0.95
                self.network.apply(InitWeights_He(1e-2))
                self.print_to_log_file(
                    "At epoch 100, the mean foreground Dice was 0. This can be caused by a too "
                    "high momentum. High momentum (0.99) is good for datasets where it works, but "
                    "sometimes causes issues such as this one. Momentum has now been reduced to "
                    "0.95 and network weights have been reinitialized")
        return continue_training
Esempio n. 7
0
    def initialize_network(self):
        """
        - momentum 0.99
        - SGD instead of Adam
        - self.lr_scheduler = None because we do poly_lr
        - deep supervision = True
        - ReLU
        - i am sure I forgot something here

        Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
        :return:
        """
        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = nn.InstanceNorm3d

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = nn.InstanceNorm2d

        norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = GeLU
        net_nonlin_kwargs = {}
        self.network = Generic_UNet(
            self.num_input_channels, self.base_num_features, self.num_classes,
            len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2,
            conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
            net_nonlin, net_nonlin_kwargs, True, False, lambda x: x,
            InitWeights_He(), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper
Esempio n. 8
0
    def initialize_network(self):
        if self.threeD:
            cfg = get_default_network_config(3, None, norm_type="in")

        else:
            cfg = get_default_network_config(1, None, norm_type="in")

        stage_plans = self.plans['plans_per_stage'][self.stage]
        conv_kernel_sizes = stage_plans['conv_kernel_sizes']
        blocks_per_stage_encoder = stage_plans['num_blocks_encoder']
        blocks_per_stage_decoder = stage_plans['num_blocks_decoder']
        pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']

        self.network = FabiansUNet(self.num_input_channels,
                                   self.base_num_features,
                                   blocks_per_stage_encoder, 2,
                                   pool_op_kernel_sizes, conv_kernel_sizes,
                                   cfg, self.num_classes,
                                   blocks_per_stage_decoder, True, False, 320,
                                   InitWeights_He(1e-2))

        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper
Esempio n. 9
0
    def __init__(self,
                 input_channels,
                 base_num_features,
                 num_classes,
                 num_pool,
                 num_conv_per_stage=2,
                 feat_map_mul_on_downscale=2,
                 conv_op=nn.Conv2d,
                 norm_op=nn.BatchNorm2d,
                 norm_op_kwargs=None,
                 dropout_op=nn.Dropout2d,
                 dropout_op_kwargs=None,
                 nonlin=nn.LeakyReLU,
                 nonlin_kwargs=None,
                 deep_supervision=True,
                 dropout_in_localization=False,
                 final_nonlin=softmax_helper,
                 weightInitializer=InitWeights_He(1e-2),
                 pool_op_kernel_sizes=None,
                 conv_kernel_sizes=None,
                 upscale_logits=False,
                 convolutional_pooling=False,
                 convolutional_upsampling=False,
                 max_num_features=None,
                 basic_block=ConvDropoutNormNonlin,
                 seg_output_use_bias=False):
        """
        basically more flexible than v1, architecture is the same

        Does this look complicated? Nah bro. Functionality > usability

        This does everything you need, including world peace.

        Questions? -> [email protected]
        """
        super(Generic_UNet, self).__init__()
        self.convolutional_upsampling = convolutional_upsampling
        self.convolutional_pooling = convolutional_pooling
        self.upscale_logits = upscale_logits
        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {'p': 0.5, 'inplace': True}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}

        self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True}

        self.nonlin = nonlin
        self.nonlin_kwargs = nonlin_kwargs
        self.dropout_op_kwargs = dropout_op_kwargs
        self.norm_op_kwargs = norm_op_kwargs
        self.weightInitializer = weightInitializer
        self.conv_op = conv_op
        self.norm_op = norm_op
        self.dropout_op = dropout_op
        self.num_classes = num_classes
        self.final_nonlin = final_nonlin
        self._deep_supervision = deep_supervision
        self.do_ds = deep_supervision

        if conv_op == nn.Conv2d:
            upsample_mode = 'bilinear'
            pool_op = nn.MaxPool2d
            transpconv = nn.ConvTranspose2d
            if pool_op_kernel_sizes is None:
                pool_op_kernel_sizes = [(2, 2)] * num_pool
            if conv_kernel_sizes is None:
                conv_kernel_sizes = [(3, 3)] * (num_pool + 1)
        elif conv_op == nn.Conv3d:
            upsample_mode = 'trilinear'
            pool_op = nn.MaxPool3d
            transpconv = nn.ConvTranspose3d
            if pool_op_kernel_sizes is None:
                pool_op_kernel_sizes = [(2, 2, 2)] * num_pool
            if conv_kernel_sizes is None:
                conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1)
        else:
            raise ValueError(
                "unknown convolution dimensionality, conv op: %s" %
                str(conv_op))

        self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes,
                                                        0,
                                                        dtype=np.int64)
        self.pool_op_kernel_sizes = pool_op_kernel_sizes
        self.conv_kernel_sizes = conv_kernel_sizes

        self.conv_pad_sizes = []
        for krnl in self.conv_kernel_sizes:
            self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])

        if max_num_features is None:
            if self.conv_op == nn.Conv3d:
                self.max_num_features = self.MAX_NUM_FILTERS_3D
            else:
                self.max_num_features = self.MAX_FILTERS_2D
        else:
            self.max_num_features = max_num_features

        self.conv_blocks_context = []
        self.conv_blocks_localization = []
        self.td = []
        self.tu = []
        self.seg_outputs = []

        output_features = base_num_features
        input_features = input_channels

        for d in range(num_pool):
            # determine the first stride
            if d != 0 and self.convolutional_pooling:
                first_stride = pool_op_kernel_sizes[d - 1]
            else:
                first_stride = None

            self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d]
            self.conv_kwargs['padding'] = self.conv_pad_sizes[d]
            # add convolutions
            self.conv_blocks_context.append(
                StackedConvLayers(input_features,
                                  output_features,
                                  num_conv_per_stage,
                                  self.conv_op,
                                  self.conv_kwargs,
                                  self.norm_op,
                                  self.norm_op_kwargs,
                                  self.dropout_op,
                                  self.dropout_op_kwargs,
                                  self.nonlin,
                                  self.nonlin_kwargs,
                                  first_stride,
                                  basic_block=basic_block))
            if not self.convolutional_pooling:
                self.td.append(pool_op(pool_op_kernel_sizes[d]))
            input_features = output_features
            output_features = int(
                np.round(output_features * feat_map_mul_on_downscale))

            output_features = min(output_features, self.max_num_features)

        # now the bottleneck.
        # determine the first stride
        if self.convolutional_pooling:
            first_stride = pool_op_kernel_sizes[-1]
        else:
            first_stride = None

        # the output of the last conv must match the number of features from the skip connection if we are not using
        # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be
        # done by the transposed conv
        if self.convolutional_upsampling:
            final_num_features = output_features
        else:
            final_num_features = self.conv_blocks_context[-1].output_channels

        self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool]
        self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool]
        self.conv_blocks_context.append(
            nn.Sequential(
                StackedConvLayers(input_features,
                                  output_features,
                                  num_conv_per_stage - 1,
                                  self.conv_op,
                                  self.conv_kwargs,
                                  self.norm_op,
                                  self.norm_op_kwargs,
                                  self.dropout_op,
                                  self.dropout_op_kwargs,
                                  self.nonlin,
                                  self.nonlin_kwargs,
                                  first_stride,
                                  basic_block=basic_block),
                StackedConvLayers(output_features,
                                  final_num_features,
                                  1,
                                  self.conv_op,
                                  self.conv_kwargs,
                                  self.norm_op,
                                  self.norm_op_kwargs,
                                  self.dropout_op,
                                  self.dropout_op_kwargs,
                                  self.nonlin,
                                  self.nonlin_kwargs,
                                  basic_block=basic_block)))

        # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here
        if not dropout_in_localization:
            old_dropout_p = self.dropout_op_kwargs['p']
            self.dropout_op_kwargs['p'] = 0.0

        # now lets build the localization pathway
        for u in range(num_pool):
            nfeatures_from_down = final_num_features
            nfeatures_from_skip = self.conv_blocks_context[-(
                2 + u
            )].output_channels  # self.conv_blocks_context[-1] is bottleneck, so start with -2
            n_features_after_tu_and_concat = nfeatures_from_skip * 2

            # the first conv reduces the number of features to match those of skip
            # the following convs work on that number of features
            # if not convolutional upsampling then the final conv reduces the num of features again
            if u != num_pool - 1 and not self.convolutional_upsampling:
                final_num_features = self.conv_blocks_context[-(
                    3 + u)].output_channels
            else:
                final_num_features = nfeatures_from_skip

            if not self.convolutional_upsampling:
                self.tu.append(
                    Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)],
                             mode=upsample_mode))
            else:
                self.tu.append(
                    transpconv(nfeatures_from_down,
                               nfeatures_from_skip,
                               pool_op_kernel_sizes[-(u + 1)],
                               pool_op_kernel_sizes[-(u + 1)],
                               bias=False))

            self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[-(u + 1)]
            self.conv_kwargs['padding'] = self.conv_pad_sizes[-(u + 1)]
            self.conv_blocks_localization.append(
                nn.Sequential(
                    StackedConvLayers(n_features_after_tu_and_concat,
                                      nfeatures_from_skip,
                                      num_conv_per_stage - 1,
                                      self.conv_op,
                                      self.conv_kwargs,
                                      self.norm_op,
                                      self.norm_op_kwargs,
                                      self.dropout_op,
                                      self.dropout_op_kwargs,
                                      self.nonlin,
                                      self.nonlin_kwargs,
                                      basic_block=basic_block),
                    StackedConvLayers(nfeatures_from_skip,
                                      final_num_features,
                                      1,
                                      self.conv_op,
                                      self.conv_kwargs,
                                      self.norm_op,
                                      self.norm_op_kwargs,
                                      self.dropout_op,
                                      self.dropout_op_kwargs,
                                      self.nonlin,
                                      self.nonlin_kwargs,
                                      basic_block=basic_block)))

        for ds in range(len(self.conv_blocks_localization)):
            self.seg_outputs.append(
                conv_op(self.conv_blocks_localization[ds][-1].output_channels,
                        num_classes, 1, 1, 0, 1, 1, seg_output_use_bias))

        self.upscale_logits_ops = []
        cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes),
                                  axis=0)[::-1]
        for usl in range(num_pool - 1):
            if self.upscale_logits:
                self.upscale_logits_ops.append(
                    Upsample(scale_factor=tuple(
                        [int(i) for i in cum_upsample[usl + 1]]),
                             mode=upsample_mode))
            else:
                self.upscale_logits_ops.append(lambda x: x)

        if not dropout_in_localization:
            self.dropout_op_kwargs['p'] = old_dropout_p

        # register all modules properly
        self.conv_blocks_localization = nn.ModuleList(
            self.conv_blocks_localization)
        self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context)
        self.td = nn.ModuleList(self.td)
        self.tu = nn.ModuleList(self.tu)
        self.seg_outputs = nn.ModuleList(self.seg_outputs)
        if self.upscale_logits:
            self.upscale_logits_ops = nn.ModuleList(
                self.upscale_logits_ops
            )  # lambda x:x is not a Module so we need to distinguish here

        if self.weightInitializer is not None:
            self.apply(self.weightInitializer)
Esempio n. 10
0
    def initialize_network(self):
        """
        changed deep supervision to False
        :return:
        """
        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = FRN3D

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            raise NotImplementedError
            norm_op = nn.BatchNorm2d

        norm_op_kwargs = {'eps': 1e-6}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = Identity
        net_nonlin_kwargs = {}
        self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
                                    len(self.net_num_pool_op_kernel_sizes),
                                    self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
                                    net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
                                    self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper
Esempio n. 11
0
    def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2,
                 feat_map_mul_on_downscale=2, conv_op=nn.Conv2d,
                 norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
                 dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
                 nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False,
                 final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None,
                 conv_kernel_sizes=None,
                 upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False):
        """
        basically more flexible than v1, architecture is the same

        Does this look complicated? Nah bro. Functionality > usability

        This does everything you need, including world peace.

        Questions? -> [email protected]
        """
        super(Generic_UNet, self).__init__()
        # self.convolutional_upsampling = convolutional_upsampling
        # self.convolutional_pooling = convolutional_pooling
        # self.upscale_logits = upscale_logits
        # if nonlin_kwargs is None:
        #      nonlin_kwargs = {'negative_slope':1e-2, 'inplace':True}
        # if dropout_op_kwargs is None:
        #     dropout_op_kwargs = {'p':0.5, 'inplace':True}
        # if norm_op_kwargs is None:
        #     norm_op_kwargs = {'eps':1e-5, 'affine':True, 'momentum':0.1}
        #
        # self.conv_kwargs = {'stride':1, 'dilation':1, 'bias':True}
        #
        # self.nonlin = nonlin
        # self.nonlin_kwargs = nonlin_kwargs
        # self.dropout_op_kwargs = dropout_op_kwargs
        # self.norm_op_kwargs = norm_op_kwargs
        # self.weightInitializer = weightInitializer
        # self.conv_op = conv_op
        # self.norm_op = norm_op
        # self.dropout_op = dropout_op
        # self.num_classes = num_classes
        # self.final_nonlin = final_nonlin
        # self.do_ds = deep_supervision
        #
        # if conv_op == nn.Conv2d:
        #     upsample_mode = 'bilinear'
        #     pool_op = nn.MaxPool2d
        #     transpconv = nn.ConvTranspose2d
        #     if pool_op_kernel_sizes is None:
        #         pool_op_kernel_sizes = [(2, 2)] * num_pool
        #     if conv_kernel_sizes is None:
        #         conv_kernel_sizes = [(3, 3)] * (num_pool + 1)
        # elif conv_op == nn.Conv3d:
        #     upsample_mode = 'trilinear'
        #     pool_op = nn.MaxPool3d
        #     transpconv = nn.ConvTranspose3d
        #     if pool_op_kernel_sizes is None:
        #         pool_op_kernel_sizes = [(2, 2, 2)] * num_pool
        #     if conv_kernel_sizes is None:
        #         conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1)
        # else:
        #     raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op))
        #
        # self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0)
        # self.pool_op_kernel_sizes = pool_op_kernel_sizes
        # self.conv_kernel_sizes = conv_kernel_sizes
        #
        # self.conv_pad_sizes = []
        # for krnl in self.conv_kernel_sizes:
        #     self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])
        #
        # self.conv_blocks_context = []
        # self.conv_blocks_localization = []
        # self.td = []
        # self.tu = []
        # self.seg_outputs = []
        #
        # output_features = base_num_features
        # input_features = input_channels
        #
        # for d in range(num_pool):
        #     # determine the first stride
        #     if d != 0 and self.convolutional_pooling:
        #         first_stride = pool_op_kernel_sizes[d-1]
        #     else:
        #         first_stride = None
        #
        #     self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d]
        #     self.conv_kwargs['padding'] = self.conv_pad_sizes[d]
        #     # add convolutions
        #     self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage,
        #                                                       self.conv_op, self.conv_kwargs, self.norm_op,
        #                                                       self.norm_op_kwargs, self.dropout_op,
        #                                                       self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs,
        #                                                       first_stride))
        #     if not self.convolutional_pooling:
        #         self.td.append(pool_op(pool_op_kernel_sizes[d]))
        #     input_features = output_features
        #     output_features = int(np.round(output_features * feat_map_mul_on_downscale))
        #     if self.conv_op == nn.Conv3d:
        #         output_features = min(output_features, self.MAX_NUM_FILTERS_3D)
        #     else:
        #         output_features = min(output_features, self.MAX_FILTERS_2D)
        #
        # # now the bottleneck.
        # # determine the first stride
        # if self.convolutional_pooling:
        #     first_stride = pool_op_kernel_sizes[-1]
        # else:
        #     first_stride = None
        #
        # # the output of the last conv must match the number of features from the skip connection if we are not using
        # # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be
        # # done by the transposed conv
        # if self.convolutional_upsampling:
        #     final_num_features = output_features
        # else:
        #     final_num_features = self.conv_blocks_context[-1].output_channels
        #
        # self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool]
        # self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool]
        # self.conv_blocks_context.append(nn.Sequential(
        #     StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs,
        #                       self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
        #                       self.nonlin_kwargs, first_stride),
        #     StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs,
        #                       self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
        #                       self.nonlin_kwargs)))
        #
        # # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here
        # if not dropout_in_localization:
        #     old_dropout_p = self.dropout_op_kwargs['p']
        #     self.dropout_op_kwargs['p'] = 0.0
        #
        # # now lets build the localization pathway
        # for u in range(num_pool):
        #     nfeatures_from_down = final_num_features
        #     nfeatures_from_skip = self.conv_blocks_context[-(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2
        #     n_features_after_tu_and_concat = nfeatures_from_skip * 2
        #
        #     # the first conv reduces the number of features to match those of skip
        #     # the following convs work on that number of features
        #     # if not convolutional upsampling then the final conv reduces the num of features again
        #     if u != num_pool - 1 and not self.convolutional_upsampling:
        #         final_num_features = self.conv_blocks_context[-(3 + u)].output_channels
        #     else:
        #         final_num_features = nfeatures_from_skip
        #
        #     if not self.convolutional_upsampling:
        #         self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u+1)], mode=upsample_mode))
        #     else:
        #         self.tu.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u+1)],
        #                                   pool_op_kernel_sizes[-(u+1)], bias=False))
        #
        #     self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u+1)]
        #     self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u+1)]
        #     self.conv_blocks_localization.append(nn.Sequential(
        #         StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1,
        #                           self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op,
        #                           self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs),
        #         StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs,
        #                           self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
        #                           self.nonlin, self.nonlin_kwargs)
        #     ))
        #
        # for ds in range(len(self.conv_blocks_localization)):
        #     self.seg_outputs.append(conv_op(self.conv_blocks_localization[ds][-1].output_channels, num_classes,
        #                                     1, 1, 0, 1, 1, False))
        #
        # self.upscale_logits_ops = []
        # cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1]
        # for usl in range(num_pool - 1):
        #     if self.upscale_logits:
        #         self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl+1]]),
        #                                                 mode=upsample_mode))
        #     else:
        #         self.upscale_logits_ops.append(lambda x: x)
        #
        # if not dropout_in_localization:
        #     self.dropout_op_kwargs['p'] = old_dropout_p
        #
        # # register all modules properly
        # self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization)
        # self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context)
        # self.td = nn.ModuleList(self.td)
        # self.tu = nn.ModuleList(self.tu)
        # self.seg_outputs = nn.ModuleList(self.seg_outputs)
        # if self.upscale_logits:
        #     self.upscale_logits_ops = nn.ModuleList(self.upscale_logits_ops) # lambda x:x is not a Module so we need to distinguish here
        #
        # if self.weightInitializer is not None:
        #     self.apply(self.weightInitializer)
        #     #self.apply(print_module_training_status)

        # dawn changed here
        # the model generator is changed in a very rude way
        self.conv_op = conv_op
        self.num_classes = num_classes
        # deeplab

        BatchNorm = SynchronizedBatchNorm2d
        input_channels = 3
        self.preprocessor = build_preprocessor(input_channels, 3, BatchNorm)
        self.backbone = build_backbone('resnet', 16, BatchNorm)
        self.aspp = build_aspp('resnet', 16, BatchNorm)
        self.decoder = build_decoder(num_classes, 'resnet', BatchNorm)