def testPairwiseFlatInnerMatrix(self): # Compare pairwise_flat_inner_projected against naive implementation. what1 = initializers.random_matrix_batch(((2, 3, 4), None), 4, batch_size=3, dtype=self.dtype) what2 = initializers.random_matrix_batch(((2, 3, 4), None), 4, batch_size=4, dtype=self.dtype) where = initializers.random_matrix(((2, 3, 4), None), 3, dtype=self.dtype) projected1 = riemannian.project(what1, where) projected2 = riemannian.project(what2, where) desired = batch_ops.pairwise_flat_inner(projected1, projected2) actual = riemannian.pairwise_flat_inner_projected(projected1, projected2) with self.test_session() as sess: desired_val, actual_val = sess.run((desired, actual)) self.assertAllClose(desired_val, actual_val, atol=1e-5, rtol=1e-5) with self.assertRaises(ValueError): # Second argument is not a projection on the tangent space. riemannian.pairwise_flat_inner_projected(projected1, what2) where2 = initializers.random_matrix(((2, 3, 4), None), 3, dtype=self.dtype) another_projected2 = riemannian.project(what2, where2) with self.assertRaises(ValueError): # The arguments are projections on different tangent spaces. riemannian.pairwise_flat_inner_projected(projected1, another_projected2)
def testPairwiseFlatInnerVectorsWithMatrix(self): # Test pairwise_flat_inner of a batch of TT vectors with providing a matrix, # so we should compute # res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j] tt_vectors_1 = initializers.random_matrix_batch(((2, 3), None), batch_size=2, dtype=self.dtype) tt_vectors_2 = initializers.random_matrix_batch(((2, 3), None), batch_size=3, dtype=self.dtype) matrix = initializers.random_matrix(((2, 3), (2, 3)), dtype=self.dtype) res_actual = batch_ops.pairwise_flat_inner(tt_vectors_1, tt_vectors_2, matrix) full_vectors_1 = tf.reshape(ops.full(tt_vectors_1), (2, 6)) full_vectors_2 = tf.reshape(ops.full(tt_vectors_2), (3, 6)) with self.test_session() as sess: res = sess.run( (res_actual, full_vectors_1, full_vectors_2, ops.full(matrix))) res_actual_val, vectors_1_val, vectors_2_val, matrix_val = res res_desired_val = np.zeros((2, 3)) for i in range(2): for j in range(3): curr_val = np.dot(vectors_1_val[i], matrix_val) curr_val = np.dot(curr_val, vectors_2_val[j]) res_desired_val[i, j] = curr_val self.assertAllClose(res_desired_val, res_actual_val)
def _latent_vars_distribution(self, x, seq_lens): """Computes the parameters of the variational distribution over potentials. Args: x: `tf.Tensor` of shape `batch_size` x `max_seq_len` x d; sequences of features for the current batch. seq_lens: `tf.Tensor` of shape `bach_size`; lenghts of input sequences. Returns: A tuple containing 4 `tf.Tensors`. `m_un`: a `tf.Tensor` of shape `n_labels` x `batch_size` x `max_seq_len`; the expectations of the unary potentials. `S_un`: a `tf.Tensor` of shape `n_labels` x `batch_size` x `max_seq_len` x `max_seq_len`; the covariance matrix of unary potentials. `m_bin`: a `tf.Tensor` of shape `max_seq_len`^2; the expectations of binary potentials. `S_bin`: a `tf.Tensor` of shape `max_seq_len`^2 x `max_seq_len`^2; the covariance matrix of binary potentials. """ batch_size, max_len, d = x.get_shape().as_list() n_labels = self.n_labels sequence_mask = tf.sequence_mask(seq_lens, maxlen=max_len) indices = tf.cast(tf.where(sequence_mask), tf.int32) x_flat = tf.gather_nd(x, indices) print('_latent_vars_distribution/x_flat', x_flat.get_shape(), '=', 'sum_len', 'x', d) w = self.inputs.interpolate_on_batch(self.cov.project(x_flat)) m_un_flat = batch_ops.pairwise_flat_inner(w, self.mus) print('_latent_vars_distribution/m_un_flat', m_un_flat.get_shape(), '=', 'sum_len', 'x', self.n_labels) shape = tf.concat([[batch_size], [max_len], [n_labels]], axis=0) m_un = tf.scatter_nd(indices, m_un_flat, shape) m_un = tf.transpose(m_un, [2, 0, 1]) sigmas = ops.tt_tt_matmul(self.sigma_ls, t3f.transpose(self.sigma_ls)) K_mms = self._K_mms() K_nn = self._K_nns(x) S_un = K_nn S_un += _kron_sequence_pairwise_quadratic_form(sigmas, w, seq_lens, max_len) S_un -= _kron_sequence_pairwise_quadratic_form(K_mms, w, seq_lens, max_len) S_un = self._remove_extra_elems(seq_lens, S_un) m_bin = tf.identity(self.bin_mu) S_bin = tf.matmul(self.bin_sigma_l, tf.transpose(self.bin_sigma_l)) return m_un, S_un, m_bin, S_bin
def testPairwiseFlatInnerTensor(self): # Test pairwise_flat_inner of a batch of TT tensors. tt_tensors_1 = initializers.random_tensor_batch((2, 3, 2), batch_size=5) tt_tensors_2 = initializers.random_tensor_batch((2, 3, 2), batch_size=5) res_actual = batch_ops.pairwise_flat_inner(tt_tensors_1, tt_tensors_2) full_tensors_1 = tf.reshape(ops.full(tt_tensors_1), (5, 12)) full_tensors_2 = tf.reshape(ops.full(tt_tensors_2), (5, 12)) res_desired = tf.matmul(full_tensors_1, tf.transpose(full_tensors_2)) res_desired = tf.squeeze(res_desired) with self.test_session() as sess: res_actual_val, res_desired_val = sess.run( (res_actual, res_desired)) self.assertAllClose(res_desired_val, res_actual_val)
def _predict_process_values(self, x, with_variance=False, test=False): w = self.inputs.interpolate_on_batch(self.cov.project(x, test=test)) mean = batch_ops.pairwise_flat_inner(w, self.mus) if not with_variance: return mean K_mms = self._K_mms() sigma_ls = _kron_tril(self.sigma_ls) variances = [] sigmas = ops.tt_tt_matmul(sigma_ls, ops.transpose(sigma_ls)) variances = pairwise_quadratic_form(sigmas, w, w) variances -= pairwise_quadratic_form(K_mms, w, w) variances += self.cov.cov_0()[None, :] return mean, variances
def testPairwiseFlatInnerMatrix(self): # Test pairwise_flat_inner of a batch of TT matrices. tt_vectors_1 = initializers.random_matrix_batch(((2, 3), (2, 3)), batch_size=5) tt_vectors_2 = initializers.random_matrix_batch(((2, 3), (2, 3)), batch_size=5) res_actual = batch_ops.pairwise_flat_inner(tt_vectors_1, tt_vectors_2) full_vectors_1 = tf.reshape(ops.full(tt_vectors_1), (5, 36)) full_vectors_2 = tf.reshape(ops.full(tt_vectors_2), (5, 36)) res_desired = tf.matmul(full_vectors_1, tf.transpose(full_vectors_2)) res_desired = tf.squeeze(res_desired) with self.test_session() as sess: res_actual_val, res_desired_val = sess.run( (res_actual, res_desired)) self.assertAllClose(res_desired_val, res_actual_val, atol=1e-5, rtol=1e-5)