示例#1
0
  def _prob(self, x):
    if self.validate_args:
      with tf.control_dependencies([
          assert_util.assert_greater_equal(x, self.low),
          assert_util.assert_less_equal(x, self.high)
      ]):
        x = tf.identity(x)

    broadcast_x_to_high = _broadcast_to(x, [self.high])
    left_of_peak = tf.logical_and(
        broadcast_x_to_high > self.low, broadcast_x_to_high <= self.peak)

    interval_length = self.high - self.low
    # This is the pdf function when a low <= high <= x. This looks like
    # a triangle, so we have to treat each line segment separately.
    result_inside_interval = tf.where(
        left_of_peak,
        # Line segment from (self.low, 0) to (self.peak, 2 / (self.high -
        # self.low).
        2. * (x - self.low) / (interval_length * (self.peak - self.low)),
        # Line segment from (self.peak, 2 / (self.high - self.low)) to
        # (self.high, 0).
        2. * (self.high - x) / (interval_length * (self.high - self.peak)))

    broadcast_x_to_peak = _broadcast_to(x, [self.peak])
    outside_interval = tf.logical_or(
        broadcast_x_to_peak < self.low, broadcast_x_to_peak > self.high)

    broadcast_shape = tf.broadcast_dynamic_shape(
        tf.shape(input=x), self.batch_shape_tensor())

    return tf.where(
        outside_interval,
        tf.zeros(broadcast_shape, dtype=self.dtype),
        result_inside_interval)
示例#2
0
    def _variance(self):
        tailweight = tf.convert_to_tensor(self.tailweight)
        scale = tf.convert_to_tensor(self.scale)
        # For tail < 0.5, the variance is finite. See Eq (18) in
        # https://www.hindawi.com/journals/tswj/2015/909231/
        var = (
            tf.cast(tf.pow(1. - 2. * tailweight, -3. / 2.), dtype=self.dtype) *
            tf.math.square(scale))
        # We need to put the tf.where inside the outer tf.where to ensure we never
        # hit a NaN in the gradient.
        result_where_defined = tf.where(
            tailweight < 0.5, var,
            tf.convert_to_tensor(np.inf, dtype=self.dtype))

        if self.allow_nan_stats:
            return tf.where(tailweight < 1.0, result_where_defined,
                            tf.convert_to_tensor(np.nan, self.dtype))
        else:
            return distribution_util.with_dependencies([
                assert_util.assert_greater_equal(
                    tf.ones([], dtype=self.dtype),
                    tailweight,
                    message=
                    "variance not defined for components of tailweight >= 1"),
            ], result_where_defined)
示例#3
0
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
  strm = tfp_test_util.test_seed_stream()
  initialization = tfd.MultivariateNormalFullCovariance(
      loc=tf.zeros(event_size),
      covariance_matrix=tf.eye(event_size)).sample(
          batch_size, seed=strm())
  samples, _ = run_nuts_chain(
      event_size, batch_size, num_steps=1,
      initial_state=initialization, **kwargs)
  answer = samples[0][-1]
  check_cdf_agrees = (
      st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
          answer, initialization, num_projections=100, false_fail_rate=1e-6))
  check_sample_shape = assert_util.assert_equal(
      tf.shape(answer)[0], batch_size)
  movement = tf.linalg.norm(answer - initialization, axis=-1)
  # This movement distance (0.3) was copied from the univariate case.
  check_movement = assert_util.assert_greater_equal(
      tf.reduce_mean(movement), 0.3)
  check_enough_power = assert_util.assert_less(
      st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
          batch_size, batch_size, false_fail_rate=1e-8, false_pass_rate=1e-6),
      0.055)
  return (
      check_cdf_agrees,
      check_sample_shape,
      check_movement,
      check_enough_power,
  )
