Ejemplo n.º 1
0
def load_dataset(data_list,
                 bids_path,
                 transforms_params,
                 model_params,
                 target_suffix,
                 roi_params,
                 contrast_params,
                 slice_filter_params,
                 slice_axis,
                 multichannel,
                 dataset_type="training",
                 requires_undo=False,
                 metadata_type=None,
                 object_detection_params=None,
                 soft_gt=False,
                 device=None,
                 cuda_available=None,
                 **kwargs):
    """Get loader appropriate loader according to model type. Available loaders are Bids3DDataset for 3D data,
    BidsDataset for 2D data and HDF5Dataset for HeMIS.

    Args:
        data_list (list): Subject names list.
        bids_path (str): Path to the BIDS dataset.
        transforms_params (dict): Dictionary containing transformations for "training", "validation", "testing" (keys),
            eg output of imed_transforms.get_subdatasets_transforms.
        model_params (dict): Dictionary containing model parameters.
        target_suffix (list of str): List of suffixes for target masks.
        roi_params (dict): Contains ROI related parameters.
        contrast_params (dict): Contains image contrasts related parameters.
        slice_filter_params (dict): Contains slice_filter parameters, see :doc:`configuration_file` for more details.
        slice_axis (string): Choice between "axial", "sagittal", "coronal" ; controls the axis used to extract the 2D
            data.
        multichannel (bool): If True, the input contrasts are combined as input channels for the model. Otherwise, each
            contrast is processed individually (ie different sample / tensor).
        metadata_type (str): Choice between None, "mri_params", "contrasts".
        dataset_type (str): Choice between "training", "validation" or "testing".
        requires_undo (bool): If True, the transformations without undo_transform will be discarded.
        object_detection_params (dict): Object dection parameters.
        soft_gt (bool): If True, ground truths will be converted to float32, otherwise to uint8 and binarized
            (to save memory).
    Returns:
        BidsDataset

    Note: For more details on the parameters transform_params, target_suffix, roi_params, contrast_params,
    slice_filter_params and object_detection_params see :doc:`configuration_file`.
    """
    # Compose transforms
    tranform_lst, _ = imed_transforms.prepare_transforms(
        copy.deepcopy(transforms_params), requires_undo)

    # If ROICrop is not part of the transforms, then enforce no slice filtering based on ROI data.
    if 'ROICrop' not in transforms_params:
        roi_params["slice_filter_roi"] = None

    if model_params["name"] == "Modified3DUNet" or ('is_2d' in model_params and
                                                    not model_params['is_2d']):
        dataset = Bids3DDataset(
            bids_path,
            subject_lst=data_list,
            target_suffix=target_suffix,
            roi_params=roi_params,
            contrast_params=contrast_params,
            metadata_choice=metadata_type,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            multichannel=multichannel,
            model_params=model_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt)

    elif model_params["name"] == "HeMISUnet":
        dataset = imed_adaptative.HDF5Dataset(
            root_dir=bids_path,
            subject_lst=data_list,
            model_params=model_params,
            contrast_params=contrast_params,
            target_suffix=target_suffix,
            slice_axis=imed_utils.AXIS_DCT[slice_axis],
            transform=tranform_lst,
            metadata_choice=metadata_type,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                **slice_filter_params,
                device=device,
                cuda_available=cuda_available),
            roi_params=roi_params,
            object_detection_params=object_detection_params,
            soft_gt=soft_gt)
    else:
        # Task selection
        task = imed_utils.get_task(model_params["name"])

        dataset = BidsDataset(bids_path,
                              subject_lst=data_list,
                              target_suffix=target_suffix,
                              roi_params=roi_params,
                              contrast_params=contrast_params,
                              metadata_choice=metadata_type,
                              slice_axis=imed_utils.AXIS_DCT[slice_axis],
                              transform=tranform_lst,
                              multichannel=multichannel,
                              slice_filter_fn=imed_loader_utils.SliceFilter(
                                  **slice_filter_params,
                                  device=device,
                                  cuda_available=cuda_available),
                              soft_gt=soft_gt,
                              object_detection_params=object_detection_params,
                              task=task)
        dataset.load_filenames()

    if model_params["name"] != "Modified3DUNet":
        print("Loaded {} {} slices for the {} set.".format(
            len(dataset), slice_axis, dataset_type))
    else:
        print("Loaded {} volumes of size {} for the {} set.".format(
            len(dataset), slice_axis, dataset_type))

    return dataset
