def __init__(self,
                 config: cfg.Configuration,
                 subjects_train,
                 subjects_valid,
                 subjects_test,
                 is_subject_selection: bool = True,
                 collate_fn=lib_cnv.TensorFlowCollate(),
                 padding_size: tuple = (0, 0, 0)):
        super().__init__()

        indexing_strategy = pymia_extr.PatchWiseIndexing(
            patch_shape=config.patch_size, ignore_incomplete=False)

        self.dataset = pymia_extr.ParameterizableDataset(
            config.database_file,
            indexing_strategy,
            pymia_extr.SubjectExtractor(),  # for the select_indices
            None)

        self.no_subjects_train = len(subjects_train)
        self.no_subjects_valid = len(subjects_valid)
        self.no_subjects_test = len(subjects_test)

        if is_subject_selection:
            # get sampler ids by subjects
            sampler_ids_train = pymia_extr.select_indices(
                self.dataset, pymia_extr.SubjectSelection(subjects_train))
            sampler_ids_valid = pymia_extr.select_indices(
                self.dataset, pymia_extr.SubjectSelection(subjects_valid))
            sampler_ids_test = pymia_extr.select_indices(
                self.dataset, pymia_extr.SubjectSelection(subjects_test))
        else:
            # get sampler ids from indices files
            sampler_ids_train, sampler_ids_valid, sampler_ids_test = pkl.load_sampler_ids(
                config.indices_dir, pkl.PATCH_WISE_FILE_NAME, subjects_train,
                subjects_valid, subjects_test)

        # define extractors
        self.extractor_train = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.PadDataExtractor(
                padding=padding_size,
                extractor=pymia_extr.DataExtractor(
                    categories=(pymia_def.KEY_IMAGES, ))),
            pymia_extr.PadDataExtractor(
                padding=(0, 0, 0),
                extractor=pymia_extr.SelectiveDataExtractor(
                    selection=config.maps, category=pymia_def.KEY_LABELS)),
            pymia_extr.PadDataExtractor(padding=(0, 0, 0),
                                        extractor=pymia_extr.DataExtractor(
                                            categories=(defs.ID_MASK_FG,
                                                        defs.ID_MASK_T1H2O))),
            pymia_extr.IndexingExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        # to calculate validation loss, we require the labels and mask during validation
        self.extractor_valid = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.PadDataExtractor(
                padding=padding_size,
                extractor=pymia_extr.DataExtractor(
                    categories=(pymia_def.KEY_IMAGES, ))),
            pymia_extr.PadDataExtractor(
                padding=(0, 0, 0),
                extractor=pymia_extr.SelectiveDataExtractor(
                    selection=config.maps, category=pymia_def.KEY_LABELS)),
            pymia_extr.PadDataExtractor(
                padding=(0, 0, 0),
                extractor=pymia_extr.DataExtractor(
                    categories=(defs.ID_MASK_FG, defs.ID_MASK_T1H2O,
                                defs.ID_MASK_ROI, defs.ID_MASK_ROI_T1H2O))),
            pymia_extr.IndexingExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        self.extractor_test = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.SubjectExtractor(),
            pymia_extr.SelectiveDataExtractor(selection=config.maps,
                                              category=pymia_def.KEY_LABELS),
            pymia_extr.DataExtractor(categories=(defs.ID_MASK_FG,
                                                 defs.ID_MASK_T1H2O,
                                                 defs.ID_MASK_ROI,
                                                 defs.ID_MASK_ROI_T1H2O)),
            pymia_extr.ImagePropertiesExtractor(),
            pymia_extr.ImageShapeExtractor(),
            ext.NormalizationExtractor()
        ])

        # define transforms for extraction
        # after extraction, the first dimension is the batch dimension.
        # E.g., shape = (1, 16, 16, 4) instead of (16, 16, 4) --> therefore squeeze the data
        self.extraction_transform_train = pymia_tfm.Squeeze(
            entries=(pymia_def.KEY_IMAGES, pymia_def.KEY_LABELS,
                     defs.ID_MASK_FG, defs.ID_MASK_T1H2O),
            squeeze_axis=0)
        self.extraction_transform_valid = pymia_tfm.Squeeze(
            entries=(pymia_def.KEY_IMAGES, pymia_def.KEY_LABELS,
                     defs.ID_MASK_FG, defs.ID_MASK_T1H2O),
            squeeze_axis=0)
        self.extraction_transform_test = None

        # define loaders
        training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
        self.loader_train = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_training,
                                                  sampler=training_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        validation_sampler = pymia_extr.SubsetSequentialSampler(
            sampler_ids_valid)
        self.loader_valid = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_testing,
                                                  sampler=validation_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        testing_sampler = pymia_extr.SubsetSequentialSampler(sampler_ids_test)
        self.loader_test = pymia_extr.DataLoader(self.dataset,
                                                 config.batch_size_testing,
                                                 sampler=testing_sampler,
                                                 collate_fn=collate_fn,
                                                 num_workers=1)
