Ejemplo n.º 1
0
def main(hdf_file):

    # Use a pad extractor in order to compensate for the valid convolutions of the network. Actual image information is padded
    extractor = extr.PadDataExtractor(
        (2, 2, 2), extr.DataExtractor(categories=(defs.KEY_IMAGES, )))

    # Adapted permutation due to the additional dimension
    transform = tfm.Permute(permutation=(3, 0, 1, 2),
                            entries=(defs.KEY_IMAGES, ))

    # Creating patch indexing strategy with patch_shape that equal the network output shape
    indexing_strategy = extr.PatchWiseIndexing(patch_shape=(32, 32, 32))
    dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor,
                                   transform)

    direct_extractor = extr.ComposeExtractor([
        extr.ImagePropertiesExtractor(),
        extr.DataExtractor(categories=(defs.KEY_LABELS, defs.KEY_IMAGES))
    ])
    assembler = assm.SubjectAssembler(dataset)

    # torch specific handling
    pytorch_dataset = pymia_torch.PytorchDatasetAdapter(dataset)
    loader = torch_data.dataloader.DataLoader(pytorch_dataset,
                                              batch_size=2,
                                              shuffle=False)
    # dummy CNN with valid convolutions instead of same convolutions
    dummy_network = nn.Sequential(
        nn.Conv3d(in_channels=2, out_channels=8, kernel_size=3, padding=0),
        nn.Conv3d(in_channels=8, out_channels=1, kernel_size=3, padding=0),
        nn.Sigmoid())
    torch.set_grad_enabled(False)

    nb_batches = len(loader)

    # looping over the data in the dataset
    for i, batch in enumerate(loader):

        x, sample_indices = batch[defs.KEY_IMAGES], batch[
            defs.KEY_SAMPLE_INDEX]
        prediction = dummy_network(x)

        numpy_prediction = prediction.numpy().transpose((0, 2, 3, 4, 1))

        is_last = i == nb_batches - 1
        assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last)

        for subject_index in assembler.subjects_ready:
            subject_prediction = assembler.get_assembled_subject(subject_index)

            direct_sample = dataset.direct_extract(direct_extractor,
                                                   subject_index)
            target, image_properties = direct_sample[
                defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]
Ejemplo n.º 2
0
def main(hdf_file, is_meta):

    if not is_meta:
        extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, ))
    else:
        extractor = extr.FilesystemDataExtractor(
            categories=(defs.KEY_IMAGES, ))

    transform = tfm.Permute(permutation=(2, 0, 1), entries=(defs.KEY_IMAGES, ))

    indexing_strategy = extr.SliceIndexing()
    dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor,
                                   transform)

    direct_extractor = extr.ComposeExtractor([
        extr.ImagePropertiesExtractor(),
        extr.DataExtractor(categories=(defs.KEY_LABELS, defs.KEY_IMAGES))
    ])
    assembler = assm.SubjectAssembler(dataset)

    # torch specific handling
    pytorch_dataset = pymia_torch.PytorchDatasetAdapter(dataset)
    loader = torch_data.dataloader.DataLoader(pytorch_dataset,
                                              batch_size=2,
                                              shuffle=False)
    dummy_network = nn.Sequential(
        nn.Conv2d(in_channels=2, out_channels=8, kernel_size=3, padding=1),
        nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1),
        nn.Sigmoid())
    torch.set_grad_enabled(False)

    nb_batches = len(loader)

    # looping over the data in the dataset
    for i, batch in enumerate(loader):

        x, sample_indices = batch[defs.KEY_IMAGES], batch[
            defs.KEY_SAMPLE_INDEX]
        prediction = dummy_network(x)

        numpy_prediction = prediction.numpy().transpose((0, 2, 3, 1))

        is_last = i == nb_batches - 1
        assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last)

        for subject_index in assembler.subjects_ready:
            subject_prediction = assembler.get_assembled_subject(subject_index)

            direct_sample = dataset.direct_extract(direct_extractor,
                                                   subject_index)
            target, image_properties = direct_sample[
                defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]
