示例#1
0
def run_main(args):
    with open(args.c, "r") as fhandle:
        context = json.load(fhandle)

    transform_lst = torch_transforms.Compose([
        imed_transforms.Resample(wspace=0.75, hspace=0.75),
        imed_transforms.CenterCrop([128, 128]),
        imed_transforms.NumpyToTensor(),
        imed_transforms.NormalizeInstance(),
    ])

    train_lst, valid_lst, test_lst = imed_loader_utils.split_dataset(
        context["bids_path"], context["center_test"], context["split_method"],
        context["random_seed"])

    balance_dct = {}
    for ds_lst, ds_name in zip([train_lst, valid_lst, test_lst],
                               ['train', 'valid', 'test']):
        print("\nLoading {} set.\n".format(ds_name))
        ds = imed_loader.BidsDataset(
            context["bids_path"],
            subject_lst=ds_lst,
            target_suffix=context["target_suffix"],
            contrast_lst=context["contrast_test"]
            if ds_name == 'test' else context["contrast_train_validation"],
            metadata_choice=context["metadata"],
            contrast_balance=context["contrast_balance"],
            transform=transform_lst,
            slice_filter_fn=imed_utils.SliceFilter())

        print("Loaded {} axial slices for the {} set.".format(
            len(ds), ds_name))
        ds_loader = DataLoader(ds,
                               batch_size=1,
                               shuffle=False,
                               pin_memory=False,
                               collate_fn=imed_loader_utils.imed_collate,
                               num_workers=1)

        balance_lst = []
        for i, batch in enumerate(ds_loader):
            gt_sample = batch["gt"].numpy().astype(np.int)[0, 0, :, :]
            nb_ones = (gt_sample == 1).sum()
            nb_voxels = gt_sample.size
            balance_lst.append(nb_ones * 100.0 / nb_voxels)

        balance_dct[ds_name] = balance_lst

    for ds_name in balance_dct:
        print('\nClass balance in {} set:'.format(ds_name))
        print_stats(balance_dct[ds_name])

    print('\nClass balance in full set:')
    print_stats([e for d in balance_dct for e in balance_dct[d]])
示例#2
0
def run_main(context):
    no_transform = torch_transforms.Compose([
        imed_transforms.CenterCrop([128, 128]),
        imed_transforms.NumpyToTensor(),
        imed_transforms.NormalizeInstance(),
    ])

    out_dir = context["log_directory"]
    split_dct = joblib.load(os.path.join(out_dir, "split_datasets.joblib"))
    metadata_dct = {}
    for subset in ['train', 'valid', 'test']:
        metadata_dct[subset] = {}
        ds = imed_loader.BidsDataset(
            context["bids_path"],
            subject_lst=split_dct[subset],
            contrast_lst=context["contrast_train_validation"]
            if subset != "test" else context["contrast_test"],
            transform=no_transform,
            slice_filter_fn=imed_utils.SliceFilter())

        for m in metadata_type:
            if m in metadata_dct:
                metadata_dct[subset][m] = [
                    v for m_lst in [metadata_dct[subset][m], ds.metadata[m]]
                    for v in m_lst
                ]
            else:
                metadata_dct[subset][m] = ds.metadata[m]

    cluster_dct = joblib.load(os.path.join(out_dir,
                                           "clustering_models.joblib"))

    out_dir = os.path.join(out_dir, "cluster_metadata")
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)

    for m in metadata_type:
        values = [
            v for s in ['train', 'valid', 'test'] for v in metadata_dct[s][m]
        ]
        print('\n{}: Min={}, Max={}, Median={}'.format(m, min(values),
                                                       max(values),
                                                       np.median(values)))
        plot_decision_boundaries(metadata_dct, cluster_dct[m],
                                 metadata_range[m], m,
                                 os.path.join(out_dir, m + '.png'))
