def update_loss(self):
        # we train the first 500 epochs with CE, then transition to Dice between 500 and 750. The last 250 epochs will be Dice only

        if self.epoch <= 500:
            weight_ce = 2
            weight_dice = 0
        elif 500 < self.epoch <= 750:
            weight_ce = 2 - 2 / 250 * (self.epoch - 500)
            weight_dice = 0 + 2 / 250 * (self.epoch - 500)
        elif 750 < self.epoch <= self.max_num_epochs:
            weight_ce = 0
            weight_dice = 2
        else:
            raise RuntimeError("Invalid epoch: %d" % self.epoch)

        self.print_to_log_file("weight ce", weight_ce, "weight dice",
                               weight_dice)

        self.loss = DC_and_CE_loss(
            {
                'batch_dice': self.batch_dice,
                'smooth': 1e-5,
                'do_bg': False
            }, {},
            weight_ce=weight_ce,
            weight_dice=weight_dice)

        self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
示例#2
0
    def initialize(self, training=True, force_load_plans=False):
        if not self.was_initialized:
            maybe_mkdir_p(self.output_folder)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            ################# Here we wrap the loss for deep supervision ############
            # we need to know the number of outputs of the network
            net_numpool = len(self.net_num_pool_op_kernel_sizes)

            # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
            # this gives higher resolution outputs more weight in the loss
            weights = np.array([1 / (2**i) for i in range(net_numpool)])

            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
            mask = np.array([True] + [
                True if i < net_numpool - 1 else False
                for i in range(1, net_numpool)
            ])
            weights[~mask] = 0
            weights = weights / weights.sum()
            self.ds_loss_weights = weights
            # now wrap the loss
            self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
            ################# END ###################

            self.folder_with_preprocessed_data = join(
                self.dataset_directory,
                self.plans['data_identifier'] + "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    print("unpacking dataset")
                    unpack_dataset(self.folder_with_preprocessed_data)
                    print("done")
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                self.tr_gen, self.val_gen = get_insaneDA_augmentation2(
                    self.dl_tr,
                    self.dl_val,
                    self.data_aug_params['patch_size_for_spatialtransform'],
                    self.data_aug_params,
                    deep_supervision_scales=self.deep_supervision_scales,
                    pin_memory=self.pin_memory)
                self.print_to_log_file("TRAINING KEYS:\n %s" %
                                       (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                       (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()

            assert isinstance(self.network,
                              (SegmentationNetwork, nn.DataParallel))
        else:
            self.print_to_log_file(
                'self.was_initialized is True, not running self.initialize again'
            )
        self.was_initialized = True
    def initialize(self, training=True, force_load_plans=False):
        if not self.was_initialized:
            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)
            self.setup_DA_params()

            ################# Here we wrap the loss for deep supervision ############
            net_numpool = len(self.net_num_pool_op_kernel_sizes)
            weights = np.array([1 / (2**i) for i in range(net_numpool)])
            mask = np.array([
                True if i < net_numpool - 1 else False
                for i in range(net_numpool)
            ])
            weights[~mask] = 0
            weights = weights / weights.sum()
            self.ds_loss_weights = weights
            self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
            ################# END ###################

            self.folder_with_preprocessed_data = join(
                self.dataset_directory,
                self.plans['data_identifier'] + "_stage%d" % self.stage)

            if training:
                if not isdir(self.folder_with_segs_from_prev_stage):
                    raise RuntimeError(
                        "Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the "
                        "segmentations for the next stage")

                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    print("unpacking dataset")
                    unpack_dataset(self.folder_with_preprocessed_data)
                    print("done")
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                self.tr_gen, self.val_gen = get_moreDA_augmentation(
                    self.dl_tr,
                    self.dl_val,
                    self.data_aug_params['patch_size_for_spatialtransform'],
                    self.data_aug_params,
                    deep_supervision_scales=self.deep_supervision_scales,
                    pin_memory=self.pin_memory)
                self.print_to_log_file("TRAINING KEYS:\n %s" %
                                       (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                       (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()

            assert isinstance(self.network,
                              (SegmentationNetwork, nn.DataParallel))
        else:
            self.print_to_log_file(
                'self.was_initialized is True, not running self.initialize again'
            )

        self.was_initialized = True