Example #1
0
def lu_reconstruct_assertions(lower_upper, perm, validate_args):
    """Returns list of assertions related to `lu_reconstruct` assumptions."""
    assertions = []

    message = 'Input `lower_upper` must have at least 2 dimensions.'
    if tensorshape_util.rank(lower_upper.shape) is not None:
        if tensorshape_util.rank(lower_upper.shape) < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(lower_upper,
                                             rank=2,
                                             message=message))

    message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
    if (tensorshape_util.rank(lower_upper.shape) is not None
            and tensorshape_util.rank(perm.shape) is not None):
        if (tensorshape_util.rank(lower_upper.shape) !=
                tensorshape_util.rank(perm.shape) + 1):
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(lower_upper,
                                    rank=tf.rank(perm) + 1,
                                    message=message))

    message = '`lower_upper` must be square.'
    if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]):
        if lower_upper.shape[-2] != lower_upper.shape[-1]:
            raise ValueError(message)
    elif validate_args:
        m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2)
        assertions.append(assert_util.assert_equal(m, n, message=message))

    return assertions
Example #2
0
    def _make_columnar(self, x):
        """Ensures non-scalar input has at least one column.

    Example:
      If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`.

      If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged.

      If `x = 1` then the output is unchanged.

    Args:
      x: `Tensor`.

    Returns:
      columnar_x: `Tensor` with at least two dimensions.
    """
        if tensorshape_util.rank(x.shape) is not None:
            if tensorshape_util.rank(x.shape) == 1:
                x = x[tf.newaxis, :]
            return x
        shape = tf.shape(x)
        maybe_expanded_shape = tf.concat([
            shape[:-1],
            distribution_util.pick_vector(tf.equal(tf.rank(x), 1), [1],
                                          np.array([], dtype=np.int32)),
            shape[-1:],
        ], 0)
        return tf.reshape(x, maybe_expanded_shape)
 def _pad_sample_dims(self, x):
     with tf.name_scope("pad_sample_dims"):
         ndims = tensorshape_util.rank(x.shape) if tensorshape_util.rank(
             x.shape) is not None else tf.rank(x)
         shape = tf.shape(x)
         d = ndims - self._event_ndims
         x = tf.reshape(x,
                        shape=tf.concat([shape[:d], [1], shape[d:]],
                                        axis=0))
         return x
def softmax(x, axis, name=None):
    """Equivalent to tf.math.softmax but works around b/70297725."""
    with tf.name_scope(name or "softmax"):
        x = tf.convert_to_tensor(x, name="x")
        ndims = (tensorshape_util.rank(x.shape) if tensorshape_util.rank(
            x.shape) is not None else tf.rank(x, name="ndims"))
        axis = tf.convert_to_tensor(axis, dtype=tf.int32, name="axis")
        axis_ = tf.get_static_value(axis)
        if axis_ is not None:
            axis = np.int(ndims + axis_ if axis_ < 0 else axis_)
        else:
            axis = tf.where(axis < 0, ndims + axis, axis)
    return tf.math.softmax(x, axis=axis)
 def _log_prob(self, x):
     # By convention, we always put the grid points right-most.
     y = tf.stack([aff.inverse(x) for aff in self.interpolated_affine],
                  axis=-1)
     log_prob = tf.reduce_sum(self.distribution.log_prob(y), axis=-2)
     # Because the affine transformation has a constant Jacobian, it is the case
     # that `affine.fldj(x) = -affine.ildj(x)`. This is not true in general.
     fldj = tf.stack([
         aff.forward_log_det_jacobian(
             x, event_ndims=tf.rank(self.event_shape_tensor()))
         for aff in self.interpolated_affine
     ],
                     axis=-1)
     return tf.reduce_logsumexp(self.mixture_distribution.logits - fldj +
                                log_prob,
                                axis=-1)
Example #6
0
def _broadcast_event_and_samples(event, samples, event_ndims):
  """Broadcasts the event or samples."""
  # This is the shape of self.samples, without the samples axis, i.e. the shape
  # of the result of a call to dist.sample(). This way we can broadcast it with
  # event to get a properly-sized event, then add the singleton dim back at
  # -event_ndims - 1.
  samples_shape = tf.concat(
      [
          tf.shape(samples)[:-event_ndims - 1],
          tf.shape(samples)[tf.rank(samples) - event_ndims:]
      ],
      axis=0)
  event = event * tf.ones(samples_shape, dtype=event.dtype)
  event = tf.expand_dims(event, axis=-event_ndims - 1)
  samples = samples * tf.ones_like(event, dtype=samples.dtype)

  return event, samples
