def testEmbeddingVariableForGradientDescent(self):
        ev = embedding_variable_ops.EmbeddingVariable(
            embedding_dim=3,
            ktype=dtypes.int64,
            initializer=init_ops.ones_initializer(dtypes.float32))

        def loss_fn(ev):
            emb = embedding_ops.embedding_lookup(
                ev, math_ops.cast([0, 1, 2, 5, 6, 7], dtypes.int64))
            fun = math_ops.multiply(emb, 2.0, name='multiply')
            loss = math_ops.reduce_sum(fun, name='reduce_sum')
            return loss

        gs = training_util.get_or_create_global_step()
        opt = gradient_descent.GradientDescentOptimizer(0.1)
        g_v = opt.compute_gradients(lambda: loss_fn(ev), [ev])
        train_op = opt.apply_gradients(g_v)
        emb = embedding_ops.embedding_lookup(
            ev, math_ops.cast([0, 1, 2, 5, 6, 7], dtypes.int64))
        init = variables.global_variables_initializer()
        self.assertEqual(None, self.evaluate(init))
        self.assertEqual(None, self.evaluate(train_op))
        emb_result = self.evaluate(emb)
        grad_result = self.evaluate(g_v[0][0])
        for i in range(6):
            for j in range(3):
                self.assertAlmostEqual(.8, emb_result[i][j], delta=1e-05)
                self.assertAlmostEqual(2.,
                                       grad_result.values[i][j],
                                       delta=1e-05)
 def testEmbeddingVariableForSaveRestore(self):
     ev = embedding_variable_ops.EmbeddingVariable(
         embedding_dim=2,
         initializer=init_ops.random_normal_initializer(),
         ktype=dtypes.int64)
     var_emb = embedding_ops.embedding_lookup(
         ev, math_ops.cast([0, 1, 2], dtypes.int64))
     loss = math_ops.reduce_sum(var_emb)
     optimizer = gradient_descent.GradientDescentOptimizer(0.1)
     with ops.control_dependencies([var_emb]):
         opt = optimizer.minimize(loss)
     saver = saver_module.Saver()
     init = variables.global_variables_initializer()
     with session.Session() as sess:
         sess.run([init])
         sess.run([opt, var_emb])
         sess.run([opt, var_emb])
         sess.run([opt, var_emb])
         save = sess.run(var_emb)
         saver.save(sess, "ckpt")
     with session.Session() as sess:
         saver.restore(sess, "ckpt")
         restore = sess.run(var_emb)
     for i in range(3):
         for j in range(2):
             self.assertAlmostEqual(save[i][j], restore[i][j], delta=1e-05)
 def testEmbeddingVariableForTypeNotMatch(self):
     with self.assertRaises(errors.InvalidArgumentError):
         ev = embedding_variable_ops.EmbeddingVariable(
             embedding_dim=3,
             ktype=dtypes.int32,
             initializer=init_ops.ones_initializer(dtypes.float32))
         emb = embedding_ops.embedding_lookup(
             ev, math_ops.cast([0, 1, 2, 5, 6, 7], dtypes.int64))
 def testEmbeddingVariableForGeneralConstInitializer(self):
     ev = embedding_variable_ops.EmbeddingVariable(
         embedding_dim=3,
         ktype=dtypes.int64,
         initializer=init_ops.ones_initializer(dtypes.float32))
     emb = embedding_ops.embedding_lookup(
         ev, math_ops.cast([1, 6], dtypes.int64))
     init = variables.global_variables_initializer()
     self.assertEqual(None, self.evaluate(init))
     self.assertAllEqual([[1., 1., 1.]] * 2, self.evaluate(emb))
예제 #5
0
 def testEmbeddingVariableForGetShape(self):
   ev = embedding_variable_ops.EmbeddingVariable(
       embedding_dim=3, initializer=init_ops.ones_initializer(dtypes.float32))
   emb = embedding_ops.embedding_lookup(
       ev, math_ops.cast([0, 1, 2, 5, 6, 7], dtypes.int32))
   shape = ev.total_count()
   init = variables.global_variables_initializer()
   self.assertEqual(None, self.evaluate(init))
   self.evaluate(emb)
   self.assertAllEqual([6, 3], self.evaluate(shape))