Exemplo n.º 1
0
    def build(self):
        SEED = 42
        data = pd.read_csv(self.data)
        ab = data.label

        ############################################
        transforms = [
            RescaleIntensity((0, 1)),
            RandomAffine(),
            transformss.ToTensor(),
        ]
        transform = Compose(transforms)
        #############################################

        dataset_dir = self.dataset_dir
        dataset_dir = Path(dataset_dir)

        images_dir = dataset_dir
        labels_dir = dataset_dir
        image_paths = sorted(images_dir.glob('**/*.nii'))
        label_paths = sorted(labels_dir.glob('**/*.nii'))
        assert len(image_paths) == len(label_paths)

        # These two names are arbitrary
        MRI = 'features'
        BRAIN = 'targets'

        #split dataset into training and validation
        from catalyst.utils import split_dataframe_train_test

        train_image_paths, valid_image_paths = split_dataframe_train_test(
            image_paths, test_size=0.2, random_state=SEED)

        #training data
        subjects = []
        i = 0
        for (image_path, label_path) in zip(train_image_paths, label_paths):
            subject_dict = {
                MRI: torchio.Image(image_path, torchio.INTENSITY),
                BRAIN: ab[i],
            }
            i = i + 1
            subject = torchio.Subject(subject_dict)
            subjects.append(subject)
        train_data = torchio.ImagesDataset(subjects)

        #validation data
        subjects = []
        for (image_path, label_path) in zip(valid_image_paths, label_paths):
            subject_dict = {
                MRI: torchio.Image(image_path, torchio.INTENSITY),
                BRAIN: ab[i],
            }
            i = i + 1
            subject = torchio.Subject(subject_dict)
            subjects.append(subject)
        test_data = torchio.ImagesDataset(subjects)
        return train_data, test_data
Exemplo n.º 2
0
 def test_no_load_transform(self):
     with self.assertRaises(ValueError):
         dataset = torchio.ImagesDataset(
             self.subjects_list,
             load_image_data=False,
             transform=lambda x: x,
         )
Exemplo n.º 3
0
    def train_dataloader(self) -> DataLoader:
        training_transform = get_train_transforms()
        train_imageDataset = torchio.ImagesDataset(
            self.training_subjects, transform=training_transform)

        patches_training_set = torchio.Queue(
            subjects_dataset=train_imageDataset,
            # Maximum number of patches that can be stored in the queue.
            # Using a large number means that the queue needs to be filled less often,
            # but more CPU memory is needed to store the patches.
            max_length=self.max_queue_length,
            # Number of patches to extract from each volume.
            # A small number of patches ensures a large variability in the queue,
            # but training will be slower.
            samples_per_volume=self.samples_per_volume,
            #  A sampler used to extract patches from the volumes.
            sampler=torchio.sampler.UniformSampler(self.patch_size),
            num_workers=self.num_workers,
            # If True, the subjects dataset is shuffled at the beginning of each epoch,
            # i.e. when all patches from all subjects have been processed
            shuffle_subjects=False,
            # If True, patches are shuffled after filling the queue.
            shuffle_patches=True,
            verbose=True,
        )

        training_loader = DataLoader(patches_training_set,
                                     batch_size=self.hparams.batch_size)

        print(
            f"{ctime()}: getting number of training subjects {len(training_loader)}"
        )
        return training_loader
Exemplo n.º 4
0
 def train_dataloader(self):
     print(f"{ctime()}:  Creating Dataset...")
     subjects = get_cc539_subjects()
     transform = compose_transforms()
     subj_dataset = tio.ImagesDataset(subjects, transform=transform)
     print(f"{ctime()}:  Creating DataLoader...")
     training_loader = DataLoader(subj_dataset, batch_size=self.batch_size, num_workers=8)
     return training_loader
