Example #1
0
 def test_categorical_variable(self):
     random_seed.set_random_seed(42)
     with self.test_session() as sess:
         cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
         embeddings = ops.categorical_variable(cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
         sess.run(variables.global_variables_initializer())
         emb1 = sess.run(embeddings, feed_dict={cat_var_idx.name: [[0, 1], [2, 3]]})
         emb2 = sess.run(embeddings, feed_dict={cat_var_idx.name: [[0, 2], [1, 3]]})
     self.assertEqual(emb1.shape, emb2.shape)
     self.assertAllEqual(np.transpose(emb2, axes=[1, 0, 2]), emb1)
Example #2
0
 def test_categorical_variable(self):
   random_seed.set_random_seed(42)
   with self.test_session() as sess:
     cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
     embeddings = ops.categorical_variable(
         cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
     sess.run(variables.global_variables_initializer())
     emb1 = sess.run(embeddings,
                     feed_dict={cat_var_idx.name: [[0, 1], [2, 3]]})
     emb2 = sess.run(embeddings,
                     feed_dict={cat_var_idx.name: [[0, 2], [1, 3]]})
   self.assertEqual(emb1.shape, emb2.shape)
   self.assertAllEqual(np.transpose(emb2, axes=[1, 0, 2]), emb1)