Exemplo n.º 1
0
def main(hdf_file, plot_dir):
    os.makedirs(plot_dir, exist_ok=True)

    # setup the datasource
    extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS))
    indexing_strategy = extr.SliceIndexing()
    dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor)

    seed = 1
    np.random.seed(seed)
    sample_idx = 55

    # set up transformations without augmentation
    transforms_augmentation = []
    transforms_before_augmentation = [tfm.Permute(permutation=(2, 0, 1)), ]  # to have the channel-dimension first
    transforms_after_augmentation = [tfm.Squeeze(entries=(defs.KEY_LABELS,)), ]  # get rid of the channel-dimension for the labels
    train_transforms = tfm.ComposeTransform(transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
    dataset.set_transform(train_transforms)
    sample = dataset[sample_idx]
    plot_sample(plot_dir, 'none', sample)

    # augmentation with pymia
    transforms_augmentation = [augm.RandomRotation90(axes=(-2, -1)), augm.RandomMirror()]
    train_transforms = tfm.ComposeTransform(
        transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
    dataset.set_transform(train_transforms)
    sample = dataset[sample_idx]
    plot_sample(plot_dir, 'pymia', sample)

    # augmentation with batchgenerators
    transforms_augmentation = [BatchgeneratorsTransform([
        bg_tfm.spatial_transforms.MirrorTransform(axes=(0, 1), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS),
        bg_tfm.noise_transforms.GaussianBlurTransform(blur_sigma=(0.2, 1.0), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS),
    ])]
    train_transforms = tfm.ComposeTransform(
        transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
    dataset.set_transform(train_transforms)
    sample = dataset[sample_idx]
    plot_sample(plot_dir, 'batchgenerators', sample)

    # augmentation with TorchIO
    transforms_augmentation = [TorchIOTransform(
        [tio.RandomFlip(axes=('LR'), flip_probability=1.0, keys=(defs.KEY_IMAGES, defs.KEY_LABELS), seed=seed),
         tio.RandomAffine(scales=(0.9, 1.2), degrees=(10), isotropic=False, default_pad_value='otsu',
                          image_interpolation='NEAREST', keys=(defs.KEY_IMAGES, defs.KEY_LABELS), seed=seed),
         ])]
    train_transforms = tfm.ComposeTransform(
        transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation)
    dataset.set_transform(train_transforms)
    sample = dataset[sample_idx]
    plot_sample(plot_dir, 'torchio', sample)
Exemplo n.º 2
0
def main(hdf_file: str, data_dir: str):
    keys = [
        FileTypes.T1, FileTypes.T2, FileTypes.GT, FileTypes.MASK,
        FileTypes.AGE, FileTypes.GPA, FileTypes.SEX
    ]
    crawler = pymia_load.FileSystemDataCrawler(data_dir, keys,
                                               DataSetFilePathGenerator(),
                                               DirectoryFilter(), '.mha')

    subjects = [
        Subject(id_, file_dict) for id_, file_dict in crawler.data.items()
    ]

    if os.path.exists(hdf_file):
        os.remove(hdf_file)

    with pymia_crt.get_writer(hdf_file) as writer:
        callbacks = pymia_crt.get_default_callbacks(writer)

        # normalize the images and unsqueeze the labels and mask.
        # Unsqueeze is needed due to the convention to have the number of channels as last dimension.
        # I.e., here we have the shape 10 x 256 x 256 before the unsqueeze operation and after 10 x 256 x 256 x 1
        transform = pymia_tfm.ComposeTransform([
            pymia_tfm.IntensityNormalization(loop_axis=3,
                                             entries=('images', )),
            pymia_tfm.UnSqueeze(entries=('labels', 'mask'))
        ])

        traverser = pymia_crt.SubjectFileTraverser()
        traverser.traverse(subjects,
                           callback=callbacks,
                           load=LoadData(),
                           transform=transform)
Exemplo n.º 3
0
def main(hdf_file: str, data_dir: str):
    if os.path.exists(hdf_file):
        raise RuntimeError(
            'Dataset file "{}" does already exist'.format(hdf_file))

    # let's create some sample data
    np.random.seed(42)  # to have same sample data
    create_sample_data(data_dir, no_subjects=8)

    # collect the data
    collector = Collector(data_dir)
    subjects = collector.get_subject_files()
    for subject in subjects:
        print(subject.subject)

    # get the values for parametric map normalization
    min_, max_ = get_normalization_values(subjects, LoadData())

    with pymia_crt.Hdf5Writer(hdf_file) as writer:
        callbacks = pymia_crt.get_default_callbacks(writer)
        callbacks.callbacks.append(
            WriteNormalizationCallback(writer, min_, max_))

        transform = pymia_tfm.ComposeTransform([
            tfm.MRFMaskedLabelNormalization(min_, max_, data.ID_MASK_FG),
            pymia_tfm.IntensityNormalization(loop_axis=4,
                                             entries=(pymia_def.KEY_IMAGES, )),
        ])

        traverser = pymia_crt.SubjectFileTraverser()
        traverser.traverse(subjects,
                           callback=callbacks,
                           load=LoadData(),
                           transform=transform,
                           concat_fn=concat)
Exemplo n.º 4
0
def get_transform(transform_params: typing.Union[cfg.DictableParameterExt, list, tuple]) -> tfm.Transform:
    if isinstance(transform_params, (list, tuple)):
        transforms = []
        for transform_param in transform_params:
            transforms.append(get_transform(transform_param))
        return tfm.ComposeTransform(transforms)

    if transform_params.type not in transform_registry:
        raise ValueError('transform type "{}" unknown'.format(transform_params.type))
    return transform_registry[transform_params.type](**transform_params.params)
Exemplo n.º 5
0
def occlude(tester, result_dir: str, temporal_dim: int):
    for temporal_idx in range(temporal_dim):
        print('Occlude temporal dimension {}...'.format(temporal_idx))
        # modify extraction transform for occlusion experiment
        tester.data_handler.extraction_transform_valid = pymia_tfm.ComposeTransform(
            [
                pymia_tfm.Squeeze(
                    entries=(pymia_def.KEY_IMAGES, pymia_def.KEY_LABELS,
                             data.ID_MASK_FG, data.ID_MASK_T1H2O),
                    squeeze_axis=0),
                tfm.MRFTemporalOcclusion(temporal_idx=temporal_idx,
                                         entries=(pymia_def.KEY_IMAGES, ))
            ])
        tester.result_dir = os.path.join(result_dir,
                                         'occlusion{}'.format(temporal_idx))
        tester.predict()
def build_brats_dataset(params: BuildParameters):
    collector = collect.Brats17Collector(
        params.in_dir)  # 17 is same dataset as 18
    subject_files = collector.get_subject_files()

    if params.split_file is not None:
        if params.is_train_data:
            train_subjects, valid_subjects, _ = split.load_split(
                params.split_file)
            selection = train_subjects + valid_subjects
        else:
            _, _, selection = split.load_split(params.split_file)

        subject_files = [sf for sf in subject_files if sf.subject in selection]
        assert len(subject_files) == len(selection)

    # sort the subject files according to the subject name (identifier)
    subject_files.sort(key=lambda sf: sf.subject)

    if params.prediction_path is not None:
        subject_files = params.add_prediction_fn(subject_files,
                                                 params.prediction_path)

    fh.create_dir_if_not_exists(params.out_file, is_file=True)
    fh.remove_if_exists(params.out_file)

    with crt.get_writer(params.out_file) as writer:
        callbacks = [
            crt.MonitoringCallback(),
            crt.WriteNamesCallback(writer),
            crt.WriteFilesCallback(writer),
            crt.WriteDataCallback(writer),
            crt.WriteSubjectCallback(writer),
            crt.WriteImageInformationCallback(writer),
        ]

        has_grade = params.in_dir.endswith('Training')
        if has_grade:
            callbacks.append(WriteGradeCallback(writer))
        callback = crt.ComposeCallback(callbacks)

        traverser = crt.SubjectFileTraverser()
        traverser.traverse(subject_files,
                           callback=callback,
                           load=LoadSubject(),
                           transform=tfm.ComposeTransform(params.transforms))
def main(hdf_file: str, data_dir: str):
    if os.path.exists(hdf_file):
        raise RuntimeError('Dataset file "{}" does exist already'.format(hdf_file))

    # we threshold I_Q at probability 0.1
    probability_threshold = 0.1
    # we use image information extracted from 5^3 neighborhood around each point
    spatial_size = 5

    # let's create some sample data
    np.random.seed(42)  # to have same sample data
    create_sample_data(data_dir, no_subjects=8)

    # collect the data
    collector = Collector(data_dir)
    subjects = collector.get_subject_files()

    for subject in subjects:
        print(subject.subject)
    print('Total of {} subjects'.format(len(subjects)))

    os.makedirs(os.path.dirname(hdf_file), exist_ok=True)

    with pymia_crt.get_writer(hdf_file) as writer:
        callbacks = pymia_crt.get_default_callbacks(writer)

        transform = pymia_tfm.ComposeTransform([
            pymia_tfm.LambdaTransform(lambda_fn=lambda np_data: np_data.astype(np.float32),
                                      entries=('images',
                                               data.KEY_IMAGE_INFORMATION,
                                               )),
            pymia_tfm.LambdaTransform(loop_axis=1, entries=('images', ), lambda_fn=normalize_unit_cube),
            pymia_tfm.IntensityNormalization(loop_axis=-1, entries=(data.KEY_IMAGE_INFORMATION,))
        ])

        traverser = pymia_crt.SubjectFileTraverser()
        traverser.traverse(subjects, callback=callbacks, load=LoadData(probability_threshold, spatial_size),
                           transform=transform, concat_fn=concat)
Exemplo n.º 8
0
def main(hdf_file, log_dir):
    # initialize the evaluator with the metrics and the labels to evaluate
    metrics = [metric.DiceCoefficient()]
    labels = {1: 'WHITEMATTER',
              2: 'GREYMATTER',
              3: 'HIPPOCAMPUS',
              4: 'AMYGDALA',
              5: 'THALAMUS'}
    evaluator = eval_.SegmentationEvaluator(metrics, labels)

    # we want to log the mean and standard deviation of the metrics among all subjects of the dataset
    functions = {'MEAN': np.mean, 'STD': np.std}
    statistics_aggregator = writer.StatisticsAggregator(functions=functions)
    console_writer = writer.ConsoleStatisticsWriter(functions=functions)

    # initialize TensorBoard writer
    tb = tensorboard.SummaryWriter(os.path.join(log_dir, 'logging-example-torch'))

    # setup the training datasource
    train_subjects, valid_subjects = ['Subject_1', 'Subject_2', 'Subject_3'], ['Subject_4']
    extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS))
    indexing_strategy = extr.SliceIndexing()

    augmentation_transforms = [augm.RandomElasticDeformation(), augm.RandomMirror()]
    transforms = [tfm.Permute(permutation=(2, 0, 1)), tfm.Squeeze(entries=(defs.KEY_LABELS,))]
    train_transforms = tfm.ComposeTransform(augmentation_transforms + transforms)
    train_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, train_transforms,
                                         subject_subset=train_subjects)

    # setup the validation datasource
    valid_transforms = tfm.ComposeTransform([tfm.Permute(permutation=(2, 0, 1))])
    valid_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, valid_transforms,
                                         subject_subset=valid_subjects)
    direct_extractor = extr.ComposeExtractor(
        [extr.SubjectExtractor(),
         extr.ImagePropertiesExtractor(),
         extr.DataExtractor(categories=(defs.KEY_LABELS,))]
    )
    assembler = assm.SubjectAssembler(valid_dataset)

    # torch specific handling
    pytorch_train_dataset = pymia_torch.PytorchDatasetAdapter(train_dataset)
    train_loader = torch_data.dataloader.DataLoader(pytorch_train_dataset, batch_size=16, shuffle=True)

    pytorch_valid_dataset = pymia_torch.PytorchDatasetAdapter(valid_dataset)
    valid_loader = torch_data.dataloader.DataLoader(pytorch_valid_dataset, batch_size=16, shuffle=False)

    u_net = unet.UNetModel(ch_in=2, ch_out=6, n_channels=16, n_pooling=3).to(device)

    print(u_net)

    optimizer = optim.Adam(u_net.parameters(), lr=1e-3)
    train_batches = len(train_loader)

    # looping over the data in the dataset
    epochs = 100
    for epoch in range(epochs):
        u_net.train()
        print(f'Epoch {epoch + 1}/{epochs}')

        # training
        print('training')
        for i, batch in enumerate(train_loader):
            x, y = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_LABELS].to(device).long()
            logits = u_net(x)

            optimizer.zero_grad()
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optimizer.step()

            tb.add_scalar('train/loss', loss.item(), epoch*train_batches + i)
            print(f'[{i + 1}/{train_batches}]\tloss: {loss.item()}')

        # validation
        print('validation')
        with torch.no_grad():
            u_net.eval()
            valid_batches = len(valid_loader)
            for i, batch in enumerate(valid_loader):
                x, sample_indices = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_SAMPLE_INDEX]

                logits = u_net(x)
                prediction = logits.argmax(dim=1, keepdim=True)

                numpy_prediction = prediction.cpu().numpy().transpose((0, 2, 3, 1))

                is_last = i == valid_batches - 1
                assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last)

                for subject_index in assembler.subjects_ready:
                    subject_prediction = assembler.get_assembled_subject(subject_index)

                    direct_sample = train_dataset.direct_extract(direct_extractor, subject_index)
                    target, image_properties = direct_sample[defs.KEY_LABELS],  direct_sample[defs.KEY_PROPERTIES]

                    # evaluate the prediction against the reference
                    evaluator.evaluate(subject_prediction[..., 0], target[..., 0], direct_sample[defs.KEY_SUBJECT])

            # calculate mean and standard deviation of each metric
            results = statistics_aggregator.calculate(evaluator.results)
            # log to TensorBoard into category train
            for result in results:
                tb.add_scalar(f'valid/{result.metric}-{result.id_}', result.value, epoch)

            console_writer.write(evaluator.results)

            # clear results such that the evaluator is ready for the next evaluation
            evaluator.clear()
    def __init__(self,
                 config: cfg.Configuration,
                 subjects_train,
                 subjects_valid,
                 subjects_test,
                 collate_fn=pymia_cnv.TensorFlowCollate()):
        super().__init__()

        indexing_strategy = PointCloudIndexing(config.no_points)

        self.dataset = pymia_extr.ParameterizableDataset(
            config.database_file,
            indexing_strategy,
            pymia_extr.SubjectExtractor(),  # for the usual select_indices
            None)

        self.no_subjects_train = len(subjects_train)
        self.no_subjects_valid = len(subjects_valid)
        self.no_subjects_test = len(
            subjects_valid
        )  # same as validation for this kind of cross validation

        # get sampler ids by subjects
        sampler_ids_train = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_train))
        sampler_ids_valid = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_valid))

        categories = ('images', 'labels')
        categories_tfm = ('images', 'labels')
        if config.use_image_information:
            image_information_categories = (data.KEY_IMAGE_INFORMATION, )
            categories += image_information_categories
            categories_tfm += image_information_categories
            collate_fn.entries += image_information_categories

        # define point cloud shuffler for augmentation
        sizes = {}
        for idx in range(len(self.dataset.get_subjects())):
            sample = self.dataset.direct_extract(PointCloudSizeExtractor(),
                                                 idx)
            sizes[idx] = sample['size']
        self.point_cloud_shuffler = aug.PointCloudShuffler(sizes)
        self.point_cloud_shuffler_valid = aug.PointCloudShuffler(
            sizes
        )  # will only shuffle once at instantiation because shuffle() is not called during training (see set_seed)

        data_extractor_train = aug.ShuffledDataExtractor(
            self.point_cloud_shuffler, categories)
        data_extractor_valid = aug.ShuffledDataExtractor(
            self.point_cloud_shuffler_valid, categories)
        data_extractor_test = aug.ShuffledDataExtractor(
            self.point_cloud_shuffler_valid, categories=('indices', 'labels'))

        # define extractors
        self.extractor_train = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.SubjectExtractor(),  # required for plotting
            PointCloudSizeExtractor(),  # to init_shape in SubjectAssembler
            data_extractor_train,
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_valid = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.SubjectExtractor(),  # required for plotting
            PointCloudSizeExtractor(),  # to init_shape in SubjectAssembler
            data_extractor_valid,
            pymia_extr.IndexingExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        self.extractor_test = pymia_extr.ComposeExtractor([
            pymia_extr.NamesExtractor(),  # required for SelectiveDataExtractor
            pymia_extr.SubjectExtractor(),
            data_extractor_test,  # we need the indices, i.e. the point's coordinates,
            # to convert the point cloud back to an image
            pymia_extr.DataExtractor(
                categories=('gt', ),
                ignore_indexing=True),  # the ground truth is used for the
            # validation at config.save_validation_nth_epoch
            pymia_extr.ImagePropertiesExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        # define transforms for extraction
        self.extraction_transform_train = pymia_tfm.ComposeTransform([
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=-1),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
        ])

        self.extraction_transform_valid = pymia_tfm.ComposeTransform([
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=-1),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', ))
            # for PyTorch loss functions
        ])

        if config.use_jitter:
            self.extraction_transform_train.transforms.append(
                aug.PointCloudJitter())
            self.extraction_transform_valid.transforms.append(
                aug.PointCloudJitter())

        if config.use_rotation:
            self.extraction_transform_train.transforms.append(
                aug.PointCloudRotate())
            self.extraction_transform_valid.transforms.append(
                aug.PointCloudRotate())

        # need to add probability concatenation after augmentation!
        if config.use_point_feature:
            self.extraction_transform_train.transforms.append(
                ConcatenateCoordinatesAndPointFeatures())
            self.extraction_transform_valid.transforms.append(
                ConcatenateCoordinatesAndPointFeatures())

        if config.use_image_information:
            spatial_size = config.image_information_config.spatial_size

            def slice_patches(np_data):
                z = (np_data.shape[1] - spatial_size) // 2
                y = (np_data.shape[2] - spatial_size) // 2
                x = (np_data.shape[3] - spatial_size) // 2

                np_data = np_data[:, z:(z + spatial_size),
                                  y:(y + spatial_size),
                                  x:(x + spatial_size), :]
                return np_data

            self.extraction_transform_train.transforms.append(
                pymia_tfm.LambdaTransform(
                    lambda_fn=slice_patches,
                    entries=image_information_categories))
            self.extraction_transform_valid.transforms.append(
                pymia_tfm.LambdaTransform(
                    lambda_fn=slice_patches,
                    entries=image_information_categories))

        # define loaders
        training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
        self.loader_train = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_training,
                                                  sampler=training_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        validation_sampler = pymia_extr.SubsetSequentialSampler(
            sampler_ids_valid)
        self.loader_valid = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_testing,
                                                  sampler=validation_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        self.loader_test = pymia_extr.DataLoader(self.dataset,
                                                 config.batch_size_testing,
                                                 sampler=validation_sampler,
                                                 collate_fn=collate_fn,
                                                 num_workers=1)

        self.extraction_transform_test = None
