Esempio n. 1
0
    def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                       validate_args, name):
        """Helper to __init__ which ensures override batch/event_shape are valid."""
        if override_shape is None:
            override_shape = []

        override_shape = tf.convert_to_tensor(override_shape,
                                              dtype=tf.int32,
                                              name=name)

        if not dtype_util.is_integer(override_shape.dtype):
            raise TypeError("shape override must be an integer")

        override_is_scalar = _is_scalar_from_shape_tensor(override_shape)
        if tf.get_static_value(override_is_scalar):
            return self._empty

        dynamic_assertions = []

        if tensorshape_util.rank(override_shape.shape) is not None:
            if tensorshape_util.rank(override_shape.shape) != 1:
                raise ValueError("shape override must be a vector")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_rank(
                    override_shape,
                    1,
                    message="shape override must be a vector")
            ]

        if tf.get_static_value(override_shape) is not None:
            if any(s < 0 for s in tf.get_static_value(override_shape)):
                raise ValueError(
                    "shape override must have non-negative elements")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_non_negative(
                    override_shape,
                    message="shape override must have non-negative elements")
            ]

        is_both_nonscalar = prefer_static.logical_and(
            prefer_static.logical_not(base_is_scalar),
            prefer_static.logical_not(override_is_scalar))
        if tf.get_static_value(is_both_nonscalar) is not None:
            if tf.get_static_value(is_both_nonscalar):
                raise ValueError("base distribution not scalar")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_equal(
                    is_both_nonscalar,
                    False,
                    message="base distribution not scalar")
            ]

        if not dynamic_assertions:
            return override_shape
        return distribution_util.with_dependencies(dynamic_assertions,
                                                   override_shape)
  def _assert_compatible_shape(self, index, sample_shape, samples):
    requested_shape, _ = self._expand_sample_shape_to_vector(
        tf.convert_to_tensor(sample_shape, dtype=tf.int32),
        name='requested_shape')
    actual_shape = prefer_static.shape(samples)
    actual_rank = prefer_static.rank_from_shape(actual_shape)
    requested_rank = prefer_static.rank_from_shape(requested_shape)

    # We test for two properties we expect of yielded distributions:
    # (1) The rank of the tensor of generated samples must be at least
    #     as large as the rank requested.
    # (2) The requested shape must be a prefix of the shape of the
    #     generated tensor of samples.
    # We attempt to perform test (1) statically first.
    # We don't need to do this explicitly for test (2) because
    # `assert_equal` evaluates statically if it can.
    static_actual_rank = tf.get_static_value(actual_rank)
    static_requested_rank = tf.get_static_value(requested_rank)

    assertion_message = ('Samples yielded by distribution #{} are not '
                         'consistent with `sample_shape` passed to '
                         '`JointDistributionCoroutine` '
                         'distribution.'.format(index))

    # TODO Remove this static check (b/138738650)
    if (static_actual_rank is not None and
        static_requested_rank is not None):
      # We're able to statically check the rank
      if static_actual_rank < static_requested_rank:
        raise ValueError(assertion_message)
      else:
        control_dependencies = []
    else:
      # We're not able to statically check the rank
      control_dependencies = [
          assert_util.assert_greater_equal(
              actual_rank, requested_rank,
              message=assertion_message)
          ]

    with tf.control_dependencies(control_dependencies):
      trimmed_actual_shape = actual_shape[:requested_rank]

    control_dependencies = [
        assert_util.assert_equal(
            requested_shape, trimmed_actual_shape,
            message=assertion_message)
    ]

    return control_dependencies
Esempio n. 3
0
def assert_finite(x, data=None, summarize=None, message=None, name=None):
    """Assert all elements of `x` are finite.

  Args:
    x:  Numeric `Tensor`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).
      Defaults to "assert_finite".

  Returns:
    Op raising `InvalidArgumentError` unless `x` has specified rank or lower.
    If static checks determine `x` has correct rank, a `no_op` is returned.

  Raises:
    ValueError:  If static checks determine `x` has wrong rank.
  """
    with tf.name_scope(name or 'assert_finite'):
        x_ = tf.get_static_value(x)
        if x_ is not None:
            if ~np.all(np.isfinite(x_)):
                raise ValueError(message)
            return x
        assertion = tf1.assert_equal(tf.math.is_finite(x),
                                     tf.ones_like(x, tf.bool),
                                     data=data,
                                     summarize=summarize,
                                     message=message)
        with tf.control_dependencies([assertion]):
            return tf.identity(x)
Esempio n. 4
0
def _maybe_validate_perm(perm, validate_args, name=None):
    """Checks that `perm` is valid."""
    with tf.name_scope(name or 'maybe_validate_perm'):
        assertions = []
        if not dtype_util.is_integer(perm.dtype):
            raise TypeError('`perm` must be integer type')

        msg = '`perm` must be a vector.'
        if tensorshape_util.rank(perm.shape) is not None:
            if tensorshape_util.rank(perm.shape) != 1:
                raise ValueError(msg[:-1] + ', saw rank: {}.'.format(
                    tensorshape_util.rank(perm.shape)))
        elif validate_args:
            assertions += [assert_util.assert_rank(perm, 1, message=msg)]

        perm_ = tf.get_static_value(perm)
        msg = '`perm` must be a valid permutation vector.'
        if perm_ is not None:
            if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
                raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
        elif validate_args:
            assertions += [
                assert_util.assert_equal(tf.sort(perm),
                                         tf.range(tf.size(perm)),
                                         message=msg)
            ]

        return assertions
