コード例 #1
0
 def testInitialization(self):
     with tf_ops.Graph().as_default():
         input_ids = array_ops.constant([[0], [1], [4]])
         vocab_size = 5
         factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size)
         factor.instantiate_cov_variables()
         cov = factor.get_cov_var()
         self.assertEqual(cov.shape.as_list(), [vocab_size])
コード例 #2
0
  def testCovarianceUpdateOp(self):
    with tf_ops.Graph().as_default():
      input_ids = array_ops.constant([[0], [1], [4]])
      vocab_size = 5
      factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
      factor.instantiate_cov_variables()
      cov_update_op = factor.make_covariance_update_op(0.0)

      with self.test_session() as sess:
        sess.run(tf_variables.global_variables_initializer())
        new_cov = sess.run(cov_update_op)
        self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)