Exemplo n.º 10
0
    def __init__(self,
                 config: cfg.Configuration,
                 subjects_train,
                 subjects_valid,
                 subjects_test,
                 collate_fn=pymia_cnv.TorchCollate(
                     ('images', 'labels', 'mask_fg', 'mask_t1h2o'))):
        super().__init__()

        indexing_strategy = pymia_extr.SliceIndexing()

        self.dataset = pymia_extr.ParameterizableDataset(
            config.database_file,
            indexing_strategy,
            pymia_extr.SubjectExtractor(),  # for the usual select_indices
            None)

        self.no_subjects_train = len(subjects_train)
        self.no_subjects_valid = len(subjects_valid)
        self.no_subjects_test = 0

        # get sampler ids by subjects
        sampler_ids_train = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_train))
        sampler_ids_valid = pymia_extr.select_indices(
            self.dataset, pymia_extr.SubjectSelection(subjects_valid))

        # define extractors
        self.extractor_train = pymia_extr.ComposeExtractor([
            pymia_extr.DataExtractor(categories=('images', 'labels')),
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_valid = pymia_extr.ComposeExtractor([
            pymia_extr.DataExtractor(categories=('images', 'labels')),
            pymia_extr.IndexingExtractor(
            ),  # for SubjectAssembler (assembling)
            pymia_extr.ImageShapeExtractor()  # for SubjectAssembler (shape)
        ])

        self.extractor_test = pymia_extr.ComposeExtractor([
            pymia_extr.SubjectExtractor(),
            pymia_extr.DataExtractor(categories=('labels', )),
            pymia_extr.ImagePropertiesExtractor(),
            pymia_extr.ImageShapeExtractor()
        ])

        # define transforms for extraction
        self.extraction_transform_train = pymia_tfm.ComposeTransform([
            pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)),
            pymia_tfm.Permute((2, 0, 1)),
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=0),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
            pymia_tfm.ToTorchTensor()
        ])

        self.extraction_transform_valid = pymia_tfm.ComposeTransform([
            pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)),
            pymia_tfm.Permute((2, 0, 1)),
            pymia_tfm.Squeeze(entries=('labels', ),
                              squeeze_axis=0),  # for PyTorch loss functions
            pymia_tfm.LambdaTransform(
                lambda_fn=lambda np_data: np_data.astype(np.int64),
                entries=('labels', )),
            # for PyTorch loss functions
            pymia_tfm.ToTorchTensor()
        ])

        self.extraction_transform_test = None

        # define loaders
        training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train)
        self.loader_train = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_training,
                                                  sampler=training_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        validation_sampler = pymia_extr.SubsetSequentialSampler(
            sampler_ids_valid)
        self.loader_valid = pymia_extr.DataLoader(self.dataset,
                                                  config.batch_size_testing,
                                                  sampler=validation_sampler,
                                                  collate_fn=collate_fn,
                                                  num_workers=1)

        self.loader_test = None