Ejemplo n.º 2
0
def test_HeMIS(p=0.0001):
    print('[INFO]: Starting test ... \n')
    training_transform_dict = {
        "Resample":
            {
                "wspace": 0.75,
                "hspace": 0.75
            },
        "CenterCrop":
            {
                "size": [48, 48]
            },
        "NumpyToTensor": {}
    }

    transform_lst, _ = imed_transforms.prepare_transforms(training_transform_dict)

    roi_params = {"suffix": "_seg-manual", "slice_filter_roi": None}

    train_lst = ['sub-unf01']
    contrasts = ['T1w', 'T2w', 'T2star']

    print('[INFO]: Creating dataset ...\n')
    model_params = {
            "name": "HeMISUnet",
            "dropout_rate": 0.3,
            "bn_momentum": 0.9,
            "depth": 2,
            "in_channel": 1,
            "out_channel": 1,
            "missing_probability": 0.00001,
            "missing_probability_growth": 0.9,
            "contrasts": ["T1w", "T2w"],
            "ram": False,
            "path_hdf5": 'testing_data/mytestfile.hdf5',
            "csv_path": 'testing_data/hdf5.csv',
            "target_lst": ["T2w"],
            "roi_lst": ["T2w"]
        }
    contrast_params = {
        "contrast_lst": ['T1w', 'T2w', 'T2star'],
        "balance": {}
    }
    dataset = imed_adaptative.HDF5Dataset(root_dir=PATH_BIDS,
                                          subject_lst=train_lst,
                                          model_params=model_params,
                                          contrast_params=contrast_params,
                                          target_suffix=["_lesion-manual"],
                                          slice_axis=2,
                                          transform=transform_lst,
                                          metadata_choice=False,
                                          dim=2,
                                          slice_filter_fn=imed_loader_utils.SliceFilter(filter_empty_input=True,
                                                                                 filter_empty_mask=True),
                                          roi_params=roi_params)

    dataset.load_into_ram(['T1w', 'T2w', 'T2star'])
    print("[INFO]: Dataset RAM status:")
    print(dataset.status)
    print("[INFO]: In memory Dataframe:")
    print(dataset.dataframe)

    # TODO
    # ds_train.filter_roi(nb_nonzero_thr=10)

    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                              shuffle=True, pin_memory=True,
                              collate_fn=imed_loader_utils.imed_collate,
                              num_workers=1)

    model = models.HeMISUnet(contrasts=contrasts,
                             depth=3,
                             drop_rate=DROPOUT,
                             bn_momentum=BN)

    print(model)
    cuda_available = torch.cuda.is_available()

    if cuda_available:
        torch.cuda.set_device(GPU_NUMBER)
        print("Using GPU number {}".format(GPU_NUMBER))
        model.cuda()

    # Initialing Optimizer and scheduler
    step_scheduler_batch = False
    optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS)

    load_lst, reload_lst, pred_lst, opt_lst, schedul_lst, init_lst, gen_lst = [], [], [], [], [], [], []

    for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"):
        start_time = time.time()

        start_init = time.time()
        lr = scheduler.get_last_lr()[0]
        model.train()

        tot_init = time.time() - start_init
        init_lst.append(tot_init)

        num_steps = 0
        start_gen = 0
        for i, batch in enumerate(train_loader):
            if i > 0:
                tot_gen = time.time() - start_gen
                gen_lst.append(tot_gen)

            start_load = time.time()
            input_samples, gt_samples = imed_utils.unstack_tensors(batch["input"]), batch["gt"]

            print(batch["input_metadata"][0][0]["missing_mod"])
            missing_mod = imed_training.get_metadata(batch["input_metadata"], model_params)

            print("Number of missing contrasts = {}."
                  .format(len(input_samples) * len(input_samples[0]) - missing_mod.sum()))
            print("len input = {}".format(len(input_samples)))
            print("Batch = {}, {}".format(input_samples[0].shape, gt_samples[0].shape))

            if cuda_available:
                var_input = imed_utils.cuda(input_samples)
                var_gt = imed_utils.cuda(gt_samples, non_blocking=True)
            else:
                var_input = input_samples
                var_gt = gt_samples

            tot_load = time.time() - start_load
            load_lst.append(tot_load)

            start_pred = time.time()
            preds = model(var_input, missing_mod)
            tot_pred = time.time() - start_pred
            pred_lst.append(tot_pred)

            start_opt = time.time()
            loss = - losses.DiceLoss()(preds, var_gt)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            if step_scheduler_batch:
                scheduler.step()

            num_steps += 1
            tot_opt = time.time() - start_opt
            opt_lst.append(tot_opt)

            start_gen = time.time()

        start_schedul = time.time()
        if not step_scheduler_batch:
            scheduler.step()
        tot_schedul = time.time() - start_schedul
        schedul_lst.append(tot_schedul)

        start_reload = time.time()
        print("[INFO]: Updating Dataset")
        p = p ** (2 / 3)
        dataset.update(p=p)
        print("[INFO]: Reloading dataset")
        train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                                  shuffle=True, pin_memory=True,
                                  collate_fn=imed_loader_utils.imed_collate,
                                  num_workers=1)
        tot_reload = time.time() - start_reload
        reload_lst.append(tot_reload)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

    print('Mean SD init {} -- {}'.format(np.mean(init_lst), np.std(init_lst)))
    print('Mean SD load {} -- {}'.format(np.mean(load_lst), np.std(load_lst)))
    print('Mean SD reload {} -- {}'.format(np.mean(reload_lst), np.std(reload_lst)))
    print('Mean SD pred {} -- {}'.format(np.mean(pred_lst), np.std(pred_lst)))
    print('Mean SD opt {} --  {}'.format(np.mean(opt_lst), np.std(opt_lst)))
    print('Mean SD gen {} -- {}'.format(np.mean(gen_lst), np.std(gen_lst)))
    print('Mean SD scheduler {} -- {}'.format(np.mean(schedul_lst), np.std(schedul_lst)))
