Esempio n. 1
0
    def __init__(self, params):
        """TODOC"""

        super(Gaussian, self).__init__(params)

        # Check the flattened parameters have the right shape
        concat_dim = int(params.shape[-1])
        assert concat_dim % 3 == 0
        # Extract the separate parameters
        self.dim = concat_dim // 3
        outer_slices = [slice(None)] * (len(params.shape) - 1)
        μ_flat = params[outer_slices + [slice(self.dim)]]
        logD_flat = params[outer_slices + [slice(self.dim, 2 * self.dim)]]
        u_flat = params[outer_slices + [slice(2 * self.dim, None)]]

        # Prepare the D matrix
        D = tf.matrix_diag(K.exp(logD_flat))
        D_inv = tf.matrix_diag(K.exp(- logD_flat))
        D_inv_sqrt = tf.matrix_diag(K.exp(- .5 * logD_flat))

        # Some pre-computations
        u = K.expand_dims(u_flat, -1)
        uT = tf.matrix_transpose(u)
        uT_D_inv_u = uT @ D_inv @ u
        η = 1.0 / (1.0 + uT_D_inv_u)

        self.μ = K.expand_dims(μ_flat, -1)
        self.R = D_inv_sqrt - (((1 - K.sqrt(η)) / uT_D_inv_u) * (D_inv @ u @ uT @ D_inv_sqrt))
        self.C_inv = D + u @ uT
        self.C = D_inv - (η * (D_inv @ u @ uT @ D_inv))
        self.logdetC = right_squeeze2(K.log(η)) - K.sum(logD_flat, axis=-1)
Esempio n. 2
0
 def logprobability(self, v):
     """TODOC"""
     # Turn `v` into a column vector
     v = K.expand_dims(v, -1)
     # Check shapes and broadcast
     v = broadcast_left(v, self.μ)
     return - .5 * (self.dim * np.log(2 * np.pi) + 2 * self.logdetS
                    + right_squeeze2(tf.matrix_transpose(v - self.μ)
                                     @ self.S2_inv
                                     @ (v - self.μ)))
Esempio n. 3
0
 def kl_to_normal(self):
     """TODOC"""
     return .5 * (self.traceS2 - 2 * self.logdetS
                  + right_squeeze2(tf.matrix_transpose(self.μ) @ self.μ)
                  - self.dim)