示例#2
0
    def __init__(self,
                 config: cfg.Configuration,
                 subjects_train,
                 subjects_valid,
                 subjects_test,
                 collate_fn=pymia_cnv.TorchCollate(
                     ('images', 'labels', 'mask_fg', 'mask_t1h2o'))):
        super().__init__()

        indexing_strategy = pymia_extr.SliceIndexing()

        self.dataset = pymia_extr.ParameterizableDataset(
            config.database_file,
            indexing_strategy,
            pymia_extr.SubjectExtractor(),  # for the usual select_indices
            None)

        self.no_subjects_train = len(subjects_train)
        self.no_subjects_valid = len(subjects_valid)
        self.no_subjects_test = 0

        # get sampler ids by subjects
        sampler_ids_train = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_train))
        sampler_ids_valid = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_valid))

        # define extractors
        self.extractor_train = pymia_extr.ComposeExtractor([
            pymia_extr.DataExtractor(categories=('images', 'labels')),
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_valid = pymia_extr.ComposeExtractor([
            pymia_extr.DataExtractor(categories=('images', 'labels')),
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_test = pymia_extr.ComposeExtractor([
            pymia_extr.SubjectExtractor(),
            pymia_extr.DataExtractor(categories=('labels', )),
            pymia_extr.ImagePropertiesExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        # define transforms for extraction
        self.extraction_transform_train = pymia_tfm.ComposeTransform([
            pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)),
            pymia_tfm.Permute((2, 0, 1)),
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=0),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
            pymia_tfm.ToTorchTensor()
        ])

        self.extraction_transform_valid = pymia_tfm.ComposeTransform([
            pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)),
            pymia_tfm.Permute((2, 0, 1)),
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=0),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
            pymia_tfm.ToTorchTensor()
        ])

        self.extraction_transform_test = None

        # define loaders
        training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
        self.loader_train = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_training,
                                                  sampler=training_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        validation_sampler = pymia_extr.SubsetSequentialSampler(
            sampler_ids_valid)
        self.loader_valid = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_testing,
                                                  sampler=validation_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        self.loader_test = None
示例#3
0
def main(hdf_file: str):
    extractor = extr.ComposeExtractor([
        extr.NamesExtractor(),
        extr.DataExtractor(),
        extr.SelectiveDataExtractor(),
        extr.DataExtractor(('numerical', ), ignore_indexing=True),
        extr.DataExtractor(('gender', ), ignore_indexing=True),
        extr.DataExtractor(('mask', ), ignore_indexing=False),
        extr.SubjectExtractor(),
        extr.FilesExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS,
                                        'mask', 'numerical', 'gender')),
        extr.IndexingExtractor(),
        extr.ImagePropertiesExtractor()
    ])
    dataset = extr.PymiaDatasource(hdf_file, extr.SliceIndexing(), extractor)

    for i in range(len(dataset)):
        item = dataset[i]

        index_expr = item[defs.KEY_INDEX_EXPR]  # type: data.IndexExpression
        root = item[defs.KEY_FILE_ROOT]

        image = None  # type: sitk.Image
        for i, file in enumerate(
                item[defs.KEY_PLACEHOLDER_FILES.format('images')]):
            image = sitk.ReadImage(os.path.join(root, file))
            np_img = sitk.GetArrayFromImage(image).astype(np.float32)
            np_img = (np_img - np_img.mean()) / np_img.std()
            np_slice = np_img[index_expr.expression]
            if (np_slice != item[defs.KEY_IMAGES][..., i]).any():
                raise ValueError('slice not equal')

        # for any image
        image_properties = conv.ImageProperties(image)

        if image_properties != item[defs.KEY_PROPERTIES]:
            raise ValueError('image properties not equal')

        for file in item[defs.KEY_PLACEHOLDER_FILES.format('labels')]:
            image = sitk.ReadImage(os.path.join(root, file))
            np_img = sitk.GetArrayFromImage(image)
            np_img = np.expand_dims(
                np_img, axis=-1
            )  # due to the convention of having the last dim as number of channels
            np_slice = np_img[index_expr.expression]
            if (np_slice != item[defs.KEY_LABELS]).any():
                raise ValueError('slice not equal')

        for file in item[defs.KEY_PLACEHOLDER_FILES.format('mask')]:
            image = sitk.ReadImage(os.path.join(root, file))
            np_img = sitk.GetArrayFromImage(image)
            np_img = np.expand_dims(
                np_img, axis=-1
            )  # due to the convention of having the last dim as number of channels
            np_slice = np_img[index_expr.expression]
            if (np_slice != item['mask']).any():
                raise ValueError('slice not equal')

        for file in item[defs.KEY_PLACEHOLDER_FILES.format('numerical')]:
            with open(os.path.join(root, file), 'r') as f:
                lines = f.readlines()
            age = float(lines[0].split(':')[1].strip())
            gpa = float(lines[1].split(':')[1].strip())
            if age != item['numerical'][0][0] or gpa != item['numerical'][0][1]:
                raise ValueError('value not equal')

        for file in item[defs.KEY_PLACEHOLDER_FILES.format('gender')]:
            with open(os.path.join(root, file), 'r') as f:
                gender = f.readlines()[2].split(':')[1].strip()
            if gender != str(item['gender'][0]):
                raise ValueError('value not equal')

    print('All test passed!')