Ejemplo n.º 3
0
def test_hdf5(download_data_testing_test_files, loader_parameters):
    print('[INFO]: Starting test ... \n')

    bids_df = imed_loader_utils.BidsDataframe(loader_parameters,
                                              __tmp_dir__,
                                              derivatives=True)

    contrast_params = loader_parameters["contrast_params"]
    target_suffix = loader_parameters["target_suffix"]
    roi_params = loader_parameters["roi_params"]

    train_lst = ['sub-unf01']

    training_transform_dict = {
        "Resample": {
            "wspace": 0.75,
            "hspace": 0.75
        },
        "CenterCrop": {
            "size": [48, 48]
        },
        "NumpyToTensor": {}
    }
    transform_lst, _ = imed_transforms.prepare_transforms(
        training_transform_dict)

    bids_to_hdf5 = imed_adaptative.BIDStoHDF5(
        bids_df=bids_df,
        subject_file_lst=train_lst,
        path_hdf5=os.path.join(__data_testing_dir__, 'mytestfile.hdf5'),
        target_suffix=target_suffix,
        roi_params=roi_params,
        contrast_lst=contrast_params["contrast_lst"],
        metadata_choice="contrast",
        transform=transform_lst,
        contrast_balance={},
        slice_axis=2,
        slice_filter_fn=imed_loader_utils.SliceFilter(filter_empty_input=True,
                                                      filter_empty_mask=True))

    # Checking architecture
    def print_attrs(name, obj):
        print("\nName of the object: {}".format(name))
        print("Type: {}".format(type(obj)))
        print("Including the following attributes:")
        for key, val in obj.attrs.items():
            print("    %s: %s" % (key, val))

    print('\n[INFO]: HDF5 architecture:')
    with h5py.File(bids_to_hdf5.path_hdf5, "a") as hdf5_file:
        hdf5_file.visititems(print_attrs)
        print('\n[INFO]: HDF5 file successfully generated.')
        print('[INFO]: Generating dataframe ...\n')

        df = imed_adaptative.Dataframe(hdf5_file=hdf5_file,
                                       contrasts=['T1w', 'T2w', 'T2star'],
                                       path=os.path.join(
                                           __data_testing_dir__, 'hdf5.csv'),
                                       target_suffix=['T1w', 'T2w', 'T2star'],
                                       roi_suffix=['T1w', 'T2w', 'T2star'],
                                       dim=2,
                                       filter_slices=True)

        print(df.df)

        print('\n[INFO]: Dataframe successfully generated. ')
        print('[INFO]: Creating dataset ...\n')

        model_params = {
            "name": "HeMISUnet",
            "dropout_rate": 0.3,
            "bn_momentum": 0.9,
            "depth": 2,
            "in_channel": 1,
            "out_channel": 1,
            "missing_probability": 0.00001,
            "missing_probability_growth": 0.9,
            "contrasts": ["T1w", "T2w"],
            "ram": False,
            "path_hdf5": os.path.join(__data_testing_dir__, 'mytestfile.hdf5'),
            "csv_path": os.path.join(__data_testing_dir__, 'hdf5.csv'),
            "target_lst": ["T2w"],
            "roi_lst": ["T2w"]
        }

        dataset = imed_adaptative.HDF5Dataset(
            bids_df=bids_df,
            subject_file_lst=train_lst,
            target_suffix=target_suffix,
            slice_axis=2,
            model_params=model_params,
            contrast_params=contrast_params,
            transform=transform_lst,
            metadata_choice=False,
            dim=2,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                filter_empty_input=True, filter_empty_mask=True),
            roi_params=roi_params)

        dataset.load_into_ram(['T1w', 'T2w', 'T2star'])
        print("Dataset RAM status:")
        print(dataset.status)
        print("In memory Dataframe:")
        print(dataset.dataframe)
        print('\n[INFO]: Test passed successfully. ')

        print("\n[INFO]: Starting loader test ...")

        device = torch.device(
            "cuda:" + str(GPU_ID) if torch.cuda.is_available() else "cpu")
        cuda_available = torch.cuda.is_available()
        if cuda_available:
            torch.cuda.set_device(device)
            print("Using GPU ID {}".format(device))

        train_loader = DataLoader(dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=False,
                                  pin_memory=True,
                                  collate_fn=imed_loader_utils.imed_collate,
                                  num_workers=1)

        for i, batch in enumerate(train_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]
            print("len input = {}".format(len(input_samples)))
            print("Batch = {}, {}".format(input_samples[0].shape,
                                          gt_samples[0].shape))

            if cuda_available:
                var_input = imed_utils.cuda(input_samples)
                var_gt = imed_utils.cuda(gt_samples, non_blocking=True)
            else:
                var_input = input_samples
                var_gt = gt_samples

            break
        print(
            "Congrats your dataloader works! You can go home now and get a beer."
        )
        return 0