示例#3
0
def run_main(context):
    no_transform = torch_transforms.Compose([
        imed_transforms.CenterCrop([128, 128]),
        imed_transforms.NumpyToTensor(),
        imed_transforms.NormalizeInstance(),
    ])

    out_dir = context["log_directory"]
    metadata_dct = {}
    for subset in ['train', 'validation', 'test']:
        metadata_dct[subset] = {}
        for bids_ds in tqdm(context["bids_path_" + subset],
                            desc="Loading " + subset + " set"):
            ds = imed_loader.BidsDataset(
                bids_ds,
                contrast_lst=context["contrast_train_validation"]
                if subset != "test" else context["contrast_test"],
                transform=no_transform,
                slice_filter_fn=imed_utils.SliceFilter())

            for m in metadata_type:
                if m in metadata_dct:
                    metadata_dct[subset][m] = [
                        v
                        for m_lst in [metadata_dct[subset][m], ds.metadata[m]]
                        for v in m_lst
                    ]
                else:
                    metadata_dct[subset][m] = ds.metadata[m]

        for m in metadata_type:
            metadata_dct[subset][m] = list(set(metadata_dct[subset][m]))

    with open(out_dir + "/metadata_config.json", 'w') as fp:
        json.dump(metadata_dct, fp)

    return
示例#4
0
def test_hdf5():
    print('[INFO]: Starting test ... \n')
    train_lst = ['sub-unf01']

    training_transform_dict = {
        "Resample": {
            "wspace": 0.75,
            "hspace": 0.75
        },
        "CenterCrop": {
            "size": [48, 48]
        },
        "NumpyToTensor": {}
    }
    transform_lst, _ = imed_transforms.prepare_transforms(
        training_transform_dict)

    roi_params = {"suffix": "_seg-manual", "slice_filter_roi": None}

    hdf5_file = imed_adaptative.Bids_to_hdf5(
        PATH_BIDS,
        subject_lst=train_lst,
        hdf5_name='testing_data/mytestfile.hdf5',
        target_suffix=["_lesion-manual"],
        roi_params=roi_params,
        contrast_lst=['T1w', 'T2w', 'T2star'],
        metadata_choice="contrast",
        transform=transform_lst,
        contrast_balance={},
        slice_axis=2,
        slice_filter_fn=imed_utils.SliceFilter(filter_empty_input=True,
                                               filter_empty_mask=True))

    # Checking architecture
    def print_attrs(name, obj):
        print("\nName of the object: {}".format(name))
        print("Type: {}".format(type(obj)))
        print("Including the following attributes:")
        for key, val in obj.attrs.items():
            print("    %s: %s" % (key, val))

    print('\n[INFO]: HDF5 architecture:')
    hdf5_file.hdf5_file.visititems(print_attrs)
    print('\n[INFO]: HDF5 file successfully generated.')
    print('[INFO]: Generating dataframe ...\n')

    df = imed_adaptative.Dataframe(hdf5=hdf5_file.hdf5_file,
                                   contrasts=['T1w', 'T2w', 'T2star'],
                                   path='testing_data/hdf5.csv',
                                   target_suffix=['T1w', 'T2w', 'T2star'],
                                   roi_suffix=['T1w', 'T2w', 'T2star'],
                                   dim=2,
                                   filter_slices=True)

    print(df.df)

    print('\n[INFO]: Dataframe successfully generated. ')
    print('[INFO]: Creating dataset ...\n')

    model_params = {
        "name": "HeMISUnet",
        "dropout_rate": 0.3,
        "bn_momentum": 0.9,
        "depth": 2,
        "in_channel": 1,
        "out_channel": 1,
        "missing_probability": 0.00001,
        "missing_probability_growth": 0.9,
        "contrasts": ["T1w", "T2w"],
        "ram": False,
        "hdf5_path": 'testing_data/mytestfile.hdf5',
        "csv_path": 'testing_data/hdf5.csv',
        "target_lst": ["T2w"],
        "roi_lst": ["T2w"]
    }
    contrast_params = {"contrast_lst": ['T1w', 'T2w', 'T2star'], "balance": {}}

    dataset = imed_adaptative.HDF5Dataset(
        root_dir=PATH_BIDS,
        subject_lst=train_lst,
        target_suffix="_lesion-manual",
        slice_axis=2,
        model_params=model_params,
        contrast_params=contrast_params,
        transform=transform_lst,
        metadata_choice=False,
        dim=2,
        slice_filter_fn=imed_utils.SliceFilter(filter_empty_input=True,
                                               filter_empty_mask=True),
        roi_params=roi_params)

    dataset.load_into_ram(['T1w', 'T2w', 'T2star'])
    print("Dataset RAM status:")
    print(dataset.status)
    print("In memory Dataframe:")
    print(dataset.dataframe)
    print('\n[INFO]: Test passed successfully. ')

    print("\n[INFO]: Starting loader test ...")

    device = torch.device(
        "cuda:" + str(GPU_NUMBER) if torch.cuda.is_available() else "cpu")
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.set_device(device)
        print("Using GPU number {}".format(device))

    train_loader = DataLoader(dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              pin_memory=True,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=1)

    for i, batch in enumerate(train_loader):
        input_samples, gt_samples = batch["input"], batch["gt"]
        print("len input = {}".format(len(input_samples)))
        print("Batch = {}, {}".format(input_samples[0].shape,
                                      gt_samples[0].shape))

        if cuda_available:
            var_input = imed_utils.cuda(input_samples)
            var_gt = imed_utils.cuda(gt_samples, non_blocking=True)
        else:
            var_input = input_samples
            var_gt = gt_samples

        break
    os.remove('testing_data/mytestfile.hdf5')
    print(
        "Congrats your dataloader works! You can go Home now and get a beer.")
    return 0
