Пример #1
0
 def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
Пример #2
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []
    mean_direction = tf.convert_to_tensor(self.mean_direction)
    concentration = tf.convert_to_tensor(self.concentration)

    assertions = []
    if is_init != tensor_util.is_ref(self._mean_direction):
      assertions.append(
          assert_util.assert_greater(
              tf.shape(mean_direction)[-1],
              1,
              message='`mean_direction` may not have scalar event shape'))
      assertions.append(
          assert_util.assert_less_equal(
              tf.shape(mean_direction)[-1],
              5,
              message='von Mises-Fisher ndims > 5 is not currently supported'))
      assertions.append(
          assert_util.assert_near(
              1.,
              tf.linalg.norm(mean_direction, axis=-1),
              message='`mean_direction` must be unit-length'))
    if is_init != tensor_util.is_ref(self._concentration):
      assertions.append(
          assert_util.assert_non_negative(
              concentration, message='`concentration` must be non-negative'))
    return assertions
Пример #3
0
    def _sample_control_dependencies(self, x):
        assertions = []
        if tensorshape_util.is_fully_defined(x.shape[-2:]):
            if not (tensorshape_util.dims(x.shape)[-2] ==
                    tensorshape_util.dims(x.shape)[-1] == self.dimension):
                raise ValueError(
                    'Input dimension mismatch: expected [..., {}, {}], got {}'.
                    format(self.dimension, self.dimension,
                           tensorshape_util.dims(x.shape)))
        elif self.validate_args:
            msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
                self.dimension, self.dimension, tf.shape(x))
            assertions.append(
                assert_util.assert_equal(tf.shape(x)[-2],
                                         self.dimension,
                                         message=msg))
            assertions.append(
                assert_util.assert_equal(tf.shape(x)[-1],
                                         self.dimension,
                                         message=msg))

        if self.validate_args:
            assertions.append(
                assert_util.assert_near(
                    x,
                    tf.linalg.band_part(x, -1, 0),
                    message='Cholesky factors must be lower triangular.'))
        return assertions
Пример #4
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    logits = self._logits
    probs = self._probs
    param, name = (probs, 'probs') if logits is None else (logits, 'logits')

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
      if not dtype_util.is_floating(param.dtype):
        raise TypeError('Argument `{}` must having floating type.'.format(name))

      msg = 'Argument `{}` must have rank at least 1.'.format(name)
      shape_static = tensorshape_util.dims(param.shape)
      if shape_static is not None:
        if len(shape_static) < 1:
          raise ValueError(msg)
      elif self.validate_args:
        param = tf.convert_to_tensor(param)
        assertions.append(
            assert_util.assert_rank_at_least(param, 1, message=msg))
        with tf.control_dependencies(assertions):
          param = tf.identity(param)

      msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
      msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
          name, dtype_util.max(tf.int32))
      event_size = shape_static[-1] if shape_static is not None else None
      if event_size is not None:
        if event_size < 1:
          raise ValueError(msg1)
        if event_size > dtype_util.max(tf.int32):
          raise ValueError(msg2)
      elif self.validate_args:
        param = tf.convert_to_tensor(param)
        assertions.append(assert_util.assert_greater_equal(
            tf.shape(param)[-1], 1, message=msg1))
        # NOTE: For now, we leave out a runtime assertion that
        # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
        # will fail before we get to this point.

    if not self.validate_args:
      assert not assertions  # Should never happen.
      return []

    if probs is not None:
      probs = param  # reuse tensor conversion from above
      if is_init != tensor_util.is_ref(probs):
        probs = tf.convert_to_tensor(probs)
        one = tf.ones([], dtype=probs.dtype)
        assertions.extend([
            assert_util.assert_non_negative(probs),
            assert_util.assert_less_equal(probs, one),
            assert_util.assert_near(
                tf.reduce_sum(probs, axis=-1), one,
                message='Argument `probs` must sum to 1.'),
        ])

    return assertions
  def _inverse(self, y):
    # To derive the inverse mapping note that:
    #   y[i] = exp(x[i]) / normalization
    # and
    #   y[end] = 1 / normalization.
    # Thus:
    # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization)
    #      = log(exp(x[i])/normalization) - log(y[end])
    #      = log(y[i]) - log(y[end])

    # Do this first to make sure CSE catches that it'll happen again in
    # _inverse_log_det_jacobian.

    assertions = []
    if self.validate_args:
      assertions.append(assert_util.assert_near(
          tf.reduce_sum(y, axis=-1),
          tf.ones([], y.dtype),
          2. * np.finfo(dtype_util.as_numpy_dtype(y.dtype)).eps,
          message='Last dimension of `y` must sum to `1`.'))
      assertions.append(assert_util.assert_less_equal(
          y, tf.ones([], y.dtype),
          message='Elements of `y` must be less than or equal to `1`.'))
      assertions.append(assert_util.assert_non_negative(
          y, message='Elements of `y` must be non-negative.'))

    with tf.control_dependencies(assertions):
      x = tf.math.log(y)
      x, log_normalization = tf.split(x, num_or_size_splits=[-1, 1], axis=-1)
    return x - log_normalization
