def _get_mu(self, ranks, x, y): """Initializes latent inputs expectations mu. Either loads pretrained values of tt-cores of mu, or initializes it according to optimal formulas from the given data. Args: ranks: tt-ranks of mu x: features of a batch of objects y: targets of a batch of objects """ # TODO: test if this is needed. w = self.inputs.interpolate_on_batch(self.cov.project(x)) Sigma = ops.tt_tt_matmul(self.sigma_l, ops.transpose(self.sigma_l)) temp = ops.tt_tt_matmul(w, y) anc = ops.tt_tt_matmul(Sigma, temp) res = TensorTrain([core[0, :, :, :, :] for core in anc.tt_cores], tt_ranks=[1]*(anc.ndims()+1)) res = res for i in range(1, anc.get_shape()[0]): elem = TensorTrain([core[i, :, :, :, :] for core in anc.tt_cores], tt_ranks=[1]*(anc.ndims()+1)) res = ops.add(res, elem) mu_ranks = [1] + [ranks] * (res.ndims() - 1) + [1] return t3f.get_variable('tt_mu', initializer=TensorTrain(res.tt_cores, res.get_raw_shape(), mu_ranks))
def _unary_complexity_penalty(self): """Computes the complexity penalty for unary potentials. This function computes KL-divergence between prior and variational distribution over the values of GPs at inducing inputs. Returns: A scalar `tf.Tensor` containing the complexity penalty for GPs determining unary potentials. """ # TODO: test this mus = self.mus sigma_ls = _kron_tril(self.sigma_ls) sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls)) sigmas_logdet = _kron_logdet(sigma_ls) K_mms = self._K_mms() K_mms_inv = kron.inv(K_mms) K_mms_logdet = kron.slog_determinant(K_mms)[1] penalty = 0 penalty += -K_mms_logdet penalty += sigmas_logdet penalty += -ops.tt_tt_flat_inner(sigmas, K_mms_inv) penalty += -ops.tt_tt_flat_inner(mus, ops.tt_tt_matmul(K_mms_inv, mus)) return tf.reduce_sum(penalty) / 2
def _get_mus(self, mu_ranks): """Initialize expectations of var distribution over unary potentials. Args: mu_ranks: TT-ranks of mus. """ # TODO: is this a good initialization? x_init = tf.random_normal([mu_ranks, self.d], dtype=tf.float64) y_init = tf.random_normal([mu_ranks], dtype=tf.float64) w = self.inputs.interpolate_on_batch(x_init) y_init_cores = [tf.reshape(y_init, (-1, 1, 1, 1, 1))] for core_idx in range(1, w.ndims()): y_init_cores += [tf.ones((mu_ranks, 1, 1, 1, 1), dtype=tf.float64)] y_init = t3f.TensorTrainBatch(y_init_cores) Sigma = ops.tt_tt_matmul(self.sigma_ls[0], ops.transpose(self.sigma_ls[0])) res_batch = t3f.tt_tt_matmul(Sigma, t3f.tt_tt_matmul(w, y_init)) res = res_batch[0] for i in range(1, mu_ranks): res = res + res_batch[i] mu_ranks = [1] + [mu_ranks] * (res.ndims() - 1) + [1] mu_cores = [] for core in res.tt_cores: mu_cores.append( tf.tile(core[None, ...], [self.n_labels, 1, 1, 1, 1])) return t3f.get_variable('tt_mus', initializer=TensorTrainBatch( mu_cores, res.get_raw_shape(), mu_ranks))
def _get_mus(self, ranks, x_init, y_init): w = self.inputs.interpolate_on_batch(self.cov.project(x_init)) Sigma = ops.tt_tt_matmul(self.sigma_ls[0], ops.transpose(self.sigma_ls[0])) temp = ops.tt_tt_matmul(w, y_init) anc = ops.tt_tt_matmul(Sigma, temp) res = TensorTrain([core[0, :, :, :, :] for core in anc.tt_cores], tt_ranks=[1]*(anc.ndims()+1)) res = res for i in range(1, anc.get_shape()[0]): elem = TensorTrain([core[i, :, :, :, :] for core in anc.tt_cores], tt_ranks=[1]*(anc.ndims()+1)) res = ops.add(res, elem) mu_ranks = [1] + [ranks] * (res.ndims() - 1) + [1] mu_cores = [] for core in res.tt_cores: mu_cores.append(tf.tile(core[None, ...], [self.n_class, 1, 1, 1, 1])) return t3f.get_variable('tt_mus', initializer=TensorTrainBatch(mu_cores, res.get_raw_shape(), mu_ranks))
def complexity_penalty(self): """Returns the complexity penalty term for ELBO. """ mus = self.mus sigma_ls = _kron_tril(self.sigma_ls) sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls)) sigmas_logdet = _kron_logdet(sigma_ls) K_mms = self._K_mms() K_mms_inv = kron.inv(K_mms) K_mms_logdet = kron.slog_determinant(K_mms)[1] penalty = 0 penalty += - K_mms_logdet penalty += sigmas_logdet penalty += - ops.tt_tt_flat_inner(sigmas, K_mms_inv) penalty += - ops.tt_tt_flat_inner(mus, ops.tt_tt_matmul(K_mms_inv, mus)) return penalty / 2
def complexity_penalty(self): """Returns the complexity penalty term for ELBO of different GP models. """ mu = self.mu sigma_l = _kron_tril(self.sigma_l) sigma = ops.tt_tt_matmul(sigma_l, ops.transpose(sigma_l)) sigma_logdet = _kron_logdet(sigma_l) K_mm = self.K_mm() K_mm_inv = kron.inv(K_mm) K_mm_logdet = kron.slog_determinant(K_mm)[1] elbo = 0 elbo += - K_mm_logdet elbo += sigma_logdet elbo += - ops.tt_tt_flat_inner(sigma, K_mm_inv) elbo += - ops.tt_tt_flat_inner(mu, ops.tt_tt_matmul(K_mm_inv, mu)) return elbo / 2
def predict_process_value(self, x, with_variance=False): """Predicts the value of the process at point x. Args: x: data features with_variance: if True, returns process variance at x """ mu = self.mu w = self.inputs.interpolate_on_batch(self.cov.project(x)) mean = ops.tt_tt_flat_inner(w, mu) if not with_variance: return mean K_mm = self.K_mm() variance = self.cov.cov_0() sigma_l_w = ops.tt_tt_matmul(ops.transpose(self.sigma_l), w) variance += ops.tt_tt_flat_inner(sigma_l_w, sigma_l_w) variance -= ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(K_mm, w)) return mean, variance
def _latent_vars_distribution(self, x, seq_lens): """Computes the parameters of the variational distribution over potentials. Args: x: `tf.Tensor` of shape `batch_size` x `max_seq_len` x d; sequences of features for the current batch. seq_lens: `tf.Tensor` of shape `bach_size`; lenghts of input sequences. Returns: A tuple containing 4 `tf.Tensors`. `m_un`: a `tf.Tensor` of shape `n_labels` x `batch_size` x `max_seq_len`; the expectations of the unary potentials. `S_un`: a `tf.Tensor` of shape `n_labels` x `batch_size` x `max_seq_len` x `max_seq_len`; the covariance matrix of unary potentials. `m_bin`: a `tf.Tensor` of shape `max_seq_len`^2; the expectations of binary potentials. `S_bin`: a `tf.Tensor` of shape `max_seq_len`^2 x `max_seq_len`^2; the covariance matrix of binary potentials. """ batch_size, max_len, d = x.get_shape().as_list() n_labels = self.n_labels sequence_mask = tf.sequence_mask(seq_lens, maxlen=max_len) indices = tf.cast(tf.where(sequence_mask), tf.int32) x_flat = tf.gather_nd(x, indices) print('_latent_vars_distribution/x_flat', x_flat.get_shape(), '=', 'sum_len', 'x', d) w = self.inputs.interpolate_on_batch(self.cov.project(x_flat)) m_un_flat = batch_ops.pairwise_flat_inner(w, self.mus) print('_latent_vars_distribution/m_un_flat', m_un_flat.get_shape(), '=', 'sum_len', 'x', self.n_labels) shape = tf.concat([[batch_size], [max_len], [n_labels]], axis=0) m_un = tf.scatter_nd(indices, m_un_flat, shape) m_un = tf.transpose(m_un, [2, 0, 1]) sigmas = ops.tt_tt_matmul(self.sigma_ls, t3f.transpose(self.sigma_ls)) K_mms = self._K_mms() K_nn = self._K_nns(x) S_un = K_nn S_un += _kron_sequence_pairwise_quadratic_form(sigmas, w, seq_lens, max_len) S_un -= _kron_sequence_pairwise_quadratic_form(K_mms, w, seq_lens, max_len) S_un = self._remove_extra_elems(seq_lens, S_un) m_bin = tf.identity(self.bin_mu) S_bin = tf.matmul(self.bin_sigma_l, tf.transpose(self.bin_sigma_l)) return m_un, S_un, m_bin, S_bin
def _predict_process_values(self, x, with_variance=False, test=False): w = self.inputs.interpolate_on_batch(self.cov.project(x, test=test)) mean = batch_ops.pairwise_flat_inner(w, self.mus) if not with_variance: return mean K_mms = self._K_mms() sigma_ls = _kron_tril(self.sigma_ls) variances = [] sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls)) variances = pairwise_quadratic_form(sigmas, w, w) variances -= pairwise_quadratic_form(K_mms, w, w) variances += self.cov.cov_0()[None, :] return mean, variances
def elbo(self, w, y): '''Evidence lower bound. Args: w: interpolation vector for the current batch. y: target values for the current batch. ''' l = tf.cast(tf.shape(y)[0], tf.float64) # batch size N = tf.cast(self.N, dtype=tf.float64) y = tf.reshape(y, [-1]) mu = self.gp.mu sigma_l = _kron_tril(self.gp.sigma_l) sigma = ops.tt_tt_matmul(sigma_l, ops.transpose(sigma_l)) sigma_n = self.gp.cov.noise_variance() K_mm = self.gp.K_mm() tilde_K_ii = l * self.gp.cov.cov_0() tilde_K_ii -= tf.reduce_sum(ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(K_mm, w))) elbo = 0 elbo -= tf.reduce_sum(tf.square(y - ops.tt_tt_flat_inner(w, mu))) elbo -= tilde_K_ii # TODO: wtf? # elbo -= ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(sigma, w)) elbo -= tf.reduce_sum(ops.tt_tt_flat_inner(w, ops.tt_tt_matmul(sigma, w))) elbo /= 2 * sigma_n**2 * l elbo += self.gp.complexity_penalty() / N # TODO: wtf? # elbo -= tf.log(tf.abs(sigma_n)) return -elbo[0]