示例#1
0
        def compute(i, tas):
            """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if fn_output_signature and result_value structure don't match
        ValueType: if fn_output_signature and result_value lengths don't match
      """
            elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta]
            elems_value_flat = _elems_value_batchable_to_flat(
                elems_value_batchable, elems_flat_signature)
            elems_value = elems_unflatten(elems_value_flat)
            ag_ctx = autograph_ctx.control_status_ctx()
            autographed_fn = autograph.tf_convert(fn, ag_ctx)
            result_value = autographed_fn(elems_value)
            nest.assert_same_structure(fn_output_signature or elems,
                                       result_value)
            result_value_flat = nest.flatten(result_value)
            result_value_batchable = _result_value_flat_to_batchable(
                result_value_flat, result_flat_signature)
            tas = [
                ta.write(i, value)
                for (ta, value) in zip(tas, result_value_batchable)
            ]
            return (i + 1, tas)
  def _call_for_each_replica(self, fn, args, kwargs):
    if isinstance(fn, def_function.Function):
      wrapped = self._cfer_fn_cache.get(fn)
      if wrapped is None:
        # We need to wrap fn such that it triggers _call_for_each_replica inside
        # the tf.function.
        wrapped = fn._clone(  # pylint: disable=protected-access
            python_function=functools.partial(self._call_for_each_replica,
                                              fn.python_function))
        self._cfer_fn_cache[fn] = wrapped
      return wrapped(args, kwargs)

    if context.executing_eagerly():
      logging.log_first_n(
          logging.WARN, "Using %s eagerly has significant "
          "overhead currently. We will be working on improving "
          "this in the future, but for now please wrap "
          "`call_for_each_replica` or `experimental_run` or "
          "`run` inside a tf.function to get the best performance." %
          self._container_strategy().__class__.__name__, 5)
    else:
      # When a tf.function is wrapped to trigger _call_for_each_replica (see
      # the other branch above), AutoGraph stops conversion at
      # _call_for_each_replica itself (TF library functions are whitelisted).
      # This makes sure that the Python function that originally passed to
      # the tf.function is still converted.
      fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())

    return _call_for_each_replica(self._container_strategy(), self._devices,
                                  fn, args, kwargs)
示例#3
0
def call_for_each_replica(strategy, fn, args=None, kwargs=None):
    """Call `fn` on each worker devices(replica).

  It's highly recommended to wrap the call to this function inside a
  `tf.function`, otherwise the performance is poor.

  Args:
    strategy: `tf.distribute.Strategy`.
    fn: function to call on each worker devices.
    args: positional arguments to `fn`.
    kwargs: keyword arguments to `fn`.

  Returns:
    Wrapped returned value of `fn` from all replicas.
  """
    if args is None:
        args = ()
    if kwargs is None:
        kwargs = {}

    if isinstance(fn, def_function.Function):
        # Don't lift up the tf.function decoration if `fn` is compiled with XLA
        # and all devices are GPU. In this case we will use collectives to do
        # cross-device communication, thus no merge_call is in the path.
        if fn._jit_compile and all(  # pylint: disable=protected-access
            [_is_gpu_device(d) for d in strategy.extended.worker_devices]):
            return _call_for_each_replica(strategy, fn, args, kwargs)

        if strategy not in _cfer_fn_cache:
            _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary()
        wrapped = _cfer_fn_cache[strategy].get(fn)
        if wrapped is None:
            # We need to wrap fn such that it triggers _call_for_each_replica inside
            # the tf.function. We use _clone() instead of @tf.function wrapped
            # call_for_each_replica() because we would like to retain the arguments to
            # the @tf.function decorator of fn.
            wrapped = fn._clone(  # pylint: disable=protected-access
                python_function=functools.partial(call_for_each_replica,
                                                  strategy,
                                                  fn.python_function))
            _cfer_fn_cache[strategy][fn] = wrapped
        return wrapped(args, kwargs)

    if context.executing_eagerly():
        logging.log_first_n(
            logging.WARN, "Using %s eagerly has significant "
            "overhead currently. We will be working on improving "
            "this in the future, but for now please wrap "
            "`call_for_each_replica` or `experimental_run` or "
            "`run` inside a tf.function to get "
            "the best performance." % strategy.__class__.__name__, 5)
    else:
        # When a tf.function is wrapped to trigger _call_for_each_replica (see
        # the other branch above), AutoGraph stops conversion at
        # _call_for_each_replica itself (TF library functions are allowlisted).
        # This makes sure that the Python function that originally passed to
        # the tf.function is still converted.
        fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())

    return _call_for_each_replica(strategy, fn, args, kwargs)