Example #7
0
 def _prob(self, x):
     if self.validate_args:
         is_vector_check = assert_util.assert_rank_at_least(x, 1)
         right_vec_space_check = assert_util.assert_equal(
             self.event_shape_tensor(),
             tf.gather(tf.shape(x),
                       tf.rank(x) - 1),
             message=
             "Argument 'x' not defined in the same space R^k as this distribution"
         )
         with tf.control_dependencies([is_vector_check]):
             with tf.control_dependencies([right_vec_space_check]):
                 x = tf.identity(x)
     loc = tf.convert_to_tensor(self.loc)
     return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc),
                                  axis=-1),
                    dtype=self.dtype)
Example #8
0
def _maybe_check_valid_shape(shape, validate_args):
    """Check that a shape Tensor is int-type and otherwise sane."""
    if not dtype_util.is_integer(shape.dtype):
        raise TypeError('{} dtype ({}) should be `int`-like.'.format(
            shape, dtype_util.name(shape.dtype)))

    assertions = []

    message = '`{}` rank should be <= 1.'
    if tensorshape_util.rank(shape.shape) is not None:
        if tensorshape_util.rank(shape.shape) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.rank(shape),
                                    2,
                                    message=message.format(shape)))

    shape_ = tf.get_static_value(shape)

    message = '`{}` elements must have at most one `-1`.'
    if shape_ is not None:
        if sum(shape_ == -1) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.reduce_sum(
                tf.cast(tf.equal(shape, -1), tf.int32)),
                                    2,
                                    message=message.format(shape)))

    message = '`{}` elements must be either positive integers or `-1`.'
    if shape_ is not None:
        if np.any(shape_ < -1):
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(shape,
                                       -2,
                                       message=message.format(shape)))

    return assertions
 def _sample_shape(self, x):
     """Computes graph and static `sample_shape`."""
     x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else
                tensorshape_util.rank(x.shape))
     event_ndims = (tf.size(self.event_shape_tensor())
                    if tensorshape_util.rank(self.event_shape) is None else
                    tensorshape_util.rank(self.event_shape))
     batch_ndims = (tf.size(self._batch_shape_unexpanded)
                    if tensorshape_util.rank(self.batch_shape) is None else
                    tensorshape_util.rank(self.batch_shape))
     sample_ndims = x_ndims - batch_ndims - event_ndims
     if isinstance(sample_ndims, int):
         static_sample_shape = x.shape[:sample_ndims]
     else:
         static_sample_shape = tf.TensorShape(None)
     if tensorshape_util.is_fully_defined(static_sample_shape):
         sample_shape = np.int32(static_sample_shape)
     else:
         sample_shape = tf.shape(x)[:sample_ndims]
     return sample_shape, static_sample_shape
