def _build_encoder_cell(self, hparams, base_gpu=0): """Build a multi-layer RNN cell that can be used by encoder.""" return model_helper.create_rnn_cell(num_units=hparams.num_units, num_layers=hparams.num_layers, keep_prob=hparams.keep_prob, num_gpus=hparams.num_gpus, base_gpu=base_gpu)
def _build_encoder(self, embedding, input): if self.time_major: input = tf.transpose(input) encoder_emb_inp = tf.nn.embedding_lookup(embedding, input) cell = model_helper.create_rnn_cell(512, 1, 0.9) return tf.nn.dynamic_rnn(cell, encoder_emb_inp, dtype=tf.float32, time_major=self.time_major)
def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, source_sequence_length): """Build a RNN cell with attention mechanism that can be used by decoder.""" num_units = hparams.num_units num_layers = hparams.num_layers num_gpus = hparams.num_gpus beam_width = hparams.beam_width dtype = tf.float32 if self.time_major: memory = tf.transpose(encoder_outputs, [1, 0, 2]) else: memory = encoder_outputs if not self.training and beam_width > 0: memory = tf.contrib.seq2seq.tile_batch(memory, multiplier=beam_width) source_sequence_length = tf.contrib.seq2seq.tile_batch( source_sequence_length, multiplier=beam_width) encoder_state = tf.contrib.seq2seq.tile_batch( encoder_state, multiplier=beam_width) batch_size = self.batch_size * beam_width else: batch_size = self.batch_size attention_mechanism = tf.contrib.seq2seq.LuongAttention( num_units, memory, memory_sequence_length=source_sequence_length) cell = model_helper.create_rnn_cell(num_units=num_units, num_layers=num_layers, keep_prob=hparams.keep_prob, num_gpus=num_gpus) # Only generate alignment in greedy INFER mode. alignment_history = (not self.training and beam_width == 0) cell = tf.contrib.seq2seq.AttentionWrapper( cell, attention_mechanism, attention_layer_size=num_units, alignment_history=alignment_history, name="attention") if self.training: # A bug in TensorFlow prevents this to work at inference time. Therefore skipped. cell = tf.contrib.rnn.DeviceWrapper( cell, model_helper.get_device_str(num_layers - 1, num_gpus)) if hparams.pass_hidden_state: decoder_initial_state = cell.zero_state( batch_size, dtype).clone(cell_state=encoder_state) else: decoder_initial_state = cell.zero_state(batch_size, dtype) return cell, decoder_initial_state