def maybe_assert_bernoulli_param_correctness(
    is_init, validate_args, probs, logits):
  """Return assertions for `Bernoulli`-type distributions."""
  if is_init:
    x, name = (probs, 'probs') if logits is None else (logits, 'logits')
    if not dtype_util.is_floating(x.dtype):
      raise TypeError(
          'Argument `{}` must having floating type.'.format(name))

  if not validate_args:
    return []

  assertions = []

  if probs is not None:
    if is_init != tensor_util.is_ref(probs):
      probs = tf.convert_to_tensor(probs)
      one = tf.constant(1., probs.dtype)
      assertions += [
          assert_util.assert_non_negative(
              probs, message='probs has components less than 0.'),
          assert_util.assert_less_equal(
              probs, one, message='probs has components greater than 1.')
      ]

  return assertions
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
    """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
    batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape)
    if tensorshape_util.is_fully_defined(batch_shape_static):
        return np.int32(batch_shape_static), batch_shape_static, []
    with tf.name_scope(name or 'calculate_reshape'):
        original_size = tf.reduce_prod(original_shape)
        implicit_dim = tf.equal(new_shape, -1)
        size_implicit_dim = (original_size //
                             tf.maximum(1, -tf.reduce_prod(new_shape)))
        expanded_new_shape = tf.where(  # Assumes exactly one `-1`.
            implicit_dim, size_implicit_dim, new_shape)
        validations = [] if not validate else [  # pylint: disable=g-long-ternary
            assert_util.assert_rank(
                original_shape, 1, message='Original shape must be a vector.'),
            assert_util.assert_rank(
                new_shape, 1, message='New shape must be a vector.'),
            assert_util.assert_less_equal(
                tf.math.count_nonzero(implicit_dim, dtype=tf.int32),
                1,
                message='At most one dimension can be unknown.'),
            assert_util.assert_positive(
                expanded_new_shape, message='Shape elements must be >=-1.'),
            assert_util.assert_equal(tf.reduce_prod(expanded_new_shape),
                                     original_size,
                                     message='Shape sizes do not match.'),
        ]
        return expanded_new_shape, batch_shape_static, validations
 def _maybe_assert_valid_sample(self, x, dtype):
     if not self.validate_args:
         return x
     one = tf.ones([], dtype=dtype)
     return distribution_util.with_dependencies([
         assert_util.assert_non_negative(x),
         assert_util.assert_less_equal(x, one),
         assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])),
     ], x)
Beispiel #4
0
 def _maybe_assert_valid_y(self, y):
   if not self.validate_args:
     return []
   is_positive = assert_util.assert_non_negative(
       y, message='Inverse transformation input must be greater than 0.')
   less_than_one = assert_util.assert_less_equal(
       y,
       tf.constant(1., y.dtype),
       message='Inverse transformation input must be less than or equal to 1.')
   return [is_positive, less_than_one]
Beispiel #5
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return counts
     counts = distribution_util.embed_check_nonnegative_integer_form(counts)
     msg = ('Sampled counts must be itemwise less than '
            'or equal to `total_count` parameter.')
     return distribution_util.with_dependencies([
         assert_util.assert_less_equal(
             counts, self.total_count, message=msg),
     ], counts)
Beispiel #6
0
 def _maybe_assert_valid(self, x):
     if not self.validate_args:
         return x
     return distribution_util.with_dependencies([
         assert_util.assert_non_negative(
             x, message="sample must be non-negative"),
         assert_util.assert_less_equal(
             x,
             tf.ones([], self.concentration0.dtype),
             message="sample must be no larger than `1`."),
     ], x)
Beispiel #7
0
 def _validate_correlationness(self, x):
     if not self.validate_args or self.input_output_cholesky:
         return x
     checks = [
         assert_util.assert_less_equal(
             dtype_util.as_numpy_dtype(x.dtype)(-1),
             x,
             message='Correlations must be >= -1.'),
         assert_util.assert_less_equal(
             x,
             dtype_util.as_numpy_dtype(x.dtype)(1),
             message='Correlations must be <= 1.'),
         assert_util.assert_near(tf.linalg.diag_part(x),
                                 dtype_util.as_numpy_dtype(x.dtype)(1),
                                 message='Self-correlations must be = 1.'),
         assert_util.assert_near(
             x,
             tf.linalg.matrix_transpose(x),
             message='Correlation matrices must be symmetric')
     ]
     with tf.control_dependencies(checks):
         return tf.identity(x)
 def _assertions(self, t):
     if not self.validate_args:
         return []
     return [
         assert_util.assert_non_negative(
             t,
             message="Inverse transformation input must be greater than 0."
         ),
         assert_util.assert_less_equal(
             t,
             dtype_util.as_numpy_dtype(t.dtype)(1.),
             message=
             "Inverse transformation input must be less than or equal "
             "to 1.")
     ]
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if self._probs is not None:
         if is_init != tensor_util.is_ref(self._probs):
             probs = tf.convert_to_tensor(self._probs)
             assertions.append(
                 assert_util.assert_positive(
                     probs, message='Argument `probs` must be positive.'))
             assertions.append(
                 assert_util.assert_less_equal(
                     probs,
                     dtype_util.as_numpy_dtype(self.dtype)(1.),
                     message=
                     'Argument `probs` must be less than or equal to 1.'))
     return assertions
Beispiel #10
0
 def _make_runtime_assertions(self, distribution, reinterpreted_batch_ndims,
                              validate_args):
     assertions = []
     static_reinterpreted_batch_ndims = tf.get_static_value(
         reinterpreted_batch_ndims)
     batch_ndims = tensorshape_util.rank(distribution.batch_shape)
     if batch_ndims is not None and static_reinterpreted_batch_ndims is not None:
         if static_reinterpreted_batch_ndims > batch_ndims:
             raise ValueError("reinterpreted_batch_ndims({}) cannot exceed "
                              "distribution.batch_ndims({})".format(
                                  static_reinterpreted_batch_ndims,
                                  batch_ndims))
     elif validate_args:
         assertions.append(
             assert_util.assert_less_equal(
                 reinterpreted_batch_ndims,
                 prefer_static.rank_from_shape(
                     distribution.batch_shape_tensor,
                     distribution.batch_shape),
                 message=("reinterpreted_batch_ndims cannot exceed "
                          "distribution.batch_ndims")))
     return assertions
def maybe_assert_negative_binomial_param_correctness(is_init, validate_args,
                                                     total_count, probs,
                                                     logits):
    """Return assertions for `NegativeBinomial`-type distributions."""
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')
        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

    if not validate_args:
        return []

    assertions = []
    if is_init != tensor_util.is_ref(total_count):
        total_count = tf.convert_to_tensor(total_count)
        assertions.extend([
            assert_util.assert_non_negative(
                total_count,
                message='`total_count` has components less than 0.'),
            distribution_util.assert_integer_form(
                total_count,
                message='`total_count` has fractional components.')
        ])
    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            one = tf.constant(1., probs.dtype)
            assertions.extend([
                assert_util.assert_non_negative(
                    probs, message='`probs` has components less than 0.'),
                assert_util.assert_less_equal(
                    probs,
                    one,
                    message='`probs` has components greater than 1.')
            ])

    return assertions
    def _prob(self, x):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        if self.validate_args:
            with tf.control_dependencies([
                    assert_util.assert_greater_equal(x, low),
                    assert_util.assert_less_equal(x, high)
            ]):
                x = tf.identity(x)

        interval_length = high - low
        # This is the pdf function when a low <= high <= x. This looks like
        # a triangle, so we have to treat each line segment separately.
        result_inside_interval = tf.where(
            (x >= low) & (x <= peak),
            # Line segment from (low, 0) to (peak, 2 / (high - low)).
            2. * (x - low) / (interval_length * (peak - low)),
            # Line segment from (peak, 2 / (high - low)) to (high, 0).
            2. * (high - x) / (interval_length * (high - peak)))

        return tf.where((x < low) | (x > high), tf.zeros_like(x),
                        result_inside_interval)
Beispiel #13
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        logits = self._logits
        probs = self._probs
        param, name = (probs, 'probs') if logits is None else (logits,
                                                               'logits')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, tf.int32.max)
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > tf.int32.max:
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1:],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(self.temperature):
            assertions.append(assert_util.assert_positive(self.temperature))

        if probs is not None:
            probs = param  # reuse tensor conversion from above
            if is_init != tensor_util.is_ref(probs):
                probs = tf.convert_to_tensor(probs)
                one = tf.ones([], dtype=probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(probs),
                    assert_util.assert_less_equal(probs, one),
                    assert_util.assert_near(
                        tf.reduce_sum(probs, axis=-1),
                        one,
                        message='Argument `probs` must sum to 1.'),
                ])

        return assertions
    def _sample_n(self, n, seed=None):
        seed = SeedStream(seed, salt='vom_mises_fisher')
        # The sampling strategy relies on the fact that vMF variates are symmetric
        # about the mean direction. Accordingly, if we have a sampling strategy for
        # the away-from-mean angle, then we can uniformly sample the remaining
        # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
        # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
        #
        # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
        # von-Mises distributed `x` value in [-1, 1], then uniformly select what
        # amounts to a "up" or "down" additional degree of freedom after unit
        # normalizing, followed by a final rotation to the desired mean direction
        # from a basis of (1, 0).
        #
        # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
        # unit sphere over which the distribution is uniform, in particular the
        # circle where x = \hat{x} intersects the unit sphere. We pick a point on
        # that circle, then rotate to the desired mean direction from a basis of
        # (1, 0, 0).
        event_dim = (tf.compat.dimension_value(self.event_shape[0])
                     or self._event_shape_tensor()[0])

        sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()],
                                       axis=0)
        dim = tf.cast(event_dim - 1, self.dtype)
        if event_dim == 3:
            samples_dim0 = self._sample_3d(n, seed=seed)
        else:
            # Wood'94 provides a rejection algorithm to sample the x coordinate.
            # Wood'94 definition of b:
            # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
            # https://stats.stackexchange.com/questions/156729 suggests:
            b = dim / (2 * self.concentration +
                       tf.sqrt(4 * self.concentration**2 + dim**2))
            # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
            #     https://github.com/nicola-decao/s-vae-tf/
            x = (1 - b) / (1 + b)
            c = self.concentration * x + dim * tf.math.log1p(-x**2)
            beta = beta_lib.Beta(dim / 2, dim / 2)

            def cond_fn(w, should_continue):
                del w
                return tf.reduce_any(should_continue)

            def body_fn(w, should_continue):
                z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
                # set_shape needed here because of b/139013403
                z.set_shape(w.shape)
                w = tf.where(should_continue,
                             (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
                w = tf.debugging.check_numerics(w, 'w')
                unif = tf.random.uniform(sample_batch_shape,
                                         seed=seed(),
                                         dtype=self.dtype)
                # set_shape needed here because of b/139013403
                unif.set_shape(w.shape)
                should_continue = tf.logical_and(
                    should_continue,
                    self.concentration * w + dim * tf.math.log1p(-x * w) - c <
                    tf.math.log(unif))
                return w, should_continue

            w = tf.zeros(sample_batch_shape, dtype=self.dtype)
            should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
            samples_dim0 = tf.while_loop(cond=cond_fn,
                                         body=body_fn,
                                         loop_vars=(w, should_continue))[0]
            samples_dim0 = samples_dim0[..., tf.newaxis]
        if not self._allow_nan_stats:
            # Verify samples are w/in -1, 1, with useful error output tensors (top
            # value rather than all values).
            with tf.control_dependencies([
                    assert_util.assert_less_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(1.01),
                        data=[
                            tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]
                        ]),
                    assert_util.assert_greater_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(-1.01),
                        data=[
                            -tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]
                        ])
            ]):
                samples_dim0 = tf.identity(samples_dim0)
        samples_otherdims_shape = tf.concat(
            [sample_batch_shape, [event_dim - 1]], axis=0)
        unit_otherdims = tf.math.l2_normalize(tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
                                              axis=-1)
        samples = tf.concat(
            [
                samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
                tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
            ],
            axis=-1)
        samples = tf.math.l2_normalize(samples, axis=-1)
        if not self._allow_nan_stats:
            samples = tf.debugging.check_numerics(samples, 'samples')

        # Runtime assert that samples are unit length.
        if not self._allow_nan_stats:
            worst, idx = tf.math.top_k(
                tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
            with tf.control_dependencies([
                    assert_util.assert_near(
                        dtype_util.as_numpy_dtype(self.dtype)(0),
                        worst,
                        data=[
                            worst, idx,
                            tf.gather(tf.reshape(samples, [-1, event_dim]),
                                      idx)
                        ],
                        atol=1e-4,
                        summarize=100)
            ]):
                samples = tf.identity(samples)
        # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
        # Now, we move the mode to `self.mean_direction` using a rotation matrix.
        if not self._allow_nan_stats:
            # Assert that the basis vector rotates to the mean direction, as expected.
            basis = tf.cast(
                tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                self.dtype)
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.linalg.norm(self._rotate(basis) -
                                       self.mean_direction,
                                       axis=-1),
                        dtype_util.as_numpy_dtype(self.dtype)(1e-5))
            ]):
                return self._rotate(samples)
        return self._rotate(samples)
    def __init__(self,
                 mean_direction,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='VonMisesFisher'):
        """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. (This is *not* in general the
        mean of the distribution; the mean is not generally in the support of
        the distribution.) NOTE: `D` is currently restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([mean_direction, concentration],
                                            tf.float32)
            mean_direction = tf.convert_to_tensor(mean_direction,
                                                  name='mean_direction',
                                                  dtype=dtype)
            concentration = tf.convert_to_tensor(concentration,
                                                 name='concentration',
                                                 dtype=dtype)
            assertions = [
                assert_util.assert_non_negative(
                    concentration,
                    message='`concentration` must be non-negative'),
                assert_util.assert_greater(
                    tf.shape(mean_direction)[-1],
                    1,
                    message='`mean_direction` may not have scalar event shape'
                ),
                assert_util.assert_near(
                    1.,
                    tf.linalg.norm(mean_direction, axis=-1),
                    message='`mean_direction` must be unit-length')
            ] if validate_args else []
            static_event_dim = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(mean_direction.shape,
                                                    1)[-1])
            if static_event_dim is not None and static_event_dim > 5:
                raise ValueError('vMF ndims > 5 is not currently supported')
            elif validate_args:
                assertions += [
                    assert_util.assert_less_equal(
                        tf.shape(mean_direction)[-1],
                        5,
                        message='vMF ndims > 5 is not currently supported')
                ]
            with tf.control_dependencies(assertions):
                self._mean_direction = tf.identity(mean_direction)
                self._concentration = tf.identity(concentration)
            dtype_util.assert_same_float_dtype(
                [self._mean_direction, self._concentration])
            # mean_direction is always reparameterized.
            # concentration is only for event_dim==3, via an inversion sampler.
            reparameterization_type = (reparameterization.FULLY_REPARAMETERIZED
                                       if static_event_dim == 3 else
                                       reparameterization.NOT_REPARAMETERIZED)
            super(VonMisesFisher, self).__init__(
                dtype=self._concentration.dtype,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                reparameterization_type=reparameterization_type,
                parameters=parameters,
                name=name)
