def test_step_with_greedy_embedding_helper(): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size # cell's logits must match vocabulary size input_depth = 10 start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) end_token = 1 embeddings = np.random.randn(vocabulary_size, input_depth).astype(np.float32) embeddings_t = tf.constant(embeddings) cell = tf.keras.layers.LSTMCell(vocabulary_size) sampler = sampler_py.GreedyEmbeddingSampler() initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler) (first_finished, first_inputs, first_state) = my_decoder.initialize( embeddings_t, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state, ) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype assert (basic_decoder.BasicDecoderOutput(cell_depth, tf.TensorShape( [])) == output_size) assert basic_decoder.BasicDecoderOutput(tf.float32, tf.int32) == output_dtype ( step_outputs, step_state, step_next_inputs, step_finished, ) = my_decoder.step(tf.constant(0), first_inputs, first_state) assert len(first_state) == 2 assert len(step_state) == 2 assert isinstance(step_outputs, basic_decoder.BasicDecoderOutput) assert (batch_size, cell_depth) == step_outputs[0].shape assert (batch_size, ) == step_outputs[1].shape assert (batch_size, cell_depth) == first_state[0].shape assert (batch_size, cell_depth) == first_state[1].shape assert (batch_size, cell_depth) == step_state[0].shape assert (batch_size, cell_depth) == step_state[1].shape expected_sample_ids = np.argmax(step_outputs.rnn_output, -1) expected_step_finished = expected_sample_ids == end_token expected_step_next_inputs = embeddings[expected_sample_ids] np.testing.assert_equal( np.asanyarray([False, False, False, False, False]), first_finished, ) np.testing.assert_equal(expected_step_finished, step_finished) assert output_dtype.sample_id == step_outputs.sample_id.dtype np.testing.assert_equal(expected_sample_ids, step_outputs.sample_id) np.testing.assert_equal(expected_step_next_inputs, step_next_inputs)
def test_dynamic_decode_tflite_conversion(): if test_utils.is_gpu_available(): pytest.skip("cpu-only test") units = 10 vocab_size = 20 cell = tf.keras.layers.LSTMCell(units) sampler = sampler_py.GreedyEmbeddingSampler() embeddings = tf.random.uniform([vocab_size, units]) my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler) @tf.function def _decode(start_tokens, end_token): batch_size = tf.size(start_tokens) initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) return decoder.dynamic_decode( my_decoder, maximum_iterations=5, enable_tflite_convertible=True, decoder_init_input=embeddings, decoder_init_kwargs=dict( initial_state=initial_state, start_tokens=start_tokens, end_token=end_token, ), ) concrete_function = _decode.get_concrete_function( tf.TensorSpec([1], dtype=tf.int32), tf.TensorSpec([], dtype=tf.int32) ) if tf.__version__[:3] >= "2.7": converter = tf.lite.TFLiteConverter.from_concrete_functions( [concrete_function], _decode ) else: converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function]) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS, ] _ = converter.convert() with pytest.raises(tf.errors.InvalidArgumentError, match="batch size"): # Batch size > 1 should throw an error. _decode.get_concrete_function( tf.TensorSpec([2], dtype=tf.int32), tf.TensorSpec([], dtype=tf.int32) )
def testStepWithGreedyEmbeddingHelper(self): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size # cell's logits must match vocabulary size input_depth = 10 start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) end_token = 1 with self.cached_session(use_gpu=True): embeddings = np.random.randn(vocabulary_size, input_depth).astype( np.float32 ) embeddings_t = tf.constant(embeddings) cell = tf.keras.layers.LSTMCell(vocabulary_size) sampler = sampler_py.GreedyEmbeddingSampler() initial_state = cell.get_initial_state( batch_size=batch_size, dtype=tf.float32 ) my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler) (first_finished, first_inputs, first_state) = my_decoder.initialize( embeddings_t, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state, ) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tf.TensorShape([])), output_size, ) self.assertEqual( basic_decoder.BasicDecoderOutput(tf.float32, tf.int32), output_dtype ) ( step_outputs, step_state, step_next_inputs, step_finished, ) = my_decoder.step(tf.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertLen(first_state, 2) self.assertLen(step_state, 2) self.assertTrue(isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) self.evaluate(tf.compat.v1.global_variables_initializer()) eval_result = self.evaluate( { "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished, } ) expected_sample_ids = np.argmax(eval_result["step_outputs"].rnn_output, -1) expected_step_finished = expected_sample_ids == end_token expected_step_next_inputs = embeddings[expected_sample_ids] self.assertAllEqual( [False, False, False, False, False], eval_result["first_finished"] ) self.assertAllEqual(expected_step_finished, eval_result["step_finished"]) self.assertEqual( output_dtype.sample_id, eval_result["step_outputs"].sample_id.dtype ) self.assertAllEqual( expected_sample_ids, eval_result["step_outputs"].sample_id ) self.assertAllEqual( expected_step_next_inputs, eval_result["step_next_inputs"] )