Exemplo n.º 5
0
def get_torchio_dataset(inputs, targets, transform):
    """
    The function creates dataset from the list of files from cunstumised dataloader.
    """
    subjects = []
    for (image_path, label_path) in zip(inputs, targets ):
        subject_dict = {
            MRI : torchio.Image(image_path, torchio.INTENSITY),
            LABEL: torchio.Image(label_path, torchio.LABEL),
        }
        subject = torchio.Subject(subject_dict)
        subjects.append(subject)
    
    if transform:
        dataset = torchio.ImagesDataset(subjects, transform = transform)
    elif not transform:
        dataset = torchio.ImagesDataset(subjects)
    
    return  dataset
Exemplo n.º 6
0
 def test_label_probabilities(self):
     labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, -1)
     subject = torchio.Subject(label=torchio.Image(tensor=labels,
                                                   type=torchio.LABEL), )
     sample = torchio.ImagesDataset([subject])[0]
     probs_dict = {0: 0, 1: 50, 2: 25, 3: 25}
     sampler = LabelSampler(5, 'label', label_probabilities=probs_dict)
     probabilities = sampler.get_probability_map(sample)
     fixture = torch.Tensor((0, 0, 2 / 12, 2 / 12, 3 / 12, 2 / 12, 0))
     assert torch.all(probabilities.squeeze().eq(fixture))
Exemplo n.º 7
0
 def get_sample(self, image_shape):
     t1 = torch.rand(*image_shape)
     prob = torch.zeros_like(t1)
     prob[3, 3, 3] = 1
     subject = torchio.Subject(
         t1=torchio.Image(tensor=t1),
         prob=torchio.Image(tensor=prob),
     )
     sample = torchio.ImagesDataset([subject])[0]
     return sample
Exemplo n.º 8
0
def get_dataset(datasets):
    subjects = get_subjects(datasets)

    training_transform = get_train_transforms()
    validation_transform = get_val_transform()

    num_subjects = len(subjects)
    # print(f"{ctime()}: get total number of {num_subjects} subjects")
    num_training_subjects = int(num_subjects *
                                0.9)  # (5074+359+21) * 0.9 used for training

    training_subjects = subjects[:num_training_subjects]
    validation_subjects = subjects[num_training_subjects:]

    training_set = torchio.ImagesDataset(training_subjects,
                                         transform=training_transform)

    validation_set = torchio.ImagesDataset(validation_subjects,
                                           transform=validation_transform)
    return training_set, validation_set
Exemplo n.º 9
0
 def test_dataloader(self):
     test_transform = get_test_transform()
     # using all the data to test
     test_imageDataset = torchio.ImagesDataset(self.subjects,
                                               transform=test_transform)
     test_loader = DataLoader(
         test_imageDataset,
         batch_size=1,  # always one because using different label size
         num_workers=10)
     print('Testing set:', len(test_imageDataset), 'subjects')
     return test_loader
Exemplo n.º 10
0
 def val_dataloader(self) -> DataLoader:
     val_transform = get_val_transform()
     val_imageDataset = torchio.ImagesDataset(self.validation_subjects,
                                              transform=val_transform)
     val_loader = DataLoader(
         val_imageDataset,
         batch_size=self.hparams.batch_size * 2,
         # num_workers=multiprocessing.cpu_count())
         num_workers=10)
     print('Validation set:', len(val_imageDataset), 'subjects')
     return val_loader
Exemplo n.º 11
0
 def test_coverage(self):
     dataset = torchio.ImagesDataset(self.subjects_list,
                                     transform=lambda x: x)
     _ = len(dataset)  # for coverage
     sample = dataset[0]
     output_path = self.dir / 'test.nii.gz'
     paths_dict = {'t1': output_path}
     dataset.save_sample(sample, paths_dict)
     nii = nib.load(str(output_path))
     ndims_output = len(nii.shape)
     ndims_sample = len(sample['t1'][DATA].shape)
     assert ndims_sample == ndims_output + 1
