def test_misconfig(self, input_params): with self.assertRaisesRegex(ValueError, 'compatib'): dice_metric = MeanDice(**input_params) y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.ones((2, 1)) dice_metric.update([y_pred, y])
def test_shape_mismatch(self, input_params, _expected): dice_metric = MeanDice(**input_params) with self.assertRaises((AssertionError, ValueError)): y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.ones((2, 3)) dice_metric.update([y_pred, y]) with self.assertRaises((AssertionError, ValueError)): y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.ones((3, 2)) dice_metric.update([y_pred, y])
def test_compute(self, input_params, expected_avg): dice_metric = MeanDice(**input_params) y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.ones((2, 1)) dice_metric.update([y_pred, y]) y_pred = torch.Tensor([[0, 1], [1, 0]]) y = torch.Tensor([[1.], [0.]]) dice_metric.update([y_pred, y]) avg_dice = dice_metric.compute() self.assertAlmostEqual(avg_dice, expected_avg)
output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) dice_metric = MeanDice(add_sigmoid=True, output_transform=lambda output: (output[0][0], output[2])) dice_metric.attach(trainer, "Training Dice") logging.basicConfig(stream=sys.stdout, level=logging.INFO) stats_logger = StatsHandler() stats_logger.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) # tensor of ones to use where for converting labels to zero and ones ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32)
sw_batch_size = 4 def _sliding_window_processor(engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net, device) return seg_probs, seg.to(device) evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the arrary data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_path='tempdir', output_ext='.nii.gz', output_postfix='seg',
third_output_tensor = engine.state.output[0][1][2].detach().cpu() img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, 255, engine.state.epoch) third_label_tensor = torch.where(engine.state.batch[1][2] > 0, ones, engine.state.batch[1][2]) img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, 255, engine.state.epoch) engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) writer = SummaryWriter() # Set parameters for validation validation_every_n_epochs = 1 metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(add_sigmoid=True)} evaluator = create_supervised_evaluator(net, val_metrics, device, True, output_transform=lambda x, y, y_pred: (y_pred[0], y)) # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) val_stats_handler = StatsHandler() val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader
# StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) validation_every_n_iters = 5 # Set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) def run_validation(engine): evaluator.run(val_loader) # Add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer)