示例#1
0
def _tf_range_for_stmt(
    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
  """Overload of for_stmt that iterates over a TF range (and elides it)."""
  start, limit, delta = iter_.op.inputs

  iterate = compat_util.BasicRef(start)

  def _value_or(name, var, default):
    if (name == opts['iterate_names'] and isinstance(var, variables.Undefined)):
      return default
    return var

  def aug_get_state():
    state_vars = get_state()
    state_vars = tuple(
        _value_or(name, var, iterate.value)
        for name, var in zip(symbol_names, state_vars))
    return (iterate.value,) + state_vars

  def aug_set_state(aug_loop_vars):
    # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
    iterate.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
    # The iteration index is not "output" by the for loop. If the iterate
    # is used outside the loop, it will appear in the loop vars separately.
    set_state(loop_vars)

  def aug_body():
    body(iterate.value)
    iterate.value += delta

  def aug_test():
    # TODO(b/159713842): Remove once constant folding works.
    const_delta = tensor_util.constant_value(delta)
    if const_delta is not None:
      if const_delta >= 0:
        main_test = iterate.value < limit
      else:
        main_test = iterate.value > limit
    else:
      main_test = math_ops.logical_or(
          math_ops.logical_and(delta >= 0, iterate.value < limit),
          math_ops.logical_and(delta < 0, iterate.value > limit))

    if extra_test is not None:
      main_test = control_flow_ops.cond(main_test, extra_test, lambda: False)
    return main_test

  # TODO(b/134181679): Remove.
  if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
    opts['maximum_iterations'] = math_ops.cast(
        misc.get_range_len(start, limit, delta), dtypes.int32)

  _tf_while_stmt(
      aug_test,
      aug_body,
      aug_get_state,
      aug_set_state,
      ('<internal iterate>',) + symbol_names,
      opts)
示例#2
0
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        tensor in self._forward_graph.inputs or
        tensor in self._forward_graph.outputs):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.captures:
        self.xla_intermediates.append(tensor)
        self.if_op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              consumer.outputs[0] in self._forward_graph.outputs):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.if_op_needs_rewrite = True
        self._wrapped_intermediates[tensor] = optional

      optional = self._wrapped_intermediates[tensor]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor] = captured_tensor
    return captured_tensor
示例#3
0
def _is_inside_compilation():
    graph = ops.get_default_graph()
    attrs = graph._attr_scope_map  # pylint: disable=protected-access

    is_in_xla_context = control_flow_util.GraphOrParentsInXlaContext(graph)
    is_outside_compilation = scopes.OUTSIDE_COMPILATION_NAME in attrs

    return is_in_xla_context and not is_outside_compilation
示例#4
0
def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
                           init_vars, basic_symbol_names,
                           composite_symbol_names, opts):
    """Overload of for_stmt that iterates over TF entities that admit a length."""
    _disallow_undefs_into_loop(*init_vars)

    n = py_builtins.len_(iter_)
    # TODO(b/117628877): Revisit performance once XLA has the necessary support.
    # Note: using a TensorArray creates an extra copy, but can calculate
    # gradients more efficiently than StridedSlice.
    ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
    iter_ = ta.unstack(iter_)

    def while_body(iterate_index, *loop_vars):
        """Main loop body."""
        iterate = iter_.read(iterate_index)
        new_vars = body(iterate, *loop_vars)

        loop_vars = (iterate_index + 1, )
        if new_vars:
            loop_vars += new_vars

        return loop_vars

    def while_cond(iterate_index, *loop_vars):
        if extra_test is not None:
            return control_flow_ops.cond(iterate_index < n,
                                         lambda: extra_test(*loop_vars),
                                         lambda: False)
        return iterate_index < n

    # TODO(b/134181679): Let the op itself handle optimizations.
    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
        opts['maximum_iterations'] = n

    results = _tf_while_stmt(
        while_cond,
        while_body,
        get_state,
        set_state,
        (array_ops.zeros_like(n), ) + init_vars,
        ('<internal iterate>', ) + basic_symbol_names,
        composite_symbol_names,
        opts,
    )

    # Note: the iteration index is not returned by the while loop, however
    # if a symbol with the same name exists outside the loop, it will be captured
    # by the loop variables and ultimately updated correctly.
    if isinstance(results, (tuple, list)):
        assert len(results) >= 1  # Has at least the iterate.
        if len(results) > 1:
            results = results[1:]
    else:
        results = ()

    return results
def _is_xla():
  """Returns `True` when we are tracing a function for XLA compilation."""
  if JAX_MODE:
    return True
  # Import locally to avoid TF dependency for TFP-on-JAX.
  from tensorflow.python.framework import ops  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
  from tensorflow.python.ops import control_flow_util  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
  return (not tf.executing_eagerly() and
          control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()))
示例#6
0
def smart_for_loop(loop_num_iter,
                   body_fn,
                   initial_loop_vars,
                   parallel_iterations=10,
                   unroll_threshold=1,
                   name=None):
    """Construct a for loop, preferring a python loop if `n` is statically known.

  Given `loop_num_iter` and `body_fn`, return an op corresponding to executing
  `body_fn` `loop_num_iter` times, feeding previous outputs of `body_fn` into
  the next iteration.

  If `loop_num_iter` is statically known, the op is constructed via python for
  loop, and otherwise a `tf.while_loop` is used.

  Args:
    loop_num_iter: `Integer` `Tensor` representing the number of loop
      iterations.
    body_fn: Callable to be executed `loop_num_iter` times.
    initial_loop_vars: Listlike object of `Tensors` to be passed in to
      `body_fn`'s first execution.
    parallel_iterations: The number of iterations allowed to run in parallel.
      It must be a positive integer. See `tf.while_loop` for more details.
      Default value: `10`.
    unroll_threshold: Integer denoting the maximum number of iterations to
      unroll, if possible. If `loop_num_iter > unroll_threshold` a
      `tf.while_loop` will always be used, even if `loop_num_iter` is
      statically known.
      Default value: `1`.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "smart_for_loop").
  Returns:
    result: `Tensor` representing applying `body_fn` iteratively `n` times.
  """
    with tf.name_scope(name or 'smart_for_loop'):
        loop_num_iter_ = tf.get_static_value(loop_num_iter)
        if (loop_num_iter_ is None or tf.executing_eagerly()
                # large values for loop_num_iter_ will cause ridiculously slow
                # graph compilation time (GitHub issue #1033)
                or loop_num_iter_ > unroll_threshold or
                control_flow_util.GraphOrParentsInXlaContext(
                    tf1.get_default_graph())):
            # Cast to int32 to run the comparison against i in host memory,
            # where while/LoopCond needs it.
            loop_num_iter = tf.cast(loop_num_iter, dtype=tf.int32)
            return tf.while_loop(
                cond=lambda i, *args: i < loop_num_iter,
                body=lambda i, *args: [i + 1] + list(body_fn(*args)),
                loop_vars=[np.int32(0)] + initial_loop_vars,
                parallel_iterations=parallel_iterations)[1:]
        result = initial_loop_vars
        for _ in range(loop_num_iter_):
            result = body_fn(*result)
        return result
 def _variadic_reduce_no_grad(operands, inits, axis, reducer):
   if JAX_MODE:
     from jax import lax  # pylint: disable=g-import-not-at-top
     return lax.reduce(
         operands, init_values=inits, dimensions=axis, computation=reducer)
   elif (tf.executing_eagerly() or
         not control_flow_util.GraphOrParentsInXlaContext(
             tf1.get_default_graph())):
     return _variadic_reduce(
         operands, init=inits, axis=axis, reducer=reducer)
   else:
     return _xla_reduce(operands, inits, axis)
示例#8
0
            def f(x):
                assert control_flow_util.GraphOrParentsInXlaContext(
                    ops.get_default_graph())
                x = ops.convert_to_tensor(x)

                def body(i, a):
                    return i + 1, control_flow_ops.cond(
                        i > 2, lambda: a + (x**2), lambda: a + 3)

                return control_flow_ops.while_loop(
                    lambda i, *_: i < 10,
                    body, (constant_op.constant(0), constant_op.constant(3.)),
                    maximum_iterations=10)[1]
示例#9
0
def _xla_compile(func: Callable, *args: TfVal) -> TfVal:
    """Ensure that the function is compiled with XLA.

  This is needed to work around some bugs, e.g., without XLA
  a certain TF op has different behavior than expected by JAX.
  """
    # Do not invoke XLA if we are already in an XLA context
    in_xla_context = control_flow_util.GraphOrParentsInXlaContext(
        ops.get_default_graph())
    if in_xla_context:
        return func(*args)
    else:
        res, = tf.xla.experimental.compile(func, args)
        return res
示例#10
0
    def _log_unnorm_prob(self, x, concentration, name=None):
        """Returns the unnormalized log density of an LKJ distribution.

    Args:
      x: `float` or `double` `Tensor` of correlation matrices.  The shape of `x`
        must be `B + [D, D]`, where `B` broadcasts with the shape of
        `concentration`.
      concentration: `float` or `double` `Tensor`. The positive concentration
        parameter of the LKJ distributions.
      name: Python `str` name prefixed to Ops created by this function.

    Returns:
      log_p: A Tensor of the unnormalized log density of each matrix element of
        `x`, with respect to an LKJ distribution with parameter the
        corresponding element of `concentration`.
    """
        with tf.name_scope(name or 'log_unnorm_prob_lkj'):
            x = tf.convert_to_tensor(x, name='x')
            # The density is det(matrix) ** (concentration - 1).
            # Computing the determinant with `logdet` is usually fine, since
            # correlation matrices are Hermitian and PSD. But in some cases, for a
            # PSD matrix whose eigenvalues are close to zero, `logdet` raises an error
            # complaining that it is not PSD. The root cause is the computation of the
            # cholesky decomposition in `logdet`. Hence, we use the less efficient but
            # more robust `slogdet` which does not use `cholesky`.
            #
            # An alternative would have been to check allow_nan_stats and use
            #   eigenvalues = tf.linalg.self_adjoint_eigvals(x)
            #   psd_mask = tf.cast(
            #     tf.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype)
            #   tf.where(psd_mask, answer, float('-inf'))
            # to emit probability 0 for inputs that are not PSD, without ever raising
            # an error. More care must be taken, as due to numerical stability issues,
            # self_adjoint_eigvals can return slightly negative eigenvalues even for
            # a PSD matrix.
            if self.input_output_cholesky:
                logdet = 2.0 * tf.reduce_sum(
                    tf.math.log(tf.linalg.diag_part(x)), axis=[-1])
            else:
                # TODO(b/162937268): Remove the hackaround.
                if (not tf.executing_eagerly()
                        and control_flow_util.GraphOrParentsInXlaContext(
                            tf1.get_default_graph())):
                    s = tf.linalg.svd(x, compute_uv=False)
                    logdet = tf.math.reduce_sum(tf.math.log(s), -1)
                else:
                    logdet = tf.linalg.slogdet(x).log_abs_determinant
            answer = (concentration - 1.) * logdet
            return answer
示例#11
0
def maybe_propagate_compile_time_consts_in_xla(op):
    """Tells XLA whether to propagate compile-time consts in the loop body.

  This is needed to make compile time constants available to ops, for example
  `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this
  would always be turned on, but that doesn't work with legacy functionalized
  while_loops.

  Args:
    op: A `While` Operation.
  """
    if control_flow_util.GraphOrParentsInXlaContext(op.graph):
        # pylint: disable=protected-access
        op._set_attr("_xla_propagate_compile_time_consts",
                     attr_value_pb2.AttrValue(b=True))
示例#12
0
def smart_for_loop(loop_num_iter,
                   body_fn,
                   initial_loop_vars,
                   parallel_iterations=10,
                   name=None):
    """Construct a for loop, preferring a python loop if `n` is staticaly known.

  Given `loop_num_iter` and `body_fn`, return an op corresponding to executing
  `body_fn` `loop_num_iter` times, feeding previous outputs of `body_fn` into
  the next iteration.

  If `loop_num_iter` is statically known, the op is constructed via python for
  loop, and otherwise a `tf.while_loop` is used.

  Args:
    loop_num_iter: `Integer` `Tensor` representing the number of loop
      iterations.
    body_fn: Callable to be executed `loop_num_iter` times.
    initial_loop_vars: Listlike object of `Tensors` to be passed in to
      `body_fn`'s first execution.
    parallel_iterations: The number of iterations allowed to run in parallel.
      It must be a positive integer. See `tf.while_loop` for more details.
      Default value: `10`.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "smart_for_loop").
  Returns:
    result: `Tensor` representing applying `body_fn` iteratively `n` times.
  """
    with tf.compat.v1.name_scope(name, 'smart_for_loop',
                                 [loop_num_iter, initial_loop_vars]):
        loop_num_iter_ = tf.get_static_value(
            tf.convert_to_tensor(value=loop_num_iter,
                                 dtype=tf.int64,
                                 name='loop_num_iter'))
        if (loop_num_iter_ is None or tf.executing_eagerly()
                or control_flow_util.GraphOrParentsInXlaContext(
                    tf.compat.v1.get_default_graph())):
            return tf.while_loop(
                cond=lambda i, *args: i < loop_num_iter,
                body=lambda i, *args: [i + 1] + list(body_fn(*args)),
                loop_vars=[np.int64(0)] + initial_loop_vars,
                parallel_iterations=parallel_iterations)[1:]
        result = initial_loop_vars
        for _ in range(loop_num_iter_):
            result = body_fn(*result)
        return result
示例#13
0
def _known_len_tf_for_stmt(iter_, extra_test, body, get_state, set_state,
                           symbol_names, opts):
    """Overload of for_stmt that iterates over TF entities that admit a length."""
    n = py_builtins.len_(iter_)

    # TODO(b/117628877): Revisit performance once XLA has the necessary support.
    # Note: using a TensorArray creates an extra copy, but can calculate
    # gradients more efficiently than StridedSlice.
    ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
    iter_ = ta.unstack(iter_)

    iterate_index = compat_util.BasicRef(0)

    def aug_get_state():
        return (iterate_index.value, ) + get_state()

    def aug_set_state(aug_loop_vars):
        # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
        iterate_index.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
        # The iteration index is not "output" by the for loop. If the iterate
        # is used outside the loop, it will appear in the loop vars separately.
        set_state(loop_vars)

    def aug_body():
        body(iter_.read(iterate_index.value))
        iterate_index.value += 1

    def aug_test():
        main_test = iterate_index.value < n
        if extra_test is not None:
            return control_flow_ops.cond(main_test, extra_test, lambda: False)
        return main_test

    # TODO(b/159186914): Remove.
    if not control_flow_util.GraphOrParentsInXlaContext(
            ops.get_default_graph()):
        opts['maximum_iterations'] = n

    _tf_while_stmt(
        aug_test,
        aug_body,
        aug_get_state,
        aug_set_state,
        ('<internal iterate>', ) + symbol_names,
        opts,
    )
示例#14
0
def _tf_ragged_for_stmt(
    iter_, extra_test, body, get_state, set_state, symbol_names, opts):
  """Overload of for_stmt that iterates over TF ragged tensors."""
  init_vars = get_state()
  _verify_loop_init_vars(init_vars, symbol_names)

  # TODO(mdan): Move this into len()? Requires eager support.
  if iter_.shape and iter_.shape[0] is not None:
    n = iter_.shape[0]
  else:
    n = iter_.row_lengths()[0]

  iterate_index = compat_util.BasicRef(0)

  def aug_get_state():
    return (iterate_index.value,) + get_state()

  def aug_set_state(aug_loop_vars):
    # TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
    iterate_index.value, loop_vars = aug_loop_vars[0], aug_loop_vars[1:]
    # The iteration index is not "output" by the for loop. If the iterate
    # is used outside the loop, it will appear in the loop vars separately.
    set_state(loop_vars)

  def aug_body():
    body(iter_[iterate_index.value])
    iterate_index.value += 1

  def aug_test():
    main_test = iterate_index.value < n
    if extra_test is not None:
      return control_flow_ops.cond(main_test, extra_test, lambda: False)
    return main_test

  # TODO(b/159186914): Remove.
  if not control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
    opts['maximum_iterations'] = n

  _tf_while_stmt(
      aug_test,
      aug_body,
      aug_get_state,
      aug_set_state,
      ('<internal iterate>',) + symbol_names,
      opts)
示例#15
0
def outside_compilation_scope(name="outside"):
  """Provides a scope for placing operations on the host, outside the current
  compilation scope. The operations will be placed on the default host device.
  This allows for offloading computations from the IPU to the host, which can
  be useful for operations that are not supported or suitable for execution on
  the IPU.

  Example:

  .. code-block:: python

    def my_net(a):
      with ipu_scope("/device:IPU:0"):
        b = a * a
        with outside_compilation_scope():
          c = b + 2  # Placed on the host.
        d = b + c
        return d

  Args:
    name: A name for the outside compilation scope.

  Returns:
    A context
  """
  graph = ops.get_default_graph()

  if not control_flow_util.GraphOrParentsInXlaContext(graph):
    raise ValueError(
        "outside_compilation_scope is only allowed in XLA context")

  current_attrs = graph._attr_scope_map  # pylint: disable=protected-access
  if OUTSIDE_COMPILATION_NAME in current_attrs:
    raise ValueError("Illegal nesting of outside_compilation_scope")

  unique_name = graph.unique_name(name, mark_as_used=True)
  attr_value = attr_value_pb2.AttrValue(s=unique_name.encode())
  attrs = {OUTSIDE_COMPILATION_NAME: attr_value}

  # Use a name scope to reduce the risk of op name collisions when
  # moving ops from the current graph to the outside graph.
  with ops.name_scope(unique_name), \
      graph._attr_scope(attrs):  # pylint: disable=protected-access
    yield
示例#16
0
  def _maybe_warn_increased_dof(self,
                                component_name,
                                component_ldj,
                                increased_dof):
    """Warns or raises when `increased_dof` is True."""
    # Short-circuit when the component LDJ is statically zero.
    if (tf.get_static_value(tf.rank(component_ldj)) == 0
        and tf.get_static_value(component_ldj) == 0):
      return

    # Short-circuit when increased_dof is statically False.
    increased_dof_ = tf.get_static_value(increased_dof)
    if increased_dof_ is False:  # pylint: disable=g-bool-id-comparison
      return

    error_message = (
        'Nested component "{}" in composition "{}" operates on inputs '
        'with increased degrees of freedom. This may result in an '
        'incorrect log_det_jacobian.'
        ).format(component_name, self.name)

    # When validate_args is True, we raise on increased DoF.
    if self._validate_args:
      if increased_dof_:
        raise ValueError(error_message)
      return assert_util.assert_equal(False, increased_dof, error_message)

    if (not tf.executing_eagerly() and
        control_flow_util.GraphOrParentsInXlaContext(tf1.get_default_graph())):
      return  # No StringFormat or Print ops in XLA.

    # Otherwise, we print a warning and continue.
    return ps.cond(
        pred=increased_dof,
        false_fn=tf.no_op,
        true_fn=lambda: tf.print(  # pylint: disable=g-long-lambda
            'WARNING: ' + error_message, output_stream=sys.stderr))
示例#17
0
def _windowed_adaptive_impl(n_draws,
                            joint_dist,
                            *,
                            kind,
                            n_chains,
                            proposal_kernel_kwargs,
                            num_adaptation_steps,
                            current_state,
                            dual_averaging_kwargs,
                            trace_fn,
                            return_final_kernel_results,
                            discard_tuning,
                            seed,
                            **pins):
  """Runs windowed sampling using either HMC or NUTS as internal sampler."""
  if trace_fn is None:
    trace_fn = lambda *args: ()
    no_trace = True
  else:
    no_trace = False

  if (tf.executing_eagerly() or
      not control_flow_util.GraphOrParentsInXlaContext(
          tf1.get_default_graph())):
    # A Tensor num_draws argument breaks XLA, which requires static TensorArray
    # trace_fn result allocation sizes.
    num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)

  setup_seed, init_seed, seed = samplers.split_seed(
      samplers.sanitize_seed(seed), n=3)
  (target_log_prob_fn, initial_transformed_position, bijector,
   step_broadcast, batch_shape) = _setup_mcmc(
       joint_dist,
       n_chains=n_chains,
       init_position=current_state,
       seed=setup_seed,
       **pins)

  if proposal_kernel_kwargs.get('step_size') is None:
    if batch_shape.shape != (0,):  # Scalar batch has a 0-vector shape.
      raise ValueError('Batch target density must specify init_step_size. Got '
                       f'batch shape {batch_shape} from joint {joint_dist}.')

    init_step_size = _get_step_size(initial_transformed_position,
                                    target_log_prob_fn)

  else:
    init_step_size = step_broadcast(proposal_kernel_kwargs['step_size'])

  proposal_kernel_kwargs.update({
      'target_log_prob_fn': target_log_prob_fn,
      'step_size': init_step_size,
      'momentum_distribution': _init_momentum(
          initial_transformed_position,
          batch_shape=ps.concat([[n_chains], batch_shape], axis=0))})

  first_window_size, slow_window_size, last_window_size = _get_window_sizes(
      num_adaptation_steps)

  all_traces = []
  # Using tf.function here and on _slow_window_closure caches tracing
  # of _fast_window and _slow_window, respectively, within a single
  # call to windowed sampling.  Why not annotate _fast_window and
  # _slow_window directly?  Two reasons:
  # - Caching across calls to windowed sampling is probably futile,
  #   because the trace function and bijector will be different Python
  #   objects, preventing cache hits.
  # - The cache of a global tf.function sticks around for the lifetime
  #   of the Python process, potentially leaking memory.
  @tf.function(autograph=False)
  def _fast_window_closure(proposal_kernel_kwargs,
                           window_size,
                           initial_position,
                           seed):
    return _fast_window(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=window_size,
        initial_position=initial_position,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=seed)
  draws, trace, step_size, running_variances = _fast_window_closure(
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      window_size=first_window_size,
      initial_position=initial_transformed_position,
      seed=init_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})

  all_draws = [[d] for d in draws]
  all_traces.append(trace)
  *slow_seeds, seed = samplers.split_seed(seed, n=5)
  @tf.function(autograph=False)
  def _slow_window_closure(proposal_kernel_kwargs,
                           window_size,
                           initial_position,
                           running_variances,
                           seed):
    return _slow_window(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=window_size,
        initial_position=initial_position,
        initial_running_variance=running_variances,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=seed)
  for idx, slow_seed in enumerate(slow_seeds):
    window_size = slow_window_size * (2**idx)

    # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
    (draws, trace, step_size, running_variances, momentum_distribution
     ) = _slow_window_closure(
         proposal_kernel_kwargs=proposal_kernel_kwargs,
         window_size=window_size,
         initial_position=[d[-1] for d in draws],
         running_variances=running_variances,
         seed=slow_seed)
    for all_d, d in zip(all_draws, draws):
      all_d.append(d)
    all_traces.append(trace)
    proposal_kernel_kwargs.update(
        {'step_size': step_size,
         'momentum_distribution': momentum_distribution})

  fast_seed, sample_seed = samplers.split_seed(seed)
  draws, trace, step_size, _ = _fast_window_closure(
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      window_size=last_window_size,
      initial_position=[d[-1] for d in draws],
      seed=fast_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})
  for all_d, d in zip(all_draws, draws):
    all_d.append(d)
  all_traces.append(trace)

  ret = _do_sampling(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      num_draws=n_draws,
      initial_position=[d[-1] for d in draws],
      bijector=bijector,
      trace_fn=trace_fn,
      return_final_kernel_results=return_final_kernel_results,
      seed=sample_seed)

  if discard_tuning:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(draws),
          trace=trace,
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      if no_trace:
        return bijector.inverse(draws)
      else:
        return sample.StatesAndTrace(all_states=bijector.inverse(draws),
                                     trace=trace)
  else:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      for all_d, d in zip(all_draws, draws):
        all_d.append(d)
      all_traces.append(trace)
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(
              [tf.concat(d, axis=0) for d in all_draws]),
          trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                      *all_traces, expand_composites=True),
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      for all_d, d in zip(all_draws, draws):
        all_d.append(d)
      all_states = bijector.inverse([tf.concat(d, axis=0) for d in all_draws])
      if no_trace:
        return all_states
      else:
        all_traces.append(trace)
        return sample.StatesAndTrace(
            all_states=all_states,
            trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                        *all_traces, expand_composites=True))