Exemplo n.º 12
0
 def train_dataloader(self) -> DataLoader:
     training_transform = get_train_transforms()
     train_imageDataset = torchio.ImagesDataset(
         self.training_subjects, transform=training_transform)
     training_loader = DataLoader(
         train_imageDataset,
         batch_size=self.hparams.batch_size,
         # num_workers=multiprocessing.cpu_count()) would cause RuntimeError('DataLoader
         # worker (pid(s) {}) exited unexpectedly' if don't do that
         num_workers=10)
     print('Training set:', len(train_imageDataset), 'subjects')
     return training_loader
Exemplo n.º 13
0
    def __init__(self, root_dir, img_range=(0,0)):
        self.root_dir = root_dir
        self.img_range = img_range


        subject_lists = []

        #check if there is a labels
        if self.root_dir[-1] != '/':
            self.root_dir += '/'

        self.is_labeled = os.path.isdir(self.root_dir + LABEL_DIR)

        self.files = [re.findall('[0-9]{4}', filename)[0] for filename in os.listdir(self.root_dir + TRAIN_DIR)]
        self.files = sorted(self.files, key = lambda f : int(f))

        # store all subjects in the list
        for img_num in range(img_range[0], img_range[1]+1):
            img_file = os.path.join(self.root_dir, TRAIN_DIR, IMG_PREFIX + self.files[img_num] + EXT)
            label_file = os.path.join(self.root_dir, LABEL_DIR, LABEL_PREFIX + self.files[img_num] + EXT)

            subject = torchio.Subject(
                torchio.Image('t1', img_file, torchio.INTENSITY),
                torchio.Image('label', label_file, torchio.LABEL)
            )

            subject_lists.append(subject)

            print(img_file)
            print(label_file)

        # Define transforms for data normalization and augmentation
        mtransforms = (
            ZNormalization(),
            #transforms.RandomNoise(std_range=(0, 0.25)),
            #transforms.RandomFlip(axes=(0,)),
        )

        self.subjects = torchio.ImagesDataset(subject_lists, transform=transforms.Compose(mtransforms))

        self.dataset = torchio.Queue(
            subjects_dataset=self.subjects,
            max_length=2,
            samples_per_volume=675,
            sampler_class=torchio.sampler.ImageSampler,
            patch_size=(240, 240, 3),
            num_workers=4,
            shuffle_subjects=False,
            shuffle_patches=True
        )

        print("Dataset details\n  Images: {}".format(self.img_range[1] - self.img_range[0] + 1))
Exemplo n.º 14
0
def test_unet() -> None:
    LEARN_RATE = 1e-4
    batch_size = 1
    channels = 1
    n_classes = 2
    model = UNet3d(initial_features=16)
    print(f"{ctime()}:  Built U-Net.")
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{ctime()}:  Total parameters: {total_params} ({trainable_params} trainable).")

    model.half()
    model.cuda()
    criterion = CrossEntropyLoss()
    criterion.cuda()

    print(f"{ctime()}:  Creating Dataset...")
    subjects = get_cc539_subjects()
    transform = compose_transforms()
    subj_dataset = tio.ImagesDataset(subjects, transform=transform)
    # batch size has to be 1 for variable-sized inputs
    print(f"{ctime()}:  Creating DataLoader...")
    training_loader = DataLoader(subj_dataset, batch_size=1, num_workers=8)
    print(f"{ctime()}:  Created DataLoader...")

    filterwarnings("ignore", message="Image.*has negative values.*")
    for i, subjects_batch in enumerate(training_loader):
        print(f"{ctime()}:  Augmenting input...")
        img = subjects_batch["img"][tio.DATA]
        print(f"{ctime()}:  Augmented input...")
        target = subjects_batch["label"][tio.DATA]
        # if we are using half precision, must also half inputs
        img.half()
        target.half()
        # iF we don't convert inputs tensor to CUDA, we get an;
        #    "Could not run 'aten::slow_conv3d_forward' with arguments from the
        #    'CUDATensorId' backend. 'aten::slow_conv3d_forward' is only available
        #    for these backends: [CPUTensorId, VariableTensorId]
        #    error. If we do, we run out of memory (since inputs are freaking
        #    brains)
        img = img.cuda()
        x = F.interpolate(img, size=(90, 90, 90))
        print(f"{ctime()}:  Running model with batch of one brain...")
        out = model(x)
        print(f"{ctime()}:  Got output tensor from one brain...")
        loss = criterion(out, target)
        print(f"{ctime()}:  Computed loss for batch size of 1 brain...")
        raise
