def test_multiple_decorators(self, func, torch):
     args = [torch.ones(10), torch.ones(10)]
     _ = input("arg1", shape=(args[0].shape), name=("A", ))(func)
     decorated = input("arg2", shape=(args[1].shape), name=("B", ))(_)
     decorated(arg1=args[0], arg2=args[1])
     func.assert_called_once()
     assert_has_names(func.call_args[0][0], ("A", ))
     assert_has_names(func.call_args[0][1], ("B", ))
    def test_preserves_docstring(self):
        def f(arg1, arg2=None, *args, **kwargs):
            """foo bar baz"""

        decorated = input("arg1", name=("A", ))(f)
        assert decorated.__doc__
        assert decorated.__doc__ == f.__doc__
 def test_selects_correct_arg_by_name(self, func, torch):
     arg1 = torch.ones(10)
     arg2 = torch.ones(10)
     added_name = ("A", )
     decorated = input("arg1", name=added_name)(func)
     decorated(arg1=arg1, arg2=arg2)
     func.assert_called_once()
     assert_has_names(func.call_args[0][0], added_name)
     assert_is_unnamed(func.call_args[0][1])
 def test_optional(self, func, torch, optional, arg_given):
     decorated = input("arg2", shape=(None, ), optional=optional)(func)
     if not optional and not arg_given:
         with pytest.raises(ValueError):
             decorated(arg1=torch.ones(10))
     else:
         if arg_given:
             decorated(arg1=torch.ones(10), arg2=torch.ones(10))
         else:
             decorated(arg1=torch.ones(10))
         func.assert_called()
 def test_coerce_named_input(self, func, shape, names_in, names_out, torch):
     tensor = torch.ones(*shape, names=names_in)
     decorated = input("arg1", name=names_out)(func)
     decorated(tensor)
     func.assert_called_once()
     assert_has_names(func.call_args[0][0], names_out)
 def test_drop_names(self, func, shape, names, torch):
     tensor = torch.ones(*shape, names=names)
     decorated = input("arg1", name=names, drop_names=True)(func)
     decorated(tensor)
     func.assert_called_once()
     assert_is_unnamed(func.call_args[0][0])
 def test_calls_original_func(self, func, torch):
     tensor = torch.ones(10)
     decorated = input("arg1", shape=(tensor.shape))(func)
     decorated(tensor)
     func.assert_called_once()
    def test_kw_only_args(self, torch):
        def f(*, arg1, arg2=None, **kwargs):
            """foo bar baz"""

        decorated = input("arg1", name=("A", ))(f)
        decorated(arg1=torch.ones(10))
 def test_preserves_signature(self, func):
     decorated = input("arg1", name=("A", ))(func)
     assert signature(decorated) == signature(func)
 def test_validates_shape(self, func, pre, post, torch):
     tensor = torch.ones(*post)
     decorated = input("arg1", shape=pre)(func)
     decorated(tensor)
     func.assert_called_once()
     assert func.call_args[0][0].shape == post