def _encode(self, input_dict): """ Encodes data into representation :param input_dict: a Python dictionary. Must define: * src_inputs - a Tensor of shape [batch_size, time] or [time, batch_size] (depending on time_major param) * src_lengths - a Tensor of shape [batch_size] :return: a Python dictionary with: * encoder_outputs - a Tensor of shape [batch_size, time, representation_dim] or [time, batch_size, representation_dim] * encoder_state - a Tensor of shape [batch_size, dim] * src_lengths - (copy ref from input) a Tensor of shape [batch_size] """ # TODO: make a separate level of config for cell_params? cell_params = copy.deepcopy(self.params) cell_params["num_units"] = self.params['encoder_cell_units'] self._enc_emb_w = tf.get_variable( name="EncoderEmbeddingMatrix", shape=[self._src_vocab_size, self._src_emb_size], dtype=tf.float32) if self._mode == "train": dp_input_keep_prob = self.params['encoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['encoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 self._encoder_cell_fw = create_rnn_cell( cell_type=self.params['encoder_cell_type'], cell_params=cell_params, num_layers=self.params['encoder_layers'], dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, residual_connections=self.params['encoder_use_skip_connections'], ) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) embedded_inputs = tf.cast( tf.nn.embedding_lookup( self.enc_emb_w, input_dict['src_sequence'], ), self.params['dtype']) encoder_outputs, encoder_state = tf.nn.dynamic_rnn( cell=self._encoder_cell_fw, inputs=embedded_inputs, sequence_length=input_dict['src_length'], time_major=time_major, swap_memory=use_swap_memory, dtype=embedded_inputs.dtype, ) return { 'outputs': encoder_outputs, 'state': encoder_state, 'src_lengths': input_dict['src_length'], 'encoder_input': input_dict['src_sequence'] }
def _decode(self, input_dict): """ Decodes representation into data :param input_dict: Python dictionary with inputs to decoder Must define: * src_inputs - decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * src_lengths - decoder input lengths Tensor of shape [batch_size] Does not need tgt_inputs and tgt_lengths :return: a Python dictionary with: * final_outputs - tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * final_state - tensor with decoder final state * final_sequence_lengths - tensor of shape [batch_size, time] or [time, batch_size] """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_lengths'] self._dec_emb_w = tf.get_variable( name='DecoderEmbeddingMatrix', shape=[self._tgt_vocab_size, self._tgt_emb_size], dtype=tf.float32 ) self._output_projection_layer = tf.layers.Dense( self._tgt_vocab_size, use_bias=False, ) cell_params = copy.deepcopy(self.params) cell_params["num_units"] = self.params['decoder_cell_units'] if self._mode == "train": dp_input_keep_prob = self.params['decoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['decoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 if self.params['attention_type'].startswith('gnmt'): residual_connections = False wrap_to_multi_rnn = False else: residual_connections = self.params['decoder_use_skip_connections'] wrap_to_multi_rnn = True self._decoder_cells = create_rnn_cell( cell_type=self.params['decoder_cell_type'], cell_params=cell_params, num_layers=self.params['decoder_layers'], dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, residual_connections=residual_connections, wrap_to_multi_rnn=wrap_to_multi_rnn, ) tiled_enc_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=self._beam_width, ) tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch( enc_src_lengths, multiplier=self._beam_width, ) attention_mechanism = self._build_attention( tiled_enc_outputs, tiled_enc_src_lengths, ) if self.params['attention_type'].startswith('gnmt'): attention_cell = self._decoder_cells.pop(0) attention_cell = AttentionWrapper( attention_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, # don't use attention layer. output_attention=False, name="gnmt_attention") attentive_decoder_cell = GNMTAttentionMultiCell( attention_cell, self._add_residual_wrapper(self._decoder_cells), use_new_attention=(self.params['attention_type'] == 'gnmt_v2')) else: attentive_decoder_cell = AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanism, ) batch_size_tensor = tf.constant(self._batch_size) embedding_fn = lambda ids: tf.cast( tf.nn.embedding_lookup(self._dec_emb_w, ids), dtype=self.params['dtype']) #decoder = tf.contrib.seq2seq.BeamSearchDecoder( decoder = BeamSearchDecoder( cell=attentive_decoder_cell, embedding=embedding_fn, start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]), end_token=self.END_SYMBOL, initial_state=attentive_decoder_cell.zero_state( dtype=encoder_outputs.dtype, batch_size=batch_size_tensor * self._beam_width, ), beam_width=self._beam_width, output_layer=self._output_projection_layer, length_penalty_weight=self._length_penalty_weight ) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) final_outputs, final_state, final_sequence_lengths = \ tf.contrib.seq2seq.dynamic_decode( decoder=decoder, maximum_iterations=tf.reduce_max(enc_src_lengths) * 2, swap_memory=use_swap_memory, output_time_major=time_major, ) return {'logits': final_outputs.predicted_ids[:, :, 0], 'samples': final_outputs.predicted_ids[:, :, 0], 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths}
def _encode(self, input_dict): self._enc_emb_w = tf.get_variable( name="EncoderEmbeddingMatrix", shape=[self._src_vocab_size, self._src_emb_size], dtype=tf.float32) if self.params['encoder_layers'] < 2: raise ValueError("GNMT encoder must have at least 2 layers") cell_params = copy.deepcopy(self.params) cell_params["num_units"] = self.params['encoder_cell_units'] with tf.variable_scope("Level1FW"): self._encoder_l1_cell_fw = create_rnn_cell( cell_type=self.params['encoder_cell_type'], cell_params=cell_params, num_layers=1, dp_input_keep_prob=1.0, dp_output_keep_prob=1.0, residual_connections=False, ) with tf.variable_scope("Level1BW"): self._encoder_l1_cell_bw = create_rnn_cell( cell_type=self.params['encoder_cell_type'], cell_params=cell_params, num_layers=1, dp_input_keep_prob=1.0, dp_output_keep_prob=1.0, residual_connections=False, ) if self._mode == "train": dp_input_keep_prob = self.params['encoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['encoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 with tf.variable_scope("UniDirLevel"): self._encoder_cells = create_rnn_cell( cell_type=self.params['encoder_cell_type'], cell_params=cell_params, num_layers=self.params['encoder_layers'] - 1, dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, residual_connections=False, wrap_to_multi_rnn=False, ) # add residual connections starting from the third layer for idx, cell in enumerate(self._encoder_cells): if idx > 0: self._encoder_cells[idx] = tf.contrib.rnn.ResidualWrapper( cell) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) embedded_inputs = tf.cast( tf.nn.embedding_lookup( self.enc_emb_w, input_dict['src_sequence'], ), self.params['dtype']) # first bi-directional layer _encoder_output, _ = tf.nn.bidirectional_dynamic_rnn( cell_fw=self._encoder_l1_cell_fw, cell_bw=self._encoder_l1_cell_bw, inputs=embedded_inputs, sequence_length=input_dict['src_length'], swap_memory=use_swap_memory, time_major=time_major, dtype=embedded_inputs.dtype, ) encoder_l1_outputs = tf.concat(_encoder_output, 2) # stack of unidirectional layers encoder_outputs, encoder_state = tf.nn.dynamic_rnn( cell=tf.contrib.rnn.MultiRNNCell(self._encoder_cells), inputs=encoder_l1_outputs, sequence_length=input_dict['src_length'], swap_memory=use_swap_memory, time_major=time_major, dtype=encoder_l1_outputs.dtype, ) return { 'outputs': encoder_outputs, 'state': encoder_state, 'src_lengths': input_dict['src_length'], 'encoder_input': input_dict['src_sequence'] }
def _decode(self, input_dict): """ Decodes representation into data :param input_dict: Python dictionary with inputs to decoder Must define: * src_inputs - decoder input Tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * src_lengths - decoder input lengths Tensor of shape [batch_size] * tgt_inputs - Only during training. labels Tensor of the shape [batch_size, time] or [time, batch_size] * tgt_lengths - Only during training. labels lengths Tensor of the shape [batch_size] :return: a Python dictionary with: * final_outputs - tensor of shape [batch_size, time, dim] or [time, batch_size, dim] * final_state - tensor with decoder final state * final_sequence_lengths - tensor of shape [batch_size, time] or [time, batch_size] """ encoder_outputs = input_dict['encoder_output']['outputs'] enc_src_lengths = input_dict['encoder_output']['src_lengths'] tgt_inputs = input_dict['tgt_sequence'] tgt_lengths = input_dict['tgt_length'] self._dec_emb_w = tf.get_variable( name='DecoderEmbeddingMatrix', shape=[self._tgt_vocab_size, self._tgt_emb_size], dtype=tf.float32, ) self._output_projection_layer = tf.layers.Dense( self._tgt_vocab_size, use_bias=False, ) cell_params = copy.deepcopy(self.params) cell_params["num_units"] = self.params['decoder_cell_units'] if self._mode == "train": dp_input_keep_prob = self.params['decoder_dp_input_keep_prob'] dp_output_keep_prob = self.params['decoder_dp_output_keep_prob'] else: dp_input_keep_prob = 1.0 dp_output_keep_prob = 1.0 if self.params['attention_type'].startswith('gnmt'): residual_connections = False wrap_to_multi_rnn = False else: residual_connections = self.params['decoder_use_skip_connections'] wrap_to_multi_rnn = True self._decoder_cells = create_rnn_cell( cell_type=self.params['decoder_cell_type'], cell_params=cell_params, num_layers=self.params['decoder_layers'], dp_input_keep_prob=dp_input_keep_prob, dp_output_keep_prob=dp_output_keep_prob, residual_connections=residual_connections, wrap_to_multi_rnn=wrap_to_multi_rnn, ) attention_mechanism = self._build_attention( encoder_outputs, enc_src_lengths, ) if self.params['attention_type'].startswith('gnmt'): attention_cell = self._decoder_cells.pop(0) # attention_cell = tf.contrib.seq2seq.AttentionWrapper( attention_cell = AttentionWrapper( attention_cell, attention_mechanism=attention_mechanism, attention_layer_size=None, output_attention=False, name="gnmt_attention") attentive_decoder_cell = GNMTAttentionMultiCell( attention_cell, self._add_residual_wrapper(self._decoder_cells), use_new_attention=(self.params['attention_type'] == 'gnmt_v2')) else: # attentive_decoder_cell = tf.contrib.seq2seq.AttentionWrapper( attentive_decoder_cell = AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanism, ) if self._mode == "train": input_vectors = tf.cast(tf.nn.embedding_lookup(self._dec_emb_w, tgt_inputs), dtype=self.params['dtype']) helper = tf.contrib.seq2seq.TrainingHelper( inputs=input_vectors, sequence_length=tgt_lengths) decoder = tf.contrib.seq2seq.BasicDecoder( cell=attentive_decoder_cell, helper=helper, output_layer=self._output_projection_layer, initial_state=attentive_decoder_cell.zero_state( self._batch_size, dtype=encoder_outputs.dtype, ), ) elif self._mode == "infer" or self._mode == "eval": embedding_fn = lambda ids: tf.cast(tf.nn.embedding_lookup(self._dec_emb_w, ids), dtype=self.params['dtype']) helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding=embedding_fn,#self._dec_emb_w, start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL), end_token=self.END_SYMBOL) decoder = tf.contrib.seq2seq.BasicDecoder( cell=attentive_decoder_cell, helper=helper, initial_state=attentive_decoder_cell.zero_state( batch_size=self._batch_size, dtype=encoder_outputs.dtype, ), output_layer=self._output_projection_layer, ) else: raise ValueError( "Unknown mode for decoder: {}".format(self._mode) ) time_major = self.params.get("time_major", False) use_swap_memory = self.params.get("use_swap_memory", False) if self._mode == 'train': maximum_iterations = tf.reduce_max(tgt_lengths) else: maximum_iterations = tf.reduce_max(enc_src_lengths) * 2 final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, # impute_finished=False if self._decoder_type == "beam_search" else True, impute_finished=True, maximum_iterations=maximum_iterations, swap_memory=use_swap_memory, output_time_major=time_major, ) return {'logits': final_outputs.rnn_output, 'samples': tf.argmax(final_outputs.rnn_output, axis=-1), 'final_state': final_state, 'final_sequence_lengths': final_sequence_lengths}