Пример #6
0
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []
        mean_direction = tf.convert_to_tensor(self.mean_direction)
        concentration = tf.convert_to_tensor(self.concentration)

        assertions = []
        if is_init != tensor_util.is_ref(self._mean_direction):
            assertions.append(
                assert_util.assert_greater(
                    tf.shape(mean_direction)[-1],
                    1,
                    message=
                    '`mean_direction` must be a vector of at least size 2.'))
            assertions.append(
                assert_util.assert_near(
                    tf.cast(1., self.dtype),
                    tf.linalg.norm(mean_direction, axis=-1),
                    message='`mean_direction` must be unit-length'))
        if is_init != tensor_util.is_ref(self._concentration):
            assertions.append(
                assert_util.assert_non_negative(
                    concentration,
                    message='`concentration` must be non-negative'))
        return assertions
Пример #7
0
    def _sample_control_dependencies(self, x):
        assertions = []
        if tensorshape_util.is_fully_defined(x.shape[-2:]):
            if not (tensorshape_util.dims(x.shape)[-2] ==
                    tensorshape_util.dims(x.shape)[-1] == self.dimension):
                raise ValueError(
                    'Input dimension mismatch: expected [..., {}, {}], got {}'.
                    format(self.dimension, self.dimension,
                           tensorshape_util.dims(x.shape)))
        elif self.validate_args:
            msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
                self.dimension, self.dimension, tf.shape(x))
            assertions.append(
                assert_util.assert_equal(tf.shape(x)[-2],
                                         self.dimension,
                                         message=msg))
            assertions.append(
                assert_util.assert_equal(tf.shape(x)[-1],
                                         self.dimension,
                                         message=msg))

        if self.validate_args and not self.input_output_cholesky:
            assertions.append(
                assert_util.assert_less_equal(
                    dtype_util.as_numpy_dtype(x.dtype)(-1),
                    x,
                    message='Correlations must be >= -1.',
                    summarize=30))
            assertions.append(
                assert_util.assert_less_equal(
                    x,
                    dtype_util.as_numpy_dtype(x.dtype)(1),
                    message='Correlations must be <= 1.',
                    summarize=30))
            assertions.append(
                assert_util.assert_near(
                    tf.linalg.diag_part(x),
                    dtype_util.as_numpy_dtype(x.dtype)(1),
                    message='Self-correlations must be = 1.',
                    summarize=30))
            assertions.append(
                assert_util.assert_near(
                    x,
                    tf.linalg.matrix_transpose(x),
                    message='Correlation matrices must be symmetric.',
                    summarize=30))
        return assertions
Пример #8
0
 def _assert_valid_sample(self, x):
   if not self.validate_args:
     return x
   return distribution_util.with_dependencies([
       assert_util.assert_non_positive(x),
       assert_util.assert_near(
           tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1])),
   ], x)
Пример #9
0
 def _is_valid_correlation_cholesky(self, x):
   if not self.validate_args:
     return []
   return [
       assert_util.assert_near(
           x,
           tf.linalg.band_part(x, -1, 0),
           message='Cholesky factors must be lower triangular.')
   ]
Пример #10
0
 def _maybe_assert_valid_sample(self, x, dtype):
     if not self.validate_args:
         return x
     one = tf.ones([], dtype=dtype)
     return distribution_util.with_dependencies([
         assert_util.assert_non_negative(x),
         assert_util.assert_less_equal(x, one),
         assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])),
     ], x)