Esempio n. 5
0
 def _event_shape(self):
     sample_shape = tf.TensorShape(tf.get_static_value(self.sample_shape))
     if (tensorshape_util.rank(sample_shape) is None or
             tensorshape_util.rank(self.distribution.event_shape) is None):
         return tf.TensorShape(None)
     return tensorshape_util.concatenate(sample_shape,
                                         self.distribution.event_shape)
Esempio n. 6
0
 def _reshape_part(part, dtype, event_shape):
     part = tf.cast(part, dtype)
     static_rank = tf.get_static_value(
         ps.rank_from_shape(event_shape))
     if static_rank == 1:
         return part
     new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                           axis=-1)
     return tf.reshape(part, ps.cast(new_shape, tf.int32))
Esempio n. 7
0
def non_negative_axis(axis, rank, name=None):  # pylint:disable=redefined-outer-name
  """Make (possibly negatively indexed) `axis` argument non-negative."""
  with tf.name_scope(name or 'non_negative_axis'):
    if axis is None:
      return None
    if rank is None:
      raise ValueError('Argument `rank` cannot be `None`.')
    dtype = dtype_util.as_numpy_dtype(
        dtype_util.common_dtype([axis, rank], dtype_hint=tf.int32))
    rank_ = tf.get_static_value(rank)
    axis_ = tf.get_static_value(axis)
    if rank_ is None or axis_ is None:
      axis = tf.convert_to_tensor(axis, dtype=dtype, name='axis')
      rank = tf.convert_to_tensor(rank, dtype=dtype, name='rank')
      return tf.where(axis < 0, rank + axis, axis)
    axis_ = np.array(axis_, dtype=dtype)
    rank_ = np.array(rank_, dtype=dtype)
    return np.where(axis_ < 0, axis_ + rank_, axis_)
Esempio n. 8
0
def _pick_scalar_condition(pred, cond_true, cond_false):
    """Convenience function which chooses the condition based on the predicate."""
    # Note: This function is only valid if all of pred, cond_true, and cond_false
    # are scalars. This means its semantics are arguably more like tf.cond than
    # tf.where even though we use tf.where to implement it.
    pred_ = tf.get_static_value(tf.convert_to_tensor(pred))
    if pred_ is None:
        return tf.where(pred, cond_true, cond_false)
    return cond_true if pred_ else cond_false
Esempio n. 9
0
def _kl_sample(a, b, name='kl_sample'):
    """Batched KL divergence `KL(a || b)` for Sample distributions.

  We can leverage the fact that:

  ```
  KL(Sample(a) || Sample(b)) = sum(KL(a || b))
  ```

  where the sum is over the `sample_shape` dims.

  Args:
    a: Instance of `Sample` distribution.
    b: Instance of `Sample` distribution.
    name: (optional) name to use for created ops.
      Default value: `"kl_sample"`'.

  Returns:
    kldiv: Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the `sample_shape` of `a` and `b` don't match.
  """
    assertions = []
    a_ss = tf.get_static_value(a.sample_shape)
    b_ss = tf.get_static_value(b.sample_shape)
    msg = '`a.sample_shape` must be identical to `b.sample_shape`.'
    if a_ss is not None and b_ss is not None:
        if not np.array_equal(a_ss, b_ss):
            raise ValueError(msg)
    elif a.validate_args or b.validate_args:
        assertions.append(
            assert_util.assert_equal(a.sample_shape,
                                     b.sample_shape,
                                     message=msg))
    with tf.control_dependencies(assertions):
        kl = kullback_leibler.kl_divergence(a.distribution,
                                            b.distribution,
                                            name=name)
        n = prefer_static.reduce_prod(a.sample_shape)
        return tf.cast(x=n, dtype=kl.dtype) * kl
Esempio n. 10
0
def _setdiff1d(a, b, aminusb=True, validate_indices=True):
  """Compute set difference of elements in last dimension of `a` and `b`."""
  if not aminusb:
    raise NotImplementedError(
        'Argument `aminusb != True` is currently unimplemented.')
  if not validate_indices:
    raise NotImplementedError(
        'Argument `validate_indices != True` is currently unimplemented.')
  with tf.name_scope('setdiff1d'):
    dtype = dtype_util.as_numpy_dtype(
        dtype_util.common_dtype([a, b], dtype_hint=tf.int32))
    a_ = tf.get_static_value(a)
    b_ = tf.get_static_value(b)
    if a_ is None or b_ is None:
      a = tf.convert_to_tensor(a, dtype=dtype, name='a')
      b = tf.convert_to_tensor(b, dtype=dtype, name='b')
      return tf.sparse.to_dense(tf.sets.difference(
          a[tf.newaxis], b[tf.newaxis]))[0]
    a_ = np.array(a_, dtype=dtype)
    b_ = np.array(b_, dtype=dtype)
    return np.setdiff1d(a_, b_)
Esempio n. 11
0
 def _maybe_rotate_dims(self, x, rotate_right=False):
     """Helper which rolls left event_dims left or right event_dims right."""
     needs_rotation_const = tf.get_static_value(self._needs_rotation)
     if needs_rotation_const is not None and not needs_rotation_const:
         return x
     ndims = prefer_static.rank(x)
     n = (ndims -
          self._rotate_ndims) if rotate_right else self._rotate_ndims
     perm = prefer_static.concat(
         [prefer_static.range(n, ndims),
          prefer_static.range(0, n)], axis=0)
     return tf.transpose(a=x, perm=perm)
