def run_main(context):
    no_transform = torch_transforms.Compose([
        imed_transforms.CenterCrop([128, 128]),
        imed_transforms.NumpyToTensor(),
        imed_transforms.NormalizeInstance(),
    ])

    out_dir = context["path_output"]
    split_dct = joblib.load(os.path.join(out_dir, "split_datasets.joblib"))
    metadata_dct = {}
    for subset in ['train', 'valid', 'test']:
        metadata_dct[subset] = {}
        ds = imed_loader.BidsDataset(context["path_data"],
                                     subject_lst=split_dct[subset],
                                     contrast_lst=context["contrast_train_validation"]
                                     if subset != "test" else context["contrast_test"],
                                     transform=no_transform,
                                     slice_filter_fn=imed_loader_utils.SliceFilter())

        for m in metadata_type:
            if m in metadata_dct:
                metadata_dct[subset][m] = [v for m_lst in [metadata_dct[subset][m], ds.metadata[m]] for v in m_lst]
            else:
                metadata_dct[subset][m] = ds.metadata[m]

    cluster_dct = joblib.load(os.path.join(out_dir, "clustering_models.joblib"))

    out_dir = os.path.join(out_dir, "cluster_metadata")
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)

    for m in metadata_type:
        values = [v for s in ['train', 'valid', 'test'] for v in metadata_dct[s][m]]
        print('\n{}: Min={}, Max={}, Median={}'.format(m, min(values), max(values), np.median(values)))
        plot_decision_boundaries(metadata_dct, cluster_dct[m], metadata_range[m], m, os.path.join(out_dir, m + '.png'))
Example #2
0
def run_main(args):
    with open(args.c, "r") as fhandle:
        context = json.load(fhandle)

    transform_lst = torch_transforms.Compose([
        imed_transforms.Resample(wspace=0.75, hspace=0.75),
        imed_transforms.CenterCrop([128, 128]),
        imed_transforms.NumpyToTensor(),
        imed_transforms.NormalizeInstance(),
    ])

    train_lst, valid_lst, test_lst = imed_loader_utils.split_dataset(
        context["bids_path"], context["center_test"], context["split_method"],
        context["random_seed"])

    balance_dct = {}
    for ds_lst, ds_name in zip([train_lst, valid_lst, test_lst],
                               ['train', 'valid', 'test']):
        print("\nLoading {} set.\n".format(ds_name))
        ds = imed_loader.BidsDataset(
            context["bids_path"],
            subject_lst=ds_lst,
            target_suffix=context["target_suffix"],
            contrast_lst=context["contrast_test"]
            if ds_name == 'test' else context["contrast_train_validation"],
            metadata_choice=context["metadata"],
            contrast_balance=context["contrast_balance"],
            transform=transform_lst,
            slice_filter_fn=imed_utils.SliceFilter())

        print("Loaded {} axial slices for the {} set.".format(
            len(ds), ds_name))
        ds_loader = DataLoader(ds,
                               batch_size=1,
                               shuffle=False,
                               pin_memory=False,
                               collate_fn=imed_loader_utils.imed_collate,
                               num_workers=1)

        balance_lst = []
        for i, batch in enumerate(ds_loader):
            gt_sample = batch["gt"].numpy().astype(np.int)[0, 0, :, :]
            nb_ones = (gt_sample == 1).sum()
            nb_voxels = gt_sample.size
            balance_lst.append(nb_ones * 100.0 / nb_voxels)

        balance_dct[ds_name] = balance_lst

    for ds_name in balance_dct:
        print('\nClass balance in {} set:'.format(ds_name))
        print_stats(balance_dct[ds_name])

    print('\nClass balance in full set:')
    print_stats([e for d in balance_dct for e in balance_dct[d]])
Example #3
0
def run_main(context):
    no_transform = torch_transforms.Compose([
        imed_transforms.CenterCrop([128, 128]),
        imed_transforms.NumpyToTensor(),
        imed_transforms.NormalizeInstance(),
    ])

    out_dir = context["log_directory"]
    metadata_dct = {}
    for subset in ['train', 'validation', 'test']:
        metadata_dct[subset] = {}
        for bids_ds in tqdm(context["bids_path_" + subset],
                            desc="Loading " + subset + " set"):
            ds = imed_loader.BidsDataset(
                bids_ds,
                contrast_lst=context["contrast_train_validation"]
                if subset != "test" else context["contrast_test"],
                transform=no_transform,
                slice_filter_fn=imed_loader_utils.SliceFilter())

            for m in metadata_type:
                if m in metadata_dct:
                    metadata_dct[subset][m] = [
                        v
                        for m_lst in [metadata_dct[subset][m], ds.metadata[m]]
                        for v in m_lst
                    ]
                else:
                    metadata_dct[subset][m] = ds.metadata[m]

        for m in metadata_type:
            metadata_dct[subset][m] = list(set(metadata_dct[subset][m]))

    with open(out_dir + "/metadata_config.json", 'w') as fp:
        json.dump(metadata_dct, fp)

    return
