Exemple #1
0
def main(data_dir: str, result_file: str, result_summary_file: str):
    # initialize metrics
    metrics = [
        metric.DiceCoefficient(),
        metric.HausdorffDistance(percentile=95, metric='HDRFDST95'),
        metric.VolumeSimilarity()
    ]

    # define the labels to evaluate
    labels = {1: 'WHITEMATTER', 2: 'GREYMATTER', 5: 'THALAMUS'}

    evaluator = eval_.SegmentationEvaluator(metrics, labels)

    # get subjects to evaluate
    subject_dirs = [
        subject for subject in glob.glob(os.path.join(data_dir, '*'))
        if os.path.isdir(subject)
        and os.path.basename(subject).startswith('Subject')
    ]

    for subject_dir in subject_dirs:
        subject_id = os.path.basename(subject_dir)
        print(f'Evaluating {subject_id}...')

        # load ground truth image and create artificial prediction by erosion
        ground_truth = sitk.ReadImage(
            os.path.join(subject_dir, f'{subject_id}_GT.mha'))
        prediction = ground_truth
        for label_val in labels.keys():
            # erode each label we are going to evaluate
            prediction = sitk.BinaryErode(prediction, 1, sitk.sitkBall, 0,
                                          label_val)

        # evaluate the "prediction" against the ground truth
        evaluator.evaluate(prediction, ground_truth, subject_id)

    # use two writers to report the results
    writer.CSVWriter(result_file).write(evaluator.results)

    print('\nSubject-wise results...')
    writer.ConsoleWriter().write(evaluator.results)

    # report also mean and standard deviation among all subjects
    functions = {'MEAN': np.mean, 'STD': np.std}
    writer.CSVStatisticsWriter(result_summary_file,
                               functions=functions).write(evaluator.results)
    print('\nAggregated statistic results...')
    writer.ConsoleStatisticsWriter(functions=functions).write(
        evaluator.results)

    # clear results such that the evaluator is ready for the next evaluation
    evaluator.clear()
Exemple #2
0
def main(hdf_file, log_dir):
    # initialize the evaluator with the metrics and the labels to evaluate
    metrics = [metric.DiceCoefficient()]
    labels = {1: 'WHITEMATTER',
              2: 'GREYMATTER',
              3: 'HIPPOCAMPUS',
              4: 'AMYGDALA',
              5: 'THALAMUS'}
    evaluator = eval_.SegmentationEvaluator(metrics, labels)

    # we want to log the mean and standard deviation of the metrics among all subjects of the dataset
    functions = {'MEAN': np.mean, 'STD': np.std}
    statistics_aggregator = writer.StatisticsAggregator(functions=functions)
    console_writer = writer.ConsoleStatisticsWriter(functions=functions)

    # initialize TensorBoard writer
    tb = tensorboard.SummaryWriter(os.path.join(log_dir, 'logging-example-torch'))

    # setup the training datasource
    train_subjects, valid_subjects = ['Subject_1', 'Subject_2', 'Subject_3'], ['Subject_4']
    extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS))
    indexing_strategy = extr.SliceIndexing()

    augmentation_transforms = [augm.RandomElasticDeformation(), augm.RandomMirror()]
    transforms = [tfm.Permute(permutation=(2, 0, 1)), tfm.Squeeze(entries=(defs.KEY_LABELS,))]
    train_transforms = tfm.ComposeTransform(augmentation_transforms + transforms)
    train_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, train_transforms,
                                         subject_subset=train_subjects)

    # setup the validation datasource
    valid_transforms = tfm.ComposeTransform([tfm.Permute(permutation=(2, 0, 1))])
    valid_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, valid_transforms,
                                         subject_subset=valid_subjects)
    direct_extractor = extr.ComposeExtractor(
        [extr.SubjectExtractor(),
         extr.ImagePropertiesExtractor(),
         extr.DataExtractor(categories=(defs.KEY_LABELS,))]
    )
    assembler = assm.SubjectAssembler(valid_dataset)

    # torch specific handling
    pytorch_train_dataset = pymia_torch.PytorchDatasetAdapter(train_dataset)
    train_loader = torch_data.dataloader.DataLoader(pytorch_train_dataset, batch_size=16, shuffle=True)

    pytorch_valid_dataset = pymia_torch.PytorchDatasetAdapter(valid_dataset)
    valid_loader = torch_data.dataloader.DataLoader(pytorch_valid_dataset, batch_size=16, shuffle=False)

    u_net = unet.UNetModel(ch_in=2, ch_out=6, n_channels=16, n_pooling=3).to(device)

    print(u_net)

    optimizer = optim.Adam(u_net.parameters(), lr=1e-3)
    train_batches = len(train_loader)

    # looping over the data in the dataset
    epochs = 100
    for epoch in range(epochs):
        u_net.train()
        print(f'Epoch {epoch + 1}/{epochs}')

        # training
        print('training')
        for i, batch in enumerate(train_loader):
            x, y = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_LABELS].to(device).long()
            logits = u_net(x)

            optimizer.zero_grad()
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optimizer.step()

            tb.add_scalar('train/loss', loss.item(), epoch*train_batches + i)
            print(f'[{i + 1}/{train_batches}]\tloss: {loss.item()}')

        # validation
        print('validation')
        with torch.no_grad():
            u_net.eval()
            valid_batches = len(valid_loader)
            for i, batch in enumerate(valid_loader):
                x, sample_indices = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_SAMPLE_INDEX]

                logits = u_net(x)
                prediction = logits.argmax(dim=1, keepdim=True)

                numpy_prediction = prediction.cpu().numpy().transpose((0, 2, 3, 1))

                is_last = i == valid_batches - 1
                assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last)

                for subject_index in assembler.subjects_ready:
                    subject_prediction = assembler.get_assembled_subject(subject_index)

                    direct_sample = train_dataset.direct_extract(direct_extractor, subject_index)
                    target, image_properties = direct_sample[defs.KEY_LABELS],  direct_sample[defs.KEY_PROPERTIES]

                    # evaluate the prediction against the reference
                    evaluator.evaluate(subject_prediction[..., 0], target[..., 0], direct_sample[defs.KEY_SUBJECT])

            # calculate mean and standard deviation of each metric
            results = statistics_aggregator.calculate(evaluator.results)
            # log to TensorBoard into category train
            for result in results:
                tb.add_scalar(f'valid/{result.metric}-{result.id_}', result.value, epoch)

            console_writer.write(evaluator.results)

            # clear results such that the evaluator is ready for the next evaluation
            evaluator.clear()
