Beispiel #1
0
    def test_misconfig(self, input_params):
        with self.assertRaisesRegex(ValueError, 'compatib'):
            dice_metric = MeanDice(**input_params)

            y_pred = torch.Tensor([[0, 1], [1, 0]])
            y = torch.ones((2, 1))
            dice_metric.update([y_pred, y])
Beispiel #2
0
    def test_compute(self, input_params, expected_avg):
        dice_metric = MeanDice(**input_params)

        y_pred = torch.Tensor([[0, 1], [1, 0]])
        y = torch.ones((2, 1))
        dice_metric.update([y_pred, y])

        y_pred = torch.Tensor([[0, 1], [1, 0]])
        y = torch.Tensor([[1.], [0.]])
        dice_metric.update([y_pred, y])

        avg_dice = dice_metric.compute()
        self.assertAlmostEqual(avg_dice, expected_avg)
Beispiel #3
0
    def test_shape_mismatch(self, input_params, _expected):
        dice_metric = MeanDice(**input_params)
        with self.assertRaises((AssertionError, ValueError)):
            y_pred = torch.Tensor([[0, 1], [1, 0]])
            y = torch.ones((2, 3))
            dice_metric.update([y_pred, y])

        with self.assertRaises((AssertionError, ValueError)):
            y_pred = torch.Tensor([[0, 1], [1, 0]])
            y = torch.ones((3, 2))
            dice_metric.update([y_pred, y])
Beispiel #4
0
                                    output_transform=lambda x, y, y_pred, loss:
                                    [y_pred, loss.item(), y])

checkpoint_handler = ModelCheckpoint('./',
                                     'net',
                                     n_saved=10,
                                     require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                          handler=checkpoint_handler,
                          to_save={
                              'net': net,
                              'opt': opt
                          })

dice_metric = MeanDice(add_sigmoid=True,
                       output_transform=lambda output:
                       (output[0][0], output[2]))
dice_metric.attach(trainer, "Training Dice")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
stats_logger = StatsHandler()
stats_logger.attach(trainer)


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
    # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform
    writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch)

    # tensor of ones to use where for converting labels to zero and ones
    ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32)
Beispiel #5
0
sw_batch_size = 4


def _sliding_window_processor(engine, batch):
    net.eval()
    img, seg, meta_data = batch
    with torch.no_grad():
        seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net,
                                             device)
        return seg_probs, seg.to(device)


evaluator = Engine(_sliding_window_processor)

# add evaluation metric to the evaluator engine
MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
val_stats_handler = StatsHandler(
    name='evaluator',
    output_transform=lambda x:
    None  # no need to print loss value, so disable per iteration output
)
val_stats_handler.attach(evaluator)

# for the arrary data format, assume the 3rd item of batch data is the meta_data
file_saver = SegmentationSaver(
    output_path='tempdir',
    output_ext='.nii.gz',
    output_postfix='seg',
Beispiel #6
0
    third_output_tensor = engine.state.output[0][1][2].detach().cpu()
    img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64,
                                                 255, engine.state.epoch)
    third_label_tensor = torch.where(engine.state.batch[1][2] > 0, ones, engine.state.batch[1][2])
    img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64,
                                     255, engine.state.epoch)
    engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1])

writer = SummaryWriter()

# Set parameters for validation
validation_every_n_epochs = 1
metric_name = 'Mean_Dice'

# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice(add_sigmoid=True)}
evaluator = create_supervised_evaluator(net, val_metrics, device, True,
                                        output_transform=lambda x, y, y_pred: (y_pred[0], y))

# Add stats event handler to print validation stats via evaluator
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
val_stats_handler = StatsHandler()
val_stats_handler.attach(evaluator)

# Add early stopping handler to evaluator.
early_stopper = EarlyStopping(patience=4,
                              score_function=stopping_fn_from_metric(metric_name),
                              trainer=trainer)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

# create a validation data loader
Beispiel #7
0
# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't set metrics for trainer here, so just print loss, user can also customize print functions
# and can use output_transform to convert engine.state.output if it's not loss value
train_stats_handler = StatsHandler(name='trainer')
train_stats_handler.attach(trainer)

# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
train_tensorboard_stats_handler = TensorBoardStatsHandler()
train_tensorboard_stats_handler.attach(trainer)


validation_every_n_iters = 5
# Set parameters for validation
metric_name = 'Mean_Dice'
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)}

# ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch)


@trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters))
def run_validation(engine):
    evaluator.run(val_loader)


# Add early stopping handler to evaluator
early_stopper = EarlyStopping(patience=4,
                              score_function=stopping_fn_from_metric(metric_name),
                              trainer=trainer)