Exemplo n.º 15
0
def test() -> None:
    print(f"{ctime()}:  Creating Dataset...")
    subjects = get_cc539_subjects()
    transform = compose_transforms()
    subj_dataset = tio.ImagesDataset(subjects, transform=transform)
    # batch size has to be 1 for variable-sized inputs
    print(f"{ctime()}:  Creating DataLoader...")
    training_loader = DataLoader(subj_dataset, batch_size=1, num_workers=8)

    filterwarnings("ignore", message="Image.*has negative values.*")
    for i, subjects_batch in enumerate(training_loader):
        inputs = subjects_batch["img"][tio.DATA]
        target = subjects_batch["label"][tio.DATA]
        print(f"{ctime()}:  Got subject img and mask {i}")
        if COMPUTE_CANADA and not IN_COMPUTE_CAN_JOB:  # don't run much on node
            if i > 5:
                sys.exit(0)
Exemplo n.º 16
0
    def test_dataloader(self):
        test_imageDataset = torchio.ImagesDataset(self.test_subjects)

        # patches_validation_set = torchio.Queue(
        #     subjects_dataset=val_imageDataset,
        #     max_length=self.max_queue_length,
        #     samples_per_volume=self.samples_per_volume,
        #     sampler=torchio.sampler.UniformSampler(self.patch_size),
        #     num_workers=self.num_workers,
        #     shuffle_subjects=False,
        #     shuffle_patches=True,
        #     verbose=True,
        # )

        # the batch_size here only could be 1 because we only could handle one image to aggregate
        test_loader = DataLoader(test_imageDataset, batch_size=1)
        print(
            f"{ctime()}: getting number of validation subjects {len(test_loader)}"
        )
        return test_loader
Exemplo n.º 17
0
    out_files = get_output_file(args.input)
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

    logging.info("Loading model ...")
    # model = LitUnet.load_from_checkpoint('./log/checkpoint/1.pth')
    model = LitUnet().cuda()
    model.eval()
    logging.info("Model loaded !")

    subjects = [
        tio.Subject(img=tio.Image(path=file, label=tio.INTENSITY), )
        for file in in_files
    ]

    val_transform = get_val_transform()
    val_imageDataset = tio.ImagesDataset(subjects, transform=val_transform)
    validation_loader = torch.utils.data.DataLoader(
        val_imageDataset,
        batch_size=1,
        num_workers=multiprocessing.cpu_count(),
    )

    for index, batch in enumerate(validation_loader):
        inputs = batch["img"][DATA].to(device)
        print(inputs.shape)
        # logits = model(inputs)
        # probabilities = torch.sigmoid(logits)
        #
        # mask = probabilities > 0.5
        # print(f"mask shape: {mask.shape}")
Exemplo n.º 18
0
 def iterate_dataset(self, paths_list):
     dataset = torchio.ImagesDataset(paths_list)
     for _ in dataset:
         pass
Exemplo n.º 19
0
    ])
    label_path_list = sorted([
        Path(f) for f in sorted(
            glob(f"{str(label_path_folder)}/**/*.nii.gz", recursive=True))
    ])

    subjects = []
    for img_path, label_path in zip(img_path_list, label_path_list):
        subject = tio.Subject(
            img=tio.Image(path=img_path, type=tio.INTENSITY),
            label=tio.Image(path=label_path, type=tio.LABEL),
        )
        subjects.append(subject)

    print(f"get {len(subjects)} of subject!")

    training_transform = get_train_transforms()

    training_set = tio.ImagesDataset(subjects, transform=training_transform)

    loader = DataLoader(
        training_set,
        batch_size=2,
        # num_workers=multiprocessing.cpu_count())
        num_workers=8)

    for batch_idx, batch in enumerate(loader):
        inputs, targets = _prepare_data(batch)
        print(f"inputs shape: {inputs.shape}")
        print(f"targets shape: {targets.shape}")
