Пример #1
0
    def test_loss_dict(self):
        log_stream = StringIO()
        logging.basicConfig(stream=log_stream, level=logging.INFO)
        key_to_handler = 'test_logging'
        key_to_print = 'myLoss1'

        # set up engine
        def _train_func(engine, batch):
            return torch.tensor(0.0)

        engine = Engine(_train_func)

        # set up testing handler
        stats_handler = StatsHandler(
            name=key_to_handler, output_transform=lambda x: {key_to_print: x})
        stats_handler.attach(engine)

        engine.run(range(3), max_epochs=2)

        # check logging output
        output_str = log_stream.getvalue()
        grep = re.compile('.*{}.*'.format(key_to_handler))
        has_key_word = re.compile('.*{}.*'.format(key_to_print))
        for idx, line in enumerate(output_str.split('\n')):
            if grep.match(line):
                if idx in [1, 2, 3, 6, 7, 8]:
                    self.assertTrue(has_key_word.match(line))
Пример #2
0
    def test_metrics_print(self):
        log_stream = StringIO()
        logging.basicConfig(stream=log_stream, level=logging.INFO)
        key_to_handler = 'test_logging'
        key_to_print = 'testing_metric'

        # set up engine
        def _train_func(engine, batch):
            return torch.tensor(0.0)

        engine = Engine(_train_func)

        # set up dummy metric
        @engine.on(Events.EPOCH_COMPLETED)
        def _update_metric(engine):
            current_metric = engine.state.metrics.get(key_to_print, 0.1)
            engine.state.metrics[key_to_print] = current_metric + 0.1

        # set up testing handler
        stats_handler = StatsHandler(name=key_to_handler)
        stats_handler.attach(engine)

        engine.run(range(3), max_epochs=2)

        # check logging output
        output_str = log_stream.getvalue()
        grep = re.compile('.*{}.*'.format(key_to_handler))
        has_key_word = re.compile('.*{}.*'.format(key_to_print))
        for idx, line in enumerate(output_str.split('\n')):
            if grep.match(line):
                if idx in [5, 10]:
                    self.assertTrue(has_key_word.match(line))
Пример #3
0
                                     n_saved=10,
                                     require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                          handler=checkpoint_handler,
                          to_save={
                              'net': net,
                              'opt': opt
                          })

dice_metric = MeanDice(add_sigmoid=True,
                       output_transform=lambda output:
                       (output[0][0], output[2]))
dice_metric.attach(trainer, "Training Dice")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
stats_logger = StatsHandler()
stats_logger.attach(trainer)


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
    # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform
    writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch)

    # tensor of ones to use where for converting labels to zero and ones
    ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32)
    first_output_tensor = engine.state.output[0][1][0].detach().cpu()
    # log model output to tensorboard, as three dimensional tensor with no channels dimension
    img2tensorboard.add_animated_gif_no_channels(writer,
                                                 "first_output_final_batch",
                                                 first_output_tensor, 64, 255,
Пример #4
0
    with torch.no_grad():
        seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net,
                                             device)
        return seg_probs, seg.to(device)


evaluator = Engine(_sliding_window_processor)

# add evaluation metric to the evaluator engine
MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x:
    None  # no need to print loss value, so disable per iteration output
)
val_stats_handler.attach(evaluator)

# for the arrary data format, assume the 3rd item of batch data is the meta_data
file_saver = SegmentationSaver(
    output_path='tempdir',
    output_ext='.nii.gz',
    output_postfix='seg',
    name='evaluator',
    batch_transform=lambda x: x[2],
    output_transform=lambda output: predict_segmentation(output[0]))
file_saver.attach(evaluator)

# the model was trained by "unet_training_array" exmple
Пример #5
0
# adding checkpoint handler to save models (network params and optimizer stats) during training
checkpoint_handler = ModelCheckpoint('./runs/',
                                     'net',
                                     n_saved=10,
                                     require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                          handler=checkpoint_handler,
                          to_save={
                              'net': net,
                              'opt': opt
                          })

# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't set metrics for trainer here, so just print loss, user can also customize print functions
# and can use output_transform to convert engine.state.output if it's not loss value
train_stats_handler = StatsHandler(name='trainer')
train_stats_handler.attach(trainer)

# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
train_tensorboard_stats_handler = TensorBoardStatsHandler()
train_tensorboard_stats_handler.attach(trainer)

# Set parameters for validation
validation_every_n_epochs = 1

metric_name = 'Accuracy'
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: Accuracy()}
# ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(net,