Ejemplo n.º 1
0
    def add_batch(self,
                  to_assemble: typing.Union[np.ndarray,
                                            typing.Dict[str, np.ndarray]],
                  sample_indices: np.ndarray,
                  last_batch=False,
                  **kwargs):
        """see :meth:`Assembler.add_batch`"""

        if not isinstance(to_assemble, dict):
            to_assemble = {'__prediction': to_assemble}

        sample_indices = sample_indices.tolist(
        )  # to ensure that not np.int64 entries, but int

        for batch_idx, sample_idx in enumerate(sample_indices):
            subject_index, index_expression = self.datasource.indices[
                sample_idx]

            plane_dimension = self._get_plane_dimension(index_expression)

            if plane_dimension not in self.planes:
                self.planes[plane_dimension] = SubjectAssembler(
                    self.datasource, self.zero_fn)

            indexing = index_expression.get_indexing()
            if not isinstance(indexing, list):
                indexing = [indexing]

            extractor = extr.ImagePropertyShapeExtractor(numpy_format=True)
            required_plane_shape = self.datasource.direct_extract(
                extractor, subject_index)[defs.KEY_SHAPE]

            index_at_plane = indexing[plane_dimension]
            if isinstance(index_at_plane, tuple):
                # is a range in the off plane direction (tuple)
                required_off_plane_size = index_at_plane[1] - index_at_plane[0]
                required_plane_shape[plane_dimension] = required_off_plane_size
            else:  # isinstance of int
                # is one slice in off plane direction (int)
                required_plane_shape.pop(plane_dimension)
            transform = tfm.SizeCorrection(tuple(required_plane_shape),
                                           entries=tuple(to_assemble.keys()))
            self.planes[
                plane_dimension].assemble_interaction_fn = ApplyTransformInteractionFn(
                    transform)

            self.planes[plane_dimension].add_sample(to_assemble, batch_idx,
                                                    sample_idx)

        ready = None
        for plane_assembler in self.planes.values():
            if last_batch:
                plane_assembler.end()

            if ready is None:
                ready = set(plane_assembler.subjects_ready)
            else:
                ready.intersection_update(plane_assembler.subjects_ready)
        self._subjects_ready = ready
Ejemplo n.º 2
0
    def add_batch(self, to_assemble, batch: dict, last_batch=False):
        if 'index_expr' not in batch:
            raise ValueError(
                'SubjectAssembler requires "index_expr" to be extracted (use IndexingExtractor)'
            )
        if 'shape' not in batch:
            raise ValueError(
                'SubjectAssembler requires "shape" to be extracted (use ImageShapeExtractor)'
            )

        if not isinstance(to_assemble, dict):
            to_assemble = {'__prediction': to_assemble}

        for idx in range(len(batch['index_expr'])):
            index_expr = batch['index_expr'][idx]
            if isinstance(index_expr, bytes):
                # is pickled
                index_expr = pickle.loads(index_expr)
            plane_dimension = self._get_plane_dimension(index_expr)

            if plane_dimension not in self.planes:
                self.planes[plane_dimension] = SubjectAssembler(self.zero_fn)

            indexing = index_expr.get_indexing()
            if not isinstance(indexing, list):
                indexing = [indexing]

            required_plane_shape = list(batch['shape'][idx])

            index_at_plane = indexing[plane_dimension]
            if isinstance(index_at_plane, tuple):
                # is a range in the off plane direction (tuple)
                required_off_plane_size = index_at_plane[1] - index_at_plane[0]
                required_plane_shape[plane_dimension] = required_off_plane_size
            else:  # isinstance of int
                # is one slice in off plane direction (int)
                required_plane_shape.pop(plane_dimension)
            transform = tfm.SizeCorrection(tuple(required_plane_shape),
                                           entries=tuple(to_assemble.keys()))
            self.planes[plane_dimension].on_sample_fn = TransformSampleFn(
                transform)

            self.planes[plane_dimension].add_sample(to_assemble, batch, idx)

        ready = None
        for plane_assembler in self.planes.values():
            if last_batch:
                plane_assembler.end()

            if ready is None:
                ready = set(plane_assembler.subjects_ready)
            else:
                ready.intersection_update(plane_assembler.subjects_ready)
        self._subjects_ready = ready
Ejemplo n.º 3
0
def size_correction(params: dict):
    data = params['__prediction']
    idx = params['batch_idx']
    batch = params['batch']

    data = np.transpose(data,
                        (1, 2, 0))  # transpose back from PyTorch convention
    data = np.argmax(data, -1)  # convert to class labels

    # correct size
    correct_shape = batch['shape'][idx][1:]
    transform = tfm.SizeCorrection(correct_shape, 0, entries=('data', ))
    data = transform({'data': data})['data']

    data = np.expand_dims(data, -1)  # for dataset convention
    return data, batch
Ejemplo n.º 4
0
    def __init__(self,
                 config: cfg.Configuration,
                 subjects_train,
                 subjects_valid,
                 subjects_test,
                 collate_fn=pymia_cnv.TorchCollate(
                     ('images', 'labels', 'mask_fg', 'mask_t1h2o'))):
        super().__init__()

        indexing_strategy = pymia_extr.SliceIndexing()

        self.dataset = pymia_extr.ParameterizableDataset(
            config.database_file,
            indexing_strategy,
            pymia_extr.SubjectExtractor(),  # for the usual select_indices
            None)

        self.no_subjects_train = len(subjects_train)
        self.no_subjects_valid = len(subjects_valid)
        self.no_subjects_test = 0

        # get sampler ids by subjects
        sampler_ids_train = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_train))
        sampler_ids_valid = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_valid))

        # define extractors
        self.extractor_train = pymia_extr.ComposeExtractor([
            pymia_extr.DataExtractor(categories=('images', 'labels')),
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_valid = pymia_extr.ComposeExtractor([
            pymia_extr.DataExtractor(categories=('images', 'labels')),
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_test = pymia_extr.ComposeExtractor([
            pymia_extr.SubjectExtractor(),
            pymia_extr.DataExtractor(categories=('labels', )),
            pymia_extr.ImagePropertiesExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        # define transforms for extraction
        self.extraction_transform_train = pymia_tfm.ComposeTransform([
            pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)),
            pymia_tfm.Permute((2, 0, 1)),
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=0),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
            pymia_tfm.ToTorchTensor()
        ])

        self.extraction_transform_valid = pymia_tfm.ComposeTransform([
            pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)),
            pymia_tfm.Permute((2, 0, 1)),
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=0),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
            pymia_tfm.ToTorchTensor()
        ])

        self.extraction_transform_test = None

        # define loaders
        training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
        self.loader_train = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_training,
                                                  sampler=training_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        validation_sampler = pymia_extr.SubsetSequentialSampler(
            sampler_ids_valid)
        self.loader_valid = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_testing,
                                                  sampler=validation_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        self.loader_test = None