示例#1
0
  def testReadVariableInsideFunction(self, distribution, run_functions_eagerly):

    def_function.run_functions_eagerly(run_functions_eagerly)

    # Get devices on which variables will be placed. Default strategy does not
    # define this, so assume cpu:0 in that case.
    try:
      devices = distribution.extended.parameter_devices
    except RuntimeError:
      devices = ["cpu:0"]

    with distribution.scope():
      v = variables.Variable(0.)
      if isinstance(v, values.DistributedVariable):
        for i in range(len(devices)):
          # NOTE: Assigning manually to component variables so we can test
          # different values on different devices. Using .assign on the
          # mirrored variable itself will lead to a synchronization which
          # will prohibit testing different values.
          replica_variable = v._values[i]
          replica_variable.assign(math_ops.cast(i, dtypes.float32))

    @def_function.function
    def read():
      return v.read_value()

    # Verify that the value from each device is read, when in that device
    # scope. Doing this inside strategy scope is needed to force function
    # retracing on each device, otherwise `read()` will only be traced once
    # on the first device and following variable read will always read the value
    # on the first replica.
    with distribution.scope():
      for i, d in enumerate(devices):
        with ops.device(d):
          self.assertEqual(math_ops.cast(i, dtypes.float32), read())
    def testDefaultDeviceInsideNestedFunctionWithScope(self, distribution,
                                                       run_functions_eagerly):

        def_function.run_functions_eagerly(run_functions_eagerly)
        try:
            worker = distribution.extended.worker_devices[0]
        except RuntimeError:
            worker = None
        expected_device = (device_util.canonicalize("cpu:0", worker)
                           if run_functions_eagerly else "")
        with distribution.scope():

            @def_function.function
            def foo():
                with ops.device("cpu:0"):

                    @def_function.function
                    def bar():
                        one = array_ops.ones([])
                        self.assertEqual(expected_device, one.device)
                        return one + 1

                    bar()

            foo()
示例#3
0
 def _benchmark_matmul_forward_backward_2_by_2_CPU(self, run_eager=False):
   def_function.run_functions_eagerly(run_eager)
   with context.device(CPU):
     m = self._m_2_by_2.cpu()
     self._benchmark_defun_matmul_forward_backward(
         m, transpose_b=False, num_iters=self._num_iters_2_by_2)
   def_function.run_functions_eagerly(False)
示例#4
0
 def testRunFunctionsEagerly(self):
   try:
     original_setting = def_function.functions_run_eagerly()
     def_function.run_functions_eagerly(True)
     x = constant_op.constant(1.)
     with forwardprop.ForwardAccumulator(x, 2.) as acc:
       y = x * 3.
     self.assertAllClose(6., acc.jvp(y))
   finally:
     def_function.run_functions_eagerly(original_setting)
示例#5
0
 def test_cond(self, run_functions_eagerly):
     try:
         def_function.run_functions_eagerly(run_functions_eagerly)
         with self.device:
             pred = self.device.pack([True, False])
             capture = self.device.pack([[1.], [2.]])
             result = control_flow_ops.cond(
                 pred, def_function.function(lambda: capture * 2.),
                 def_function.function(lambda: capture * 4.))
         self.assertAllClose([[2.], [8.]], self.device.unpack(result))
     finally:
         def_function.run_functions_eagerly(False)
示例#6
0
    def testDefaultDeviceInsideFunctionWithScope(self, distribution,
                                                 run_functions_eagerly):

        def_function.run_functions_eagerly(run_functions_eagerly)
        expected_device = (device_util.canonicalize("cpu:0")
                           if run_functions_eagerly else "")
        with distribution.scope():
            with ops.device_v2("cpu:0"):

                @def_function.function
                def add():
                    one = array_ops.ones([])
                    self.assertEqual(expected_device, one.device)
                    return one + 1

                add()
示例#7
0
def pfor(loop_fn,
         iters,
         fallback_to_while_loop=True,
         parallel_iterations=None):
    """Equivalent to running `loop_fn` `iters` times and stacking the outputs.

  `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
  times, with input from 0 to `iters - 1`, and stacking corresponding output of
  each iteration. However the implementation does not use a `tf.while_loop`.
  Instead it adds new operations to the graph that collectively compute the same
  value as what running `loop_fn` in a loop would compute.


  This is an experimental feature and currently has a lot of limitations:
    - There should be no data dependency between the different iterations. For
      example, a future iteration should not depend on a value or side-effect of
      a previous iteration.
    - Stateful kernels may mostly not be supported since these often imply a
      data dependency or ordering of the iterations. We do support a limited set
      of such stateful kernels though (like RandomFoo, Variable operations like
      reads, etc).
    - Conversion works only on a limited set of kernels for which a converter
      has been registered.
    - `loop_fn` has limited support for control flow operations. `tf.cond` in
      particular is not supported.
    - `loop_fn` should return nested structure of Tensors or Operations. However
      if an Operation is returned, it should have zero outputs.
    - The shape and dtype of `loop_fn` outputs should not depend on the input
      to loop_fn.

  Args:
    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
      the iteration number, and optionally a keyword argument `pfor_config` set
      to a PForConfig object. It returns a possibly nested structure of Tensor
      or Operation objects. Note that if setting `parallel_iterations` argument
      to something other than None, `loop_fn` may be called more than once
      during graph construction. So it may need to avoid mutating global state.
    iters: Number of iterations for which to run `loop_fn`.
    fallback_to_while_loop: If true, on failing to vectorize an operation, pfor
      fallbacks to using a `tf.while_loop` to dispatch the iterations.
    parallel_iterations: A knob to control how many iterations are vectorized
      and dispatched in parallel. The default value of None corresponds to
      vectorizing all the iterations.  If `parallel_iterations` is smaller than
      `iters`, then chunks of at most that many iterations are dispatched in
      sequence. This knob can be used to control the total memory usage.

  Returns:
    Returns a nested structure of stacked tensor objects with the same nested
    structure as the output of `loop_fn`.
  Raises:
    ValueError: If parallel_iterations is not None and not an integer > 1.
  """
    def f():
        return _pfor_impl(loop_fn,
                          iters,
                          fallback_to_while_loop=fallback_to_while_loop,
                          parallel_iterations=parallel_iterations)

    # Note that we wrap into a tf.function if in eager execution mode or under
    # XLA compilation. The latter is so that we don't compile operations like
    # tf.placeholder that are created by the loop body.
    functions_run_eagerly = None
    if context.executing_eagerly() or _is_under_xla_context():
        functions_run_eagerly = def_function.functions_run_eagerly()
        if functions_run_eagerly:
            logging.warning(
                "It looks like tf.function behavior was disabled, perhaps using "
                "tf.config.run_functions_eagerly. Vectorization "
                "primitives (e.g. tf.vectorized_map) require tf.function to work. "
                "These primitives will override the disable.")
            def_function.run_functions_eagerly(False)
        f = def_function.function(f)
    outputs = f()
    if functions_run_eagerly is not None:
        def_function.run_functions_eagerly(functions_run_eagerly)
    return outputs
 def setUp(self):
     super().setUp()
     # Clear the state for every test.
     def_function.run_functions_eagerly(False)
示例#9
0
 def setup(self):
     # Clear the state for every test.
     def_function.run_functions_eagerly(False)