def softmax(x, axis, name=None):
    """Equivalent to tf.math.softmax but works around b/70297725."""
    with tf.name_scope(name or "softmax"):
        x = tf.convert_to_tensor(x, name="x")
        ndims = (tensorshape_util.rank(x.shape) if tensorshape_util.rank(
            x.shape) is not None else tf.rank(x, name="ndims"))
        axis = tf.convert_to_tensor(axis, dtype=tf.int32, name="axis")
        axis_ = tf.get_static_value(axis)
        if axis_ is not None:
            axis = np.int(ndims + axis_ if axis_ < 0 else axis_)
        else:
            axis = tf.where(axis < 0, ndims + axis, axis)
    return tf.math.softmax(x, axis=axis)
Esempio n. 13
0
 def _event_shape(self, shape, static_perm_to_shape):
     """Helper for _forward and _inverse_event_shape."""
     rightmost_ = tf.get_static_value(self.rightmost_transposed_ndims)
     if tensorshape_util.rank(shape) is None or rightmost_ is None:
         return tf.TensorShape(None)
     if tensorshape_util.rank(shape) < rightmost_:
         raise ValueError(
             'Invalid shape: min event ndims={} but got {}'.format(
                 rightmost_, shape))
     perm_ = tf.get_static_value(self.perm, partial=True)
     if perm_ is None:
         return shape[:tensorshape_util.rank(shape) -
                      rightmost_].concatenate([None] * int(rightmost_))
     # We can use elimination to reidentify a single None dimension.
     if sum(p is None for p in perm_) == 1:
         present = np.argsort([-1 if p is None else p for p in perm_])
         for i, p in enumerate(present[1:]):  # The -1 sorts to position 0.
             if i != p:
                 perm_ = [i if p is None else p for p in perm_]
                 break
     return shape[:tensorshape_util.rank(shape) - rightmost_].concatenate(
         static_perm_to_shape(
             shape[tensorshape_util.rank(shape) - rightmost_:], perm_))
Esempio n. 14
0
    def __init__(self,
                 distribution,
                 reinterpreted_batch_ndims=None,
                 validate_args=False,
                 name=None):
        """Construct a `Independent` distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims
        which will be regarded as event dims. When `None` all but the first
        batch axis (batch axis 0) will be transferred to event dimensions
        (analogous to `tf.layers.flatten`).
      validate_args: Python `bool`.  Whether to validate input with asserts.
        If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      name: The name for ops managed by the distribution.
        Default value: `Independent + distribution.name`.

    Raises:
      ValueError: if `reinterpreted_batch_ndims` exceeds
        `distribution.batch_ndims`
    """
        parameters = dict(locals())
        name = name or "Independent" + distribution.name
        self._distribution = distribution
        with tf.name_scope(name) as name:
            if reinterpreted_batch_ndims is None:
                reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims(
                    distribution)
            reinterpreted_batch_ndims = tf.convert_to_tensor(
                reinterpreted_batch_ndims,
                dtype=tf.int32,
                name="reinterpreted_batch_ndims")
            self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
            self._static_reinterpreted_batch_ndims = tf.get_static_value(
                reinterpreted_batch_ndims)
            if self._static_reinterpreted_batch_ndims is not None:
                self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims
            super(Independent, self).__init__(
                dtype=self._distribution.dtype,
                reparameterization_type=self._distribution.
                reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=self._distribution.allow_nan_stats,
                parameters=parameters,
                name=name)
            self._runtime_assertions = self._make_runtime_assertions(
                distribution, reinterpreted_batch_ndims, validate_args)
Esempio n. 15
0
def broadcast_shape(x_shape, y_shape):
  """Computes the shape of a broadcast.

  When both arguments are statically-known, the broadcasted shape will be
  computed statically and returned as a `TensorShape`.  Otherwise, a rank-1
  `Tensor` will be returned.

  Arguments:
    x_shape: A `TensorShape` or rank-1 integer `Tensor`.  The input `Tensor` is
      broadcast against this shape.
    y_shape: A `TensorShape` or rank-1 integer `Tensor`.  The input `Tensor` is
      broadcast against this shape.

  Returns:
    shape: A `TensorShape` or rank-1 integer `Tensor` representing the
      broadcasted shape.
  """
  x_shape_static = tf.get_static_value(x_shape)
  y_shape_static = tf.get_static_value(y_shape)
  if (x_shape_static is None) or (y_shape_static is None):
    return tf.broadcast_dynamic_shape(x_shape, y_shape)

  return tf.broadcast_static_shape(
      tf.TensorShape(x_shape_static), tf.TensorShape(y_shape_static))
Esempio n. 16
0
def _maybe_check_valid_shape(shape, validate_args):
    """Check that a shape Tensor is int-type and otherwise sane."""
    if not dtype_util.is_integer(shape.dtype):
        raise TypeError('{} dtype ({}) should be `int`-like.'.format(
            shape, dtype_util.name(shape.dtype)))

    assertions = []

    message = '`{}` rank should be <= 1.'
    if tensorshape_util.rank(shape.shape) is not None:
        if tensorshape_util.rank(shape.shape) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.rank(shape),
                                    2,
                                    message=message.format(shape)))

    shape_ = tf.get_static_value(shape)

    message = '`{}` elements must have at most one `-1`.'
    if shape_ is not None:
        if sum(shape_ == -1) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.reduce_sum(
                tf.cast(tf.equal(shape, -1), tf.int32)),
                                    2,
                                    message=message.format(shape)))

    message = '`{}` elements must be either positive integers or `-1`.'
    if shape_ is not None:
        if np.any(shape_ < -1):
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(shape,
                                       -2,
                                       message=message.format(shape)))

    return assertions
