예제 #1
0
    def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch):
        """
        Trains the network a single epoch

        Parameters
        ----------
        batchgen : MultiThreadedAugmenter
            Generator yielding the training batches
        epoch : int
            current epoch

        """
        self.module.training = True

        n_batches = batchgen.generator.num_batches * batchgen.num_processes
        pbar = tqdm(enumerate(batchgen),
                    unit=' batch',
                    total=n_batches,
                    desc='Epoch %d' % epoch)

        for batch_nr, batch in pbar:

            data_dict = batch

            _, _, _ = self.closure_fn(self.module,
                                      data_dict,
                                      optimizers=self.optimizers,
                                      losses=self.losses,
                                      metrics=self.metrics,
                                      fold=self.fold,
                                      batch_nr=batch_nr)

        batchgen._finish()
예제 #2
0
        np_prediction[np_prediction > 0] = 1 # tumor core
        np_prediction[np_prediction < 0] = 0

    np_cut = center_crop_3D_image(np_prediction[0,0], patient_data.shape[2:])

    # if args.multi_class:
    #    dice = np_dice_multi_class(np_cut, patient_data[0,3,:,:,:])
    # else:
    #    dice = np_dice(np_cut, patient_data[0,3,:,:,:])
    # logging.info("{}, {}".format(idx, dice))
    # dices.append(dice)

    # repair labels
    np_cut[np_cut == 3] = 4
    print(f"seg output shape is {np_cut.shape}")

    output_path = '/'.join(target_patients[idx].split('/')[-2:])
    output_path = os.path.join('augmented_segmentation_output', args.model_name, output_path + '.nii.gz')

    print(f"The output path is {output_path}")
    if not os.path.exists(os.path.dirname(output_path)):
        try:
            os.makedirs(os.path.dirname(output_path))
        except OSError as exc: # Guard against race condition
            logging.info('An error occured when trying to create the saving directory!')

    save_segmentation_as_nifti(np_cut, meta_data, output_path)

    test_gen_new._finish()
    del test_dl_new
    continue