def __call__(self, dataset: extr.Dataset,
                 data_config: cfg.DataConfiguration, **kwargs):
        if not isinstance(dataset, extr.ParameterizableDataset):
            raise ValueError('dataset neeeds to be of type {}'.format(
                extr.ParameterizableDataset.__class__.__name__))

        selection_params = data_config.selection_strategy
        if not isinstance(selection_params, (list, tuple)):
            selection_params = [selection_params]

        subject_selection_params = [
            p for p in selection_params if p.type == 'subject'
        ]
        if len(subject_selection_params) > 0:
            if 'entries' not in kwargs:
                raise ValueError('"entries" needed in kwargs to build sampler')
            entries = kwargs['entries']

            assert len(subject_selection_params) == 1
            subject_selection_param = subject_selection_params[
                0]  # type: cfg.ParameterClass

            # add the subject entries to the parameters of the type
            subject_selection_param.params = entries

        selection_extractor = factory.get_extractor(
            data_config.selection_extractor)
        selection_strategy = factory.get_selection_strategy(selection_params)

        indices = self.get_indices(dataset, selection_extractor,
                                   selection_strategy)
        if data_config.shuffle:
            return extr.SubsetRandomSampler(indices=indices)
        return extr.SubsetSequentialSampler(indices=indices)
    def __call__(self, dataset: extr.Dataset,
                 data_config: cfg.DataConfiguration, **kwargs):
        if 'entries' not in kwargs:
            raise ValueError('"entries" needed in kwargs to build sampler')

        entries = kwargs['entries']
        if data_config.shuffle:
            return extr.SubsetRandomSampler(indices=entries)
        return extr.SubsetSequentialSampler(indices=entries)
    def __init__(self,
                 config: cfg.Configuration,
                 subjects_train,
                 subjects_valid,
                 subjects_test,
                 collate_fn=pymia_cnv.TensorFlowCollate()):
        super().__init__()

        indexing_strategy = PointCloudIndexing(config.no_points)

        self.dataset = pymia_extr.ParameterizableDataset(
            config.database_file,
            indexing_strategy,
            pymia_extr.SubjectExtractor(),  # for the usual select_indices
            None)

        self.no_subjects_train = len(subjects_train)
        self.no_subjects_valid = len(subjects_valid)
        self.no_subjects_test = len(
            subjects_valid
        )  # same as validation for this kind of cross validation

        # get sampler ids by subjects
        sampler_ids_train = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_train))
        sampler_ids_valid = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_valid))

        categories = ('images', 'labels')
        categories_tfm = ('images', 'labels')
        if config.use_image_information:
            image_information_categories = (data.KEY_IMAGE_INFORMATION, )
            categories += image_information_categories
            categories_tfm += image_information_categories
            collate_fn.entries += image_information_categories

        # define point cloud shuffler for augmentation
        sizes = {}
        for idx in range(len(self.dataset.get_subjects())):
            sample = self.dataset.direct_extract(PointCloudSizeExtractor(),
                                                 idx)
            sizes[idx] = sample['size']
        self.point_cloud_shuffler = aug.PointCloudShuffler(sizes)
        self.point_cloud_shuffler_valid = aug.PointCloudShuffler(
            sizes
        )  # will only shuffle once at instantiation because shuffle() is not called during training (see set_seed)

        data_extractor_train = aug.ShuffledDataExtractor(
            self.point_cloud_shuffler, categories)
        data_extractor_valid = aug.ShuffledDataExtractor(
            self.point_cloud_shuffler_valid, categories)
        data_extractor_test = aug.ShuffledDataExtractor(
            self.point_cloud_shuffler_valid, categories=('indices', 'labels'))

        # define extractors
        self.extractor_train = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.SubjectExtractor(),  # required for plotting
            PointCloudSizeExtractor(),  # to init_shape in SubjectAssembler
            data_extractor_train,
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_valid = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.SubjectExtractor(),  # required for plotting
            PointCloudSizeExtractor(),  # to init_shape in SubjectAssembler
            data_extractor_valid,
            pymia_extr.IndexingExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        self.extractor_test = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.SubjectExtractor(),
            data_extractor_test,  # we need the indices, i.e. the point's coordinates,
            # to convert the point cloud back to an image
            pymia_extr.DataExtractor(
                categories=('gt', ),
                ignore_indexing=True),  # the ground truth is used for the
            # validation at config.save_validation_nth_epoch
            pymia_extr.ImagePropertiesExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        # define transforms for extraction
        self.extraction_transform_train = pymia_tfm.ComposeTransform([
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=-1),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
        ])

        self.extraction_transform_valid = pymia_tfm.ComposeTransform([
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=-1),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', ))
            # for PyTorch loss functions
        ])

        if config.use_jitter:
            self.extraction_transform_train.transforms.append(
                aug.PointCloudJitter())
            self.extraction_transform_valid.transforms.append(
                aug.PointCloudJitter())

        if config.use_rotation:
            self.extraction_transform_train.transforms.append(
                aug.PointCloudRotate())
            self.extraction_transform_valid.transforms.append(
                aug.PointCloudRotate())

        # need to add probability concatenation after augmentation!
        if config.use_point_feature:
            self.extraction_transform_train.transforms.append(
                ConcatenateCoordinatesAndPointFeatures())
            self.extraction_transform_valid.transforms.append(
                ConcatenateCoordinatesAndPointFeatures())

        if config.use_image_information:
            spatial_size = config.image_information_config.spatial_size

            def slice_patches(np_data):
                z = (np_data.shape[1] - spatial_size) // 2
                y = (np_data.shape[2] - spatial_size) // 2
                x = (np_data.shape[3] - spatial_size) // 2

                np_data = np_data[:, z:(z + spatial_size),
                                  y:(y + spatial_size),
                                  x:(x + spatial_size), :]
                return np_data

            self.extraction_transform_train.transforms.append(
                pymia_tfm.LambdaTransform(
                    lambda_fn=slice_patches,
                    entries=image_information_categories))
            self.extraction_transform_valid.transforms.append(
                pymia_tfm.LambdaTransform(
                    lambda_fn=slice_patches,
                    entries=image_information_categories))

        # define loaders
        training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
        self.loader_train = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_training,
                                                  sampler=training_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        validation_sampler = pymia_extr.SubsetSequentialSampler(
            sampler_ids_valid)
        self.loader_valid = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_testing,
                                                  sampler=validation_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        self.loader_test = pymia_extr.DataLoader(self.dataset,
                                                 config.batch_size_testing,
                                                 sampler=validation_sampler,
                                                 collate_fn=collate_fn,
                                                 num_workers=1)

        self.extraction_transform_test = None
