def test_multiple_decorators(self, func, torch):
     ret = (torch.ones(10), torch.ones(10))
     func.return_value = ret
     _ = output(0, shape=(ret[0].shape), name=("A", ))(func)
     decorated = output(1, shape=(ret[0].shape), name=("B", ))(_)
     result = decorated(torch.ones(10))
     func.assert_called_once()
     assert_has_names(result[0], ("A", ))
     assert_has_names(result[1], ("B", ))
    def test_preserves_docstring(self):
        def f(arg1, arg2=None, *args, **kwargs):
            """foo bar baz"""

        decorated = output(name=("A", ))(f)
        assert decorated.__doc__
        assert decorated.__doc__ == f.__doc__
    def test_validates_shape(self, pre, post, torch):
        tensor = torch.ones(*post)

        def f():
            return tensor

        decorated = output(shape=pre)(f)
        result = decorated()
        assert result.shape == post
    def test_coerce_named_output(self, shape, names_in, names_out, torch):
        tensor = torch.ones(*shape, names=names_in)

        def f():
            return tensor

        decorated = output(name=names_out)(f)
        result = decorated()
        assert_has_names(result, names_out)
    def test_selects_correct_arg_by_pos(self, pos, torch):
        args = [torch.ones(10), torch.ones(10)]

        def f():
            return args[0], args[1]

        added_name = ("A", )
        decorated = output(pos, name=added_name)(f)
        result = list(decorated())
        assert_has_names(result[pos], added_name)
        del result[pos]
        assert_is_unnamed(result[0])
    def test_preserves_signature(self):
        def f(arg1, arg2=None, *args, **kwargs):
            pass

        decorated = output(name=("A", ))(f)
        assert signature(decorated) == signature(f)
 def test_calls_original_func(self, func, torch):
     tensor = torch.ones(10)
     func.return_value = tensor
     decorated = output(shape=(tensor.shape))(func)
     decorated(tensor)
     func.assert_called_once()