Esempio n. 17
0
def _get_static_predicate(pred):
  """Helper function for statically evaluating predicates in `cond`."""
  if tf.is_tensor(pred):
    pred_value = tf.get_static_value(tf.convert_to_tensor(pred))

    # TODO(jamieas): remove the dependency on `pywrap_tensorflow`.
    # pylint: disable=protected-access
    if pred_value is None:
      pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
                                                        pred._as_tf_output())
    # pylint: enable=protected-access
    if pred_value in (0, 1, True, False):
      pred_value = bool(pred_value)

  elif pred in (0, 1, True, False):  # Accept 1/0 as valid boolean values
    # This branch also casts np.array(False), tf.EagerTensor(True), etc.
    pred_value = bool(pred)
  else:
    raise TypeError('`pred` must be a Tensor, or a Python bool, or 1 or 0. '
                    'Found instead: {}'.format(pred))
  return pred_value
    def __init__(self, power=0., validate_args=False, name='power_transform'):
        """Instantiates the `PowerTransform` bijector.

    Args:
      power: Python `float` scalar indicating the transform power, i.e.,
        `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      ValueError: if `power < 0` or is not known statically.
    """
        with tf.name_scope(name) as name:
            power = tf.get_static_value(
                tf.convert_to_tensor(power, name='power'))
            if power is None or power < 0:
                raise ValueError('`power` must be a non-negative TF constant.')
            self._power = power
            super(PowerTransform, self).__init__(forward_min_event_ndims=0,
                                                 validate_args=validate_args,
                                                 name=name)
Esempio n. 19
0
 def _make_runtime_assertions(self, distribution, reinterpreted_batch_ndims,
                              validate_args):
     assertions = []
     static_reinterpreted_batch_ndims = tf.get_static_value(
         reinterpreted_batch_ndims)
     batch_ndims = tensorshape_util.rank(distribution.batch_shape)
     if batch_ndims is not None and static_reinterpreted_batch_ndims is not None:
         if static_reinterpreted_batch_ndims > batch_ndims:
             raise ValueError("reinterpreted_batch_ndims({}) cannot exceed "
                              "distribution.batch_ndims({})".format(
                                  static_reinterpreted_batch_ndims,
                                  batch_ndims))
     elif validate_args:
         assertions.append(
             assert_util.assert_less_equal(
                 reinterpreted_batch_ndims,
                 prefer_static.rank_from_shape(
                     distribution.batch_shape_tensor,
                     distribution.batch_shape),
                 message=("reinterpreted_batch_ndims cannot exceed "
                          "distribution.batch_ndims")))
     return assertions
Esempio n. 20
0
def _maybe_validate_rightmost_transposed_ndims(rightmost_transposed_ndims,
                                               validate_args,
                                               name=None):
    """Checks that `rightmost_transposed_ndims` is valid."""
    with tf.name_scope(name or 'maybe_validate_rightmost_transposed_ndims'):
        assertions = []
        if not dtype_util.is_integer(rightmost_transposed_ndims.dtype):
            raise TypeError(
                '`rightmost_transposed_ndims` must be integer type.')

        if tensorshape_util.rank(rightmost_transposed_ndims.shape) is not None:
            if tensorshape_util.rank(rightmost_transposed_ndims.shape) != 0:
                raise ValueError(
                    '`rightmost_transposed_ndims` must be a scalar, '
                    'saw rank: {}.'.format(
                        tensorshape_util.rank(
                            rightmost_transposed_ndims.shape)))
        elif validate_args:
            assertions += [
                assert_util.assert_rank(rightmost_transposed_ndims, 0)
            ]

        rightmost_transposed_ndims_ = tf.get_static_value(
            rightmost_transposed_ndims)
        msg = '`rightmost_transposed_ndims` must be non-negative.'
        if rightmost_transposed_ndims_ is not None:
            if rightmost_transposed_ndims_ < 0:
                raise ValueError(
                    msg[:-1] +
                    ', saw: {}.'.format(rightmost_transposed_ndims_))
        elif validate_args:
            assertions += [
                assert_util.assert_non_negative(rightmost_transposed_ndims,
                                                message=msg)
            ]

        return assertions
