コード例 #1
0
ファイル: test_handler_stats.py プロジェクト: mhubii/MONAI
    def test_loss_dict(self):
        log_stream = StringIO()
        logging.basicConfig(stream=log_stream, level=logging.INFO)
        key_to_handler = 'test_logging'
        key_to_print = 'myLoss1'

        # set up engine
        def _train_func(engine, batch):
            return torch.tensor(0.0)

        engine = Engine(_train_func)

        # set up testing handler
        stats_handler = StatsHandler(
            name=key_to_handler, output_transform=lambda x: {key_to_print: x})
        stats_handler.attach(engine)

        engine.run(range(3), max_epochs=2)

        # check logging output
        output_str = log_stream.getvalue()
        grep = re.compile('.*{}.*'.format(key_to_handler))
        has_key_word = re.compile('.*{}.*'.format(key_to_print))
        for idx, line in enumerate(output_str.split('\n')):
            if grep.match(line):
                if idx in [1, 2, 3, 6, 7, 8]:
                    self.assertTrue(has_key_word.match(line))
コード例 #2
0
ファイル: test_handler_stats.py プロジェクト: mhubii/MONAI
    def test_metrics_print(self):
        log_stream = StringIO()
        logging.basicConfig(stream=log_stream, level=logging.INFO)
        key_to_handler = 'test_logging'
        key_to_print = 'testing_metric'

        # set up engine
        def _train_func(engine, batch):
            return torch.tensor(0.0)

        engine = Engine(_train_func)

        # set up dummy metric
        @engine.on(Events.EPOCH_COMPLETED)
        def _update_metric(engine):
            current_metric = engine.state.metrics.get(key_to_print, 0.1)
            engine.state.metrics[key_to_print] = current_metric + 0.1

        # set up testing handler
        stats_handler = StatsHandler(name=key_to_handler)
        stats_handler.attach(engine)

        engine.run(range(3), max_epochs=2)

        # check logging output
        output_str = log_stream.getvalue()
        grep = re.compile('.*{}.*'.format(key_to_handler))
        has_key_word = re.compile('.*{}.*'.format(key_to_print))
        for idx, line in enumerate(output_str.split('\n')):
            if grep.match(line):
                if idx in [5, 10]:
                    self.assertTrue(has_key_word.match(line))
コード例 #3
0
                                     require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                          handler=checkpoint_handler,
                          to_save={
                              'net': net,
                              'opt': opt
                          })

dice_metric = MeanDice(add_sigmoid=True,
                       output_transform=lambda output:
                       (output[0][0], output[2]))
dice_metric.attach(trainer, "Training Dice")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
stats_logger = StatsHandler()
stats_logger.attach(trainer)


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
    # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform
    writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch)

    # tensor of ones to use where for converting labels to zero and ones
    ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32)
    first_output_tensor = engine.state.output[0][1][0].detach().cpu()
    # log model output to tensorboard, as three dimensional tensor with no channels dimension
    img2tensorboard.add_animated_gif_no_channels(writer,
                                                 "first_output_final_batch",
                                                 first_output_tensor, 64, 255,
                                                 engine.state.epoch)
コード例 #4
0
        return seg_probs, seg.to(device)


evaluator = Engine(_sliding_window_processor)

# add evaluation metric to the evaluator engine
MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x:
    None  # no need to print loss value, so disable per iteration output
)
val_stats_handler.attach(evaluator)

# for the arrary data format, assume the 3rd item of batch data is the meta_data
file_saver = SegmentationSaver(
    output_path='tempdir',
    output_ext='.nii.gz',
    output_postfix='seg',
    name='evaluator',
    batch_transform=lambda x: x[2],
    output_transform=lambda output: predict_segmentation(output[0]))
file_saver.attach(evaluator)

# the model was trained by "unet_training_array" exmple
ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth',
                              load_dict={'net': net})
ckpt_saver.attach(evaluator)
コード例 #5
0
# Since network outputs logits and segmentation, we need a custom function.
def _loss_fn(i, j):
    return loss(i[0], j)

# Create trainer
device = torch.device("cuda:0")
trainer = create_supervised_trainer(net, opt, _loss_fn, device, False,
                                    output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y])

# adding checkpoint handler to save models (network params and optimizer stats) during training
checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                          handler=checkpoint_handler,
                          to_save={'net': net, 'opt': opt})
train_stats_handler = StatsHandler()
train_stats_handler.attach(trainer)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
    # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform
    writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch)

    # tensor of ones to use where for converting labels to zero and ones
    ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32)
    first_output_tensor = engine.state.output[0][1][0].detach().cpu()
    # log model output to tensorboard, as three dimensional tensor with no channels dimension
    img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64,
                                                 255, engine.state.epoch)
    # get label tensor and convert to single class
    first_label_tensor = torch.where(engine.state.batch[1][0] > 0, ones, engine.state.batch[1][0])
    # log label tensor to tensorboard, there is a channel dimension when getting label from batch