def add_extra_summary(config, vocab_i2s, decoder_output, base_word, output_words, src_words, tgt_words, inserted_words, deleted_words, oov, vocab_size, collections=None): pred_tokens = decoder.str_tokens(decoder_output, vocab_i2s, vocab_size, oov) pred_len = decoder.seq_length(decoder_output) ops = {} if config.get('logger.enable_trace', False): trace_summary = add_extra_summary_trace(pred_tokens, pred_len, base_word, output_words, src_words, tgt_words, inserted_words, deleted_words, vocab_i2s, collections) ops[ES_TRACE] = trace_summary if config.get('logger.enable_bleu', True): avg_bleu = add_extra_summary_avg_bleu(pred_tokens, pred_len, output_words, vocab_i2s, collections) ops[ES_BLEU] = avg_bleu return ops
def model_fn(features, mode, config, embedding_matrix, vocab_tables): if mode == tf.estimator.ModeKeys.PREDICT: base_words, extended_base_words, \ _, _, \ src_words, tgt_words, \ inserted_words, deleted_words, \ oov = features output_words = extended_output_words = tgt_words else: base_words, extended_base_words, \ output_words, extended_output_words, \ src_words, tgt_words, \ inserted_words, deleted_words, \ oov = features is_training = mode == tf.estimator.ModeKeys.TRAIN tf.add_to_collection('is_training', is_training) if mode != tf.estimator.ModeKeys.TRAIN: config.put('editor.enable_dropout', False) config.put('editor.dropout_keep', 1.0) vocab_i2s = vocab_tables[vocab.INT_TO_STR] vocab.init_embeddings(embedding_matrix) vocab_size = len(vocab_tables[vocab.RAW_WORD2ID]) train_decoder_output, infer_decoder_output, \ gold_dec_out, gold_dec_out_len = editor.editor_train( base_words, extended_base_words, output_words, extended_output_words, src_words, tgt_words, inserted_words, deleted_words, oov, vocab_size, config.editor.hidden_dim, config.editor.agenda_dim, config.editor.edit_dim, config.editor.edit_enc.micro_ev_dim, config.editor.edit_enc.num_heads, config.editor.encoder_layers, config.editor.decoder_layers, config.editor.attention_dim, config.editor.beam_width, config.editor.edit_enc.transformer, config.editor.edit_enc.wa_hidden_dim, config.editor.edit_enc.wa_hidden_layer, config.editor.edit_enc.meve_hidden_dim, config.editor.edit_enc.meve_hidden_layer, config.editor.max_sent_length, config.editor.dropout_keep, config.editor.lamb_reg, config.editor.norm_eps, config.editor.norm_max, config.editor.kill_edit, config.editor.draw_edit, config.editor.use_swap_memory, config.get('editor.use_beam_decoder', False), config.get('editor.enable_dropout', False), config.get('editor.no_insert_delete_attn', False), config.get('editor.enable_vae', True) ) loss = optimizer.loss(train_decoder_output, gold_dec_out, gold_dec_out_len) train_op, gradients_norm = optimizer.train(loss, config.optim.learning_rate, config.optim.max_norm_observe_steps) tf.logging.info("Trainable variable") for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): tf.logging.info(str(i)) if mode == tf.estimator.ModeKeys.TRAIN: tf.summary.scalar('grad_norm', gradients_norm) ops = add_extra_summary(config, vocab_i2s, train_decoder_output, base_words, output_words, src_words, tgt_words, inserted_words, deleted_words, oov, vocab_size, collections=['extra']) hooks = [ get_train_extra_summary_writer(config), get_extra_summary_logger(ops, config), ] if config.get('logger.enable_profiler', False): hooks.append(get_profiler_hook(config)) return tf.estimator.EstimatorSpec( mode, train_op=train_op, loss=loss, training_hooks=hooks ) elif mode == tf.estimator.ModeKeys.EVAL: ops = add_extra_summary(config, vocab_i2s, train_decoder_output, base_words, output_words, src_words, tgt_words, inserted_words, deleted_words, oov, vocab_size, collections=['extra']) return tf.estimator.EstimatorSpec( mode, loss=loss, evaluation_hooks=[get_extra_summary_logger(ops, config)], eval_metric_ops={'bleu': tf_metrics.streaming_mean(ops[ES_BLEU])} ) elif mode == tf.estimator.ModeKeys.PREDICT: lengths = decoder.seq_length(infer_decoder_output) tokens = decoder.str_tokens(infer_decoder_output, vocab_i2s, vocab_size, oov) # attns_weight = tf.get_collection('attns_weight') preds = { 'str_tokens': tokens, 'sample_id': decoder.sample_id(infer_decoder_output), 'lengths': lengths, 'joined': metrics.join_tokens(tokens, lengths), } tmee_attentions = tf.get_collection('TransformerMicroEditExtractor_Attentions') if len(tmee_attentions) > 0: preds.update({ 'tmee_attentions_st_enc_self': tmee_attentions[0][0], 'tmee_attentions_st_dec_self': tmee_attentions[0][1], 'tmee_attentions_st_dec_enc': tmee_attentions[0][2], 'tmee_attentions_ts_enc_self': tmee_attentions[1][0], 'tmee_attentions_ts_dec_self': tmee_attentions[1][1], 'tmee_attentions_ts_dec_enc': tmee_attentions[1][2], }) add_decoder_attention(config, preds) return tf.estimator.EstimatorSpec( mode, predictions=preds )
def decoder_outputs_to_pq(decoder_output, base_sent_embed, src_words, tgt_words, src_len, tgt_len, temperature_starter, decay_rate, decay_steps, agenda_dim, enc_hidden_dim, enc_num_layers, dec_hidden_dim, dec_num_layers, swap_memory, use_dropout=False, dropout_keep=1.0): with tf.variable_scope(OPS_NAME): # [batch] output_length = decoder.seq_length(decoder_output) # [batch x max_len x word_dim] output_embed = prepare_output_embed(decoder_output, temperature_starter, decay_rate, decay_steps) # [batch x max_len x hidden], [batch x hidden] hidden_states, dec_output_embedding = encoder.bidirectional_encoder( output_embed, output_length, enc_hidden_dim, enc_num_layers, use_dropout=False, dropout_keep=1.0, swap_memory=swap_memory, name='dec_output_encoder') agenda = tf.concat([dec_output_embedding, base_sent_embed], axis=1) agenda = tf.layers.dense(agenda, agenda_dim, activation=None, use_bias=False, name='agenda') # [batch x agenda_dim] # append START, STOP token to create decoder input and output out = prepare_decoder_input_output(src_words, src_len, None) p_dec_inp, p_dec_inp_len, p_dec_out, p_dec_out_len = out out = prepare_decoder_input_output(tgt_words, tgt_len, None) q_dec_inp, q_dec_inp_len, q_dec_out, q_dec_out_len = out # decode agenda twice p_dec, q_dec = decode_agenda(agenda, p_dec_inp, q_dec_inp, p_dec_inp_len, q_dec_inp_len, dec_hidden_dim, dec_num_layers, swap_memory, use_dropout=use_dropout, dropout_keep=dropout_keep) # calculate the loss with tf.name_scope('losses'): p_loss = decoding_loss(p_dec, p_dec_out, p_dec_out_len) q_loss = decoding_loss(q_dec, q_dec_out, q_dec_out_len) loss = p_loss + q_loss tf.summary.scalar('reconstruction_loss', loss, ['extra']) return loss
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter, decay_rate, decay_steps, edit_dim, enc_hidden_dim, enc_num_layers, dense_layers, swap_memory): with tf.variable_scope(OPS_NAME): # [VOCAB x word_dim] embeddings = vocab.get_embeddings() # Extend embedding matrix to support oov tokens unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN) unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0) unk_embeddings = tf.tile(unk_embed, [50, 1]) # [VOCAB+50 x word_dim] embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0) global_step = tf.train.get_global_step() temperature = tf.train.exponential_decay(temperature_starter, global_step, decay_steps, decay_rate, name='temperature') tf.summary.scalar('temper', temperature, ['extra']) # [batch x max_len x VOCAB+50], softmax probabilities outputs = decoder.rnn_output(decoder_output) # substitute values less than 0 for numerical stability outputs = tf.where(tf.less_equal(outputs, 0), tf.ones_like(outputs) * 1e-10, outputs) # convert softmax probabilities to one_hot vectors dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs) # [batch x max_len x VOCAB+50], one_hot outputs_one_hot = dist.sample() # [batch x max_len x word_dim], one_hot^T * embedding_matrix outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot, embeddings_extended) # [batch] outputs_length = decoder.seq_length(decoder_output) # [batch x max_len x hidden], [batch x hidden] hidden_states, sentence_embedding = encoder.source_sent_encoder( outputs_embed, outputs_length, enc_hidden_dim, enc_num_layers, use_dropout=False, dropout_keep=1.0, swap_memory=swap_memory) h = sentence_embedding for l in dense_layers: h = tf.layers.dense(h, l, activation='relu', name='hidden_%s' % (l)) # [batch x edit_dim] edit_vector = tf.layers.dense(h, edit_dim, activation=None, name='linear') return edit_vector