def testEmbeddingWrapperWithDynamicRnn(self): with self.test_session() as sess: with variable_scope.variable_scope("root"): inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) embedding_cell = core_rnn_cell_impl.EmbeddingWrapper( core_rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True), embedding_classes=1, embedding_size=2) outputs, _ = rnn.dynamic_rnn(cell=embedding_cell, inputs=inputs, sequence_length=input_lengths, dtype=dtypes.float32) sess.run([variables_lib.global_variables_initializer()]) # This will fail if output's dtype is inferred from input's. sess.run(outputs)
def testEmbeddingWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 1], dtype=dtypes.int32) m = array_ops.zeros([1, 2]) embedding_cell = core_rnn_cell_impl.EmbeddingWrapper( core_rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2) self.assertEqual(embedding_cell.output_size, 2) g, new_m = embedding_cell(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([g, new_m], { x.name: np.array([[1]]), m.name: np.array([[0.1, 0.1]]) }) self.assertEqual(res[1].shape, (1, 2)) # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.17139, 0.17139]])