Exemplo n.º 20
0
 def test_no_load(self):
     dataset = torchio.ImagesDataset(self.subjects_list,
                                     load_image_data=False)
     for sample in dataset:
         pass
Exemplo n.º 21
0
 def iterate_dataset(subjects_list):
     dataset = torchio.ImagesDataset(subjects_list)
     for _ in dataset:
         pass
Exemplo n.º 22
0
def gridsampler_pipeline(
        input_array,
        entity_pts,
        patch_size=(64, 64, 64),
        patch_overlap=(0, 0, 0),
        batch_size=1,
):
    import torchio as tio
    from torchio import IMAGE, LOCATION
    from torchio.data.inference import GridAggregator, GridSampler

    logger.debug("Starting up gridsampler pipeline...")
    input_tensors = []
    output_tensors = []

    entity_pts = entity_pts.astype(np.int32)
    img_tens = torch.FloatTensor(input_array)

    one_subject = tio.Subject(
        img=tio.Image(tensor=img_tens, label=tio.INTENSITY),
        label=tio.Image(tensor=img_tens, label=tio.LABEL),
    )

    img_dataset = tio.ImagesDataset([
        one_subject,
    ])
    img_sample = img_dataset[-1]
    grid_sampler = GridSampler(img_sample, patch_size, patch_overlap)
    patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
    aggregator1 = GridAggregator(grid_sampler)
    aggregator2 = GridAggregator(grid_sampler)

    pipeline = Pipeline({
        "p":
        1,
        "ordered_ops": [
            make_masks,
            make_features,
            make_sr,
            make_seg_sr,
            make_seg_cnn,
        ],
    })

    payloads = []

    with torch.no_grad():
        for patches_batch in patch_loader:
            locations = patches_batch[LOCATION]

            loc_arr = np.array(locations[0])
            loc = (loc_arr[0], loc_arr[1], loc_arr[2])
            logger.debug(f"Location: {loc}")

            # Prepare region data (IMG (Float Volume) AND GEOMETRY (3d Point))
            cropped_vol, offset_pts = crop_vol_and_pts_centered(
                input_array,
                entity_pts,
                location=loc,
                patch_size=patch_size,
                offset=True,
                debug_verbose=True,
            )

            plt.figure(figsize=(12, 12))
            plt.imshow(cropped_vol[cropped_vol.shape[0] // 2, :], cmap="gray")
            plt.scatter(offset_pts[:, 1], offset_pts[:, 2])

            logger.debug(f"Number of offset_pts: {offset_pts.shape}")
            logger.debug(
                f"Allocating memory for no. voxels: {cropped_vol.shape[0] * cropped_vol.shape[1] * cropped_vol.shape[2]}"
            )

            # payload = Patch(
            #    {"in_array": cropped_vol},
            #    offset_pts,
            #    None,
            # )

            payload = Patch(
                {"total_mask": np.random.random((4, 4), )},
                {"total_anno": np.random.random((4, 4), )},
                {"points": np.random.random((4, 3), )},
            )
            pipeline.init_payload(payload)

            for step in pipeline:
                logger.debug(step)

            # Aggregation (Output: large volume aggregated from many smaller volumes)
            output_tensor = (torch.FloatTensor(
                payload.annotation_layers["total_mask"]).unsqueeze(
                    0).unsqueeze(1))
            logger.debug(
                f"Aggregating output tensor of shape: {output_tensor.shape}")
            aggregator1.add_batch(output_tensor, locations)

            output_tensor = (torch.FloatTensor(
                payload.annotation_layers["prediction"]).unsqueeze(
                    0).unsqueeze(1))
            logger.debug(
                f"Aggregating output tensor of shape: {output_tensor.shape}")
            aggregator2.add_batch(output_tensor, locations)
            payloads.append(payload)

    output_tensor1 = aggregator1.get_output_tensor()
    logger.debug(output_tensor1.shape)
    output_arr1 = np.array(output_tensor1.squeeze(0))

    output_tensor2 = aggregator2.get_output_tensor()
    logger.debug(output_tensor2.shape)
    output_arr2 = np.array(output_tensor2.squeeze(0))

    return [output_tensor1, output_tensor2], payloads
Exemplo n.º 23
0
                       label=torchio.INTENSITY),
    label=torchio.Image(tensor=torch.from_numpy(valid_seg),
                        label=torchio.LABEL),
)
# Define the transforms for the set of training patches
training_transform = Compose([
    RandomNoise(p=0.2),
    RandomFlip(axes=(0, 1, 2)),
    RandomBlur(p=0.2),
    OneOf({
        RandomAffine(): 0.8,
        RandomElasticDeformation(): 0.2,
    }, p=0.5),  # Changed from p=0.75 24/6/20
])
# Create the datasets
training_dataset = torchio.ImagesDataset(
    [train_subject], transform=training_transform)

