Example #1
0
def predict_next_stage(trainer, stage_to_be_predicted_folder):
    output_folder = join(pardir(trainer.output_folder), "pred_next_stage")
    maybe_mkdir_p(output_folder)

    process_manager = Pool(2)
    results = []

    for pat in trainer.dataset_val.keys():
        print(pat)
        data_file = trainer.dataset_val[pat]['data_file']
        data_preprocessed = np.load(data_file)['data'][:-1]
        predicted = trainer.predict_preprocessed_data_return_softmax(
            data_preprocessed, True, 1, False, 1,
            trainer.data_aug_params['mirror_axes'], True, True, 2,
            trainer.patch_size, True)
        data_file_nofolder = data_file.split("/")[-1]
        data_file_nextstage = join(stage_to_be_predicted_folder,
                                   data_file_nofolder)
        data_nextstage = np.load(data_file_nextstage)['data']
        target_shp = data_nextstage.shape[1:]
        output_file = join(
            output_folder,
            data_file_nextstage.split("/")[-1][:-4] + "_segFromPrevStage.npz")
        results.append(
            process_manager.starmap_async(
                resample_and_save, [(predicted, target_shp, output_file)]))

    _ = [i.get() for i in results]
Example #2
0
    def initialize(self, training=True, force_load_plans=False):
        if not self.was_initialized:
            maybe_mkdir_p(self.output_folder)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            ################# Here we wrap the loss for deep supervision ############
            # we need to know the number of outputs of the network
            net_numpool = len(self.net_num_pool_op_kernel_sizes)

            # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
            # this gives higher resolution outputs more weight in the loss
            weights = np.array([1 / (2 ** i) for i in range(net_numpool)])

            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
            mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
            weights[~mask] = 0
            weights = weights / weights.sum()

            # now wrap the loss
            self.loss = MultipleOutputLoss2(self.loss, weights)
            ################# END ###################

            self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
                                                      "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    print("unpacking dataset")
                    unpack_dataset(self.folder_with_preprocessed_data)
                    print("done")
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                self.tr_gen, self.val_gen = get_no_augmentation(self.dl_tr, self.dl_val,
                                                                    self.data_aug_params[
                                                                        'patch_size_for_spatialtransform'],
                                                                    self.data_aug_params,
                                                                    deep_supervision_scales=self.deep_supervision_scales)
                self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()

            assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
        else:
            self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
        self.was_initialized = True
Example #3
0
def split_4d(input_folder,
             num_processes=default_num_threads,
             overwrite_task_output_id=None):
    assert isdir(join(input_folder, "imagesTr")) and isdir(join(input_folder, "labelsTr")) and \
           isfile(join(input_folder, "dataset.json")), \
        "The input folder must be a valid Task folder from the Medical Segmentation Decathlon with at least the " \
        "imagesTr and labelsTr subfolders and the dataset.json file"

    while input_folder.endswith("/") or input_folder.endswith("\\"):
        input_folder = input_folder[:-1]

    full_task_name = Path(input_folder).parts[-1]

    assert full_task_name.startswith(
        "Task"
    ), "The input folder must point to a folder that starts with TaskXX_"

    first_underscore = full_task_name.find("_")
    assert first_underscore == 6, "Input folder start with TaskXX with XX being a 3-digit id: 00, 01, 02 etc"

    input_task_id = int(full_task_name[4:6])
    if overwrite_task_output_id is None:
        overwrite_task_output_id = input_task_id

    task_name = full_task_name[7:]

    output_folder = join(nnUNet_raw_data,
                         "Task%03.0d_" % overwrite_task_output_id + task_name)

    if isdir(output_folder):
        shutil.rmtree(output_folder)

    files = []
    output_dirs = []

    maybe_mkdir_p(output_folder)
    for subdir in ["imagesTr", "imagesTs"]:
        curr_out_dir = join(output_folder, subdir)
        if not isdir(curr_out_dir):
            maybe_mkdir_p(curr_out_dir)
        curr_dir = join(input_folder, subdir)
        nii_files = [
            join(curr_dir, i) for i in os.listdir(curr_dir)
            if i.endswith(".nii.gz")
        ]
        nii_files.sort()
        for n in nii_files:
            files.append(n)
            output_dirs.append(curr_out_dir)

    shutil.copytree(join(input_folder, "labelsTr"),
                    join(output_folder, "labelsTr"))

    p = Pool(num_processes)
    p.starmap(split_4d_nifti, zip(files, output_dirs))
    p.close()
    p.join()
    shutil.copy(join(input_folder, "dataset.json"), output_folder)
Example #4
0
def plan_and_preprocess(task_string,
                        processes_lowres=default_num_threads,
                        processes_fullres=3,
                        no_preprocessing=False):
    from nnunet.experiment_planning.experiment_planner_baseline_2DUNet import ExperimentPlanner2D
    from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner

    preprocessing_output_dir_this_task_train = join(preprocessing_output_dir,
                                                    task_string)
    cropped_out_dir = join(nnUNet_cropped_data, task_string)
    maybe_mkdir_p(preprocessing_output_dir_this_task_train)

    shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"),
                preprocessing_output_dir_this_task_train)
    shutil.copy(join(nnUNet_raw_data, task_string, "dataset.json"),
                preprocessing_output_dir_this_task_train)

    exp_planner = ExperimentPlanner(cropped_out_dir,
                                    preprocessing_output_dir_this_task_train)
    exp_planner.plan_experiment()
    if not no_preprocessing:
        exp_planner.run_preprocessing((processes_lowres, processes_fullres))

    exp_planner = ExperimentPlanner2D(
        cropped_out_dir, preprocessing_output_dir_this_task_train)
    exp_planner.plan_experiment()
    if not no_preprocessing:
        exp_planner.run_preprocessing(processes_fullres)

    # write which class is in which slice to all training cases (required to speed up 2D Dataloader)
    # This is done for all data so that if we wanted to use them with 2D we could do so

    if not no_preprocessing:
        p = Pool(default_num_threads)

        # if there is more than one my_data_identifier (different brnaches) then this code will run for all of them if
        # they start with the same string. not problematic, but not pretty
        stages = [
            i for i in subdirs(
                preprocessing_output_dir_this_task_train, join=True, sort=True)
            if Path(i).parts[-1].find("stage") != -1
        ]
        for s in stages:
            print(Path(s).parts[-1])
            list_of_npz_files = subfiles(s, True, None, ".npz", True)
            list_of_pkl_files = [i[:-4] + ".pkl" for i in list_of_npz_files]
            all_classes = []
            for pk in list_of_pkl_files:
                with open(pk, 'rb') as f:
                    props = pickle.load(f)
                all_classes_tmp = np.array(props['classes'])
                all_classes.append(all_classes_tmp[all_classes_tmp >= 0])
            p.map(add_classes_in_slice_info,
                  zip(list_of_npz_files, list_of_pkl_files, all_classes))
        p.close()
        p.join()
