Exemplo n.º 1
0
 def setup_DA_params(self):
     nnUNetTrainerV2.setup_DA_params(self)
     self.data_aug_params["do_mirror"] = False
    def setup_DA_params(self):
        nnUNetTrainerV2.setup_DA_params(self)
        self.deep_supervision_scales = [[1, 1, 1]] + list(
            list(i) for i in 1 / np.cumprod(
                np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]

        if self.threeD:
            self.data_aug_params = default_3D_augmentation_params
            self.data_aug_params['rotation_x'] = (-90. / 360 * 2. * np.pi,
                                                  90. / 360 * 2. * np.pi)
            self.data_aug_params['rotation_y'] = (-90. / 360 * 2. * np.pi,
                                                  90. / 360 * 2. * np.pi)
            self.data_aug_params['rotation_z'] = (-90. / 360 * 2. * np.pi,
                                                  90. / 360 * 2. * np.pi)
            if self.do_dummy_2D_aug:
                self.data_aug_params["dummy_2D"] = True
                self.print_to_log_file("Using dummy2d data augmentation")
                self.data_aug_params["elastic_deform_alpha"] = \
                    default_2D_augmentation_params["elastic_deform_alpha"]
                self.data_aug_params["elastic_deform_sigma"] = \
                    default_2D_augmentation_params["elastic_deform_sigma"]
                self.data_aug_params[
                    "rotation_x"] = default_2D_augmentation_params[
                        "rotation_x"]
        else:
            self.do_dummy_2D_aug = False
            if max(self.patch_size) / min(self.patch_size) > 1.5:
                default_2D_augmentation_params['rotation_x'] = (-180. / 360 *
                                                                2. * np.pi,
                                                                180. / 360 *
                                                                2. * np.pi)
            self.data_aug_params = default_2D_augmentation_params
        self.data_aug_params[
            "mask_was_used_for_normalization"] = self.use_mask_for_norm

        if self.do_dummy_2D_aug:
            self.basic_generator_patch_size = get_patch_size(
                self.patch_size[1:], self.data_aug_params['rotation_x'],
                self.data_aug_params['rotation_y'],
                self.data_aug_params['rotation_z'],
                self.data_aug_params['scale_range'])
            self.basic_generator_patch_size = np.array(
                [self.patch_size[0]] + list(self.basic_generator_patch_size))
            patch_size_for_spatialtransform = self.patch_size[1:]
        else:
            self.basic_generator_patch_size = get_patch_size(
                self.patch_size, self.data_aug_params['rotation_x'],
                self.data_aug_params['rotation_y'],
                self.data_aug_params['rotation_z'],
                self.data_aug_params['scale_range'])
            patch_size_for_spatialtransform = self.patch_size

        self.data_aug_params['selected_seg_channels'] = [0]
        self.data_aug_params[
            'patch_size_for_spatialtransform'] = patch_size_for_spatialtransform

        self.data_aug_params["p_rot"] = 0.3

        self.data_aug_params["scale_range"] = (0.65, 1.6)
        self.data_aug_params["p_scale"] = 0.3
        self.data_aug_params["independent_scale_factor_for_each_axis"] = True
        self.data_aug_params["p_independent_scale_per_axis"] = 0.3

        self.data_aug_params["do_elastic"] = True
        self.data_aug_params["p_eldef"] = 0.2
        self.data_aug_params["eldef_deformation_scale"] = (0, 0.25)

        self.data_aug_params["do_additive_brightness"] = True
        self.data_aug_params["additive_brightness_mu"] = 0
        self.data_aug_params["additive_brightness_sigma"] = 0.2
        self.data_aug_params["additive_brightness_p_per_sample"] = 0.3
        self.data_aug_params["additive_brightness_p_per_channel"] = 0.5

        self.data_aug_params['gamma_range'] = (0.5, 1.6)

        self.data_aug_params['num_cached_per_thread'] = 4