Example #4
0
def test_image_orientation():
    device = torch.device("cuda:" + str(GPU_NUMBER) if torch.cuda.is_available() else "cpu")
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.set_device(device)
        print("Using GPU number {}".format(device))

    train_lst = ['sub-unf01']

    training_transform_dict = {
        "Resample":
            {
                "wspace": 1.5,
                "hspace": 1,
                "dspace": 3
            },
        "CenterCrop":
            {
                "size": [176, 128, 160]
            },
        "NumpyToTensor": {},
        "NormalizeInstance": {"applied_to": ['im']}
    }

    tranform_lst, training_undo_transform = imed_transforms.prepare_transforms(training_transform_dict)

    model_params = {
            "name": "Modified3DUNet",
            "dropout_rate": 0.3,
            "bn_momentum": 0.9,
            "depth": 2,
            "in_channel": 1,
            "out_channel": 1,
            "length_3D": [176, 128, 160],
            "stride_3D": [176, 128, 160],
            "attention": False,
            "n_filters": 8
        }
    contrast_params = {
        "contrast_lst": ['T1w'],
        "balance": {}
    }

    for dim in ['2d', '3d']:
        for slice_axis in [0, 1, 2]:
            if dim == '2d':
                ds = imed_loader.BidsDataset(PATH_BIDS,
                                             subject_lst=train_lst,
                                             target_suffix=["_seg-manual"],
                                             contrast_params=contrast_params,
                                             metadata_choice=False,
                                             slice_axis=slice_axis,
                                             transform=tranform_lst,
                                             multichannel=False)
                ds.load_filenames()
            else:
                ds = imed_loader.Bids3DDataset(PATH_BIDS,
                                               subject_lst=train_lst,
                                               target_suffix=["_seg-manual"],
                                               model_params=model_params,
                                               contrast_params=contrast_params,
                                               metadata_choice=False,
                                               slice_axis=slice_axis,
                                               transform=tranform_lst,
                                               multichannel=False)

            loader = DataLoader(ds, batch_size=1,
                                shuffle=True, pin_memory=True,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=1)

            input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[0]
            segpair = imed_loader.SegmentationPair(input_filename, gt_filename, metadata=metadata,
                                                   slice_axis=slice_axis)
            nib_original = nib.load(gt_filename[0])
            # Get image with original, ras and hwd orientations
            input_init = nib_original.get_fdata()
            input_ras = nib.as_closest_canonical(nib_original).get_fdata()
            img, gt = segpair.get_pair_data()
            input_hwd = gt[0]

            pred_tmp_lst, z_tmp_lst = [], []
            for i, batch in enumerate(loader):
                # batch["input_metadata"] = batch["input_metadata"][0]  # Take only metadata from one input
                # batch["gt_metadata"] = batch["gt_metadata"][0]  # Take only metadata from one label

                for smp_idx in range(len(batch['gt'])):
                    # undo transformations
                    if dim == '2d':
                        preds_idx_undo, metadata_idx = training_undo_transform(batch["gt"][smp_idx],
                                                                               batch["gt_metadata"][smp_idx],
                                                                               data_type='gt')

                        # add new sample to pred_tmp_lst
                        pred_tmp_lst.append(preds_idx_undo[0])
                        z_tmp_lst.append(int(batch['input_metadata'][smp_idx][0]['slice_index']))

                    else:
                        preds_idx_undo, metadata_idx = training_undo_transform(batch["gt"][smp_idx],
                                                                               batch["gt_metadata"][smp_idx],
                                                                               data_type='gt')

                    fname_ref = metadata_idx[0]['gt_filenames'][0]

                    if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d':
                        # save the completely processed file as a nii
                        nib_ref = nib.load(fname_ref)
                        nib_ref_can = nib.as_closest_canonical(nib_ref)

                        if dim == '2d':
                            tmp_lst = []
                            for z in range(nib_ref_can.header.get_data_shape()[slice_axis]):
                                tmp_lst.append(pred_tmp_lst[z_tmp_lst.index(z)])
                            arr = np.stack(tmp_lst, axis=-1)
                        else:
                            arr = np.array(preds_idx_undo[0])

                        # verify image after transform, undo transform and 3D reconstruction
                        input_hwd_2 = imed_postpro.threshold_predictions(arr)
                        # Some difference are generated due to transform and undo transform
                        # (e.i. Resample interpolation)
                        assert imed_metrics.dice_score(input_hwd_2, input_hwd) >= 0.8
                        input_ras_2 = imed_loader_utils.orient_img_ras(input_hwd_2, slice_axis)
                        assert imed_metrics.dice_score(input_ras_2, input_ras) >= 0.8
                        input_init_2 = imed_loader_utils.reorient_image(input_hwd_2, slice_axis, nib_ref, nib_ref_can)
                        assert imed_metrics.dice_score(input_init_2, input_init) >= 0.8

                        # re-init pred_stack_lst
                        pred_tmp_lst, z_tmp_lst = [], []