示例#4
0
def _maybe_assert_float_matrix(logu, validate_args):
    """Assertion check for the scores matrix to be float type."""
    logu = tf.convert_to_tensor(logu, dtype_hint=tf.float32, name='logu')

    if not dtype_util.is_floating(logu.dtype):
        raise TypeError('Input argument must be `float` type.')

    assertions = []
    # Check scores is a matrix.
    msg = 'Input argument must be a (batch of) matrix.'
    rank = tensorshape_util.rank(logu.shape)
    if rank is not None:
        if rank < 2:
            raise ValueError(msg)
    elif validate_args:
        assertions.append(assert_util.assert_rank_at_least(logu, 2, msg))

    # Check scores has the shape [..., N, M], M >= N
    msg = 'Input argument must be a (batch of) matrix of the shape [N, M], M > N.'
    if (rank is not None
            and tensorshape_util.is_fully_defined(logu.shape[-2:])):
        if logu.shape[-2] > logu.shape[-1]:
            raise ValueError(msg)
    elif validate_args:
        n, m = tf.unstack(logu.shape[-2:])
        assertions.append(assert_util.assert_greater_equal(m, n, message=msg))
    return assertions
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    logits = self._logits
    probs = self._probs
    param, name = (probs, 'probs') if logits is None else (logits, 'logits')

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
      if not dtype_util.is_floating(param.dtype):
        raise TypeError('Argument `{}` must having floating type.'.format(name))

      msg = 'Argument `{}` must have rank at least 1.'.format(name)
      shape_static = tensorshape_util.dims(param.shape)
      if shape_static is not None:
        if len(shape_static) < 1:
          raise ValueError(msg)
      elif self.validate_args:
        param = tf.convert_to_tensor(param)
        assertions.append(
            assert_util.assert_rank_at_least(param, 1, message=msg))
        with tf.control_dependencies(assertions):
          param = tf.identity(param)

      msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
      msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
          name, dtype_util.max(tf.int32))
      event_size = shape_static[-1] if shape_static is not None else None
      if event_size is not None:
        if event_size < 1:
          raise ValueError(msg1)
        if event_size > dtype_util.max(tf.int32):
          raise ValueError(msg2)
      elif self.validate_args:
        param = tf.convert_to_tensor(param)
        assertions.append(assert_util.assert_greater_equal(
            tf.shape(param)[-1], 1, message=msg1))
        # NOTE: For now, we leave out a runtime assertion that
        # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
        # will fail before we get to this point.

    if not self.validate_args:
      assert not assertions  # Should never happen.
      return []

    if probs is not None:
      probs = param  # reuse tensor conversion from above
      if is_init != tensor_util.is_ref(probs):
        probs = tf.convert_to_tensor(probs)
        one = tf.ones([], dtype=probs.dtype)
        assertions.extend([
            assert_util.assert_non_negative(probs),
            assert_util.assert_less_equal(probs, one),
            assert_util.assert_near(
                tf.reduce_sum(probs, axis=-1), one,
                message='Argument `probs` must sum to 1.'),
        ])

    return assertions
示例#6
0
  def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
    """Check whether event_ndims is atleast min_event_ndims."""
    event_ndims = tf.convert_to_tensor(event_ndims, name='event_ndims')
    event_ndims_ = tf.get_static_value(event_ndims)
    assertions = []

    if not dtype_util.is_integer(event_ndims.dtype):
      raise ValueError('Expected integer dtype, got dtype {}'.format(
          event_ndims.dtype))

    if event_ndims_ is not None:
      if tensorshape_util.rank(event_ndims.shape) != 0:
        raise ValueError('Expected scalar event_ndims, got shape {}'.format(
            event_ndims.shape))
      if min_event_ndims > event_ndims_:
        raise ValueError('event_ndims ({}) must be larger than '
                         'min_event_ndims ({})'.format(event_ndims_,
                                                       min_event_ndims))
    elif self.validate_args:
      assertions += [
          assert_util.assert_greater_equal(event_ndims, min_event_ndims)
      ]

    if tensorshape_util.is_fully_defined(event_ndims.shape):
      if tensorshape_util.rank(event_ndims.shape) != 0:
        raise ValueError('Expected scalar shape, got ndims {}'.format(
            tensorshape_util.rank(event_ndims.shape)))

    elif self.validate_args:
      assertions += [
          assert_util.assert_rank(event_ndims, 0, message='Expected scalar.')
      ]
    return assertions
示例#7
0
 def _maybe_assert_valid_sample(self, x, loc):
     """Checks the validity of a sample."""
     if not self.validate_args:
         return []
     return [
         assert_util.assert_greater_equal(
             x, loc, message='x is not in the support of the distribution')
     ]
示例#8
0
 def _check_arg_and_apply_f(*args, **kwargs):
   dist = args[0]
   x = args[1]
   with tf.control_dependencies([
       assert_util.assert_greater_equal(
           x, dist.loc, message="x is not in the support of the distribution")
   ] if dist.validate_args else []):
     return f(*args, **kwargs)
示例#9
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self._num_steps):
     assertions.append(assert_util.assert_greater_equal(
         self._num_steps, 1,
         message='Argument `num_steps` must be at least 1.'))
   return assertions
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(assert_util.assert_greater_equal(
       x, self.low, message='Sample must be greater than or equal to `low`.'))
   assertions.append(assert_util.assert_less_equal(
       x, self.high, message='Sample must be less than or equal to `high`.'))
   return assertions
示例#11
0
 def _sample_control_dependencies(self, x):
   """Checks the validity of a sample."""
   assertions = []
   if not self.validate_args:
     return assertions
   loc = tf.convert_to_tensor(self.loc)
   assertions.append(
       assert_util.assert_greater_equal(
           x, loc, message='Sample must be greater than or equal to `loc`.'))
   return assertions
示例#12
0
 def _maybe_assert_valid_x(self, x):
     if not self.validate_args:
         return []
     return [
         assert_util.assert_greater_equal(
             x,
             self.loc,
             message=
             'Forward transformation input must be greater than `loc`.')
     ]
示例#13
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        scores = self._scores
        param, name = (scores, 'scores')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))
                with tf.control_dependencies(assertions):
                    param = tf.identity(param)

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, tf.int32.max)
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > tf.int32.max:
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(scores):
            scores = tf.convert_to_tensor(scores)
            assertions.extend([
                assert_util.assert_positive(scores),
            ])

        return assertions
示例#14
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self._tailweight):
         assertions.append(
             assert_util.assert_greater_equal(
                 self._tailweight,
                 tf.zeros([], dtype=self.dtype),
                 message="Argument `tailweight` must be non-negative."))
     return assertions
示例#15
0
  def _prob(self, x):
    concentration = tf.convert_to_tensor(self.concentration)
    scale = tf.convert_to_tensor(self.scale)
    with tf.control_dependencies([
        assert_util.assert_greater_equal(
            x, scale, message='`x` is not in the support of the distribution.')
    ] if self.validate_args else []):

      def prob_on_support(z):
        return concentration * (scale**concentration) / (z**(concentration + 1))

      return self._extend_support(x, scale, prob_on_support, alt=0.)
示例#16
0
  def _prob(self, x):
    with tf.control_dependencies([
        assert_util.assert_greater_equal(
            x,
            self.scale,
            message="x is not in the support of the distribution.")
    ] if self.validate_args else []):

      def prob_on_support(z):
        return (self.concentration * (self.scale ** self.concentration) /
                (z ** (self.concentration + 1)))
      return self._extend_support(x, prob_on_support, alt=0.)
 def _call_quantile(self, value, name, **kwargs):
   with self._name_and_control_scope(name):
     dtype = tf.float32 if tf.nest.is_nested(self.dtype) else self.dtype
     value = tf.convert_to_tensor(value, name='value', dtype_hint=dtype)
     if self.validate_args:
       value = distribution_util.with_dependencies([
           assert_util.assert_less_equal(value, tf.cast(1, value.dtype),
                                         message='`value` must be <= 1'),
           assert_util.assert_greater_equal(value, tf.cast(0, value.dtype),
                                            message='`value` must be >= 0')
       ], value)
     return self._quantile(value, **kwargs)