Ejemplo n.º 3
0
def main(hdf_file, log_dir):
    # initialize the evaluator with the metrics and the labels to evaluate
    metrics = [metric.DiceCoefficient()]
    labels = {1: 'WHITEMATTER',
              2: 'GREYMATTER',
              3: 'HIPPOCAMPUS',
              4: 'AMYGDALA',
              5: 'THALAMUS'}
    evaluator = eval_.SegmentationEvaluator(metrics, labels)

    # we want to log the mean and standard deviation of the metrics among all subjects of the dataset
    functions = {'MEAN': np.mean, 'STD': np.std}
    statistics_aggregator = writer.StatisticsAggregator(functions=functions)
    console_writer = writer.ConsoleStatisticsWriter(functions=functions)

    # initialize TensorBoard writer
    tb = tensorboard.SummaryWriter(os.path.join(log_dir, 'logging-example-torch'))

    # setup the training datasource
    train_subjects, valid_subjects = ['Subject_1', 'Subject_2', 'Subject_3'], ['Subject_4']
    extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS))
    indexing_strategy = extr.SliceIndexing()

    augmentation_transforms = [augm.RandomElasticDeformation(), augm.RandomMirror()]
    transforms = [tfm.Permute(permutation=(2, 0, 1)), tfm.Squeeze(entries=(defs.KEY_LABELS,))]
    train_transforms = tfm.ComposeTransform(augmentation_transforms + transforms)
    train_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, train_transforms,
                                         subject_subset=train_subjects)

    # setup the validation datasource
    valid_transforms = tfm.ComposeTransform([tfm.Permute(permutation=(2, 0, 1))])
    valid_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, valid_transforms,
                                         subject_subset=valid_subjects)
    direct_extractor = extr.ComposeExtractor(
        [extr.SubjectExtractor(),
         extr.ImagePropertiesExtractor(),
         extr.DataExtractor(categories=(defs.KEY_LABELS,))]
    )
    assembler = assm.SubjectAssembler(valid_dataset)

    # torch specific handling
    pytorch_train_dataset = pymia_torch.PytorchDatasetAdapter(train_dataset)
    train_loader = torch_data.dataloader.DataLoader(pytorch_train_dataset, batch_size=16, shuffle=True)

    pytorch_valid_dataset = pymia_torch.PytorchDatasetAdapter(valid_dataset)
    valid_loader = torch_data.dataloader.DataLoader(pytorch_valid_dataset, batch_size=16, shuffle=False)

    u_net = unet.UNetModel(ch_in=2, ch_out=6, n_channels=16, n_pooling=3).to(device)

    print(u_net)

    optimizer = optim.Adam(u_net.parameters(), lr=1e-3)
    train_batches = len(train_loader)

    # looping over the data in the dataset
    epochs = 100
    for epoch in range(epochs):
        u_net.train()
        print(f'Epoch {epoch + 1}/{epochs}')

        # training
        print('training')
        for i, batch in enumerate(train_loader):
            x, y = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_LABELS].to(device).long()
            logits = u_net(x)

            optimizer.zero_grad()
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optimizer.step()

            tb.add_scalar('train/loss', loss.item(), epoch*train_batches + i)
            print(f'[{i + 1}/{train_batches}]\tloss: {loss.item()}')

        # validation
        print('validation')
        with torch.no_grad():
            u_net.eval()
            valid_batches = len(valid_loader)
            for i, batch in enumerate(valid_loader):
                x, sample_indices = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_SAMPLE_INDEX]

                logits = u_net(x)
                prediction = logits.argmax(dim=1, keepdim=True)

                numpy_prediction = prediction.cpu().numpy().transpose((0, 2, 3, 1))

                is_last = i == valid_batches - 1
                assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last)

                for subject_index in assembler.subjects_ready:
                    subject_prediction = assembler.get_assembled_subject(subject_index)

                    direct_sample = train_dataset.direct_extract(direct_extractor, subject_index)
                    target, image_properties = direct_sample[defs.KEY_LABELS],  direct_sample[defs.KEY_PROPERTIES]

                    # evaluate the prediction against the reference
                    evaluator.evaluate(subject_prediction[..., 0], target[..., 0], direct_sample[defs.KEY_SUBJECT])

            # calculate mean and standard deviation of each metric
            results = statistics_aggregator.calculate(evaluator.results)
            # log to TensorBoard into category train
            for result in results:
                tb.add_scalar(f'valid/{result.metric}-{result.id_}', result.value, epoch)

            console_writer.write(evaluator.results)

            # clear results such that the evaluator is ready for the next evaluation
            evaluator.clear()