Esempio n. 21
0
    def __init__(self,
                 cat,
                 components,
                 validate_args=False,
                 allow_nan_stats=True,
                 use_static_graph=False,
                 name="Mixture"):
        """Initialize a Mixture distribution.

    A `Mixture` is defined by a `Categorical` (`cat`, representing the
    mixture probabilities) and a list of `Distribution` objects
    all having matching dtype, batch shape, event shape, and continuity
    properties (the components).

    The `num_classes` of `cat` must be possible to infer at graph construction
    time and match `len(components)`.

    Args:
      cat: A `Categorical` distribution instance, representing the probabilities
          of `distributions`.
      components: A list or tuple of `Distribution` instances.
        Each instance must have the same type, be defined on the same domain,
        and have matching `event_shape` and `batch_shape`.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between cat and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      use_static_graph: Calls to `sample` will not rely on dynamic tensor
        indexing, allowing for some static graph compilation optimizations, but
        at the expense of sampling all underlying distributions in the mixture.
        (Possibly useful when running on TPUs).
        Default value: `False` (i.e., use dynamic indexing).
      name: A name for this distribution (optional).

    Raises:
      TypeError: If cat is not a `Categorical`, or `components` is not
        a list or tuple, or the elements of `components` are not
        instances of `Distribution`, or do not have matching `dtype`.
      ValueError: If `components` is an empty list or tuple, or its
        elements do not have a statically known event rank.
        If `cat.num_classes` cannot be inferred at graph creation time,
        or the constant value of `cat.num_classes` is not equal to
        `len(components)`, or all `components` and `cat` do not have
        matching static batch shapes, or all components do not
        have matching static event shapes.
    """
        parameters = dict(locals())
        if not isinstance(cat, categorical.Categorical):
            raise TypeError(
                "cat must be a Categorical distribution, but saw: %s" % cat)
        if not components:
            raise ValueError("components must be a non-empty list or tuple")
        if not isinstance(components, (list, tuple)):
            raise TypeError("components must be a list or tuple, but saw: %s" %
                            components)
        if not all(
                isinstance(c, distribution.Distribution) for c in components):
            raise TypeError(
                "all entries in components must be Distribution instances"
                " but saw: %s" % components)

        dtype = components[0].dtype
        if not all(d.dtype == dtype for d in components):
            raise TypeError("All components must have the same dtype, but saw "
                            "dtypes: %s" % [(d.name, d.dtype)
                                            for d in components])
        static_event_shape = components[0].event_shape
        static_batch_shape = cat.batch_shape
        for di, d in enumerate(components):
            if not tensorshape_util.is_compatible_with(static_batch_shape,
                                                       d.batch_shape):
                raise ValueError(
                    "components[{}] batch shape must be compatible with cat "
                    "shape and other component batch shapes".format(di))
            static_event_shape = tensorshape_util.merge_with(
                static_event_shape, d.event_shape)
            static_batch_shape = tensorshape_util.merge_with(
                static_batch_shape, d.batch_shape)
        if tensorshape_util.rank(static_event_shape) is None:
            raise ValueError(
                "Expected to know rank(event_shape) from components, but "
                "none of the components provide a static number of ndims")

        # Ensure that all batch and event ndims are consistent.
        with tf.name_scope(name) as name:
            num_components = cat._num_categories()
            static_num_components = tf.get_static_value(num_components)
            if static_num_components is None:
                raise ValueError(
                    "Could not infer number of classes from cat and unable "
                    "to compare this value to the number of components passed in."
                )
            # Possibly convert from numpy 0-D array.
            static_num_components = int(static_num_components)
            if static_num_components != len(components):
                raise ValueError(
                    "cat.num_classes != len(components): %d vs. %d" %
                    (static_num_components, len(components)))

            cat_batch_shape = cat.batch_shape_tensor()
            cat_batch_rank = tf.size(cat_batch_shape)
            if validate_args:
                batch_shapes = [d.batch_shape_tensor() for d in components]
                batch_ranks = [tf.size(bs) for bs in batch_shapes]
                check_message = ("components[%d] batch shape must match cat "
                                 "batch shape")
                self._assertions = [
                    assert_util.assert_equal(cat_batch_rank,
                                             batch_ranks[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
                self._assertions += [
                    assert_util.assert_equal(cat_batch_shape,
                                             batch_shapes[di],
                                             message=check_message % di)
                    for di in range(len(components))
                ]
            else:
                self._assertions = []

            self._cat = cat
            self._components = list(components)
            self._num_components = static_num_components
            self._static_event_shape = static_event_shape
            self._static_batch_shape = static_batch_shape

            self._use_static_graph = use_static_graph
            if use_static_graph and static_num_components is None:
                raise ValueError(
                    "Number of categories must be known statically when "
                    "`static_sample=True`.")

        super(Mixture, self).__init__(
            dtype=dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name)
Esempio n. 22
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            with tf.control_dependencies(self._assertions):
                # This sampling approach is almost the same as the approach used by
                # `MixtureSameFamily`. The differences are due to having a list of
                # `Distribution` objects rather than a single object, and maintaining
                # random seed management that is consistent with the non-static code
                # path.
                samples = []
                cat_samples = self.cat.sample(n, seed=seed)
                stream = SeedStream(seed, salt="Mixture")

                for c in range(self.num_components):
                    samples.append(self.components[c].sample(n, seed=stream()))
                stack_axis = -1 - tensorshape_util.rank(
                    self._static_event_shape)
                x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
                npdt = dtype_util.as_numpy_dtype(x.dtype)
                mask = tf.one_hot(
                    indices=cat_samples,  # [n, B]
                    depth=self._num_components,  # == k
                    on_value=npdt(1),
                    off_value=npdt(0))  # [n, B, k]
                mask = distribution_util.pad_mixture_dimensions(
                    mask, self, self._cat,
                    tensorshape_util.rank(
                        self._static_event_shape))  # [n, B, k, [1]*e]
                return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        with tf.control_dependencies(self._assertions):
            n = tf.convert_to_tensor(n, name="n")
            static_n = tf.get_static_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.shape
            if tensorshape_util.is_fully_defined(static_samples_shape):
                samples_shape = tensorshape_util.as_list(static_samples_shape)
                samples_size = tensorshape_util.num_elements(
                    static_samples_shape)
            else:
                samples_shape = tf.shape(cat_samples)
                samples_size = tf.size(cat_samples)
            static_batch_shape = self.batch_shape
            if tensorshape_util.is_fully_defined(static_batch_shape):
                batch_shape = tensorshape_util.as_list(static_batch_shape)
                batch_size = tensorshape_util.num_elements(static_batch_shape)
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = tf.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if tensorshape_util.is_fully_defined(static_event_shape):
                event_shape = np.array(
                    tensorshape_util.as_list(static_event_shape),
                    dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                             samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = tf.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = tf.reshape(
                tf.tile(tf.range(0, batch_size), [n]), samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = tf.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            stream = SeedStream(seed, salt="Mixture")

            for c in range(self.num_components):
                n_class = tf.size(partitioned_samples_indices[c])
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=stream())

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * tf.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = tf.reshape(
                    samples_class_c,
                    tf.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = tf.gather(samples_class_c,
                                            lookup_partitioned_batch_indices,
                                            name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = tf.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = tf.reshape(
                lhs_flat_ret,
                tf.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            tensorshape_util.set_shape(
                ret,
                tensorshape_util.concatenate(static_samples_shape,
                                             self.event_shape))
            return ret
Esempio n. 23
0
def _zeros_like(input, dtype=None, name=None):  # pylint: disable=redefined-builtin
  s = _shape(input)
  s_ = tf.get_static_value(s)
  if s_ is not None:
    return np.zeros(s, _numpy_dtype(dtype or input.dtype))
  return tf.zeros(s, dtype or s.dtype, name)
Esempio n. 24
0
    def __init__(self,
                 distribution,
                 bijector,
                 batch_shape=None,
                 event_shape=None,
                 kwargs_split_fn=_default_kwargs_split_fn,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns
        a tuple of kwargs `dict`s for each of the `distribution` and `bijector`
        parameters respectively.
        Default value: `_default_kwargs_split_fn` (i.e.,
            `lambda kwargs: (kwargs.get('distribution_kwargs', {}),
                             kwargs.get('bijector_kwargs', {}))`)
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operations.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
        parameters = dict(locals()) if parameters is None else parameters
        name = name or (("" if bijector is None else bijector.name) +
                        (distribution.name or ""))
        with tf.name_scope(name) as name:
            self._kwargs_split_fn = (_default_kwargs_split_fn
                                     if kwargs_split_fn is None else
                                     kwargs_split_fn)
            # For convenience we define some handy constants.
            self._zero = tf.constant(0, dtype=tf.int32, name="zero")
            self._empty = tf.constant([], dtype=tf.int32, name="empty")

            # We will keep track of a static and dynamic version of
            # self._is_{batch,event}_override. This way we can do more prior to graph
            # execution, including possibly raising Python exceptions.

            self._override_batch_shape = self._maybe_validate_shape_override(
                batch_shape, distribution.is_scalar_batch(), validate_args,
                "batch_shape")
            self._is_batch_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_batch_shape),
                    self._zero))
            self._is_maybe_batch_override = bool(
                tf.get_static_value(self._override_batch_shape) is None
                or tf.get_static_value(self._override_batch_shape).size != 0)

            self._override_event_shape = self._maybe_validate_shape_override(
                event_shape, distribution.is_scalar_event(), validate_args,
                "event_shape")
            self._is_event_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_event_shape),
                    self._zero))
            self._is_maybe_event_override = bool(
                tf.get_static_value(self._override_event_shape) is None
                or tf.get_static_value(self._override_event_shape).size != 0)

            # To convert a scalar distribution into a multivariate distribution we
            # will draw dims from the sample dims, which are otherwise iid. This is
            # easy to do except in the case that the base distribution has batch dims
            # and we're overriding event shape. When that case happens the event dims
            # will incorrectly be to the left of the batch dims. In this case we'll
            # cyclically permute left the new dims.
            self._needs_rotation = prefer_static.reduce_all([
                self._is_event_override,
                prefer_static.logical_not(self._is_batch_override),
                prefer_static.logical_not(distribution.is_scalar_batch())
            ])
            override_event_ndims = prefer_static.rank_from_shape(
                self._override_event_shape)
            self._rotate_ndims = _pick_scalar_condition(
                self._needs_rotation, override_event_ndims, 0)
            # We'll be reducing the head dims (if at all), i.e., this will be []
            # if we don't need to reduce.
            self._reduce_event_indices = prefer_static.range(
                self._rotate_ndims - override_event_ndims, self._rotate_ndims)

        self._distribution = distribution
        self._bijector = bijector
        super(TransformedDistribution, self).__init__(
            dtype=self._distribution.dtype,
            reparameterization_type=self._distribution.reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=self._distribution.allow_nan_stats,
            parameters=parameters,
            name=name)
Esempio n. 25
0
    def __init__(self,
                 perm=None,
                 rightmost_transposed_ndims=None,
                 validate_args=False,
                 name='transpose'):
        """Instantiates the `Transpose` bijector.

    Args:
      perm: Positive `int32` vector-shaped `Tensor` representing permutation of
        rightmost dims (for forward transformation).  Note that the `0`th index
        represents the first of the rightmost dims and the largest value must be
        `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value:
        `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`.
      rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor`
        representing the number of rightmost dimensions to permute.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value: `tf.size(perm)`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are
        specified.
      NotImplementedError: if `rightmost_transposed_ndims` is not known prior to
        graph execution.
    """
        with tf.name_scope(name) as name:
            if (rightmost_transposed_ndims is None) == (perm is None):
                raise ValueError('Must specify exactly one of '
                                 '`rightmost_transposed_ndims` and `perm`.')
            if rightmost_transposed_ndims is not None:
                rightmost_transposed_ndims = tf.convert_to_tensor(
                    rightmost_transposed_ndims,
                    dtype_hint=np.int32,
                    name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_rightmost_transposed_ndims(
                    rightmost_transposed_ndims, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        rightmost_transposed_ndims = tf.identity(
                            rightmost_transposed_ndims)
                perm_start = (distribution_util.prefer_static_value(
                    rightmost_transposed_ndims) - 1)
                perm = tf.range(start=perm_start,
                                limit=-1,
                                delta=-1,
                                name='perm')
            else:  # perm is not None:
                perm = tf.convert_to_tensor(perm,
                                            dtype_hint=np.int32,
                                            name='perm')
                rightmost_transposed_ndims = tf.size(
                    perm, name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_perm(perm, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        perm = tf.identity(perm)

            # TODO(b/110828604): If bijector base class ever supports dynamic
            # `min_event_ndims`, then this class already works dynamically and the
            # following five lines can be removed.
            if rightmost_transposed_ndims_ is None:
                raise NotImplementedError(
                    '`rightmost_transposed_ndims` must be '
                    'known prior to graph execution.')
            else:
                rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_)

            self._perm = perm
            self._rightmost_transposed_ndims = rightmost_transposed_ndims
            super(Transpose, self).__init__(
                forward_min_event_ndims=rightmost_transposed_ndims_,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
def concat_vectors(*args):
    """Concatenates input vectors, statically if possible."""
    args_ = [tf.get_static_value(x) for x in args]
    if any(vec is None for vec in args_):
        return tf.concat(args, axis=0)
    return [val for vec in args_ for val in vec]  # pylint: disable=g-complex-comprehension
Esempio n. 27
0
    def __init__(self,
                 df,
                 scale_operator,
                 input_output_cholesky=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Wishart distributions.

    Args:
      df: `float` or `double` tensor, the degrees of freedom of the
        distribution(s). `df` must be greater than or equal to `k`.
      scale_operator: `float` or `double` instance of `LinearOperator`.
      input_output_cholesky: Python `bool`. If `True`, functions whose input or
        output have the semantics of samples assume inputs are in Cholesky form
        and return outputs in Cholesky form. In particular, if this flag is
        `True`, input to `log_prob` is presumed of Cholesky form and output from
        `sample`, `mean`, and `mode` are of Cholesky form.  Setting this
        argument to `True` is purely a computational optimization and does not
        change the underlying distribution; for instance, `mean` returns the
        Cholesky of the mean, not the mean of Cholesky factors. The `variance`
        and `stddev` methods are unaffected by this flag.
        Default value: `False` (i.e., input/output does not have Cholesky
        semantics).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if scale is not floating-type
      TypeError: if scale.dtype != df.dtype
      ValueError: if df < k, where scale operator event shape is
        `(k, k)`
    """
        parameters = dict(locals())
        self._input_output_cholesky = input_output_cholesky
        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                if not dtype_util.is_floating(scale_operator.dtype):
                    raise TypeError(
                        "scale_operator.dtype=%s is not a floating-point type"
                        % scale_operator.dtype)
                if not scale_operator.is_square:
                    print(scale_operator.to_dense().eval())
                    raise ValueError("scale_operator must be square.")

                self._scale_operator = scale_operator
                self._df = tf.convert_to_tensor(df,
                                                dtype=scale_operator.dtype,
                                                name="df")
                dtype_util.assert_same_float_dtype(
                    [self._df, self._scale_operator])
                if tf.compat.dimension_value(
                        self._scale_operator.shape[-1]) is None:
                    self._dimension = tf.cast(
                        self._scale_operator.domain_dimension_tensor(),
                        dtype=self._scale_operator.dtype,
                        name="dimension")
                else:
                    self._dimension = tf.convert_to_tensor(
                        tf.compat.dimension_value(
                            self._scale_operator.shape[-1]),
                        dtype=self._scale_operator.dtype,
                        name="dimension")
                df_val = tf.get_static_value(self._df)
                dim_val = tf.get_static_value(self._dimension)
                if df_val is not None and dim_val is not None:
                    df_val = np.asarray(df_val)
                    if not df_val.shape:
                        df_val = [df_val]
                    if np.any(df_val < dim_val):
                        raise ValueError(
                            "Degrees of freedom (df = %s) cannot be less than "
                            "dimension of scale matrix (scale.dimension = %s)"
                            % (df_val, dim_val))
                elif validate_args:
                    assertions = assert_util.assert_less_equal(
                        self._dimension,
                        self._df,
                        message=("Degrees of freedom (df = %s) cannot be "
                                 "less than dimension of scale matrix "
                                 "(scale.dimension = %s)" %
                                 (self._dimension, self._df)))
                    self._df = distribution_util.with_dependencies(
                        [assertions], self._df)
        super(_WishartLinearOperator, self).__init__(
            dtype=self._scale_operator.dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            parameters=parameters,
            name=name)
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (observation_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (transition_distribution.batch_shape
                     and transition_distribution.batch_shape[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (observation_distribution.batch_shape
                     and observation_distribution.batch_shape[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
Esempio n. 29
0
    def __init__(self, permutation, axis=-1, validate_args=False, name=None):
        """Creates the `Permute` bijector.

    Args:
      permutation: An `int`-like vector-shaped `Tensor` representing the
        permutation to apply to the `axis` dimension of the transformed
        `Tensor`.
      axis: Scalar `int` `Tensor` representing the dimension over which to
        `tf.gather`. `axis` must be relative to the end (reading left to right)
        thus must be negative.
        Default value: `-1` (i.e., right-most).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str`, name given to ops managed by this object.

    Raises:
      TypeError: if `not dtype_util.is_integer(permutation.dtype)`.
      ValueError: if `permutation` does not contain exactly one of each of
        `{0, 1, ..., d}`.
      NotImplementedError: if `axis` is not known prior to graph execution.
      NotImplementedError: if `axis` is not negative.
    """
        with tf.name_scope(name or "permute") as name:
            axis = tf.convert_to_tensor(axis, name="axis")
            if not dtype_util.is_integer(axis.dtype):
                raise TypeError("axis.dtype ({}) should be `int`-like.".format(
                    dtype_util.name(axis.dtype)))
            permutation = tf.convert_to_tensor(permutation, name="permutation")
            if not dtype_util.is_integer(permutation.dtype):
                raise TypeError(
                    "permutation.dtype ({}) should be `int`-like.".format(
                        dtype_util.name(permutation.dtype)))
            p = tf.get_static_value(permutation)
            if p is not None:
                if set(p) != set(np.arange(p.size)):
                    raise ValueError(
                        "Permutation over `d` must contain exactly one of "
                        "each of `{0, 1, ..., d}`.")
            elif validate_args:
                p, _ = tf.math.top_k(-permutation,
                                     k=tf.shape(permutation)[-1],
                                     sorted=True)
                permutation = distribution_util.with_dependencies([
                    assert_util.assert_equal(
                        -p,
                        tf.range(tf.size(p)),
                        message=(
                            "Permutation over `d` must contain exactly one of "
                            "each of `{0, 1, ..., d}`.")),
                ], permutation)
            axis_ = tf.get_static_value(axis)
            if axis_ is None:
                raise NotImplementedError(
                    "`axis` must be known prior to graph "
                    "execution.")
            elif axis_ >= 0:
                raise NotImplementedError(
                    "`axis` must be relative the rightmost "
                    "dimension, i.e., negative.")
            else:
                forward_min_event_ndims = int(np.abs(axis_))
            self._permutation = permutation
            self._axis = axis
            super(Permute, self).__init__(
                forward_min_event_ndims=forward_min_event_ndims,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
Esempio n. 30
0
def _replace_event_shape_in_tensorshape(input_tensorshape, event_shape_in,
                                        event_shape_out):
    """Replaces the event shape dims of a `TensorShape`.

  Args:
    input_tensorshape: a `TensorShape` instance in which to attempt replacing
      event shape.
    event_shape_in: `Tensor` shape representing the event shape expected to
      be present in (rightmost dims of) `tensorshape_in`. Must be compatible
      with the rightmost dims of `tensorshape_in`.
    event_shape_out: `Tensor` shape representing the new event shape, i.e.,
      the replacement of `event_shape_in`,

  Returns:
    output_tensorshape: `TensorShape` with the rightmost `event_shape_in`
      replaced by `event_shape_out`. Might be partially defined, i.e.,
      `TensorShape(None)`.
    is_validated: Python `bool` indicating static validation happened.

  Raises:
    ValueError: if we can determine the event shape portion of
      `tensorshape_in` as well as `event_shape_in` both statically, and they
      are not compatible. "Compatible" here means that they are identical on
      any dims that are not -1 in `event_shape_in`.
  """
    event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape)
    if tensorshape_util.rank(
            input_tensorshape) is None or event_shape_in_ndims is None:
        return tf.TensorShape(None), False  # Not is_validated.

    input_non_event_ndims = tensorshape_util.rank(
        input_tensorshape) - event_shape_in_ndims
    if input_non_event_ndims < 0:
        raise ValueError(
            'Input has fewer ndims ({}) than event shape ndims ({}).'.format(
                tensorshape_util.rank(input_tensorshape),
                event_shape_in_ndims))

    input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims]
    input_event_tensorshape = input_tensorshape[input_non_event_ndims:]

    # Check that `input_event_shape_` and `event_shape_in` are compatible in the
    # sense that they have equal entries in any position that isn't a `-1` in
    # `event_shape_in`. Note that our validations at construction time ensure
    # there is at most one such entry in `event_shape_in`.
    event_shape_in_ = tf.get_static_value(event_shape_in)
    is_validated = (tensorshape_util.is_fully_defined(input_event_tensorshape)
                    and event_shape_in_ is not None)
    if is_validated:
        input_event_shape_ = np.int32(input_event_tensorshape)
        mask = event_shape_in_ >= 0
        explicit_input_event_shape_ = input_event_shape_[mask]
        explicit_event_shape_in_ = event_shape_in_[mask]
        if not np.all(explicit_input_event_shape_ == explicit_event_shape_in_):
            raise ValueError(
                'Input `event_shape` does not match `event_shape_in`. '
                '({} vs {}).'.format(input_event_shape_, event_shape_in_))

    event_tensorshape_out = tensorshape_util.constant_value_as_shape(
        event_shape_out)
    if tensorshape_util.rank(event_tensorshape_out) is None:
        output_tensorshape = tf.TensorShape(None)
    else:
        output_tensorshape = tensorshape_util.concatenate(
            input_non_event_tensorshape, event_tensorshape_out)

    return output_tensorshape, is_validated