Exemple #3
0
def main(result_dir: str, data_atlas_dir: str, data_train_dir: str,
         data_test_dir: str):
    """Brain tissue segmentation using decision forests.

    The main routine executes the medical image analysis pipeline:

        - Image loading
        - Registration
        - Pre-processing
        - Feature extraction
        - Decision forest classifier model building
        - Segmentation using the decision forest classifier model on unseen images
        - Post-processing of the segmentation
        - Evaluation of the segmentation
    """

    # load atlas images
    putil.load_atlas_images(data_atlas_dir)

    print('-' * 5, 'Training...')

    # crawl the training image directories
    crawler = futil.FileSystemDataCrawler(data_train_dir, LOADING_KEYS,
                                          futil.BrainImageFilePathGenerator(),
                                          futil.DataDirectoryFilter())
    pre_process_params = {
        'skullstrip_pre': True,
        'normalization_pre': True,
        'registration_pre': True,
        'coordinates_feature': True,
        'intensity_feature': True,
        'gradient_intensity_feature': True
    }

    # load images for training and pre-process
    images = putil.pre_process_batch(crawler.data,
                                     pre_process_params,
                                     multi_process=False)

    # generate feature matrix and label vector
    data_train = np.concatenate([img.feature_matrix[0] for img in images])
    labels_train = np.concatenate([img.feature_matrix[1]
                                   for img in images]).squeeze()

    #warnings.warn('Random forest parameters not properly set.')
    forest = sk_ensemble.RandomForestClassifier(
        max_features=images[0].feature_matrix[0].shape[1],
        n_estimators=10,
        max_depth=10)

    start_time = timeit.default_timer()
    forest.fit(data_train, labels_train)
    print(' Time elapsed:', timeit.default_timer() - start_time, 's')

    # create a result directory with timestamp
    t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    result_dir = os.path.join(result_dir, t)
    os.makedirs(result_dir, exist_ok=True)

    print('-' * 5, 'Testing...')

    # initialize evaluator
    evaluator = putil.init_evaluator()

    # crawl the training image directories
    crawler = futil.FileSystemDataCrawler(data_test_dir, LOADING_KEYS,
                                          futil.BrainImageFilePathGenerator(),
                                          futil.DataDirectoryFilter())

    # load images for testing and pre-process
    pre_process_params['training'] = False
    images_test = putil.pre_process_batch(crawler.data,
                                          pre_process_params,
                                          multi_process=False)

    images_prediction = []
    images_probabilities = []

    for img in images_test:
        print('-' * 10, 'Testing', img.id_)

        start_time = timeit.default_timer()
        predictions = forest.predict(img.feature_matrix[0])
        probabilities = forest.predict_proba(img.feature_matrix[0])
        print(' Time elapsed:', timeit.default_timer() - start_time, 's')

        # convert prediction and probabilities back to SimpleITK images
        image_prediction = conversion.NumpySimpleITKImageBridge.convert(
            predictions.astype(np.uint8), img.image_properties)
        image_probabilities = conversion.NumpySimpleITKImageBridge.convert(
            probabilities, img.image_properties)

        # evaluate segmentation without post-processing
        evaluator.evaluate(image_prediction,
                           img.images[structure.BrainImageTypes.GroundTruth],
                           img.id_)

        images_prediction.append(image_prediction)
        images_probabilities.append(image_probabilities)

    # post-process segmentation and evaluate with post-processing
    post_process_params = {'simple_post': True}
    images_post_processed = putil.post_process_batch(images_test,
                                                     images_prediction,
                                                     images_probabilities,
                                                     post_process_params,
                                                     multi_process=False)

    for i, img in enumerate(images_test):
        evaluator.evaluate(images_post_processed[i],
                           img.images[structure.BrainImageTypes.GroundTruth],
                           img.id_ + '-PP')

        # save results
        sitk.WriteImage(
            images_prediction[i],
            os.path.join(result_dir, images_test[i].id_ + '_SEG.mha'), True)
        sitk.WriteImage(
            images_post_processed[i],
            os.path.join(result_dir, images_test[i].id_ + '_SEG-PP.mha'), True)

    # use two writers to report the results
    os.makedirs(
        result_dir,
        exist_ok=True)  # generate result directory, if it does not exists
    result_file = os.path.join(result_dir, 'results.csv')
    writer.CSVWriter(result_file).write(evaluator.results)

    print('\nSubject-wise results...')
    writer.ConsoleWriter().write(evaluator.results)

    # report also mean and standard deviation among all subjects
    result_summary_file = os.path.join(result_dir, 'results_summary.csv')
    functions = {'MEAN': np.mean, 'STD': np.std}
    writer.CSVStatisticsWriter(result_summary_file,
                               functions=functions).write(evaluator.results)
    print('\nAggregated statistic results...')
    writer.ConsoleStatisticsWriter(functions=functions).write(
        evaluator.results)

    # clear results such that the evaluator is ready for the next evaluation
    evaluator.clear()
