def _build_stochastic_x1_cross_entropy(self, qx1_samples, batch_indices=None): diag_px1 = self.px1_cov_chol.shape.ndims == 1 or self.multi_diag_px1_cov if self.multi_diag_px1_cov or self.px1_cov_chol.shape.ndims == 3: x1_ce = 0. for s in range( self.n_seq if batch_indices is None else self.batch_size): b_s = s if batch_indices is None else batch_indices[s] _px1_mu = self.px1_mu if self.px1_mu.shape.ndims == 1 else self.px1_mu[ b_s] if diag_px1: _x1_ce = diag_mvn_logp(qx1_samples[s] - _px1_mu, self.px1_cov_chol[b_s]) else: _x1_ce = mvn_logp(tf.transpose(qx1_samples[s] - _px1_mu), self.px1_cov_chol[b_s]) x1_ce += tf.reduce_mean(_x1_ce) else: _px1_mu = self.px1_mu if self.px1_mu.shape.ndims == 1 else self.px1_mu[:, None, :] if diag_px1: x1_ce = diag_mvn_logp(qx1_samples - _px1_mu, self.px1_cov_chol) else: x1_ce = mvn_logp( tf.transpose(qx1_samples - _px1_mu, [2, 0, 1]), self.px1_cov_chol) x1_ce = tf.reduce_sum(tf.reduce_mean(x1_ce, -1)) if batch_indices is not None: x1_ce *= float(self.n_seq) / float(self.batch_size) return -x1_ce
def logp(self, X, inputs=None): d = X[..., 1:, :] - self.conditional_mean(X[..., :-1, :], inputs=inputs) if self.Qchol.shape.ndims == 2: dim_perm = [2, 0, 1] if X.shape.ndims == 3 else [1, 0] return mvn_logp(tf.transpose(d, dim_perm), self.Qchol) elif self.Qchol.shape.ndims == 1: return diag_mvn_logp(d, self.Qchol)
def logp(self, X, Y): """ :param X: latent state (T x E) or (n_samples x T x E) :param Y: observations (T x D) :return: \log P(Y|X(n)) (T) or (n_samples x T) """ d = Y - self.conditional_mean(X) dim_perm = [2, 0, 1] if X.shape.ndims == 3 else [1, 0] return mvn_logp(tf.transpose(d, dim_perm), tf.ones([1,1], dtype = tf.float64)*tf.sqrt(self.noise))
def logp(self, X, Y): """ :param X: latent state (T x E) or (n_samples x T x E) :param Y: observations (T x D) :return: \log P(Y|X(n)) (T) or (n_samples x T) """ d = Y - self.conditional_mean(X) dim_perm = [2, 0, 1] if X.shape.ndims == 3 else [1, 0] return mvn_logp(tf.transpose(d, dim_perm), self.Rchol)
def _build_stochastic_x1_cross_entropy(self, qx1_samples): return -tf.reduce_mean( mvn_logp(tf.transpose(qx1_samples - self.px1_mu), self.px1_cov_chol))