Пример #11
0
 def _maybe_assert_valid_sample(self, x):
     """Checks the validity of a sample."""
     if not self.validate_args:
         return x
     return distribution_util.with_dependencies([
         assert_util.assert_positive(x, message="samples must be positive"),
         assert_util.assert_near(
             tf.ones([], dtype=self.dtype),
             tf.reduce_sum(input_tensor=x, axis=-1),
             message="sample last-dimension must sum to `1`"),
     ], x)
Пример #12
0
 def _is_valid_correlation_matrix(self, x):
     if not self.validate_args or self.input_output_cholesky:
         return []
     return [
         assert_util.assert_less_equal(
             dtype_util.as_numpy_dtype(x.dtype)(-1),
             x,
             message='Correlations must be >= -1.'),
         assert_util.assert_less_equal(
             x,
             dtype_util.as_numpy_dtype(x.dtype)(1),
             message='Correlations must be <= 1.'),
         assert_util.assert_near(tf.linalg.diag_part(x),
                                 dtype_util.as_numpy_dtype(x.dtype)(1),
                                 message='Self-correlations must be = 1.'),
         assert_util.assert_near(
             x,
             tf.linalg.matrix_transpose(x),
             message='Correlation matrices must be symmetric')
     ]
Пример #13
0
 def _maybe_assert_valid_sample(self, x):
     """Checks the validity of a sample."""
     if not self.validate_args:
         return []
     return [
         assert_util.assert_positive(x, message='samples must be positive'),
         assert_util.assert_near(
             tf.ones([], dtype=self.dtype),
             tf.reduce_sum(x, axis=-1),
             message='sample last-dimension must sum to `1`'),
     ]
Пример #14
0
 def _sample_control_dependencies(self, x):
   """Checks the validity of a sample."""
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(assert_util.assert_non_negative(
       x, message='Samples must be non-negative.'))
   assertions.append(assert_util.assert_near(
       tf.ones([], dtype=self.dtype),
       tf.reduce_sum(x, axis=-1),
       message='Sample last-dimension must sum to `1`.'))
   return assertions
Пример #15
0
 def _assert_ops(
     self,
     ode_fn,
     initial_time,
     initial_state,
     solution_times,
     previous_solver_state,
     rtol,
     atol,
     first_step_size,
     safety_factor,
     min_step_size_factor,
     max_step_size_factor,
     max_num_steps,
     solution_times_by_solver
 ):
   """Constructs dynamic assertions that validate input values to `_solve`."""
   assert_ops = []
   if self._validate_args is None:
     return assert_ops
   if solution_times_by_solver:
     final_time = solution_times.final_time
     assert_ops.append(
         util.assert_positive(final_time - initial_time,
                              'final_time - initial_time'))
   else:
     assert_ops += [
         util.assert_increasing(solution_times, 'solution_times'),
         util.assert_nonnegative(solution_times[0] - initial_time,
                                 'solution_times[0] - initial_time'),
     ]
   if previous_solver_state is not None:
     state_diff = initial_state - previous_solver_state.current_state
     assert_states_match = assert_util.assert_near(
         tf.norm(state_diff), 0., message='`previous_solver_state` does not '
         'match the `initial_state`.')
     assert_ops.append(assert_states_match)
   if self._max_num_steps is not None:
     assert_ops.append(util.assert_positive(max_num_steps, 'max_num_steps'))
   assert_ops += [
       util.assert_positive(rtol, 'rtol'),
       util.assert_positive(atol, 'atol'),
       util.assert_positive(first_step_size, 'first_step_size'),
       util.assert_positive(safety_factor, 'safety_factor'),
       util.assert_positive(
           min_step_size_factor, 'min_step_size_factor'),
       util.assert_positive(
           max_step_size_factor, 'max_step_size_factor'),
   ]
   derivative = ode_fn(initial_time, initial_state)
   tf.nest.assert_same_structure(initial_state, derivative)
   return assert_ops
Пример #16
0
 def _validate_correlationness(self, x):
     if not self.validate_args or self.input_output_cholesky:
         return x
     checks = [
         assert_util.assert_less_equal(
             dtype_util.as_numpy_dtype(x.dtype)(-1),
             x,
             message='Correlations must be >= -1.'),
         assert_util.assert_less_equal(
             x,
             dtype_util.as_numpy_dtype(x.dtype)(1),
             message='Correlations must be <= 1.'),
         assert_util.assert_near(tf.linalg.diag_part(x),
                                 dtype_util.as_numpy_dtype(x.dtype)(1),
                                 message='Self-correlations must be = 1.'),
         assert_util.assert_near(
             x,
             tf.linalg.matrix_transpose(x),
             message='Correlation matrices must be symmetric')
     ]
     with tf.control_dependencies(checks):
         return tf.identity(x)