Exemple #4
0
def main(result_dir: str, data_atlas_dir: str, data_train_dir: str,
         data_test_dir: str, tmp_result_dir: str):
    """Brain tissue segmentation using decision forests.

    Section of the original main routine. Executes post processing part of the medical image analysis pipeline:

        Must be done separately in advance:
        - Image loading
        - Registration
        - Pre-processing
        - Feature extraction
        - Decision forest classifier model building
        - Segmentation using the decision forest classifier model on unseen images

        Is carried out in this section of the pipeline
        - Loading of temporary data
        - Post-processing of the segmentation
        - Evaluation of the segmentation
    """

    # load atlas images
    putil.load_atlas_images(data_atlas_dir)

    # print('-' * 5, 'Training...')
    #
    # # crawl the training image directories
    # crawler = futil.FileSystemDataCrawler(data_train_dir,
    #                                       LOADING_KEYS,
    #                                       futil.BrainImageFilePathGenerator(),
    #                                       futil.DataDirectoryFilter())
    pre_process_params = {
        'skullstrip_pre': True,
        'normalization_pre': True,
        'registration_pre': True,
        'coordinates_feature': True,
        'intensity_feature': True,
        'gradient_intensity_feature': True
    }

    # create a result directory with timestamp
    t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    result_dir = os.path.join(result_dir, t)
    os.makedirs(result_dir, exist_ok=True)

    # initialize evaluator
    evaluator = putil.init_evaluator()

    # crawl the test image directories
    crawler = futil.FileSystemDataCrawler(data_test_dir, LOADING_KEYS,
                                          futil.BrainImageFilePathGenerator(),
                                          futil.DataDirectoryFilter())

    # load necessary data to perform post processing
    pre_process_params['training'] = False
    images_test = putil.pre_process_batch(crawler.data,
                                          pre_process_params,
                                          multi_process=False)

    # load the prediction of the test images (segmented image
    images_prediction, images_probabilities = putil.load_prediction_images(
        images_test, tmp_result_dir, '2020-10-30-18-31-15')
    # evaluate images without post-processing
    for i, img in enumerate(images_test):
        evaluator.evaluate(images_prediction[i],
                           img.images[structure.BrainImageTypes.GroundTruth],
                           img.id_)

    # post-process segmentation and evaluate with post-processing
    post_process_params = {
        'simple_post': True,
        'variance': 1.0,
        'preserve_background': False
    }
    images_post_processed = putil.post_process_batch(images_test,
                                                     images_prediction,
                                                     images_probabilities,
                                                     post_process_params,
                                                     multi_process=False)

    for i, img in enumerate(images_test):
        evaluator.evaluate(images_post_processed[i],
                           img.images[structure.BrainImageTypes.GroundTruth],
                           img.id_ + '-PP')

        # save results
        sitk.WriteImage(
            images_prediction[i],
            os.path.join(result_dir, images_test[i].id_ + '_SEG.mha'), True)
        sitk.WriteImage(
            images_post_processed[i],
            os.path.join(result_dir, images_test[i].id_ + '_SEG-PP.mha'), True)

    # use two writers to report the results
    os.makedirs(
        result_dir,
        exist_ok=True)  # generate result directory, if it does not exists
    result_file = os.path.join(result_dir, 'results.csv')
    writer.CSVWriter(result_file).write(evaluator.results)

    print('\nSubject-wise results...')
    writer.ConsoleWriter().write(evaluator.results)

    # report also mean and standard deviation among all subjects
    result_summary_file = os.path.join(result_dir, 'results_summary.csv')
    functions = {'MEAN': np.mean, 'STD': np.std}
    writer.CSVStatisticsWriter(result_summary_file,
                               functions=functions).write(evaluator.results)
    print('\nAggregated statistic results...')
    writer.ConsoleStatisticsWriter(functions=functions).write(
        evaluator.results)

    # clear results such that the evaluator is ready for the next evaluation
    evaluator.clear()