Example #10
0
    def _forward_log_det_jacobian(self, x, **kwargs):
        x = tf.convert_to_tensor(x, name="x")

        fldj = tf.cast(0., dtype=dtype_util.base_dtype(x.dtype))

        if not self.bijectors:
            return fldj

        event_ndims = self._maybe_get_static_event_ndims(
            self.forward_min_event_ndims)

        if _use_static_shape(x, event_ndims):
            event_shape = x.shape[tensorshape_util.rank(x.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(x)[tf.rank(x) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in reversed(self.bijectors):
            fldj = fldj + b.forward_log_det_jacobian(
                x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
            if _use_static_shape(x, event_ndims):
                event_shape = b.forward_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.forward_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            x = b.forward(x, **kwargs.get(b.name, {}))

        return fldj
Example #11
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
  """Helper to validate block sizes."""
  block_sizes_shape = block_sizes.shape
  if tensorshape_util.is_fully_defined(block_sizes_shape):
    if (tensorshape_util.rank(block_sizes_shape) != 1 or
        (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))):
      raise ValueError(
          '`block_sizes` must be `None`, or a vector of the same length as '
          '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
          'length {}'.format(block_sizes_shape, len(bijectors)))
    return block_sizes
  elif validate_args:
    message = ('`block_sizes` must be `None`, or a vector of the same length '
               'as `bijectors`.')
    with tf.control_dependencies([
        assert_util.assert_equal(
            tf.size(block_sizes), len(bijectors), message=message),
        assert_util.assert_equal(tf.rank(block_sizes), 1)
    ]):
      return tf.identity(block_sizes)
  else:
    return block_sizes
Example #12
0
def assert_rank_at_most(x,
                        rank,
                        data=None,
                        summarize=None,
                        message=None,
                        name=None):
    """Assert `x` has rank equal to `rank` or smaller.

  Example of adding a dependency to an operation:

  ```python
  with tf.control_dependencies([tf.assert_rank_at_most(x, 2)]):
    output = tf.reduce_sum(x)
  ```

  Args:
    x:  Numeric `Tensor`.
    rank:  Scalar `Tensor`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).
      Defaults to "assert_rank_at_most".

  Returns:
    Op raising `InvalidArgumentError` unless `x` has specified rank or lower.
    If static checks determine `x` has correct rank, a `no_op` is returned.

  Raises:
    ValueError:  If static checks determine `x` has wrong rank.
  """
    with tf.name_scope(name or 'assert_rank_at_most'):
        return tf1.assert_less_equal(tf.rank(x),
                                     rank,
                                     data=data,
                                     summarize=summarize,
                                     message=message)
Example #13
0
    def _inverse_log_det_jacobian(self, y, **kwargs):
        y = tf.convert_to_tensor(y, name="y")
        ildj = tf.cast(0., dtype=dtype_util.base_dtype(y.dtype))

        if not self.bijectors:
            return ildj

        event_ndims = self._maybe_get_static_event_ndims(
            self.inverse_min_event_ndims)

        if _use_static_shape(y, event_ndims):
            event_shape = y.shape[tensorshape_util.rank(y.shape) -
                                  event_ndims:]
        else:
            event_shape = tf.shape(y)[tf.rank(y) - event_ndims:]

        # TODO(b/129973548): Document and simplify.
        for b in self.bijectors:
            ildj = ildj + b.inverse_log_det_jacobian(
                y, event_ndims=event_ndims, **kwargs.get(b.name, {}))

            if _use_static_shape(y, event_ndims):
                event_shape = b.inverse_event_shape(event_shape)
                event_ndims = self._maybe_get_static_event_ndims(
                    tensorshape_util.rank(event_shape))
            else:
                event_shape = b.inverse_event_shape_tensor(event_shape)
                event_shape_ = distribution_util.maybe_get_static_value(
                    event_shape)
                event_ndims = tf.size(event_shape)
                event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)

                if event_ndims_ is not None and event_shape_ is not None:
                    event_ndims = event_ndims_
                    event_shape = event_shape_

            y = b.inverse(y, **kwargs.get(b.name, {}))
        return ildj
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (observation_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (transition_distribution.batch_shape
                     and transition_distribution.batch_shape[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (observation_distribution.batch_shape
                     and observation_distribution.batch_shape[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
Example #15
0
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape):
    """Slices a single parameter of a distribution.

  Args:
    param: A `Tensor`, the original parameter to slice.
    param_event_ndims: `int` event parameterization rank for this parameter.
    slices: A `tuple` of normalized slices.
    dist_batch_shape: The distribution's batch shape `Tensor`.

  Returns:
    new_param: A `Tensor`, batch-sliced according to slices.
  """
    # Extend param shape with ones on the left to match dist_batch_shape.
    param_shape = tf.shape(input=param)
    insert_ones = tf.ones(
        [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)],
        dtype=param_shape.dtype)
    new_param_shape = tf.concat([insert_ones, param_shape], axis=0)
    full_batch_param = tf.reshape(param, new_param_shape)
    param_slices = []
    # We separately track the batch axis from the parameter axis because we want
    # them to align for positive indexing, and be offset by param_event_ndims for
    # negative indexing.
    param_dim_idx = 0
    batch_dim_idx = 0
    for slc in slices:
        if slc is tf.newaxis:
            param_slices.append(slc)
            continue
        if slc is Ellipsis:
            if batch_dim_idx < 0:
                raise ValueError(
                    'Found multiple `...` in slices {}'.format(slices))
            param_slices.append(slc)
            # Switch over to negative indexing for the broadcast check.
            num_remaining_non_newaxis_slices = sum([
                s is not tf.newaxis
                for s in slices[slices.index(Ellipsis) + 1:]
            ])
            batch_dim_idx = -num_remaining_non_newaxis_slices
            param_dim_idx = batch_dim_idx - param_event_ndims
            continue
        # Find the batch dimension sizes for both parameter and distribution.
        param_dim_size = new_param_shape[param_dim_idx]
        batch_dim_size = dist_batch_shape[batch_dim_idx]
        is_broadcast = batch_dim_size > param_dim_size
        # Slices are denoted by start:stop:step.
        if isinstance(slc, slice):
            start, stop, step = slc.start, slc.stop, slc.step
            if start is not None:
                start = tf.where(is_broadcast, 0, start)
            if stop is not None:
                stop = tf.where(is_broadcast, 1, stop)
            if step is not None:
                step = tf.where(is_broadcast, 1, step)
            param_slices.append(slice(start, stop, step))
        else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
            param_slices.append(tf.where(is_broadcast, 0, slc))
        param_dim_idx += 1
        batch_dim_idx += 1
    param_slices.extend([ALL_SLICE] * param_event_ndims)
    return full_batch_param.__getitem__(param_slices)
    def _validate_sample_arg(self, x):
        """Helper which validates sample arg, e.g., input to `log_prob`."""
        with tf.name_scope('validate_sample_arg'):
            x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None
                       else tensorshape_util.rank(x.shape))
            event_ndims = (tf.size(self.event_shape_tensor())
                           if tensorshape_util.rank(self.event_shape) is None
                           else tensorshape_util.rank(self.event_shape))
            batch_ndims = (tf.size(self._batch_shape_unexpanded)
                           if tensorshape_util.rank(self.batch_shape) is None
                           else tensorshape_util.rank(self.batch_shape))
            expected_batch_event_ndims = batch_ndims + event_ndims

            if (isinstance(x_ndims, int)
                    and isinstance(expected_batch_event_ndims, int)):
                if x_ndims < expected_batch_event_ndims:
                    raise NotImplementedError(
                        'Broadcasting is not supported; too few batch and event dims '
                        '(expected at least {}, saw {}).'.format(
                            expected_batch_event_ndims, x_ndims))
                ndims_assertion = []
            elif self.validate_args:
                ndims_assertion = [
                    assert_util.assert_greater_equal(
                        x_ndims,
                        expected_batch_event_ndims,
                        message=('Broadcasting is not supported; too few '
                                 'batch and event dims.'),
                        name='assert_batch_and_event_ndims_large_enough'),
                ]

            if (tensorshape_util.is_fully_defined(self.batch_shape)
                    and tensorshape_util.is_fully_defined(self.event_shape)):
                expected_batch_event_shape = np.int32(
                    tensorshape_util.concatenate(self.batch_shape,
                                                 self.event_shape))
            else:
                expected_batch_event_shape = tf.concat([
                    self.batch_shape_tensor(),
                    self.event_shape_tensor(),
                ],
                                                       axis=0)

            sample_ndims = x_ndims - expected_batch_event_ndims
            if isinstance(sample_ndims, int):
                sample_ndims = max(sample_ndims, 0)
            if (isinstance(sample_ndims, int) and
                    tensorshape_util.is_fully_defined(x.shape[sample_ndims:])):
                actual_batch_event_shape = np.int32(x.shape[sample_ndims:])
            else:
                sample_ndims = tf.maximum(sample_ndims, 0)
                actual_batch_event_shape = tf.shape(x)[sample_ndims:]

            if (isinstance(expected_batch_event_shape, np.ndarray)
                    and isinstance(actual_batch_event_shape, np.ndarray)):
                if any(expected_batch_event_shape != actual_batch_event_shape):
                    raise NotImplementedError(
                        'Broadcasting is not supported; '
                        'unexpected batch and event shape '
                        '(expected {}, saw {}).'.format(
                            expected_batch_event_shape,
                            actual_batch_event_shape))
                # We need to set the final runtime-assertions to `ndims_assertion` since
                # its possible this assertion was created. We could add a condition to
                # only do so if `self.validate_args == True`, however this is redundant
                # as `ndims_assertion` already encodes this information.
                runtime_assertions = ndims_assertion
            elif self.validate_args:
                # We need to make the `ndims_assertion` a control dep because otherwise
                # TF itself might raise an exception owing to this assertion being
                # ill-defined, ie, one cannot even compare different rank Tensors.
                with tf.control_dependencies(ndims_assertion):
                    shape_assertion = assert_util.assert_equal(
                        expected_batch_event_shape,
                        actual_batch_event_shape,
                        message=('Broadcasting is not supported; '
                                 'unexpected batch and event shape.'),
                        name='assert_batch_and_event_shape_same')
                runtime_assertions = [shape_assertion]
            else:
                runtime_assertions = []

            return runtime_assertions
Example #17
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        x_ndims = tf.rank(x_sqrt)
        num_singleton_axes_to_prepend = (
            tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = tf.concat([
            tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            tf.shape(x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = tf.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - tf.size(batch_shape) - 2
        sample_shape = tf.shape(x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = tf.concat(
            [tf.range(sample_ndims, ndims),
             tf.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            tf.cast(self.dimension, dtype=tf.int32) *
            tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims]))
        shape = tf.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [tf.cast(self.dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self.scale_operator.solve(
            scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat(
            [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape],
            axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = tf.concat([
            tf.range(ndims - sample_ndims, ndims),
            tf.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt),
                                          axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x - self.log_normalization())

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            tensorshape_util.set_shape(
                log_prob,
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob
Example #18
0
  s_ = tf.get_static_value(s)
  if s_ is not None:
    return np.ones(s_, dtype_util.as_numpy_dtype(dtype or input.dtype))
  return tf.ones(s, dtype or s.dtype, name)
ones_like = _copy_docstring(tf.ones_like, _ones_like)

range = _prefer_static(  # pylint: disable=redefined-builtin
    tf.range,
    lambda start, limit=None, delta=1, dtype=None, name='range': np.arange(  # pylint: disable=g-long-lambda
        start, limit, delta).astype(_numpy_dtype(
            dtype or np.array(tf.get_static_value(start)).dtype)))

rank = _copy_docstring(
    tf.rank,
    lambda input, name=None: (  # pylint: disable=redefined-builtin,g-long-lambda
        tf.rank(input) if tensorshape_util.rank(input.shape) is None
        else np.int32(tensorshape_util.rank(input.shape))))

reduce_all = _prefer_static(
    tf.reduce_all,
    lambda input_tensor, axis=None, keepdims=False, name=None: np.all(  # pylint: disable=g-long-lambda
        input_tensor, axis, keepdims=keepdims))

reduce_any = _prefer_static(
    tf.reduce_any,
    lambda input_tensor, axis=None, keepdims=False, name=None: np.any(  # pylint: disable=g-long-lambda
        input_tensor, axis, keepdims=keepdims))

reduce_prod = _prefer_static(
    tf.reduce_prod,
    lambda input_tensor, axis=None, keepdims=False, name=None: np.prod(  # pylint: disable=g-long-lambda
Example #19
0
 def _transpose(self, x, perm):
     perm = self._make_perm(tf.rank(x), perm)
     return tf.transpose(a=x, perm=perm)
    def _observation_log_probs(self, observations, mask):
        """Compute and shape tensor of log probs associated with observations.."""

        # Let E be the underlying event shape
        #     M the number of steps in the HMM
        #     N the number of states of the HMM
        #
        # Then the incoming observations have shape
        #
        # observations : batch_o [M] E
        #
        # and the mask (if present) has shape
        #
        # mask : batch_m [M]
        #
        # Let this HMM distribution have batch shape batch_d
        # We need to broadcast all three of these batch shapes together
        # into the shape batch.
        #
        # We need to move the step dimension to the first dimension to make
        # them suitable for folding or scanning over.
        #
        # When we call `log_prob` for our observations we need to
        # do this for each state the observation could correspond to.
        # We do this by expanding the dimensions by 1 so we end up with:
        #
        # observations : [M] batch [1] [E]
        #
        # After calling `log_prob` we get
        #
        # observation_log_probs : [M] batch [N]
        #
        # We wish to use `mask` to select from this so we also
        # reshape and broadcast it up to shape
        #
        # mask : [M] batch [N]

        observation_tensor_shape = tf.shape(observations)
        observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                           _underlying_event_rank]
        observation_event_shape = observation_tensor_shape[
            -1 - self._underlying_event_rank:]

        if mask is not None:
            mask_tensor_shape = tf.shape(mask)
            mask_batch_shape = mask_tensor_shape[:-1]

        batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())

        if mask is not None:
            batch_shape = tf.broadcast_dynamic_shape(batch_shape,
                                                     mask_batch_shape)
        observations = tf.broadcast_to(
            observations,
            tf.concat([batch_shape, observation_event_shape], axis=0))
        observation_rank = tf.rank(observations)
        underlying_event_rank = self._underlying_event_rank
        observations = distribution_util.move_dimension(
            observations, observation_rank - underlying_event_rank - 1, 0)
        observations = tf.expand_dims(observations,
                                      observation_rank - underlying_event_rank)
        observation_log_probs = self._observation_distribution.log_prob(
            observations)

        if mask is not None:
            mask = tf.broadcast_to(
                mask, tf.concat([batch_shape, [self._num_steps]], axis=0))
            mask = distribution_util.move_dimension(mask, -1, 0)
            observation_log_probs = tf.where(
                mask[..., tf.newaxis], tf.zeros_like(observation_log_probs),
                observation_log_probs)

        return observation_log_probs
Example #21
0
    def _forward_log_det_jacobian(self, x):
        # Let Y be a symmetric, positive definite matrix and write:
        #   Y = X X.T
        # where X is lower-triangular.
        #
        # Observe that,
        #   dY[i,j]/dX[a,b]
        #   = d/dX[a,b] { X[i,:] X[j,:] }
        #   = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] }
        #
        # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is
        # symmetric and X is lower-triangular, we need vectors of dimension:
        #   d = p (p + 1) / 2
        # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e.,
        #   k = { i (i + 1) / 2 + j   i>=j
        #       { undef               i<j
        # and assume zero-based indexes. When k is undef, the element is dropped.
        # Example:
        #           j      k
        #        0 1 2 3  /
        #    0 [ 0 . . . ]
        # i  1 [ 1 2 . . ]
        #    2 [ 3 4 5 . ]
        #    3 [ 6 7 8 9 ]
        # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With
        # slight abuse: k(i,j)=undef means the element is dropped.)
        #
        # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are
        # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b.
        # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since:
        # (1) j<=i<a thus i,j!=a.
        # (2) i=a>j  thus i,j!=a.
        #
        # Since the Jacobian is lower-triangular, we need only compute the product
        # of diagonal elements:
        #   d vec[Y] / d vec[X] @[k(i,j), k(i,j)]
        #   = X[j,j] + I[i=j] X[i,j]
        #   = 2 X[j,j].
        # Since there is a 2 X[j,j] term for every lower-triangular element of X we
        # conclude:
        #   |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
        diag = tf.linalg.diag_part(x)

        # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output
        # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the
        # output is unchanged.
        diag = self._make_columnar(diag)

        with tf.control_dependencies(self._assertions(x)):
            # Create a vector equal to: [p, p-1, ..., 2, 1].
            if tf.compat.dimension_value(x.shape[-1]) is None:
                p_int = tf.shape(x)[-1]
                p_float = tf.cast(p_int, dtype=x.dtype)
            else:
                p_int = tf.compat.dimension_value(x.shape[-1])
                p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int)
            exponents = tf.linspace(p_float, 1., p_int)

            sum_weighted_log_diag = tf.squeeze(tf.matmul(
                tf.math.log(diag), exponents[..., tf.newaxis]),
                                               axis=-1)
            fldj = p_float * np.log(2.) + sum_weighted_log_diag

            # We finally need to undo adding an extra column in non-scalar cases
            # where there is a single matrix as input.
            if tensorshape_util.rank(x.shape) is not None:
                if tensorshape_util.rank(x.shape) == 2:
                    fldj = tf.squeeze(fldj, axis=-1)
                return fldj

            shape = tf.shape(fldj)
            maybe_squeeze_shape = tf.concat([
                shape[:-1],
                distribution_util.pick_vector(tf.equal(
                    tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:])
            ], 0)
            return tf.reshape(fldj, maybe_squeeze_shape)
Example #22
0
def _replicate(n, tensor):
    """Replicate the input tensor n times along a new (major) dimension."""
    # TODO(axch) Does this already exist somewhere?  Should it get contributed?
    multiples = tf.concat([[n], tf.ones([tf.rank(tensor)], dtype=n.dtype)],
                          axis=0)
    return tf.tile(tensor[tf.newaxis], multiples)