示例#1
0
 def testInitialization(self):
     with tf.Graph().as_default():
         input_ids = tf.constant([[0], [1], [4]])
         vocab_size = 5
         factor = ff.EmbeddingInputKroneckerFactor((input_ids, ),
                                                   vocab_size)
         factor.instantiate_cov_variables()
         cov = factor.cov
         self.assertEqual(cov.shape.as_list(), [vocab_size])
示例#2
0
    def testCovarianceUpdateOp(self):
        with tf.Graph().as_default():
            input_ids = tf.constant([[0], [1], [4]])
            vocab_size = 5
            factor = ff.EmbeddingInputKroneckerFactor((input_ids, ),
                                                      vocab_size)
            factor.instantiate_cov_variables()
            cov_update_op = factor.make_covariance_update_op(0.0)

            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                new_cov = sess.run(cov_update_op)
                self.assertAllClose(
                    np.array([1., 1., 0., 0., 1.]) / 3., new_cov)