Example #5
0
    def run_inference_and_store_results(self,output_file_tag=''):
        output_file_base_name = output_file_tag + "_nnunet_seg.nii.gz"
        
        # passing only lists of length one to predict_cases
        for inner_list in self.data.inference_loader:
            list_of_lists = [inner_list]
            
            # output filenames (list of one) include information about patient folder name
            # infering patient folder name from all file paths for a sanity check
            # (should give the same answer)
            folder_names = [fpath.split('/')[-2] for fpath in inner_list]
            if set(folder_names) != set(folder_names[:1]):
                raise RuntimeError('Patient file paths: {} were found to come from different folders against expectation.'.format(inner_list)) 
            patient_folder_name = folder_names[0]
            output_filename = patient_folder_name + output_file_base_name
            
            final_out_folder = join(self.intermediate_out_folder, patient_folder_name)

            intermediate_output_folders = []
            
            for model_name, folds in zip(self.model_list, self.folds_list):
                output_model = join(self.intermediate_out_folder, model_name)
                intermediate_output_folders.append(output_model)
                intermediate_output_filepaths = [join(output_model, output_filename)]
                maybe_mkdir_p(output_model)
                params_folder_model = join(self.params_folder, model_name)
                
                predict_cases(model=params_folder_model, 
                            list_of_lists=list_of_lists, 
                            output_filenames=intermediate_output_filepaths, 
                            folds=folds, 
                            save_npz=True, 
                            num_threads_preprocessing=1, 
                            num_threads_nifti_save=1, 
                            segs_from_prev_stage=None, 
                            do_tta=True, 
                            mixed_precision=True,
                            overwrite_existing=True, 
                            all_in_gpu=False, 
                            step_size=0.5)

            merge(folders=intermediate_output_folders, 
                output_folder=final_out_folder, 
                threads=1, 
                override=True, 
                postprocessing_file=None, 
                store_npz=False)

            f = join(final_out_folder, output_filename)
            apply_brats_threshold(f, f, self.threshold, self.replace_with)
            load_convert_save(f)

        _ = [shutil.rmtree(i) for i in intermediate_output_folders]
Example #6
0
def predict_next_stage(trainer, stage_to_be_predicted_folder):
    output_folder = join(pardir(trainer.output_folder), "pred_next_stage")
    maybe_mkdir_p(output_folder)

    if 'segmentation_export_params' in trainer.plans.keys():
        force_separate_z = trainer.plans['segmentation_export_params'][
            'force_separate_z']
        interpolation_order = trainer.plans['segmentation_export_params'][
            'interpolation_order']
        interpolation_order_z = trainer.plans['segmentation_export_params'][
            'interpolation_order_z']
    else:
        force_separate_z = None
        interpolation_order = 1
        interpolation_order_z = 0

    export_pool = Pool(2)
    results = []

    for pat in trainer.dataset_val.keys():
        print(pat)
        data_file = trainer.dataset_val[pat]['data_file']
        data_preprocessed = np.load(data_file)['data'][:-1]

        predicted_probabilities = trainer.predict_preprocessed_data_return_seg_and_softmax(
            data_preprocessed,
            do_mirroring=trainer.data_aug_params["do_mirror"],
            mirror_axes=trainer.data_aug_params['mirror_axes'],
            mixed_precision=trainer.fp16)[1]

        data_file_nofolder = data_file.split("/")[-1]
        data_file_nextstage = join(stage_to_be_predicted_folder,
                                   data_file_nofolder)
        data_nextstage = np.load(data_file_nextstage)['data']
        target_shp = data_nextstage.shape[1:]
        output_file = join(
            output_folder,
            data_file_nextstage.split("/")[-1][:-4] + "_segFromPrevStage.npz")

        if np.prod(predicted_probabilities.shape) > (
                2e9 / 4 * 0.85):  # *0.85 just to be save
            np.save(output_file[:-4] + ".npy", predicted_probabilities)
            predicted_probabilities = output_file[:-4] + ".npy"

        results.append(
            export_pool.starmap_async(resample_and_save, [
                (predicted_probabilities, target_shp, output_file,
                 force_separate_z, interpolation_order, interpolation_order_z)
            ]))

    _ = [i.get() for i in results]
    export_pool.close()
    export_pool.join()
Example #7
0
def crop(task_string, override=False, num_threads=default_num_threads):
    cropped_out_dir = join(nnUNet_cropped_data, task_string)
    maybe_mkdir_p(cropped_out_dir)

    if override and isdir(cropped_out_dir):
        shutil.rmtree(cropped_out_dir)
        maybe_mkdir_p(cropped_out_dir)

    splitted_4d_output_dir_task = join(nnUNet_raw_data, task_string)
    lists, _ = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)

    imgcrop = ImageCropper(num_threads, cropped_out_dir)
    imgcrop.run_cropping(lists, overwrite_existing=override)
    shutil.copy(join(nnUNet_raw_data, task_string, "dataset.json"), cropped_out_dir)
    def initialize(self, training=True, force_load_plans=False):
        """
        For prediction of test cases just set training=False, this will prevent loading of training data and
        training batchgenerator initialization
        :param training:
        :return:
        """

        maybe_mkdir_p(self.output_folder)

        if force_load_plans or (self.plans is None):
            self.load_plans_file()

        self.process_plans(self.plans)

        self.setup_DA_params()

        self.folder_with_preprocessed_data = join(
            self.dataset_directory,
            self.plans['data_identifier'] + "_stage%d" % self.stage)
        if training:
            self.dl_tr, self.dl_val = self.get_basic_generators()
            if self.unpack_data:
                print("unpacking dataset")
                unpack_dataset(self.folder_with_preprocessed_data)
                print("done")
            else:
                print(
                    "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                    "will wait all winter for your model to finish!")
            self.tr_gen, self.val_gen = get_no_augmentation(
                self.dl_tr, self.dl_val,
                self.data_aug_params['patch_size_for_spatialtransform'],
                self.data_aug_params)
            self.print_to_log_file("TRAINING KEYS:\n %s" %
                                   (str(self.dataset_tr.keys())),
                                   also_print_to_console=False)
            self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                   (str(self.dataset_val.keys())),
                                   also_print_to_console=False)
        else:
            pass
        self.initialize_network()
        assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
        self.was_initialized = True
        self.data_aug_params['mirror_axes'] = ()
