def test_loss_dict(self): log_stream = StringIO() logging.basicConfig(stream=log_stream, level=logging.INFO) key_to_handler = 'test_logging' key_to_print = 'myLoss1' # set up engine def _train_func(engine, batch): return torch.tensor(0.0) engine = Engine(_train_func) # set up testing handler stats_handler = StatsHandler( name=key_to_handler, output_transform=lambda x: {key_to_print: x}) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [1, 2, 3, 6, 7, 8]: self.assertTrue(has_key_word.match(line))
def test_metrics_print(self): log_stream = StringIO() logging.basicConfig(stream=log_stream, level=logging.INFO) key_to_handler = 'test_logging' key_to_print = 'testing_metric' # set up engine def _train_func(engine, batch): return torch.tensor(0.0) engine = Engine(_train_func) # set up dummy metric @engine.on(Events.EPOCH_COMPLETED) def _update_metric(engine): current_metric = engine.state.metrics.get(key_to_print, 0.1) engine.state.metrics[key_to_print] = current_metric + 0.1 # set up testing handler stats_handler = StatsHandler(name=key_to_handler) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) # check logging output output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [5, 10]: self.assertTrue(has_key_word.match(line))
require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) dice_metric = MeanDice(add_sigmoid=True, output_transform=lambda output: (output[0][0], output[2])) dice_metric.attach(trainer, "Training Dice") logging.basicConfig(stream=sys.stdout, level=logging.INFO) stats_logger = StatsHandler() stats_logger.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) # tensor of ones to use where for converting labels to zero and ones ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32) first_output_tensor = engine.state.output[0][1][0].detach().cpu() # log model output to tensorboard, as three dimensional tensor with no channels dimension img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, 255, engine.state.epoch)
return seg_probs, seg.to(device) evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice') # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the arrary data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator', batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0])) file_saver.attach(evaluator) # the model was trained by "unet_training_array" exmple ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net}) ckpt_saver.attach(evaluator)
# Since network outputs logits and segmentation, we need a custom function. def _loss_fn(i, j): return loss(i[0], j) # Create trainer device = torch.device("cuda:0") trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) train_stats_handler = StatsHandler() train_stats_handler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) # tensor of ones to use where for converting labels to zero and ones ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32) first_output_tensor = engine.state.output[0][1][0].detach().cpu() # log model output to tensorboard, as three dimensional tensor with no channels dimension img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, 255, engine.state.epoch) # get label tensor and convert to single class first_label_tensor = torch.where(engine.state.batch[1][0] > 0, ones, engine.state.batch[1][0]) # log label tensor to tensorboard, there is a channel dimension when getting label from batch