Ejemplo n.º 4
0
def main(hdf_file: str, log_dir: str):
    # initialize the evaluator with the metrics and the labels to evaluate
    metrics = [metric.DiceCoefficient()]
    labels = {1: 'WHITEMATTER', 2: 'GREYMATTER', 5: 'THALAMUS'}
    evaluator = eval_.SegmentationEvaluator(metrics, labels)

    # we want to log the mean and standard deviation of the metrics among all subjects of the dataset
    functions = {'MEAN': np.mean, 'STD': np.std}
    statistics_aggregator = writer.StatisticsAggregator(functions=functions)

    # initialize TensorBoard writer
    tb = tensorboard.SummaryWriter(
        os.path.join(log_dir, 'logging-example-torch'))

    # initialize the data handling
    transform = tfm.Permute(permutation=(2, 0, 1), entries=(defs.KEY_IMAGES, ))
    dataset = extr.PymiaDatasource(
        hdf_file, extr.SliceIndexing(),
        extr.DataExtractor(categories=(defs.KEY_IMAGES, )), transform)
    pytorch_dataset = pymia_torch.PytorchDatasetAdapter(dataset)
    loader = torch_data.dataloader.DataLoader(pytorch_dataset,
                                              batch_size=100,
                                              shuffle=False)

    assembler = assm.SubjectAssembler(dataset)
    direct_extractor = extr.ComposeExtractor([
        extr.SubjectExtractor(),  # extraction of the subject name
        extr.ImagePropertiesExtractor(
        ),  # Extraction of image properties (origin, spacing, etc.) for storage
        extr.DataExtractor(
            categories=(defs.KEY_LABELS,
                        ))  # Extraction of "labels" entries for evaluation
    ])

    # initialize a dummy network, which returns a random prediction
    class DummyNetwork(nn.Module):
        def forward(self, x):
            return torch.randint(0, 6, (x.size(0), 1, *x.size()[2:]))

    dummy_network = DummyNetwork()
    torch.manual_seed(0)  # set seed for reproducibility

    nb_batches = len(loader)

    epochs = 10
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')
        for i, batch in enumerate(loader):
            # get the data from batch and predict
            x, sample_indices = batch[defs.KEY_IMAGES], batch[
                defs.KEY_SAMPLE_INDEX]
            prediction = dummy_network(x)

            # translate the prediction to numpy and back to (B)HWC (channel last)
            numpy_prediction = prediction.numpy().transpose((0, 2, 3, 1))

            # add the batch prediction to the assembler
            is_last = i == nb_batches - 1
            assembler.add_batch(numpy_prediction, sample_indices.numpy(),
                                is_last)

            # process the subjects/images that are fully assembled
            for subject_index in assembler.subjects_ready:
                subject_prediction = assembler.get_assembled_subject(
                    subject_index)

                # extract the target and image properties via direct extract
                direct_sample = dataset.direct_extract(direct_extractor,
                                                       subject_index)
                reference, image_properties = direct_sample[
                    defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]

                # evaluate the prediction against the reference
                evaluator.evaluate(subject_prediction[..., 0], reference[...,
                                                                         0],
                                   direct_sample[defs.KEY_SUBJECT])

        # calculate mean and standard deviation of each metric
        results = statistics_aggregator.calculate(evaluator.results)
        # log to TensorBoard into category train
        for result in results:
            tb.add_scalar(f'train/{result.metric}-{result.id_}', result.value,
                          epoch)

        # clear results such that the evaluator is ready for the next evaluation
        evaluator.clear()