示例#5
0
def load_dataset(data_list,
                 bids_path,
                 transforms_params,
                 model_params,
                 target_suffix,
                 roi_params,
                 contrast_params,
                 slice_filter_params,
                 slice_axis,
                 multichannel,
                 dataset_type="training",
                 requires_undo=False,
                 metadata_type=None,
                 object_detection_params=None,
                 soft_gt=False,
                 device=None,
                 cuda_available=None,
                 **kwargs):
    """Get loader appropriate loader according to model type. Available loaders are Bids3DDataset for 3D data,
    BidsDataset for 2D data and HDF5Dataset for HeMIS.

    Args:
        data_list (list): Subject names list.
        bids_path (str): Path to the BIDS dataset.
        transforms_params (dict): Dictionary containing transformations for "training", "validation", "testing" (keys),
            eg output of imed_transforms.get_subdatasets_transforms.
        model_params (dict): Dictionary containing model parameters.
        target_suffix (list of str): List of suffixes for target masks.
        roi_params (dict): Contains ROI related parameters.
        contrast_params (dict): Contains image contrasts related parameters.
        slice_filter_params (dict): Contains slice_filter parameters, see :doc:`configuration_file` for more details.
        slice_axis (string): Choice between "axial", "sagittal", "coronal" ; controls the axis used to extract the 2D
            data.
        multichannel (bool): If True, the input contrasts are combined as input channels for the model. Otherwise, each
            contrast is processed individually (ie different sample / tensor).
        metadata_type (str): Choice between None, "mri_params", "contrasts".
        dataset_type (str): Choice between "training", "validation" or "testing".
        requires_undo (bool): If True, the transformations without undo_transform will be discarded.
        object_detection_params (dict): Object dection parameters.
        soft_gt (bool): If True, ground truths will be converted to float32, otherwise to uint8 and binarized
            (to save memory).
    Returns:
        BidsDataset

    Note: For more details on the parameters transform_params, target_suffix, roi_params, contrast_params,
    slice_filter_params and object_detection_params see :doc:`configuration_file`.
    """
    # Compose transforms
    tranform_lst, _ = imed_transforms.prepare_transforms(
        copy.deepcopy(transforms_params), requires_undo)

    # If ROICrop is not part of the transforms, then enforce no slice filtering based on ROI data.
    if 'ROICrop' not in transforms_params:
        roi_params["slice_filter_roi"] = None

    if model_params["name"] == "UNet3D":
        dataset = Bids3DDataset(
            bids_path,
            subject_lst=data_list,
            target_suffix=target_suffix,
            roi_params=roi_params,
            contrast_params=contrast_params,
            metadata_choice=metadata_type,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            multichannel=multichannel,
            model_params=model_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt)

    elif model_params["name"] == "HeMISUnet":
        dataset = imed_adaptative.HDF5Dataset(
            root_dir=bids_path,
            subject_lst=data_list,
            model_params=model_params,
            contrast_params=contrast_params,
            target_suffix=target_suffix,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            metadata_choice=metadata_type,
            slice_filter_fn=imed_utils.SliceFilter(
                **slice_filter_params,
                device=device,
                cuda_available=cuda_available),
            roi_params=roi_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt)
    else:
        # Task selection
        task = imed_utils.get_task(model_params["name"])

        dataset = BidsDataset(bids_path,
                              subject_lst=data_list,
                              target_suffix=target_suffix,
                              roi_params=roi_params,
                              contrast_params=contrast_params,
                              metadata_choice=metadata_type,
                              slice_axis=imed_utils.AXIS_DCT[slice_axis],
                              transform=tranform_lst,
                              multichannel=multichannel,
                              slice_filter_fn=imed_utils.SliceFilter(
                                  **slice_filter_params,
                                  device=device,
                                  cuda_available=cuda_available),
                              soft_gt=soft_gt,
                              object_detection_params=object_detection_params,
                              task=task)
        dataset.load_filenames()

    if model_params["name"] != "UNet3D":
        print("Loaded {} {} slices for the {} set.".format(
            len(dataset), slice_axis, dataset_type))
    else:
        print("Loaded {} volumes of size {} for the {} set.".format(
            len(dataset), slice_axis, dataset_type))

    return dataset
