def main(hdf_file): # Use a pad extractor in order to compensate for the valid convolutions of the network. Actual image information is padded extractor = extr.PadDataExtractor( (2, 2, 2), extr.DataExtractor(categories=(defs.KEY_IMAGES, ))) # Adapted permutation due to the additional dimension transform = tfm.Permute(permutation=(3, 0, 1, 2), entries=(defs.KEY_IMAGES, )) # Creating patch indexing strategy with patch_shape that equal the network output shape indexing_strategy = extr.PatchWiseIndexing(patch_shape=(32, 32, 32)) dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, transform) direct_extractor = extr.ComposeExtractor([ extr.ImagePropertiesExtractor(), extr.DataExtractor(categories=(defs.KEY_LABELS, defs.KEY_IMAGES)) ]) assembler = assm.SubjectAssembler(dataset) # torch specific handling pytorch_dataset = pymia_torch.PytorchDatasetAdapter(dataset) loader = torch_data.dataloader.DataLoader(pytorch_dataset, batch_size=2, shuffle=False) # dummy CNN with valid convolutions instead of same convolutions dummy_network = nn.Sequential( nn.Conv3d(in_channels=2, out_channels=8, kernel_size=3, padding=0), nn.Conv3d(in_channels=8, out_channels=1, kernel_size=3, padding=0), nn.Sigmoid()) torch.set_grad_enabled(False) nb_batches = len(loader) # looping over the data in the dataset for i, batch in enumerate(loader): x, sample_indices = batch[defs.KEY_IMAGES], batch[ defs.KEY_SAMPLE_INDEX] prediction = dummy_network(x) numpy_prediction = prediction.numpy().transpose((0, 2, 3, 4, 1)) is_last = i == nb_batches - 1 assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last) for subject_index in assembler.subjects_ready: subject_prediction = assembler.get_assembled_subject(subject_index) direct_sample = dataset.direct_extract(direct_extractor, subject_index) target, image_properties = direct_sample[ defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]
def main(hdf_file, is_meta): if not is_meta: extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, )) else: extractor = extr.FilesystemDataExtractor( categories=(defs.KEY_IMAGES, )) transform = tfm.Permute(permutation=(2, 0, 1), entries=(defs.KEY_IMAGES, )) indexing_strategy = extr.SliceIndexing() dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, transform) direct_extractor = extr.ComposeExtractor([ extr.ImagePropertiesExtractor(), extr.DataExtractor(categories=(defs.KEY_LABELS, defs.KEY_IMAGES)) ]) assembler = assm.SubjectAssembler(dataset) # torch specific handling pytorch_dataset = pymia_torch.PytorchDatasetAdapter(dataset) loader = torch_data.dataloader.DataLoader(pytorch_dataset, batch_size=2, shuffle=False) dummy_network = nn.Sequential( nn.Conv2d(in_channels=2, out_channels=8, kernel_size=3, padding=1), nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1), nn.Sigmoid()) torch.set_grad_enabled(False) nb_batches = len(loader) # looping over the data in the dataset for i, batch in enumerate(loader): x, sample_indices = batch[defs.KEY_IMAGES], batch[ defs.KEY_SAMPLE_INDEX] prediction = dummy_network(x) numpy_prediction = prediction.numpy().transpose((0, 2, 3, 1)) is_last = i == nb_batches - 1 assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last) for subject_index in assembler.subjects_ready: subject_prediction = assembler.get_assembled_subject(subject_index) direct_sample = dataset.direct_extract(direct_extractor, subject_index) target, image_properties = direct_sample[ defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES]
def main(hdf_file, log_dir): # initialize the evaluator with the metrics and the labels to evaluate metrics = [metric.DiceCoefficient()] labels = {1: 'WHITEMATTER', 2: 'GREYMATTER', 3: 'HIPPOCAMPUS', 4: 'AMYGDALA', 5: 'THALAMUS'} evaluator = eval_.SegmentationEvaluator(metrics, labels) # we want to log the mean and standard deviation of the metrics among all subjects of the dataset functions = {'MEAN': np.mean, 'STD': np.std} statistics_aggregator = writer.StatisticsAggregator(functions=functions) console_writer = writer.ConsoleStatisticsWriter(functions=functions) # initialize TensorBoard writer tb = tensorboard.SummaryWriter(os.path.join(log_dir, 'logging-example-torch')) # setup the training datasource train_subjects, valid_subjects = ['Subject_1', 'Subject_2', 'Subject_3'], ['Subject_4'] extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS)) indexing_strategy = extr.SliceIndexing() augmentation_transforms = [augm.RandomElasticDeformation(), augm.RandomMirror()] transforms = [tfm.Permute(permutation=(2, 0, 1)), tfm.Squeeze(entries=(defs.KEY_LABELS,))] train_transforms = tfm.ComposeTransform(augmentation_transforms + transforms) train_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, train_transforms, subject_subset=train_subjects) # setup the validation datasource valid_transforms = tfm.ComposeTransform([tfm.Permute(permutation=(2, 0, 1))]) valid_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, valid_transforms, subject_subset=valid_subjects) direct_extractor = extr.ComposeExtractor( [extr.SubjectExtractor(), extr.ImagePropertiesExtractor(), extr.DataExtractor(categories=(defs.KEY_LABELS,))] ) assembler = assm.SubjectAssembler(valid_dataset) # torch specific handling pytorch_train_dataset = pymia_torch.PytorchDatasetAdapter(train_dataset) train_loader = torch_data.dataloader.DataLoader(pytorch_train_dataset, batch_size=16, shuffle=True) pytorch_valid_dataset = pymia_torch.PytorchDatasetAdapter(valid_dataset) valid_loader = torch_data.dataloader.DataLoader(pytorch_valid_dataset, batch_size=16, shuffle=False) u_net = unet.UNetModel(ch_in=2, ch_out=6, n_channels=16, n_pooling=3).to(device) print(u_net) optimizer = optim.Adam(u_net.parameters(), lr=1e-3) train_batches = len(train_loader) # looping over the data in the dataset epochs = 100 for epoch in range(epochs): u_net.train() print(f'Epoch {epoch + 1}/{epochs}') # training print('training') for i, batch in enumerate(train_loader): x, y = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_LABELS].to(device).long() logits = u_net(x) optimizer.zero_grad() loss = F.cross_entropy(logits, y) loss.backward() optimizer.step() tb.add_scalar('train/loss', loss.item(), epoch*train_batches + i) print(f'[{i + 1}/{train_batches}]\tloss: {loss.item()}') # validation print('validation') with torch.no_grad(): u_net.eval() valid_batches = len(valid_loader) for i, batch in enumerate(valid_loader): x, sample_indices = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_SAMPLE_INDEX] logits = u_net(x) prediction = logits.argmax(dim=1, keepdim=True) numpy_prediction = prediction.cpu().numpy().transpose((0, 2, 3, 1)) is_last = i == valid_batches - 1 assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last) for subject_index in assembler.subjects_ready: subject_prediction = assembler.get_assembled_subject(subject_index) direct_sample = train_dataset.direct_extract(direct_extractor, subject_index) target, image_properties = direct_sample[defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES] # evaluate the prediction against the reference evaluator.evaluate(subject_prediction[..., 0], target[..., 0], direct_sample[defs.KEY_SUBJECT]) # calculate mean and standard deviation of each metric results = statistics_aggregator.calculate(evaluator.results) # log to TensorBoard into category train for result in results: tb.add_scalar(f'valid/{result.metric}-{result.id_}', result.value, epoch) console_writer.write(evaluator.results) # clear results such that the evaluator is ready for the next evaluation evaluator.clear()
def main(hdf_file: str, log_dir: str): # initialize the evaluator with the metrics and the labels to evaluate metrics = [metric.DiceCoefficient()] labels = {1: 'WHITEMATTER', 2: 'GREYMATTER', 5: 'THALAMUS'} evaluator = eval_.SegmentationEvaluator(metrics, labels) # we want to log the mean and standard deviation of the metrics among all subjects of the dataset functions = {'MEAN': np.mean, 'STD': np.std} statistics_aggregator = writer.StatisticsAggregator(functions=functions) # initialize TensorBoard writer tb = tensorboard.SummaryWriter( os.path.join(log_dir, 'logging-example-torch')) # initialize the data handling transform = tfm.Permute(permutation=(2, 0, 1), entries=(defs.KEY_IMAGES, )) dataset = extr.PymiaDatasource( hdf_file, extr.SliceIndexing(), extr.DataExtractor(categories=(defs.KEY_IMAGES, )), transform) pytorch_dataset = pymia_torch.PytorchDatasetAdapter(dataset) loader = torch_data.dataloader.DataLoader(pytorch_dataset, batch_size=100, shuffle=False) assembler = assm.SubjectAssembler(dataset) direct_extractor = extr.ComposeExtractor([ extr.SubjectExtractor(), # extraction of the subject name extr.ImagePropertiesExtractor( ), # Extraction of image properties (origin, spacing, etc.) for storage extr.DataExtractor( categories=(defs.KEY_LABELS, )) # Extraction of "labels" entries for evaluation ]) # initialize a dummy network, which returns a random prediction class DummyNetwork(nn.Module): def forward(self, x): return torch.randint(0, 6, (x.size(0), 1, *x.size()[2:])) dummy_network = DummyNetwork() torch.manual_seed(0) # set seed for reproducibility nb_batches = len(loader) epochs = 10 for epoch in range(epochs): print(f'Epoch {epoch + 1}/{epochs}') for i, batch in enumerate(loader): # get the data from batch and predict x, sample_indices = batch[defs.KEY_IMAGES], batch[ defs.KEY_SAMPLE_INDEX] prediction = dummy_network(x) # translate the prediction to numpy and back to (B)HWC (channel last) numpy_prediction = prediction.numpy().transpose((0, 2, 3, 1)) # add the batch prediction to the assembler is_last = i == nb_batches - 1 assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last) # process the subjects/images that are fully assembled for subject_index in assembler.subjects_ready: subject_prediction = assembler.get_assembled_subject( subject_index) # extract the target and image properties via direct extract direct_sample = dataset.direct_extract(direct_extractor, subject_index) reference, image_properties = direct_sample[ defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES] # evaluate the prediction against the reference evaluator.evaluate(subject_prediction[..., 0], reference[..., 0], direct_sample[defs.KEY_SUBJECT]) # calculate mean and standard deviation of each metric results = statistics_aggregator.calculate(evaluator.results) # log to TensorBoard into category train for result in results: tb.add_scalar(f'train/{result.metric}-{result.id_}', result.value, epoch) # clear results such that the evaluator is ready for the next evaluation evaluator.clear()