def test_loss_dict(self): log_stream = StringIO() logging.basicConfig(stream=log_stream, level=logging.INFO) key_to_handler = 'test_logging' key_to_print = 'myLoss1' # set up engine def _train_func(engine, batch): return torch.tensor(0.0) engine = Engine(_train_func) # set up testing handler stats_handler = StatsHandler( name=key_to_handler, output_transform=lambda x: {key_to_print: x}) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [1, 2, 3, 6, 7, 8]: self.assertTrue(has_key_word.match(line))
def test_metrics_print(self): log_stream = StringIO() logging.basicConfig(stream=log_stream, level=logging.INFO) key_to_handler = 'test_logging' key_to_print = 'testing_metric' # set up engine def _train_func(engine, batch): return torch.tensor(0.0) engine = Engine(_train_func) # set up dummy metric @engine.on(Events.EPOCH_COMPLETED) def _update_metric(engine): current_metric = engine.state.metrics.get(key_to_print, 0.1) engine.state.metrics[key_to_print] = current_metric + 0.1 # set up testing handler stats_handler = StatsHandler(name=key_to_handler) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [5, 10]: self.assertTrue(has_key_word.match(line))
n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) dice_metric = MeanDice(add_sigmoid=True, output_transform=lambda output: (output[0][0], output[2])) dice_metric.attach(trainer, "Training Dice") logging.basicConfig(stream=sys.stdout, level=logging.INFO) stats_logger = StatsHandler() stats_logger.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) # tensor of ones to use where for converting labels to zero and ones ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32) first_output_tensor = engine.state.output[0][1][0].detach().cpu() # log model output to tensorboard, as three dimensional tensor with no channels dimension img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, 255,
with torch.no_grad(): seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net, device) return seg_probs, seg.to(device) evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the arrary data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0])) file_saver.attach(evaluator) # the model was trained by "unet_training_array" exmple
# adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions # and can use output_transform to convert engine.state.output if it's not loss value train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler train_tensorboard_stats_handler = TensorBoardStatsHandler() train_tensorboard_stats_handler.attach(trainer) # Set parameters for validation validation_every_n_epochs = 1 metric_name = 'Accuracy' # add evaluation metric to the evaluator engine val_metrics = {metric_name: Accuracy()} # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values evaluator = create_supervised_evaluator(net,