def tf_abstract_eval(f): """Returns a function that evaluates `f` given input shapes and dtypes. It transforms function `f` to a function that performs the same computation as `f` but only on shapes and dtypes (a.k.a. shape inference). Args: f: the function to be transformed. Returns: A function whose input arguments can be either the same as `f`'s or only their shapes/dtypes represented by `ShapeDtype`, and whose return values are `ShapeDtype`s with the same nested structure as `f`'s return values. """ f_shape = tf_np_extensions.eval_on_shapes(f) def from_shape_type(x): if isinstance(x, ShapeDtype): return tf.TensorSpec(x.shape, x.dtype) else: return x def to_shape_type(x): # pylint: disable=missing-docstring # TODO(wangpeng): handle partial output shapes using `tf.shape`. def to_numpy_shape(s): if s.is_fully_defined(): return tuple(s.as_list()) else: raise ValueError( "The output shapes (%s) of the dry-run'ed function are" ' not fully defined.' % s) def to_numpy_dtype(t): return np.dtype(t.as_numpy_dtype) if isinstance(x, tf.TensorSpec): return ShapeDtype(to_numpy_shape(x.shape), to_numpy_dtype(x.dtype)) else: return x def f_return(*args): args = tf.nest.map_structure(from_shape_type, args) res = f_shape(*args) return tf.nest.map_structure(to_shape_type, res) return f_return
def testScanShape(self, f, inputs, expected_outputs): outputs = extensions.eval_on_shapes( functools.partial(extensions.scan, f), static_argnums=(2,))(*inputs) self.assertAllEqual(expected_outputs, outputs)
def f_prime(*args): res = extensions.eval_on_shapes(f, **kwargs)(*args) return tf.nest.map_structure( lambda x: tf_np.zeros(x.shape, x.dtype), res)
def f_prime(a, b): shape_dtype = extensions.eval_on_shapes(f)(a, b) return tf_np.zeros(shape=shape_dtype.shape, dtype=shape_dtype.dtype)
def transformer(f): return extensions.eval_on_shapes(f)
def f_prime(a, b): shape_dtype = extensions.eval_on_shapes(f)(a, b) return array_creation.zeros(shape=shape_dtype.shape, dtype=shape_dtype.dtype)
def _CompileAndCheck(self, fun, args_maker, check_dtypes, rtol=None, atol=None, check_eval_on_shapes=True, check_incomplete_shape=False, check_unknown_rank=True, static_argnums=()): """Compiles the function and checks the results. Args: fun: the function to be checked. args_maker: a callable that returns a tuple which will be used as the positional arguments. check_dtypes: whether to check that the result dtypes from non-compiled and compiled runs agree. rtol: relative tolerance for allclose assertions. atol: absolute tolerance for allclose assertions. check_eval_on_shapes: whether to run `eval_on_shapes` on the function and check that the result shapes and dtypes are correct. check_incomplete_shape: whether to check that the function can handle incomplete shapes (including those with and without a known rank). check_unknown_rank: (only has effect when check_incomplete_shape is True) whether to check that the function can handle unknown ranks. static_argnums: indices of arguments to be treated as static arguments for `jit` and `eval_on_shapes`. """ args = args_maker() for x in args: if not hasattr(x, 'dtype'): # If there is a input that doesn't have dtype info, jit and # eval_on_shapes may pick a different dtype for it than numpy, so we # skip the dtype check. check_dtypes = False # `wrapped_fun` and `python_should_be_executing` are used to check that when # the jitted function is called the second time, the original Python # function won't be executed. def wrapped_fun(*args): self.assertTrue(python_should_be_executing) return fun(*args) python_ans = fun(*args) python_shapes = tf.nest.map_structure(lambda x: onp.shape(x), python_ans) onp_shapes = tf.nest.map_structure(lambda x: onp.shape(onp.asarray(x)), python_ans) self.assertEqual(python_shapes, onp_shapes) cfun = npe.jit(wrapped_fun, static_argnums=static_argnums) python_should_be_executing = True monitored_ans = cfun(*args) python_should_be_executing = False compiled_ans = cfun(*args) self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol) self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) # Run `cfun` with a different set of arguments to check that changing # arguments won't cause recompilation. new_args = args_maker() skip_retracing_test = False for old, new in zip(args, new_args): if npe.most_precise_int_dtype(old) != npe.most_precise_int_dtype(new): # If the old and new arguments result in different dtypes (because they # fall into different value ranges), tf-numpy will retrace, so we skip # the no-retrace test. skip_retracing_test = True if not skip_retracing_test: python_should_be_executing = True new_python_ans = fun(*new_args) python_should_be_executing = False compiled_ans = cfun(*new_args) self.assertAllClose(new_python_ans, compiled_ans, check_dtypes, atol, rtol) if check_eval_on_shapes: # Check that npe.eval_on_shapes can get complete output shapes given # complete input shapes. cfun = npe.eval_on_shapes(fun, static_argnums=static_argnums) compiled_ans = cfun(*args) flat_python_ans = tf.nest.flatten(python_ans) flat_compiled_ans = tf.nest.flatten(compiled_ans) self.assertEqual(len(flat_python_ans), len(flat_compiled_ans)) for a, b in zip(flat_python_ans, flat_compiled_ans): if hasattr(a, 'shape'): self.assertEqual(a.shape, b.shape) if check_dtypes and hasattr(a, 'dtype'): self.assertEqual(tf.as_dtype(a.dtype), b.dtype) # If some argument doesn't have a `dtype` attr (e.g. a Python scalar), we # skip incomplete-shape checks, since shape specs need dtype. It's OK to # skip since the same incomplete-shape checks will run for []-shaped arrays. if check_incomplete_shape and all(hasattr(x, 'dtype') for x in args): # Check partial shapes with known ranks. # Numpy scalars (created by e.g. np.int32(5)) have `dtype` but not # `shape`. if all(hasattr(x, 'shape') for x in args): specs = [tf.TensorSpec([None] * len(x.shape), x.dtype) for x in args] cfun = npe.jit( fun, static_argnums=static_argnums, input_signature=specs) compiled_ans = cfun(*args) self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) if check_unknown_rank: # Check unknown ranks. specs = [tf.TensorSpec(None, x.dtype) for x in args] cfun = npe.jit( fun, static_argnums=static_argnums, input_signature=specs) compiled_ans = cfun(*args) self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol)