Exemple #1
0
def _log_average_probs_maybe_check_args(sample_axis, event_axis,
                                        validate_args):
    """Assertions for `log_average_probs`."""
    assertions = []
    msg = 'Arguments `sample_axis` and `event_axis` must be distinct.'
    sample_setdiff = ps.setdiff1d(sample_axis, event_axis)
    if ps.is_numpy(sample_setdiff):
        if not np.array_equal(sample_setdiff,
                              tf.get_static_value(sample_axis)):
            raise ValueError(msg)
    elif validate_args:
        assertions.append(
            _assert_array_equal(sample_setdiff,
                                sample_axis,
                                message=msg,
                                name='sample_setdiff_rank_check'))
    event_setdiff = ps.setdiff1d(event_axis, sample_axis)
    if ps.is_numpy(event_setdiff):
        if not np.array_equal(event_setdiff, tf.get_static_value(event_axis)):
            raise ValueError(msg)
    elif validate_args:
        assertions.append(
            _assert_array_equal(event_setdiff,
                                event_axis,
                                message=msg,
                                name='event_setdiff_rank_check'))
    return assertions
Exemple #2
0
def maybe_check_wont_broadcast(flat_xs, validate_args):
  """Verifies that `parts` don't broadcast."""
  flat_xs = tuple(flat_xs)  # So we can receive generators.
  if not validate_args:
    # Note: we don't try static validation because it is theoretically
    # possible that a user wants to take advantage of broadcasting.
    # Only when `validate_args` is `True` do we enforce the validation.
    return flat_xs
  msg = 'Broadcasting probably indicates an error in model specification.'
  s = tuple(prefer_static.shape(x) for x in flat_xs)
  if all(prefer_static.is_numpy(s_) for s_ in s):
    if not all(np.all(a == b) for a, b in zip(s[1:], s[:-1])):
      raise ValueError(msg)
    return flat_xs
  assertions = [assert_util.assert_equal(a, b, message=msg)
                for a, b in zip(s[1:], s[:-1])]
  with tf.control_dependencies(assertions):
    return tuple(tf.identity(x) for x in flat_xs)
Exemple #3
0
def _log_loosum_exp_impl(logx, axis, keepdims, compute_mean):
    """Implementation for `*loosum*` functions."""
    with tf.name_scope('log_loosum_exp_impl'):
        logx = tf.convert_to_tensor(logx, name='logx')
        dtype = dtype_util.as_numpy_dtype(logx.dtype)

        if axis is not None:
            x = np.array(axis)
            axis = (tf.convert_to_tensor(
                axis, name='axis', dtype_hint=tf.int32)
                    if x.dtype is np.object else x.astype(np.int32))

        log_sum_x = tf.reduce_logsumexp(logx, axis=axis, keepdims=True)

        # Later we'll want to compute the mean from a sum so we calculate the number
        # of reduced elements, n.
        n = prefer_static.size(logx) // prefer_static.size(log_sum_x)
        n = prefer_static.cast(n, dtype)

        # log_loosum_x[i] =
        # = logsumexp(logx[j] : j != i)
        # = log( exp(logsumexp(logx)) - exp(logx[i]) )
        # = log( exp(logsumexp(logx - logx[i])) exp(logx[i])  - exp(logx[i]))
        # = logx[i] + log(exp(logsumexp(logx - logx[i])) - 1)
        # = logx[i] + log(exp(logsumexp(logx) - logx[i]) - 1)
        # = logx[i] + softplus_inverse(logsumexp(logx) - logx[i])
        d = log_sum_x - logx
        # We use `d != 0` rather than `d > 0.` because `d < 0.` should never happen;
        # if it does we want to complain loudly (which `softplus_inverse` will).
        d_ok = tf.not_equal(d, 0.)
        safe_d = tf.where(d_ok, d, 1.)
        d_ok_result = logx + softplus_inverse(safe_d)

        neg_inf = tf.constant(-np.inf, dtype=dtype)

        # When not(d_ok) and is_positive_and_largest then we manually compute the
        # log_loosum_x. (We can efficiently do this for any one point but not all,
        # hence we still need the above calculation.) This is good because when
        # this condition is met, we cannot use the above calculation; its -inf.
        # We now compute the log-leave-out-max-sum, replicate it to every
        # point and make sure to select it only when we need to.
        max_logx = tf.reduce_max(logx, axis=axis, keepdims=True)
        is_positive_and_largest = (logx > 0.) & tf.equal(logx, max_logx)
        log_lomsum_x = tf.reduce_logsumexp(tf.where(is_positive_and_largest,
                                                    neg_inf, logx),
                                           axis=axis,
                                           keepdims=True)
        d_not_ok_result = tf.where(is_positive_and_largest, log_lomsum_x,
                                   neg_inf)

        log_loosum_x = tf.where(d_ok, d_ok_result, d_not_ok_result)

        # We now squeeze log_sum_x so as if we used `keepdims=False`.
        # TODO(b/136176077): These mental gymnastics could all be replaced with
        # `tf.squeeze(log_sum_x, axis)` if tf.squeeze supported Tensor valued `axis`
        # arguments.
        if not keepdims:
            if axis is None:
                keepdims = np.array([], dtype=np.int32)
            else:
                rank = prefer_static.rank(logx)
                keepdims = prefer_static.setdiff1d(
                    prefer_static.range(rank),
                    prefer_static.non_negative_axis(axis, rank))
            squeeze_shape = tf.gather(prefer_static.shape(logx),
                                      indices=keepdims)
            log_sum_x = tf.reshape(log_sum_x, shape=squeeze_shape)
            if prefer_static.is_numpy(keepdims):
                tensorshape_util.set_shape(log_sum_x,
                                           np.array(logx.shape)[keepdims])

        # Set static shapes just in case we lost them.
        tensorshape_util.set_shape(n, [])
        tensorshape_util.set_shape(log_loosum_x, logx.shape)

        if not compute_mean:
            return log_loosum_x, log_sum_x, n

        log_nm1 = prefer_static.log(max(1., n - 1.))
        log_n = prefer_static.log(n)
        return log_loosum_x - log_nm1, log_sum_x - log_n, n
