Пример #1
0
 def _maybe_assert_valid_sample(self, x, loc):
     """Checks the validity of a sample."""
     if not self.validate_args:
         return []
     return [
         assert_util.assert_greater_equal(
             x, loc, message='x is not in the support of the distribution')
     ]
Пример #2
0
  def _prob(self, x):
    concentration = tf.convert_to_tensor(self.concentration)
    scale = tf.convert_to_tensor(self.scale)
    with tf.control_dependencies([
        assert_util.assert_greater_equal(
            x, scale, message='`x` is not in the support of the distribution.')
    ] if self.validate_args else []):

      def prob_on_support(z):
        return concentration * (scale**concentration) / (z**(concentration + 1))

      return self._extend_support(x, scale, prob_on_support, alt=0.)
Пример #3
0
    def _assert_compatible_shape(self, index, sample_shape, samples):
        requested_shape, _ = self._expand_sample_shape_to_vector(
            tf.convert_to_tensor(sample_shape, dtype=tf.int32),
            name='requested_shape')
        actual_shape = prefer_static.shape(samples)
        actual_rank = prefer_static.rank_from_shape(actual_shape)
        requested_rank = prefer_static.rank_from_shape(requested_shape)

        # We test for two properties we expect of yielded distributions:
        # (1) The rank of the tensor of generated samples must be at least
        #     as large as the rank requested.
        # (2) The requested shape must be a prefix of the shape of the
        #     generated tensor of samples.
        # We attempt to perform test (1) statically first.
        # We don't need to do this explicitly for test (2) because
        # `assert_equal` evaluates statically if it can.
        static_actual_rank = tf.get_static_value(actual_rank)
        static_requested_rank = tf.get_static_value(requested_rank)

        assertion_message = ('Samples yielded by distribution #{} are not '
                             'consistent with `sample_shape` passed to '
                             '`JointDistributionCoroutine` '
                             'distribution.'.format(index))

        # TODO Remove this static check (b/138738650)
        if (static_actual_rank is not None
                and static_requested_rank is not None):
            # We're able to statically check the rank
            if static_actual_rank < static_requested_rank:
                raise ValueError(assertion_message)
            else:
                control_dependencies = []
        else:
            # We're not able to statically check the rank
            control_dependencies = [
                assert_util.assert_greater_equal(actual_rank,
                                                 requested_rank,
                                                 message=assertion_message)
            ]

        with tf.control_dependencies(control_dependencies):
            trimmed_actual_shape = actual_shape[:requested_rank]

        control_dependencies = [
            assert_util.assert_equal(requested_shape,
                                     trimmed_actual_shape,
                                     message=assertion_message)
        ]

        return control_dependencies
Пример #4
0
  def _log_prob(self, x):
    concentration = tf.convert_to_tensor(self.concentration)
    scale = tf.convert_to_tensor(self.scale)
    with tf.control_dependencies([
        assert_util.assert_greater_equal(
            x, scale, message='`x` is not in the support of the distribution.')
    ] if self.validate_args else []):

      def log_prob_on_support(z):
        return (tf.math.log(concentration) +
                concentration * tf.math.log(scale) -
                (concentration + 1.) * tf.math.log(z))

      return self._extend_support(
          x, scale, log_prob_on_support, alt=-np.inf)
Пример #5
0
    def _prob(self, x):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        if self.validate_args:
            with tf.control_dependencies([
                    assert_util.assert_greater_equal(x, low),
                    assert_util.assert_less_equal(x, high)
            ]):
                x = tf.identity(x)

        interval_length = high - low
        # This is the pdf function when a low <= high <= x. This looks like
        # a triangle, so we have to treat each line segment separately.
        result_inside_interval = tf.where(
            (x >= low) & (x <= peak),
            # Line segment from (low, 0) to (peak, 2 / (high - low)).
            2. * (x - low) / (interval_length * (peak - low)),
            # Line segment from (peak, 2 / (high - low)) to (high, 0).
            2. * (high - x) / (interval_length * (high - peak)))

        return tf.where((x < low) | (x > high), tf.zeros_like(x),
                        result_inside_interval)
