def testBahdanauMonotonicNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=4.5706983), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( cell_state=rnn_cell.LSTMStateTuple( c=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473), h=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)), attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testLuongNotNormalized(self): create_attention_mechanism = wrapper.LuongAttentionV2 expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary(shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.05481226), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.38453412), ResultSummary(shape=(5, 9), dtype=np.dtype("float32"), mean=0.5785929) ], attention=ResultSummary(shape=(5, 6), dtype=np.dtype("float32"), mean=0.16311775), time=3, alignments=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary(shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention(create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9)
def testBahdanauMonotonicNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.043294173), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.40034312), ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.5925445)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=0.096119694), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testLuongMonotonicScaled(self): create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 create_attention_kwargs = {"scale": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.159497), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( cell_state=rnn_cell.LSTMStateTuple( c=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384), h=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)), attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.11899644) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs)
def testNotUseAttentionLayer(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.072406612), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742), ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002)], attention=ResultSummary( shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_layer_size=None, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testBahdanauMonotonicNotNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.041342419), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.33866978), ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.46913195)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=0.092498459), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.121448785067) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testBahdanauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=0.051747426), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype(np.int32), mean=3.33333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype(np.float32), mean=0.44189346), ResultSummary( shape=(5, 9), dtype=np.dtype(np.float32), mean=0.65429491)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype(np.float32), mean=0.073610783), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, create_query_layer=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs)
def testBahdanauNormalized(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.047594748), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=3.6)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.41311637), ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.61683208)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=0.090581432), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testLuongMonotonicScaled(self): create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 create_attention_kwargs = {"scale": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.027387079), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.32660431), ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.52464348)], attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=0.089345723), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs)
def testLuongScaled(self): create_attention_mechanism = wrapper.LuongAttentionV2 create_attention_kwargs = {"scale": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=2.6605489), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( cell_state=rnn_cell.LSTMStateTuple( c=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547), h=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)), attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9, create_attention_kwargs=create_attention_kwargs)
def testBahdanauMonotonicNotNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=5.9850435), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( cell_state=rnn_cell.LSTMStateTuple( c=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492), h=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)), attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.117412611) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testBahdanauNormalized(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.9548259), sample_id=ResultSummary( shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) expected_final_state = wrapper.AttentionWrapperState( cell_state=rnn_cell.LSTMStateTuple( c=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209), h=ResultSummary( shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)), attention=ResultSummary( shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), alignment_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, create_query_layer=True, create_attention_kwargs=create_attention_kwargs)
def testBahdanauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324), sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0)) expected_final_state = wrapper.AttentionWrapperState( cell_state=rnn_cell.LSTMStateTuple( c=ResultSummary( shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636), h=ResultSummary( shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)), attention=ResultSummary( shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), attention_state=ResultSummary( shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, alignment_history=True, create_query_layer=True, expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs)
def testStepWithSampleEmbeddingHelper(self): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size # cell's logits must match vocabulary size input_depth = 10 np.random.seed(0) start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) end_token = 1 with self.session(use_gpu=True) as sess: with variable_scope.variable_scope( "testStepWithSampleEmbeddingHelper", initializer=init_ops.constant_initializer(0.01)): embeddings = np.random.randn(vocabulary_size, input_depth).astype(np.float32) cell = rnn_cell.LSTMCell(vocabulary_size) helper = helper_py.SampleEmbeddingHelper(embeddings, start_tokens, end_token, seed=0) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) sample_ids = sess_results["step_outputs"].sample_id self.assertEqual(output_dtype.sample_id, sample_ids.dtype) expected_step_finished = (sample_ids == end_token) expected_step_next_inputs = embeddings[sample_ids] self.assertAllEqual(expected_step_finished, sess_results["step_finished"]) self.assertAllEqual(expected_step_next_inputs, sess_results["step_next_inputs"])
def _testStepWithScheduledOutputTrainingHelper(self, use_next_input_layer): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = input_depth if use_next_input_layer: cell_depth = 6 with self.test_session() as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) half = constant_op.constant(0.5) next_input_layer = None if use_next_input_layer: next_input_layer = layers_core.Dense(input_depth, use_bias=False) helper = helper_py.ScheduledOutputTrainingHelper( inputs=inputs, sequence_length=sequence_length, sampling_probability=half, time_major=False, next_input_layer=next_input_layer) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step(constant_op.constant(0), first_inputs, first_state) if use_next_input_layer: output_after_next_input_layer = next_input_layer( step_outputs.rnn_output) batch_size_t = my_decoder.batch_size self.assertTrue( isinstance(first_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size, ), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) fetches = { "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished } if use_next_input_layer: fetches[ "output_after_next_input_layer"] = output_after_next_input_layer sess_results = sess.run(fetches) self.assertAllEqual([False, False, False, False, True], sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) sample_ids = sess_results["step_outputs"].sample_id batch_where_not_sampling = np.where(np.logical_not(sample_ids)) batch_where_sampling = np.where(sample_ids) if use_next_input_layer: self.assertAllClose( sess_results["step_next_inputs"][batch_where_sampling], sess_results["output_after_next_input_layer"] [batch_where_sampling]) else: self.assertAllClose( sess_results["step_next_inputs"][batch_where_sampling], sess_results["step_outputs"]. rnn_output[batch_where_sampling]) self.assertAllClose( sess_results["step_next_inputs"][batch_where_not_sampling], np.squeeze(inputs[batch_where_not_sampling, 1], axis=1))
def testStepWithInferenceHelperCategorical(self): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size start_token = 0 end_token = 6 start_inputs = array_ops.one_hot( np.ones(batch_size) * start_token, vocabulary_size) # The sample function samples categorically from the logits. sample_fn = lambda x: helper_py.categorical_sample(logits=x) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = ( lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32)) end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token) with self.session(use_gpu=True) as sess: with variable_scope.variable_scope( "testStepWithInferenceHelper", initializer=init_ops.constant_initializer(0.01)): cell = rnn_cell.LSTMCell(vocabulary_size) helper = helper_py.InferenceHelper( sample_fn, sample_shape=(), sample_dtype=dtypes.int32, start_inputs=start_inputs, end_fn=end_fn, next_inputs_fn=next_inputs_fn) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) sample_ids = sess_results["step_outputs"].sample_id self.assertEqual(output_dtype.sample_id, sample_ids.dtype) expected_step_finished = (sample_ids == end_token) expected_step_next_inputs = np.zeros((batch_size, vocabulary_size)) expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0 self.assertAllEqual(expected_step_finished, sess_results["step_finished"]) self.assertAllEqual(expected_step_next_inputs, sess_results["step_next_inputs"])
def testBahndahauNormalized(self): create_attention_mechanism = functools.partial( wrapper.BahdanauAttention, normalize=True, attention_r_initializer=2.0) array = np.array float32 = np.float32 int32 = np.int32 expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=array( [[[ 1.72670335e-02, -5.83671592e-03, 6.38638902e-03, -8.11776379e-04, 1.12681929e-03, -1.24236047e-02 ], [ 1.75918192e-02, -5.73426578e-03, 6.29768707e-03, -8.63141613e-04, 2.03352375e-03, -1.21420780e-02 ], [ 1.72424167e-02, -5.66471322e-03, 6.63427915e-03, -6.23903936e-04, 1.68706616e-03, -1.22524602e-02 ]], [[ 1.79958157e-02, -9.80986748e-03, 4.73218597e-03, -3.89962713e-03, 1.41502675e-02, -1.48344040e-02 ], [ 1.82184577e-02, -9.88379307e-03, 4.90130857e-03, -3.91892251e-03, 1.36479288e-02, -1.53291579e-02 ], [ 1.83001235e-02, -1.00617753e-02, 4.97077405e-03, -3.94908339e-03, 1.37211196e-02, -1.52311027e-02 ]], [[ 7.93476030e-03, -8.46967567e-03, -7.16930721e-04, 4.37953044e-04, 1.04503892e-03, -1.82424393e-02 ], [ 7.90629163e-03, -8.48819874e-03, -5.57833235e-04, 5.02390554e-04, 6.79406337e-04, -1.84837580e-02 ], [ 8.14734399e-03, -8.23053624e-03, -5.92814526e-04, 4.16347990e-04, 1.29250437e-03, -1.84548404e-02 ]], [[ 1.21026095e-02, -1.26739489e-02, 1.78718648e-04, 2.68748170e-03, 7.80996867e-03, -9.69076063e-04 ], [ 1.17978491e-02, -1.32678337e-02, 6.00410858e-05, 2.66301399e-03, 7.00691342e-03, -1.10030361e-03 ], [ 1.15651665e-02, -1.30795036e-02, -2.74205930e-04, 2.48012133e-03, 6.94250735e-03, -8.47495161e-04 ]], [[ 1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03, 8.93574394e-03, -7.28237582e-03 ], [ 1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03, 8.94328393e-03, -7.39541650e-03 ], [ 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, 9.36511997e-03, -7.64556089e-03 ]]], dtype=float32), sample_id=array( [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) expected_final_state = wrapper.DynamicAttentionWrapperState( cell_state=core_rnn_cell.LSTMStateTuple( c=array([[ -0.02209264, -0.00794879, -0.00157153, 0.01614309, -0.01383773, -0.00750943, -0.00824213, -0.01210296, 0.01794949 ], [ 0.01726926, -0.01418139, -0.0040099, 0.0319339, -0.03545783, -0.02142831, -0.00609501, -0.00195033, -0.01938949 ], [ -0.01159083, 0.0087524, -0.01639001, -0.01400012, 0.01342422, -0.01041037, 0.00620991, -0.00960796, -0.00650131 ], [ -0.04763237, -0.01192762, -0.00019377, 0.04103839, -0.00138058, 0.02126443, -0.02793917, -0.05467755, -0.02912025 ], [ 0.02241185, -0.00141741, 0.01911988, 0.00547728, -0.01280068, -0.00307024, -0.00494239, 0.02169247, 0.01631995 ]], dtype=float32), h=array([[ -1.10821165e-02, -3.92766716e-03, -7.99638336e-04, 7.92923011e-03, -7.04019284e-03, -3.77124036e-03, -4.19876305e-03, -6.17261464e-03, 8.95325281e-03 ], [ 8.60597286e-03, -7.16368994e-03, -1.94644753e-03, 1.62479617e-02, -1.76739115e-02, -1.06403306e-02, -3.01484042e-03, -9.74688213e-04, -9.96260438e-03 ], [ -5.78098884e-03, 4.48751403e-03, -8.12216662e-03, -6.94991415e-03, 6.72604749e-03, -5.10144979e-03, 3.08637507e-03, -4.71517537e-03, -3.20256175e-03 ], [ -2.38018110e-02, -5.89307398e-03, -9.74484938e-05, 2.01694984e-02, -6.82370039e-04, 1.07099237e-02, -1.42087601e-02, -2.70793457e-02, -1.44684138e-02 ], [ 1.11825848e-02, -6.99267141e-04, 9.82748345e-03, 2.74566701e-03, -6.56377291e-03, -1.53681310e-03, -2.48806458e-03, 1.10462429e-02, 7.97568541e-03 ]], dtype=float32)), attention=array([[ 0.01724242, -0.00566471, 0.00663428, -0.0006239, 0.00168707, -0.01225246 ], [ 0.01830012, -0.01006178, 0.00497077, -0.00394908, 0.01372112, -0.0152311 ], [ 0.00814734, -0.00823054, -0.00059281, 0.00041635, 0.0012925, -0.01845484 ], [ 0.01156517, -0.0130795, -0.00027421, 0.00248012, 0.00694251, -0.0008475 ], [ 0.01073981, -0.00856867, 0.00152354, 0.00206834, 0.00936512, -0.00764556 ]], dtype=float32)) self._testWithAttention(create_attention_mechanism, expected_final_output, expected_final_state)
def testStepWithTrainingHelper(self): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = 10 with self.test_session() as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) helper = helper_py.TrainingHelper( inputs, sequence_length, time_major=False) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) self.assertAllEqual([False, False, False, False, True], sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) self.assertAllEqual( np.argmax(sess_results["step_outputs"].rnn_output, -1), sess_results["step_outputs"].sample_id)
def testLuongNormalized(self): create_attention_mechanism = functools.partial( wrapper.LuongAttention, normalize=True, attention_r_initializer=2.0) array = np.array float32 = np.float32 int32 = np.int32 expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=array( [[[ 1.23956744e-02, -6.88115368e-03, 3.15234554e-03, -1.97300944e-03, 4.79680905e-03, -1.38076628e-02 ], [ 1.28376717e-02, -6.78718928e-03, 3.07988771e-03, -2.03956687e-03, 5.68403490e-03, -1.35601182e-02 ], [ 1.23463338e-02, -6.76322030e-03, 3.28891934e-03, -1.86874042e-03, 5.47897862e-03, -1.37654068e-02 ]], [[ 1.54412268e-02, -1.07613346e-02, 4.43824846e-03, -8.81063985e-04, 1.26828086e-02, -1.21067995e-02 ], [ 1.57206059e-02, -1.08218864e-02, 4.61952807e-03, -9.61483689e-04, 1.22140013e-02, -1.26614980e-02 ], [ 1.57821011e-02, -1.09842420e-02, 4.66934917e-03, -9.85997496e-04, 1.22719472e-02, -1.25438003e-02 ]], [[ 9.27361846e-03, -9.66077764e-03, -9.69522633e-04, 1.48308463e-05, 3.88664147e-03, -1.64083000e-02 ], [ 9.26287938e-03, -9.74234194e-03, -8.32488062e-04, 5.83778601e-05, 3.52663640e-03, -1.66827720e-02 ], [ 9.50474478e-03, -9.49789397e-03, -8.71829456e-04, -3.09986062e-05, 4.13423358e-03, -1.66635048e-02 ]], [[ 1.21398102e-02, -1.27454493e-02, 1.57688977e-04, 2.70034792e-03, 7.79653806e-03, -8.36936757e-04 ], [ 1.18234595e-02, -1.33170560e-02, 4.55579720e-05, 2.67185434e-03, 6.99766818e-03, -1.00935437e-03 ], [ 1.16009805e-02, -1.31483339e-02, -2.94458936e-04, 2.49248254e-03, 6.92958105e-03, -7.20315147e-04 ]], [[ 1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03, 8.93574394e-03, -7.28237582e-03 ], [ 1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03, 8.94328393e-03, -7.39541650e-03 ], [ 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, 9.36511997e-03, -7.64556089e-03 ]]], dtype=float32), sample_id=array( [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) expected_final_state = wrapper.DynamicAttentionWrapperState( cell_state=core_rnn_cell.LSTMStateTuple( c=array([[ -0.02204949, -0.00805957, -0.001603, 0.01609283, -0.01380462, -0.0074945, -0.00816895, -0.01210009, 0.01795324 ], [ 0.01727016, -0.01420708, -0.00399973, 0.03195432, -0.03547529, -0.02138673, -0.00610332, -0.00191565, -0.01937822 ], [ -0.01160676, 0.00876512, -0.01641791, -0.01400807, 0.01347767, -0.01036341, 0.00627499, -0.00963627, -0.00650573 ], [ -0.04763342, -0.01192671, -0.00019402, 0.04103871, -0.00138017, 0.02126611, -0.02793773, -0.05467714, -0.02912043 ], [ 0.02241185, -0.00141741, 0.01911988, 0.00547728, -0.01280068, -0.00307024, -0.00494239, 0.02169247, 0.01631995 ]], dtype=float32), h=array([[ -1.10610286e-02, -3.98253463e-03, -8.15684092e-04, 7.90454168e-03, -7.02364743e-03, -3.76377185e-03, -4.16135695e-03, -6.17104582e-03, 8.95532966e-03 ], [ 8.60653073e-03, -7.17685232e-03, -1.94147974e-03, 1.62585936e-02, -1.76823437e-02, -1.06195193e-02, -3.01911240e-03, -9.57308919e-04, -9.95720550e-03 ], [ -5.78888878e-03, 4.49400023e-03, -8.13617278e-03, -6.95386063e-03, 6.75271638e-03, -5.07823005e-03, 3.11873178e-03, -4.72912844e-03, -3.20472987e-03 ], [ -2.38023344e-02, -5.89262368e-03, -9.75721487e-05, 2.01696623e-02, -6.82163402e-04, 1.07107637e-02, -1.42080421e-02, -2.70791352e-02, -1.44685050e-02 ], [ 1.11825848e-02, -6.99267141e-04, 9.82748345e-03, 2.74566701e-03, -6.56377291e-03, -1.53681310e-03, -2.48806458e-03, 1.10462429e-02, 7.97568541e-03 ]], dtype=float32)), attention=array( [[ 1.23463338e-02, -6.76322030e-03, 3.28891934e-03, -1.86874042e-03, 5.47897862e-03, -1.37654068e-02 ], [ 1.57821011e-02, -1.09842420e-02, 4.66934917e-03, -9.85997496e-04, 1.22719472e-02, -1.25438003e-02 ], [ 9.50474478e-03, -9.49789397e-03, -8.71829456e-04, -3.09986062e-05, 4.13423358e-03, -1.66635048e-02 ], [ 1.16009805e-02, -1.31483339e-02, -2.94458936e-04, 2.49248254e-03, 6.92958105e-03, -7.20315147e-04 ], [ 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, 9.36511997e-03, -7.64556089e-03 ]], dtype=float32)) self._testWithAttention(create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9)
def testLuongNotNormalized(self): create_attention_mechanism = wrapper.LuongAttention array = np.array float32 = np.float32 int32 = np.int32 expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=array( [[[ 1.23641128e-02, -6.82715839e-03, 3.24165262e-03, -1.90772023e-03, 4.69654519e-03, -1.37025211e-02 ], [ 1.29463980e-02, -6.79699238e-03, 3.10124992e-03, -2.02869414e-03, 5.66399656e-03, -1.35517996e-02 ], [ 1.22659411e-02, -6.81970268e-03, 3.15135531e-03, -1.96937821e-03, 5.62768336e-03, -1.39173865e-02 ]], [[ 1.53944232e-02, -1.07725551e-02, 4.42822604e-03, -8.30623554e-04, 1.26549732e-02, -1.20573286e-02 ], [ 1.57453734e-02, -1.08157266e-02, 4.62466478e-03, -9.88351414e-04, 1.22286947e-02, -1.26876952e-02 ], [ 1.57857724e-02, -1.09536834e-02, 4.64798324e-03, -1.01319887e-03, 1.22695938e-02, -1.25500849e-02 ]], [[ 9.23123397e-03, -9.42669343e-03, -9.09919385e-04, 6.09827694e-05, 3.90436035e-03, -1.63374804e-02 ], [ 9.22935922e-03, -9.57853813e-03, -7.92966573e-04, 8.89014918e-05, 3.52671882e-03, -1.66499857e-02 ], [ 9.49526206e-03, -9.39475093e-03, -8.49372707e-04, -1.72815053e-05, 4.16132808e-03, -1.66336838e-02 ]], [[ 1.21248290e-02, -1.27166547e-02, 1.66158192e-04, 2.69516627e-03, 7.80194718e-03, -8.90152063e-04 ], [ 1.17861275e-02, -1.32453050e-02, 6.66640699e-05, 2.65894993e-03, 7.01114535e-03, -1.14195189e-03 ], [ 1.15833860e-02, -1.31145213e-02, -2.84505659e-04, 2.48642010e-03, 6.93593081e-03, -7.82784075e-04 ]], [[ 1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03, 8.93574394e-03, -7.28237582e-03 ], [ 1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03, 8.94328393e-03, -7.39541650e-03 ], [ 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, 9.36511997e-03, -7.64556089e-03 ]]], dtype=float32), sample_id=array( [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) expected_final_state = wrapper.DynamicAttentionWrapperState( cell_state=core_rnn_cell.LSTMStateTuple( c=array([[ -0.02204997, -0.00805805, -0.00160245, 0.01609369, -0.01380494, -0.00749439, -0.00817, -0.01209992, 0.01795316 ], [ 0.01727016, -0.01420713, -0.00399972, 0.03195436, -0.03547532, -0.02138666, -0.00610335, -0.00191557, -0.01937821 ], [ -0.01160429, 0.00876595, -0.01641685, -0.01400784, 0.01348004, -0.01036458, 0.00627241, -0.00963544, -0.00650568 ], [ -0.04763246, -0.01192755, -0.00019379, 0.04103841, -0.00138055, 0.02126456, -0.02793905, -0.0546775, -0.02912027 ], [ 0.02241185, -0.00141741, 0.01911988, 0.00547728, -0.01280068, -0.00307024, -0.00494239, 0.02169247, 0.01631995 ]], dtype=float32), h=array([[ -1.10612623e-02, -3.98178305e-03, -8.15406092e-04, 7.90496264e-03, -7.02379830e-03, -3.76371504e-03, -4.16189339e-03, -6.17096573e-03, 8.95528216e-03 ], [ 8.60652886e-03, -7.17687514e-03, -1.94147555e-03, 1.62586085e-02, -1.76823605e-02, -1.06194830e-02, -3.01912241e-03, -9.57269047e-04, -9.95719433e-03 ], [ -5.78764686e-03, 4.49441886e-03, -8.13564472e-03, -6.95375400e-03, 6.75391173e-03, -5.07880514e-03, 3.11744539e-03, -4.72871540e-03, -3.20470310e-03 ], [ -2.38018595e-02, -5.89303859e-03, -9.74571449e-05, 2.01695058e-02, -6.82353624e-04, 1.07099945e-02, -1.42086931e-02, -2.70793252e-02, -1.44684194e-02 ], [ 1.11825848e-02, -6.99267141e-04, 9.82748345e-03, 2.74566701e-03, -6.56377291e-03, -1.53681310e-03, -2.48806458e-03, 1.10462429e-02, 7.97568541e-03 ]], dtype=float32)), attention=array( [[ 1.22659411e-02, -6.81970268e-03, 3.15135531e-03, -1.96937821e-03, 5.62768336e-03, -1.39173865e-02 ], [ 1.57857724e-02, -1.09536834e-02, 4.64798324e-03, -1.01319887e-03, 1.22695938e-02, -1.25500849e-02 ], [ 9.49526206e-03, -9.39475093e-03, -8.49372707e-04, -1.72815053e-05, 4.16132808e-03, -1.66336838e-02 ], [ 1.15833860e-02, -1.31145213e-02, -2.84505659e-04, 2.48642010e-03, 6.93593081e-03, -7.82784075e-04 ], [ 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, 9.36511997e-03, -7.64556089e-03 ]], dtype=float32)) self._testWithAttention(create_attention_mechanism, expected_final_output, expected_final_state, attention_mechanism_depth=9)
def _testStepWithScheduledOutputTrainingHelper( self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = input_depth if use_auxiliary_inputs: auxiliary_input_depth = 4 auxiliary_inputs = np.random.randn( batch_size, max_time, auxiliary_input_depth).astype(np.float32) else: auxiliary_inputs = None with self.session(use_gpu=True) as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) sampling_probability = constant_op.constant(sampling_probability) if use_next_inputs_fn: def next_inputs_fn(outputs): # Use deterministic function for test. samples = math_ops.argmax(outputs, axis=1) return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32) else: next_inputs_fn = None helper = helper_py.ScheduledOutputTrainingHelper( inputs=inputs, sequence_length=sequence_length, sampling_probability=sampling_probability, time_major=False, next_inputs_fn=next_inputs_fn, auxiliary_inputs=auxiliary_inputs) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) if use_next_inputs_fn: output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) fetches = { "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished } if use_next_inputs_fn: fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn sess_results = sess.run(fetches) self.assertAllEqual([False, False, False, False, True], sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) sample_ids = sess_results["step_outputs"].sample_id self.assertEqual(output_dtype.sample_id, sample_ids.dtype) batch_where_not_sampling = np.where(np.logical_not(sample_ids)) batch_where_sampling = np.where(sample_ids) auxiliary_inputs_to_concat = ( auxiliary_inputs[:, 1] if use_auxiliary_inputs else np.array([]).reshape(batch_size, 0).astype(np.float32)) expected_next_sampling_inputs = np.concatenate( (sess_results["output_after_next_inputs_fn"][batch_where_sampling] if use_next_inputs_fn else sess_results["step_outputs"].rnn_output[batch_where_sampling], auxiliary_inputs_to_concat[batch_where_sampling]), axis=-1) self.assertAllClose( sess_results["step_next_inputs"][batch_where_sampling], expected_next_sampling_inputs) self.assertAllClose( sess_results["step_next_inputs"][batch_where_not_sampling], np.concatenate( (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0), auxiliary_inputs_to_concat[batch_where_not_sampling]), axis=-1))
def testBahndahauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttention array = np.array float32 = np.float32 int32 = np.int32 expected_final_outputs = basic_decoder.BasicDecoderOutput( rnn_output=array( [[[ 1.25166783e-02, -6.88887993e-03, 3.17239435e-03, -1.98234897e-03, 4.77387803e-03, -1.38330357e-02 ], [ 1.28883058e-02, -6.76271692e-03, 3.13419267e-03, -2.02183682e-03, 5.62057737e-03, -1.35373026e-02 ], [ 1.24917831e-02, -6.71574520e-03, 3.42238229e-03, -1.79501204e-03, 5.33161033e-03, -1.36620644e-02 ]], [[ 1.55150667e-02, -1.07274549e-02, 4.44198400e-03, -9.73310322e-04, 1.27242506e-02, -1.21861566e-02 ], [ 1.57585666e-02, -1.07965544e-02, 4.61554807e-03, -1.01510016e-03, 1.22341057e-02, -1.27029382e-02 ], [ 1.58304181e-02, -1.09712025e-02, 4.67861444e-03, -1.03920139e-03, 1.23004699e-02, -1.25949886e-02 ]], [[ 9.26700700e-03, -9.75431874e-03, -9.95740294e-04, -1.27463136e-06, 3.81659716e-03, -1.64887272e-02 ], [ 9.25191958e-03, -9.80092678e-03, -8.48566880e-04, 5.02091134e-05, 3.46567202e-03, -1.67435352e-02 ], [ 9.48173273e-03, -9.52653307e-03, -8.79382715e-04, -3.07094306e-05, 4.05955408e-03, -1.67226996e-02 ]], [[ 1.21462569e-02, -1.27578378e-02, 1.54045003e-04, 2.70257704e-03, 7.79421115e-03, -8.14041123e-04 ], [ 1.18412934e-02, -1.33513296e-02, 3.54760559e-05, 2.67801876e-03, 6.99122995e-03, -9.46014654e-04 ], [ 1.16087487e-02, -1.31632648e-02, -2.98853614e-04, 2.49515846e-03, 6.92677684e-03, -6.92734495e-04 ]], [[ 1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03, 8.93574394e-03, -7.28237582e-03 ], [ 1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03, 8.94328393e-03, -7.39541650e-03 ], [ 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, 9.36511997e-03, -7.64556089e-03 ]]], dtype=float32), sample_id=array( [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int32)) expected_final_state = wrapper.DynamicAttentionWrapperState( cell_state=core_rnn_cell.LSTMStateTuple( c=array([[ -0.0220502, -0.008058, -0.00160266, 0.01609341, -0.01380513, -0.00749483, -0.00816989, -0.01210028, 0.01795324 ], [ 0.01727026, -0.0142065, -0.00399991, 0.03195379, -0.03547479, -0.02138772, -0.00610318, -0.00191625, -0.01937846 ], [ -0.0116077, 0.00876439, -0.01641787, -0.01400803, 0.01347527, -0.01036386, 0.00627491, -0.0096361, -0.00650565 ], [ -0.04763387, -0.01192631, -0.00019412, 0.04103886, -0.00137999, 0.02126684, -0.02793711, -0.05467696, -0.02912051 ], [ 0.02241185, -0.00141741, 0.01911988, 0.00547728, -0.01280068, -0.00307024, -0.00494239, 0.02169247, 0.01631995 ]], dtype=float32), h=array([[ -1.10613741e-02, -3.98175791e-03, -8.15514475e-04, 7.90482666e-03, -7.02390168e-03, -3.76394135e-03, -4.16183751e-03, -6.17114361e-03, 8.95532221e-03 ], [ 8.60657450e-03, -7.17655150e-03, -1.94156705e-03, 1.62583217e-02, -1.76821016e-02, -1.06200138e-02, -3.01904045e-03, -9.57608980e-04, -9.95732192e-03 ], [ -5.78935863e-03, 4.49362956e-03, -8.13615043e-03, -6.95384294e-03, 6.75151078e-03, -5.07845683e-03, 3.11869266e-03, -4.72904649e-03, -3.20469099e-03 ], [ -2.38025561e-02, -5.89242764e-03, -9.76260417e-05, 2.01697368e-02, -6.82076614e-04, 1.07111251e-02, -1.42077375e-02, -2.70790439e-02, -1.44685479e-02 ], [ 1.11825848e-02, -6.99267141e-04, 9.82748345e-03, 2.74566701e-03, -6.56377291e-03, -1.53681310e-03, -2.48806458e-03, 1.10462429e-02, 7.97568541e-03 ]], dtype=float32)), attention=array( [[ 1.24917831e-02, -6.71574520e-03, 3.42238229e-03, -1.79501204e-03, 5.33161033e-03, -1.36620644e-02 ], [ 1.58304181e-02, -1.09712025e-02, 4.67861444e-03, -1.03920139e-03, 1.23004699e-02, -1.25949886e-02 ], [ 9.48173273e-03, -9.52653307e-03, -8.79382715e-04, -3.07094306e-05, 4.05955408e-03, -1.67226996e-02 ], [ 1.16087487e-02, -1.31632648e-02, -2.98853614e-04, 2.49515846e-03, 6.92677684e-03, -6.92734495e-04 ], [ 1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03, 9.36511997e-03, -7.64556089e-03 ]], dtype=float32)) self._testWithAttention(create_attention_mechanism, expected_final_outputs, expected_final_state)
def testStepWithScheduledEmbeddingTrainingHelper(self): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 vocabulary_size = 10 with self.session(use_gpu=True) as sess: inputs = np.random.randn( batch_size, max_time, input_depth).astype(np.float32) embeddings = np.random.randn( vocabulary_size, input_depth).astype(np.float32) half = constant_op.constant(0.5) cell = rnn_cell.LSTMCell(vocabulary_size) helper = helper_py.ScheduledEmbeddingTrainingHelper( inputs=inputs, sequence_length=sequence_length, embedding=embeddings, sampling_probability=half, time_major=False) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(vocabulary_size, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, vocabulary_size), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, vocabulary_size), first_state[0].get_shape()) self.assertEqual((batch_size, vocabulary_size), first_state[1].get_shape()) self.assertEqual((batch_size, vocabulary_size), step_state[0].get_shape()) self.assertEqual((batch_size, vocabulary_size), step_state[1].get_shape()) self.assertEqual((batch_size, input_depth), step_next_inputs.get_shape()) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) self.assertAllEqual([False, False, False, False, True], sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) sample_ids = sess_results["step_outputs"].sample_id self.assertEqual(output_dtype.sample_id, sample_ids.dtype) batch_where_not_sampling = np.where(sample_ids == -1) batch_where_sampling = np.where(sample_ids > -1) self.assertAllClose( sess_results["step_next_inputs"][batch_where_sampling], embeddings[sample_ids[batch_where_sampling]]) self.assertAllClose( sess_results["step_next_inputs"][batch_where_not_sampling], np.squeeze(inputs[batch_where_not_sampling, 1]))
def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ with ops.name_scope(name, 'TrieSamplerDecoderStep', (time, inputs, state)): if _is_attention_state(state): cell_outputs, cell_state = self._cell(inputs, state) state_trie_keys = state.trie_keys state_trie_exclude = state.trie_exclude elif _is_gnmt_state(state): cell_outputs, cell_state = self._cell(inputs, state) state_trie_keys = state[0].trie_keys state_trie_exclude = state[0].trie_exclude else: cell_outputs, cell_state = self._cell(inputs, state.cell_state) state_trie_keys = state.trie_keys state_trie_exclude = state.trie_exclude if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) cell_outputs_shape = cell_outputs.get_shape() cell_outputs = tf.py_func( _trie_scores_py_func(self.trie, beam_search=False), [cell_outputs, state_trie_keys, state_trie_exclude], tf.float32, stateful=False) cell_outputs.set_shape(cell_outputs_shape) sample_ids = self._helper.sample(time=time, outputs=cell_outputs, state=cell_state) (finished, next_inputs, next_cell_state) = self._helper.next_inputs(time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids) trie_keys = tf.py_func(_amend_trie_keys_py_func(beam_search=False), [state_trie_keys, sample_ids], tf.string, stateful=False) trie_keys.set_shape(state_trie_keys.get_shape()) if _is_attention_state(next_cell_state): next_state = TrieSamplerAttentionState( *next_cell_state, trie_keys=trie_keys, trie_exclude=state_trie_exclude) elif _is_gnmt_state(next_cell_state): next_state = (TrieSamplerAttentionState( *next_cell_state[0], trie_keys=trie_keys, trie_exclude=state_trie_exclude), ) + next_cell_state[1:] else: next_state = TrieSamplerState(cell_state=next_cell_state, trie_keys=trie_keys, trie_exclude=state_trie_exclude) outputs = basic_decoder.BasicDecoderOutput(cell_outputs, sample_ids) return outputs, next_state, next_inputs, finished
def _testStepWithTrainingHelper(self, use_output_layer): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = 10 output_layer_depth = 3 with self.session(use_gpu=True) as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) helper = helper_py.TrainingHelper( inputs, sequence_length, time_major=False) if use_output_layer: output_layer = layers_core.Dense(output_layer_depth, use_bias=False) expected_output_depth = output_layer_depth else: output_layer = None expected_output_depth = cell_depth my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size), output_layer=output_layer) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(expected_output_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, expected_output_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) if use_output_layer: # The output layer was accessed self.assertEqual(len(output_layer.variables), 1) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) self.assertAllEqual([False, False, False, False, True], sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) self.assertEqual(output_dtype.sample_id, sess_results["step_outputs"].sample_id.dtype) self.assertAllEqual( np.argmax(sess_results["step_outputs"].rnn_output, -1), sess_results["step_outputs"].sample_id)
def testStepWithGreedyEmbeddingHelper(self): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size # cell's logits must match vocabulary size input_depth = 10 start_tokens = [0] * batch_size end_token = 1 with self.test_session() as sess: embeddings = np.random.randn(vocabulary_size, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(vocabulary_size) helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens, end_token) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step(constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue( isinstance(first_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size, ), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) expected_sample_ids = np.argmax( sess_results["step_outputs"].rnn_output, -1) expected_step_finished = (expected_sample_ids == end_token) expected_step_next_inputs = embeddings[expected_sample_ids] self.assertAllEqual([False, False, False, False, False], sess_results["first_finished"]) self.assertAllEqual(expected_step_finished, sess_results["step_finished"]) self.assertAllEqual(expected_sample_ids, sess_results["step_outputs"].sample_id) self.assertAllEqual(expected_step_next_inputs, sess_results["step_next_inputs"])
def testStepWithInferenceHelperMultilabel(self): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size start_token = 0 end_token = 6 start_inputs = array_ops.one_hot( np.ones(batch_size) * start_token, vocabulary_size) # The sample function samples independent bernoullis from the logits. sample_fn = ( lambda x: helper_py.bernoulli_sample(logits=x, dtype=dtypes.bool)) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = math_ops.to_float end_fn = lambda sample_ids: sample_ids[:, end_token] with self.session(use_gpu=True) as sess: with variable_scope.variable_scope( "testStepWithInferenceHelper", initializer=init_ops.constant_initializer(0.01)): cell = rnn_cell.LSTMCell(vocabulary_size) helper = helper_py.InferenceHelper( sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool, start_inputs=start_inputs, end_fn=end_fn, next_inputs_fn=next_inputs_fn) my_decoder = basic_decoder.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, cell_depth), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step( constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) sample_ids = sess_results["step_outputs"].sample_id self.assertEqual(output_dtype.sample_id, sample_ids.dtype) expected_step_finished = sample_ids[:, end_token] expected_step_next_inputs = sample_ids.astype(np.float32) self.assertAllEqual(expected_step_finished, sess_results["step_finished"]) self.assertAllEqual(expected_step_next_inputs, sess_results["step_next_inputs"])
def DISABLED_testStepWithGreedyEmbeddingHelper(self): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size # cell's logits must match vocabulary size input_depth = 10 start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) end_token = 1 with self.cached_session(use_gpu=True): embeddings = np.random.randn(vocabulary_size, input_depth).astype(np.float32) embeddings_t = constant_op.constant(embeddings) cell = rnn_cell.LSTMCell(vocabulary_size) sampler = sampler_py.GreedyEmbeddingSampler() initial_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size) my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) (first_finished, first_inputs, first_state) = my_decoder.initialize(embeddings_t, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( basic_decoder.BasicDecoderOutput(cell_depth, tensor_shape.TensorShape([])), output_size) self.assertEqual( basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (step_outputs, step_state, step_next_inputs, step_finished) = my_decoder.step(constant_op.constant(0), first_inputs, first_state) batch_size_t = my_decoder.batch_size self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) self.assertTrue( isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size, ), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) self.evaluate(variables.global_variables_initializer()) eval_result = self.evaluate({ "batch_size": batch_size_t, "first_finished": first_finished, "first_inputs": first_inputs, "first_state": first_state, "step_outputs": step_outputs, "step_state": step_state, "step_next_inputs": step_next_inputs, "step_finished": step_finished }) expected_sample_ids = np.argmax( eval_result["step_outputs"].rnn_output, -1) expected_step_finished = (expected_sample_ids == end_token) expected_step_next_inputs = embeddings[expected_sample_ids] self.assertAllEqual([False, False, False, False, False], eval_result["first_finished"]) self.assertAllEqual(expected_step_finished, eval_result["step_finished"]) self.assertEqual(output_dtype.sample_id, eval_result["step_outputs"].sample_id.dtype) self.assertAllEqual(expected_sample_ids, eval_result["step_outputs"].sample_id) self.assertAllEqual(expected_step_next_inputs, eval_result["step_next_inputs"])