Пример #1
0
def test_tensor_property_name():
    string_io = io.StringIO()
    fmt = torchsnooper.TensorFormat(property_name=True)

    @torchsnooper.snoop(string_io, tensor_format=fmt)
    def my_function():
        x = torch.randn((5, 8))
        return x

    my_function()

    output = string_io.getvalue()
    print(output)
    assert_output(output, (
        CallEntry(),
        LineEntry(),
        VariableEntry(
            'x',
            'tensor<shape=(5, 8), dtype=float32, device=cpu, requires_grad=False>'
        ),
        LineEntry(),
        ReturnEntry(),
        ReturnValueEntry(
            'tensor<shape=(5, 8), dtype=float32, device=cpu, requires_grad=False>'
        ),
    ))
Пример #2
0
def test_tensor_property_name():
    string_io = io.StringIO()
    fmt = torchsnooper.TensorFormat(property_name=True)

    @torchsnooper.snoop(string_io,
                        max_variable_length=100000,
                        tensor_format=fmt)
    def my_function():
        x = torch.randn((5, 8))
        return x

    my_function()

    output = string_io.getvalue()
    print(output)
    assert_output(output, (
        SourcePathEntry(),
        CallEntry(),
        LineEntry(),
        VariableEntry(
            'x',
            'tensor<shape=(5, 8), dtype=float32, device=cpu, requires_grad=False, has_nan=False, has_inf=False, memory_format=torch.contiguous>'
        ),
        LineEntry(),
        ReturnEntry(),
        ReturnValueEntry(
            'tensor<shape=(5, 8), dtype=float32, device=cpu, requires_grad=False, has_nan=False, has_inf=False, memory_format=torch.contiguous>'
        ),
        ElapsedTimeEntry(),
    ))
Пример #3
0
def test_tensor_property_selector():
    string_io = io.StringIO()
    fmt = torchsnooper.TensorFormat(properties=('shape', 'device',
                                                'requires_grad'))

    @torchsnooper.snoop(string_io, tensor_format=fmt)
    def my_function():
        x = torch.randn((5, 8))
        return x

    my_function()

    output = string_io.getvalue()
    print(output)
    assert_output(output, (
        CallEntry(),
        LineEntry(),
        VariableEntry('x', 'tensor<(5, 8), cpu>'),
        LineEntry(),
        ReturnEntry(),
        ReturnValueEntry('tensor<(5, 8), cpu>'),
    ))