def test_save_model_grads():
    """
    Tests a feature of `OptimizerCallback` for saving model gradients
    """
    logdir = "./logs"
    dataset_root = "./dataset"
    loaders = _get_loaders(root=dataset_root, batch_size=4, num_workers=1)
    images, _ = next(iter(loaders["train"]))
    _, c, h, w = images.shape
    input_shape = (c, h, w)

    model = _SimpleNet(input_shape)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters())

    criterion_callback = CriterionCallback()
    optimizer_callback = OptimizerCallback()
    save_model_grads_callback = GradNormLogger()
    prefix = save_model_grads_callback.grad_norm_prefix
    test_callback = _OnBatchEndCheckGradsCallback(prefix)

    callbacks = collections.OrderedDict(
        loss=criterion_callback,
        optimizer=optimizer_callback,
        grad_norm=save_model_grads_callback,
        test_callback=test_callback,
    )

    runner = SupervisedRunner()
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir=logdir,
        callbacks=callbacks,
        check=True,
        verbose=True,
    )

    shutil.rmtree(logdir)
    shutil.rmtree(dataset_root)
def test_tracer_callback():
    """
    Tests a feature of `TracerCallback` for model tracing during training
    """
    logdir = "./logs"
    dataset_root = "./dataset"
    loaders = _get_loaders(root=dataset_root, batch_size=4, num_workers=1)
    images, targets = next(iter(loaders["train"]))
    _, c, h, w = images.shape
    input_shape = (c, h, w)

    model = _TracedNet(input_shape)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters())

    method_name = "forward"
    mode = "eval"
    requires_grad = False
    checkpoint_name = "best"
    opt_level = None

    trace_name = get_trace_name(
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
        additional_string=checkpoint_name,
    )
    tracing_path = Path(logdir) / "trace" / trace_name
    criterion_callback = CriterionCallback()
    optimizer_callback = OptimizerCallback()
    tracer_callback = TracerCallback(
        metric="loss",
        minimize=False,
        trace_mode=mode,
        mode=checkpoint_name,
        do_once=True,
        method_name=method_name,
        requires_grad=requires_grad,
        opt_level=opt_level,
    )
    test_callback = _OnStageEndCheckModelTracedCallback(
        path=tracing_path, inputs=images,
    )

    callbacks = collections.OrderedDict(
        loss=criterion_callback,
        optimizer=optimizer_callback,
        tracer_callback=tracer_callback,
        test_callback=test_callback,
    )

    runner = SupervisedRunner(input_key="x")
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir=logdir,
        callbacks=callbacks,
        check=True,
        verbose=True,
    )

    shutil.rmtree(logdir)
    shutil.rmtree(dataset_root)