コード例 #1
0
ファイル: tf.py プロジェクト: wangdongya/trax
def tf_abstract_eval(f):
    """Returns a function that evaluates `f` given input shapes and dtypes.

  It transforms function `f` to a function that performs the same computation as
  `f` but only on shapes and dtypes (a.k.a. shape inference).

  Args:
    f: the function to be transformed.

  Returns:
    A function whose input arguments can be either the same as `f`'s or only
    their shapes/dtypes represented by `ShapeDtype`, and whose return values are
    `ShapeDtype`s with the same nested structure as `f`'s return values.
  """
    f_shape = tf_np_extensions.eval_on_shapes(f)

    def from_shape_type(x):
        if isinstance(x, ShapeDtype):
            return tf.TensorSpec(x.shape, x.dtype)
        else:
            return x

    def to_shape_type(x):  # pylint: disable=missing-docstring
        # TODO(wangpeng): handle partial output shapes using `tf.shape`.
        def to_numpy_shape(s):
            if s.is_fully_defined():
                return tuple(s.as_list())
            else:
                raise ValueError(
                    "The output shapes (%s) of the dry-run'ed function are"
                    ' not fully defined.' % s)

        def to_numpy_dtype(t):
            return np.dtype(t.as_numpy_dtype)

        if isinstance(x, tf.TensorSpec):
            return ShapeDtype(to_numpy_shape(x.shape), to_numpy_dtype(x.dtype))
        else:
            return x

    def f_return(*args):
        args = tf.nest.map_structure(from_shape_type, args)
        res = f_shape(*args)
        return tf.nest.map_structure(to_shape_type, res)

    return f_return
コード例 #2
0
ファイル: extensions_test.py プロジェクト: yliu45/trax
 def testScanShape(self, f, inputs, expected_outputs):
   outputs = extensions.eval_on_shapes(
       functools.partial(extensions.scan, f), static_argnums=(2,))(*inputs)
   self.assertAllEqual(expected_outputs, outputs)
コード例 #3
0
ファイル: extensions_test.py プロジェクト: yliu45/trax
 def f_prime(*args):
   res = extensions.eval_on_shapes(f, **kwargs)(*args)
   return tf.nest.map_structure(
       lambda x: tf_np.zeros(x.shape, x.dtype), res)
コード例 #4
0
ファイル: extensions_test.py プロジェクト: zhaoqiuye/trax
 def f_prime(a, b):
     shape_dtype = extensions.eval_on_shapes(f)(a, b)
     return tf_np.zeros(shape=shape_dtype.shape,
                        dtype=shape_dtype.dtype)
コード例 #5
0
ファイル: extensions_test.py プロジェクト: zhaoqiuye/trax
 def transformer(f):
     return extensions.eval_on_shapes(f)
コード例 #6
0
 def f_prime(a, b):
   shape_dtype = extensions.eval_on_shapes(f)(a, b)
   return array_creation.zeros(shape=shape_dtype.shape,
                               dtype=shape_dtype.dtype)
コード例 #7
0
ファイル: test_util.py プロジェクト: yangliuy/trax
  def _CompileAndCheck(self,
                       fun,
                       args_maker,
                       check_dtypes,
                       rtol=None,
                       atol=None,
                       check_eval_on_shapes=True,
                       check_incomplete_shape=False,
                       check_unknown_rank=True,
                       static_argnums=()):
    """Compiles the function and checks the results.

    Args:
      fun: the function to be checked.
      args_maker: a callable that returns a tuple which will be used as the
        positional arguments.
      check_dtypes: whether to check that the result dtypes from non-compiled
        and compiled runs agree.
      rtol: relative tolerance for allclose assertions.
      atol: absolute tolerance for allclose assertions.
      check_eval_on_shapes: whether to run `eval_on_shapes` on the function and
        check that the result shapes and dtypes are correct.
      check_incomplete_shape: whether to check that the function can handle
        incomplete shapes (including those with and without a known rank).
      check_unknown_rank: (only has effect when check_incomplete_shape is True)
        whether to check that the function can handle unknown ranks.
      static_argnums: indices of arguments to be treated as static arguments for
        `jit` and `eval_on_shapes`.
    """
    args = args_maker()

    for x in args:
      if not hasattr(x, 'dtype'):
        # If there is a input that doesn't have dtype info, jit and
        # eval_on_shapes may pick a different dtype for it than numpy, so we
        # skip the dtype check.
        check_dtypes = False

    # `wrapped_fun` and `python_should_be_executing` are used to check that when
    # the jitted function is called the second time, the original Python
    # function won't be executed.
    def wrapped_fun(*args):
      self.assertTrue(python_should_be_executing)
      return fun(*args)

    python_ans = fun(*args)

    python_shapes = tf.nest.map_structure(lambda x: onp.shape(x), python_ans)
    onp_shapes = tf.nest.map_structure(lambda x: onp.shape(onp.asarray(x)),
                                       python_ans)
    self.assertEqual(python_shapes, onp_shapes)

    cfun = npe.jit(wrapped_fun, static_argnums=static_argnums)
    python_should_be_executing = True
    monitored_ans = cfun(*args)

    python_should_be_executing = False
    compiled_ans = cfun(*args)

    self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol)
    self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)

    # Run `cfun` with a different set of arguments to check that changing
    # arguments won't cause recompilation.

    new_args = args_maker()

    skip_retracing_test = False
    for old, new in zip(args, new_args):
      if npe.most_precise_int_dtype(old) != npe.most_precise_int_dtype(new):
        # If the old and new arguments result in different dtypes (because they
        # fall into different value ranges), tf-numpy will retrace, so we skip
        # the no-retrace test.
        skip_retracing_test = True

    if not skip_retracing_test:
      python_should_be_executing = True
      new_python_ans = fun(*new_args)
      python_should_be_executing = False
      compiled_ans = cfun(*new_args)
      self.assertAllClose(new_python_ans, compiled_ans, check_dtypes, atol,
                          rtol)

    if check_eval_on_shapes:
      # Check that npe.eval_on_shapes can get complete output shapes given
      # complete input shapes.
      cfun = npe.eval_on_shapes(fun, static_argnums=static_argnums)
      compiled_ans = cfun(*args)
      flat_python_ans = tf.nest.flatten(python_ans)
      flat_compiled_ans = tf.nest.flatten(compiled_ans)
      self.assertEqual(len(flat_python_ans), len(flat_compiled_ans))
      for a, b in zip(flat_python_ans, flat_compiled_ans):
        if hasattr(a, 'shape'):
          self.assertEqual(a.shape, b.shape)
        if check_dtypes and hasattr(a, 'dtype'):
          self.assertEqual(tf.as_dtype(a.dtype), b.dtype)

    # If some argument doesn't have a `dtype` attr (e.g. a Python scalar), we
    # skip incomplete-shape checks, since shape specs need dtype. It's OK to
    # skip since the same incomplete-shape checks will run for []-shaped arrays.
    if check_incomplete_shape and all(hasattr(x, 'dtype') for x in args):
      # Check partial shapes with known ranks.
      # Numpy scalars (created by e.g. np.int32(5)) have `dtype` but not
      # `shape`.
      if all(hasattr(x, 'shape') for x in args):
        specs = [tf.TensorSpec([None] * len(x.shape), x.dtype) for x in args]
        cfun = npe.jit(
            fun, static_argnums=static_argnums, input_signature=specs)
        compiled_ans = cfun(*args)
        self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)

      if check_unknown_rank:
        # Check unknown ranks.
        specs = [tf.TensorSpec(None, x.dtype) for x in args]
        cfun = npe.jit(
            fun, static_argnums=static_argnums, input_signature=specs)
        compiled_ans = cfun(*args)
        self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)