예제 #1
0
    def _covariance(self):
        # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/

        event_dimension_static = tf.compat.dimension_value(self.event_shape[0])

        mean_direction = tf.convert_to_tensor(self.mean_direction)
        concentration = tf.convert_to_tensor(self.concentration)
        event_dimension = tf.cast(
            self._event_shape_tensor(mean_direction)[0], self.dtype)
        safe_conc = tf.where(concentration > 0, concentration,
                             tf.ones_like(concentration))[..., tf.newaxis]
        h = tfp_math.bessel_iv_ratio(event_dimension / 2, safe_conc)
        intermediate = (
            tf.matmul(mean_direction[..., :, tf.newaxis],
                      mean_direction[..., tf.newaxis, :]) *
            (1 - event_dimension * h / safe_conc - h**2)[..., tf.newaxis])
        cov = tf.linalg.set_diag(
            intermediate,
            tf.linalg.diag_part(intermediate) + (h / safe_conc))
        return tf.where(
            concentration[..., tf.newaxis, tf.newaxis] > 0., cov,
            tf.linalg.eye(event_dimension_static,
                          batch_shape=self._batch_shape_tensor(
                              mean_direction=mean_direction,
                              concentration=concentration)) / event_dimension)
예제 #2
0
    def _covariance(self):
        # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/

        event_dimension_static = tf.compat.dimension_value(self.event_shape[0])

        # TODO(b/141142878): Enable this; numerically unstable.
        if event_dimension_static is not None and event_dimension_static > 2:
            raise NotImplementedError(
                'vMF covariance is numerically unstable for dim>2')

        mean_direction = tf.convert_to_tensor(self.mean_direction)
        concentration = tf.convert_to_tensor(self.concentration)
        event_dimension = tf.cast(
            self._event_shape_tensor(mean_direction)[0], self.dtype)
        safe_conc = tf.where(concentration > 0, concentration,
                             tf.ones_like(concentration))[..., tf.newaxis]
        h = tfp_math.bessel_iv_ratio(event_dimension / 2, safe_conc)
        intermediate = (
            tf.matmul(mean_direction[..., :, tf.newaxis],
                      mean_direction[..., tf.newaxis, :]) *
            (1 - event_dimension * h / safe_conc - h**2)[..., tf.newaxis])
        cov = tf.linalg.set_diag(
            intermediate,
            tf.linalg.diag_part(intermediate) + (h / safe_conc))
        return tf.where(
            concentration[..., tf.newaxis, tf.newaxis] > 0., cov,
            tf.linalg.eye(event_dimension_static,
                          batch_shape=self._batch_shape_tensor(
                              mean_direction=mean_direction,
                              concentration=concentration)) / event_dimension)
예제 #3
0
    def _mean(self):
        # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/
        concentration = tf.convert_to_tensor(self.concentration)
        mean_direction = tf.convert_to_tensor(self.mean_direction)

        event_dimension = tf.cast(
            self._event_shape_tensor(mean_direction)[0], self.dtype)
        safe_conc = tf.where(concentration > 0, concentration,
                             tf.ones_like(concentration))
        safe_mean = mean_direction * (tfp_math.bessel_iv_ratio(
            event_dimension / 2., safe_conc)[..., tf.newaxis])
        return tf.where(concentration[..., tf.newaxis] > 0., safe_mean,
                        tf.zeros_like(safe_mean))
예제 #4
0
    def _entropy(self):
        mean_direction = tf.convert_to_tensor(self.mean_direction)
        event_dimension = tf.cast(
            self._event_shape_tensor(mean_direction=mean_direction)[0],
            dtype=self.dtype)
        concentration = tf.convert_to_tensor(self.concentration)
        # Compared to [1] (see the KL(VonMisesFisher || SphericalUniform) for the
        # exact reference, we add the log normalization constant rather than
        # subtract it. This is because in TFP, we have the convention that we
        # normalize our distributions p(x) by a constant 1 / Z. Taking the log
        # gives us a negative sign.
        entropy = -concentration * tfp_math.bessel_iv_ratio(
            event_dimension / 2., concentration) + self._log_normalization(
                concentration=concentration)

        return tf.broadcast_to(
            entropy,
            self._batch_shape_tensor(mean_direction=mean_direction,
                                     concentration=concentration))