def test_lax_custom_linear_solve(self): if not traceback_util.filtered_tracebacks_supported(): raise unittest.SkipTest('Filtered tracebacks not supported') def err(*_): assert False return () matvec = lambda v: v solve = lambda mv, b: 1. b = 1. def f1(): return lax.custom_linear_solve(err, b, solve) def f2(): return lax.custom_linear_solve(matvec, b, err) check_filtered_stack_trace( self, AssertionError, f1, [('f1', 'return lax.custom_linear_solve(err, b, solve)'), ('err', 'assert False')]) check_filtered_stack_trace( self, AssertionError, f2, [('f2', 'return lax.custom_linear_solve(matvec, b, err)'), ('err', 'assert False')])
def test_lax_custom_root(self): if not traceback_util.filtered_tracebacks_supported(): raise unittest.SkipTest('Filtered tracebacks not supported') def err(*_): assert False return () def g(x): return (x - 1.)**2. def solve(*_): return 1. def f1(): return lax.custom_root(g, 0., err, solve) def f2(): return lax.custom_root(g, 0., solve, err) def f3(): return lax.custom_root(err, 0., solve, solve) check_filtered_stack_trace( self, AssertionError, f1, [('f1', 'return lax.custom_root(g, 0., err, solve)'), ('err', 'assert False')]) check_filtered_stack_trace( self, AssertionError, f2, [('f2', 'return lax.custom_root(g, 0., solve, err)'), ('err', 'assert False')]) check_filtered_stack_trace( self, AssertionError, f3, [('f3', 'return lax.custom_root(err, 0., solve, solve)'), ('err', 'assert False')])
def test_lax_cond(self): if not traceback_util.filtered_tracebacks_supported(): raise unittest.SkipTest('Filtered tracebacks not supported') def err(_): assert False return () def f(): return lax.cond(True, err, lambda _: (), ()) check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.cond(True, err, lambda _: (), ())'), ('err', 'assert False')])
def test_lax_scan(self): if not traceback_util.filtered_tracebacks_supported(): raise unittest.SkipTest('Filtered tracebacks not supported') def err(*_): assert False return () def f(): return lax.scan(err, (), (), 3) check_filtered_stack_trace(self, AssertionError, f, [('f', 'return lax.scan(err, (), (), 3)'), ('err', 'assert False')])
def test_lax_map(self): if not traceback_util.filtered_tracebacks_supported(): raise unittest.SkipTest('Filtered tracebacks not supported') def err(_): assert False return () def f(): xs = jnp.ones(3) return lax.map(err, xs) check_filtered_stack_trace(self, AssertionError, f, [('f', 'return lax.map(err, xs)'), ('err', 'assert False')])
def test_lax_associative_scan(self): if not traceback_util.filtered_tracebacks_supported(): raise unittest.SkipTest('Filtered tracebacks not supported') def err(*_): assert False return () def f(): xs = jnp.arange(4.) return lax.associative_scan(err, xs) check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.associative_scan(err, xs)'), ('err', 'assert False')])
def test_lax_while_loop(self): if not traceback_util.filtered_tracebacks_supported(): raise unittest.SkipTest('Filtered tracebacks not supported') def err(*_): assert False return () def f(): pred = lambda _: False return lax.while_loop(pred, err, ()) check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.while_loop(pred, err, ())'), ('err', 'assert False')])
def test_nested_jit_and_grad(self): if not traceback_util.filtered_tracebacks_supported(): raise unittest.SkipTest('Filtered tracebacks not supported') @jit def innermost(x): assert False @jit def inbetween(x): return 1 + grad(innermost)(x) @jit def outermost(x): return 2 + inbetween(x) f = lambda: outermost(jnp.array([1, 2])) check_filtered_stack_trace(self, TypeError, f, [ ('<lambda>', 'f = lambda: outermost'), ('outermost', 'return 2 + inbetween(x)'), ('inbetween', 'return 1 + grad(innermost)(x)')])
def test_cause_chain(self): if not traceback_util.filtered_tracebacks_supported(): raise unittest.SkipTest('Filtered tracebacks not supported') @jit def inner(x): raise ValueError('inner') @jit def outer(x): try: inner(x) except ValueError as e: raise TypeError('outer') from e f = lambda: outer(1.) check_filtered_stack_trace(self, TypeError, f, [ ('<lambda>', 'f = lambda: outer'), ('outer', 'raise TypeError')]) e = get_exception(TypeError, f) self.assertIsInstance(e.__cause__, ValueError) self.assertIsInstance(e.__cause__.__cause__, traceback_util.FilteredStackTrace)
def skip_if_unsupported_filter_mode(filter_mode): if (filter_mode == "remove_frames" and not traceback_util.filtered_tracebacks_supported()): raise unittest.SkipTest('Filtered tracebacks not supported') elif filter_mode == "tracebackhide" and sys.version_info[:2] < (3, 7): raise unittest.SkipTest('Tracebackhide requires Python 3.7 or newer')