validation_dataset = torchio.ImagesDataset(
    [valid_subject])
# Define the queue of sampled patches for training and validation
sampler = torchio.data.UniformSampler(PATCH_SIZE)
patches_training_set = torchio.Queue(
    subjects_dataset=training_dataset,
    max_length=MAX_QUEUE_LENGTH,
    samples_per_volume=TRAIN_PATCHES,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    shuffle_subjects=False,
    shuffle_patches=True,
)
Exemplo n.º 24
0
    return transform


if __name__ == "__main__":
    if not os.path.exists(cropped_resample_img_folder):
        os.mkdir(cropped_resample_img_folder)
    if not os.path.exists(cropped_resample_label_folder):
        os.mkdir(cropped_resample_label_folder)

    print(f"{ctime()}: starting ...")

    # for idx, mri in enumerate(get_path(datasets)):
    # if not COMPUTECANADA:
    # run_crop(idx, mri.img_path, mri.label_path, cropped_img_folder, cropped_label_folder)

    idx = 0

    subjects, visual_img_path_list, visual_label_path_list = get_subjects(use_cropped_resampled_data=False)

    transform = pre_transform()
    image_dataset = tio.ImagesDataset(subjects, transform=transform)
    loader = DataLoader(image_dataset,
                        batch_size=1)  # always one because using different label size

    for batch in tqdm(loader):
        idx += run_resample(batch, cropped_resample_img_folder, cropped_resample_label_folder)

    print(f"{ctime()}: ending ...")
    print(f"Totally get {idx} imgs!")
    # show_save_img_and_label(img_2D, label_2D, bbox_percentile_80, bbox_kmeans, "./rectangle_image", idx)