Exemple #4
0
    def __init__(self,
                 config: cfg.Configuration,
                 subjects_train,
                 subjects_valid,
                 subjects_test,
                 collate_fn=pymia_cnv.TorchCollate(
                     ('images', 'labels', 'mask_fg', 'mask_t1h2o'))):
        super().__init__()

        indexing_strategy = pymia_extr.SliceIndexing()

        self.dataset = pymia_extr.ParameterizableDataset(
            config.database_file,
            indexing_strategy,
            pymia_extr.SubjectExtractor(),  # for the usual select_indices
            None)

        self.no_subjects_train = len(subjects_train)
        self.no_subjects_valid = len(subjects_valid)
        self.no_subjects_test = 0

        # get sampler ids by subjects
        sampler_ids_train = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_train))
        sampler_ids_valid = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_valid))

        # define extractors
        self.extractor_train = pymia_extr.ComposeExtractor([
            pymia_extr.DataExtractor(categories=('images', 'labels')),
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_valid = pymia_extr.ComposeExtractor([
            pymia_extr.DataExtractor(categories=('images', 'labels')),
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_test = pymia_extr.ComposeExtractor([
            pymia_extr.SubjectExtractor(),
            pymia_extr.DataExtractor(categories=('labels', )),
            pymia_extr.ImagePropertiesExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        # define transforms for extraction
        self.extraction_transform_train = pymia_tfm.ComposeTransform([
            pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)),
            pymia_tfm.Permute((2, 0, 1)),
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=0),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
            pymia_tfm.ToTorchTensor()
        ])

        self.extraction_transform_valid = pymia_tfm.ComposeTransform([
            pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)),
            pymia_tfm.Permute((2, 0, 1)),
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=0),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
            pymia_tfm.ToTorchTensor()
        ])

        self.extraction_transform_test = None

        # define loaders
        training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
        self.loader_train = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_training,
                                                  sampler=training_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        validation_sampler = pymia_extr.SubsetSequentialSampler(
            sampler_ids_valid)
        self.loader_valid = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_testing,
                                                  sampler=validation_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        self.loader_test = None
    def __init__(self,
                 config: cfg.Configuration,
                 subjects_train,
                 subjects_valid,
                 subjects_test,
                 is_subject_selection: bool = True,
                 collate_fn=lib_cnv.TensorFlowCollate(),
                 padding_size: tuple = (0, 0, 0)):
        super().__init__()

        indexing_strategy = pymia_extr.PatchWiseIndexing(
            patch_shape=config.patch_size, ignore_incomplete=False)

        self.dataset = pymia_extr.ParameterizableDataset(
            config.database_file,
            indexing_strategy,
            pymia_extr.SubjectExtractor(),  # for the select_indices
            None)

        self.no_subjects_train = len(subjects_train)
        self.no_subjects_valid = len(subjects_valid)
        self.no_subjects_test = len(subjects_test)

        if is_subject_selection:
            # get sampler ids by subjects
            sampler_ids_train = pymia_extr.select_indices(
                self.dataset, pymia_extr.SubjectSelection(subjects_train))
            sampler_ids_valid = pymia_extr.select_indices(
                self.dataset, pymia_extr.SubjectSelection(subjects_valid))
            sampler_ids_test = pymia_extr.select_indices(
                self.dataset, pymia_extr.SubjectSelection(subjects_test))
        else:
            # get sampler ids from indices files
            sampler_ids_train, sampler_ids_valid, sampler_ids_test = pkl.load_sampler_ids(
                config.indices_dir, pkl.PATCH_WISE_FILE_NAME, subjects_train,
                subjects_valid, subjects_test)

        # define extractors
        self.extractor_train = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.PadDataExtractor(
                padding=padding_size,
                extractor=pymia_extr.DataExtractor(
                    categories=(pymia_def.KEY_IMAGES, ))),
            pymia_extr.PadDataExtractor(
                padding=(0, 0, 0),
                extractor=pymia_extr.SelectiveDataExtractor(
                    selection=config.maps, category=pymia_def.KEY_LABELS)),
            pymia_extr.PadDataExtractor(padding=(0, 0, 0),
                                        extractor=pymia_extr.DataExtractor(
                                            categories=(defs.ID_MASK_FG,
                                                        defs.ID_MASK_T1H2O))),
            pymia_extr.IndexingExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        # to calculate validation loss, we require the labels and mask during validation
        self.extractor_valid = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.PadDataExtractor(
                padding=padding_size,
                extractor=pymia_extr.DataExtractor(
                    categories=(pymia_def.KEY_IMAGES, ))),
            pymia_extr.PadDataExtractor(
                padding=(0, 0, 0),
                extractor=pymia_extr.SelectiveDataExtractor(
                    selection=config.maps, category=pymia_def.KEY_LABELS)),
            pymia_extr.PadDataExtractor(
                padding=(0, 0, 0),
                extractor=pymia_extr.DataExtractor(
                    categories=(defs.ID_MASK_FG, defs.ID_MASK_T1H2O,
                                defs.ID_MASK_ROI, defs.ID_MASK_ROI_T1H2O))),
            pymia_extr.IndexingExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        self.extractor_test = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.SubjectExtractor(),
            pymia_extr.SelectiveDataExtractor(selection=config.maps,
                                              category=pymia_def.KEY_LABELS),
            pymia_extr.DataExtractor(categories=(defs.ID_MASK_FG,
                                                 defs.ID_MASK_T1H2O,
                                                 defs.ID_MASK_ROI,
                                                 defs.ID_MASK_ROI_T1H2O)),
            pymia_extr.ImagePropertiesExtractor(),
            pymia_extr.ImageShapeExtractor(),
            ext.NormalizationExtractor()
        ])

        # define transforms for extraction
        # after extraction, the first dimension is the batch dimension.
        # E.g., shape = (1, 16, 16, 4) instead of (16, 16, 4) --> therefore squeeze the data
        self.extraction_transform_train = pymia_tfm.Squeeze(
            entries=(pymia_def.KEY_IMAGES, pymia_def.KEY_LABELS,
                     defs.ID_MASK_FG, defs.ID_MASK_T1H2O),
            squeeze_axis=0)
        self.extraction_transform_valid = pymia_tfm.Squeeze(
            entries=(pymia_def.KEY_IMAGES, pymia_def.KEY_LABELS,
                     defs.ID_MASK_FG, defs.ID_MASK_T1H2O),
            squeeze_axis=0)
        self.extraction_transform_test = None

        # define loaders
        training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
        self.loader_train = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_training,
                                                  sampler=training_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        validation_sampler = pymia_extr.SubsetSequentialSampler(
            sampler_ids_valid)
        self.loader_valid = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_testing,
                                                  sampler=validation_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        testing_sampler = pymia_extr.SubsetSequentialSampler(sampler_ids_test)
        self.loader_test = pymia_extr.DataLoader(self.dataset,
                                                 config.batch_size_testing,
                                                 sampler=testing_sampler,
                                                 collate_fn=collate_fn,
                                                 num_workers=1)