Beispiel #16
0
    def __init__(self,
                 df,
                 scale_operator,
                 input_output_cholesky=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Wishart distributions.

    Args:
      df: `float` or `double` tensor, the degrees of freedom of the
        distribution(s). `df` must be greater than or equal to `k`.
      scale_operator: `float` or `double` instance of `LinearOperator`.
      input_output_cholesky: Python `bool`. If `True`, functions whose input or
        output have the semantics of samples assume inputs are in Cholesky form
        and return outputs in Cholesky form. In particular, if this flag is
        `True`, input to `log_prob` is presumed of Cholesky form and output from
        `sample`, `mean`, and `mode` are of Cholesky form.  Setting this
        argument to `True` is purely a computational optimization and does not
        change the underlying distribution; for instance, `mean` returns the
        Cholesky of the mean, not the mean of Cholesky factors. The `variance`
        and `stddev` methods are unaffected by this flag.
        Default value: `False` (i.e., input/output does not have Cholesky
        semantics).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if scale is not floating-type
      TypeError: if scale.dtype != df.dtype
      ValueError: if df < k, where scale operator event shape is
        `(k, k)`
    """
        parameters = dict(locals())
        self._input_output_cholesky = input_output_cholesky
        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                if not dtype_util.is_floating(scale_operator.dtype):
                    raise TypeError(
                        "scale_operator.dtype=%s is not a floating-point type"
                        % scale_operator.dtype)
                if not scale_operator.is_square:
                    print(scale_operator.to_dense().eval())
                    raise ValueError("scale_operator must be square.")

                self._scale_operator = scale_operator
                self._df = tf.convert_to_tensor(df,
                                                dtype=scale_operator.dtype,
                                                name="df")
                dtype_util.assert_same_float_dtype(
                    [self._df, self._scale_operator])
                if tf.compat.dimension_value(
                        self._scale_operator.shape[-1]) is None:
                    self._dimension = tf.cast(
                        self._scale_operator.domain_dimension_tensor(),
                        dtype=self._scale_operator.dtype,
                        name="dimension")
                else:
                    self._dimension = tf.convert_to_tensor(
                        tf.compat.dimension_value(
                            self._scale_operator.shape[-1]),
                        dtype=self._scale_operator.dtype,
                        name="dimension")
                df_val = tf.get_static_value(self._df)
                dim_val = tf.get_static_value(self._dimension)
                if df_val is not None and dim_val is not None:
                    df_val = np.asarray(df_val)
                    if not df_val.shape:
                        df_val = [df_val]
                    if np.any(df_val < dim_val):
                        raise ValueError(
                            "Degrees of freedom (df = %s) cannot be less than "
                            "dimension of scale matrix (scale.dimension = %s)"
                            % (df_val, dim_val))
                elif validate_args:
                    assertions = assert_util.assert_less_equal(
                        self._dimension,
                        self._df,
                        message=("Degrees of freedom (df = %s) cannot be "
                                 "less than dimension of scale matrix "
                                 "(scale.dimension = %s)" %
                                 (self._dimension, self._df)))
                    self._df = distribution_util.with_dependencies(
                        [assertions], self._df)
        super(_WishartLinearOperator, self).__init__(
            dtype=self._scale_operator.dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            parameters=parameters,
            name=name)