示例#4
0
def main(config_file: str):
    config = cfg.load(config_file, cfg.Configuration)
    print(config)

    indexing_strategy = pymia_extr.SliceIndexing()  # slice-wise extraction
    extraction_transform = None  # we do not want to apply any transformation on the slices after extraction
    # define an extractor for training, i.e. what information we would like to extract per sample
    train_extractor = pymia_extr.ComposeExtractor([pymia_extr.NamesExtractor(),
                                                   pymia_extr.DataExtractor(),
                                                   pymia_extr.SelectiveDataExtractor()])

    # define an extractor for testing, i.e. what information we would like to extract per sample
    # not that usually we don't use labels for testing, i.e. the SelectiveDataExtractor is only used for this example
    test_extractor = pymia_extr.ComposeExtractor([pymia_extr.NamesExtractor(),
                                                  pymia_extr.IndexingExtractor(),
                                                  pymia_extr.DataExtractor(),
                                                  pymia_extr.SelectiveDataExtractor(),
                                                  pymia_extr.ImageShapeExtractor()])

    # define an extractor for evaluation, i.e. what information we would like to extract per sample
    eval_extractor = pymia_extr.ComposeExtractor([pymia_extr.NamesExtractor(),
                                                  pymia_extr.SubjectExtractor(),
                                                  pymia_extr.SelectiveDataExtractor(),
                                                  pymia_extr.ImagePropertiesExtractor()])

    # define the data set
    dataset = pymia_extr.ParameterizableDataset(config.database_file,
                                                indexing_strategy,
                                                pymia_extr.SubjectExtractor(),  # for select_indices() below
                                                extraction_transform)

    # generate train / test split for data set
    # we use Subject_0, Subject_1 and Subject_2 for training and Subject_3 for testing
    sampler_ids_train = pymia_extr.select_indices(dataset,
                                                  pymia_extr.SubjectSelection(('Subject_0', 'Subject_1', 'Subject_2')))
    sampler_ids_test = pymia_extr.select_indices(dataset,
                                                 pymia_extr.SubjectSelection(('Subject_3')))

    # set up training data loader
    training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
    training_loader = pymia_extr.DataLoader(dataset, config.batch_size_training, sampler=training_sampler,
                                            collate_fn=collate_batch, num_workers=1)

    # set up testing data loader
    testing_sampler = pymia_extr.SubsetSequentialSampler(sampler_ids_test)
    testing_loader = pymia_extr.DataLoader(dataset, config.batch_size_testing, sampler=testing_sampler,
                                           collate_fn=collate_batch, num_workers=1)

    sample = dataset.direct_extract(train_extractor, 0)  # extract a subject

    evaluator = init_evaluator()  # initialize evaluator

    for epoch in range(config.epochs):  # epochs loop
        dataset.set_extractor(train_extractor)
        for batch in training_loader:  # batches for training
            # feed_dict = batch_to_feed_dict(x, y, batch, True)  # e.g. for TensorFlow
            # train model, e.g.:
            # sess.run([train_op, loss], feed_dict=feed_dict)
            pass

        # subject assembler for testing
        subject_assembler = pymia_asmbl.SubjectAssembler()

        dataset.set_extractor(test_extractor)
        for batch in testing_loader:  # batches for testing
            # feed_dict = batch_to_feed_dict(x, y, batch, False)  # e.g. for TensorFlow
            # test model, e.g.:
            # prediction = sess.run(y_model, feed_dict=feed_dict)
            prediction = np.stack(batch['labels'], axis=0)  # we use the labels as predictions such that we can validate the assembler
            subject_assembler.add_batch(prediction, batch)

        # evaluate all test images
        for subject_idx in list(subject_assembler.predictions.keys()):
            # convert prediction and labels back to SimpleITK images
            sample = dataset.direct_extract(eval_extractor, subject_idx)
            label_image = pymia_conv.NumpySimpleITKImageBridge.convert(sample['labels'],
                                                                       sample['properties'])

            assembled = subject_assembler.get_assembled_subject(sample['subject_index'])
            prediction_image = pymia_conv.NumpySimpleITKImageBridge.convert(assembled, sample['properties'])
            evaluator.evaluate(prediction_image, label_image, sample['subject'])  # evaluate prediction