示例#6
0
def test_HeMIS(p=0.0001):
    print('[INFO]: Starting test ... \n')
    training_transform_dict = {
        "Resample": {
            "wspace": 0.75,
            "hspace": 0.75
        },
        "CenterCrop": {
            "size": [48, 48]
        },
        "NumpyToTensor": {}
    }

    transform_lst, _ = imed_transforms.prepare_transforms(
        training_transform_dict)

    roi_params = {"suffix": "_seg-manual", "slice_filter_roi": None}

    train_lst = ['sub-unf01']
    contrasts = ['T1w', 'T2w', 'T2star']

    print('[INFO]: Creating dataset ...\n')
    model_params = {
        "name": "HeMISUnet",
        "dropout_rate": 0.3,
        "bn_momentum": 0.9,
        "depth": 2,
        "in_channel": 1,
        "out_channel": 1,
        "missing_probability": 0.00001,
        "missing_probability_growth": 0.9,
        "contrasts": ["T1w", "T2w"],
        "ram": False,
        "hdf5_path": 'testing_data/mytestfile.hdf5',
        "csv_path": 'testing_data/hdf5.csv',
        "target_lst": ["T2w"],
        "roi_lst": ["T2w"]
    }
    contrast_params = {"contrast_lst": ['T1w', 'T2w', 'T2star'], "balance": {}}
    dataset = imed_adaptative.HDF5Dataset(
        root_dir=PATH_BIDS,
        subject_lst=train_lst,
        model_params=model_params,
        contrast_params=contrast_params,
        target_suffix=["_lesion-manual"],
        slice_axis=2,
        transform=transform_lst,
        metadata_choice=False,
        dim=2,
        slice_filter_fn=imed_utils.SliceFilter(filter_empty_input=True,
                                               filter_empty_mask=True),
        roi_params=roi_params)

    dataset.load_into_ram(['T1w', 'T2w', 'T2star'])
    print("[INFO]: Dataset RAM status:")
    print(dataset.status)
    print("[INFO]: In memory Dataframe:")
    print(dataset.dataframe)

    # TODO
    # ds_train.filter_roi(nb_nonzero_thr=10)

    train_loader = DataLoader(dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=1)

    model = models.HeMISUnet(contrasts=contrasts,
                             depth=3,
                             drop_rate=DROPOUT,
                             bn_momentum=BN)

    print(model)
    cuda_available = torch.cuda.is_available()

    if cuda_available:
        torch.cuda.set_device(GPU_NUMBER)
        print("Using GPU number {}".format(GPU_NUMBER))
        model.cuda()

    # Initialing Optimizer and scheduler
    step_scheduler_batch = False
    optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS)

    load_lst, reload_lst, pred_lst, opt_lst, schedul_lst, init_lst, gen_lst = [], [], [], [], [], [], []

    for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"):
        start_time = time.time()

        start_init = time.time()
        lr = scheduler.get_last_lr()[0]
        model.train()

        tot_init = time.time() - start_init
        init_lst.append(tot_init)

        num_steps = 0
        start_gen = 0
        for i, batch in enumerate(train_loader):
            if i > 0:
                tot_gen = time.time() - start_gen
                gen_lst.append(tot_gen)

            start_load = time.time()
            input_samples, gt_samples = imed_utils.unstack_tensors(
                batch["input"]), batch["gt"]

            print(batch["input_metadata"][0][0]["missing_mod"])
            missing_mod = imed_training.get_metadata(batch["input_metadata"],
                                                     model_params)

            print("Number of missing contrasts = {}.".format(
                len(input_samples) * len(input_samples[0]) -
                missing_mod.sum()))
            print("len input = {}".format(len(input_samples)))
            print("Batch = {}, {}".format(input_samples[0].shape,
                                          gt_samples[0].shape))

            if cuda_available:
                var_input = imed_utils.cuda(input_samples)
                var_gt = imed_utils.cuda(gt_samples, non_blocking=True)
            else:
                var_input = input_samples
                var_gt = gt_samples

            tot_load = time.time() - start_load
            load_lst.append(tot_load)

            start_pred = time.time()
            preds = model(var_input, missing_mod)
            tot_pred = time.time() - start_pred
            pred_lst.append(tot_pred)

            start_opt = time.time()
            loss = -losses.DiceLoss()(preds, var_gt)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()

            num_steps += 1
            tot_opt = time.time() - start_opt
            opt_lst.append(tot_opt)

            start_gen = time.time()

        start_schedul = time.time()
        if not step_scheduler_batch:
            scheduler.step()
        tot_schedul = time.time() - start_schedul
        schedul_lst.append(tot_schedul)

        start_reload = time.time()
        print("[INFO]: Updating Dataset")
        p = p**(2 / 3)
        dataset.update(p=p)
        print("[INFO]: Reloading dataset")
        train_loader = DataLoader(dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True,
                                  pin_memory=True,
                                  collate_fn=imed_loader_utils.imed_collate,
                                  num_workers=1)
        tot_reload = time.time() - start_reload
        reload_lst.append(tot_reload)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

    print('Mean SD init {} -- {}'.format(np.mean(init_lst), np.std(init_lst)))
    print('Mean SD load {} -- {}'.format(np.mean(load_lst), np.std(load_lst)))
    print('Mean SD reload {} -- {}'.format(np.mean(reload_lst),
                                           np.std(reload_lst)))
    print('Mean SD pred {} -- {}'.format(np.mean(pred_lst), np.std(pred_lst)))
    print('Mean SD opt {} --  {}'.format(np.mean(opt_lst), np.std(opt_lst)))
    print('Mean SD gen {} -- {}'.format(np.mean(gen_lst), np.std(gen_lst)))
    print('Mean SD scheduler {} -- {}'.format(np.mean(schedul_lst),
                                              np.std(schedul_lst)))