Exemple #4
0
def trace_scan(loop_fn,
               initial_state,
               elems,
               trace_fn,
               parallel_iterations=10,
               name=None):
    """A simplified version of `tf.scan` that has configurable tracing.

  This function repeatedly calls `loop_fn(state, elem)`, where `state` is the
  `initial_state` during the first iteration, and the return value of `loop_fn`
  for every iteration thereafter. `elem` is a slice of `elements` along the
  first dimension, accessed in order. Additionally, it calls `trace_fn` on the
  return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are
  stacked and returned from this function, such that the first dimension of
  those `Tensor`s matches the size of `elems`.

  Args:
    loop_fn: A callable that takes in a `Tensor` or a nested collection of
      `Tensor`s with the same structure as `initial_state`, a slice of `elems`
      and returns the same structure as `initial_state`.
    initial_state: A `Tensor` or a nested collection of `Tensor`s passed to
      `loop_fn` in the first iteration.
    elems: A `Tensor` that is split along the first dimension and each element
      of which is passed to `loop_fn`.
    trace_fn: A callable that takes in the return value of `loop_fn` and returns
      a `Tensor` or a nested collection of `Tensor`s.
    parallel_iterations: Passed to the internal `tf.while_loop`.
    name: Name scope used in this function. Default: 'trace_scan'.

  Returns:
    final_state: The final return value of `loop_fn`.
    trace: The same structure as the return value of `trace_fn`, but with each
      `Tensor` being a stack of the corresponding `Tensors` in the return value
      of `trace_fn` for each slice of `elems`.
  """
    with tf.name_scope(name or 'trace_scan'), tf1.variable_scope(
            tf1.get_variable_scope()) as vs:
        if vs.caching_device is None and not tf.executing_eagerly():
            vs.set_caching_device(lambda op: op.device)

        initial_state = tf.nest.map_structure(
            lambda x: tf.convert_to_tensor(x, name='initial_state'),
            initial_state)
        elems = tf.convert_to_tensor(elems, name='elems')

        length = prefer_static.size0(elems)
        static_length = length if prefer_static.is_numpy(length) else None

        # This is an TensorArray in part because of XLA, which had trouble with
        # non-statically known indices. I.e. elems[i] errored, but
        # elems_array.read(i) worked.
        elems_array = tf.TensorArray(elems.dtype,
                                     size=length,
                                     element_shape=elems.shape[1:])
        elems_array = elems_array.unstack(elems)

        trace_arrays = tf.nest.map_structure(
            lambda x: tf.TensorArray(
                x.dtype, size=length, element_shape=x.shape),
            trace_fn(initial_state))

        def _body(i, state, trace_arrays):
            state = loop_fn(state, elems_array.read(i))
            trace_arrays = tf.nest.pack_sequence_as(trace_arrays, [
                a.write(i, v) for a, v in zip(tf.nest.flatten(trace_arrays),
                                              tf.nest.flatten(trace_fn(state)))
            ])
            return i + 1, state, trace_arrays

        _, final_state, trace_arrays = tf.while_loop(
            cond=lambda i, *args: i < length,
            body=_body,
            loop_vars=(0, initial_state, trace_arrays),
            parallel_iterations=parallel_iterations)

        stacked_trace = tf.nest.map_structure(lambda x: x.stack(),
                                              trace_arrays)

        # Restore the static length if we know it.
        def _merge_static_length(x):
            x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:]))
            return x

        stacked_trace = tf.nest.map_structure(_merge_static_length,
                                              stacked_trace)
        return final_state, stacked_trace
