def log_prob(self, x): scale_tril_inv = \ _batch_triangular_inv(self.scale_tril.reshape(-1, self.d, self.d)) scale_tril_inv = scale_tril_inv.reshape(self.batch_shape + (self.d, self.d)) bsti = broadcast.broadcast_to(scale_tril_inv, x.shape + (self.d, )) bl = broadcast.broadcast_to(self.loc, x.shape) m = matmul.matmul(bsti, expand_dims.expand_dims(x - bl, axis=-1)) m = matmul.matmul(swapaxes.swapaxes(m, -1, -2), m) m = squeeze.squeeze(m, axis=-1) m = squeeze.squeeze(m, axis=-1) logz = LOGPROBC * self.d - self._logdet(self.scale_tril) return broadcast.broadcast_to(logz, m.shape) - 0.5 * m
def log_prob(self, x): scale_tril_inv = \ _batch_triangular_inv(self.scale_tril.reshape(-1, self.d, self.d)) scale_tril_inv = scale_tril_inv.reshape( self.batch_shape+(self.d, self.d)) bsti = broadcast.broadcast_to(scale_tril_inv, x.shape + (self.d,)) bl = broadcast.broadcast_to(self.loc, x.shape) m = matmul.matmul( bsti, expand_dims.expand_dims(x - bl, axis=-1)) m = matmul.matmul(swapaxes.swapaxes(m, -1, -2), m) m = squeeze.squeeze(m, axis=-1) m = squeeze.squeeze(m, axis=-1) logz = LOGPROBC * self.d - self._logdet(self.scale_tril) return broadcast.broadcast_to(logz, m.shape) - 0.5 * m