def test_simple_failing_func(decorator: Decorator) -> None: """Check we can construct a key, and cache in a dict.""" # store a ref to the thrown exception outside the function # so we can check it's the same one returned exception = None @decorator def f(a: int, b: int) -> int: nonlocal exception # this exception should be cached by the wrapper # so we only see it once exception = RuntimeError("failure") raise exception with Context(dict()) as d: try: f(1, 2) except RuntimeError as e: assert e is exception with pytest.raises(RuntimeError): f(1, 2) assert type(d[key(f, 1, 2)]) is Exception assert d[key(f, 1, 2)].args[0] is exception
def test_simple_func(decorator: Decorator, ) -> None: """Check we can construct a key, and cache in a dict.""" @decorator def f(a: int, b: int) -> int: return a + b with Context(dict()) as d: f(1, 2) f(1, 2) assert d[key(f, 1, 2)] == 3
def test_transforms(decorator: Decorator) -> None: """Test we can intercept a function call.""" @decorator def f(x: int) -> int: return x @decorator def g(a: int, b: int) -> int: return f(a) + f(b) handler = TransformedCallHandler() with Context(handler): assert g(1, 2) == 3 handler.transforms[f.__wrapped__] = lambda x: x + 1 # type: ignore assert g(1, 2) == 5
def test_mock_null_handler(decorator: Decorator) -> None: """Check that a null mock handler is called correctly.""" handler = MagicMock() handler.__contains__.return_value = False handler.__getitem__.return_value = None handler.__setitem__.return_value = None @decorator def f(a: int, b: int) -> int: return a + b with Context(handler): assert f(1, 2) == 3 handler.__contains__.assert_called_once_with(key(f, 1, 2)) handler.__getitem__.assert_not_called() handler.__setitem__.assert_called_once_with(key(f, 1, 2), 3)
def test_fib(decorator: Decorator) -> None: """Check that caching fibonacci works.""" @decorator def fib(x: int) -> int: if x <= 1: return 1 return fib(x - 1) + fib(x - 2) with Context(dict()) as d: assert fib(0) == 1 assert fib(1) == 1 assert fib(2) == 2 assert fib(3) == 3 assert d[key(fib, 0)] == 1 assert d[key(fib, 1)] == 1 assert d[key(fib, 2)] == 2 assert d[key(fib, 3)] == 3
def test_simple_graph_exception(decorator: Decorator) -> None: """Check we can construct a key, and cache in a dict.""" # store a ref to the thrown exception outside the function # so we can check it's the same one returned exception = None @decorator def f(a: int, b: int) -> int: nonlocal exception # this exception should be cached by the wrapper # so we only see it once exception = RuntimeError("failure") raise exception @decorator def g(a: int, b: int) -> int: return f(a, b) a = 1 b = 2 handler = GraphCallHandler() with Context(handler): try: g(a, b) except RuntimeError as e: assert e is exception with pytest.raises(RuntimeError): g(a, b) # exceptions get cached twice - should this be the case, or do # we re-call an throw from source? assert type(handler.retvals[key(f, a, b)]) is Exception assert type(handler.retvals[key(g, a, b)]) is Exception assert handler.retvals[key(f, a, b)].args[0] is exception assert handler.retvals[key(g, a, b)].args[0] is exception assert handler.parents[key(g, a, b)] == set() assert handler.parents[key(f, a, b)] == {key(g, a, b)}
def test_mock_cached_handler(decorator: Decorator) -> None: """Check that a fixed value mock handler is called correctly.""" return_value = -1 handler = MagicMock() handler.__contains__.return_value = True handler.__getitem__.return_value = return_value handler.__setitem__.return_value = None @decorator def f(a: int, b: int) -> int: raise AssertionError("this function should not be called") with pytest.raises(AssertionError): f(1, 2) with Context(handler): assert f(1, 2) is return_value handler.__contains__.assert_called_once_with(key(f, 1, 2)) handler.__getitem__.assert_called_once_with(key(f, 1, 2)) handler.__setitem__.assert_not_called()
branch = self.stack[-1].add( key, style=("bold" if key not in self.seen else "#7F7F7F") ) self.stack.append(branch) return False def __getitem__(self, key): raise NotImplementedError def __setitem__(self, key, value): self.seen.add(key) branch = self.stack.pop(-1) branch.label = repr(key) + " = " + repr(value) with Context(RichTreeCallHandler()) as handler: g(1, 2) print(handler.stack[0]) @shift def fib(n): if n < 2: return 1 return fib(n - 1) + fib(n - 2) with Context(RichTreeCallHandler()) as handler: fib(7) print(handler.stack[0])