예제 #1
0
    def test_lax_custom_linear_solve(self):
        if not traceback_util.filtered_tracebacks_supported():
            raise unittest.SkipTest('Filtered tracebacks not supported')

        def err(*_):
            assert False
            return ()

        matvec = lambda v: v
        solve = lambda mv, b: 1.
        b = 1.

        def f1():
            return lax.custom_linear_solve(err, b, solve)

        def f2():
            return lax.custom_linear_solve(matvec, b, err)

        check_filtered_stack_trace(
            self, AssertionError, f1,
            [('f1', 'return lax.custom_linear_solve(err, b, solve)'),
             ('err', 'assert False')])
        check_filtered_stack_trace(
            self, AssertionError, f2,
            [('f2', 'return lax.custom_linear_solve(matvec, b, err)'),
             ('err', 'assert False')])
예제 #2
0
    def test_lax_custom_root(self):
        if not traceback_util.filtered_tracebacks_supported():
            raise unittest.SkipTest('Filtered tracebacks not supported')

        def err(*_):
            assert False
            return ()

        def g(x):
            return (x - 1.)**2.

        def solve(*_):
            return 1.

        def f1():
            return lax.custom_root(g, 0., err, solve)

        def f2():
            return lax.custom_root(g, 0., solve, err)

        def f3():
            return lax.custom_root(err, 0., solve, solve)

        check_filtered_stack_trace(
            self, AssertionError, f1,
            [('f1', 'return lax.custom_root(g, 0., err, solve)'),
             ('err', 'assert False')])
        check_filtered_stack_trace(
            self, AssertionError, f2,
            [('f2', 'return lax.custom_root(g, 0., solve, err)'),
             ('err', 'assert False')])
        check_filtered_stack_trace(
            self, AssertionError, f3,
            [('f3', 'return lax.custom_root(err, 0., solve, solve)'),
             ('err', 'assert False')])
예제 #3
0
  def test_lax_cond(self):
    if not traceback_util.filtered_tracebacks_supported():
      raise unittest.SkipTest('Filtered tracebacks not supported')

    def err(_):
      assert False
      return ()

    def f():
      return lax.cond(True, err, lambda _: (), ())

    check_filtered_stack_trace(self, AssertionError, f, [
        ('f', 'return lax.cond(True, err, lambda _: (), ())'),
        ('err', 'assert False')])
예제 #4
0
    def test_lax_scan(self):
        if not traceback_util.filtered_tracebacks_supported():
            raise unittest.SkipTest('Filtered tracebacks not supported')

        def err(*_):
            assert False
            return ()

        def f():
            return lax.scan(err, (), (), 3)

        check_filtered_stack_trace(self, AssertionError, f,
                                   [('f', 'return lax.scan(err, (), (), 3)'),
                                    ('err', 'assert False')])
예제 #5
0
    def test_lax_map(self):
        if not traceback_util.filtered_tracebacks_supported():
            raise unittest.SkipTest('Filtered tracebacks not supported')

        def err(_):
            assert False
            return ()

        def f():
            xs = jnp.ones(3)
            return lax.map(err, xs)

        check_filtered_stack_trace(self, AssertionError, f,
                                   [('f', 'return lax.map(err, xs)'),
                                    ('err', 'assert False')])
예제 #6
0
  def test_lax_associative_scan(self):
    if not traceback_util.filtered_tracebacks_supported():
      raise unittest.SkipTest('Filtered tracebacks not supported')

    def err(*_):
      assert False
      return ()

    def f():
      xs = jnp.arange(4.)
      return lax.associative_scan(err, xs)

    check_filtered_stack_trace(self, AssertionError, f, [
        ('f', 'return lax.associative_scan(err, xs)'),
        ('err', 'assert False')])
예제 #7
0
  def test_lax_while_loop(self):
    if not traceback_util.filtered_tracebacks_supported():
      raise unittest.SkipTest('Filtered tracebacks not supported')

    def err(*_):
      assert False
      return ()

    def f():
      pred = lambda _: False
      return lax.while_loop(pred, err, ())

    check_filtered_stack_trace(self, AssertionError, f, [
        ('f', 'return lax.while_loop(pred, err, ())'),
        ('err', 'assert False')])
예제 #8
0
  def test_nested_jit_and_grad(self):
    if not traceback_util.filtered_tracebacks_supported():
      raise unittest.SkipTest('Filtered tracebacks not supported')

    @jit
    def innermost(x):
      assert False
    @jit
    def inbetween(x):
      return 1 + grad(innermost)(x)
    @jit
    def outermost(x):
      return 2 + inbetween(x)

    f = lambda: outermost(jnp.array([1, 2]))

    check_filtered_stack_trace(self, TypeError, f, [
        ('<lambda>', 'f = lambda: outermost'),
        ('outermost', 'return 2 + inbetween(x)'),
        ('inbetween', 'return 1 + grad(innermost)(x)')])
예제 #9
0
  def test_cause_chain(self):
    if not traceback_util.filtered_tracebacks_supported():
      raise unittest.SkipTest('Filtered tracebacks not supported')

    @jit
    def inner(x):
      raise ValueError('inner')
    @jit
    def outer(x):
      try:
        inner(x)
      except ValueError as e:
        raise TypeError('outer') from e

    f = lambda: outer(1.)

    check_filtered_stack_trace(self, TypeError, f, [
        ('<lambda>', 'f = lambda: outer'),
        ('outer', 'raise TypeError')])
    e = get_exception(TypeError, f)
    self.assertIsInstance(e.__cause__, ValueError)
    self.assertIsInstance(e.__cause__.__cause__,
                          traceback_util.FilteredStackTrace)
예제 #10
0
def skip_if_unsupported_filter_mode(filter_mode):
    if (filter_mode == "remove_frames"
            and not traceback_util.filtered_tracebacks_supported()):
        raise unittest.SkipTest('Filtered tracebacks not supported')
    elif filter_mode == "tracebackhide" and sys.version_info[:2] < (3, 7):
        raise unittest.SkipTest('Tracebackhide requires Python 3.7 or newer')