Пример #17
0
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs,
                                               logits):
    """Return assertions for `Categorical`-type distributions."""
    assertions = []

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')

        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

        msg = 'Argument `{}` must have rank at least 1.'.format(name)
        ndims = tensorshape_util.rank(x.shape)
        if ndims is not None:
            if ndims < 1:
                raise ValueError(msg)
        elif validate_args:
            x = tf.convert_to_tensor(x)
            probs = x if logits is None else None  # Retain tensor conversion.
            logits = x if probs is None else None
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=msg))

    if not validate_args:
        assert not assertions  # Should never happen.
        return []

    if logits is not None:
        if is_init != tensor_util.is_mutable(logits):
            logits = tf.convert_to_tensor(logits)
            assertions.extend(
                distribution_util.assert_categorical_event_shape(logits))

    if probs is not None:
        if is_init != tensor_util.is_mutable(probs):
            probs = tf.convert_to_tensor(probs)
            assertions.extend([
                assert_util.assert_non_negative(probs),
                assert_util.assert_near(
                    tf.reduce_sum(probs, axis=-1),
                    np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)),
                    message='Argument `probs` must sum to 1.')
            ])
            assertions.extend(
                distribution_util.assert_categorical_event_shape(probs))

    return assertions
Пример #18
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(assert_util.assert_non_positive(
       x,
       message=('Samples must be less than or equal to `0` for '
                '`ExpRelaxedOneHotCategorical` or `1` for '
                '`RelaxedOneHotCategorical`.')))
   assertions.append(assert_util.assert_near(
       tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1]),
       message=('Final dimension of samples must sum to `0` for ''.'
                '`ExpRelaxedOneHotCategorical` or `1` '
                'for `RelaxedOneHotCategorical`.')))
   return assertions
Пример #19
0
 def _maybe_assert_valid_sample(self, samples):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return samples
   with tf.control_dependencies([
       assert_util.assert_near(
           1.,
           tf.linalg.norm(samples, axis=-1),
           message='samples must be unit length'),
       assert_util.assert_equal(
           tf.shape(samples)[-1:],
           self.event_shape_tensor(),
           message=('samples must have innermost dimension matching that of '
                    '`self.mean_direction`')),
   ]):
     return tf.identity(samples)
Пример #20
0
    def _sample_control_dependencies(self, samples):
        inner_sample_dim = samples.shape[-1]
        shape_msg = ('Samples must have innermost dimension matching that of '
                     '`self.dimension`. Found {}, expected {}'.format(
                         inner_sample_dim, self.dimension))
        if inner_sample_dim is not None:
            if self.dimension != inner_sample_dim:
                raise ValueError(shape_msg)

        assertions = []
        if not self.validate_args:
            return assertions
        assertions.append(
            assert_util.assert_near(tf.cast(1., dtype=self.dtype),
                                    tf.linalg.norm(samples, axis=-1),
                                    message='Samples must be unit length.'))
        assertions.append(
            assert_util.assert_equal(tf.shape(samples)[-1:],
                                     self.dimension,
                                     message=shape_msg))
        return assertions
