Beispiel #1
0
 def _event_shape(self):
     sample_shape = tf.TensorShape(tf.get_static_value(self.sample_shape))
     if (tensorshape_util.rank(sample_shape) is None or
             tensorshape_util.rank(self.distribution.event_shape) is None):
         return tf.TensorShape(None)
     return tensorshape_util.concatenate(sample_shape,
                                         self.distribution.event_shape)
Beispiel #2
0
 def _event_shape(self):
     batch_shape = self.distribution.batch_shape
     if self._static_reinterpreted_batch_ndims is None:
         return tf.TensorShape(None)
     if tensorshape_util.rank(batch_shape) is not None:
         reinterpreted_batch_shape = batch_shape[
             tensorshape_util.rank(batch_shape) -
             self._static_reinterpreted_batch_ndims:]
     else:
         reinterpreted_batch_shape = tf.TensorShape(
             [None] * int(self._static_reinterpreted_batch_ndims))
     return tensorshape_util.concatenate(reinterpreted_batch_shape,
                                         self.distribution.event_shape)
 def _call_and_reshape_output(self,
                              fn,
                              event_shape_list=None,
                              static_event_shape_list=None,
                              extra_kwargs=None):
     """Calls `fn` and appropriately reshapes its output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn`, `event_shape_list`, `static_event_shape_list` and/or
     # `extra_kwargs` as keys.
     with tf.control_dependencies(self._runtime_assertions):
         if event_shape_list is None:
             event_shape_list = [self._event_shape_tensor()]
         if static_event_shape_list is None:
             static_event_shape_list = [self.event_shape]
         new_shape = tf.concat([self._batch_shape_unexpanded] +
                               event_shape_list,
                               axis=0)
         result = tf.reshape(
             fn(**extra_kwargs) if extra_kwargs else fn(), new_shape)
         if (tensorshape_util.rank(self.batch_shape) is not None
                 and tensorshape_util.rank(self.event_shape) is not None):
             event_shape = tf.TensorShape([])
             for rss in static_event_shape_list:
                 event_shape = tensorshape_util.concatenate(
                     event_shape, rss)
             static_shape = tensorshape_util.concatenate(
                 self.batch_shape, event_shape)
             tensorshape_util.set_shape(result, static_shape)
         return result
Beispiel #4
0
def _size(input, out_type=tf.int32, name=None):  # pylint: disable=redefined-builtin
  if not hasattr(input, 'shape'):
    x = np.array(input)
    input = tf.convert_to_tensor(input) if x.dtype is np.object else x
  n = tensorshape_util.num_elements(tf.TensorShape(input.shape))
  if n is None:
    return tf.size(input, out_type=out_type, name=name)
  return np.array(n).astype(_numpy_dtype(out_type))
 def _get_final_shape(qs):
     """Helper to build `TensorShape`."""
     bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1)
     num_components = tf.compat.dimension_value(bs[-1])
     if num_components is not None:
         num_components += 1
     tail = tf.TensorShape([num_components, qs])
     return bs[:-1].concatenate(tail)
Beispiel #6
0
 def _batch_shape(self):
     batch_shape = self.distribution.batch_shape
     if (self._static_reinterpreted_batch_ndims is None
             or tensorshape_util.rank(batch_shape) is None):
         return tf.TensorShape(None)
     d = (tensorshape_util.rank(batch_shape) -
          self._static_reinterpreted_batch_ndims)
     return batch_shape[:d]
Beispiel #7
0
def _shape(input, out_type=tf.int32, name=None):  # pylint: disable=redefined-builtin
  if not hasattr(input, 'shape'):
    x = np.array(input)
    input = tf.convert_to_tensor(input) if x.dtype is np.object else x
  input_shape = tf.TensorShape(input.shape)
  if tensorshape_util.is_fully_defined(input.shape):
    return np.array(tensorshape_util.as_list(input_shape)).astype(
        _numpy_dtype(out_type))
  return tf.shape(input, out_type=out_type, name=name)
Beispiel #8
0
def broadcast_shape(x_shape, y_shape):
  """Computes the shape of a broadcast.

  When both arguments are statically-known, the broadcasted shape will be
  computed statically and returned as a `TensorShape`.  Otherwise, a rank-1
  `Tensor` will be returned.

  Arguments:
    x_shape: A `TensorShape` or rank-1 integer `Tensor`.  The input `Tensor` is
      broadcast against this shape.
    y_shape: A `TensorShape` or rank-1 integer `Tensor`.  The input `Tensor` is
      broadcast against this shape.

  Returns:
    shape: A `TensorShape` or rank-1 integer `Tensor` representing the
      broadcasted shape.
  """
  x_shape_static = tf.get_static_value(x_shape)
  y_shape_static = tf.get_static_value(y_shape)
  if (x_shape_static is None) or (y_shape_static is None):
    return tf.broadcast_dynamic_shape(x_shape, y_shape)

  return tf.broadcast_static_shape(
      tf.TensorShape(x_shape_static), tf.TensorShape(y_shape_static))
    def event_shape(self):
        """Shape of a single sample from a single batch as a `TensorShape`.

    May be partially defined or unknown.

    Returns:
      event_shape: `tuple` of `TensorShape`s representing the `event_shape` for
        each distribution in `model`.
    """
        # The following cannot leak graph Tensors in eager because `batch_shape` is
        # a `TensorShape`.
        if self._most_recently_built_distributions is None:
            return None
        return self._model_unflatten(
            tf.TensorShape(None) if d is None else d.event_shape
            for d in self._most_recently_built_distributions)
    def batch_shape(self):
        """Shape of a single sample from a single event index as a `TensorShape`.

    May be partially defined or unknown.

    The batch dimensions are indexes into independent, non-identical
    parameterizations of this distribution.

    Returns:
      batch_shape: `tuple` of `TensorShape`s representing the `batch_shape` for
        each distribution in `model`.
    """
        # The following cannot leak graph Tensors in eager because `batch_shape` is
        # a `TensorShape`.
        if self._most_recently_built_distributions is None:
            return None
        return self._model_unflatten(
            tf.TensorShape(None) if d is None else d.batch_shape
            for d in self._most_recently_built_distributions)
 def _sample_shape(self, x):
     """Computes graph and static `sample_shape`."""
     x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else
                tensorshape_util.rank(x.shape))
     event_ndims = (tf.size(self.event_shape_tensor())
                    if tensorshape_util.rank(self.event_shape) is None else
                    tensorshape_util.rank(self.event_shape))
     batch_ndims = (tf.size(self._batch_shape_unexpanded)
                    if tensorshape_util.rank(self.batch_shape) is None else
                    tensorshape_util.rank(self.batch_shape))
     sample_ndims = x_ndims - batch_ndims - event_ndims
     if isinstance(sample_ndims, int):
         static_sample_shape = x.shape[:sample_ndims]
     else:
         static_sample_shape = tf.TensorShape(None)
     if tensorshape_util.is_fully_defined(static_sample_shape):
         sample_shape = np.int32(static_sample_shape)
     else:
         sample_shape = tf.shape(x)[:sample_ndims]
     return sample_shape, static_sample_shape
def determine_batch_event_shapes(grid, endpoint_affine):
    """Helper to infer batch_shape and event_shape."""
    with tf.name_scope("determine_batch_event_shapes"):
        # grid  # shape: [B, k, q]
        # endpoint_affine     # len=k, shape: [B, d, d]
        batch_shape = grid.shape[:-2]
        batch_shape_tensor = tf.shape(grid)[:-2]
        event_shape = None
        event_shape_tensor = None

        def _set_event_shape(shape, shape_tensor):
            if event_shape is None:
                return shape, shape_tensor
            return (tf.broadcast_static_shape(event_shape, shape),
                    tf.broadcast_dynamic_shape(event_shape_tensor,
                                               shape_tensor))

        for aff in endpoint_affine:
            if aff.shift is not None:
                batch_shape = tf.broadcast_static_shape(
                    batch_shape, aff.shift.shape[:-1])
                batch_shape_tensor = tf.broadcast_dynamic_shape(
                    batch_shape_tensor,
                    tf.shape(aff.shift)[:-1])
                event_shape, event_shape_tensor = _set_event_shape(
                    aff.shift.shape[-1:],
                    tf.shape(aff.shift)[-1:])

            if aff.scale is not None:
                batch_shape = tf.broadcast_static_shape(
                    batch_shape, aff.scale.batch_shape)
                batch_shape_tensor = tf.broadcast_dynamic_shape(
                    batch_shape_tensor, aff.scale.batch_shape_tensor())
                event_shape, event_shape_tensor = _set_event_shape(
                    tf.TensorShape([aff.scale.range_dimension]),
                    aff.scale.range_dimension_tensor()[tf.newaxis])

        return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor
Beispiel #13
0
 def _event_shape(self, shape, static_perm_to_shape):
     """Helper for _forward and _inverse_event_shape."""
     rightmost_ = tf.get_static_value(self.rightmost_transposed_ndims)
     if tensorshape_util.rank(shape) is None or rightmost_ is None:
         return tf.TensorShape(None)
     if tensorshape_util.rank(shape) < rightmost_:
         raise ValueError(
             'Invalid shape: min event ndims={} but got {}'.format(
                 rightmost_, shape))
     perm_ = tf.get_static_value(self.perm, partial=True)
     if perm_ is None:
         return shape[:tensorshape_util.rank(shape) -
                      rightmost_].concatenate([None] * int(rightmost_))
     # We can use elimination to reidentify a single None dimension.
     if sum(p is None for p in perm_) == 1:
         present = np.argsort([-1 if p is None else p for p in perm_])
         for i, p in enumerate(present[1:]):  # The -1 sorts to position 0.
             if i != p:
                 perm_ = [i if p is None else p for p in perm_]
                 break
     return shape[:tensorshape_util.rank(shape) - rightmost_].concatenate(
         static_perm_to_shape(
             shape[tensorshape_util.rank(shape) - rightmost_:], perm_))
Beispiel #14
0
 def static_perm_to_shape(subshp, perm):
     result = [None] * len(perm)
     for i, p in enumerate(perm):
         if p is not None:
             result[p] = subshp[i]
     return tf.TensorShape(result)
Beispiel #15
0
 def static_perm_to_shape(subshp, perm):
     return tf.TensorShape(
         [None if p is None else subshp[p] for p in perm])
Beispiel #16
0
 def _event_shape(self):
     event_sizes = tf.nest.map_structure(tensorshape_util.num_elements,
                                         self._distribution.event_shape)
     if any(r is None for r in tf.nest.flatten(event_sizes)):
         return tf.TensorShape([None])
     return tf.TensorShape([sum(tf.nest.flatten(event_sizes))])
Beispiel #17
0
 def _batch_shape(self):
     return functools.reduce(tensorshape_util.merge_with,
                             tf.nest.flatten(self._cached_batch_shape),
                             tf.TensorShape(None))
Beispiel #18
0
def _replace_event_shape_in_tensorshape(input_tensorshape, event_shape_in,
                                        event_shape_out):
    """Replaces the event shape dims of a `TensorShape`.

  Args:
    input_tensorshape: a `TensorShape` instance in which to attempt replacing
      event shape.
    event_shape_in: `Tensor` shape representing the event shape expected to
      be present in (rightmost dims of) `tensorshape_in`. Must be compatible
      with the rightmost dims of `tensorshape_in`.
    event_shape_out: `Tensor` shape representing the new event shape, i.e.,
      the replacement of `event_shape_in`,

  Returns:
    output_tensorshape: `TensorShape` with the rightmost `event_shape_in`
      replaced by `event_shape_out`. Might be partially defined, i.e.,
      `TensorShape(None)`.
    is_validated: Python `bool` indicating static validation happened.

  Raises:
    ValueError: if we can determine the event shape portion of
      `tensorshape_in` as well as `event_shape_in` both statically, and they
      are not compatible. "Compatible" here means that they are identical on
      any dims that are not -1 in `event_shape_in`.
  """
    event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape)
    if tensorshape_util.rank(
            input_tensorshape) is None or event_shape_in_ndims is None:
        return tf.TensorShape(None), False  # Not is_validated.

    input_non_event_ndims = tensorshape_util.rank(
        input_tensorshape) - event_shape_in_ndims
    if input_non_event_ndims < 0:
        raise ValueError(
            'Input has fewer ndims ({}) than event shape ndims ({}).'.format(
                tensorshape_util.rank(input_tensorshape),
                event_shape_in_ndims))

    input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims]
    input_event_tensorshape = input_tensorshape[input_non_event_ndims:]

    # Check that `input_event_shape_` and `event_shape_in` are compatible in the
    # sense that they have equal entries in any position that isn't a `-1` in
    # `event_shape_in`. Note that our validations at construction time ensure
    # there is at most one such entry in `event_shape_in`.
    event_shape_in_ = tf.get_static_value(event_shape_in)
    is_validated = (tensorshape_util.is_fully_defined(input_event_tensorshape)
                    and event_shape_in_ is not None)
    if is_validated:
        input_event_shape_ = np.int32(input_event_tensorshape)
        mask = event_shape_in_ >= 0
        explicit_input_event_shape_ = input_event_shape_[mask]
        explicit_event_shape_in_ = event_shape_in_[mask]
        if not np.all(explicit_input_event_shape_ == explicit_event_shape_in_):
            raise ValueError(
                'Input `event_shape` does not match `event_shape_in`. '
                '({} vs {}).'.format(input_event_shape_, event_shape_in_))

    event_tensorshape_out = tensorshape_util.constant_value_as_shape(
        event_shape_out)
    if tensorshape_util.rank(event_tensorshape_out) is None:
        output_tensorshape = tf.TensorShape(None)
    else:
        output_tensorshape = tensorshape_util.concatenate(
            input_non_event_tensorshape, event_tensorshape_out)

    return output_tensorshape, is_validated
 def _batch_shape(self):
   return tensorshape_util.with_rank_at_least(
       tf.broadcast_static_shape(
           tf.TensorShape(self._total_count.shape).concatenate([1]),
           tf.TensorShape(self._concentration.shape)),
       1)[:-1]
Beispiel #20
0
 def _event_shape(self):
     dimension = self.scale_operator.domain_dimension
     return tf.TensorShape([dimension, dimension])
 def _entropy(self, **kwargs):
     return self._call_and_reshape_output(self.distribution.entropy, [],
                                          [tf.TensorShape([])],
                                          extra_kwargs=kwargs)
Beispiel #22
0
 def _batch_shape(self):
   if self.samples.shape.rank is None:
     return tf.TensorShape(None)
   return self.samples.shape[:self._samples_axis]
Beispiel #23
0
 def _event_shape(self):
   if self.samples.shape.rank is None:
     return tf.TensorShape(None)
   return self.samples.shape[self._samples_axis + 1:]
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (observation_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (transition_distribution.batch_shape
                     and transition_distribution.batch_shape[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (observation_distribution.batch_shape
                     and observation_distribution.batch_shape[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
Beispiel #25
0
 def _event_shape(self):
     return tf.TensorShape([])
Beispiel #26
0
 def _event_shape(self):
     return tf.TensorShape([self.dimension, self.dimension])