Exemple #1
0
 def _batch_shape(self):
   scalar_shape = tf.TensorShape([])
   return tf.broadcast_static_shape(
       scalar_shape if self.amplitude is None else self.amplitude.shape,
       scalar_shape if self.length_scale is None else self.length_scale.shape)
Exemple #2
0
 def _batch_shape(self):
     return tf.broadcast_static_shape(self.loc.shape,
                                      self.concentration.shape)
Exemple #3
0
 def _batch_shape(self):
   return tf.broadcast_static_shape(
       (self._probs if self._logits is None else self._logits).shape[:-1],
       self.total_count.shape)
Exemple #4
0
def _mvnormal_quasi(sample_shape,
                    mean,
                    random_type,
                    seed,
                    covariance_matrix=None,
                    scale_matrix=None,
                    validate_args=False,
                    dtype=None,
                    **kwargs):
    """Returns normal draws using low-discrepancy sequences."""
    if scale_matrix is None and covariance_matrix is None:
        scale_matrix = tf.linalg.eye(tf.shape(mean)[-1], dtype=mean.dtype)
    elif scale_matrix is None and covariance_matrix is not None:
        covariance_matrix = tf.convert_to_tensor(covariance_matrix,
                                                 dtype=dtype,
                                                 name='covariance_matrix')
        scale_matrix = tf.linalg.cholesky(covariance_matrix)
    else:
        scale_matrix = tf.convert_to_tensor(scale_matrix,
                                            dtype=dtype,
                                            name='scale_matrix')
    scale_shape = scale_matrix.shape
    dim = scale_shape[-1]
    if mean is None:
        mean = tf.zeros([dim], dtype=scale_matrix.dtype)
    # Batch shape of the output
    batch_shape = tf.broadcast_static_shape(mean.shape, scale_shape[:-1])
    # Reverse elements of the batch shape
    batch_shape_reverse = tf.TensorShape(reversed(batch_shape))
    # Transposed shape of the output
    output_shape_t = tf.concat([batch_shape_reverse, sample_shape], -1)
    # Number of quasi random samples
    num_samples = tf.reduce_prod(output_shape_t) // dim
    # Number of initial low discrepancy sequence numbers to skip
    if 'skip' in kwargs:
        skip = kwargs['skip']
    else:
        skip = 0
    if random_type == RandomType.SOBOL:
        # Shape [num_samples, dim] of the Sobol samples
        low_discrepancy_seq = sobol.sample(dim=dim,
                                           num_results=num_samples,
                                           skip=skip,
                                           dtype=mean.dtype)
    else:  # HALTON or HALTON_RANDOMIZED random_dtype
        if 'randomization_params' in kwargs:
            randomization_params = kwargs['randomization_params']
        else:
            randomization_params = None
        randomized = random_type == RandomType.HALTON_RANDOMIZED
        # Shape [num_samples, dim] of the Sobol samples
        low_discrepancy_seq, _ = halton.sample(
            dim=dim,
            sequence_indices=tf.range(skip, skip + num_samples),
            randomized=randomized,
            randomization_params=randomization_params,
            seed=seed,
            validate_args=validate_args,
            dtype=mean.dtype)

    # Transpose to the shape [dim, num_samples]
    low_discrepancy_seq = tf.transpose(low_discrepancy_seq)
    size_sample = tf.size(sample_shape)
    size_batch = tf.size(batch_shape)
    # Permutation for `output_shape_t` to the output shape
    permutation = tf.concat([
        tf.range(size_batch, size_batch + size_sample),
        tf.range(size_batch - 1, -1, -1)
    ], -1)
    # Reshape Sobol samples to the correct output shape
    low_discrepancy_seq = tf.transpose(
        tf.reshape(low_discrepancy_seq, output_shape_t), permutation)
    # Apply inverse Normal CDF to Sobol samples to obtain the corresponding
    # Normal samples
    samples = tf.math.erfinv((low_discrepancy_seq - 0.5) * 2) * _SQRT_2
    return mean + tf.linalg.matvec(scale_matrix, samples)
Exemple #5
0
    def _parameter_control_dependencies(self, is_init):
        """Validate parameters."""
        bw, bh, kd = None, None, None
        try:
            shape = tf.broadcast_static_shape(self.bin_widths.shape,
                                              self.bin_heights.shape)
        except ValueError as e:
            raise ValueError(
                '`bin_widths`, `bin_heights` must broadcast: {}'.format(
                    str(e)))
        bin_sizes_shape = shape
        try:
            shape = tf.broadcast_static_shape(shape[:-1],
                                              self.knot_slopes.shape[:-1])
        except ValueError as e:
            raise ValueError(
                '`bin_widths`, `bin_heights`, and `knot_slopes` must broadcast on '
                'batch axes: {}'.format(str(e)))

        assertions = []
        if (tensorshape_util.is_fully_defined(bin_sizes_shape[-1:])
                and tensorshape_util.is_fully_defined(
                    self.knot_slopes.shape[-1:])):
            if tensorshape_util.rank(self.knot_slopes.shape) > 0:
                num_interior_knots = tensorshape_util.dims(
                    bin_sizes_shape)[-1] - 1
                if tensorshape_util.dims(self.knot_slopes.shape)[-1] not in (
                        1, num_interior_knots):
                    raise ValueError(
                        'Innermost axis of non-scalar `knot_slopes` must broadcast with '
                        '{}; got {}.'.format(num_interior_knots,
                                             self.knot_slopes.shape))
        elif self.validate_args:
            if is_init != any(
                    tensor_util.is_ref(t)
                    for t in (self.bin_widths, self.bin_heights,
                              self.knot_slopes)):
                bw = tf.convert_to_tensor(
                    self.bin_widths) if bw is None else bw
                bh = tf.convert_to_tensor(
                    self.bin_heights) if bh is None else bh
                kd = _ensure_at_least_1d(
                    self.knot_slopes) if kd is None else kd
                shape = tf.broadcast_dynamic_shape(
                    tf.shape((bw + bh)[..., :-1]), tf.shape(kd))
                assertions.append(
                    assert_util.assert_greater(
                        tf.shape(shape)[0],
                        tf.zeros([], dtype=shape.dtype),
                        message=
                        '`(bin_widths + bin_heights)[..., :-1]` must broadcast '
                        'with `knot_slopes` to at least 1-D.'))

        if not self.validate_args:
            assert not assertions
            return assertions

        if (is_init != tensor_util.is_ref(self.bin_widths)
                or is_init != tensor_util.is_ref(self.bin_heights)):
            bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
            bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
            assertions += [
                assert_util.assert_near(
                    tf.reduce_sum(bw, axis=-1),
                    tf.reduce_sum(bh, axis=-1),
                    message='`sum(bin_widths, axis=-1)` must equal '
                    '`sum(bin_heights, axis=-1)`.'),
            ]
        if is_init != tensor_util.is_ref(self.bin_widths):
            bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
            assertions += [
                assert_util.assert_positive(
                    bw, message='`bin_widths` must be positive.'),
            ]
        if is_init != tensor_util.is_ref(self.bin_heights):
            bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
            assertions += [
                assert_util.assert_positive(
                    bh, message='`bin_heights` must be positive.'),
            ]
        if is_init != tensor_util.is_ref(self.knot_slopes):
            kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
            assertions += [
                assert_util.assert_positive(
                    kd, message='`knot_slopes` must be positive.'),
            ]
        return assertions
Exemple #6
0
 def _batch_shape(self):
     return tf.broadcast_static_shape(self.loc.shape, self.scale.shape)
Exemple #7
0
 def _batch_shape(self):
     if self.to_shape is None:
         return tf.broadcast_static_shape(
             self.distribution.batch_shape,
             tf.TensorShape(tf.get_static_value(self.with_shape)))
     return tf.TensorShape(tf.get_static_value(self.to_shape))
 def _batch_shape(self):
     return tf.broadcast_static_shape(self.loc.shape,
                                      self.cutpoints.shape[:-1])
def _kl_brute_force(a, b, name=None):
    """Batched KL divergence `KL(a || b)` for multivariate Normals.

  With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and
  covariance `C_a`, `C_b` respectively,

  ```
  KL(a || b) = 0.5 * ( L - k + T + Q ),
  L := Log[Det(C_b)] - Log[Det(C_a)]
  T := trace(C_b^{-1} C_a),
  Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
  ```

  This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
  methods for solving systems with `C_b` may be available, a dense version of
  (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B`
  is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
  and `y`.

  Args:
    a: Instance of `MultivariateNormalLinearOperator`.
    b: Instance of `MultivariateNormalLinearOperator`.
    name: (optional) name to use for created ops. Default "kl_mvn".

  Returns:
    Batchwise `KL(a || b)`.
  """
    def squared_frobenius_norm(x):
        """Helper to make KL calculation slightly more readable."""
        # http://mathworld.wolfram.com/FrobeniusNorm.html
        # The gradient of KL[p,q] is not defined when p==q. The culprit is
        # tf.norm, i.e., we cannot use the commented out code.
        # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1]))
        return tf.reduce_sum(tf.square(x), axis=[-2, -1])

    # TODO(b/35041439): See also b/35040945. Remove this function once LinOp
    # supports something like:
    #   A.inverse().solve(B).norm(order='fro', axis=[-1, -2])
    def is_diagonal(x):
        """Helper to identify if `LinearOperator` has only a diagonal component."""
        return (isinstance(x, tf.linalg.LinearOperatorIdentity)
                or isinstance(x, tf.linalg.LinearOperatorScaledIdentity)
                or isinstance(x, tf.linalg.LinearOperatorDiag))

    with tf.name_scope(name or 'kl_mvn'):
        # Calculation is based on:
        # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
        # and,
        # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
        # i.e.,
        #   If Ca = AA', Cb = BB', then
        #   tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
        #                  = tr[inv(B) A A' inv(B)']
        #                  = tr[(inv(B) A) (inv(B) A)']
        #                  = sum_{ij} (inv(B) A)_{ij}**2
        #                  = ||inv(B) A||_F**2
        # where ||.||_F is the Frobenius norm and the second equality follows from
        # the cyclic permutation property.
        if is_diagonal(a.scale) and is_diagonal(b.scale):
            # Using `stddev` because it handles expansion of Identity cases.
            b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis]
        else:
            b_inv_a = b.scale.solve(a.scale.to_dense())
        kl_div = (b.scale.log_abs_determinant() -
                  a.scale.log_abs_determinant() + 0.5 *
                  (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) +
                   squared_frobenius_norm(b_inv_a) + squared_frobenius_norm(
                       b.scale.solve((b.mean() - a.mean())[..., tf.newaxis]))))
        tensorshape_util.set_shape(
            kl_div, tf.broadcast_static_shape(a.batch_shape, b.batch_shape))
        return kl_div