Пример #6
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        logits = self._logits
        probs = self._probs
        param, name = (probs, 'probs') if logits is None else (logits,
                                                               'logits')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))
                with tf.control_dependencies(assertions):
                    param = tf.identity(param)

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, tf.int32.max)
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > tf.int32.max:
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if probs is not None:
            probs = param  # reuse tensor conversion from above
            if is_init != tensor_util.is_ref(probs):
                probs = tf.convert_to_tensor(probs)
                one = tf.ones([], dtype=probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(probs),
                    assert_util.assert_less_equal(probs, one),
                    assert_util.assert_near(
                        tf.reduce_sum(probs, axis=-1),
                        one,
                        message='Argument `probs` must sum to 1.'),
                ])

        return assertions
Пример #7
0
    def _validate_sample_arg(self, x):
        """Helper which validates sample arg, e.g., input to `log_prob`."""
        with tf.name_scope('validate_sample_arg'):
            x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None
                       else tensorshape_util.rank(x.shape))
            event_ndims = (tf.size(self.event_shape_tensor())
                           if tensorshape_util.rank(self.event_shape) is None
                           else tensorshape_util.rank(self.event_shape))
            batch_ndims = (tf.size(self._batch_shape_unexpanded)
                           if tensorshape_util.rank(self.batch_shape) is None
                           else tensorshape_util.rank(self.batch_shape))
            expected_batch_event_ndims = batch_ndims + event_ndims

            if (isinstance(x_ndims, int)
                    and isinstance(expected_batch_event_ndims, int)):
                if x_ndims < expected_batch_event_ndims:
                    raise NotImplementedError(
                        'Broadcasting is not supported; too few batch and event dims '
                        '(expected at least {}, saw {}).'.format(
                            expected_batch_event_ndims, x_ndims))
                ndims_assertion = []
            elif self.validate_args:
                ndims_assertion = [
                    assert_util.assert_greater_equal(
                        x_ndims,
                        expected_batch_event_ndims,
                        message=('Broadcasting is not supported; too few '
                                 'batch and event dims.'),
                        name='assert_batch_and_event_ndims_large_enough'),
                ]

            if (tensorshape_util.is_fully_defined(self.batch_shape)
                    and tensorshape_util.is_fully_defined(self.event_shape)):
                expected_batch_event_shape = np.int32(
                    tensorshape_util.concatenate(self.batch_shape,
                                                 self.event_shape))
            else:
                expected_batch_event_shape = tf.concat([
                    self.batch_shape_tensor(),
                    self.event_shape_tensor(),
                ],
                                                       axis=0)

            sample_ndims = x_ndims - expected_batch_event_ndims
            if isinstance(sample_ndims, int):
                sample_ndims = max(sample_ndims, 0)
            if (isinstance(sample_ndims, int) and
                    tensorshape_util.is_fully_defined(x.shape[sample_ndims:])):
                actual_batch_event_shape = np.int32(x.shape[sample_ndims:])
            else:
                sample_ndims = tf.maximum(sample_ndims, 0)
                actual_batch_event_shape = tf.shape(x)[sample_ndims:]

            if (isinstance(expected_batch_event_shape, np.ndarray)
                    and isinstance(actual_batch_event_shape, np.ndarray)):
                if any(expected_batch_event_shape != actual_batch_event_shape):
                    raise NotImplementedError(
                        'Broadcasting is not supported; '
                        'unexpected batch and event shape '
                        '(expected {}, saw {}).'.format(
                            expected_batch_event_shape,
                            actual_batch_event_shape))
                # We need to set the final runtime-assertions to `ndims_assertion` since
                # its possible this assertion was created. We could add a condition to
                # only do so if `self.validate_args == True`, however this is redundant
                # as `ndims_assertion` already encodes this information.
                runtime_assertions = ndims_assertion
            elif self.validate_args:
                # We need to make the `ndims_assertion` a control dep because otherwise
                # TF itself might raise an exception owing to this assertion being
                # ill-defined, ie, one cannot even compare different rank Tensors.
                with tf.control_dependencies(ndims_assertion):
                    shape_assertion = assert_util.assert_equal(
                        expected_batch_event_shape,
                        actual_batch_event_shape,
                        message=('Broadcasting is not supported; '
                                 'unexpected batch and event shape.'),
                        name='assert_batch_and_event_shape_same')
                runtime_assertions = [shape_assertion]
            else:
                runtime_assertions = []

            return runtime_assertions
  def _sample_n(self, n, seed=None):
    seed = SeedStream(seed, salt='vom_mises_fisher')
    # The sampling strategy relies on the fact that vMF variates are symmetric
    # about the mean direction. Accordingly, if we have a sampling strategy for
    # the away-from-mean angle, then we can uniformly sample the remaining
    # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
    # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
    #
    # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
    # von-Mises distributed `x` value in [-1, 1], then uniformly select what
    # amounts to a "up" or "down" additional degree of freedom after unit
    # normalizing, followed by a final rotation to the desired mean direction
    # from a basis of (1, 0).
    #
    # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
    # unit sphere over which the distribution is uniform, in particular the
    # circle where x = \hat{x} intersects the unit sphere. We pick a point on
    # that circle, then rotate to the desired mean direction from a basis of
    # (1, 0, 0).
    event_dim = (
        tf.compat.dimension_value(self.event_shape[0]) or
        self._event_shape_tensor()[0])

    sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
    dim = tf.cast(event_dim - 1, self.dtype)
    if event_dim == 3:
      samples_dim0 = self._sample_3d(n, seed=seed)
    else:
      # Wood'94 provides a rejection algorithm to sample the x coordinate.
      # Wood'94 definition of b:
      # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
      # https://stats.stackexchange.com/questions/156729 suggests:
      b = dim / (2 * self.concentration +
                 tf.sqrt(4 * self.concentration**2 + dim**2))
      # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
      #     https://github.com/nicola-decao/s-vae-tf/
      x = (1 - b) / (1 + b)
      c = self.concentration * x + dim * tf.math.log1p(-x**2)
      beta = beta_lib.Beta(dim / 2, dim / 2)

      def cond_fn(w, should_continue):
        del w
        return tf.reduce_any(should_continue)

      def body_fn(w, should_continue):
        z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
        # set_shape needed here because of b/139013403
        z.set_shape(w.shape)
        w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
        w = tf.debugging.check_numerics(w, 'w')
        unif = tf.random.uniform(
            sample_batch_shape, seed=seed(), dtype=self.dtype)
        # set_shape needed here because of b/139013403
        unif.set_shape(w.shape)
        should_continue = tf.logical_and(
            should_continue,
            self.concentration * w + dim * tf.math.log1p(-x * w) - c <
            tf.math.log(unif))
        return w, should_continue

      w = tf.zeros(sample_batch_shape, dtype=self.dtype)
      should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
      samples_dim0 = tf.while_loop(
          cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0]
      samples_dim0 = samples_dim0[..., tf.newaxis]
    if not self._allow_nan_stats:
      # Verify samples are w/in -1, 1, with useful error output tensors (top
      # value rather than all values).
      with tf.control_dependencies([
          assert_util.assert_less_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(1.01),
              data=[tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]]),
          assert_util.assert_greater_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(-1.01),
              data=[-tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]])
      ]):
        samples_dim0 = tf.identity(samples_dim0)
    samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]],
                                        axis=0)
    unit_otherdims = tf.math.l2_normalize(
        tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
        axis=-1)
    samples = tf.concat([
        samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
        tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
    ], axis=-1)
    samples = tf.math.l2_normalize(samples, axis=-1)
    if not self._allow_nan_stats:
      samples = tf.debugging.check_numerics(samples, 'samples')

    # Runtime assert that samples are unit length.
    if not self._allow_nan_stats:
      worst, idx = tf.math.top_k(
          tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
      with tf.control_dependencies([
          assert_util.assert_near(
              dtype_util.as_numpy_dtype(self.dtype)(0),
              worst,
              data=[
                  worst, idx,
                  tf.gather(tf.reshape(samples, [-1, event_dim]), idx)
              ],
              atol=1e-4,
              summarize=100)
      ]):
        samples = tf.identity(samples)
    # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
    # Now, we move the mode to `self.mean_direction` using a rotation matrix.
    if not self._allow_nan_stats:
      # Assert that the basis vector rotates to the mean direction, as expected.
      basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                      self.dtype)
      with tf.control_dependencies([
          assert_util.assert_less(
              tf.linalg.norm(
                  self._rotate(basis) - self.mean_direction, axis=-1),
              dtype_util.as_numpy_dtype(self.dtype)(1e-5))
      ]):
        return self._rotate(samples)
    return self._rotate(samples)
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (observation_distribution.batch_shape is not None
                    and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (transition_distribution.batch_shape
                     and transition_distribution.batch_shape[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (observation_distribution.batch_shape
                     and observation_distribution.batch_shape[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters