def reader(encoder_inputs, source_sequence_length, mode, hparams, target_vocab_size): encoder_features = tf.one_hot(encoder_inputs, target_vocab_size) forward_cell_list, backward_cell_list = [], [] for layer in range(hparams.num_layers): with tf.variable_scope('fw_cell_{}'.format(layer)): cell = lstm_cell(hparams.num_units, hparams.dropout, mode) forward_cell_list.append(cell) with tf.variable_scope('bw_cell_{}'.format(layer)): cell = lstm_cell(hparams.num_units, hparams.dropout, mode) backward_cell_list.append(cell) forward_cell = tf.nn.rnn_cell.MultiRNNCell(forward_cell_list) backward_cell = tf.nn.rnn_cell.MultiRNNCell(backward_cell_list) encoder_outputs, encoder_state = tf.nn.bidirectional_dynamic_rnn( forward_cell, backward_cell, encoder_features, sequence_length=source_sequence_length, dtype=tf.float32) encoder_outputs = tf.concat(encoder_outputs, -1) return (encoder_outputs, source_sequence_length), encoder_state
def listener(encoder_inputs, source_sequence_length, mode, hparams): if hparams.use_pyramidal: return pyramidal_bilstm(encoder_inputs, source_sequence_length, mode, hparams) else: forward_cell_list, backward_cell_list = [], [] for layer in range(hparams.num_layers): with tf.variable_scope('fw_cell_{}'.format(layer)): cell = lstm_cell(hparams.num_units, hparams.dropout, mode) forward_cell_list.append(cell) with tf.variable_scope('bw_cell_{}'.format(layer)): cell = lstm_cell(hparams.num_units, hparams.dropout, mode) backward_cell_list.append(cell) forward_cell = tf.nn.rnn_cell.MultiRNNCell(forward_cell_list) backward_cell = tf.nn.rnn_cell.MultiRNNCell(backward_cell_list) encoder_outputs, encoder_state = tf.nn.bidirectional_dynamic_rnn( forward_cell, backward_cell, encoder_inputs, sequence_length=source_sequence_length, dtype=tf.float32) encoder_outputs = tf.concat(encoder_outputs, -1) return (encoder_outputs, source_sequence_length), encoder_state
def attend(encoder_outputs, source_sequence_length, mode, hparams): memory = encoder_outputs att_kwargs = {} if hparams.attention_type == 'luong': attention_fn = tf_contrib.seq2seq.LuongAttention elif hparams.attention_type == 'bahdanau': attention_fn = tf_contrib.seq2seq.BahdanauAttention elif hparams.attention_type == 'luong_monotonic': attention_fn = tf_contrib.seq2seq.LuongMonotonicAttention elif hparams.attention_type == 'bahdanau_monotonic': attention_fn = tf_contrib.seq2seq.BahdanauMonotonicAttention if mode == tf.estimator.ModeKeys.TRAIN: att_kwargs['sigmoid_noise'] = 1.0 else: att_kwargs['mode'] = 'hard' elif hparams.attention_type == 'custom': attention_fn = CustomAttention attention_mechanism = attention_fn( hparams.num_units, memory, source_sequence_length, **att_kwargs) cell_list = [] for layer in range(hparams.num_layers): with tf.variable_scope('decoder_cell_'.format(layer)): cell = lstm_cell(hparams.num_units, hparams.dropout, mode) cell_list.append(cell) alignment_history = (mode != tf.estimator.ModeKeys.TRAIN) attention_layer_size = hparams.binf_count if hparams.binf_projection else hparams.attention_layer_size if hparams.bottom_only: attention_cell = cell_list.pop(0) attention_cell = tf_contrib.seq2seq.AttentionWrapper( attention_cell, attention_mechanism, attention_layer_size=attention_layer_size, alignment_history=alignment_history) decoder_cell = AttentionMultiCell(attention_cell, cell_list) else: decoder_cell = tf.nn.rnn_cell.MultiRNNCell(cell_list) decoder_cell = tf_contrib.seq2seq.AttentionWrapper( decoder_cell, attention_mechanism, attention_layer_size=attention_layer_size, alignment_history=alignment_history) return decoder_cell
def attend(encoder_outputs, source_sequence_length, mode, hparams): memory = encoder_outputs if hparams.attention_type == 'luong': attention_fn = tf.contrib.seq2seq.LuongAttention elif hparams.attention_type == 'bahdanau': attention_fn = tf.contrib.seq2seq.BahdanauAttention elif hparams.attention_type == 'custom': attention_fn = CustomAttention attention_mechanism = attention_fn(hparams.num_units, memory, source_sequence_length) cell_list = [] for layer in range(hparams.num_layers): with tf.variable_scope('decoder_cell_'.format(layer)): cell = lstm_cell(hparams.num_units, hparams.dropout, mode) cell_list.append(cell) alignment_history = (mode != tf.estimator.ModeKeys.TRAIN) if hparams.bottom_only: attention_cell = cell_list.pop(0) attention_cell = tf.contrib.seq2seq.AttentionWrapper( attention_cell, attention_mechanism, attention_layer_size=hparams.attention_layer_size, alignment_history=alignment_history) decoder_cell = AttentionMultiCell(attention_cell, cell_list) else: decoder_cell = tf.nn.rnn_cell.MultiRNNCell(cell_list) decoder_cell = tf.contrib.seq2seq.AttentionWrapper( decoder_cell, attention_mechanism, attention_layer_size=hparams.attention_layer_size, alignment_history=alignment_history) return decoder_cell
def speller(encoder_outputs, encoder_state, decoder_inputs, source_sequence_length, target_sequence_length, mode, hparams): batch_size = tf.shape(encoder_outputs)[0] beam_width = hparams.beam_width if mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0: source_sequence_length = tf.contrib.seq2seq.tile_batch( source_sequence_length, multiplier=beam_width) encoder_state = tf.contrib.seq2seq.tile_batch( encoder_state, multiplier=beam_width) batch_size = batch_size * beam_width def embedding_fn(ids): # pass callable object to avoid OOM when using one-hot encoding if hparams.embedding_size != 0: target_embedding = tf.get_variable( 'target_embedding', [ hparams.target_vocab_size, hparams.embedding_size], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) return tf.nn.embedding_lookup(target_embedding, ids) else: return tf.one_hot(ids, hparams.target_vocab_size) cell_list = [] for layer in range(hparams.num_layers): with tf.variable_scope('decoder_cell_'.format(layer)): cell = lstm_cell(hparams.num_units * 2, hparams.dropout, mode) cell_list.append(cell) decoder_cell = tf.nn.rnn_cell.MultiRNNCell(cell_list) projection_layer = tf.layers.Dense( hparams.target_vocab_size, use_bias=True, name='projection_layer') initial_state = tuple([LSTMStateTuple(c=tf.concat([es[0].c, es[1].c], axis=-1), h=tf.concat([es[0].h, es[1].h], axis=-1)) for es in encoder_state[-hparams.num_layers:]]) maximum_iterations = None if mode != tf.estimator.ModeKeys.TRAIN: max_source_length = tf.reduce_max(source_sequence_length) maximum_iterations = tf.to_int32(tf.round(tf.to_float( max_source_length) * hparams.decoding_length_factor)) if mode == tf.estimator.ModeKeys.TRAIN: decoder_inputs = embedding_fn(decoder_inputs) if hparams.sampling_probability > 0.0: helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( decoder_inputs, target_sequence_length, embedding_fn, hparams.sampling_probability) else: helper = tf.contrib.seq2seq.TrainingHelper( decoder_inputs, target_sequence_length) decoder = tf.contrib.seq2seq.BasicDecoder( decoder_cell, helper, initial_state, output_layer=projection_layer) elif mode == tf.estimator.ModeKeys.PREDICT and beam_width > 0: start_tokens = tf.fill( [tf.div(batch_size, beam_width)], hparams.sos_id) decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=embedding_fn, start_tokens=start_tokens, end_token=hparams.eos_id, initial_state=initial_state, beam_width=beam_width, output_layer=projection_layer) else: start_tokens = tf.fill([batch_size], hparams.sos_id) helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding_fn, start_tokens, hparams.eos_id) decoder = tf.contrib.seq2seq.BasicDecoder( decoder_cell, helper, initial_state, output_layer=projection_layer) decoder_outputs, final_context_state, final_sequence_length = tf.contrib.seq2seq.dynamic_decode( decoder, maximum_iterations=maximum_iterations) return decoder_outputs, final_context_state, final_sequence_length