def test_multiple_decorators(self, func, torch): args = [torch.ones(10), torch.ones(10)] _ = input("arg1", shape=(args[0].shape), name=("A", ))(func) decorated = input("arg2", shape=(args[1].shape), name=("B", ))(_) decorated(arg1=args[0], arg2=args[1]) func.assert_called_once() assert_has_names(func.call_args[0][0], ("A", )) assert_has_names(func.call_args[0][1], ("B", ))
def test_preserves_docstring(self): def f(arg1, arg2=None, *args, **kwargs): """foo bar baz""" decorated = input("arg1", name=("A", ))(f) assert decorated.__doc__ assert decorated.__doc__ == f.__doc__
def test_selects_correct_arg_by_name(self, func, torch): arg1 = torch.ones(10) arg2 = torch.ones(10) added_name = ("A", ) decorated = input("arg1", name=added_name)(func) decorated(arg1=arg1, arg2=arg2) func.assert_called_once() assert_has_names(func.call_args[0][0], added_name) assert_is_unnamed(func.call_args[0][1])
def test_optional(self, func, torch, optional, arg_given): decorated = input("arg2", shape=(None, ), optional=optional)(func) if not optional and not arg_given: with pytest.raises(ValueError): decorated(arg1=torch.ones(10)) else: if arg_given: decorated(arg1=torch.ones(10), arg2=torch.ones(10)) else: decorated(arg1=torch.ones(10)) func.assert_called()
def test_coerce_named_input(self, func, shape, names_in, names_out, torch): tensor = torch.ones(*shape, names=names_in) decorated = input("arg1", name=names_out)(func) decorated(tensor) func.assert_called_once() assert_has_names(func.call_args[0][0], names_out)
def test_drop_names(self, func, shape, names, torch): tensor = torch.ones(*shape, names=names) decorated = input("arg1", name=names, drop_names=True)(func) decorated(tensor) func.assert_called_once() assert_is_unnamed(func.call_args[0][0])
def test_calls_original_func(self, func, torch): tensor = torch.ones(10) decorated = input("arg1", shape=(tensor.shape))(func) decorated(tensor) func.assert_called_once()
def test_kw_only_args(self, torch): def f(*, arg1, arg2=None, **kwargs): """foo bar baz""" decorated = input("arg1", name=("A", ))(f) decorated(arg1=torch.ones(10))
def test_preserves_signature(self, func): decorated = input("arg1", name=("A", ))(func) assert signature(decorated) == signature(func)
def test_validates_shape(self, func, pre, post, torch): tensor = torch.ones(*post) decorated = input("arg1", shape=pre)(func) decorated(tensor) func.assert_called_once() assert func.call_args[0][0].shape == post