Пример #21
0
  def _sample_control_dependencies(self, samples):
    """Check samples for proper shape and whether samples are unit vectors."""
    inner_sample_dim = samples.shape[-1]
    event_size = self.event_shape[-1]
    shape_msg = ('Samples must have innermost dimension matching that of '
                 '`self.mean_direction`.')
    if event_size is not None and inner_sample_dim is not None:
      if event_size != inner_sample_dim:
        raise ValueError(shape_msg)

    assertions = []
    if not self.validate_args:
      return assertions
    assertions.append(assert_util.assert_near(
        1.,
        tf.linalg.norm(samples, axis=-1),
        message='Samples must be unit length.'))
    assertions.append(assert_util.assert_equal(
        tf.shape(samples)[-1:],
        self.event_shape_tensor(),
        message=shape_msg))
    return assertions
  def _parameter_control_dependencies(self, is_init):
    """Validate parameters."""
    bw, bh, kd = None, None, None
    try:
      shape = tf.broadcast_static_shape(self.bin_widths.shape,
                                        self.bin_heights.shape)
    except ValueError as e:
      raise ValueError('`bin_widths`, `bin_heights` must broadcast: {}'.format(
          str(e)))
    bin_sizes_shape = shape
    try:
      shape = tf.broadcast_static_shape(shape[:-1], self.knot_slopes.shape[:-1])
    except ValueError as e:
      raise ValueError(
          '`bin_widths`, `bin_heights`, and `knot_slopes` must broadcast on '
          'batch axes: {}'.format(str(e)))

    assertions = []
    if (tensorshape_util.is_fully_defined(bin_sizes_shape[-1:]) and
        tensorshape_util.is_fully_defined(self.knot_slopes.shape[-1:])):
      if tensorshape_util.rank(self.knot_slopes.shape) > 0:
        num_interior_knots = tensorshape_util.dims(bin_sizes_shape)[-1] - 1
        if tensorshape_util.dims(
            self.knot_slopes.shape)[-1] not in (1, num_interior_knots):
          raise ValueError(
              'Innermost axis of non-scalar `knot_slopes` must broadcast with '
              '{}; got {}.'.format(num_interior_knots, self.knot_slopes.shape))
    elif self.validate_args:
      if is_init != any(
          tensor_util.is_ref(t)
          for t in (self.bin_widths, self.bin_heights, self.knot_slopes)):
        bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
        bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
        kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
        shape = tf.broadcast_dynamic_shape(
            tf.shape((bw + bh)[..., :-1]), tf.shape(kd))
        assertions.append(
            assert_util.assert_greater(
                tf.shape(shape)[0],
                tf.zeros([], dtype=shape.dtype),
                message='`(bin_widths + bin_heights)[..., :-1]` must broadcast '
                'with `knot_slopes` to at least 1-D.'))

    if not self.validate_args:
      assert not assertions
      return assertions

    if (is_init != tensor_util.is_ref(self.bin_widths) or
        is_init != tensor_util.is_ref(self.bin_heights)):
      bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
      bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
      assertions += [
          assert_util.assert_near(
              tf.reduce_sum(bw, axis=-1),
              tf.reduce_sum(bh, axis=-1),
              message='`sum(bin_widths, axis=-1)` must equal '
              '`sum(bin_heights, axis=-1)`.'),
      ]
    if is_init != tensor_util.is_ref(self.bin_widths):
      bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
      assertions += [
          assert_util.assert_positive(
              bw, message='`bin_widths` must be positive.'),
      ]
    if is_init != tensor_util.is_ref(self.bin_heights):
      bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
      assertions += [
          assert_util.assert_positive(
              bh, message='`bin_heights` must be positive.'),
      ]
    if is_init != tensor_util.is_ref(self.knot_slopes):
      kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
      assertions += [
          assert_util.assert_positive(
              kd, message='`knot_slopes` must be positive.'),
      ]
    return assertions
