def _testDynamicDecodeRNN(self, time_major, has_attention): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 max_out = max(decoder_sequence_length) output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 with self.test_session() as sess: batch_size_tensor = constant_op.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) initial_state = cell.zero_state(batch_size, dtypes.float32) if has_attention: inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32), shape=(None, None, input_depth)) tiled_inputs = beam_search_decoder.tile_batch( inputs, multiplier=beam_width) tiled_sequence_length = beam_search_decoder.tile_batch( encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) initial_state = beam_search_decoder.tile_batch( initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=False) cell_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) if has_attention: cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=embedding, start_tokens=array_ops.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0) final_outputs, final_state, final_sequence_lengths = ( decoder.dynamic_decode(bsd, output_time_major=time_major, maximum_iterations=max_out)) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape self.assertTrue( isinstance(final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)) self.assertTrue( isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) beam_search_decoder_output = final_outputs.beam_search_decoder_output self.assertEqual( _t((batch_size, None, beam_width)), tuple(beam_search_decoder_output.scores.get_shape().as_list())) self.assertEqual( _t((batch_size, None, beam_width)), tuple(final_outputs.predicted_ids.get_shape().as_list())) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ 'final_outputs': final_outputs, 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths }) max_sequence_length = np.max( sess_results['final_sequence_lengths']) # A smoke test self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), sess_results['final_outputs'].beam_search_decoder_output. scores.shape) self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), sess_results['final_outputs'].beam_search_decoder_output. predicted_ids.shape)
def get_decoder(cell, y__ref_flag, x_ref_flag, tgt_ref_flag, beam_width=None): output_layer_params = ({ "output_layer": tf.identity } if Config.copy_flag else { "vocab_size": vocab.size }) if Config.attn_flag: # attention if Config.attn_x and Config.attn_y_: memory = tf.concat( [ sent_enc_outputs[y__ref_flag], sd_enc_outputs[x_ref_flag] ], axis=1, ) memory_sequence_length = None elif Config.attn_y_: memory = sent_enc_outputs[y__ref_flag] memory_sequence_length = sent_sequence_length[y__ref_flag] elif Config.attn_x: memory = sd_enc_outputs[x_ref_flag] memory_sequence_length = sd_sequence_length[x_ref_flag] else: raise Exception( "Must specify either y__ref_flag or x_ref_flag.") attention_decoder = tx.modules.AttentionRNNDecoder( cell=cell, memory=memory, memory_sequence_length=memory_sequence_length, hparams=Config.config_model.attention_decoder, **output_layer_params) if not Config.copy_flag: return attention_decoder cell = (attention_decoder.cell if beam_width is None else attention_decoder._get_beam_search_cell(beam_width)) if Config.copy_flag: # copynet kwargs = { "y__ids": sent_ids[y__ref_flag][:, 1:], "y__states": sent_enc_outputs[y__ref_flag][:, 1:], "y__lengths": sent_sequence_length[y__ref_flag] - 1, "x_ids": sd_ids[x_ref_flag]["value"], "x_states": sd_enc_outputs[x_ref_flag], "x_lengths": sd_sequence_length[x_ref_flag], } if tgt_ref_flag is not None: kwargs.update({ "input_ids": data_batch["{}_text_ids".format( y_strs[tgt_ref_flag])][:, :-1] }) memory_prefixes = [] if Config.copy_y_: memory_prefixes.append("y_") if Config.copy_x: memory_prefixes.append("x") if beam_width is not None: kwargs = { name: tile_batch(value, beam_width) for name, value in kwargs.items() } def get_get_copy_scores(memory_ids_states_lengths, output_size): memory_copy_states = [ tf.layers.dense( memory_states, units=output_size, activation=None, use_bias=False, ) for _, memory_states, _ in memory_ids_states_lengths ] def get_copy_scores(query, coverities=None): ret = [] if Config.copy_y_: memory = memory_copy_states[len(ret)] if coverities is not None: memory = memory + tf.layers.dense( coverities[len(ret)], units=output_size, activation=None, use_bias=False, ) memory = tf.nn.tanh(memory) ret_y_ = tf.einsum("bim,bm->bi", memory, query) ret.append(ret_y_) if Config.copy_x: memory = memory_copy_states[len(ret)] if coverities is not None: memory = memory + tf.layers.dense( coverities[len(ret)], units=output_size, activation=None, use_bias=False, ) memory = tf.nn.tanh(memory) ret_x = tf.einsum("bim,bm->bi", memory, query) ret.append(ret_x) return ret return get_copy_scores covrity_dim = (Config.config_model.coverage_state_dim if Config.coverage else None) coverity_rnn_cell_hparams = (Config.config_model.coverage_rnn_cell if Config.coverage else None) cell = CopyNetWrapper( cell=cell, vocab_size=vocab.size, memory_ids_states_lengths=[ tuple(kwargs["{}_{}".format(prefix, s)] for s in ("ids", "states", "lengths")) for prefix in memory_prefixes ], input_ids=kwargs["input_ids"] if tgt_ref_flag is not None else None, get_get_copy_scores=get_get_copy_scores, coverity_dim=covrity_dim, coverity_rnn_cell_hparams=coverity_rnn_cell_hparams, disabled_vocab_size=Config.disabled_vocab_size, eps=Config.eps, ) decoder = tx.modules.BasicRNNDecoder( cell=cell, hparams=Config.config_model.decoder, **output_layer_params) return decoder
def _testDynamicDecodeRNN(self, time_major, has_attention, with_alignment_history=False): encoder_sequence_length = np.array([3, 2, 3, 1, 1]) decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 cell_depth = 9 attention_depth = 6 vocab_size = 20 end_token = vocab_size - 1 start_token = 0 embedding_dim = 50 max_out = max(decoder_sequence_length) output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None) beam_width = 3 with self.cached_session() as sess: batch_size_tensor = constant_op.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) initial_state = cell.zero_state(batch_size, dtypes.float32) coverage_penalty_weight = 0.0 if has_attention: coverage_penalty_weight = 0.2 inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, input_depth).astype( np.float32), shape=(None, None, input_depth)) tiled_inputs = beam_search_decoder.tile_batch( inputs, multiplier=beam_width) tiled_sequence_length = beam_search_decoder.tile_batch( encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) initial_state = beam_search_decoder.tile_batch( initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_depth, alignment_history=with_alignment_history) cell_state = cell.zero_state( dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) if has_attention: cell_state = cell_state.clone(cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=embedding, start_tokens=array_ops.fill([batch_size_tensor], start_token), end_token=end_token, initial_state=cell_state, beam_width=beam_width, output_layer=output_layer, length_penalty_weight=0.0, coverage_penalty_weight=coverage_penalty_weight) final_outputs, final_state, final_sequence_lengths = ( decoder.dynamic_decode( bsd, output_time_major=time_major, maximum_iterations=max_out)) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape self.assertTrue( isinstance(final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput)) self.assertTrue( isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) beam_search_decoder_output = final_outputs.beam_search_decoder_output self.assertEqual( _t((batch_size, None, beam_width)), tuple(beam_search_decoder_output.scores.get_shape().as_list())) self.assertEqual( _t((batch_size, None, beam_width)), tuple(final_outputs.predicted_ids.get_shape().as_list())) sess.run(variables.global_variables_initializer()) sess_results = sess.run({ 'final_outputs': final_outputs, 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths }) max_sequence_length = np.max(sess_results['final_sequence_lengths']) # A smoke test self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), sess_results['final_outputs'].beam_search_decoder_output.scores.shape) self.assertEqual( _t((batch_size, max_sequence_length, beam_width)), sess_results[ 'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
def impl(features, mode, hp): contexts = features[ 'contexts'] # batch_size,max_con_length(with query),max_sen_length context_utterance_length = features[ 'context_utterance_length'] # batch_size,max_con_length context_length = features['context_length'] # batch_size if mode == modekeys.TRAIN or mode == modekeys.EVAL: response_in = features['response_in'] # batch,max_res_sen response_out = features['response_out'] # batch,max_res_sen response_mask = features[ 'response_mask'] # batch,max_res_sen, tf.float32 batch_size = hp.batch_size else: batch_size = context_utterance_length.shape[0].value with tf.variable_scope('embedding_layer', reuse=tf.AUTO_REUSE) as vs: embedding_w = get_embedding_matrix(hp.word_dim, mode, hp.vocab_size, random_seed, hp.word_embed_path, hp.vocab_path) contexts = tf.nn.embedding_lookup(embedding_w, contexts, 'context_embedding') if mode == modekeys.TRAIN or mode == modekeys.EVAL: response_in = tf.nn.embedding_lookup(embedding_w, response_in, 'response_in_embedding') with tf.variable_scope('utterance_encoding_layer', reuse=tf.AUTO_REUSE) as vs: kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=random_seed + 1) bias_initializer = tf.zeros_initializer() fw_cell = tf.nn.rnn_cell.GRUCell(num_units=hp.word_rnn_num_units, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer) kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=random_seed - 1) bias_initializer = tf.zeros_initializer() bw_cell = tf.nn.rnn_cell.GRUCell(num_units=hp.word_rnn_num_units, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer) context_t = tf.transpose(contexts, perm=[ 1, 0, 2, 3 ]) # max_con_length(with query),batch_size,max_sen_length context_utterance_length_t = tf.transpose( context_utterance_length, perm=[1, 0]) # max_con_length, batch_size a = tf.split(context_t, hp.max_context_length, axis=0) # 1,batch_size,max_sen_length b = tf.split(context_utterance_length_t, hp.max_context_length, axis=0) # 1,batch_size utterance_encodings = [] for utterance, length in zip(a, b): utterance = tf.squeeze(utterance, axis=0) length = tf.squeeze(length, axis=0) utterance_hidden_states, _ = tf.nn.bidirectional_dynamic_rnn( fw_cell, bw_cell, utterance, sequence_length=length, initial_state_fw=fw_cell.zero_state(batch_size, tf.float32), initial_state_bw=bw_cell.zero_state(batch_size, tf.float32)) utterance_encoding = tf.concat(utterance_hidden_states, axis=2) utterance_encodings.append( tf.expand_dims(utterance_encoding, axis=0)) utterance_encodings = tf.concat( utterance_encodings, axis=0) # max_con_length,batch_size,max_sen,2*word_rnn_num_units with tf.variable_scope('hierarchical_attention_layer', reuse=tf.AUTO_REUSE) as vs: if mode == modekeys.PREDICT and hp.beam_width != 0: utterance_encodings = tf.transpose(utterance_encodings, perm=[1, 0, 2, 3]) utterance_encodings = tile_batch(utterance_encodings, multiplier=hp.beam_width) utterance_encodings = tf.transpose(utterance_encodings, perm=[1, 0, 2, 3]) context_utterance_length_t = tf.transpose( context_utterance_length_t, perm=[1, 0]) context_utterance_length_t = tile_batch(context_utterance_length_t, multiplier=hp.beam_width) context_utterance_length_t = tf.transpose( context_utterance_length_t, perm=[1, 0]) context_length = tile_batch(context_length, multiplier=hp.beam_width) attention_mechanism = ContextAttentionMechanism( context_attn_units=hp.context_attn_units, utte_attn_units=hp.utte_attn_units, context=utterance_encodings, context_utterance_length=context_utterance_length_t, max_context_length=hp.max_context_length, context_rnn_num_units=hp.context_rnn_num_units, context_actual_length=context_length) with tf.variable_scope('decoder_layer', reuse=tf.AUTO_REUSE) as vs: kernel_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=random_seed + 3) bias_initializer = tf.zeros_initializer() decoder_cell = tf.nn.rnn_cell.GRUCell( num_units=hp.decoder_rnn_num_units, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer) attn_cell = AttentionWrapper( decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, output_attention=False) # output_attention should be False output_layer = layers_core.Dense( units=hp.vocab_size, activation=None, use_bias=False) # should use no activation and no bias if mode == modekeys.TRAIN: sequence_length = tf.constant(value=hp.max_sentence_length, dtype=tf.int32, shape=[batch_size]) helper = TrainingHelper(inputs=response_in, sequence_length=sequence_length) decoder = BasicDecoder(cell=attn_cell, helper=helper, initial_state=attn_cell.zero_state( batch_size, tf.float32), output_layer=output_layer) final_outputs, final_state, final_sequence_lengths = dynamic_decode( decoder=decoder, impute_finished=True, parallel_iterations=32, swap_memory=True) logits = final_outputs.rnn_output cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=response_out, logits=logits) cross_entropy = tf.multiply(cross_entropy, response_mask) cross_entropy = tf.reduce_sum(cross_entropy, axis=1) loss = tf.reduce_mean(cross_entropy) l2_norm = hp.lambda_l2 * tf.add_n([ tf.nn.l2_loss(var) for var in tf.trainable_variables() if 'bias' not in var.name ]) loss = loss + l2_norm debug_tensors = [] return loss, debug_tensors elif mode == modekeys.EVAL: sequence_length = tf.constant(value=hp.max_sentence_length, dtype=tf.int32, shape=[batch_size]) helper = tf.contrib.seq2seq.TrainingHelper( inputs=response_in, sequence_length=sequence_length) decoder = BasicDecoder(cell=attn_cell, helper=helper, initial_state=attn_cell.zero_state( batch_size, tf.float32), output_layer=output_layer) final_outputs, final_state, final_sequence_lengths = dynamic_decode( decoder=decoder, impute_finished=True, parallel_iterations=32, swap_memory=True) logits = final_outputs.rnn_output cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=response_out, logits=logits) cross_entropy = tf.reduce_mean(cross_entropy * response_mask) ppl = tf.exp(cross_entropy) return ppl elif mode == modekeys.PREDICT: if hp.beam_width == 0: helper = GreedyEmbeddingHelper(embedding=embedding_w, start_tokens=tf.constant( 1, tf.int32, shape=[batch_size]), end_token=2) initial_state = attn_cell.zero_state(batch_size=batch_size, dtype=tf.float32) decoder = BasicDecoder(cell=attn_cell, helper=helper, initial_state=initial_state, output_layer=output_layer) final_outputs, final_state, final_sequence_lengths = dynamic_decode( decoder, maximum_iterations=hp.max_sentence_length) results = {} results['response_ids'] = final_outputs.sample_id results['response_lens'] = final_sequence_lengths return results else: decoder_initial_state = attn_cell.zero_state( batch_size=batch_size * hp.beam_width, dtype=tf.float32) decoder = BeamSearchDecoder( cell=attn_cell, embedding=embedding_w, start_tokens=tf.constant(1, tf.int32, shape=[batch_size]), end_token=2, initial_state=decoder_initial_state, beam_width=hp.beam_width, output_layer=output_layer) final_outputs, final_state, final_sequence_lengths = dynamic_decode( decoder, impute_finished=False, maximum_iterations=hp.max_sentence_length) final_outputs = final_outputs.predicted_ids # b,s,beam_width final_outputs = tf.transpose(final_outputs, perm=[0, 2, 1]) # b,beam_width,s # predicted_length = final_state.lengths #b,s predicted_length = None results = {} results['response_ids'] = final_outputs results['response_lens'] = None return results
def _build_decoder(self, encoder_outputs, encoder_state): with tf.name_scope("seq_decoder"): batch_size = self.batch_size # sequence_length = tf.fill([self.batch_size], self.num_steps) if self.mode == tf.contrib.learn.ModeKeys.TRAIN: sequence_length = self.iterator.target_length else: sequence_length = self.iterator.source_length if (self.mode != tf.contrib.learn.ModeKeys.TRAIN) and self.beam_width > 1: batch_size = batch_size * self.beam_width encoder_outputs = beam_search_decoder.tile_batch( encoder_outputs, multiplier=self.beam_width) encoder_state = nest.map_structure( lambda s: beam_search_decoder.tile_batch( s, self.beam_width), encoder_state) sequence_length = beam_search_decoder.tile_batch( sequence_length, multiplier=self.beam_width) single_cell = single_rnn_cell(self.hparams.unit_type, self.num_units, self.dropout) decoder_cell = MultiRNNCell( [single_cell for _ in range(self.num_layers_decoder)]) decoder_cell = InputProjectionWrapper(decoder_cell, num_proj=self.num_units) attention_mechanism = create_attention_mechanism( self.hparams.attention_mechanism, self.num_units, memory=encoder_outputs, source_sequence_length=sequence_length) decoder_cell = wrapper.AttentionWrapper( decoder_cell, attention_mechanism, attention_layer_size=self.num_units, output_attention=True, alignment_history=False) # AttentionWrapperState의 cell_state를 encoder의 state으로 설정한다. initial_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32) embeddings_decoder = tf.get_variable( "embedding_decoder", [self.num_decoder_symbols, self.num_units], initializer=self.initializer, dtype=tf.float32) output_layer = Dense(units=self.num_decoder_symbols, use_bias=True, name="output_layer") if self.mode == tf.contrib.learn.ModeKeys.TRAIN: decoder_inputs = tf.nn.embedding_lookup( embeddings_decoder, self.iterator.target_in) decoder_helper = helper.TrainingHelper( decoder_inputs, sequence_length=sequence_length) dec = basic_decoder.BasicDecoder(decoder_cell, decoder_helper, initial_state, output_layer=output_layer) final_outputs, final_state, _ = decoder.dynamic_decode(dec) output_ids = final_outputs.rnn_output outputs = final_outputs.sample_id else: def embedding_fn(inputs): return tf.nn.embedding_lookup(embeddings_decoder, inputs) decoding_length_factor = 2.0 max_encoder_length = tf.reduce_max(self.iterator.source_length) maximum_iterations = tf.to_int32( tf.round( tf.to_float(max_encoder_length) * decoding_length_factor)) tgt_sos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(self.hparams.sos)), tf.int32) tgt_eos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(self.hparams.eos)), tf.int32) start_tokens = tf.fill([self.batch_size], tgt_sos_id) end_token = tgt_eos_id if self.beam_width == 1: decoder_helper = helper.GreedyEmbeddingHelper( embedding=embedding_fn, start_tokens=start_tokens, end_token=end_token) dec = basic_decoder.BasicDecoder(decoder_cell, decoder_helper, initial_state, output_layer=output_layer) else: dec = beam_search_decoder.BeamSearchDecoder( cell=decoder_cell, embedding=embedding_fn, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state, output_layer=output_layer, beam_width=self.beam_width) final_outputs, final_state, _ = decoder.dynamic_decode( dec, # swap_memory=True, maximum_iterations=maximum_iterations) if self.mode == tf.contrib.learn.ModeKeys.TRAIN or self.beam_width == 1: output_ids = final_outputs.sample_id outputs = final_outputs.rnn_output else: output_ids = final_outputs.predicted_ids outputs = final_outputs.beam_search_decoder_output.scores return output_ids, outputs
def _create_decoder_cell(self): enc_outputs, enc_states, enc_seq_len = self.enc_outputs, self.enc_states, self.enc_seq_len if self.use_beam_search: enc_outputs = tile_batch(enc_outputs, multiplier=self.cfg.beam_size) enc_states = nest.map_structure( lambda s: tile_batch(s, self.cfg.beam_size), enc_states) enc_seq_len = tile_batch(self.enc_seq_len, multiplier=self.cfg.beam_size) batch_size = self.batch_size * self.cfg.beam_size if self.use_beam_search else self.batch_size with tf.variable_scope("attention"): if self.cfg.attention == "luong": # Luong attention mechanism attention_mechanism = LuongAttention( num_units=self.cfg.num_units, memory=enc_outputs, memory_sequence_length=enc_seq_len) else: # default using Bahdanau attention mechanism attention_mechanism = BahdanauAttention( num_units=self.cfg.num_units, memory=enc_outputs, memory_sequence_length=enc_seq_len) def cell_input_fn( inputs, attention ): # define cell input function to keep input/output dimension same # reference: https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/AttentionWrapper if not self.cfg.use_attention_input_feeding: return inputs input_project = tf.layers.Dense(self.cfg.num_units, dtype=tf.float32, name='attn_input_feeding') return input_project(tf.concat([inputs, attention], axis=-1)) if self.cfg.top_attention: # apply attention mechanism only on the top decoder layer cells = [ self._create_rnn_cell() for _ in range(self.cfg.num_layers) ] cells[-1] = AttentionWrapper( cells[-1], attention_mechanism=attention_mechanism, name="Attention_Wrapper", attention_layer_size=self.cfg.num_units, initial_cell_state=enc_states[-1], cell_input_fn=cell_input_fn) initial_state = [state for state in enc_states] initial_state[-1] = cells[-1].zero_state(batch_size=batch_size, dtype=tf.float32) dec_init_states = tuple(initial_state) cells = MultiRNNCell(cells) else: cells = MultiRNNCell( [self._create_rnn_cell() for _ in range(self.cfg.num_layers)]) cells = AttentionWrapper(cells, attention_mechanism=attention_mechanism, name="Attention_Wrapper", attention_layer_size=self.cfg.num_units, initial_cell_state=enc_states, cell_input_fn=cell_input_fn) dec_init_states = cells.zero_state( batch_size=batch_size, dtype=tf.float32).clone(cell_state=enc_states) return cells, dec_init_states