Example #9
0
def predict_next_stage(trainer,
                       stage_to_be_predicted_folder,
                       force_separate_z=False,
                       interpolation_order=1,
                       interpolation_order_z=0):
    output_folder = join(pardir(trainer.output_folder), "pred_next_stage")
    maybe_mkdir_p(output_folder)

    export_pool = Pool(2)
    results = []

    for pat in trainer.dataset_val.keys():
        print(pat)
        data_file = trainer.dataset_val[pat]['data_file']
        data_preprocessed = np.load(data_file)['data'][:-1]
        predicted = trainer.predict_preprocessed_data_return_softmax(
            data_preprocessed, True, 1, False, 1,
            trainer.data_aug_params['mirror_axes'], True, True, 2,
            trainer.patch_size, True)
        data_file_nofolder = data_file.split("/")[-1]
        data_file_nextstage = join(stage_to_be_predicted_folder,
                                   data_file_nofolder)
        data_nextstage = np.load(data_file_nextstage)['data']
        target_shp = data_nextstage.shape[1:]
        output_file = join(
            output_folder,
            data_file_nextstage.split("/")[-1][:-4] + "_segFromPrevStage.npz")

        if np.prod(
                predicted.shape) > (2e9 / 4 * 0.85):  # *0.85 just to be save
            np.save(output_file[:-4] + ".npy", predicted)
            predicted = output_file[:-4] + ".npy"

        results.append(
            export_pool.starmap_async(
                resample_and_save,
                [(predicted, target_shp, output_file, force_separate_z,
                  interpolation_order, interpolation_order_z)]))

    _ = [i.get() for i in results]
    export_pool.close()
    export_pool.join()
Example #10
0
def testing(unet, test_loader, batch_size, device, output_dir):

    start = time()
    start_eval_test = 0
    idx = 0

    maybe_mkdir_p(os.path.join(output_dir, 'images'))
    maybe_mkdir_p(os.path.join(output_dir, 'preds'))
    maybe_mkdir_p(os.path.join(output_dir, 'labels'))

    for batch in test_loader:

        image, label = batch

        pred = unet(image.to(device))

        pad = int((pred.shape[-1] - label.shape[-1]) / 2)
        pred = pred[:, :, pad:label.shape[-1] + pad,
                    pad:label.shape[-1] + pad].argmax(dim=1)

        save_image(
            image[0, 0, pad:label.shape[-1] + pad, pad:label.shape[-1] + pad],
            os.path.join(output_dir, 'images', f'image{idx}.tif'))
        save_image(label[0, 0, :, :].float(),
                   os.path.join(output_dir, 'labels', f'label{idx}.tif'))
        save_image(pred[0, :, :].float(),
                   os.path.join(output_dir, 'preds', f'pred{idx}.tif'))

        idx += 1

        if start_eval_test == 0:
            test_eval = evaluation_metrics(pred[0, :, :].detach(),
                                           label[0, 0, :, :].detach())
            start_eval_test += 1
        else:
            np.concatenate((test_eval,
                            evaluation_metrics(pred[0, :, :].detach(),
                                               label[0, 0, :, :].detach())),
                           axis=1)

    test = np.mean(test_eval, axis=1)
    test_std = np.std(test_eval, axis=1)

    test_iou = [test[0], test_std[0]]
    test_pe = [test[1], test_std[1]]

    np.savetxt(os.path.join(output_dir, 'test_iou.out'), test_iou)
    np.savetxt(os.path.join(output_dir, 'test_pe.out'), test_pe)

    print('Mean IoU testing:', "{:.6f}".format(test[0]))
    print('Mean PE testing :', "{:.6f}".format(test[1]))
    print('Testing took    :', "{:.6f}".format(time() - start), 's')
    print('                                                     ')

    print('Testing is finished')