示例#18
0
 def _assertions(self, t):
   if not self.validate_args:
     return []
   return [
       assert_util.assert_greater_equal(
           t,
           dtype_util.as_numpy_dtype(t.dtype)(-1),
           message="Inverse transformation input must be >= -1."),
       assert_util.assert_less_equal(
           t,
           dtype_util.as_numpy_dtype(t.dtype)(1),
           message="Inverse transformation input must be <= 1.")
   ]
示例#19
0
    def _log_prob(self, x):
        with tf.control_dependencies([
                assert_util.assert_greater_equal(
                    x,
                    self.scale,
                    message="x is not in the support of the distribution.")
        ] if self.validate_args else []):

            def log_prob_on_support(z):
                return (tf.math.log(self.concentration) +
                        self.concentration * tf.math.log(self.scale) -
                        (self.concentration + 1.) * tf.math.log(z))

            return self._extend_support(x, log_prob_on_support, alt=-np.inf)
    def _assert_compatible_shape(self, index, sample_shape, samples):
        requested_shape, _ = self._expand_sample_shape_to_vector(
            tf.convert_to_tensor(sample_shape, dtype=tf.int32),
            name='requested_shape')
        actual_shape = prefer_static.shape(samples)
        actual_rank = prefer_static.rank_from_shape(actual_shape)
        requested_rank = prefer_static.rank_from_shape(requested_shape)

        # We test for two properties we expect of yielded distributions:
        # (1) The rank of the tensor of generated samples must be at least
        #     as large as the rank requested.
        # (2) The requested shape must be a prefix of the shape of the
        #     generated tensor of samples.
        # We attempt to perform test (1) statically first.
        # We don't need to do this explicitly for test (2) because
        # `assert_equal` evaluates statically if it can.
        static_actual_rank = tf.get_static_value(actual_rank)
        static_requested_rank = tf.get_static_value(requested_rank)

        assertion_message = ('Samples yielded by distribution #{} are not '
                             'consistent with `sample_shape` passed to '
                             '`JointDistributionCoroutine` '
                             'distribution.'.format(index))

        # TODO Remove this static check (b/138738650)
        if (static_actual_rank is not None
                and static_requested_rank is not None):
            # We're able to statically check the rank
            if static_actual_rank < static_requested_rank:
                raise ValueError(assertion_message)
            else:
                control_dependencies = []
        else:
            # We're not able to statically check the rank
            control_dependencies = [
                assert_util.assert_greater_equal(actual_rank,
                                                 requested_rank,
                                                 message=assertion_message)
            ]

        with tf.control_dependencies(control_dependencies):
            trimmed_actual_shape = actual_shape[:requested_rank]

        control_dependencies = [
            assert_util.assert_equal(requested_shape,
                                     trimmed_actual_shape,
                                     message=assertion_message)
        ]

        return control_dependencies
示例#21
0
  def _log_prob(self, x):
    concentration = tf.convert_to_tensor(self.concentration)
    scale = tf.convert_to_tensor(self.scale)
    with tf.control_dependencies([
        assert_util.assert_greater_equal(
            x, scale, message='`x` is not in the support of the distribution.')
    ] if self.validate_args else []):

      def log_prob_on_support(z):
        # This can also be written as log(c) + c * log(s) - (c + 1) * log(z).
        # However, when c >> 1 and s and z are of the same magnitude, this can
        # lead to loss of precision (log(c) vs. log(c) - log(z)).
        return (tf.math.log(concentration / z) +
                concentration * tf.math.log(scale / z))

      return self._extend_support(
          x, scale, log_prob_on_support, alt=-np.inf)
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   loc = tf.convert_to_tensor(self.loc)
   scale = tf.convert_to_tensor(self.scale)
   concentration = tf.convert_to_tensor(self.concentration)
   assertions.append(assert_util.assert_greater_equal(
       x, loc, message='Sample must be greater than or equal to `loc`.'))
   assertions.append(assert_util.assert_equal(
       tf.logical_or(tf.greater_equal(concentration, 0),
                     tf.less_equal(x, loc - scale / concentration)),
       True,
       message=('If `concentration < 0`, sample must be less than or '
                'equal to `loc - scale / concentration`.'),
       summarize=100))
   return assertions