Exemple #10
0
 def _batch_shape(self):
     x = self._probs if self._logits is None else self._logits
     return tf.broadcast_static_shape(self.total_count.shape, x.shape)
Exemple #11
0
 def _batch_shape(self):
     if self._is_fixed_inputs_empty():
         return self._base_kernel.batch_shape
     return tf.broadcast_static_shape(
         self._base_kernel.batch_shape,
         self._fixed_inputs.shape[:-(self._base_kernel.feature_ndims + 1)])
 def _batch_shape(self):
     return tf.broadcast_static_shape(
         self.distribution.batch_shape,
         self.mixture_distribution.logits.shape)[:-1]
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init:
            axis_ = tf.get_static_value(self._axis)
            if axis_ is not None and axis_ < 0:
                raise ValueError('Axis should be positive, %d was given' %
                                 axis_)
            if axis_ is None:
                assertions.append(tf.assert_greater_equal(axis_, 0))

            all_event_shapes = [d.event_shape for d in self._distributions]
            if all(
                    tensorshape_util.is_fully_defined(event_shape)
                    for event_shape in all_event_shapes):
                if all_event_shapes[1:] != all_event_shapes[:-1]:
                    raise ValueError(
                        'Distributions must have the same `event_shape`;'
                        'found: {}' % all_event_shapes)

            all_batch_shapes = [d.batch_shape for d in self._distributions]
            if all(
                    tensorshape_util.is_fully_defined(batch_shape)
                    for batch_shape in all_batch_shapes):
                batch_shape = all_batch_shapes[0].as_list()
                batch_shape[self._axis] = 1
                for b in all_batch_shapes[1:]:
                    b = b.as_list()
                    if len(batch_shape) != len(b):
                        raise ValueError(
                            'Incompatible batch shape % s with %s' %
                            (batch_shape, b))
                    b[self._axis] = 1
                    tf.broadcast_static_shape(
                        tensorshape_util.constant_value_as_shape(batch_shape),
                        tensorshape_util.constant_value_as_shape(b))

        if not self.validate_args:
            return []

        if self.validate_args:
            # Validate that event shapes all match.
            all_event_shapes = [d.event_shape for d in self._distributions]
            if not all(
                    tensorshape_util.is_fully_defined(event_shape)
                    for event_shape in all_event_shapes):
                all_event_shape_tensors = [
                    d.event_shape_tensor() for d in self._distributions
                ]

                def _get_shapes(static_shape, dynamic_shape):
                    if tensorshape_util.is_fully_defined(static_shape):
                        return static_shape
                    else:
                        return dynamic_shape

                event_shapes = tf.nest.map_structure(_get_shapes,
                                                     all_event_shapes,
                                                     all_event_shape_tensors)
                event_shapes = tf.nest.flatten(event_shapes)
                assertions.extend(
                    assert_util.assert_equal(
                        e1,
                        e2,
                        message='Distributions should have same event shapes.')
                    for e1, e2 in zip(event_shapes[1:], event_shapes[:-1]))

            # Validate that batch shapes are broadcastable and concatenable along
            # the specified axis.
            if not all(
                    tensorshape_util.is_fully_defined(d.batch_shape)
                    for d in self._distributions):
                for i, d in enumerate(self._distributions[:-1]):
                    assertions.append(
                        tf.assert_equal(
                            tf.size(d.batch_shape_tensor()),
                            tf.size(
                                self._distributions[i +
                                                    1].batch_shape_tensor())))

                batch_shape_tensors = [
                    ps.tensor_scatter_nd_update(d.batch_shape_tensor(),
                                                updates=1,
                                                indices=[self._axis])
                    for d in self._distributions
                ]
                assertions.append(
                    functools.reduce(tf.broadcast_dynamic_shape,
                                     batch_shape_tensors[1:],
                                     batch_shape_tensors[:-1]))
        return assertions