示例#5
0
def main(hdf_file: str):
    extractor = extr.ComposeExtractor([
        extr.NamesExtractor(),
        extr.DataExtractor(),
        extr.SelectiveDataExtractor(),
        extr.DataExtractor(('numerical', ), ignore_indexing=True),
        extr.DataExtractor(('sex', ), ignore_indexing=True),
        extr.DataExtractor(('mask', ), ignore_indexing=False),
        extr.SubjectExtractor(),
        extr.FilesExtractor(categories=('images', 'labels', 'mask',
                                        'numerical', 'sex')),
        extr.IndexingExtractor(),
        extr.ImagePropertiesExtractor()
    ])
    dataset = extr.ParameterizableDataset(hdf_file, extr.SliceIndexing(),
                                          extractor)

    for i in range(len(dataset)):
        item = dataset[i]

        index_expr = item['index_expr']  # type: pymia_data.IndexExpression
        root = item['file_root']

        image = None  # type: sitk.Image
        for i, file in enumerate(item['images_files']):
            image = sitk.ReadImage(os.path.join(root, file))
            np_img = sitk.GetArrayFromImage(image).astype(np.float32)
            np_img = (np_img - np_img.mean()) / np_img.std()
            np_slice = np_img[index_expr.expression]
            if (np_slice != item['images'][..., i]).any():
                raise ValueError('slice not equal')

        # for any image
        image_properties = conv.ImageProperties(image)

        if image_properties != item['properties']:
            raise ValueError('image properties not equal')

        for file in item['labels_files']:
            image = sitk.ReadImage(os.path.join(root, file))
            np_img = sitk.GetArrayFromImage(image)
            np_img = np.expand_dims(
                np_img, axis=-1
            )  # due to the convention of having the last dim as number of channels
            np_slice = np_img[index_expr.expression]
            if (np_slice != item['labels']).any():
                raise ValueError('slice not equal')

        for file in item['mask_files']:
            image = sitk.ReadImage(os.path.join(root, file))
            np_img = sitk.GetArrayFromImage(image)
            np_img = np.expand_dims(
                np_img, axis=-1
            )  # due to the convention of having the last dim as number of channels
            np_slice = np_img[index_expr.expression]
            if (np_slice != item['mask']).any():
                raise ValueError('slice not equal')

        for file in item['numerical_files']:
            with open(os.path.join(root, file), 'r') as f:
                lines = f.readlines()
            age = float(lines[0].split(':')[1].strip())
            gpa = float(lines[1].split(':')[1].strip())
            if age != item['numerical'][0][0] or gpa != item['numerical'][0][1]:
                raise ValueError('value not equal')

        for file in item['sex_files']:
            with open(os.path.join(root, file), 'r') as f:
                sex = f.readlines()[2].split(':')[1].strip()
            if sex != str(item['sex'][0]):
                raise ValueError('value not equal')

    print('All test passed!')