示例#23
0
def assert_univariate_target_conservation(test, target_d, step_size):
    # Sample count limited partly by memory reliably available on Forge.  The test
    # remains reasonable even if the nuts recursion limit is severely curtailed
    # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump
    # the sample count.
    num_samples = int(5e4)
    num_steps = 1
    strm = test_util.test_seed_stream()
    # We wrap the initial values in `tf.identity` to avoid broken gradients
    # resulting from a bijector cache hit, since bijectors of the same
    # type/parameterization now share a cache.
    # TODO(b/72831017): Fix broken gradients caused by bijector caching.
    initialization = tf.identity(target_d.sample([num_samples], seed=strm()))

    @tf.function(autograph=False)
    def run_chain():
        nuts = tfp.experimental.mcmc.PreconditionedNoUTurnSampler(
            target_d.log_prob,
            step_size=step_size,
            max_tree_depth=3,
            unrolled_leapfrog_steps=2)
        result = tfp.mcmc.sample_chain(num_results=num_steps,
                                       num_burnin_steps=0,
                                       current_state=initialization,
                                       trace_fn=None,
                                       kernel=nuts,
                                       seed=strm())
        return result

    result = run_chain()
    test.assertAllEqual([num_steps, num_samples], result.shape)
    answer = result[0]
    check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm(answer,
                                                        target_d.cdf,
                                                        false_fail_rate=1e-6)
    check_enough_power = assert_util.assert_less(
        st.min_discrepancy_of_true_cdfs_detectable_by_dkwm(
            num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025)
    movement = tf.abs(answer - initialization)
    test.assertAllEqual([num_samples], movement.shape)
    # This movement distance (1 * step_size) was selected by reducing until 100
    # runs with independent seeds all passed.
    check_movement = assert_util.assert_greater_equal(tf.reduce_mean(movement),
                                                      1 * step_size)
    return (check_cdf_agrees, check_enough_power, check_movement)
示例#24
0
def _check_at_least_two_chains(accept_prob, reduce_chain_axis_names,
                               validate_args, message):
    """Checks that the number of chains is at least 2."""
    # Number of total chains is local batch size * distributed axis size
    local_axis_size = ps.size(accept_prob)
    distributed_axis_size = int(
        ps.reduce_prod([
            distribute_lib.get_axis_size(a) for a in reduce_chain_axis_names
        ]))
    num_chains = local_axis_size * distributed_axis_size
    num_chains_ = tf.get_static_value(num_chains)
    if num_chains_ is not None:
        if num_chains_ < 2:
            raise ValueError('{} Got: {}'.format(message, num_chains_))
    elif validate_args:
        with tf.control_dependencies(
            [assert_util.assert_greater_equal(num_chains, 2, message)]):
            accept_prob = tf.identity(accept_prob)
    return accept_prob
示例#25
0
def assert_univariate_target_conservation(test, target_d, step_size):
  # Sample count limited partly by memory reliably available on Forge.  The test
  # remains reasonable even if the nuts recursion limit is severely curtailed
  # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump
  # the sample count.
  num_samples = int(5e4)
  num_steps = 1
  strm = tfp.util.SeedStream(salt='univariate_nuts_test', seed=1)
  initialization = target_d.sample([num_samples], seed=strm())

  @tf.function(autograph=False)
  def run_chain():
    nuts = tfp.mcmc.NoUTurnSampler(
        target_d.log_prob,
        step_size=step_size,
        max_tree_depth=3,
        unrolled_leapfrog_steps=2,
        seed=strm())
    result, _ = tfp.mcmc.sample_chain(
        num_results=num_steps,
        num_burnin_steps=0,
        current_state=initialization,
        kernel=nuts)
    return result

  result = run_chain()
  test.assertAllEqual([num_steps, num_samples], result.shape)
  answer = result[0]
  check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm(
      answer, target_d.cdf, false_fail_rate=1e-6)
  check_enough_power = assert_util.assert_less(
      st.min_discrepancy_of_true_cdfs_detectable_by_dkwm(
          num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025)
  movement = tf.abs(answer - initialization)
  test.assertAllEqual([num_samples], movement.shape)
  # This movement distance (1 * step_size) was selected by reducing until 100
  # runs with independent seeds all passed.
  check_movement = assert_util.assert_greater_equal(
      tf.reduce_mean(movement), 1 * step_size)
  return (check_cdf_agrees, check_enough_power, check_movement)
示例#26
0
  def _prob(self, x):
    if self.validate_args:
      with tf.control_dependencies([
          assert_util.assert_greater_equal(x, self.low),
          assert_util.assert_less_equal(x, self.high)
      ]):
        x = tf.identity(x)

    interval_length = self.high - self.low
    # This is the pdf function when a low <= high <= x. This looks like
    # a triangle, so we have to treat each line segment separately.
    result_inside_interval = tf.where(
        (x >= self.low) & (x <= self.peak),
        # Line segment from (self.low, 0) to (self.peak, 2 / (self.high -
        # self.low).
        2. * (x - self.low) / (interval_length * (self.peak - self.low)),
        # Line segment from (self.peak, 2 / (self.high - self.low)) to
        # (self.high, 0).
        2. * (self.high - x) / (interval_length * (self.high - self.peak)))

    return tf.where((x < self.low) | (x > self.high),
                    tf.zeros_like(x),
                    result_inside_interval)
示例#27
0
    def _prob(self, x):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        if self.validate_args:
            with tf.control_dependencies([
                    assert_util.assert_greater_equal(x, low),
                    assert_util.assert_less_equal(x, high)
            ]):
                x = tf.identity(x)

        interval_length = high - low
        # This is the pdf function when a low <= high <= x. This looks like
        # a triangle, so we have to treat each line segment separately.
        result_inside_interval = tf.where(
            (x >= low) & (x <= peak),
            # Line segment from (low, 0) to (peak, 2 / (high - low)).
            2. * (x - low) / (interval_length * (peak - low)),
            # Line segment from (peak, 2 / (high - low)) to (high, 0).
            2. * (high - x) / (interval_length * (high - peak)))

        return tf.where((x < low) | (x > high), tf.zeros_like(x),
                        result_inside_interval)
示例#28
0
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (tensorshape_util.dims(transition_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (tensorshape_util.dims(observation_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (tensorshape_util.dims(transition_distribution.batch_shape)
                     is not None and tensorshape_util.as_list(
                         transition_distribution.batch_shape)[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (tensorshape_util.dims(
                        observation_distribution.batch_shape) is not None
                     and tensorshape_util.as_list(
                         observation_distribution.batch_shape)[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
示例#29
0
  def _sample_n(self, n, seed=None):
    dim0_seed, otherdims_seed = samplers.split_seed(seed,
                                                    salt='von_mises_fisher')
    # The sampling strategy relies on the fact that vMF variates are symmetric
    # about the mean direction. Accordingly, if we have a sampling strategy for
    # the away-from-mean angle, then we can uniformly sample the remaining
    # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
    # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
    #
    # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
    # von-Mises distributed `x` value in [-1, 1], then uniformly select what
    # amounts to a "up" or "down" additional degree of freedom after unit
    # normalizing, followed by a final rotation to the desired mean direction
    # from a basis of (1, 0).
    #
    # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
    # unit sphere over which the distribution is uniform, in particular the
    # circle where x = \hat{x} intersects the unit sphere. We pick a point on
    # that circle, then rotate to the desired mean direction from a basis of
    # (1, 0, 0).
    mean_direction = tf.convert_to_tensor(self.mean_direction)
    concentration = tf.convert_to_tensor(self.concentration)
    event_dim = (
        tf.compat.dimension_value(self.event_shape[0]) or
        self._event_shape_tensor(mean_direction=mean_direction)[0])

    sample_batch_shape = ps.concat([[n], self._batch_shape_tensor(
        mean_direction=mean_direction, concentration=concentration)], axis=0)
    dim = tf.cast(event_dim - 1, self.dtype)
    if event_dim == 3:
      samples_dim0 = self._sample_3d(n,
                                     mean_direction=mean_direction,
                                     concentration=concentration,
                                     seed=dim0_seed)
    else:
      # Wood'94 provides a rejection algorithm to sample the x coordinate.
      # Wood'94 definition of b:
      # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
      # https://stats.stackexchange.com/questions/156729 suggests:
      b = dim / (2 * concentration +
                 tf.sqrt(4 * concentration**2 + dim**2))
      # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
      #     https://github.com/nicola-decao/s-vae-tf/
      x = (1 - b) / (1 + b)
      c = concentration * x + dim * tf.math.log1p(-x**2)
      beta = beta_lib.Beta(dim / 2, dim / 2)

      def cond_fn(w, should_continue, seed):
        del w, seed
        return tf.reduce_any(should_continue)

      def body_fn(w, should_continue, seed):
        """While loop body for sampling the angle `w`."""
        beta_seed, unif_seed, next_seed = samplers.split_seed(seed, n=3)
        z = beta.sample(sample_shape=sample_batch_shape, seed=beta_seed)
        # set_shape needed here because of b/139013403
        tensorshape_util.set_shape(z, w.shape)
        w = tf.where(should_continue,
                     (1. - (1. + b) * z) / (1. - (1. - b) * z),
                     w)
        if not self.allow_nan_stats:
          w = tf.debugging.check_numerics(w, 'w')
        unif = samplers.uniform(
            sample_batch_shape, seed=unif_seed, dtype=self.dtype)
        # set_shape needed here because of b/139013403
        tensorshape_util.set_shape(unif, w.shape)
        should_continue = should_continue & (
            concentration * w + dim * tf.math.log1p(-x * w) - c <
            # Use log1p(-unif) to prevent log(0) and ensure that log(1) is
            # possible.
            tf.math.log1p(-unif))
        return w, should_continue, next_seed

      w = tf.zeros(sample_batch_shape, dtype=self.dtype)
      should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
      samples_dim0, _, _ = tf.while_loop(
          cond=cond_fn, body=body_fn,
          loop_vars=(w, should_continue, dim0_seed))
      samples_dim0 = samples_dim0[..., tf.newaxis]
    if not self._allow_nan_stats:
      # Verify samples are w/in -1, 1, with useful error output tensors (top
      # value rather than all values).
      with tf.control_dependencies([
          assert_util.assert_less_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(1.01)),
          assert_util.assert_greater_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(-1.01)),
      ]):
        samples_dim0 = tf.identity(samples_dim0)
    samples_otherdims_shape = ps.concat([sample_batch_shape, [event_dim - 1]],
                                        axis=0)
    unit_otherdims = tf.math.l2_normalize(
        samplers.normal(
            samples_otherdims_shape, seed=otherdims_seed, dtype=self.dtype),
        axis=-1)
    samples = tf.concat([
        samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
        tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
    ], axis=-1)
    samples = tf.math.l2_normalize(samples, axis=-1)
    if not self.allow_nan_stats:
      samples = tf.debugging.check_numerics(samples, 'samples')

    # Runtime assert that samples are unit length.
    if not self.allow_nan_stats:
      worst, _ = tf.math.top_k(
          tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
      with tf.control_dependencies([
          assert_util.assert_near(
              dtype_util.as_numpy_dtype(self.dtype)(0),
              worst,
              atol=1e-4,
              summarize=100)
      ]):
        samples = tf.identity(samples)
    # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
    # Now, we move the mode to `self.mean_direction` using a rotation matrix.
    if not self.allow_nan_stats:
      # Assert that the basis vector rotates to the mean direction, as expected.
      basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                      self.dtype)
      with tf.control_dependencies([
          assert_util.assert_less(
              tf.linalg.norm(
                  self._rotate(basis, mean_direction=mean_direction) -
                  mean_direction, axis=-1),
              dtype_util.as_numpy_dtype(self.dtype)(1e-5))
      ]):
        return self._rotate(samples, mean_direction=mean_direction)
    return self._rotate(samples, mean_direction=mean_direction)
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    # Check num_steps is a scalar that's at least 1.
    if is_init != tensor_util.is_ref(self.num_steps):
      num_steps = tf.convert_to_tensor(self.num_steps)
      num_steps_ = tf.get_static_value(num_steps)
      if num_steps_ is not None:
        if np.ndim(num_steps_) != 0:
          raise ValueError(
              '`num_steps` must be a scalar but it has rank {}'.format(
                  np.ndim(num_steps_)))
        if num_steps_ < 1:
          raise ValueError('`num_steps` must be at least 1.')
      elif self.validate_args:
        message = '`num_steps` must be a scalar'
        assertions.append(
            assert_util.assert_rank_at_most(self.num_steps, 0, message=message))
        assertions.append(
            assert_util.assert_greater_equal(
                num_steps, 1,
                message='`num_steps` must be at least 1.'))

    # Check that the initial distribution has scalar events over the
    # integers.
    if is_init and not dtype_util.is_integer(self.initial_distribution.dtype):
      raise ValueError(
          '`initial_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.initial_distribution.dtype)))

    if tensorshape_util.rank(self.initial_distribution.event_shape) is not None:
      if tensorshape_util.rank(self.initial_distribution.event_shape) != 0:
        raise ValueError('`initial_distribution` must have scalar `event_dim`s')
    elif self.validate_args:
      assertions += [
          assert_util.assert_equal(
              ps.size(self.initial_distribution.event_shape_tensor()),
              0,
              message='`initial_distribution` must have scalar `event_dim`s'),
      ]

    # Check that the transition distribution is over the integers.
    if (is_init and
        not dtype_util.is_integer(self.transition_distribution.dtype)):
      raise ValueError(
          '`transition_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.transition_distribution.dtype)))

    # Check observations have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.observation_distribution.batch_shape) == 0:
      raise ValueError(
          "`observation_distribution` can't have scalar batches")

    # Check transitions have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.transition_distribution.batch_shape) == 0:
      raise ValueError(
          "`transition_distribution` can't have scalar batches")

    # Check compatibility of transition distribution and observation
    # distribution.
    tdbs = self.transition_distribution.batch_shape
    odbs = self.observation_distribution.batch_shape
    if (tensorshape_util.dims(tdbs) is not None and
        tf.compat.dimension_value(odbs[-1]) is not None):
      if (tf.compat.dimension_value(tdbs[-1]) !=
          tf.compat.dimension_value(odbs[-1])):
        raise ValueError(
            '`transition_distribution` and `observation_distribution` '
            'must agree on last dimension of batch size')
    elif self.validate_args:
      tdbs = self.transition_distribution.batch_shape_tensor()
      odbs = self.observation_distribution.batch_shape_tensor()
      transition_precondition = assert_util.assert_greater(
          ps.size(tdbs), 0,
          message=('`transition_distribution` can\'t have scalar '
                   'batches'))
      observation_precondition = assert_util.assert_greater(
          ps.size(odbs), 0,
          message=('`observation_distribution` can\'t have scalar '
                   'batches'))
      with tf.control_dependencies([
          transition_precondition,
          observation_precondition]):
        assertions += [
            assert_util.assert_equal(
                tdbs[-1],
                odbs[-1],
                message=('`transition_distribution` and '
                         '`observation_distribution` '
                         'must agree on last dimension of batch size'))]

    return assertions