コード例 #1
0
ファイル: tf_ops_test.py プロジェクト: samuela/jax
    def test_while_batched(self, with_function=True):
        """A while with a single carry"""
        with jax_to_tf.enable_jit():

            def product(x, y):
                # Equivalent to "x * y" implemented as:
                #      res = 0.
                #      for(i=0; i < y; i++)
                #         res += x
                return lax.while_loop(
                    lambda idx_carry: idx_carry[0] < y, lambda idx_carry:
                    (idx_carry[0] + 1, idx_carry[1] + x), (0, 0.))

            # We use vmap to compute result[i, j] = i * j
            xs = np.arange(4, dtype=np.int32)
            ys = np.arange(5, dtype=np.int32)

            def product_xs_y(xs, y):
                return jax.vmap(product, in_axes=(0, None))(xs, y)

            def product_xs_ys(xs, ys):
                return jax.vmap(product_xs_y, in_axes=(None, 0))(xs, ys)

            f_jax = product_xs_ys
            f_tf = jax_to_tf.convert(f_jax)
            if with_function:
                f_tf = tf.function(f_tf)
            res_jax = f_jax(xs, ys)
            res_tf = f_tf(xs, ys)
            for r_tf, r_jax in zip(res_tf, res_jax):
                np.testing.assert_allclose(r_tf, r_jax)
コード例 #2
0
ファイル: control_flow_ops_test.py プロジェクト: stilling/jax
  def test_while(self, with_function=False):
    # Some constants to capture in the conditional branches
    cond_const = np.ones(3, dtype=np.float32)
    body_const1 = np.full_like(cond_const, 1.)
    body_const2 = np.full_like(cond_const, 2.)

    def func(x):
      # Equivalent to:
      #      c = [1, 1, 1]
      #      for(i=0; i < 3; i++)
      #        c += [1, 1, 1] + [2, 2, 2]
      #
      # The function is set-up so that it captures constants in the
      # body of the functionals. This covers some cases in the representation
      # of the lax.while primitive.
      def cond(idx_carry):
        i, c = idx_carry
        return i < jnp.sum(lax.tie_in(i, cond_const))  # Capture cond_const

      def body(idx_carry):
        i, c = idx_carry
        return (i + 1, c + body_const1 + body_const2)

      return lax.while_loop(cond, body, (0, x))

    with jax_to_tf.enable_jit():
      self.ConvertAndCompare(func, cond_const, with_function=with_function)
コード例 #3
0
ファイル: control_flow_ops_test.py プロジェクト: stilling/jax
  def test_cond_multiple_results(self, with_function=False):
    def f_jax(pred, x):
      return lax.cond(pred, lambda t: (t + 1., 1.), lambda f: (f + 2., 2.), x)

    with jax_to_tf.enable_jit():
      self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)
      self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)
コード例 #4
0
ファイル: control_flow_ops_test.py プロジェクト: stilling/jax
  def test_cond(self, with_function=False):
    def f_jax(pred, x):
      return lax.cond(pred, lambda t: t + 1., lambda f: f, x)

    with jax_to_tf.enable_jit():
      self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)
      self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)
コード例 #5
0
ファイル: control_flow_ops_test.py プロジェクト: stilling/jax
  def test_while_single_carry(self, with_function=False):
    """A while with a single carry"""
    def func(x):
      # Equivalent to:
      #      for(i=x; i < 4; i++);
      return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x)

    with jax_to_tf.enable_jit():
      self.ConvertAndCompare(func, 0, with_function=with_function)
コード例 #6
0
ファイル: tf_ops_test.py プロジェクト: samuela/jax
    def test_cond(self, with_function=False):
        with jax_to_tf.enable_jit():

            def f_jax(pred, x):
                return lax.cond(pred, lambda t: t + 1., lambda f: f, x)

            f_tf = jax_to_tf.convert(f_jax)
            if with_function:
                f_tf = tf.function(f_tf)
            np.testing.assert_allclose(f_tf(True, 1.), f_jax(True, 1.))
            np.testing.assert_allclose(f_tf(False, 1.), f_jax(False, 1.))
コード例 #7
0
ファイル: control_flow_ops_test.py プロジェクト: stilling/jax
  def test_scan(self, with_function=False):
    def f_jax(xs, ys):
      body_const = np.ones((2, ), dtype=np.float32)  # Test constant capture
      def body(res0, inputs):
        x, y = inputs
        return res0 + x * y, body_const
      return lax.scan(body, 0., (xs, ys))

    arg = np.arange(10, dtype=np.float32)
    with jax_to_tf.enable_jit():
      self.ConvertAndCompare(f_jax, arg, arg, with_function=with_function)
コード例 #8
0
ファイル: tf_ops_test.py プロジェクト: samuela/jax
    def test_while_single_carry(self, with_function=False):
        """A while with a single carry"""
        with jax_to_tf.enable_jit():

            def func(x):
                # Equivalent to:
                #      for(i=x; i < 4; i++);
                return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x)

            f_jax = func
            f_tf = jax_to_tf.convert(f_jax)
            if with_function:
                f_tf = tf.function(f_tf)
            res_jax = f_jax(0)
            res_tf = f_tf(0)
            np.testing.assert_allclose(res_jax, res_tf)
コード例 #9
0
ファイル: control_flow_ops_test.py プロジェクト: stilling/jax
  def test_while_batched(self, with_function=True):
    """A while with a single carry"""
    def product(x, y):
      # Equivalent to "x * y" implemented as:
      #      res = 0.
      #      for(i=0; i < y; i++)
      #         res += x
      return lax.while_loop(lambda idx_carry: idx_carry[0] < y,
                            lambda idx_carry: (idx_carry[0] + 1,
                                               idx_carry[1] + x),
                            (0, 0.))

    # We use vmap to compute result[i, j] = i * j
    xs = np.arange(4, dtype=np.int32)
    ys = np.arange(5, dtype=np.int32)

    def product_xs_y(xs, y):
      return jax.vmap(product, in_axes=(0, None))(xs, y)
    def product_xs_ys(xs, ys):
      return jax.vmap(product_xs_y, in_axes=(None, 0))(xs, ys)

    with jax_to_tf.enable_jit():
      self.ConvertAndCompare(product_xs_ys, xs, ys, with_function=with_function)
コード例 #10
0
ファイル: tf_ops_test.py プロジェクト: samuela/jax
    def test_while(self, with_function=False):
        with jax_to_tf.enable_jit():
            # Some constants to capture in the conditional branches
            cond_const = np.ones(3, dtype=np.float32)
            body_const1 = np.full_like(cond_const, 1.)
            body_const2 = np.full_like(cond_const, 2.)

            def func(x):
                # Equivalent to:
                #      c = [1, 1, 1]
                #      for(i=0; i < 3; i++)
                #        c += [1, 1, 1] + [2, 2, 2]
                #
                # The function is set-up so that it captures constants in the
                # body of the functionals. This covers some cases in the representation
                # of the lax.while primitive.
                def cond(idx_carry):
                    i, c = idx_carry
                    return i < jnp.sum(lax.tie_in(
                        i, cond_const))  # Capture cond_const

                def body(idx_carry):
                    i, c = idx_carry
                    return (i + 1, c + body_const1 + body_const2)

                return lax.while_loop(cond, body, (0, x))

            f_jax = func
            f_tf = jax_to_tf.convert(f_jax)
            if with_function:
                f_tf = tf.function(f_tf)
            input = cond_const
            res_jax = f_jax(input)
            res_tf = f_tf(input)
            for r_jax, r_tf in zip(res_jax, res_tf):
                np.testing.assert_allclose(r_jax, r_tf)