Пример #1
0
def test_register_record_function(tmpdir):

    use_cuda = torch.cuda.is_available()
    pytorch_profiler = PyTorchProfiler(
        export_to_chrome=False,
        use_cuda=use_cuda,
        dirpath=tmpdir,
        filename="profiler",
        schedule=None,
        on_trace_ready=None,
    )

    class TestModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), torch.nn.Linear(1, 1))

    model = TestModel()
    input = torch.rand((1, 1))

    if use_cuda:
        model = model.cuda()
        input = input.cuda()

    with pytorch_profiler.profile("a"):
        with RegisterRecordFunction(model):
            model(input)

    pytorch_profiler.describe()
    event_names = [e.name for e in pytorch_profiler.function_events]
    assert "[pl][module]torch.nn.modules.container.Sequential: layer" in event_names
    assert "[pl][module]torch.nn.modules.linear.Linear: layer.0" in event_names
    assert "[pl][module]torch.nn.modules.activation.ReLU: layer.1" in event_names
    assert "[pl][module]torch.nn.modules.linear.Linear: layer.2" in event_names
def test_pytorch_profiler_register_record_function_deprecation_warning():
    with pytest.deprecated_call(
            match=
            "RegisterRecordFunction` is deprecated in v1.7 and will be removed in in v1.9."
    ):
        _ = RegisterRecordFunction(None)