Exemplo n.º 25
0
def generate_dataset(data_path,
                     data_root='',
                     ref_path=None,
                     nb_subjects=5,
                     resampling='mni',
                     masking_method='label'):
    """
    Generate a torchio dataset from a csv file defining paths to subjects.

    :param data_path: path to a csv file
    :param data_root:
    :param ref_path:
    :param nb_subjects:
    :param resampling:
    :param masking_method:
    :return:
    """
    ds = pd.read_csv(data_path)
    ds = ds.dropna(subset=['suj'])
    np.random.seed(0)
    subject_idx = np.random.choice(range(len(ds)), nb_subjects, replace=False)
    directories = ds.iloc[subject_idx, 1]
    dir_list = directories.tolist()
    dir_list = map(lambda partial_dir: data_root + partial_dir, dir_list)

    subject_list = []
    for directory in dir_list:
        img_path = glob.glob(os.path.join(directory, 's*.nii.gz'))[0]

        mask_path = glob.glob(os.path.join(directory, 'niw_Mean*'))[0]
        coregistration_path = glob.glob(os.path.join(directory, 'aff*.txt'))[0]

        coregistration = np.loadtxt(coregistration_path, delimiter=' ')
        coregistration = np.linalg.inv(coregistration)

        subject = torchio.Subject(
            t1=torchio.Image(img_path,
                             torchio.INTENSITY,
                             coregistration=coregistration),
            label=torchio.Image(mask_path, torchio.LABEL),
            #ref=torchio.Image(ref_path, torchio.INTENSITY)
            # coregistration=coregistration,
        )
        print('adding img {} \n mask {}\n'.format(img_path, mask_path))
        subject_list.append(subject)

    transforms = [
        # Resample(1),
        RescaleIntensity((0, 1), (0, 99), masking_method=masking_method),
    ]

    if resampling == 'mni':
        # resampling_transform = ResampleWithFoV(
        #     target=nib.load(ref_path), image_interpolation=Interpolation.BSPLINE, coregistration_key='coregistration'
        # )
        resampling_transform = Resample(
            target='ref',
            image_interpolation=Interpolation.BSPLINE,
            coregistration='coregistration')
        transforms.insert(0, resampling_transform)
    elif resampling == 'mm':
        # resampling_transform = ResampleWithFoV(target=nib.load(ref_path), image_interpolation=Interpolation.BSPLINE)
        resampling_transform = Resample(
            target=ref_path, image_interpolation=Interpolation.BSPLINE)
        transforms.insert(0, resampling_transform)

    transform = Compose(transforms)

    return torchio.ImagesDataset(subject_list, transform=transform)
Exemplo n.º 26
0
def main(
    input_path,
    parcellation_path,
    output_image_path,
    output_label_path,
    min_volume,
    max_volume,
    volumes_path,
):
    """Console script for resector."""
    import torchio
    import resector
    hemispheres = 'left', 'right'
    input_path = Path(input_path)
    output_dir = input_path.parent
    stem = input_path.name.split('.nii')[0]  # assume it's a .nii file

    gm_paths = []
    resectable_paths = []
    for hemisphere in hemispheres:
        dst = output_dir / f'{stem}_gray_matter_{hemisphere}_seg.nii.gz'
        gm_paths.append(dst)
        if not dst.is_file():
            gm = resector.parcellation.get_gray_matter_mask(
                parcellation_path, hemisphere)
            resector.io.write(gm, dst)
        dst = output_dir / f'{stem}_resectable_{hemisphere}_seg.nii.gz'
        resectable_paths.append(dst)
        if not dst.is_file():
            resectable = resector.parcellation.get_resectable_hemisphere_mask(
                parcellation_path,
                hemisphere,
            )
            resector.io.write(resectable, dst)
    noise_path = output_dir / f'{stem}_noise.nii.gz'
    if not noise_path.is_file():
        resector.parcellation.make_noise_image(
            input_path,
            parcellation_path,
            noise_path,
        )

    if volumes_path is not None:
        import pandas as pd
        df = pd.read_csv(volumes_path)
        volumes = df.Volume.values
        kwargs = dict(volumes=volumes)
    else:
        kwargs = dict(volumes_range=(min_volume, max_volume))

    transform = torchio.Compose((
        torchio.ToCanonical(),
        resector.RandomResection(**kwargs),
    ))
    subject = torchio.Subject(
        image=torchio.Image(input_path, torchio.INTENSITY),
        resection_resectable_left=torchio.Image(resectable_paths[0],
                                                torchio.LABEL),
        resection_resectable_right=torchio.Image(resectable_paths[1],
                                                 torchio.LABEL),
        resection_gray_matter_left=torchio.Image(gm_paths[0], torchio.LABEL),
        resection_gray_matter_right=torchio.Image(gm_paths[1], torchio.LABEL),
        resection_noise=torchio.Image(noise_path, None),
    )
    dataset = torchio.ImagesDataset([subject], transform=transform)
    resected = dataset[0]
    dataset.save_sample(
        resected,
        dict(image=output_image_path, label=output_label_path),
    )

    return 0