Exemple #5
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    axis = None
    paddings = None

    if is_init != tensor_util.is_ref(self.axis):
      # First we check the shape of the axis argument.
      msg = 'Argument `axis` must be scalar or vector.'
      if tensorshape_util.rank(self.axis.shape) is not None:
        if tensorshape_util.rank(self.axis.shape) > 1:
          raise ValueError(msg)
      elif self.validate_args:
        if axis is None: axis = tf.convert_to_tensor(self.axis)
        assertions.append(assert_util.assert_rank_at_most(
            axis, 1, message=msg))
      # Next we check the values of the axis argument.
      axis_ = tf.get_static_value(self.axis)
      msg = 'Argument `axis` must be negative.'
      if axis_ is not None:
        if np.any(axis_ > -1):
          raise ValueError(msg)
      elif self.validate_args:
        if axis is None: axis = tf.convert_to_tensor(self.axis)
        assertions.append(assert_util.assert_less(axis, 0, message=msg))
      msg = 'Argument `axis` elements must be unique.'
      if axis_ is not None:
        if len(np.array(axis_).reshape(-1)) != len(np.unique(axis_)):
          raise ValueError(msg)
      elif self.validate_args:
        if axis is None: axis = tf.convert_to_tensor(self.axis)
        assertions.append(assert_util.assert_equal(
            prefer_static.size0(axis),
            prefer_static.size0(prefer_static.setdiff1d(axis)),
            message=msg))

    if is_init != tensor_util.is_ref(self.paddings):
      # First we check the shape of the paddings argument.
      msg = 'Argument `paddings` must be a vector of pairs.'
      if tensorshape_util.is_fully_defined(self.paddings.shape):
        shape = np.int32(self.paddings.shape)
        if len(shape) != 2 or shape[0] < 1 or shape[1] != 2:
          raise ValueError(msg)
      elif self.validate_args:
        if paddings is None: paddings = tf.convert_to_tensor(self.paddings)
        with tf.control_dependencies([
            assert_util.assert_equal(tf.rank(paddings), 2, message=msg)]):
          shape = tf.shape(paddings)
          assertions.extend([
              assert_util.assert_greater(shape[0], 0, message=msg),
              assert_util.assert_equal(shape[1], 2, message=msg),
          ])
      # Next we check the values of the paddings argument.
      paddings_ = tf.get_static_value(self.paddings)
      msg = 'Argument `paddings` must be non-negative.'
      if paddings_ is not None:
        if np.any(paddings_ < 0):
          raise ValueError(msg)
      elif self.validate_args:
        if paddings is None: paddings = tf.convert_to_tensor(self.paddings)
        assertions.append(assert_util.assert_greater(
            paddings, -1, message=msg))

    if is_init != (tensor_util.is_ref(self.axis) and
                   tensor_util.is_ref(self.paddings)):
      axis_ = tf.get_static_value(self.axis)
      if axis_ is None and axis is None:
        axis = tf.convert_to_tensor(self.axis)
      len_axis = prefer_static.size0(prefer_static.reshape(
          axis if axis_ is None else axis_, shape=-1))

      paddings_ = tf.get_static_value(self.paddings)
      if paddings_ is None and paddings is None:
        paddings = tf.convert_to_tensor(self.paddings)
      len_paddings = prefer_static.size0(
          paddings if paddings_ is None else paddings_)

      msg = ('Arguments `axis` and `paddings` must have the same number '
             'of elements.')
      if (prefer_static.is_numpy(len_axis) and
          prefer_static.is_numpy(len_paddings)):
        if len_axis != len_paddings:
          raise ValueError(msg + ' Saw: {}, {}.'.format(
              self.axis, self.paddings))
      elif self.validate_args:
        assertions.append(assert_util.assert_equal(
            len_axis, len_paddings, message=msg))

    return assertions
Exemple #6
0
 def test_num_paddings_dynamic(self):
     n = tf1.placeholder_with_default(2, shape=None)
     x = ps.pad([2, 3], paddings=[[0, n]], constant_values=1)
     if not ps.is_numpy(x):
         x = self.evaluate(x)
     self.assertAllEqual([2, 3, 1, 1], x)
