예제 #1
0
파일: jax2tf_test.py 프로젝트: alonfnt/jax
    def test_op_metadata_named(self):
        self.skipTest("include_xla_op_metadata not yet enabled")
        # Calling a jax.named_call
        # The user_frame is used to compute line numbers for ops in the test.
        user_frame = source_info_util.user_frame(source_info_util.current())

        def f_callee(x):
            return jnp.cos(x)

        def f_caller(x):
            y = jnp.tanh(x)
            z = jax.named_call(f_callee, name="callee")(y)
            return jnp.sin(z)

        x = np.ones((2, 3), np.float32)

        self.CheckOpMetadata(f_caller, x, [
            tf_test_util.OpMetadataGraph(tf_type="Tanh",
                                         source_file=__file__,
                                         source_line=user_frame.line_num + 4,
                                         op_name="jax2tf(f_caller)/tanh",
                                         op_type="tanh"),
            tf_test_util.OpMetadataGraph(
                tf_type="Cos",
                source_file=__file__,
                source_line=user_frame.line_num + 2,
                op_name="jax2tf(f_caller)/named(callee)/cos",
                op_type="cos"),
            tf_test_util.OpMetadataGraph(tf_type="Sin",
                                         source_file=__file__,
                                         source_line=user_frame.line_num + 6,
                                         op_name="jax2tf(f_caller)/sin",
                                         op_type="sin"),
        ])
예제 #2
0
파일: jax2tf_test.py 프로젝트: alonfnt/jax
    def test_op_metadata_batched_while(self):
        self.skipTest("include_xla_op_metadata not yet enabled")
        # An example with while and cond
        # The user_frame is used to compute line numbers for ops in the test.
        user_frame = source_info_util.user_frame(source_info_util.current())

        @jax.vmap
        def f_while(x):
            def body_fun(carry):
                new_carry = jnp.sin(carry)  # We look for "sin" in the graph
                return new_carry

            _, carry = lax.while_loop(
                lambda carry: jnp.all(carry <= x
                                      ),  # We look for "le" in the graph
                body_fun,
                x)
            return carry

        shape = (3, 2)
        x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)

        jax_comp = jax.xla_computation(f_while)(x)
        backend = jax.lib.xla_bridge.get_backend()
        modules = backend.compile(jax_comp).hlo_modules()
        jax_opt_hlo = modules[0].to_string()
        print(f"JAX OPT HLO = {jax_opt_hlo}")

        self.CheckOpMetadata(f_while, x, [
            tf_test_util.OpMetadataGraph(
                tf_type="Sin",
                source_file=__file__,
                source_line=user_frame.line_num + 4,
                op_name="jax2tf(f_while)/while/body/sin",
                op_type="sin"),
            tf_test_util.OpMetadataGraph(
                tf_type="LessEqual",
                source_file=__file__,
                source_line=user_frame.line_num + 8,
                op_name="jax2tf(f_while)/while/body_pred/le",
                op_type="le"),
        ])
예제 #3
0
  def test_op_metadata_while_and_cond(self):
    self.skipTest("include_xla_op_metadata not yet enabled")
    # An example with while and cond
    # The user_frame is used to compute line numbers for ops in the test.
    user_frame = source_info_util.user_frame(source_info_util.current())
    def f_while_cond(x):
      def body_fun(i_acc):
        i, acc = i_acc
        return (i + 1,
                (jnp.cos(acc) +
                 lax.cond(jnp.mod(i, 2) == 0,
                          lambda acc: jnp.sin(acc),
                          lambda acc: acc,
                          acc)))

      _, acc = lax.while_loop(
          lambda i_acc: i_acc[0] <= 5,
          body_fun, (0, x))
      return acc

    x = np.ones((2, 3), np.float32)
    self.CheckOpMetadata(
        f_while_cond, x,
        [tf_test_util.OpMetadataGraph(tf_type="Cos",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 5,
                                      op_name="jax2tf(f_while_cond)/while/body/cos",
                                      op_type="cos"),
         tf_test_util.OpMetadataGraph(tf_type="Sin",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 7,
                                      op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin",
                                      op_type="sin"),
         tf_test_util.OpMetadataGraph(tf_type="FloorMod",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 6,
                                      op_name="jax2tf(f_while_cond)/while/body/rem",
                                      op_type="rem"),
         ]
    )
예제 #4
0
파일: jax2tf_test.py 프로젝트: alonfnt/jax
    def test_op_metadata_simple(self):
        self.skipTest("include_xla_op_metadata not yet enabled")
        # A simple example
        # The user_frame is used to compute line numbers for ops in the test.
        user_frame = source_info_util.user_frame(source_info_util.current())

        def f_simple(x):
            return jnp.sin(x)

        x = np.ones((2, 3), np.float32)
        self.CheckOpMetadata(f_simple, x, [
            tf_test_util.OpMetadataGraph(tf_type="Sin",
                                         source_file=__file__,
                                         source_line=user_frame.line_num + 2,
                                         op_name="jax2tf(f_simple)/sin",
                                         op_type="sin")
        ])