Пример #23
0
  def _sample_n(self, n, seed=None):
    dim0_seed, otherdims_seed = samplers.split_seed(seed,
                                                    salt='von_mises_fisher')
    # The sampling strategy relies on the fact that vMF variates are symmetric
    # about the mean direction. Accordingly, if we have a sampling strategy for
    # the away-from-mean angle, then we can uniformly sample the remaining
    # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
    # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
    #
    # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
    # von-Mises distributed `x` value in [-1, 1], then uniformly select what
    # amounts to a "up" or "down" additional degree of freedom after unit
    # normalizing, followed by a final rotation to the desired mean direction
    # from a basis of (1, 0).
    #
    # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
    # unit sphere over which the distribution is uniform, in particular the
    # circle where x = \hat{x} intersects the unit sphere. We pick a point on
    # that circle, then rotate to the desired mean direction from a basis of
    # (1, 0, 0).
    mean_direction = tf.convert_to_tensor(self.mean_direction)
    concentration = tf.convert_to_tensor(self.concentration)
    event_dim = (
        tf.compat.dimension_value(self.event_shape[0]) or
        self._event_shape_tensor(mean_direction=mean_direction)[0])

    sample_batch_shape = ps.concat([[n], self._batch_shape_tensor(
        mean_direction=mean_direction, concentration=concentration)], axis=0)
    dim = tf.cast(event_dim - 1, self.dtype)
    if event_dim == 3:
      samples_dim0 = self._sample_3d(n,
                                     mean_direction=mean_direction,
                                     concentration=concentration,
                                     seed=dim0_seed)
    else:
      # Wood'94 provides a rejection algorithm to sample the x coordinate.
      # Wood'94 definition of b:
      # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
      # https://stats.stackexchange.com/questions/156729 suggests:
      b = dim / (2 * concentration +
                 tf.sqrt(4 * concentration**2 + dim**2))
      # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
      #     https://github.com/nicola-decao/s-vae-tf/
      x = (1 - b) / (1 + b)
      c = concentration * x + dim * tf.math.log1p(-x**2)
      beta = beta_lib.Beta(dim / 2, dim / 2)

      def cond_fn(w, should_continue, seed):
        del w, seed
        return tf.reduce_any(should_continue)

      def body_fn(w, should_continue, seed):
        """While loop body for sampling the angle `w`."""
        beta_seed, unif_seed, next_seed = samplers.split_seed(seed, n=3)
        z = beta.sample(sample_shape=sample_batch_shape, seed=beta_seed)
        # set_shape needed here because of b/139013403
        tensorshape_util.set_shape(z, w.shape)
        w = tf.where(should_continue,
                     (1. - (1. + b) * z) / (1. - (1. - b) * z),
                     w)
        if not self.allow_nan_stats:
          w = tf.debugging.check_numerics(w, 'w')
        unif = samplers.uniform(
            sample_batch_shape, seed=unif_seed, dtype=self.dtype)
        # set_shape needed here because of b/139013403
        tensorshape_util.set_shape(unif, w.shape)
        should_continue = should_continue & (
            concentration * w + dim * tf.math.log1p(-x * w) - c <
            # Use log1p(-unif) to prevent log(0) and ensure that log(1) is
            # possible.
            tf.math.log1p(-unif))
        return w, should_continue, next_seed

      w = tf.zeros(sample_batch_shape, dtype=self.dtype)
      should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
      samples_dim0, _, _ = tf.while_loop(
          cond=cond_fn, body=body_fn,
          loop_vars=(w, should_continue, dim0_seed))
      samples_dim0 = samples_dim0[..., tf.newaxis]
    if not self._allow_nan_stats:
      # Verify samples are w/in -1, 1, with useful error output tensors (top
      # value rather than all values).
      with tf.control_dependencies([
          assert_util.assert_less_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(1.01)),
          assert_util.assert_greater_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(-1.01)),
      ]):
        samples_dim0 = tf.identity(samples_dim0)
    samples_otherdims_shape = ps.concat([sample_batch_shape, [event_dim - 1]],
                                        axis=0)
    unit_otherdims = tf.math.l2_normalize(
        samplers.normal(
            samples_otherdims_shape, seed=otherdims_seed, dtype=self.dtype),
        axis=-1)
    samples = tf.concat([
        samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
        tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
    ], axis=-1)
    samples = tf.math.l2_normalize(samples, axis=-1)
    if not self.allow_nan_stats:
      samples = tf.debugging.check_numerics(samples, 'samples')

    # Runtime assert that samples are unit length.
    if not self.allow_nan_stats:
      worst, _ = tf.math.top_k(
          tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
      with tf.control_dependencies([
          assert_util.assert_near(
              dtype_util.as_numpy_dtype(self.dtype)(0),
              worst,
              atol=1e-4,
              summarize=100)
      ]):
        samples = tf.identity(samples)
    # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
    # Now, we move the mode to `self.mean_direction` using a rotation matrix.
    if not self.allow_nan_stats:
      # Assert that the basis vector rotates to the mean direction, as expected.
      basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                      self.dtype)
      with tf.control_dependencies([
          assert_util.assert_less(
              tf.linalg.norm(
                  self._rotate(basis, mean_direction=mean_direction) -
                  mean_direction, axis=-1),
              dtype_util.as_numpy_dtype(self.dtype)(1e-5))
      ]):
        return self._rotate(samples, mean_direction=mean_direction)
    return self._rotate(samples, mean_direction=mean_direction)