示例#6
0
def main(hdf_file: str, log_dir: str):
    # initialize the evaluator with the metrics and the labels to evaluate
    metrics = [metric.DiceCoefficient()]
    labels = {1: 'WHITEMATTER', 2: 'GREYMATTER', 5: 'THALAMUS'}
    evaluator = eval_.SegmentationEvaluator(metrics, labels)

    # we want to log the mean and standard deviation of the metrics among all subjects of the dataset
    functions = {'MEAN': np.mean, 'STD': np.std}
    statistics_aggregator = writer.StatisticsAggregator(functions=functions)

    # initialize TensorBoard writer
    tb = tensorboard.SummaryWriter(
        os.path.join(log_dir, 'logging-example-torch'))

    # initialize the data handling
    transform = tfm.Permute(permutation=(2, 0, 1), entries=(defs.KEY_IMAGES, ))
    dataset = extr.PymiaDatasource(
        hdf_file, extr.SliceIndexing(),
        extr.DataExtractor(categories=(defs.KEY_IMAGES, )), transform)
    pytorch_dataset = pymia_torch.PytorchDatasetAdapter(dataset)
    loader = torch_data.dataloader.DataLoader(pytorch_dataset,
                                              batch_size=100,
                                              shuffle=False)

    assembler = assm.SubjectAssembler(dataset)
    direct_extractor = extr.ComposeExtractor([
        extr.SubjectExtractor(),  # extraction of the subject name
        extr.ImagePropertiesExtractor(
        ),  # Extraction of image properties (origin, spacing, etc.) for storage
        extr.DataExtractor(
            categories=(defs.KEY_LABELS,
                        ))  # Extraction of "labels" entries for evaluation
    ])

    # initialize a dummy network, which returns a random prediction
    class DummyNetwork(nn.Module):
        def forward(self, x):
            return torch.randint(0, 6, (x.size(0), 1, *x.size()[2:]))

    dummy_network = DummyNetwork()
    torch.manual_seed(0)  # set seed for reproducibility

    nb_batches = len(loader)

    epochs = 10
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')
        for i, batch in enumerate(loader):
            # get the data from batch and predict
            x, sample_indices = batch[defs.KEY_IMAGES], batch[
                defs.KEY_SAMPLE_INDEX]
            prediction = dummy_network(x)

            # translate the prediction to numpy and back to (B)HWC (channel last)
            numpy_prediction = prediction.numpy().transpose((0, 2, 3, 1))

            # add the batch prediction to the assembler
            is_last = i == nb_batches - 1
            assembler.add_batch(numpy_prediction, sample_indices.numpy(),
                                is_last)

            # process the subjects/images that are fully assembled
            for subject_index in assembler.subjects_ready:
                subject_prediction = assembler.get_assembled_subject(
                    subject_index)

                # extract the target and image properties via direct extract
                direct_sample = dataset.direct_extract(direct_extractor,
                                                       subject_index)
                reference, image_properties = direct_sample[
                    defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]

                # evaluate the prediction against the reference
                evaluator.evaluate(subject_prediction[..., 0], reference[...,
                                                                         0],
                                   direct_sample[defs.KEY_SUBJECT])

        # calculate mean and standard deviation of each metric
        results = statistics_aggregator.calculate(evaluator.results)
        # log to TensorBoard into category train
        for result in results:
            tb.add_scalar(f'train/{result.metric}-{result.id_}', result.value,
                          epoch)

        # clear results such that the evaluator is ready for the next evaluation
        evaluator.clear()
示例#7
0
def main(hdf_file, log_dir):
    # initialize the evaluator with the metrics and the labels to evaluate
    metrics = [metric.DiceCoefficient()]
    labels = {
        1: 'WHITEMATTER',
        2: 'GREYMATTER',
        3: 'HIPPOCAMPUS',
        4: 'AMYGDALA',
        5: 'THALAMUS'
    }
    evaluator = eval_.SegmentationEvaluator(metrics, labels)

    # we want to log the mean and standard deviation of the metrics among all subjects of the dataset
    functions = {'MEAN': np.mean, 'STD': np.std}
    statistics_aggregator = writer.StatisticsAggregator(functions=functions)
    console_writer = writer.ConsoleStatisticsWriter(functions=functions)

    # initialize TensorBoard writer
    summary_writer = tf.summary.create_file_writer(
        os.path.join(log_dir, 'logging-example-tensorflow'))

    # setup the training datasource
    train_subjects, valid_subjects = ['Subject_1', 'Subject_2',
                                      'Subject_3'], ['Subject_4']
    extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES,
                                               defs.KEY_LABELS))
    indexing_strategy = extr.SliceIndexing()

    augmentation_transforms = [
        augm.RandomElasticDeformation(),
        augm.RandomMirror()
    ]
    transforms = [tfm.Squeeze(entries=(defs.KEY_LABELS, ))]
    train_transforms = tfm.ComposeTransform(augmentation_transforms +
                                            transforms)
    train_dataset = extr.PymiaDatasource(hdf_file,
                                         indexing_strategy,
                                         extractor,
                                         train_transforms,
                                         subject_subset=train_subjects)

    # setup the validation datasource
    batch_size = 16
    valid_transforms = tfm.ComposeTransform([])
    valid_dataset = extr.PymiaDatasource(hdf_file,
                                         indexing_strategy,
                                         extractor,
                                         valid_transforms,
                                         subject_subset=valid_subjects)
    direct_extractor = extr.ComposeExtractor([
        extr.SubjectExtractor(),
        extr.ImagePropertiesExtractor(),
        extr.DataExtractor(categories=(defs.KEY_LABELS, ))
    ])
    assembler = assm.SubjectAssembler(valid_dataset)

    # tensorflow specific handling
    train_gen_fn = pymia_tf.get_tf_generator(train_dataset)
    tf_train_dataset = tf.data.Dataset.from_generator(
        generator=train_gen_fn,
        output_types={
            defs.KEY_IMAGES: tf.float32,
            defs.KEY_LABELS: tf.int64,
            defs.KEY_SAMPLE_INDEX: tf.int64
        })
    tf_train_dataset = tf_train_dataset.batch(batch_size).shuffle(
        len(train_dataset))

    valid_gen_fn = pymia_tf.get_tf_generator(valid_dataset)
    tf_valid_dataset = tf.data.Dataset.from_generator(
        generator=valid_gen_fn,
        output_types={
            defs.KEY_IMAGES: tf.float32,
            defs.KEY_LABELS: tf.int64,
            defs.KEY_SAMPLE_INDEX: tf.int64
        })
    tf_valid_dataset = tf_valid_dataset.batch(batch_size)

    u_net = unet.build_model(channels=2,
                             num_classes=6,
                             layer_depth=3,
                             filters_root=16)

    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)

    train_batches = len(train_dataset) // batch_size

    # looping over the data in the dataset
    epochs = 100
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')

        # training
        print('training')
        for i, batch in enumerate(tf_train_dataset):
            x, y = batch[defs.KEY_IMAGES], batch[defs.KEY_LABELS]

            with tf.GradientTape() as tape:
                logits = u_net(x, training=True)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    y, logits, from_logits=True)

            grads = tape.gradient(loss, u_net.trainable_variables)
            optimizer.apply_gradients(zip(grads, u_net.trainable_variables))

            train_loss(loss)

            with summary_writer.as_default():
                tf.summary.scalar('train/loss',
                                  train_loss.result(),
                                  step=epoch * train_batches + i)
            print(
                f'[{i + 1}/{train_batches}]\tloss: {train_loss.result().numpy()}'
            )

        # validation
        print('validation')
        valid_batches = len(valid_dataset) // batch_size
        for i, batch in enumerate(tf_valid_dataset):
            x, sample_indices = batch[defs.KEY_IMAGES], batch[
                defs.KEY_SAMPLE_INDEX]

            logits = u_net(x)
            prediction = tf.expand_dims(tf.math.argmax(logits, -1), -1)

            numpy_prediction = prediction.numpy()

            is_last = i == valid_batches - 1
            assembler.add_batch(numpy_prediction, sample_indices.numpy(),
                                is_last)

            for subject_index in assembler.subjects_ready:
                subject_prediction = assembler.get_assembled_subject(
                    subject_index)

                direct_sample = train_dataset.direct_extract(
                    direct_extractor, subject_index)
                target, image_properties = direct_sample[
                    defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]

                # evaluate the prediction against the reference
                evaluator.evaluate(subject_prediction[..., 0], target[..., 0],
                                   direct_sample[defs.KEY_SUBJECT])

        # calculate mean and standard deviation of each metric
        results = statistics_aggregator.calculate(evaluator.results)
        # log to TensorBoard into category train
        with summary_writer.as_default():
            for result in results:
                tf.summary.scalar(f'valid/{result.metric}-{result.id_}',
                                  result.value, epoch)

        console_writer.write(evaluator.results)

        # clear results such that the evaluator is ready for the next evaluation
        evaluator.clear()