def testInitialization(self): with tf.Graph().as_default(): input_ids = tf.constant([[0], [1], [4]]) vocab_size = 5 factor = ff.EmbeddingInputKroneckerFactor((input_ids, ), vocab_size) factor.instantiate_cov_variables() cov = factor.cov self.assertEqual(cov.shape.as_list(), [vocab_size])
def testCovarianceUpdateOp(self): with tf.Graph().as_default(): input_ids = tf.constant([[0], [1], [4]]) vocab_size = 5 factor = ff.EmbeddingInputKroneckerFactor((input_ids, ), vocab_size) factor.instantiate_cov_variables() cov_update_op = factor.make_covariance_update_op(0.0) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) new_cov = sess.run(cov_update_op) self.assertAllClose( np.array([1., 1., 0., 0., 1.]) / 3., new_cov)