Example #1
        def check_compile(**kwargs):
            # `wrapped_fun` and `python_should_be_executing` are used to check that
            # when the jitted function is called the second time, the original Python
            # function won't be executed.
            def wrapped_fun(*args):
                return fun(*args)

            cfun = npe.jit(wrapped_fun,
            python_should_be_executing = True
            monitored_ans = cfun(*args)

            python_should_be_executing = False
            compiled_ans = cfun(*args)

            self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol,
            self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol,

            # Run `cfun` with a different set of arguments to check that changing
            # arguments won't cause recompilation.

            new_args = args_maker()

            skip_retracing_test = False
            for old, new in zip(tf.nest.flatten(args),
                if npe.most_precise_int_dtype(
                        old) != npe.most_precise_int_dtype(new):
                    # If the old and new arguments result in different dtypes (because
                    # they fall into different value ranges), tf-numpy will retrace, so we
                    # skip the no-retrace test.
                    skip_retracing_test = True

            if not skip_retracing_test:
                python_should_be_executing = True
                new_python_ans = fun(*new_args)
                python_should_be_executing = False
                compiled_ans = cfun(*new_args)
                self.assertAllClose(new_python_ans, compiled_ans, check_dtypes,
                                    atol, rtol)
Example #2
  def _CompileAndCheck(self,
    """Compiles the function and checks the results.

      fun: the function to be checked.
      args_maker: a callable that returns a tuple which will be used as the
        positional arguments.
      check_dtypes: whether to check that the result dtypes from non-compiled
        and compiled runs agree.
      rtol: relative tolerance for allclose assertions.
      atol: absolute tolerance for allclose assertions.
      check_eval_on_shapes: whether to run `eval_on_shapes` on the function and
        check that the result shapes and dtypes are correct.
      check_incomplete_shape: whether to check that the function can handle
        incomplete shapes (including those with and without a known rank).
      check_unknown_rank: (only has effect when check_incomplete_shape is True)
        whether to check that the function can handle unknown ranks.
      static_argnums: indices of arguments to be treated as static arguments for
        `jit` and `eval_on_shapes`.
    args = args_maker()

    for x in args:
      if not hasattr(x, 'dtype'):
        # If there is a input that doesn't have dtype info, jit and
        # eval_on_shapes may pick a different dtype for it than numpy, so we
        # skip the dtype check.
        check_dtypes = False

    # `wrapped_fun` and `python_should_be_executing` are used to check that when
    # the jitted function is called the second time, the original Python
    # function won't be executed.
    def wrapped_fun(*args):
      return fun(*args)

    python_ans = fun(*args)

    python_shapes = tf.nest.map_structure(lambda x: onp.shape(x), python_ans)
    onp_shapes = tf.nest.map_structure(lambda x: onp.shape(onp.asarray(x)),
    self.assertEqual(python_shapes, onp_shapes)

    cfun = npe.jit(wrapped_fun, static_argnums=static_argnums)
    python_should_be_executing = True
    monitored_ans = cfun(*args)

    python_should_be_executing = False
    compiled_ans = cfun(*args)

    self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol)
    self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)

    # Run `cfun` with a different set of arguments to check that changing
    # arguments won't cause recompilation.

    new_args = args_maker()

    skip_retracing_test = False
    for old, new in zip(args, new_args):
      if npe.most_precise_int_dtype(old) != npe.most_precise_int_dtype(new):
        # If the old and new arguments result in different dtypes (because they
        # fall into different value ranges), tf-numpy will retrace, so we skip
        # the no-retrace test.
        skip_retracing_test = True

    if not skip_retracing_test:
      python_should_be_executing = True
      new_python_ans = fun(*new_args)
      python_should_be_executing = False
      compiled_ans = cfun(*new_args)
      self.assertAllClose(new_python_ans, compiled_ans, check_dtypes, atol,

    if check_eval_on_shapes:
      # Check that npe.eval_on_shapes can get complete output shapes given
      # complete input shapes.
      cfun = npe.eval_on_shapes(fun, static_argnums=static_argnums)
      compiled_ans = cfun(*args)
      flat_python_ans = tf.nest.flatten(python_ans)
      flat_compiled_ans = tf.nest.flatten(compiled_ans)
      self.assertEqual(len(flat_python_ans), len(flat_compiled_ans))
      for a, b in zip(flat_python_ans, flat_compiled_ans):
        if hasattr(a, 'shape'):
          self.assertEqual(a.shape, b.shape)
        if check_dtypes and hasattr(a, 'dtype'):
          self.assertEqual(tf.as_dtype(a.dtype), b.dtype)

    # If some argument doesn't have a `dtype` attr (e.g. a Python scalar), we
    # skip incomplete-shape checks, since shape specs need dtype. It's OK to
    # skip since the same incomplete-shape checks will run for []-shaped arrays.
    if check_incomplete_shape and all(hasattr(x, 'dtype') for x in args):
      # Check partial shapes with known ranks.
      # Numpy scalars (created by e.g. np.int32(5)) have `dtype` but not
      # `shape`.
      if all(hasattr(x, 'shape') for x in args):
        specs = [tf.TensorSpec([None] * len(x.shape), x.dtype) for x in args]
        cfun = npe.jit(
            fun, static_argnums=static_argnums, input_signature=specs)
        compiled_ans = cfun(*args)
        self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)

      if check_unknown_rank:
        # Check unknown ranks.
        specs = [tf.TensorSpec(None, x.dtype) for x in args]
        cfun = npe.jit(
            fun, static_argnums=static_argnums, input_signature=specs)
        compiled_ans = cfun(*args)
        self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)