def __init__(self, params): """TODOC""" super(Gaussian, self).__init__(params) # Check the flattened parameters have the right shape concat_dim = int(params.shape[-1]) assert concat_dim % 3 == 0 # Extract the separate parameters self.dim = concat_dim // 3 outer_slices = [slice(None)] * (len(params.shape) - 1) μ_flat = params[outer_slices + [slice(self.dim)]] logD_flat = params[outer_slices + [slice(self.dim, 2 * self.dim)]] u_flat = params[outer_slices + [slice(2 * self.dim, None)]] # Prepare the D matrix D = tf.matrix_diag(K.exp(logD_flat)) D_inv = tf.matrix_diag(K.exp(- logD_flat)) D_inv_sqrt = tf.matrix_diag(K.exp(- .5 * logD_flat)) # Some pre-computations u = K.expand_dims(u_flat, -1) uT = tf.matrix_transpose(u) uT_D_inv_u = uT @ D_inv @ u η = 1.0 / (1.0 + uT_D_inv_u) self.μ = K.expand_dims(μ_flat, -1) self.R = D_inv_sqrt - (((1 - K.sqrt(η)) / uT_D_inv_u) * (D_inv @ u @ uT @ D_inv_sqrt)) self.C_inv = D + u @ uT self.C = D_inv - (η * (D_inv @ u @ uT @ D_inv)) self.logdetC = right_squeeze2(K.log(η)) - K.sum(logD_flat, axis=-1)
def logprobability(self, v): """TODOC""" # Turn `v` into a column vector v = K.expand_dims(v, -1) # Check shapes and broadcast v = broadcast_left(v, self.μ) return - .5 * (self.dim * np.log(2 * np.pi) + 2 * self.logdetS + right_squeeze2(tf.matrix_transpose(v - self.μ) @ self.S2_inv @ (v - self.μ)))
def kl_to_normal(self): """TODOC""" return .5 * (self.traceS2 - 2 * self.logdetS + right_squeeze2(tf.matrix_transpose(self.μ) @ self.μ) - self.dim)