def test_multiple_decorators(self, func, torch): ret = (torch.ones(10), torch.ones(10)) func.return_value = ret _ = output(0, shape=(ret[0].shape), name=("A", ))(func) decorated = output(1, shape=(ret[0].shape), name=("B", ))(_) result = decorated(torch.ones(10)) func.assert_called_once() assert_has_names(result[0], ("A", )) assert_has_names(result[1], ("B", ))
def test_preserves_docstring(self): def f(arg1, arg2=None, *args, **kwargs): """foo bar baz""" decorated = output(name=("A", ))(f) assert decorated.__doc__ assert decorated.__doc__ == f.__doc__
def test_validates_shape(self, pre, post, torch): tensor = torch.ones(*post) def f(): return tensor decorated = output(shape=pre)(f) result = decorated() assert result.shape == post
def test_coerce_named_output(self, shape, names_in, names_out, torch): tensor = torch.ones(*shape, names=names_in) def f(): return tensor decorated = output(name=names_out)(f) result = decorated() assert_has_names(result, names_out)
def test_selects_correct_arg_by_pos(self, pos, torch): args = [torch.ones(10), torch.ones(10)] def f(): return args[0], args[1] added_name = ("A", ) decorated = output(pos, name=added_name)(f) result = list(decorated()) assert_has_names(result[pos], added_name) del result[pos] assert_is_unnamed(result[0])
def test_preserves_signature(self): def f(arg1, arg2=None, *args, **kwargs): pass decorated = output(name=("A", ))(f) assert signature(decorated) == signature(f)
def test_calls_original_func(self, func, torch): tensor = torch.ones(10) func.return_value = tensor decorated = output(shape=(tensor.shape))(func) decorated(tensor) func.assert_called_once()