Пример #24
0
  def __init__(self,
               mean_direction,
               concentration,
               validate_args=False,
               allow_nan_stats=True,
               name='VonMisesFisher'):
    """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. (This is *not* in general the
        mean of the distribution; the mean is not generally in the support of
        the distribution.) NOTE: `D` is currently restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([mean_direction, concentration],
                                      tf.float32)
      mean_direction = tf.convert_to_tensor(
          mean_direction, name='mean_direction', dtype=dtype)
      concentration = tf.convert_to_tensor(
          concentration, name='concentration', dtype=dtype)
      assertions = [
          assert_util.assert_non_negative(
              concentration, message='`concentration` must be non-negative'),
          assert_util.assert_greater(
              tf.shape(mean_direction)[-1],
              1,
              message='`mean_direction` may not have scalar event shape'),
          assert_util.assert_near(
              1.,
              tf.linalg.norm(mean_direction, axis=-1),
              message='`mean_direction` must be unit-length')
      ] if validate_args else []
      static_event_dim = tf.compat.dimension_value(
          tensorshape_util.with_rank_at_least(mean_direction.shape, 1)[-1])
      if static_event_dim is not None and static_event_dim > 5:
        raise ValueError('vMF ndims > 5 is not currently supported')
      elif validate_args:
        assertions += [
            assert_util.assert_less_equal(
                tf.shape(mean_direction)[-1],
                5,
                message='vMF ndims > 5 is not currently supported')
        ]
      with tf.control_dependencies(assertions):
        self._mean_direction = tf.identity(mean_direction)
        self._concentration = tf.identity(concentration)
      dtype_util.assert_same_float_dtype(
          [self._mean_direction, self._concentration])
      # mean_direction is always reparameterized.
      # concentration is only for event_dim==3, via an inversion sampler.
      reparameterization_type = (
          reparameterization.FULLY_REPARAMETERIZED
          if static_event_dim == 3 else
          reparameterization.NOT_REPARAMETERIZED)
      super(VonMisesFisher, self).__init__(
          dtype=self._concentration.dtype,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          reparameterization_type=reparameterization_type,
          parameters=parameters,
          name=name)
    def __init__(self,
                 loc=None,
                 covariance_matrix=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='MultivariateNormalFullCovariance'):
        """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and
    `covariance_matrix` arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `covariance_matrix`. The last dimension of `loc` (if provided) must
    broadcast with this.

    A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive
    definite matrix.  In other words it is (real) symmetric with all eigenvalues
    strictly positive.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      covariance_matrix: Floating-point, symmetric positive definite `Tensor` of
        same `dtype` as `loc`.  The strict upper triangle of `covariance_matrix`
        is ignored, so if `covariance_matrix` is not symmetric no error will be
        raised (unless `validate_args is True`).  `covariance_matrix` has shape
        `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if neither `loc` nor `covariance_matrix` are specified.
    """
        parameters = dict(locals())

        # Convert the covariance_matrix up to a scale_tril and call MVNTriL.
        with tf.name_scope(name) as name:
            with tf.name_scope('init'):
                dtype = dtype_util.common_dtype([loc, covariance_matrix],
                                                tf.float32)
                loc = loc if loc is None else tf.convert_to_tensor(
                    loc, name='loc', dtype=dtype)
                if covariance_matrix is None:
                    scale_tril = None
                else:
                    covariance_matrix = tf.convert_to_tensor(
                        covariance_matrix,
                        name='covariance_matrix',
                        dtype=dtype)
                    if validate_args:
                        covariance_matrix = distribution_util.with_dependencies(
                            [
                                assert_util.assert_near(
                                    covariance_matrix,
                                    tf.linalg.matrix_transpose(
                                        covariance_matrix),
                                    message='Matrix was not symmetric')
                            ], covariance_matrix)
                    # No need to validate that covariance_matrix is non-singular.
                    # LinearOperatorLowerTriangular has an assert_non_singular method that
                    # is called by the Bijector.
                    # However, cholesky() ignores the upper triangular part, so we do need
                    # to separately assert symmetric.
                    scale_tril = tf.linalg.cholesky(covariance_matrix)
                super(MultivariateNormalFullCovariance,
                      self).__init__(loc=loc,
                                     scale_tril=scale_tril,
                                     validate_args=validate_args,
                                     allow_nan_stats=allow_nan_stats,
                                     name=name)
        self._parameters = parameters