Exemple #6
0
def main(config_file: str):
    config = cfg.load(config_file, cfg.Configuration)
    print(config)

    indexing_strategy = pymia_extr.SliceIndexing()  # slice-wise extraction
    extraction_transform = None  # we do not want to apply any transformation on the slices after extraction
    # define an extractor for training, i.e. what information we would like to extract per sample
    train_extractor = pymia_extr.ComposeExtractor([pymia_extr.NamesExtractor(),
                                                   pymia_extr.DataExtractor(),
                                                   pymia_extr.SelectiveDataExtractor()])

    # define an extractor for testing, i.e. what information we would like to extract per sample
    # not that usually we don't use labels for testing, i.e. the SelectiveDataExtractor is only used for this example
    test_extractor = pymia_extr.ComposeExtractor([pymia_extr.NamesExtractor(),
                                                  pymia_extr.IndexingExtractor(),
                                                  pymia_extr.DataExtractor(),
                                                  pymia_extr.SelectiveDataExtractor(),
                                                  pymia_extr.ImageShapeExtractor()])

    # define an extractor for evaluation, i.e. what information we would like to extract per sample
    eval_extractor = pymia_extr.ComposeExtractor([pymia_extr.NamesExtractor(),
                                                  pymia_extr.SubjectExtractor(),
                                                  pymia_extr.SelectiveDataExtractor(),
                                                  pymia_extr.ImagePropertiesExtractor()])

    # define the data set
    dataset = pymia_extr.ParameterizableDataset(config.database_file,
                                                indexing_strategy,
                                                pymia_extr.SubjectExtractor(),  # for select_indices() below
                                                extraction_transform)

    # generate train / test split for data set
    # we use Subject_0, Subject_1 and Subject_2 for training and Subject_3 for testing
    sampler_ids_train = pymia_extr.select_indices(dataset,
                                                  pymia_extr.SubjectSelection(('Subject_0', 'Subject_1', 'Subject_2')))
    sampler_ids_test = pymia_extr.select_indices(dataset,
                                                 pymia_extr.SubjectSelection(('Subject_3')))

    # set up training data loader
    training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
    training_loader = pymia_extr.DataLoader(dataset, config.batch_size_training, sampler=training_sampler,
                                            collate_fn=collate_batch, num_workers=1)

    # set up testing data loader
    testing_sampler = pymia_extr.SubsetSequentialSampler(sampler_ids_test)
    testing_loader = pymia_extr.DataLoader(dataset, config.batch_size_testing, sampler=testing_sampler,
                                           collate_fn=collate_batch, num_workers=1)

    sample = dataset.direct_extract(train_extractor, 0)  # extract a subject

    evaluator = init_evaluator()  # initialize evaluator

    for epoch in range(config.epochs):  # epochs loop
        dataset.set_extractor(train_extractor)
        for batch in training_loader:  # batches for training
            # feed_dict = batch_to_feed_dict(x, y, batch, True)  # e.g. for TensorFlow
            # train model, e.g.:
            # sess.run([train_op, loss], feed_dict=feed_dict)
            pass

        # subject assembler for testing
        subject_assembler = pymia_asmbl.SubjectAssembler()

        dataset.set_extractor(test_extractor)
        for batch in testing_loader:  # batches for testing
            # feed_dict = batch_to_feed_dict(x, y, batch, False)  # e.g. for TensorFlow
            # test model, e.g.:
            # prediction = sess.run(y_model, feed_dict=feed_dict)
            prediction = np.stack(batch['labels'], axis=0)  # we use the labels as predictions such that we can validate the assembler
            subject_assembler.add_batch(prediction, batch)

        # evaluate all test images
        for subject_idx in list(subject_assembler.predictions.keys()):
            # convert prediction and labels back to SimpleITK images
            sample = dataset.direct_extract(eval_extractor, subject_idx)
            label_image = pymia_conv.NumpySimpleITKImageBridge.convert(sample['labels'],
                                                                       sample['properties'])

            assembled = subject_assembler.get_assembled_subject(sample['subject_index'])
            prediction_image = pymia_conv.NumpySimpleITKImageBridge.convert(assembled, sample['properties'])
            evaluator.evaluate(prediction_image, label_image, sample['subject'])  # evaluate prediction