def compare_KLs(sess, feed_dict, mu, Q_chol, P_chols): mu_gpflow = tf.transpose(mu) if mu.shape.ndims == 2 else mu[:, None] Q_chol_gpflow = Q_chol if Q_chol.shape.ndims == 3 else Q_chol[None, ...] KL_gpflow = sess.run(gauss_kl(q_mu=mu_gpflow, q_sqrt=Q_chol_gpflow, K=None), feed_dict=feed_dict) KL_gpt = sess.run(KL(mu_diff=mu, Q_chol=Q_chol, P_chol=None, P=None), feed_dict=feed_dict) assert_allclose(KL_gpflow, KL_gpt) for P_chol in P_chols: P_ndims = P_chol.shape.ndims P = tf.square(P_chol) if P_ndims == 1 else tf.matmul( P_chol, P_chol, transpose_b=True) KL_gpflow = sess.run(gauss_kl(q_mu=mu_gpflow, q_sqrt=Q_chol_gpflow, K=tf.diag(P) if P_ndims == 1 else P), feed_dict=feed_dict) KL_gpt = sess.run(KL(mu_diff=mu, Q_chol=Q_chol, P_chol=P_chol, P=None), feed_dict=feed_dict) assert_allclose(KL_gpflow, KL_gpt) KL_gpt = sess.run(KL(mu_diff=mu, Q_chol=Q_chol, P_chol=None, P=P), feed_dict=feed_dict) assert_allclose(KL_gpflow, KL_gpt) KL_gpt = sess.run(KL(mu_diff=mu, Q_chol=Q_chol, P_chol=P_chol, P=P), feed_dict=feed_dict) assert_allclose(KL_gpflow, KL_gpt)
def _build_transition_expectations(self, qx_samples, batch_indices=None): logp_kwargs = { 'subtract_KL_U': False } if isinstance(self.transitions, GPTransitions) else {} tr_expectations = 0. for s in range( self.n_seq if batch_indices is None else self.batch_size): b_s = s if batch_indices is None else batch_indices[s] inputs = self.gather_from_list(self.Y, b_s)[:-1] if self.transitions.OBSERVATIONS_AS_INPUT \ else (None if self.inputs is None else self.gather_from_list(self.inputs, b_s)) logp = self.transitions.logp(qx_samples[s], inputs, **logp_kwargs) tr_expectations += tf.reduce_mean(tf.reduce_sum(logp, 1)) if batch_indices is not None: sum_T_l_batch = tf.cast( tf.reduce_sum(tf.gather(self.T_latent_tf, batch_indices)), gp.settings.float_type) tr_expectations *= (self.sum_T_latent - self.n_seq) / (sum_T_l_batch - self.batch_size) if isinstance(self.transitions, GPTransitions): KL_U = KL(self.transitions.Umu, self.transitions.Ucov_chol) tr_expectations -= KL_U return tr_expectations
def _build_KL_x1(self, batch_indices=None): """ qx1_mu: SxE qx1_cov_chol: SxExE px1_mu: E or SxE px1_cov_chol: None or E or ExE or SxE or SxExE """ _P_chol = self.px1_cov_chol if not self.multi_diag_px1_cov else tf.matrix_diag( self.px1_cov_chol) if self.chunking: _px1_mu = self.px1_mu _qx1_mu = self.qx1_mu[0] _qx1_cov_chol = self.qx1_cov_chol[0] elif batch_indices is None: _px1_mu = self.px1_mu _qx1_mu = self.qx1_mu _qx1_cov_chol = self.qx1_cov_chol else: _px1_mu = tf.gather( self.px1_mu, batch_indices) if self.px1_mu.shape.ndims == 2 else self.px1_mu _qx1_mu = tf.gather(self.qx1_mu, batch_indices) _qx1_cov_chol = tf.gather(self.qx1_cov_chol, batch_indices) _P_chol = None if self.px1_cov_chol is None else \ (_P_chol if _P_chol.shape.ndims < 3 else tf.gather(_P_chol, batch_indices)) KL_x1 = KL(_qx1_mu - _px1_mu, _qx1_cov_chol, P_chol=_P_chol) if batch_indices is not None and not self.chunking: KL_x1 *= float(self.n_seq) / float(self.batch_size) return KL_x1
def test_whitening(self): with self.test_context() as sess: mu = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M)) Q_chol = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) P_chol = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) feed_dict = self.get_feed_dict([mu], [Q_chol], [P_chol]) KL_black = sess.run(KL(mu, Q_chol, P_chol=P_chol), feed_dict) KL_white = sess.run( KL( tf.matrix_triangular_solve(P_chol, mu[:, :, None], lower=True)[..., 0], tf.matrix_triangular_solve(P_chol, Q_chol, lower=True)), feed_dict) assert_allclose(KL_black, KL_white)
def test_KL_x1_multiseq(self): with self.test_context() as sess: mu = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M)) Q_chol = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) P_chol_1D = tf.placeholder(FLOAT_TYPE, shape=(self.M)) P_chol_2D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.M)) P_chol_3D_diag = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M)) P_chol_3D = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) feed_dict = self.get_feed_dict( [mu], [Q_chol], [P_chol_1D, P_chol_2D, P_chol_3D_diag, P_chol_3D]) KL_1 = sess.run(KL(mu, Q_chol, P_chol=None), feed_dict) KL_2 = sess.run(KL(mu, Q_chol, P_chol=P_chol_1D), feed_dict) KL_3 = sess.run(KL(mu, Q_chol, P_chol=P_chol_2D), feed_dict) KL_4 = sess.run( KL(mu, Q_chol, P_chol=tf.matrix_diag(P_chol_3D_diag)), feed_dict) KL_5 = sess.run(KL(mu, Q_chol, P_chol=P_chol_3D), feed_dict) KL_map_1 = sess.run( tf.map_fn(lambda a: KL(a[0], a[1], P_chol=None), (mu, Q_chol), (FLOAT_TYPE)), feed_dict) KL_map_2 = sess.run( tf.map_fn(lambda a: KL(a[0], a[1], P_chol=P_chol_1D), (mu, Q_chol), (FLOAT_TYPE)), feed_dict) KL_map_3 = sess.run( tf.map_fn(lambda a: KL(a[0], a[1], P_chol=P_chol_2D), (mu, Q_chol), (FLOAT_TYPE)), feed_dict) KL_map_4 = sess.run( tf.map_fn(lambda a: KL(a[0], a[1], P_chol=a[2]), (mu, Q_chol, P_chol_3D_diag), (FLOAT_TYPE)), feed_dict) KL_map_5 = sess.run( tf.map_fn(lambda a: KL(a[0], a[1], P_chol=a[2]), (mu, Q_chol, P_chol_3D), (FLOAT_TYPE)), feed_dict) assert_allclose(KL_1, KL_map_1.sum()) assert_allclose(KL_2, KL_map_2.sum()) assert_allclose(KL_3, KL_map_3.sum()) assert_allclose(KL_4, KL_map_4.sum()) assert_allclose(KL_5, KL_map_5.sum())
def test_KL_samples_mu_2D(self): with self.test_context() as sess: mu = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.D)) Q_chol_2D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.D)) Q_chol_3D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.D, self.D)) P_chol_1D = tf.placeholder(FLOAT_TYPE, shape=(self.D)) P_chol_2D = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.D)) feed_dict = self.get_feed_dict([mu], [Q_chol_2D, Q_chol_3D], [P_chol_1D, P_chol_2D]) KL_s_1 = sess.run(tf.reduce_sum(KL_samples(mu, Q_chol_2D, None)), feed_dict) KL_s_2 = sess.run( tf.reduce_sum(KL_samples(mu, Q_chol_2D, P_chol_1D)), feed_dict) KL_s_3 = sess.run( tf.reduce_sum(KL_samples(mu, Q_chol_2D, P_chol_2D)), feed_dict) KL_s_4 = sess.run(tf.reduce_sum(KL_samples(mu, Q_chol_3D, None)), feed_dict) KL_s_5 = sess.run( tf.reduce_sum(KL_samples(mu, Q_chol_3D, P_chol_1D)), feed_dict) KL_s_6 = sess.run( tf.reduce_sum(KL_samples(mu, Q_chol_3D, P_chol_2D)), feed_dict) KL_1 = sess.run(KL(mu, tf.matrix_diag(Q_chol_2D), P_chol=None), feed_dict) KL_2 = sess.run( KL(mu, tf.matrix_diag(Q_chol_2D), P_chol=tf.diag(P_chol_1D)), feed_dict) KL_3 = sess.run( KL(mu, tf.matrix_diag(Q_chol_2D), P_chol=P_chol_2D), feed_dict) KL_4 = sess.run(KL(mu, Q_chol_3D, P_chol=None), feed_dict) KL_5 = sess.run(KL(mu, Q_chol_3D, P_chol=tf.diag(P_chol_1D)), feed_dict) KL_6 = sess.run(KL(mu, Q_chol_3D, P_chol=P_chol_2D), feed_dict) assert_allclose(KL_s_1, KL_1) assert_allclose(KL_s_2, KL_2) assert_allclose(KL_s_3, KL_3) assert_allclose(KL_s_4, KL_4) assert_allclose(KL_s_5, KL_5) assert_allclose(KL_s_6, KL_6)
def _build_KL_x1(self): return KL(self.qx1_mu - self.px1_mu, self.qx1_cov_chol, P_chol=self.px1_cov_chol)
def _build_KL_U(self): return KL(self.Umu, self.Ucov_chol)