Exemple #7
0
def trace_scan(loop_fn,
               initial_state,
               elems,
               trace_fn,
               trace_criterion_fn=None,
               static_trace_allocation_size=None,
               parallel_iterations=10,
               name=None):
  """A simplified version of `tf.scan` that has configurable tracing.

  This function repeatedly calls `loop_fn(state, elem)`, where `state` is the
  `initial_state` during the first iteration, and the return value of `loop_fn`
  for every iteration thereafter. `elem` is a slice of `elements` along the
  first dimension, accessed in order. Additionally, it calls `trace_fn` on the
  return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are
  stacked and returned from this function, such that the first dimension of
  those `Tensor`s matches the size of `elems`.

  Args:
    loop_fn: A callable that takes in a `Tensor` or a nested collection of
      `Tensor`s with the same structure as `initial_state`, a slice of `elems`
      and returns the same structure as `initial_state`.
    initial_state: A `Tensor` or a nested collection of `Tensor`s passed to
      `loop_fn` in the first iteration.
    elems: A `Tensor` that is split along the first dimension and each element
      of which is passed to `loop_fn`.
    trace_fn: A callable that takes in the return value of `loop_fn` and returns
      a `Tensor` or a nested collection of `Tensor`s.
    trace_criterion_fn: Optional callable that takes in the return value of
      `loop_fn` and returns a boolean `Tensor` indicating whether to trace it.
      If `None`, all steps are traced.
      Default value: `None`.
    static_trace_allocation_size: Optional Python `int` size of trace to
      allocate statically. This should be an upper bound on the number of steps
      traced and is used only when the length cannot be
      statically inferred (for example, if a `trace_criterion_fn` is specified).
      It is primarily intended for contexts where static shapes are required,
      such as in XLA-compiled code.
      Default value: `None`.
    parallel_iterations: Passed to the internal `tf.while_loop`.
    name: Name scope used in this function. Default: 'trace_scan'.

  Returns:
    final_state: The final return value of `loop_fn`.
    trace: The same structure as the return value of `trace_fn`, but with each
      `Tensor` being a stack of the corresponding `Tensors` in the return value
      of `trace_fn` for each slice of `elems`.
  """
  with tf.name_scope(name or 'trace_scan'), tf1.variable_scope(
      tf1.get_variable_scope()) as vs:
    if vs.caching_device is None and not tf.executing_eagerly():
      vs.set_caching_device(lambda op: op.device)

    initial_state = tf.nest.map_structure(
        lambda x: tf.convert_to_tensor(x, name='initial_state'),
        initial_state)
    elems = tf.convert_to_tensor(elems, name='elems')

    length = prefer_static.size0(elems)

    # This is an TensorArray in part because of XLA, which had trouble with
    # non-statically known indices. I.e. elems[i] errored, but
    # elems_array.read(i) worked.
    elems_array = tf.TensorArray(
        elems.dtype, size=length, element_shape=elems.shape[1:])
    elems_array = elems_array.unstack(elems)

    # Initialize trace arrays.
    dynamic_size, initial_size = True, 0
    if trace_criterion_fn is None:
      dynamic_size = prefer_static.logical_not(prefer_static.is_numpy(length))
      initial_size = length
    elif static_trace_allocation_size:
      dynamic_size, initial_size = False, static_trace_allocation_size
    trace_arrays = tf.nest.map_structure(
        lambda x: tf.TensorArray(x.dtype,  # pylint: disable=g-long-lambda
                                 size=initial_size,
                                 dynamic_size=dynamic_size,
                                 element_shape=x.shape),
        trace_fn(initial_state))

    # Helper for writing a (structured) state to (structured) arrays.
    def trace_one_step(num_steps_traced, trace_arrays, state):
      return tf.nest.map_structure(
          lambda ta, x: ta.write(num_steps_traced, x),
          trace_arrays,
          trace_fn(state))

    def _body(i, state, num_steps_traced, trace_arrays):
      elem = elems_array.read(i)
      state = loop_fn(state, elem)

      trace_arrays, num_steps_traced = prefer_static.cond(
          trace_criterion_fn(state) if trace_criterion_fn else True,
          lambda: (trace_one_step(num_steps_traced, trace_arrays, state),  # pylint: disable=g-long-lambda
                   num_steps_traced + 1),
          lambda: (trace_arrays, num_steps_traced))

      return i + 1, state, num_steps_traced, trace_arrays

    _, final_state, _, trace_arrays = tf.while_loop(
        cond=lambda i, *_: i < length,
        body=_body,
        loop_vars=(0, initial_state, 0, trace_arrays),
        parallel_iterations=parallel_iterations)

    stacked_trace = tf.nest.map_structure(lambda x: x.stack(), trace_arrays)

    # Restore the static length if we know it.
    static_length = tf.TensorShape(None if dynamic_size else initial_size)
    def _merge_static_length(x):
      tensorshape_util.set_shape(x, static_length.concatenate(x.shape[1:]))
      return x

    stacked_trace = tf.nest.map_structure(_merge_static_length, stacked_trace)
    return final_state, stacked_trace