Example #11
0
    samp_tr = int(np.round(tr_per * len(train_dataset)))
    samp_val = int(np.round(val_per * len(train_dataset)))

    # Round numbers so that we do not exceed total number of samples
    while samp_tr + samp_val > len(train_dataset):
        samp_val += -1

    # Generate an order vector to shuffle the samples before each fold for the cross validation
    np.random.seed(SEED)
    order = np.arange(len(train_dataset))
    np.random.shuffle(order)

    if FOLDS is None:  # Training with all available samples
        # Make directory where we save model and data
        all_dir = os.path.join(CUR_DIR, 'models', f'{DATASET}', 'all')
        maybe_mkdir_p(all_dir)

        val_dataset = ImageDataset_test(root_dir, ISBI2012=ISBI2012)

        # Suffle and load the training set
        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)
        val_loader = DataLoader(val_dataset,
                                batch_size=batch_size,
                                shuffle=True)

        torch.cuda.empty_cache()
        unet = Unet().to(device)
        if START_FROM == -1:  # load latest model
            epoch_id = max([
Example #12
0
# raw_data_base_dir = "/data0/mzs/zhx/lung_lobe_seg/galaNet_raw_data" # 原始数据保存文件夹
# preprocessed_output_dir = "/data0/mzs/zhx/lung_lobe_seg/galaNet_preprocessed" # 预处理后数据存放处
# network_output_dir_base = "/data0/mzs/zhx/lung_lobe_seg/galaNet_trained_models" # 网络存放处

# raw_data_base_dir = "/home/zenghexiang/data/zenghexiang/lung_lobe_seg/galaNet_raw_data" # 原始数据保存文件夹
# preprocessed_output_dir = "/home/zenghexiang/data/zenghexiang/lung_lobe_seg/galaNet_preprocessed" # 预处理后数据存放处
# network_output_dir_base = "/home/zenghexiang/data/zenghexiang/lung_lobe_seg/galaNet_trained_models" # 网络存放处

if raw_data_base_dir is not None:
    raw_dicom_data_dir = join(raw_data_base_dir,
                              "dicom_data")  # dicom原始数据存放文件夹
    raw_cropped_data_dir = join(raw_data_base_dir,
                                "cropped_data")  # 原始数据被crop后存放的文件夹
    raw_splited_dir = join(raw_data_base_dir, "splited_data")
    maybe_mkdir_p(raw_data_base_dir)
    maybe_mkdir_p(raw_cropped_data_dir)
else:
    raise AssertionError(
        "Attention! raw_data_base_dir is not defined! Please set raw_data_base_dir in paths.py."
    )

if preprocessed_output_dir is not None:
    maybe_mkdir_p(preprocessed_output_dir)
    maybe_mkdir_p(join(preprocessed_output_dir, preprocessed_data_identifer))
    maybe_mkdir_p(join(preprocessed_output_dir, preprocessed_net_inputs))
else:
    raise AssertionError("Attention! preprocessed_output_dir is not defined! "
                         "Please set preprocessed_output_dir in paths.py.")

if network_output_dir_base is not None:
Example #13
0
    def validate(self,
                 do_mirroring: bool = True,
                 use_sliding_window: bool = True,
                 step_size: float = 0.5,
                 save_softmax: bool = True,
                 use_gaussian: bool = True,
                 overwrite: bool = True,
                 validation_folder_name: str = 'validation_raw',
                 debug: bool = False,
                 all_in_gpu: bool = False,
                 segmentation_export_kwargs: dict = None,
                 run_postprocessing_on_folds: bool = True):
        if isinstance(self.network, DDP):
            net = self.network.module
        else:
            net = self.network
        ds = net.do_ds
        net.do_ds = False

        current_mode = self.network.training
        self.network.eval()

        assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
        if self.dataset_val is None:
            self.load_dataset()
            self.do_split()

        if segmentation_export_kwargs is None:
            if 'segmentation_export_params' in self.plans.keys():
                force_separate_z = self.plans['segmentation_export_params'][
                    'force_separate_z']
                interpolation_order = self.plans['segmentation_export_params'][
                    'interpolation_order']
                interpolation_order_z = self.plans[
                    'segmentation_export_params']['interpolation_order_z']
            else:
                force_separate_z = None
                interpolation_order = 1
                interpolation_order_z = 0
        else:
            force_separate_z = segmentation_export_kwargs['force_separate_z']
            interpolation_order = segmentation_export_kwargs[
                'interpolation_order']
            interpolation_order_z = segmentation_export_kwargs[
                'interpolation_order_z']

        # predictions as they come from the network go here
        output_folder = join(self.output_folder, validation_folder_name)
        maybe_mkdir_p(output_folder)
        # this is for debug purposes
        my_input_args = {
            'do_mirroring': do_mirroring,
            'use_sliding_window': use_sliding_window,
            'step_size': step_size,
            'save_softmax': save_softmax,
            'use_gaussian': use_gaussian,
            'overwrite': overwrite,
            'validation_folder_name': validation_folder_name,
            'debug': debug,
            'all_in_gpu': all_in_gpu,
            'segmentation_export_kwargs': segmentation_export_kwargs,
        }
        save_json(my_input_args, join(output_folder, "validation_args.json"))

        if do_mirroring:
            if not self.data_aug_params['do_mirror']:
                raise RuntimeError(
                    "We did not train with mirroring so you cannot do inference with mirroring enabled"
                )
            mirror_axes = self.data_aug_params['mirror_axes']
        else:
            mirror_axes = ()

        pred_gt_tuples = []

        export_pool = Pool(default_num_threads)
        results = []

        all_keys = list(self.dataset_val.keys())
        my_keys = all_keys[self.local_rank::dist.get_world_size()]
        # we cannot simply iterate over all_keys because we need to know pred_gt_tuples and valid_labels of all cases
        # for evaluation (which is done by local rank 0)
        for k in my_keys:
            properties = load_pickle(self.dataset[k]['properties_file'])
            fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
            pred_gt_tuples.append([
                join(output_folder, fname + ".nii.gz"),
                join(self.gt_niftis_folder, fname + ".nii.gz")
            ])
            if k in my_keys:
                if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
                        (save_softmax and not isfile(join(output_folder, fname + ".npz"))):
                    data = np.load(self.dataset[k]['data_file'])['data']

                    print(k, data.shape)
                    data[-1][data[-1] == -1] = 0

                    softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(
                        data[:-1],
                        do_mirroring=do_mirroring,
                        mirror_axes=mirror_axes,
                        use_sliding_window=use_sliding_window,
                        step_size=step_size,
                        use_gaussian=use_gaussian,
                        all_in_gpu=all_in_gpu,
                        mixed_precision=self.fp16)[1]

                    softmax_pred = softmax_pred.transpose(
                        [0] + [i + 1 for i in self.transpose_backward])

                    if save_softmax:
                        softmax_fname = join(output_folder, fname + ".npz")
                    else:
                        softmax_fname = None
                    """There is a problem with python process communication that prevents us from communicating obejcts
                    larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
                    communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
                    enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
                    patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
                    then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
                    filename or np.ndarray and will handle this automatically"""
                    if np.prod(softmax_pred.shape) > (
                            2e9 / 4 * 0.85):  # *0.85 just to be save
                        np.save(join(output_folder, fname + ".npy"),
                                softmax_pred)
                        softmax_pred = join(output_folder, fname + ".npy")

                    results.append(
                        export_pool.starmap_async(
                            save_segmentation_nifti_from_softmax,
                            ((softmax_pred,
                              join(output_folder,
                                   fname + ".nii.gz"), properties,
                              interpolation_order, self.regions_class_order,
                              None, None, softmax_fname, None,
                              force_separate_z, interpolation_order_z), )))

        _ = [i.get() for i in results]
        self.print_to_log_file("finished prediction")

        distributed.barrier()

        if self.local_rank == 0:
            # evaluate raw predictions
            self.print_to_log_file("evaluation of raw predictions")
            task = self.dataset_directory.split("/")[-1]
            job_name = self.experiment_name
            _ = aggregate_scores(pred_gt_tuples,
                                 labels=list(range(self.num_classes)),
                                 json_output_file=join(output_folder,
                                                       "summary.json"),
                                 json_name=job_name + " val tiled %s" %
                                 (str(use_sliding_window)),
                                 json_author="Fabian",
                                 json_task=task,
                                 num_threads=default_num_threads)

            if run_postprocessing_on_folds:
                # in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything
                # except the largest connected component for each class. To see if this improves results, we do this for all
                # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
                # have this applied during inference as well
                self.print_to_log_file("determining postprocessing")
                determine_postprocessing(
                    self.output_folder,
                    self.gt_niftis_folder,
                    validation_folder_name,
                    final_subf_name=validation_folder_name + "_postprocessed",
                    debug=debug)
                # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
                # They are always in that folder, even if no postprocessing as applied!

            # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
            # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
            # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
            # be used later
            gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
            maybe_mkdir_p(gt_nifti_folder)
            for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
                success = False
                attempts = 0
                e = None
                while not success and attempts < 10:
                    try:
                        shutil.copy(f, gt_nifti_folder)
                        success = True
                    except OSError as e:
                        attempts += 1
                        sleep(1)
                if not success:
                    print("Could not copy gt nifti file %s into folder %s" %
                          (f, gt_nifti_folder))
                    if e is not None:
                        raise e

        self.network.train(current_mode)
        net.do_ds = ds
Example #14
0
    def run_training(self):
        """
        if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first
        continued epoch with self.initial_lr

        we also need to make sure deep supervision in the network is enabled for training, thus the wrapper
        :return:
        """
        self.maybe_update_lr(
            self.epoch
        )  # if we dont overwrite epoch then self.epoch+1 is used which is not what we
        # want at the start of the training
        if isinstance(self.network, DDP):
            net = self.network.module
        else:
            net = self.network
        ds = net.do_ds
        net.do_ds = True

        _ = self.tr_gen.next()
        _ = self.val_gen.next()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        self._maybe_init_amp()

        maybe_mkdir_p(self.output_folder)
        self.plot_network_architecture()

        if cudnn.benchmark and cudnn.deterministic:
            warn(
                "torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. "
                "But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! "
                "If you want deterministic then set benchmark=False")

        if not self.was_initialized:
            self.initialize(True)

        while self.epoch < self.max_num_epochs:
            self.print_to_log_file("\nepoch: ", self.epoch)
            epoch_start_time = time()
            train_losses_epoch = []

            # train one epoch
            self.network.train()

            if self.use_progress_bar:
                with trange(self.num_batches_per_epoch) as tbar:
                    for b in tbar:
                        tbar.set_description("Epoch {}/{}".format(
                            self.epoch + 1, self.max_num_epochs))

                        l = self.run_iteration(self.tr_gen, True)

                        tbar.set_postfix(loss=l)
                        train_losses_epoch.append(l)
            else:
                for _ in range(self.num_batches_per_epoch):
                    l = self.run_iteration(self.tr_gen, True)
                    train_losses_epoch.append(l)

            self.all_tr_losses.append(np.mean(train_losses_epoch))
            self.print_to_log_file("train loss : %.4f" %
                                   self.all_tr_losses[-1])

            with torch.no_grad():
                # validation with train=False
                self.network.eval()
                val_losses = []
                for b in range(self.num_val_batches_per_epoch):
                    l = self.run_iteration(self.val_gen, False, True)
                    val_losses.append(l)
                self.all_val_losses.append(np.mean(val_losses))
                self.print_to_log_file("validation loss: %.4f" %
                                       self.all_val_losses[-1])

                if self.also_val_in_tr_mode:
                    self.network.train()
                    # validation with train=True
                    val_losses = []
                    for b in range(self.num_val_batches_per_epoch):
                        l = self.run_iteration(self.val_gen, False)
                        val_losses.append(l)
                    self.all_val_losses_tr_mode.append(np.mean(val_losses))
                    self.print_to_log_file(
                        "validation loss (train=True): %.4f" %
                        self.all_val_losses_tr_mode[-1])

            self.update_train_loss_MA(
            )  # needed for lr scheduler and stopping of training

            continue_training = self.on_epoch_end()

            epoch_end_time = time()

            if not continue_training:
                # allows for early stopping
                break

            self.epoch += 1
            self.print_to_log_file("This epoch took %f s\n" %
                                   (epoch_end_time - epoch_start_time))

        self.epoch -= 1  # if we don't do this we can get a problem with loading model_final_checkpoint.

        if self.save_final_checkpoint:
            self.save_checkpoint(
                join(self.output_folder, "model_final_checkpoint.model"))

        if self.local_rank == 0:
            # now we can delete latest as it will be identical with final
            if isfile(join(self.output_folder, "model_latest.model")):
                os.remove(join(self.output_folder, "model_latest.model"))
            if isfile(join(self.output_folder, "model_latest.model.pkl")):
                os.remove(join(self.output_folder, "model_latest.model.pkl"))

        net.do_ds = ds
Example #15
0
    def initialize(self, training=True, force_load_plans=False):
        """
        :param training:
        :return:
        """
        if not self.was_initialized:
            maybe_mkdir_p(self.output_folder)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            self.folder_with_preprocessed_data = join(
                self.dataset_directory,
                self.plans['data_identifier'] + "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    if self.local_rank == 0:
                        print("unpacking dataset")
                        unpack_dataset(self.folder_with_preprocessed_data)
                        print("done")
                    distributed.barrier()
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                # setting weights for deep supervision losses
                net_numpool = len(self.net_num_pool_op_kernel_sizes)

                # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
                # this gives higher resolution outputs more weight in the loss
                weights = np.array([1 / (2**i) for i in range(net_numpool)])

                # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
                mask = np.array([
                    True if i < net_numpool - 1 else False
                    for i in range(net_numpool)
                ])
                weights[~mask] = 0
                weights = weights / weights.sum()
                self.ds_loss_weights = weights

                seeds_train = np.random.random_integers(
                    0, 99999, self.data_aug_params.get('num_threads'))
                seeds_val = np.random.random_integers(
                    0, 99999,
                    max(self.data_aug_params.get('num_threads') // 2, 1))
                print("seeds train", seeds_train)
                print("seeds_val", seeds_val)
                self.tr_gen, self.val_gen = get_moreDA_augmentation(
                    self.dl_tr,
                    self.dl_val,
                    self.data_aug_params['patch_size_for_spatialtransform'],
                    self.data_aug_params,
                    deep_supervision_scales=self.deep_supervision_scales,
                    seeds_train=seeds_train,
                    seeds_val=seeds_val,
                    pin_memory=self.pin_memory)
                self.print_to_log_file("TRAINING KEYS:\n %s" %
                                       (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                       (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()
            self.network = DDP(self.network, device_ids=[self.local_rank])

        else:
            self.print_to_log_file(
                'self.was_initialized is True, not running self.initialize again'
            )
        self.was_initialized = True
Example #16
0
    def initialize(self, training=True, force_load_plans=False):
        """
        For prediction of test cases just set training=False, this will prevent loading of training data and
        training batchgenerator initialization
        :param training:
        :return:
        """
        if not self.was_initialized:
            maybe_mkdir_p(self.output_folder)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            self.folder_with_preprocessed_data = join(
                self.dataset_directory,
                self.plans['data_identifier'] + "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    if self.local_rank == 0:
                        print("unpacking dataset")
                        unpack_dataset(self.folder_with_preprocessed_data)
                        print("done")
                    else:
                        # we need to wait until worker 0 has finished unpacking
                        npz_files = subfiles(
                            self.folder_with_preprocessed_data,
                            suffix=".npz",
                            join=False)
                        case_ids = [i[:-4] for i in npz_files]
                        all_present = all([
                            isfile(
                                join(self.folder_with_preprocessed_data,
                                     i + ".npy")) for i in case_ids
                        ])
                        while not all_present:
                            print("worker", self.local_rank,
                                  "is waiting for unpacking")
                            sleep(3)
                            all_present = all([
                                isfile(
                                    join(self.folder_with_preprocessed_data,
                                         i + ".npy")) for i in case_ids
                            ])
                        # there is some slight chance that there may arise some error because dataloader are loading a file
                        # that is still being written by worker 0. We ignore this for now an address it only if it becomes
                        # relevant
                        # (this can occur because while worker 0 writes the file is technically present so the other workers
                        # will proceed and eventually try to read it)
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                # setting weights for deep supervision losses
                net_numpool = len(self.net_num_pool_op_kernel_sizes)

                # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
                # this gives higher resolution outputs more weight in the loss
                weights = np.array([1 / (2**i) for i in range(net_numpool)])

                # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
                mask = np.array([
                    True if i < net_numpool - 1 else False
                    for i in range(net_numpool)
                ])
                weights[~mask] = 0
                weights = weights / weights.sum()
                self.ds_loss_weights = weights

                seeds_train = np.random.random_integers(
                    0, 99999, self.data_aug_params.get('num_threads'))
                seeds_val = np.random.random_integers(
                    0, 99999,
                    max(self.data_aug_params.get('num_threads') // 2, 1))
                print("seeds train", seeds_train)
                print("seeds_val", seeds_val)
                self.tr_gen, self.val_gen = get_moreDA_augmentation(
                    self.dl_tr,
                    self.dl_val,
                    self.data_aug_params['patch_size_for_spatialtransform'],
                    self.data_aug_params,
                    deep_supervision_scales=self.deep_supervision_scales,
                    seeds_train=seeds_train,
                    seeds_val=seeds_val)
                self.print_to_log_file("TRAINING KEYS:\n %s" %
                                       (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                       (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()
            self._maybe_init_amp()
            self.network = DDP(self.network)

        else:
            self.print_to_log_file(
                'self.was_initialized is True, not running self.initialize again'
            )
        self.was_initialized = True
Example #17
0
default_data_identifier = 'nnUNet'
default_trainer = "nnUNetTrainerV2"
default_cascade_trainer = "nnUNetTrainerV2CascadeFullRes"

"""
PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP
"""

base = os.environ['nnUNet_raw_data_base'] if "nnUNet_raw_data_base" in os.environ.keys() else None
preprocessing_output_dir = os.environ['nnUNet_preprocessed'] if "nnUNet_preprocessed" in os.environ.keys() else None
network_training_output_dir_base = os.path.join(os.environ['RESULTS_FOLDER']) if "RESULTS_FOLDER" in os.environ.keys() else None

if base is not None:
    nnUNet_raw_data = join(base, "nnUNet_raw_data")
    nnUNet_cropped_data = join(base, "nnUNet_cropped_data")
    maybe_mkdir_p(nnUNet_raw_data)
    maybe_mkdir_p(nnUNet_cropped_data)
else:
    print("nnUNet_raw_data_base is not defined and nnU-Net can only be used on data for which preprocessed files "
          "are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like "
          "this. If this is not intended, please read nnunet/paths.md for information on how to set this up properly.")
    nnUNet_cropped_data = nnUNet_raw_data = None

if preprocessing_output_dir is not None:
    maybe_mkdir_p(preprocessing_output_dir)
else:
    print("nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing "
          "or training. If this is not intended, please read nnunet/pathy.md for information on how to set this up.")
    preprocessing_output_dir = None

if network_training_output_dir_base is not None:
Example #18
0
    def initialize(self, training=True, force_load_plans=False):
        '''
        Print keys to visdom and set number of epochs
        '''
        timestamp = datetime.now()
        if self.usevisdom and training:
            try:
                self.plotter = get_plotter(self.model_name)
                self.plotter.plot_text(
                    "Initialising this model: %s <br> on %d_%d_%d_%02.0d_%02.0d_%02.0d"
                    % (self.model_name, timestamp.year, timestamp.month,
                       timestamp.day, timestamp.hour, timestamp.minute,
                       timestamp.second),
                    plot_name="Welcome")
            except:
                print("Unable to connect to visdom.")

        #super().initialize(training, force_load_plans)
        ## ------- nnunettrainerv2 nodeepsupervision
        """
        removed deep supervision
        :return:
        """
        if not self.was_initialized:
            maybe_mkdir_p(self.output_folder)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            self.folder_with_preprocessed_data = join(
                self.dataset_directory,
                self.plans['data_identifier'] + "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    print("unpacking dataset")
                    unpack_dataset(self.folder_with_preprocessed_data)
                    print("done")
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                assert self.deep_supervision_scales is None
                self.tr_gen, self.val_gen = get_moreDA_augmentation(
                    self.dl_tr,
                    self.dl_val,
                    self.data_aug_params['patch_size_for_spatialtransform'],
                    self.data_aug_params,
                    deep_supervision_scales=self.deep_supervision_scales,
                    classes=None,
                    pin_memory=self.pin_memory)

                self.print_to_log_file("TRAINING KEYS:\n %s" %
                                       (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                       (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()

            assert isinstance(self.network,
                              (SegmentationNetwork, nn.DataParallel))
        else:
            self.print_to_log_file(
                'self.was_initialized is True, not running self.initialize again'
            )
        self.was_initialized = True

        # -----------------------

        if self.freeze:
            self.initialize_optimizer_and_scheduler_freezing()

        if training and self.usevisdom:
            try:
                self.plotter.plot_text(
                    "EPOCHS: %s <br> LEARNING RATE: %s <br> TRAINING KEYS: %s <br> VALIDATION KEYS: %s"
                    % (str(self.max_num_epochs), str(self.initial_lr),
                       str(self.dataset_tr.keys()), str(
                           self.dataset_val.keys())),
                    plot_name="Dataset_Info")
            except:
                print("Unable to connect to visdom.")
Example #19
0
    input_folder = args.input_folder
    output_folder = args.output_folder
    part_id = args.part_id
    num_parts = args.num_parts
    folds = args.folds
    save_npz = args.save_npz
    lowres_segmentations = args.lowres_segmentations
    num_threads_preprocessing = args.num_threads_preprocessing
    num_threads_nifti_save = args.num_threads_nifti_save
    tta = args.tta
    overwrite = args.overwrite_existing

    output_folder_name = join(
        network_training_output_dir, args.model, args.task_name,
        args.nnunet_trainer + "__" + args.plans_identifier)
    maybe_mkdir_p(output_folder_name)
    print("using model stored in ", output_folder_name)
    assert isdir(output_folder_name
                 ), "model output folder not found: %s" % output_folder_name

    if lowres_segmentations == "None":
        lowres_segmentations = None

    if isinstance(folds, list):
        if folds[0] == 'all' and len(folds) == 1:
            pass
        else:
            folds = [int(i) for i in folds]
    elif folds == "None":
        folds = None
    else:
Example #20
0
def training(unet, train_loader, val_loader, epochs, batch_size, device,
             fold_dir, DATASET):

    # Set goals for training to end
    if DATASET is 'DIC-C2DH-HeLa':
        when_to_stop = 0
        goal = 0.7756  # IoU value from table 2 in Ronneberger et al. (2015)
    elif DATASET is 'ISBI2012':
        when_to_stop = 1
        goal = 0.0611  # PE value from table 1 in Ronneberger et al. (2015)
    elif DATASET is 'PhC-C2DH-U373':
        when_to_stop = 2
        goal = 0.9203  # IoU value from table 2 in Ronneberger et al. (2015)
    else:
        when_to_stop = None

    optimizer = optim.SGD(unet.parameters(), lr=0.0001, momentum=0.99)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min',
                                                     factor=0.1,
                                                     patience=30,
                                                     threshold=1e-3,
                                                     threshold_mode='rel',
                                                     eps=1e-7)
    my_patience = 0

    maybe_mkdir_p(os.path.join(fold_dir, 'progress'))
    maybe_mkdir_p(os.path.join(fold_dir, 'models'))

    loss_best_epoch = 100000.0

    for epoch in range(epochs + 1):

        print(' ')
        print('Epoch:', epoch)

        start = time()
        total_loss = 0
        total_loss_val = 0
        start_eval_train = 0
        start_eval_val = 0

        torch.cuda.empty_cache()

        for batch in train_loader:

            optimizer.zero_grad()

            images, labels = batch

            preds = unet(images.to(device))  # pass batch to the unet

            pad = int((preds.shape[-1] - labels.shape[-1]) / 2)
            preds = preds[:, :, pad:labels.shape[-1] + pad,
                          pad:labels.shape[-1] + pad]

            ll = torch.empty_like(preds)
            ll[:, 0, :, :] = 1 - labels[:, 0, :, :]  # background
            ll[:, 1, :, :] = labels[:, 0, :, :]  # cell
            ll = ll.to(device)

            if DATASET is 'DIC-C2DH-HeLa':
                weight_maps = weighted_map(labels.squeeze(1)).to(device)
                criterion = nn.BCEWithLogitsLoss(weight=weight_maps)
            else:
                weight_maps = class_balance(labels.squeeze(1)).to(device)
                criterion = nn.BCEWithLogitsLoss(weight=weight_maps)

            loss = criterion(preds, ll)

            loss.backward()  # compute the gradients using backprop
            optimizer.step()  # update the weights

            total_loss += loss

            preds = preds.argmax(dim=1)

            for idx in range(preds.shape[0]):
                if start_eval_train == 0 and idx == 0:  # First time in epoch we initialize train_eval
                    train_eval = evaluation_metrics(
                        preds[idx, :, :].detach(), labels[idx,
                                                          0, :, :].detach())
                    start_eval_train += 1
                else:
                    np.concatenate(
                        (train_eval,
                         evaluation_metrics(preds[idx, :, :].detach(),
                                            labels[idx, 0, :, :].detach())),
                        axis=1)

        train_eval_epoch = np.mean(train_eval, axis=1)

        torch.cuda.empty_cache()

        with torch.no_grad():
            for batch in val_loader:

                images, labels = batch

                preds = unet(images.to(device))

                pad = int((preds.shape[-1] - labels.shape[-1]) / 2)
                preds = preds[:, :, pad:labels.shape[-1] + pad,
                              pad:labels.shape[-1] + pad]

                ll = torch.empty_like(preds)
                ll[:, 0, :, :] = 1 - labels[:, 0, :, :]  # background
                ll[:, 1, :, :] = labels[:, 0, :, :]  # cell
                ll = ll.to(device)

                if DATASET is 'DIC-C2DH-HeLa':
                    weight_maps = weighted_map(labels.squeeze(1)).to(device)
                    criterion = nn.BCEWithLogitsLoss(weight=weight_maps)
                else:
                    weight_maps = class_balance(labels.squeeze(1)).to(device)
                    criterion = nn.BCEWithLogitsLoss(weight=weight_maps)

                loss = criterion(preds, ll)

                total_loss_val += loss

                preds = preds.argmax(1)

                for idx in range(preds.shape[0]):
                    if start_eval_val == 0 and idx == 0:  # First time in epoch we initialize val_eval
                        val_eval = evaluation_metrics(
                            preds[idx, :, :].detach(),
                            labels[idx, 0, :, :].detach())
                        start_eval_val += 1
                    else:
                        np.concatenate((val_eval,
                                        evaluation_metrics(
                                            preds[idx, :, :].detach(),
                                            labels[idx, 0, :, :].detach())),
                                       axis=1)

        val_eval_epoch = np.mean(val_eval, axis=1)

        scheduler.step(total_loss_val /
                       (len(val_loader) * batch_size))  # update the lr

        for param_group in optimizer.param_groups:
            l_rate = param_group['lr']

        loss_epoch = total_loss / (len(train_loader) * batch_size)
        loss_epoch_val = total_loss_val / (len(val_loader) * batch_size)

        if loss_epoch_val < (loss_best_epoch * (1.0 - scheduler.threshold)):
            loss_best_epoch = loss_epoch_val
            print('New best epoch!')
            my_patience = 0
            PATH = os.path.join(fold_dir, 'models',
                                'unet_weight_save_best.pth')
            torch.save(unet.state_dict(), PATH)
            print('Model has been saved:')
            print(PATH)
        else:
            my_patience += 1

        print('Current lr is:             ', l_rate)
        print('Patience is:                {}/{}'.format(
            my_patience, scheduler.patience))
        print('Mean IoU training:         ',
              "{:.6f}".format(train_eval_epoch[0]))
        print('Mean PE training:          ',
              "{:.6f}".format(train_eval_epoch[1]))
        print('Mean IoU validation:       ',
              "{:.6f}".format(val_eval_epoch[0]))
        print('Mean PE validation:        ',
              "{:.6f}".format(val_eval_epoch[1]))
        print('Total training loss:       ',
              "{:.6f}".format(loss_epoch.item()))
        print('Total validation loss:     ',
              "{:.6f}".format(loss_epoch_val.item()))
        print('Best epoch validation loss:',
              "{:.6f}".format(loss_best_epoch.item()))
        print('Epoch duration:            ', "{:.6f}".format(time() - start),
              's')
        print(
            '                                                                         '
        )

        # Save progress (evaluation metrics and loss)
        if epoch == 0:
            train_eval_progress_iou = [train_eval_epoch[0]]
            train_eval_progress_pe = [train_eval_epoch[1]]
            val_eval_progress_iou = [val_eval_epoch[0]]
            val_eval_progress_pe = [val_eval_epoch[1]]
            loss_progress = [loss_epoch.item()]
            loss_progress_val = [loss_epoch_val.item()]
        elif epoch > 0:
            train_eval_progress_iou = np.concatenate(
                (train_eval_progress_iou, [train_eval_epoch[0]]))
            train_eval_progress_pe = np.concatenate(
                (train_eval_progress_pe, [train_eval_epoch[1]]))
            val_eval_progress_iou = np.concatenate(
                (val_eval_progress_iou, [val_eval_epoch[0]]))
            val_eval_progress_pe = np.concatenate(
                (val_eval_progress_pe, [val_eval_epoch[1]]))
            loss_progress = np.append(loss_progress, [loss_epoch.item()])
            loss_progress_val = np.append(loss_progress_val,
                                          [loss_epoch_val.item()])

        np.savetxt(os.path.join(fold_dir, 'progress', 'train_eval_iou.out'),
                   train_eval_progress_iou)
        np.savetxt(os.path.join(fold_dir, 'progress', 'train_eval_pe.out'),
                   train_eval_progress_pe)
        np.savetxt(os.path.join(fold_dir, 'progress', 'val_eval_iou.out'),
                   val_eval_progress_iou)
        np.savetxt(os.path.join(fold_dir, 'progress', 'val_eval_pe.out'),
                   val_eval_progress_pe)
        np.savetxt(os.path.join(fold_dir, 'progress', 'loss.out'),
                   loss_progress)
        np.savetxt(os.path.join(fold_dir, 'progress', 'loss_val.out'),
                   loss_progress_val)

        if when_to_stop == 0:
            if val_eval_epoch[0] > goal:
                PATH = os.path.join(fold_dir, 'models',
                                    'unet_weight_save_{}.pth'.format(DATASET))
                torch.save(unet.state_dict(), PATH)
                print('The goal was reached in epoch {}!'.format(epoch))
                print('Model has been saved:')
                print(PATH)
                # break
                when_to_stop = None
            continue
        elif when_to_stop == 1:
            if val_eval_epoch[0] > goal:
                PATH = os.path.join(fold_dir, 'models',
                                    'unet_weight_save_{}.pth'.format(DATASET))
                torch.save(unet.state_dict(), PATH)
                print('The goal was reached in epoch {}!'.format(epoch))
                print('Model has been saved:')
                print(PATH)
                # break
                when_to_stop = None
            continue
        elif when_to_stop == 2:
            if val_eval_epoch[0] > goal:
                PATH = os.path.join(fold_dir, 'models',
                                    'unet_weight_save_{}.pth'.format(DATASET))
                torch.save(unet.state_dict(), PATH)
                print('The goal was reached in epoch {}!'.format(epoch))
                print('Model has been saved:')
                print(PATH)
                # break
                when_to_stop = None
            continue

        # Save model every 50 epochs
        if epoch % 25 == 0:
            PATH = os.path.join(fold_dir, 'models',
                                'unet_weight_save_latest.pth')
            torch.save(unet.state_dict(), PATH)
            print('Model has been saved:')
            print(PATH)

        if l_rate < 10 * scheduler.eps and my_patience == scheduler.patience:
            print(f'LR dropped below {10 * scheduler.eps}!')
            print('Stopping training')
            print(' ')
            PATH = os.path.join(fold_dir, 'models',
                                'unet_weight_save_latest.pth')
            torch.save(unet.state_dict(), PATH)
            print('Model has been saved:')
            print(PATH)
            break

        if my_patience == scheduler.patience: my_patience = -1

    print('Training is finished as epoch {} has been reached'.format(epoch))
    print('                                                               ')
Example #21
0
# -*- coding: utf-8 -*-
"""
Created on Sat Jul 10 16:11:54 2021

@author: linhai
"""

import sys
import inspect
import os
from pathlib import Path
from batchgenerators.utilities.file_and_folder_operations import join, isdir, maybe_mkdir_p, subfiles, subdirs, isfile

#print (sys.path)
curDir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parDir = os.path.dirname(curDir)
sys.path.insert(0, parDir)
#sys.path.insert(0, )
p1 = 'C:\\Research\\IMA_on_segmentation\\nnUnet\\nnUNet\\rawData\\nnUNet_raw_data\\Task05_Prostate\\imagesTr'
p2 = 'C:/Research/IMA_on_segmentation/nnUnet/nnUNet/rawData/nnUNet_raw_data\\Task05_Prostate'
p3 = 'C:\\Research\\IMA_on_segmentation\\aaa'
p4 = 'C:/Research/IMA_on_segmentation/333/aaab'
print (os.path.join(p1, "aaa")+"\\")
print (isdir(join(p1, "aaa")+"\\"))
print(p1)
print (isdir(p2))
#os.mkdir(p4)
maybe_mkdir_p(p4)
#os.makedirs(p4, exist_ok=True)
Example #22
0
# do not modify these unless you know what you are doing
my_output_identifier = "nnUNet"
default_plans_identifier = "nnUNetPlans"
default_data_identifier = 'nnUNet'

try:
    # base is the folder where the raw data is stored. You just need to set base only, the others will be created
    # automatically (they are subfolders of base).
    # Here I use environment variables to set the base folder. Environment variables allow me to use the same code on
    # different systems (and our compute cluster). You can replace this line with something like:
    # base = "/path/to/my/folder"
    base = os.environ['nnUNet_base']
    raw_dataset_dir = join(base, "nnUNet_raw")
    splitted_4d_output_dir = join(base, "nnUNet_raw_splitted")
    cropped_output_dir = join(base, "nnUNet_raw_cropped")
    maybe_mkdir_p(splitted_4d_output_dir)
    maybe_mkdir_p(raw_dataset_dir)
    maybe_mkdir_p(cropped_output_dir)
except KeyError:
    cropped_output_dir = splitted_4d_output_dir = raw_dataset_dir = base = None

# preprocessing_output_dir is where the preprocessed data is stored. If you run a training I very strongly recommend
# this is a SSD!
try:
    # Here I use environment variables to set the folder. Environment variables allow me to use the same code on
    # different systems (and our compute cluster). You can replace this line with something like:
    # preprocessing_output_dir = "/path/to/my/folder_with_preprocessed_data"
    preprocessing_output_dir = os.environ['nnUNet_preprocessed']
except KeyError:
    preprocessing_output_dir = None
Example #23
0
        [join(nnunet.__path__[0], "training", "network_training")],
        trainerclass, "nnunet.training.network_training")

    if trainer_class is None:
        raise RuntimeError(
            "Could not find trainer class in nnunet.training.network_training")
    else:
        assert issubclass(
            trainer_class, nnUNetTrainer
        ), "network_trainer was found but is not derived from nnUNetTrainer"

    trainer = trainer_class(plans_file,
                            fold,
                            folder_with_preprocessed_data,
                            output_folder=output_folder_name,
                            dataset_directory=dataset_directory,
                            batch_dice=batch_dice,
                            stage=stage)

    trainer.initialize(False)
    trainer.load_dataset()
    trainer.do_split()
    trainer.load_best_checkpoint(train=False)

    stage_to_be_predicted_folder = join(
        dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1)
    output_folder = join(pardir(trainer.output_folder), "pred_next_stage")
    maybe_mkdir_p(output_folder)

    predict_next_stage(trainer, stage_to_be_predicted_folder)