Exemplo n.º 11
0
def main(hdf_file, log_dir):
    # initialize the evaluator with the metrics and the labels to evaluate
    metrics = [metric.DiceCoefficient()]
    labels = {
        1: 'WHITEMATTER',
        2: 'GREYMATTER',
        3: 'HIPPOCAMPUS',
        4: 'AMYGDALA',
        5: 'THALAMUS'
    }
    evaluator = eval_.SegmentationEvaluator(metrics, labels)

    # we want to log the mean and standard deviation of the metrics among all subjects of the dataset
    functions = {'MEAN': np.mean, 'STD': np.std}
    statistics_aggregator = writer.StatisticsAggregator(functions=functions)
    console_writer = writer.ConsoleStatisticsWriter(functions=functions)

    # initialize TensorBoard writer
    summary_writer = tf.summary.create_file_writer(
        os.path.join(log_dir, 'logging-example-tensorflow'))

    # setup the training datasource
    train_subjects, valid_subjects = ['Subject_1', 'Subject_2',
                                      'Subject_3'], ['Subject_4']
    extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES,
                                               defs.KEY_LABELS))
    indexing_strategy = extr.SliceIndexing()

    augmentation_transforms = [
        augm.RandomElasticDeformation(),
        augm.RandomMirror()
    ]
    transforms = [tfm.Squeeze(entries=(defs.KEY_LABELS, ))]
    train_transforms = tfm.ComposeTransform(augmentation_transforms +
                                            transforms)
    train_dataset = extr.PymiaDatasource(hdf_file,
                                         indexing_strategy,
                                         extractor,
                                         train_transforms,
                                         subject_subset=train_subjects)

    # setup the validation datasource
    batch_size = 16
    valid_transforms = tfm.ComposeTransform([])
    valid_dataset = extr.PymiaDatasource(hdf_file,
                                         indexing_strategy,
                                         extractor,
                                         valid_transforms,
                                         subject_subset=valid_subjects)
    direct_extractor = extr.ComposeExtractor([
        extr.SubjectExtractor(),
        extr.ImagePropertiesExtractor(),
        extr.DataExtractor(categories=(defs.KEY_LABELS, ))
    ])
    assembler = assm.SubjectAssembler(valid_dataset)

    # tensorflow specific handling
    train_gen_fn = pymia_tf.get_tf_generator(train_dataset)
    tf_train_dataset = tf.data.Dataset.from_generator(
        generator=train_gen_fn,
        output_types={
            defs.KEY_IMAGES: tf.float32,
            defs.KEY_LABELS: tf.int64,
            defs.KEY_SAMPLE_INDEX: tf.int64
        })
    tf_train_dataset = tf_train_dataset.batch(batch_size).shuffle(
        len(train_dataset))

    valid_gen_fn = pymia_tf.get_tf_generator(valid_dataset)
    tf_valid_dataset = tf.data.Dataset.from_generator(
        generator=valid_gen_fn,
        output_types={
            defs.KEY_IMAGES: tf.float32,
            defs.KEY_LABELS: tf.int64,
            defs.KEY_SAMPLE_INDEX: tf.int64
        })
    tf_valid_dataset = tf_valid_dataset.batch(batch_size)

    u_net = unet.build_model(channels=2,
                             num_classes=6,
                             layer_depth=3,
                             filters_root=16)

    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)

    train_batches = len(train_dataset) // batch_size

    # looping over the data in the dataset
    epochs = 100
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')

        # training
        print('training')
        for i, batch in enumerate(tf_train_dataset):
            x, y = batch[defs.KEY_IMAGES], batch[defs.KEY_LABELS]

            with tf.GradientTape() as tape:
                logits = u_net(x, training=True)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    y, logits, from_logits=True)

            grads = tape.gradient(loss, u_net.trainable_variables)
            optimizer.apply_gradients(zip(grads, u_net.trainable_variables))

            train_loss(loss)

            with summary_writer.as_default():
                tf.summary.scalar('train/loss',
                                  train_loss.result(),
                                  step=epoch * train_batches + i)
            print(
                f'[{i + 1}/{train_batches}]\tloss: {train_loss.result().numpy()}'
            )

        # validation
        print('validation')
        valid_batches = len(valid_dataset) // batch_size
        for i, batch in enumerate(tf_valid_dataset):
            x, sample_indices = batch[defs.KEY_IMAGES], batch[
                defs.KEY_SAMPLE_INDEX]

            logits = u_net(x)
            prediction = tf.expand_dims(tf.math.argmax(logits, -1), -1)

            numpy_prediction = prediction.numpy()

            is_last = i == valid_batches - 1
            assembler.add_batch(numpy_prediction, sample_indices.numpy(),
                                is_last)

            for subject_index in assembler.subjects_ready:
                subject_prediction = assembler.get_assembled_subject(
                    subject_index)

                direct_sample = train_dataset.direct_extract(
                    direct_extractor, subject_index)
                target, image_properties = direct_sample[
                    defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]

                # evaluate the prediction against the reference
                evaluator.evaluate(subject_prediction[..., 0], target[..., 0],
                                   direct_sample[defs.KEY_SUBJECT])

        # calculate mean and standard deviation of each metric
        results = statistics_aggregator.calculate(evaluator.results)
        # log to TensorBoard into category train
        with summary_writer.as_default():
            for result in results:
                tf.summary.scalar(f'valid/{result.metric}-{result.id_}',
                                  result.value, epoch)

        console_writer.write(evaluator.results)

        # clear results such that the evaluator is ready for the next evaluation
        evaluator.clear()