def _covariance(self): # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/ event_dimension_static = tf.compat.dimension_value(self.event_shape[0]) mean_direction = tf.convert_to_tensor(self.mean_direction) concentration = tf.convert_to_tensor(self.concentration) event_dimension = tf.cast( self._event_shape_tensor(mean_direction)[0], self.dtype) safe_conc = tf.where(concentration > 0, concentration, tf.ones_like(concentration))[..., tf.newaxis] h = tfp_math.bessel_iv_ratio(event_dimension / 2, safe_conc) intermediate = ( tf.matmul(mean_direction[..., :, tf.newaxis], mean_direction[..., tf.newaxis, :]) * (1 - event_dimension * h / safe_conc - h**2)[..., tf.newaxis]) cov = tf.linalg.set_diag( intermediate, tf.linalg.diag_part(intermediate) + (h / safe_conc)) return tf.where( concentration[..., tf.newaxis, tf.newaxis] > 0., cov, tf.linalg.eye(event_dimension_static, batch_shape=self._batch_shape_tensor( mean_direction=mean_direction, concentration=concentration)) / event_dimension)
def _covariance(self): # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/ event_dimension_static = tf.compat.dimension_value(self.event_shape[0]) # TODO(b/141142878): Enable this; numerically unstable. if event_dimension_static is not None and event_dimension_static > 2: raise NotImplementedError( 'vMF covariance is numerically unstable for dim>2') mean_direction = tf.convert_to_tensor(self.mean_direction) concentration = tf.convert_to_tensor(self.concentration) event_dimension = tf.cast( self._event_shape_tensor(mean_direction)[0], self.dtype) safe_conc = tf.where(concentration > 0, concentration, tf.ones_like(concentration))[..., tf.newaxis] h = tfp_math.bessel_iv_ratio(event_dimension / 2, safe_conc) intermediate = ( tf.matmul(mean_direction[..., :, tf.newaxis], mean_direction[..., tf.newaxis, :]) * (1 - event_dimension * h / safe_conc - h**2)[..., tf.newaxis]) cov = tf.linalg.set_diag( intermediate, tf.linalg.diag_part(intermediate) + (h / safe_conc)) return tf.where( concentration[..., tf.newaxis, tf.newaxis] > 0., cov, tf.linalg.eye(event_dimension_static, batch_shape=self._batch_shape_tensor( mean_direction=mean_direction, concentration=concentration)) / event_dimension)
def _mean(self): # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/ concentration = tf.convert_to_tensor(self.concentration) mean_direction = tf.convert_to_tensor(self.mean_direction) event_dimension = tf.cast( self._event_shape_tensor(mean_direction)[0], self.dtype) safe_conc = tf.where(concentration > 0, concentration, tf.ones_like(concentration)) safe_mean = mean_direction * (tfp_math.bessel_iv_ratio( event_dimension / 2., safe_conc)[..., tf.newaxis]) return tf.where(concentration[..., tf.newaxis] > 0., safe_mean, tf.zeros_like(safe_mean))
def _entropy(self): mean_direction = tf.convert_to_tensor(self.mean_direction) event_dimension = tf.cast( self._event_shape_tensor(mean_direction=mean_direction)[0], dtype=self.dtype) concentration = tf.convert_to_tensor(self.concentration) # Compared to [1] (see the KL(VonMisesFisher || SphericalUniform) for the # exact reference, we add the log normalization constant rather than # subtract it. This is because in TFP, we have the convention that we # normalize our distributions p(x) by a constant 1 / Z. Taking the log # gives us a negative sign. entropy = -concentration * tfp_math.bessel_iv_ratio( event_dimension / 2., concentration) + self._log_normalization( concentration=concentration) return tf.broadcast_to( entropy, self._batch_shape_tensor(mean_direction=mean_direction, concentration=concentration))