示例#18
0
def _IfGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of an If op produced by cond_v2."""
    # Get the if operator (this logic handles the case where op is a MockOp)
    if_op = op.outputs[0].op
    true_graph, false_graph = get_func_graphs(if_op)
    # Note: op.graph != ops.get_default_graph() when we are computing the gradient
    # of a nested cond.
    assert true_graph.outer_graph == if_op.graph
    assert false_graph.outer_graph == if_op.graph

    # Create grad functions that compute the gradient of the true/false forward
    # graphs. These functions will capture tensors from the forward pass
    # functions.
    true_grad_graph = _create_grad_func(
        true_graph, grads, util.unique_grad_fn_name(true_graph.name))
    false_grad_graph = _create_grad_func(
        false_graph, grads, util.unique_grad_fn_name(false_graph.name))

    if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
        # Modify 'op' to output the intermediates needed by the grad functions. Note
        # that all needed intermediates are wrapped in optionals. Each optional
        # intermediate output will have a value iff its corresponding branch is
        # taken.
        # NOTE(skyewm): if there are any active sessions, this modification to `op`
        # may make them unrunnable!

        if control_flow_util.GraphOrParentsInXlaContext(
                ops.get_default_graph()):
            # XLA does not yet support optionals, so output intermediates directly and
            # make them match via FakeParams, which can be converted to zeros in XLA.
            # TODO(skyewm,jpienaar): can XLA support optionals?
            true_intermediates = true_grad_graph.xla_intermediates
            false_intermediates = false_grad_graph.xla_intermediates
            extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
                [true_graph, false_graph],
                [true_intermediates, false_intermediates])
        else:
            true_intermediates = true_grad_graph.wrapped_intermediates
            false_intermediates = false_grad_graph.wrapped_intermediates
            # Make outputs match by adding none optionals.
            extra_true_outputs, extra_false_outputs = _make_intermediates_match(
                [true_graph, false_graph],
                [true_intermediates, false_intermediates])

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)
        # TODO(skyewm): indicate it's an internal bug if this fails.
        _check_same_outputs(_COND, [true_graph, false_graph])

        true_graph.name += "_rewritten"
        false_graph.name += "_rewritten"

        if_op._set_func_attr("then_branch",
                             util.create_new_tf_function(true_graph))
        if_op._set_func_attr("else_branch",
                             util.create_new_tf_function(false_graph))
        if_op._set_type_list_attr("Tout", true_graph.output_types)
        if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes)
        if_op._add_outputs([t.dtype for t in extra_true_outputs],
                           [t.shape for t in extra_true_outputs])

    # Resolve references to forward graph tensors in grad graphs and ensure
    # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
    true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
    false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)

    # This modifies true_grad_graph and false_grad_graph.
    _make_output_composite_tensors_match(_COND,
                                         [true_grad_graph, false_grad_graph])

    outputs = _build_cond(
        if_op.inputs[0],
        true_grad_graph,
        false_grad_graph,
        true_grad_inputs,
        false_grad_inputs,
        building_gradient=True,
    )

    # The predicate has no gradient.
    return [None] + outputs
 def in_xla_context(self):
     return control_flow_util.GraphOrParentsInXlaContext(
         tf1.get_default_graph())
示例#20
0
def dynamic_decode(decoder,
                   output_time_major: bool = False,
                   impute_finished: bool = False,
                   maximum_iterations=None,
                   parallel_iterations: int = 32,
                   swap_memory: bool = False,
                   training=None,
                   scope=None,
                   **kwargs):
    """Perform dynamic decoding with `decoder`.
    Calls initialize() once and step() repeatedly on the Decoder object.
    Args:
      decoder: A `Decoder` instance.
      output_time_major: Python boolean.  Default: `False` (batch major). If
        `True`, outputs are returned as time major tensors (this mode is
        faster). Otherwise, outputs are returned as batch major tensors (this
        adds extra time to the computation).
      impute_finished: Python boolean.  If `True`, then states for batch
        entries which are marked as finished get copied through and the
        corresponding outputs get zeroed out.  This causes some slowdown at
        each time step, but ensures that the final state and outputs have
        the correct values and that backprop ignores time steps that were
        marked as finished.
      maximum_iterations: A strictly positive `int32` scalar, the maximum
         allowed number of decoding steps. Default is `None` (decode until the
         decoder is fully done).
      parallel_iterations: Argument passed to `tf.while_loop`.
      swap_memory: Argument passed to `tf.while_loop`.
      training: Python boolean. Indicates whether the layer should behave
          in training  mode or in inference mode. Only relevant
          when `dropout` or `recurrent_dropout` is used.
      scope: Optional name scope to use.
      **kwargs: dict, other keyword arguments for dynamic_decode. It might
        contain arguments for `BaseDecoder` to initialize, which takes all
        tensor inputs during call().
    Returns:
      `(final_outputs, final_state, final_sequence_lengths)`.
    Raises:
      ValueError: if `maximum_iterations` is provided but is not a scalar.
    """
    with variable_scope.variable_scope(scope, 'decoder') as varscope:
        ctxt = ops.get_default_graph()._get_control_flow_context()
        is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
        in_while_loop = (control_flow_util.GetContainingWhileContext(ctxt)
                         is not None)

        if not context.executing_eagerly() and not in_while_loop:
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        is_xla = not tf.executing_eagerly(
        ) and control_flow_util.GraphOrParentsInXlaContext(
            tf.compat.v1.get_default_graph())

        if maximum_iterations is not None:
            maximum_iterations = tf.convert_to_tensor(
                maximum_iterations,
                dtype=tf.int32,
                name='maximum_iterations',
            )
            if maximum_iterations.shape.ndims != 0:
                raise ValueError('maximum_iterations must be a scalar')
            tf.debugging.assert_greater(
                maximum_iterations,
                0,
                message='maximum_iterations should be greater than 0',
            )
        elif is_xla:
            raise ValueError(
                'maximum_iterations is required for XLA compilation.')

        if isinstance(decoder, Decoder):
            initial_finished, initial_inputs, initial_state = (
                decoder.initialize())
        else:
            # For BaseDecoder that takes tensor inputs during call.
            decoder_init_input = kwargs.pop('decoder_init_input', None)
            decoder_init_kwargs = kwargs.pop('decoder_init_kwargs', {})
            initial_finished, initial_inputs, initial_state = decoder.initialize(
                decoder_init_input, **decoder_init_kwargs)

        zero_outputs = tf.nest.map_structure(
            lambda shape, dtype: tf.zeros(
                _prepend_batch(decoder.batch_size, shape), dtype=dtype),
            decoder.output_size,
            decoder.output_dtype,
        )

        if maximum_iterations is not None:
            initial_finished = tf.logical_or(initial_finished,
                                             0 >= maximum_iterations)
        initial_sequence_lengths = tf.zeros_like(initial_finished,
                                                 dtype=tf.int32)
        initial_time = tf.constant(0, dtype=tf.int32)

        def _shape(batch_size, from_shape):
            if (not isinstance(from_shape, tf.TensorShape)
                    or from_shape.ndims == 0):
                return None
            else:
                batch_size = tf.get_static_value(
                    tf.convert_to_tensor(batch_size, name='batch_size'))
                return tf.TensorShape([batch_size]).concatenate(from_shape)

        dynamic_size = maximum_iterations is None or not is_xla
        # The dynamic shape `TensoArray` is not allowed in TFLite yet.
        dynamic_size = dynamic_size

        def _create_ta(s, d):
            return tf.TensorArray(
                dtype=d,
                size=0 if dynamic_size else maximum_iterations,
                dynamic_size=dynamic_size,
                element_shape=_shape(decoder.batch_size, s),
            )

        initial_outputs_ta = tf.nest.map_structure(_create_ta,
                                                   decoder.output_size,
                                                   decoder.output_dtype)

        def condition(
            unused_time,
            unused_outputs_ta,
            unused_state,
            unused_inputs,
            finished,
            unused_sequence_lengths,
        ):
            return tf.logical_not(tf.reduce_all(finished))

        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.
            Args:
              time: scalar int32 tensor.
              outputs_ta: structure of TensorArray.
              state: (structure of) state tensors and TensorArrays.
              inputs: (structure of) input tensors.
              finished: bool tensor (keeping track of what's finished).
              sequence_lengths: int32 tensor (keeping track of time of finish).
            Returns:
              `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
                next_sequence_lengths)`.
              ```
            """
            (
                next_outputs,
                decoder_state,
                next_inputs,
                decoder_finished,
            ) = decoder.step(time, inputs, state, training)
            decoder_state_sequence_lengths = False
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
                lengths = getattr(decoder_state, 'lengths', None)
                if lengths is not None:
                    # sequence lengths are provided by decoder_state.lengths;
                    # overwrite our sequence lengths.
                    decoder_state_sequence_lengths = True
                    sequence_lengths = tf.cast(lengths, tf.int32)
            else:
                next_finished = tf.logical_or(decoder_finished, finished)

            if decoder_state_sequence_lengths:
                # Just pass something through the loop; at the next iteration
                # we'll pull the sequence lengths from the decoder_state again.
                next_sequence_lengths = sequence_lengths
            else:
                next_sequence_lengths = tf.where(
                    tf.logical_not(finished),
                    tf.fill(tf.shape(sequence_lengths), time + 1),
                    sequence_lengths,
                )

            tf.nest.assert_same_structure(state, decoder_state)
            tf.nest.assert_same_structure(outputs_ta, next_outputs)
            tf.nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:

                def zero_out_finished(out, zero):
                    if finished.shape.rank < zero.shape.rank:
                        broadcast_finished = tf.broadcast_to(
                            tf.expand_dims(finished, axis=-1), zero.shape)
                        return tf.where(broadcast_finished, zero, out)
                    else:
                        return tf.where(finished, zero, out)

                emit = tf.nest.map_structure(zero_out_finished, next_outputs,
                                             zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tf.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = new.shape.ndims == 0
                if not pass_through:
                    broadcast_finished = tf.broadcast_to(
                        tf.expand_dims(finished, axis=-1), new.shape)
                    return tf.where(broadcast_finished, cur, new)
                else:
                    return new

            if impute_finished:
                next_state = tf.nest.map_structure(_maybe_copy_state,
                                                   decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = tf.nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (
                time + 1,
                outputs_ta,
                next_state,
                next_inputs,
                next_finished,
                next_sequence_lengths,
            )

        res = tf.while_loop(
            condition,
            body,
            loop_vars=(
                initial_time,
                initial_outputs_ta,
                initial_state,
                initial_inputs,
                initial_finished,
                initial_sequence_lengths,
            ),
            parallel_iterations=parallel_iterations,
            maximum_iterations=maximum_iterations,
            swap_memory=swap_memory,
        )

        final_outputs_ta = res[1]
        final_state = res[2]
        final_sequence_lengths = res[5]

        final_outputs = tf.nest.map_structure(lambda ta: ta.stack(),
                                              final_outputs_ta)

        try:
            final_outputs, final_state = decoder.finalize(
                final_outputs, final_state, final_sequence_lengths)
        except NotImplementedError:
            pass

        if not output_time_major:

            final_outputs = tf.nest.map_structure(_transpose_batch_time,
                                                  final_outputs)

    return final_outputs, final_state, final_sequence_lengths
示例#21
0
def _create_keras_history_helper(tensors, processed_ops, created_layers):
    """Helper method for `create_keras_history`.

  Arguments:
    tensors: A structure of Tensors for which to create Keras metadata.
    processed_ops: Set. TensorFlow operations that have already been wrapped in
      `TensorFlowOpLayer` instances.
    created_layers: List. The `TensorFlowOpLayer` instances created.

  Returns:
    Tuple. First element is the updated set of TensorFlow Operations that
    have been wrapped in `TensorFlowOpLayer` instances. Second element is
    a list of the `TensorFlowOpLayer` instances created.
  """
    # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
    # Cannot be imported at top because of circular dependencies.
    # TODO(omalleyt): Resolve circular dependency.
    from tensorflow.python.keras.engine import base_layer  # pylint: disable=g-import-not-at-top
    tensor_list = nest.flatten(tensors)
    for tensor in tensor_list:
        if getattr(tensor, '_keras_history', None) is not None:
            continue
        op = tensor.op  # The Op that created this Tensor.
        if op not in processed_ops:
            if op.type.startswith('Sparse'):
                lambda_example = """
        weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
        output = tf.keras.layers.Lambda(weights_mult)(input)
        """
                raise ValueError(
                    'Sparse ops are not supported with functional models with built-in '
                    'layer wrapping. Please wrap the sparse ops in a Lambda layer like'
                    ': \n{lambda_example}\n'.format(
                        lambda_example=lambda_example))

            # Recursively set `_keras_history`.
            op_inputs = list(op.inputs)
            constants = {}
            layer_inputs = []
            for i, op_input in enumerate(op_inputs):
                if uses_keras_history(op_input):
                    layer_inputs.append(op_input)
                else:
                    # Treat any value not originating from a `keras.Input` as
                    # a constant. Variables cannot be supported.
                    ds_with_session = (
                        distribution_strategy_context.in_cross_replica_context(
                        ) and not ops.executing_eagerly_outside_functions())
                    using_xla = control_flow_util.GraphOrParentsInXlaContext(
                        ops.get_default_graph())
                    if ds_with_session or using_xla:
                        # In Legacy Graph mode, evaluating here makes Session be
                        # configured improperly. The downside of this is that saving
                        # via `get_config` breaks, but SavedModel still works.
                        constants[i] = op_input
                    else:
                        with ops.init_scope():
                            if ops.executing_eagerly_outside_functions():
                                constants[
                                    i] = backend.eval_in_eager_or_function(
                                        op_input)
                            else:
                                constants[i] = backend.function([],
                                                                op_input)([])
            layer_inputs = unnest_if_single_tensor(layer_inputs)
            processed_ops, created_layers = _create_keras_history_helper(
                layer_inputs, processed_ops, created_layers)
            name = op.name
            node_def = op.node_def.SerializeToString()
            op_layer = base_layer.TensorFlowOpLayer(node_def,
                                                    constants=constants,
                                                    name=name)
            created_layers.append(op_layer)
            op_layer._set_connectivity_metadata(  # pylint: disable=protected-access
                args=(layer_inputs, ),
                kwargs={},
                outputs=op.outputs)
            processed_ops.update([op])
    return processed_ops, created_layers
def _is_on_tpu():
    return control_flow_util.GraphOrParentsInXlaContext(
        framework_ops.get_default_graph())
示例#23
0
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains,
                            proposal_kernel_kwargs, num_adaptation_steps,
                            current_state, dual_averaging_kwargs, trace_fn,
                            return_final_kernel_results, discard_tuning, seed,
                            chain_axis_names, **pins):
    """Runs windowed sampling using either HMC or NUTS as internal sampler."""
    if trace_fn is None:
        trace_fn = lambda *args: ()
        no_trace = True
    else:
        no_trace = False

    if isinstance(n_chains, int):
        n_chains = [n_chains]

    if (tf.executing_eagerly()
            or not control_flow_util.GraphOrParentsInXlaContext(
                tf1.get_default_graph())):
        # A Tensor num_draws argument breaks XLA, which requires static TensorArray
        # trace_fn result allocation sizes.
        num_adaptation_steps = ps.convert_to_shape_tensor(num_adaptation_steps)

    if 'num_adaptation_steps' in dual_averaging_kwargs:
        warnings.warn(
            'Dual averaging adaptation will use the value specified in'
            ' the `num_adaptation_steps` argument for its construction,'
            ' hence there is no need to specify it in the'
            ' `dual_averaging_kwargs` argument.')

    # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
    dual_averaging_kwargs['num_adaptation_steps'] = num_adaptation_steps
    dual_averaging_kwargs.setdefault(
        'reduce_fn',
        functools.partial(
            generic_math.reduce_log_harmonic_mean_exp,
            # There is only one log_accept_prob per chain, and we reduce across
            # all chains, so typically the all_gather will be gathering scalars,
            # which should be relatively efficient.
            experimental_allow_all_gather=True))
    # By default, reduce over named axes for step size adaptation
    dual_averaging_kwargs.setdefault('experimental_reduce_chain_axis_names',
                                     chain_axis_names)
    setup_seed, sample_seed = samplers.split_seed(samplers.sanitize_seed(seed),
                                                  n=2)
    (target_log_prob_fn, initial_transformed_position, bijector,
     step_broadcast, batch_shape,
     shard_axis_names) = _setup_mcmc(joint_dist,
                                     n_chains=n_chains,
                                     init_position=current_state,
                                     seed=setup_seed,
                                     **pins)

    if proposal_kernel_kwargs.get('step_size') is None:
        if batch_shape.shape != (0, ):  # Scalar batch has a 0-vector shape.
            raise ValueError(
                'Batch target density must specify init_step_size. Got '
                f'batch shape {batch_shape} from joint {joint_dist}.')

        init_step_size = _get_step_size(initial_transformed_position,
                                        target_log_prob_fn)

    else:
        init_step_size = step_broadcast(proposal_kernel_kwargs['step_size'])

    proposal_kernel_kwargs.update({
        'target_log_prob_fn':
        target_log_prob_fn,
        'step_size':
        init_step_size,
        'momentum_distribution':
        _init_momentum(initial_transformed_position,
                       batch_shape=ps.concat([n_chains, batch_shape], axis=0),
                       shard_axis_names=shard_axis_names)
    })

    initial_running_variance = [
        sample_stats.RunningVariance.from_stats(  # pylint: disable=g-complex-comprehension
            num_samples=tf.zeros([], part.dtype),
            mean=tf.zeros_like(part),
            variance=tf.ones_like(part))
        for part in initial_transformed_position
    ]
    # TODO(phandu): Consider splitting out warmup and post warmup phases
    # to avoid executing adaptation code during the post warmup phase.
    ret = _do_sampling(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=n_draws if discard_tuning else n_draws +
        num_adaptation_steps,
        num_burnin_steps=num_adaptation_steps if discard_tuning else 0,
        initial_position=initial_transformed_position,
        initial_running_variance=initial_running_variance,
        bijector=bijector,
        trace_fn=trace_fn,
        return_final_kernel_results=return_final_kernel_results,
        chain_axis_names=chain_axis_names,
        shard_axis_names=shard_axis_names,
        seed=sample_seed)

    if return_final_kernel_results:
        draws, trace, fkr = ret
        return sample.CheckpointableStatesAndTrace(
            all_states=bijector.inverse(draws),
            trace=trace,
            final_kernel_results=fkr)
    else:
        draws, trace = ret
        if no_trace:
            return bijector.inverse(draws)
        else:
            return sample.StatesAndTrace(all_states=bijector.inverse(draws),
                                         trace=trace)
示例#24
0
def _CaseGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a Case op produced by tf.switch_case."""
    # Get the Case operator (this logic handles the case where op is a MockOp)
    case_op = op.outputs[0].op
    branch_graphs = get_func_graphs(case_op)
    assert branch_graphs
    # Note: op.graph != ops.get_default_graph() when we are computing the gradient
    # of a nested cond.
    for branch_graph in branch_graphs:
        assert branch_graph.outer_graph == case_op.graph

    # Create grad functions that compute the gradient of the branch forward
    # graphs. These functions will capture tensors from the forward pass
    # functions.
    branch_grad_graphs = []
    for branch_graph in branch_graphs:
        branch_grad_graphs.append(
            _create_grad_func(branch_graph, grads,
                              util.unique_grad_fn_name(branch_graph.name)))

    if any(g.op_needs_rewrite for g in branch_grad_graphs):
        # Modify 'op' to output the intermediates needed by the grad functions. Note
        # that all needed intermediates are wrapped in optionals. Each optional
        # intermediate output will have a value iff its corresponding branch is
        # taken.
        # NOTE(bjp): if there are any active sessions, this modification to `op`
        # may make them unrunnable!

        if control_flow_util.GraphOrParentsInXlaContext(
                ops.get_default_graph()):
            # XLA does not yet support optionals, so output intermediates directly and
            # make them match via FakeParams, which can be converted to zeros in XLA.
            # TODO(bjp,jpienaar): can XLA support optionals?
            branches_intermediates = [
                branch_grad_graph.xla_intermediates
                for branch_grad_graph in branch_grad_graphs
            ]
            extra_branch_outputs = _make_intermediates_match_xla(
                branch_graphs, branches_intermediates)
        else:
            branch_intermediates = [
                g.wrapped_intermediates for g in branch_grad_graphs
            ]
            # Make outputs match by adding none optionals.
            extra_branch_outputs = _make_intermediates_match(
                branch_graphs, branch_intermediates)

        for branch_graph, extra_outputs in zip(branch_graphs,
                                               extra_branch_outputs):
            branch_graph.outputs.extend(extra_outputs)
        # TODO(bjp): indicate it's an internal bug if this fails.
        _check_same_outputs(_CASE, branch_graphs)

        for branch_graph in branch_graphs:
            branch_graph.name += "_rewritten"

        case_op._set_func_list_attr("branches", [
            util.create_new_tf_function(branch_graph)
            for branch_graph in branch_graphs
        ])
        case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
        case_op._set_shape_list_attr("output_shapes",
                                     branch_graphs[0].output_shapes)
        case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
                             [t.shape for t in extra_branch_outputs[0]])

    # Resolve references to forward graph tensors in grad graphs and ensure
    # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
    branches_grad_inputs = [
        _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
        branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
    ]

    # This modifies the graphs in branch_grad_graphs.
    _make_output_composite_tensors_match(_CASE, branch_grad_graphs)

    outputs = _build_case(case_op.inputs[0],
                          branch_grad_graphs,
                          branches_grad_inputs,
                          name="gradient")

    # The predicate has no gradient.
    return [None] + outputs
示例#25
0
def _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
                       init_vars, basic_symbol_names, composite_symbol_names,
                       opts):
    """Overload of for_stmt that iterates over a TF range (and elides it)."""
    _disallow_undefs_into_loop(*init_vars)

    start, limit, delta = iter_.op.inputs

    def while_body(iterate, *loop_vars):
        new_vars = body(iterate, *loop_vars)
        loop_vars = (iterate + delta, )

        if new_vars:
            loop_vars += new_vars

        return loop_vars

    def while_cond(iterate, *loop_vars):
        """Cond function for `tf.while_loop`."""
        def build_main_test():
            """Main iteration condition."""
            # TODO(b/138857806): The optimizer should handle this.
            # LogicalAnd is slow on GPU so we avoid adding it if `delta` is a
            # compile time constant.
            delta_const = tensor_util.constant_value(delta)
            if delta_const is not None:
                # Support single element arrays.
                delta_const = np.asscalar(delta_const)
                if delta_const >= 0:
                    return iterate < limit
                else:
                    return iterate > limit
            else:
                return math_ops.logical_or(
                    math_ops.logical_and(delta >= 0, iterate < limit),
                    math_ops.logical_and(delta < 0, iterate > limit))

        main_test = build_main_test()
        if extra_test is not None:
            return control_flow_ops.cond(
                main_test,
                lambda: extra_test(*loop_vars),
                lambda: False,
            )
        return main_test

    # TODO(b/134181679): The op should handle this optimizations.
    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
        # This specific dtype is required by while_loop.
        opts['maximum_iterations'] = math_ops.cast(
            misc.get_range_len(start, limit, delta), dtypes.int32)

    results = _tf_while_stmt(
        while_cond,
        while_body,
        get_state,
        set_state,
        (start, ) + init_vars,
        ('<internal iterate>', ) + basic_symbol_names,
        composite_symbol_names,
        opts,
    )

    # Note: the iteration index is not returned by the while loop, however
    # if a symbol with the same name exists outside the loop, it will be captured
    # by the loop variables and ultimately updated correctly.
    if isinstance(results, (tuple, list)):
        assert len(results) >= 1  # Has at least the iterate.
        if len(results) > 1:
            results = results[1:]
    else:
        results = ()

    return results
示例#26
0
def trace_scan(loop_fn,
               initial_state,
               elems,
               trace_fn,
               trace_criterion_fn=None,
               static_trace_allocation_size=None,
               parallel_iterations=10,
               name=None):
  """A simplified version of `tf.scan` that has configurable tracing.

  This function repeatedly calls `loop_fn(state, elem)`, where `state` is the
  `initial_state` during the first iteration, and the return value of `loop_fn`
  for every iteration thereafter. `elem` is a slice of `elements` along the
  first dimension, accessed in order. Additionally, it calls `trace_fn` on the
  return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are
  stacked and returned from this function, such that the first dimension of
  those `Tensor`s matches the size of `elems`.

  Args:
    loop_fn: A callable that takes in a `Tensor` or a nested collection of
      `Tensor`s with the same structure as `initial_state`, a slice of `elems`
      and returns the same structure as `initial_state`.
    initial_state: A `Tensor` or a nested collection of `Tensor`s passed to
      `loop_fn` in the first iteration.
    elems: A `Tensor` that is split along the first dimension and each element
      of which is passed to `loop_fn`.
    trace_fn: A callable that takes in the return value of `loop_fn` and returns
      a `Tensor` or a nested collection of `Tensor`s.
    trace_criterion_fn: Optional callable that takes in the return value of
      `loop_fn` and returns a boolean `Tensor` indicating whether to trace it.
      If `None`, all steps are traced.
      Default value: `None`.
    static_trace_allocation_size: Optional Python `int` size of trace to
      allocate statically. This should be an upper bound on the number of steps
      traced and is used only when the length cannot be
      statically inferred (for example, if a `trace_criterion_fn` is specified).
      It is primarily intended for contexts where static shapes are required,
      such as in XLA-compiled code.
      Default value: `None`.
    parallel_iterations: Passed to the internal `tf.while_loop`.
    name: Name scope used in this function. Default: 'trace_scan'.

  Returns:
    final_state: The final return value of `loop_fn`.
    trace: The same structure as the return value of `trace_fn`, but with each
      `Tensor` being a stack of the corresponding `Tensors` in the return value
      of `trace_fn` for each slice of `elems`.
  """
  with tf.name_scope(name or 'trace_scan'), tf1.variable_scope(
      tf1.get_variable_scope()) as vs:
    if vs.caching_device is None and not tf.executing_eagerly():
      vs.set_caching_device(lambda op: op.device)

    initial_state = tf.nest.map_structure(
        lambda x: tf.convert_to_tensor(x, name='initial_state'),
        initial_state, expand_composites=True)
    elems = tf.convert_to_tensor(elems, name='elems')

    length = ps.size0(elems)

    # This is an TensorArray in part because of XLA, which had trouble with
    # non-statically known indices. I.e. elems[i] errored, but
    # elems_array.read(i) worked.
    elems_array = tf.TensorArray(
        elems.dtype, size=length, element_shape=elems.shape[1:])
    elems_array = elems_array.unstack(elems)

    # Initialize trace arrays.
    if trace_criterion_fn is None:
      dynamic_size, initial_size = tf.is_tensor(length), length
    elif static_trace_allocation_size is not None:
      dynamic_size, initial_size = False, static_trace_allocation_size
    elif JAX_MODE or (not tf.executing_eagerly() and
                      control_flow_util.GraphOrParentsInXlaContext(
                          tf1.get_default_graph())):
      dynamic_size, initial_size = False, length
    else:
      dynamic_size, initial_size = True, 0
    initial_trace = trace_fn(initial_state)
    flat_initial_trace = tf.nest.flatten(initial_trace, expand_composites=True)
    trace_arrays = []
    for trace_elt in flat_initial_trace:
      trace_arrays.append(
          tf.TensorArray(
              trace_elt.dtype,
              size=initial_size,
              dynamic_size=dynamic_size,
              element_shape=trace_elt.shape))

    # Helper for writing a (structured) state to (structured) arrays.
    def trace_one_step(num_steps_traced, trace_arrays, state):
      return [ta.write(num_steps_traced, x) for ta, x in
              zip(trace_arrays,
                  tf.nest.flatten(trace_fn(state), expand_composites=True))]

    def _body(i, state, num_steps_traced, trace_arrays):
      elem = elems_array.read(i)
      state = loop_fn(state, elem)

      trace_arrays, num_steps_traced = ps.cond(
          trace_criterion_fn(state) if trace_criterion_fn else True,
          lambda: (trace_one_step(num_steps_traced, trace_arrays, state),  # pylint: disable=g-long-lambda
                   num_steps_traced + 1),
          lambda: (trace_arrays, num_steps_traced))

      return i + 1, state, num_steps_traced, trace_arrays

    _, final_state, _, trace_arrays = tf.while_loop(
        cond=lambda i, *_: i < length,
        body=_body,
        loop_vars=(0, initial_state, 0, trace_arrays),
        parallel_iterations=parallel_iterations)

    # unflatten
    stacked_trace = tf.nest.pack_sequence_as(
        initial_trace, [ta.stack() for ta in trace_arrays],
        expand_composites=True)

    # Restore the static length if we know it.
    static_length = tf.TensorShape(None if dynamic_size else initial_size)

    def _merge_static_length(x):
      tensorshape_util.set_shape(x, static_length.concatenate(x.shape[1:]))
      return x

    stacked_trace = tf.nest.map_structure(
        _merge_static_length, stacked_trace, expand_composites=True)
    return final_state, stacked_trace
示例#27
0
def dynamic_decode(decoder,
                   output_time_major=False,
                   impute_finished=False,
                   maximum_iterations=None,
                   parallel_iterations=32,
                   swap_memory=False,
                   scope=None,
                   **kwargs):
  """Perform dynamic decoding with `decoder`.

  Calls initialize() once and step() repeatedly on the Decoder object.

  Args:
    decoder: A `Decoder` instance.
    output_time_major: Python boolean.  Default: `False` (batch major).  If
      `True`, outputs are returned as time major tensors (this mode is faster).
      Otherwise, outputs are returned as batch major tensors (this adds extra
      time to the computation).
    impute_finished: Python boolean.  If `True`, then states for batch
      entries which are marked as finished get copied through and the
      corresponding outputs get zeroed out.  This causes some slowdown at
      each time step, but ensures that the final state and outputs have
      the correct values and that backprop ignores time steps that were
      marked as finished.
    maximum_iterations: `int32` scalar, maximum allowed number of decoding
       steps.  Default is `None` (decode until the decoder is fully done).
    parallel_iterations: Argument passed to `tf.while_loop`.
    swap_memory: Argument passed to `tf.while_loop`.
    scope: Optional variable scope to use.
    **kwargs: dict, other keyword arguments for dynamic_decode. It might contain
      arguments for `BaseDecoder` to initialize, which takes all tensor inputs
      during call().

  Returns:
    `(final_outputs, final_state, final_sequence_lengths)`.

  Raises:
    TypeError: if `decoder` is not an instance of `Decoder`.
    ValueError: if `maximum_iterations` is provided but is not a scalar.
  """
  if not isinstance(decoder, (Decoder, BaseDecoder)):
    raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
                    type(decoder))

  with tf.variable_scope(scope, "decoder") as varscope:
    # Enable variable caching if it is safe to do so.
    if _should_cache_variables():
      if varscope.caching_device is None:
        varscope.set_caching_device(lambda op: op.device)

    if maximum_iterations is not None:
      maximum_iterations = tf.convert_to_tensor(
          maximum_iterations, dtype=tf.int32, name="maximum_iterations")
      if maximum_iterations.get_shape().ndims != 0:
        raise ValueError("maximum_iterations must be a scalar")

    if isinstance(decoder, Decoder):
      initial_finished, initial_inputs, initial_state = decoder.initialize()
    else:
      # For BaseDecoder that takes tensor inputs during call.
      decoder_init_input = kwargs.pop("decoder_init_input", None)
      decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {})
      initial_finished, initial_inputs, initial_state = decoder.initialize(
          decoder_init_input, **decoder_init_kwargs)

    zero_outputs = _create_zero_outputs(decoder.output_size,
                                        decoder.output_dtype,
                                        decoder.batch_size)

    # If we are in an XLA context, we set maximum_iterations on the while loop
    # and set a fixed size for TensorArrays.
    is_xla = control_flow_util.GraphOrParentsInXlaContext(
        tf.get_default_graph())
    if is_xla and maximum_iterations is None:
      raise ValueError("maximum_iterations is required for XLA compilation.")
    if maximum_iterations is not None:
      initial_finished = tf.logical_or(
          initial_finished, 0 >= maximum_iterations)
    initial_sequence_lengths = tf.zeros_like(
        initial_finished, dtype=tf.int32)
    initial_time = tf.constant(0, dtype=tf.int32)

    def _shape(batch_size, from_shape):
      if (not isinstance(from_shape, tf.TensorShape) or
          from_shape.ndims == 0):
        return None
      else:
        batch_size = tensor_util.constant_value(
            tf.convert_to_tensor(batch_size, name="batch_size"))
        return tf.TensorShape([batch_size]).concatenate(from_shape)

    dynamic_size = maximum_iterations is None or not is_xla

    def _create_ta(s, d):
      return tf.TensorArray(
          dtype=d,
          size=0 if dynamic_size else maximum_iterations,
          dynamic_size=dynamic_size,
          element_shape=_shape(decoder.batch_size, s))

    initial_outputs_ta = tf.nest.map_structure(_create_ta, decoder.output_size,
                                               decoder.output_dtype)

    def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
                  finished, unused_sequence_lengths):
      return tf.logical_not(tf.reduce_all(finished))

    def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
      """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
      (next_outputs, decoder_state, next_inputs,
       decoder_finished) = decoder.step(time, inputs, state)
      decoder_state_sequence_lengths = False
      if decoder.tracks_own_finished:
        next_finished = decoder_finished
        lengths = getattr(decoder_state, "lengths", None)
        if lengths is not None:
          # sequence lengths are provided by decoder_state.lengths; overwrite
          # our sequence lengths.
          decoder_state_sequence_lengths = True
          sequence_lengths = tf.cast(lengths, tf.int32)
      else:
        next_finished = tf.logical_or(decoder_finished, finished)

      if decoder_state_sequence_lengths:
        # Just pass something through the loop; at the next iteration we'll pull
        # the sequence lengths from the decoder_state again.
        next_sequence_lengths = sequence_lengths
      else:
        next_sequence_lengths = tf.where(
            tf.logical_not(finished),
            tf.fill(tf.shape(sequence_lengths), time + 1),
            sequence_lengths)

      tf.nest.assert_same_structure(state, decoder_state)
      tf.nest.assert_same_structure(outputs_ta, next_outputs)
      tf.nest.assert_same_structure(inputs, next_inputs)

      # Zero out output values past finish
      if impute_finished:
        emit = tf.nest.map_structure(
            lambda out, zero: tf.where(finished, zero, out),
            next_outputs,
            zero_outputs)
      else:
        emit = next_outputs

      # Copy through states past finish
      def _maybe_copy_state(new, cur):
        # TensorArrays and scalar states get passed through.
        if isinstance(cur, tf.TensorArray):
          pass_through = True
        else:
          new.set_shape(cur.shape)
          pass_through = (new.shape.ndims == 0)
        return new if pass_through else tf.where(finished, cur, new)

      if impute_finished:
        next_state = tf.nest.map_structure(
            _maybe_copy_state, decoder_state, state)
      else:
        next_state = decoder_state

      outputs_ta = tf.nest.map_structure(lambda ta, out: ta.write(time, out),
                                         outputs_ta, emit)
      return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
              next_sequence_lengths)

    res = tf.while_loop(
        condition,
        body,
        loop_vars=(
            initial_time,
            initial_outputs_ta,
            initial_state,
            initial_inputs,
            initial_finished,
            initial_sequence_lengths,
        ),
        parallel_iterations=parallel_iterations,
        maximum_iterations=maximum_iterations,
        swap_memory=swap_memory)

    final_outputs_ta = res[1]
    final_state = res[2]
    final_sequence_lengths = res[5]

    final_outputs = tf.nest.map_structure(
        lambda ta: ta.stack(), final_outputs_ta)

    try:
      final_outputs, final_state = decoder.finalize(
          final_outputs, final_state, final_sequence_lengths)
    except NotImplementedError:
      pass

    if not output_time_major:
      final_outputs = tf.nest.map_structure(
          _transpose_batch_time, final_outputs)

  return final_outputs, final_state, final_sequence_lengths
 def _use_merge_call(self):
     # We currently only disable merge_call when XLA is used to compile the `fn`
     # passed to `strategy.run` and all devices are GPU.
     return not control_flow_util.GraphOrParentsInXlaContext(
         ops.get_default_graph()) or not all(
             [_is_gpu_device(d) for d in self._devices])
示例#29
0
def _add_max_iterations_hint(opts, n):
    # TODO(b/159186914): Remove the safeguard, and always set maximum_iterations.
    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
        opts['maximum_iterations'] = n
  def _sample_n(self, n, seed=None):
    loc, scale, low, high = self._loc_scale_low_high()
    batch_shape = self._batch_shape_tensor(
        loc=loc, scale=scale, low=low, high=high)
    sample_and_batch_shape = ps.concat([[n], batch_shape], 0)
    # TODO(b/162522020): Use this behavior unconditionally.
    if (tf.executing_eagerly() or
        not control_flow_util.GraphOrParentsInXlaContext(
            tf1.get_default_graph())):
      return tf.random.stateless_parameterized_truncated_normal(
          shape=sample_and_batch_shape,
          means=loc,
          stddevs=scale,
          minvals=low,
          maxvals=high,
          seed=samplers.sanitize_seed(seed))

    flat_batch_and_sample_shape = tf.stack([tf.reduce_prod(batch_shape), n])
    # In order to be reparameterizable we sample on the truncated_normal of
    # unit variance and mean and scale (but with the standardized
    # truncation bounds).

    @tf.custom_gradient
    def _std_samples_with_gradients(lower, upper):
      """Standard truncated Normal with gradient support for low, high."""
      # Note: Unlike the convention in TFP, parameterized_truncated_normal
      # returns a tensor with the final dimension being the sample dimension.
      std_samples = random_ops.parameterized_truncated_normal(
          shape=flat_batch_and_sample_shape,
          means=0.0,
          stddevs=1.0,
          minvals=lower,
          maxvals=upper,
          dtype=self.dtype,
          seed=seed)

      def grad(dy):
        """Computes a derivative for the min and max parameters.

        This function implements the derivative wrt the truncation bounds, which
        get blocked by the sampler. We use a custom expression for numerical
        stability instead of automatic differentiation on CDF for implicit
        gradients.

        Args:
          dy: output gradients

        Returns:
           The standard normal samples and the gradients wrt the upper
           bound and lower bound.
        """
        # std_samples has an extra dimension (the sample dimension), expand
        # lower and upper so they broadcast along this dimension.
        # See note above regarding parameterized_truncated_normal, the sample
        # dimension is the final dimension.
        lower_broadcast = lower[..., tf.newaxis]
        upper_broadcast = upper[..., tf.newaxis]

        cdf_samples = ((special_math.ndtr(std_samples) -
                        special_math.ndtr(lower_broadcast)) /
                       (special_math.ndtr(upper_broadcast) -
                        special_math.ndtr(lower_broadcast)))

        # tiny, eps are tolerance parameters to ensure we stay away from giving
        # a zero arg to the log CDF expression.

        tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
        eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
        cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

        du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
                    tf.math.log(cdf_samples))
        dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +
                    tf.math.log1p(-cdf_samples))

        # Reduce the gradient across the samples
        grad_u = tf.reduce_sum(dy * du, axis=-1)
        grad_l = tf.reduce_sum(dy * dl, axis=-1)
        return [grad_l, grad_u]

      return std_samples, grad

    std_low, std_high = self._standardized_low_and_high(
        low=low, high=high, loc=loc, scale=scale)
    low_high_shp = tf.broadcast_dynamic_shape(
        tf.shape(std_low), tf.shape(std_high))
    std_low = tf.broadcast_to(std_low, low_high_shp)
    std_high = tf.broadcast_to(std_high, low_high_shp)

    std_samples = _std_samples_with_gradients(
        tf.reshape(std_low, [-1]), tf.reshape(std_high, [-1]))

    # The returned shape is [flat_batch x n]
    std_samples = tf.transpose(std_samples, perm=[1, 0])

    std_samples = tf.reshape(std_samples, sample_and_batch_shape)
    return std_samples * scale[tf.newaxis] + loc[tf.newaxis]