Ejemplo n.º 1
0
    def __getitem__(self, index):
        """Return the specific index pair subvolume (input, ground truth).

        Args:
            index (int): Subvolume index.
        """
        coord = self.indexes[index]
        seg_pair, _ = self.handlers[coord['handler_index']]

        # Clean transforms params from previous transforms
        # i.e. remove params from previous iterations so that the coming transforms are different
        # Use copy to have different coordinates for reconstruction for a given handler
        metadata_input = imed_loader_utils.clean_metadata(
            copy.deepcopy(seg_pair['input_metadata']))
        metadata_gt = imed_loader_utils.clean_metadata(
            copy.deepcopy(seg_pair['gt_metadata']))

        # Run transforms on images
        stack_input, metadata_input = self.transform(sample=seg_pair['input'],
                                                     metadata=metadata_input,
                                                     data_type="im")
        # Update metadata_gt with metadata_input
        metadata_gt = imed_loader_utils.update_metadata(
            metadata_input, metadata_gt)

        # Run transforms on images
        stack_gt, metadata_gt = self.transform(sample=seg_pair['gt'],
                                               metadata=metadata_gt,
                                               data_type="gt")
        # Make sure stack_gt is binarized
        if stack_gt is not None and not self.soft_gt:
            stack_gt = imed_postpro.threshold_predictions(stack_gt, thr=0.5)

        shape_x = coord["x_max"] - coord["x_min"]
        shape_y = coord["y_max"] - coord["y_min"]
        shape_z = coord["z_max"] - coord["z_min"]

        # add coordinates to metadata to reconstruct volume
        for metadata in metadata_input:
            metadata['coord'] = [
                coord["x_min"], coord["x_max"], coord["y_min"], coord["y_max"],
                coord["z_min"], coord["z_max"]
            ]

        subvolumes = {
            'input':
            torch.zeros(stack_input.shape[0], shape_x, shape_y, shape_z),
            'gt':
            torch.zeros(stack_gt.shape[0], shape_x, shape_y, shape_z)
            if stack_gt is not None else None,
            'input_metadata':
            metadata_input,
            'gt_metadata':
            metadata_gt
        }

        for _ in range(len(stack_input)):
            subvolumes['input'] = stack_input[:, coord['x_min']:coord['x_max'],
                                              coord['y_min']:coord['y_max'],
                                              coord['z_min']:coord['z_max']]

        if stack_gt is not None:
            for _ in range(len(stack_gt)):
                subvolumes['gt'] = stack_gt[:, coord['x_min']:coord['x_max'],
                                            coord['y_min']:coord['y_max'],
                                            coord['z_min']:coord['z_max']]

        return subvolumes
Ejemplo n.º 2
0
    def __getitem__(self, index):
        """Return the specific processed data corresponding to index (input, ground truth, roi and metadata).

        Args:
            index (int): Slice index.
        """
        seg_pair_slice, roi_pair_slice = self.indexes[index]

        # Clean transforms params from previous transforms
        # i.e. remove params from previous iterations so that the coming transforms are different
        metadata_input = imed_loader_utils.clean_metadata(
            seg_pair_slice['input_metadata'])
        metadata_roi = imed_loader_utils.clean_metadata(
            roi_pair_slice['gt_metadata'])
        metadata_gt = imed_loader_utils.clean_metadata(
            seg_pair_slice['gt_metadata'])

        # Run transforms on ROI
        # ROI goes first because params of ROICrop are needed for the followings
        stack_roi, metadata_roi = self.transform(sample=roi_pair_slice["gt"],
                                                 metadata=metadata_roi,
                                                 data_type="roi")

        # Update metadata_input with metadata_roi
        metadata_input = imed_loader_utils.update_metadata(
            metadata_roi, metadata_input)

        # Run transforms on images
        stack_input, metadata_input = self.transform(
            sample=seg_pair_slice["input"],
            metadata=metadata_input,
            data_type="im")

        # Update metadata_input with metadata_roi
        metadata_gt = imed_loader_utils.update_metadata(
            metadata_input, metadata_gt)

        if self.task == "segmentation":
            # Run transforms on images
            stack_gt, metadata_gt = self.transform(sample=seg_pair_slice["gt"],
                                                   metadata=metadata_gt,
                                                   data_type="gt")
            # Make sure stack_gt is binarized
            if stack_gt is not None and not self.soft_gt:
                stack_gt = imed_postpro.threshold_predictions(stack_gt,
                                                              thr=0.5)

        else:
            # Force no transformation on labels for classification task
            # stack_gt is a tensor of size 1x1, values: 0 or 1
            # "expand(1)" is necessary to be compatible with segmentation convention: n_labelxhxwxd
            stack_gt = torch.from_numpy(seg_pair_slice["gt"][0]).expand(1)

        data_dict = {
            'input': stack_input,
            'gt': stack_gt,
            'roi': stack_roi,
            'input_metadata': metadata_input,
            'gt_metadata': metadata_gt,
            'roi_metadata': metadata_roi
        }

        return data_dict