示例#4
0
  def test_tf_convert_whitelisted_method(self):

    model = sequential.Sequential([core.Dense(2)])
    converted_call = api.tf_convert(
        model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
    _, converted_target = tf_decorator.unwrap(converted_call)
    self.assertIs(converted_target.__func__, model.call.__func__)
示例#5
0
    def experimental_run_v2(self, fn, args=(), kwargs=None):
        """See base class."""
        validate_experimental_run_function(fn)

        # Note: the target function is converted to graph even when in Eager mode,
        # so autograph is on by default here.
        fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
        return self.extended.tpu_run(fn, args, kwargs)
示例#6
0
 def experimental_run_v2(self, fn, args=(), kwargs=None):
     """See base class."""
     # tf.distribute supports Eager functions, so AutoGraph should not be applied
     # when when the caller is also in Eager mode.
     fn = autograph.tf_convert(fn,
                               ag_ctx.control_status_ctx(),
                               convert_by_default=False)
     return self.extended.tpu_run(fn, args, kwargs)
示例#7
0
    def run(self, fn, args=(), kwargs=None, options=None):
        """See base class."""
        validate_run_function(fn)

        # Note: the target function is converted to graph even when in Eager mode,
        # so autograph is on by default here.
        fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
        options = options or distribute_lib.RunOptions()
        return self.extended.tpu_run(fn, args, kwargs, options)
示例#8
0
    def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
        """Run `fn` on each replica, with the given arguments.

    Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
    "per-replica" values, such as those produced by a "distributed `Dataset`",
    when `fn` is executed on a particular replica, it will be executed with the
    component of those "per-replica" values that correspond to that replica.

    `fn` may call `tf.distribute.get_replica_context()` to access members such
    as `all_reduce`.

    All arguments in `args` or `kwargs` should either be nest of tensors or
    per-replica objects containing tensors or composite tensors.

    Users can pass strategy specific options to `options` argument. An example
    to enable bucketizing dynamic shapes in `TPUStrategy.experimental_run_v2`
    is:
    ```python

    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(tpu='')

    options = tf.distribute.RunOptions()
    options.experimental_bucketizing_dynamic_shape = True

    iterator = iter(inputs)

    @tf.function()
    def step_fn(inputs):
      output = tf.reduce_sum(inputs)
      return output

      strategy.experimental_run_v2(step_fn, args=(next(iterator),),
                                   options=options)
    ```

    Args:
      fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
      args: (Optional) Positional arguments to `fn`.
      kwargs: (Optional) Keyword arguments to `fn`.
      options: (Optional) An instance of `tf.distribute.RunOptions` specifying
        the options to run `fn`.

    Returns:
      Merged return value of `fn` across replicas. The structure of the return
      value is the same as the return value from `fn`. Each element in the
      structure can either be "per-replica" `Tensor` objects or `Tensor`s
      (for example, if running on a single replica).
    """
        validate_experimental_run_function(fn)

        fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
        options = options or distribute_lib.RunOptions()
        return self.extended.tpu_run(fn, args, kwargs, options)
示例#9
0
    def test_tf_convert_tf_decorator_allowlist_method(self):
        def wrap(f):
            def wrapper(*args, **kwargs):
                return wrapper.__wrapped__(*args, **kwargs)

            return tf_decorator.make_decorator(f, wrapper)

        class TestClass(object):
            @wrap
            def method(self):
                return converter_testing.is_inside_generated_code()

        converter_testing.allowlist(TestClass.method)

        obj = TestClass()
        # It's intended that tf_convert modifies the original method in this case.
        # This is not desirable, but options are limited.
        api.tf_convert(obj.method,
                       ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
        self.assertTrue(obj.method())
示例#10
0
    def test_tf_convert_allowlisted_method(self):
        class TestClass(object):
            def method(self):
                return converter_testing.is_inside_generated_code()

        converter_testing.allowlist(TestClass.method)

        obj = TestClass()
        converted_call = api.tf_convert(
            obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
        _, converted_target = tf_decorator.unwrap(converted_call)
        self.assertIs(converted_target.__func__, obj.method.__func__)
示例#11
0
    def test_tf_convert_whitelisted_method(self):

        if six.PY2:
            self.skipTest('Test bank not comptible with Python 2.')

        class TestClass(object):
            def method(self):
                return converter_testing.is_inside_generated_code()

        converter_testing.whitelist(TestClass.method)

        obj = TestClass()
        converted_call = api.tf_convert(
            obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
        _, converted_target = tf_decorator.unwrap(converted_call)
        self.assertIs(converted_target.__func__, obj.method.__func__)
示例#12
0
 def while_body(i, *ta_list):
     """Body of while loop."""
     fn_conv = autograph.tf_convert(loop_fn,
                                    autograph_ctx.control_status_ctx())
     fn_output = nest.flatten(fn_conv(i))
     if len(fn_output) != len(flat_loop_fn_dtypes):
         raise ValueError(
             f"Number of expected outputs {len(flat_loop_fn_dtypes)}, does not "
             f"match the number of actual outputs {len(fn_output)} from loop_fn: "
             f"{loop_fn} with output {fn_output}.")
     outputs = []
     del is_none_list[:]
     is_none_list.extend(x is None for x in fn_output)
     for out, ta in zip(fn_output, ta_list):
         # TODO(agarwal): support returning Operation objects from loop_fn.
         if out is not None:
             # out may be a ref tensor, wrap it in identity to get a non-ref tensor.
             ta = ta.write(i, array_ops.expand_dims(out, 0))
         outputs.append(ta)
     return tuple([i + 1] + outputs)
示例#13
0
    def update_state(self, y_true, y_pred, sample_weight=None):
        """Accumulates metric statistics.
        `y_true` and `y_pred` should have the same shape.

        Args:

        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
        sample_weight: Optional `sample_weight` acts as a
            coefficient for the metric. If a scalar is provided, then the metric is
            simply scaled by the given value. If `sample_weight` is a tensor of size
            `[batch_size]`, then the metric for each sample of the batch is rescaled
            by the corresponding element in the `sample_weight` vector. If the shape
            of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted
            to this shape), then each metric element of `y_pred` is scaled by the
            corresponding value of `sample_weight`. (Note on `dN-1`: all metric
            functions reduce by 1 dimension, usually the last axis (-1)).

        Returns:
        Update op.
        """
        y_true = math_ops.cast(y_true, self._dtype)
        # if type(y_pred) == dict:
        y_pred = {k: math_ops.cast(v, self._dtype) for k, v in y_pred.items()}
        # else:
        #    y_pred = math_ops.cast(y_pred, self._dtype)
        [
            y_true,
            y_pred,
        ], sample_weight = metrics_utils.ragged_assert_compatible_and_get_flat_values(
            [y_true, y_pred], sample_weight)
        # y_pred, y_true = tf_losses_utils.squeeze_or_expand_dimensions(
        #    y_pred, y_true)

        ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx())
        matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
        return super(MeanMetricWrapper,
                     self).update_state(matches, sample_weight=sample_weight)
示例#14
0
    def experimental_run_v2(self, fn, args=(), kwargs=None):
        """See base class."""
        validate_experimental_run_function(fn)

        fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
        return self.extended.tpu_run(fn, args, kwargs)
示例#15
0
 def test_fn(ctx):
   return api.tf_convert(f, ctx, convert_by_default=False)()
示例#16
0
 def test_fn(ctx):
   return api.tf_convert(f, ctx)()
示例#17
0
 def test_fn(ctx):
   return api.tf_convert(decorated_f, ctx)()
示例#18
0
    def inference(self, inputs, *args, **kwargs):

        call_context = base_layer_utils.call_context()
        input_list = nest.flatten(inputs)

        # We will attempt to build a TF graph if & only if all inputs are symbolic.
        # This is always the case in graph mode. It can also be the case in eager
        # mode when all inputs can be traced back to `keras.Input()` (when building
        # models using the functional API).
        build_graph = tf_utils.are_all_symbolic_tensors(input_list)

        # Accept NumPy and scalar inputs by converting to Tensors.
        if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
            def _convert_non_tensor(x):
                # Don't call `ops.convert_to_tensor` on all `inputs` because
                # `SparseTensors` can't be converted to `Tensor`.
                if isinstance(x, (np.ndarray, float, int)):
                    return ops.convert_to_tensor(x)
                return x
            inputs = nest.map_structure(_convert_non_tensor, inputs)
            input_list = nest.flatten(inputs)

        # Handle `mask` propagation from previous layer to current layer. Masks can
        # be propagated explicitly via the `mask` argument, or implicitly via
        # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
        # explicitly take priority.
        mask_arg_passed_by_framework = False
        input_masks = self._collect_input_masks(inputs, args, kwargs)
        if (self._expects_mask_arg and input_masks is not None and
                not self._call_arg_was_passed('mask', args, kwargs)):
            mask_arg_passed_by_framework = True
            kwargs['mask'] = input_masks

        # If `training` argument was not explicitly passed, propagate `training`
        # value from this layer's calling layer.
        training_arg_passed_by_framework = False
        # Priority 1: `training` was explicitly passed.
        if self._call_arg_was_passed('training', args, kwargs):
            training_value = self._get_call_arg_value('training', args, kwargs)
            if not self._expects_training_arg:
                kwargs.pop('training')
        else:
            training_value = None
            # Priority 2: `training` was passed to a parent layer.
            if call_context.training is not None:
                training_value = call_context.training
            # Priority 3a: `learning_phase()` has been set.
            elif backend.global_learning_phase_is_set():
                training_value = backend.learning_phase()
            # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph.
            elif build_graph:
                with backend.get_graph().as_default():
                    if base_layer_utils.is_in_keras_graph():
                        training_value = backend.learning_phase()

            if self._expects_training_arg and training_value is not None:
                # Force the training_value to be bool type which matches to the contract
                # for layer/model call args.
                if tensor_util.is_tensor(training_value):
                    training_value = math_ops.cast(training_value, dtypes.bool)
                else:
                    training_value = bool(training_value)
                kwargs['training'] = training_value
                training_arg_passed_by_framework = True

        # Only create Keras history if at least one tensor originates from a
        # `keras.Input`. Otherwise this Layer may be being used outside the Keras
        # framework.
        if build_graph and base_layer_utils.needs_keras_history(inputs):
            base_layer_utils.create_keras_history(inputs)

        # Clear eager losses on top level model call.
        # We are clearing the losses only on the top level model call and not on
        # every layer/model call because layer/model may be reused.
        if (base_layer_utils.is_in_eager_or_tf_function() and
                not call_context.in_call):
            self._clear_losses()

        with call_context.enter(self, inputs, build_graph, training_value):
            # Check input assumptions set after layer building, e.g. input shape.
            if build_graph:
                # Symbolic execution on symbolic tensors. We will attempt to build
                # the corresponding TF subgraph inside `backend.get_graph()`
                # TODO(reedwm): We should assert input compatibility after the inputs
                # are casted, not before.
                input_spec.assert_input_compatibility(self.input_spec, inputs,
                                                                                            self.name)
                if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list)
                        and self._supports_ragged_inputs is False):    # pylint: disable=g-bool-id-comparison
                    raise ValueError('Layer %s does not support RaggedTensors as input. '
                                                     'Inputs received: %s. You can try converting your '
                                                     'input to an uniform tensor.' % (self.name, inputs))

                graph = backend.get_graph()
                with graph.as_default(), backend.name_scope(self._name_scope()):
                    # Build layer if applicable (if the `build` method has been
                    # overridden).
                    self._maybe_build(inputs)
                    cast_inputs = self._maybe_cast_inputs(inputs)

                    # Wrapping `call` function in autograph to allow for dynamic control
                    # flow and control dependencies in call. We are limiting this to
                    # subclassed layers as autograph is strictly needed only for
                    # subclassed layers and models.
                    # tf_convert will respect the value of autograph setting in the
                    # enclosing tf.function, if any.
                    if (base_layer_utils.is_subclassed(self) and
                            not base_layer_utils.from_saved_model(self)):
                        call_fn = autograph.tf_convert(
                                self._inference, ag_ctx.control_status_ctx())
                    else:
                        call_fn = self._inference

                    if not self.dynamic:
                        try:
                            with base_layer_utils.autocast_context_manager(
                                    self._compute_dtype):
                                # Add auto_control_deps in V2 when they are not already added by
                                # a `tf.function`.
                                if (ops.executing_eagerly_outside_functions() and
                                        not base_layer_utils.is_in_eager_or_tf_function()):
                                    with auto_control_deps.AutomaticControlDependencies() as acd:
                                        outputs = call_fn(cast_inputs, *args, **kwargs)
                                        # Wrap Tensors in `outputs` in `tf.identity` to avoid
                                        # circular dependencies.
                                        outputs = base_layer_utils.mark_as_return(outputs, acd)
                                else:
                                    outputs = call_fn(cast_inputs, *args, **kwargs)

                        except errors.OperatorNotAllowedInGraphError as e:
                            raise TypeError('You are attempting to use Python control '
                                                            'flow in a layer that was not declared to be '
                                                            'dynamic. Pass `dynamic=True` to the class '
                                                            'constructor.\nEncountered error:\n"""\n' +
                                                            str(e) + '\n"""')
                    else:
                        # We will use static shape inference to return symbolic tensors
                        # matching the specifications of the layer outputs.
                        # Since `self.dynamic` is True, we will never attempt to
                        # run the underlying TF graph (which is disconnected).
                        # TODO(fchollet): consider py_func as an alternative, which
                        # would enable us to run the underlying graph if needed.
                        outputs = self._symbolic_call(inputs)

                    if outputs is None:
                        raise ValueError('A layer\'s `call` method should return a '
                                                         'Tensor or a list of Tensors, not None '
                                                         '(layer: ' + self.name + ').')
                    if base_layer_utils.have_all_keras_metadata(inputs):
                        if training_arg_passed_by_framework:
                            kwargs.pop('training')
                        if mask_arg_passed_by_framework:
                            kwargs.pop('mask')
                        inputs, outputs = self._set_connectivity_metadata_(
                                inputs, outputs, args, kwargs)
                    self._handle_activity_regularization(inputs, outputs)
                    self._set_mask_metadata(inputs, outputs, input_masks)
                    if hasattr(self, '_set_inputs') and not self.inputs:
                        # Subclassed network: explicitly set metadata normally set by
                        # a call to self._set_inputs().
                        # TODO(b/120997007): This should be done in Eager as well, but
                        # causes garbage collection issues because of the placeholders
                        # created on the default Keras graph.
                        self._set_inputs(inputs, outputs)
            else:
                # Eager execution on data tensors.
                with backend.name_scope(self._name_scope()):
                    self._maybe_build(inputs)
                    cast_inputs = self._maybe_cast_inputs(inputs)
                    with base_layer_utils.autocast_context_manager(
                            self._compute_dtype):
                        outputs = self._inference(cast_inputs, *args, **kwargs)
                    self._handle_activity_regularization(inputs, outputs)
                    self._set_mask_metadata(inputs, outputs, input_masks)

        return outputs
示例#19
0
def _pfor_impl(loop_fn,
               iters,
               fallback_to_while_loop,
               parallel_iterations=None,
               pfor_config=None,
               warn=False):
    """Implementation of pfor."""
    assert not context.executing_eagerly()
    loop_fn_has_config = _loop_fn_has_config(loop_fn)
    existing_ops = set(ops.get_default_graph().get_operations())
    iters_value = tensor_util.constant_value(iters)
    # Run the loop body
    with ops.name_scope("loop_body"):
        loop_var = array_ops.placeholder_with_default(0, shape=[])
        if loop_fn_has_config:
            if pfor_config is None:
                pfor_config = PForConfig()
                pfor_config._set_iters(iters)  # pylint: disable=protected-access
            loop_fn_outputs = loop_fn(loop_var,
                                      **{PFOR_CONFIG_ARG: pfor_config})
        else:
            assert pfor_config is None
            f = autograph.tf_convert(loop_fn,
                                     autograph_ctx.control_status_ctx())
            loop_fn_outputs = f(loop_var)
        loop_fn_output_tensors = nest.map_structure(_composite_to_tensors,
                                                    loop_fn_outputs)

    # Convert outputs to Tensor if needed.
    tmp_loop_fn_outputs = []
    for loop_fn_output in nest.flatten(loop_fn_output_tensors):
        if (loop_fn_output is not None and not isinstance(
                loop_fn_output,
            (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))):
            if isinstance(loop_fn_output, indexed_slices.IndexedSlices):
                logging.warn(
                    "Converting %s to a dense representation may make it slow."
                    " Alternatively, output the indices and values of the"
                    " IndexedSlices separately, and handle the vectorized"
                    " outputs directly." % loop_fn_output)
                loop_fn_output = ops.convert_to_tensor(loop_fn_output)
            else:
                loop_fn_output = ops.convert_to_tensor(loop_fn_output)
        tmp_loop_fn_outputs.append(loop_fn_output)
    loop_fn_output_tensors = nest.pack_sequence_as(loop_fn_output_tensors,
                                                   tmp_loop_fn_outputs)

    new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
    iters = ops.convert_to_tensor(iters)
    if parallel_iterations is not None:
        if parallel_iterations < 1:
            raise ValueError(
                "Argument `parallel_iterations` must be None or a positive integer. "
                f"Received: {parallel_iterations}.")
        if parallel_iterations == 1:
            raise ValueError(
                "Found `parallel_iterations == 1`. Use `for_loop` instead.")
        if iters_value is not None and iters_value < parallel_iterations:
            parallel_iterations = None
    if parallel_iterations is None:
        with ops.name_scope("pfor"):
            converter = PFor(loop_var,
                             iters,
                             new_ops,
                             fallback_to_while_loop=fallback_to_while_loop,
                             pfor_config=pfor_config,
                             warn=warn)
            flattened_output_tensors = []
            for loop_fn_output in nest.flatten(loop_fn_output_tensors):
                output = converter.convert(loop_fn_output)
                flattened_output_tensors.append(output)
    else:
        if pfor_config is not None and pfor_config._has_reductions():  # pylint: disable=protected-access
            raise ValueError(
                "Setting `parallel_iterations` currently unsupported if "
                "reductions across iterations are performed.")
        num_tiled_iterations = iters // parallel_iterations
        num_remaining_iterations = iters % parallel_iterations
        # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside
        # a tf.function and extract the graph from there to vectorize it.
        with ops.name_scope("pfor_untiled"):
            converter = PFor(loop_var,
                             num_remaining_iterations,
                             new_ops,
                             fallback_to_while_loop=fallback_to_while_loop,
                             pfor_config=pfor_config)
            remaining_output_tensors = []
            flattened_output_tensors = nest.flatten(loop_fn_output_tensors)
            for loop_fn_output in flattened_output_tensors:
                output = converter.convert(loop_fn_output)
                remaining_output_tensors.append(output)

        with ops.name_scope("pfor_tiled"):
            loop_fn_dtypes = [
                ops.convert_to_tensor(x).dtype
                for x in flattened_output_tensors
            ]

            def tiled_loop_body(j):
                offset = j * parallel_iterations + num_remaining_iterations

                def tiled_loop_fn(i, pfor_config=None):
                    if loop_fn_has_config:
                        loop_fn_outputs = loop_fn(i + offset,
                                                  pfor_config=pfor_config)
                    else:
                        loop_fn_outputs = loop_fn(i + offset)
                    return nest.flatten(
                        # Stacking across iterations requires explicit Tensors.
                        nest.map_structure(_composite_to_tensors,
                                           loop_fn_outputs))

                return _pfor_impl(
                    tiled_loop_fn,
                    parallel_iterations,
                    fallback_to_while_loop=fallback_to_while_loop,
                    pfor_config=pfor_config)

            tiled_output_tensors = for_loop(tiled_loop_body,
                                            loop_fn_dtypes,
                                            num_tiled_iterations,
                                            parallel_iterations=1)
            tiled_output_tensors = [
                _flatten_first_two_dims(y) for y in tiled_output_tensors
            ]

        with ops.name_scope("pfor"):
            if iters_value is None or iters_value % parallel_iterations:
                output_tensors = control_flow_ops.cond(
                    math_ops.equal(num_remaining_iterations, 0),
                    lambda: tiled_output_tensors,
                    lambda: [
                        array_ops.concat([x, y], axis=0)  # pylint: disable=g-long-lambda
                        for x, y in zip(remaining_output_tensors,
                                        tiled_output_tensors)
                    ])
            else:
                output_tensors = tiled_output_tensors
            flattened_output_tensors = nest.flatten(output_tensors)

            for output, original_output in zip(
                    flattened_output_tensors,
                    nest.flatten(loop_fn_output_tensors)):
                # Restore any shape information lost from tiling.
                # TODO(b/174254748): this may not be correct for stacked `variant`s.
                output.set_shape(
                    tensor_shape.TensorShape([iters_value]).concatenate(
                        original_output.shape))

    return nest.map_structure_up_to(
        loop_fn_outputs,
        functools.partial(_composite_from_tensors, batch_size=iters_value),
        nest.pack_sequence_as(loop_fn_output_tensors,
                              flattened_output_tensors), loop_fn_outputs)
示例#20
0
 def test_fn(ctx, expect_converted):
   return api.tf_convert(f, ctx)(expect_converted)
示例#21
0
 def call(self, y_true, y_pred):
     if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true):
         y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions(
             y_pred, y_true)
     ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx())
     return ag_fn(y_true, y_pred, **self._fn_kwargs)