Exemple #5
0
def main(result_dir: str, data_atlas_dir: str, data_train_dir: str,
         data_test_dir: str):
    """Brain tissue segmentation using decision forests.

    Section of the original main routine. Executes gird search of the probabilistic keyhole filling method parameters:

        Must be done separately in advance:
        - Image loading
        - Registration
        - Pre-processing
        - Feature extraction
        - Decision forest classifier model building
        - Segmentation using the decision forest classifier model on unseen images

        Is carried out in this section of the pipeline
        - Loading of temporary data
        - Grid search of PKF parameter of the segmentation
        - Evaluation of the segmentation
    """

    # load atlas images
    putil.load_atlas_images(data_atlas_dir)

    print('-' * 5, 'Training...')

    # crawl the training image directories
    crawler = futil.FileSystemDataCrawler(data_train_dir, LOADING_KEYS,
                                          futil.BrainImageFilePathGenerator(),
                                          futil.DataDirectoryFilter())
    pre_process_params = {
        'skullstrip_pre': True,
        'normalization_pre': True,
        'registration_pre': True,
        'coordinates_feature': True,
        'intensity_feature': True,
        'gradient_intensity_feature': True
    }

    # load images for training and pre-process
    images = putil.pre_process_batch(crawler.data,
                                     pre_process_params,
                                     multi_process=False)

    # generate feature matrix and label vector
    data_train = np.concatenate([img.feature_matrix[0] for img in images])
    labels_train = np.concatenate([img.feature_matrix[1]
                                   for img in images]).squeeze()

    #warnings.warn('Random forest parameters not properly set.')
    forest = sk_ensemble.RandomForestClassifier(
        max_features=images[0].feature_matrix[0].shape[1],
        n_estimators=20,
        max_depth=85)

    start_time = timeit.default_timer()
    forest.fit(data_train, labels_train)
    print(' Time elapsed:', timeit.default_timer() - start_time, 's')

    # create a result directory with timestamp
    t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    result_dir = os.path.join(result_dir, t)
    os.makedirs(result_dir, exist_ok=True)

    print('-' * 5, 'Testing...')

    # initialize evaluator
    evaluator = putil.init_evaluator()

    # crawl the training image directories
    crawler = futil.FileSystemDataCrawler(data_test_dir, LOADING_KEYS,
                                          futil.BrainImageFilePathGenerator(),
                                          futil.DataDirectoryFilter())

    # load images for testing and pre-process
    pre_process_params['training'] = False
    images_test = putil.pre_process_batch(crawler.data,
                                          pre_process_params,
                                          multi_process=False)

    images_prediction = []
    images_probabilities = []

    for img in images_test:
        print('-' * 10, 'Testing', img.id_)

        start_time = timeit.default_timer()
        predictions = forest.predict(img.feature_matrix[0])
        probabilities = forest.predict_proba(img.feature_matrix[0])
        print(' Time elapsed:', timeit.default_timer() - start_time, 's')

        # convert prediction and probabilities back to SimpleITK images
        image_prediction = conversion.NumpySimpleITKImageBridge.convert(
            predictions.astype(np.uint8), img.image_properties)
        image_probabilities = conversion.NumpySimpleITKImageBridge.convert(
            probabilities, img.image_properties)

        # evaluate segmentation without post-processing
        evaluator.evaluate(image_prediction,
                           img.images[structure.BrainImageTypes.GroundTruth],
                           img.id_)

        images_prediction.append(image_prediction)
        images_probabilities.append(image_probabilities)

    # save results without post-processing
    name = 'no_PP'
    sub_dir = os.path.join(result_dir, name)
    os.makedirs(sub_dir, exist_ok=True)

    for i, img in enumerate(images_test):
        sitk.WriteImage(images_prediction[i],
                        os.path.join(sub_dir, images_test[i].id_ + '_SEG.mha'),
                        True)

    result_file = os.path.join(sub_dir, 'results.csv')
    writer.CSVWriter(result_file).write(evaluator.results)

    # report also mean and standard deviation among all subjects
    result_summary_file = os.path.join(sub_dir, 'results_summary.csv')
    functions = {'MEAN': np.mean, 'STD': np.std}
    writer.CSVStatisticsWriter(result_summary_file,
                               functions=functions).write(evaluator.results)

    # clear results such that the evaluator is ready for the next evaluation
    evaluator.clear()

    # define paramter for grid search
    post_process_param_list = []
    variance = np.arange(1, 2)
    preserve_background = np.asarray([False])

    #
    # # define paramter for grid search
    # post_process_param_list = []
    # variance = np.arange(0.5, 4.0, 0.5)
    # preserve_background = np.asarray([False, True])

    for bg in preserve_background:
        for var in variance:
            post_process_param_list.append({
                'simple_post': bool(True),
                'variance': float(var),
                'preserve_background': bool(bg)
            })

    # execute post processing with definde parameters
    for post_process_params in post_process_param_list:

        # create sub-directory for results
        name = 'PP-V-'+ str(post_process_params.get('variance')).replace('.','_') +\
               '-BG-' + str(post_process_params.get('preserve_background'))
        sub_dir = os.path.join(result_dir, name)
        os.makedirs(sub_dir, exist_ok=True)

        #write the used parameter into a text file and store it in the result folder
        completeName = os.path.join(sub_dir, "parameter.txt")
        file1 = open(completeName, "w+")
        json.dump(post_process_params, file1)
        file1.close()

        # post-process segmentation and evaluate with post-processing
        images_post_processed = putil.post_process_batch(images_test,
                                                         images_prediction,
                                                         images_probabilities,
                                                         post_process_params,
                                                         multi_process=False)

        for i, img in enumerate(images_test):
            evaluator.evaluate(
                images_post_processed[i],
                img.images[structure.BrainImageTypes.GroundTruth],
                img.id_ + '-PP')
            # save results
            sitk.WriteImage(
                images_post_processed[i],
                os.path.join(sub_dir, images_test[i].id_ + '_SEG-PP.mha'),
                True)

        # save all results in csv file
        result_file = os.path.join(sub_dir, 'results.csv')
        writer.CSVWriter(result_file).write(evaluator.results)

        print('\nSubject-wise results...')
        writer.ConsoleWriter().write(evaluator.results)

        # report also mean and standard deviation among all subjects
        result_summary_file = os.path.join(sub_dir, 'results_summary.csv')
        functions = {'MEAN': np.mean, 'STD': np.std}
        writer.CSVStatisticsWriter(
            result_summary_file, functions=functions).write(evaluator.results)
        print('\nAggregated statistic results...')
        writer.ConsoleStatisticsWriter(functions=functions).write(
            evaluator.results)

        # clear results such that the evaluator is ready for the next evaluation
        evaluator.clear()
Exemple #6
0
def main(hdf_file, log_dir):
    # initialize the evaluator with the metrics and the labels to evaluate
    metrics = [metric.DiceCoefficient()]
    labels = {
        1: 'WHITEMATTER',
        2: 'GREYMATTER',
        3: 'HIPPOCAMPUS',
        4: 'AMYGDALA',
        5: 'THALAMUS'
    }
    evaluator = eval_.SegmentationEvaluator(metrics, labels)

    # we want to log the mean and standard deviation of the metrics among all subjects of the dataset
    functions = {'MEAN': np.mean, 'STD': np.std}
    statistics_aggregator = writer.StatisticsAggregator(functions=functions)
    console_writer = writer.ConsoleStatisticsWriter(functions=functions)

    # initialize TensorBoard writer
    summary_writer = tf.summary.create_file_writer(
        os.path.join(log_dir, 'logging-example-tensorflow'))

    # setup the training datasource
    train_subjects, valid_subjects = ['Subject_1', 'Subject_2',
                                      'Subject_3'], ['Subject_4']
    extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES,
                                               defs.KEY_LABELS))
    indexing_strategy = extr.SliceIndexing()

    augmentation_transforms = [
        augm.RandomElasticDeformation(),
        augm.RandomMirror()
    ]
    transforms = [tfm.Squeeze(entries=(defs.KEY_LABELS, ))]
    train_transforms = tfm.ComposeTransform(augmentation_transforms +
                                            transforms)
    train_dataset = extr.PymiaDatasource(hdf_file,
                                         indexing_strategy,
                                         extractor,
                                         train_transforms,
                                         subject_subset=train_subjects)

    # setup the validation datasource
    batch_size = 16
    valid_transforms = tfm.ComposeTransform([])
    valid_dataset = extr.PymiaDatasource(hdf_file,
                                         indexing_strategy,
                                         extractor,
                                         valid_transforms,
                                         subject_subset=valid_subjects)
    direct_extractor = extr.ComposeExtractor([
        extr.SubjectExtractor(),
        extr.ImagePropertiesExtractor(),
        extr.DataExtractor(categories=(defs.KEY_LABELS, ))
    ])
    assembler = assm.SubjectAssembler(valid_dataset)

    # tensorflow specific handling
    train_gen_fn = pymia_tf.get_tf_generator(train_dataset)
    tf_train_dataset = tf.data.Dataset.from_generator(
        generator=train_gen_fn,
        output_types={
            defs.KEY_IMAGES: tf.float32,
            defs.KEY_LABELS: tf.int64,
            defs.KEY_SAMPLE_INDEX: tf.int64
        })
    tf_train_dataset = tf_train_dataset.batch(batch_size).shuffle(
        len(train_dataset))

    valid_gen_fn = pymia_tf.get_tf_generator(valid_dataset)
    tf_valid_dataset = tf.data.Dataset.from_generator(
        generator=valid_gen_fn,
        output_types={
            defs.KEY_IMAGES: tf.float32,
            defs.KEY_LABELS: tf.int64,
            defs.KEY_SAMPLE_INDEX: tf.int64
        })
    tf_valid_dataset = tf_valid_dataset.batch(batch_size)

    u_net = unet.build_model(channels=2,
                             num_classes=6,
                             layer_depth=3,
                             filters_root=16)

    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)

    train_batches = len(train_dataset) // batch_size

    # looping over the data in the dataset
    epochs = 100
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')

        # training
        print('training')
        for i, batch in enumerate(tf_train_dataset):
            x, y = batch[defs.KEY_IMAGES], batch[defs.KEY_LABELS]

            with tf.GradientTape() as tape:
                logits = u_net(x, training=True)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    y, logits, from_logits=True)

            grads = tape.gradient(loss, u_net.trainable_variables)
            optimizer.apply_gradients(zip(grads, u_net.trainable_variables))

            train_loss(loss)

            with summary_writer.as_default():
                tf.summary.scalar('train/loss',
                                  train_loss.result(),
                                  step=epoch * train_batches + i)
            print(
                f'[{i + 1}/{train_batches}]\tloss: {train_loss.result().numpy()}'
            )

        # validation
        print('validation')
        valid_batches = len(valid_dataset) // batch_size
        for i, batch in enumerate(tf_valid_dataset):
            x, sample_indices = batch[defs.KEY_IMAGES], batch[
                defs.KEY_SAMPLE_INDEX]

            logits = u_net(x)
            prediction = tf.expand_dims(tf.math.argmax(logits, -1), -1)

            numpy_prediction = prediction.numpy()

            is_last = i == valid_batches - 1
            assembler.add_batch(numpy_prediction, sample_indices.numpy(),
                                is_last)

            for subject_index in assembler.subjects_ready:
                subject_prediction = assembler.get_assembled_subject(
                    subject_index)

                direct_sample = train_dataset.direct_extract(
                    direct_extractor, subject_index)
                target, image_properties = direct_sample[
                    defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]

                # evaluate the prediction against the reference
                evaluator.evaluate(subject_prediction[..., 0], target[..., 0],
                                   direct_sample[defs.KEY_SUBJECT])

        # calculate mean and standard deviation of each metric
        results = statistics_aggregator.calculate(evaluator.results)
        # log to TensorBoard into category train
        with summary_writer.as_default():
            for result in results:
                tf.summary.scalar(f'valid/{result.metric}-{result.id_}',
                                  result.value, epoch)

        console_writer.write(evaluator.results)

        # clear results such that the evaluator is ready for the next evaluation
        evaluator.clear()
Exemple #7
0
def main(result_dir: str, data_atlas_dir: str, data_train_dir: str, data_test_dir: str, parameters_file: str):
    """Brain tissue segmentation using decision forests.

    The main routine executes the medical image analysis pipeline:

        - Image loading
        - Registration
        - Pre-processing
        - Feature extraction
        - Decision forest classifier model building
        - Segmentation using the decision forest classifier model on unseen images
        - Post-processing of the segmentation
        - Evaluation of the segmentation
    """
    start_main = timeit.default_timer()
    # load atlas images
    putil.load_atlas_images(data_atlas_dir)

    print('-' * 5, 'Training...')

    # crawl the training image directories
    crawler = futil.FileSystemDataCrawler(data_train_dir,
                                          LOADING_KEYS,
                                          futil.BrainImageFilePathGenerator(),
                                          futil.DataDirectoryFilter())

    fof_parameters = {'10Percentile': True,
                      '90Percentile': True,
                      'Energy': True,
                      'Entropy': True,
                      'InterquartileRange': True,
                      'Kurtosis': True,
                      'Maximum': True,
                      'MeanAbsoluteDeviation': True,
                      'Mean': True,
                      'Median': True,
                      'Minimum': True,
                      'Range': True,
                      'RobustMeanAbsoluteDeviation': True,
                      'RootMeanSquared': True,
                      'Skewness': True,
                      'TotalEnergy': True,
                      'Uniformity': True,
                      'Variance': True}

    glcm_parameters = {'Autocorrelation': True,
                       'ClusterProminence': True,
                       'ClusterShade': True,
                       'ClusterTendency': True,
                       'Contrast': True,
                       'Correlation': True,
                       'DifferenceAverage': True,
                       'DifferenceEntropy': True,
                       'DifferenceVariance': True,
                       'Id': True,
                       'Idm': True,
                       'Idmn': True,
                       'Idn': True,
                       'Imc1': True,
                       'Imc2': True,
                       'InverseVariance': True,
                       'JointAverage': True,
                       'JointEnergy': True,
                       'JointEntropy': True,
                       'MCC': True,
                       'MaximumProbability': True,
                       'SumAverage': True,
                       'SumEntropy': True,
                       'SumSquares': True}

    pre_process_params = {'skullstrip_pre': True,
                          'normalization_pre': True,
                          'registration_pre': True,
                          'save_features': False,
                          'coordinates_feature': True,
                          'intensity_feature': False,
                          'gradient_intensity_feature': False,
                          'first_order_feature': False,
                          'first_order_feature_parameters': fof_parameters,
                          'HOG_feature': False,
                          'GLCM_features': False,
                          'GLCM_features_parameters': glcm_parameters,
                          'n_estimators': 50,
                          'max_depth': 60,
                          'experiment_name': 'default'
                          }

    parameters = json.load(open(parameters_file, 'r'))
    if bool(parameters):
        pre_process_params = parameters

    # load images for training and pre-process
    images = putil.pre_process_batch(crawler.data, pre_process_params, multi_process=False)

    # generate feature matrix and label vector
    data_train = np.concatenate([img.feature_matrix[0] for img in images])
    labels_train = np.concatenate([img.feature_matrix[1] for img in images]).squeeze()
    np.nan_to_num(data_train, copy=False)

    # warnings.warn('Random forest parameters not properly set.')
    forest = sk_ensemble.RandomForestClassifier(max_features=images[0].feature_matrix[0].shape[1],
                                                n_estimators=pre_process_params['n_estimators'],  # 100
                                                max_depth=pre_process_params['max_depth'])  # 10

    # Debugging

    nan_data_idx = np.argwhere(np.isnan(data_train))
    np.savez('data_train.npz', data_train)
    np.save('data_nan.npy', nan_data_idx)

    start_time = timeit.default_timer()
    forest.fit(data_train, labels_train)
    print(' Time elapsed:', timeit.default_timer() - start_time, 's')

    # create a result directory with timestamp
    result_dir = os.path.join(result_dir, pre_process_params['experiment_name'])
    os.makedirs(result_dir, exist_ok=True)

    print('-' * 5, 'Testing...')

    # initialize evaluator
    evaluator = putil.init_evaluator()

    # crawl the training image directories
    crawler = futil.FileSystemDataCrawler(data_test_dir,
                                          LOADING_KEYS,
                                          futil.BrainImageFilePathGenerator(),
                                          futil.DataDirectoryFilter())

    # load images for testing and pre-process
    pre_process_params['training'] = False
    images_test = putil.pre_process_batch(crawler.data, pre_process_params, multi_process=False)

    images_prediction = []
    images_probabilities = []

    for img in images_test:
        print('-' * 10, 'Testing', img.id_)

        start_time = timeit.default_timer()
        predictions = forest.predict(np.nan_to_num(img.feature_matrix[0],copy=False))
        probabilities = forest.predict_proba(np.nan_to_num(img.feature_matrix[0],copy=False))
        print(' Time elapsed:', timeit.default_timer() - start_time, 's')

        # convert prediction and probabilities back to SimpleITK images
        image_prediction = conversion.NumpySimpleITKImageBridge.convert(predictions.astype(np.uint8),
                                                                        img.image_properties)
        image_probabilities = conversion.NumpySimpleITKImageBridge.convert(probabilities, img.image_properties)

        # evaluate segmentation without post-processing
        evaluator.evaluate(image_prediction, img.images[structure.BrainImageTypes.GroundTruth], img.id_)

        images_prediction.append(image_prediction)
        images_probabilities.append(image_probabilities)

    # post-process segmentation and evaluate with post-processing
    post_process_params = {'simple_post': True}
    images_post_processed = putil.post_process_batch(images_test, images_prediction, images_probabilities,
                                                     post_process_params, multi_process=False)

    for i, img in enumerate(images_test):
        evaluator.evaluate(images_post_processed[i], img.images[structure.BrainImageTypes.GroundTruth],
                           img.id_ + '-PP')

        # save results
        sitk.WriteImage(images_prediction[i], os.path.join(result_dir, images_test[i].id_ + '_SEG.mha'), True)
        sitk.WriteImage(images_post_processed[i], os.path.join(result_dir, images_test[i].id_ + '_SEG-PP.mha'), True)

    # use two writers to report the results
    os.makedirs(result_dir, exist_ok=True)  # generate result directory, if it does not exists
    result_file = os.path.join(result_dir, 'results.csv')
    writer.CSVWriter(result_file).write(evaluator.results)

    print('\nSubject-wise results...')
    writer.ConsoleWriter().write(evaluator.results)

    # report also mean and standard deviation among all subjects
    result_summary_file = os.path.join(result_dir, 'results_summary.csv')
    functions = {'MEAN': np.mean, 'STD': np.std}
    writer.CSVStatisticsWriter(result_summary_file, functions=functions).write(evaluator.results)
    print('\nAggregated statistic results...')
    writer.ConsoleStatisticsWriter(functions=functions).write(evaluator.results)

    # clear results such that the evaluator is ready for the next evaluation
    evaluator.clear()
    end_main = timeit.default_timer()
    main_time = end_main - start_main

    # writing information on a txt file
    reporter.feature_writer(result_dir, pre_process_params, main_time, 'feature_report')