def refined_summary_sampling(inp, enc_output, draft_summary, padding_mask, sampling_type='greedy', temperature=0.9, p=0.9, k=25, beam_search=False, training=False): """ Inference call, builds a refined summary It first masks each word in the summary draft one by one, then feeds the draft to BERT to generate context vectors. """ log.info(f"Building: 'Refined {sampling_type} decoder'") N = tf.shape(enc_output)[0] refined_summary = draft_summary batch = tf.shape(draft_summary)[0] print(f'draft_summary {tf.shape(draft_summary)}') dec_outputs = [] dec_logits = [] attention_dists = [] for i in (range(1, config.summ_length)): # (batch_size, seq_len) refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID) # (batch_size, seq_len, d_bert) context_vectors = model.bert_model(refined_summary_)[0] # (batch_size, seq_len, d_bert), (_) dec_output, dec_logits_i, attention_dist = model.decoder( context_vectors, enc_output, training=training, look_ahead_mask=None, padding_mask=padding_mask) # (batch_size, 1, vocab_len) dec_output_i = dec_output[:, i:i + 1, :] if sampling_type == 'nucleus': preds = tf.cast( nucleus_sampling((dec_output_i / temperature), p=p), tf.int32) elif sampling_type == 'topk': preds = tf.cast( top_k_sampling(((dec_output_i) / temperature), k=k), tf.int32) elif sampling_type == 'topktopp': preds = tf.cast( topp_topk(((dec_output_i) / temperature), p=p, k=k), tf.int32) elif sampling_type == 'random_sampling': preds = tf.cast(sampling((dec_output_i) / temperature), tf.int32) else: preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32) refined_summary = with_column(refined_summary, i, preds) # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) return refined_summary, attention_dist
def beam_search_decoder(output): # (batch_size, seq_len, d_bert) embeddings = model.embedding(output) predictions, dec_op, attention_weights = model.decoder( embeddings, enc_output, False, None, dec_padding_mask) if config.copy_gen: predictions = model.decoder.pointer_generator( dec_op[:, -1:, :], predictions[:, -1:, :], attention_weights[:, :, -1:, :], input_ids, tf.shape(input_ids)[1], tf.shape(predictions[:, -1:, :])[1], training=False, ) # (batch_size, 1, target_vocab_size) return (predictions[:, -1:, :])
def beam_search_decoder(output): _, _, dec_padding_mask = create_masks(input_ids, output) embeddings = model.embedding(output) predictions, dec_op, attention_weights = model.decoder( input_ids, embeddings, enc_output, False, None, dec_padding_mask ) if config.copy_gen: predictions = model.decoder.pointer_generator( dec_op, predictions, attention_weights, input_ids, tf.shape(input_ids)[1], tf.shape(output)[-1], False ) # (batch_size, 1, target_vocab_size) return (predictions[:,-1:,:])
def draft_summary_sampling(inp, enc_output, look_ahead_mask, padding_mask, sampling_type='greedy', temperature=0.9, p=0.9, k=25, training=False): """ Inference call, builds a draft summary auto-regressively """ log.info(f"Building: 'Draft {sampling_type} decoder'") N = tf.shape(enc_output)[0] T = tf.shape(enc_output)[1] # (batch_size, 1) dec_input = tf.ones([N, 1], dtype=tf.int32) * CLS_ID summary, dec_outputs, dec_logits, attention_dists = [], [], [], [] summary += [dec_input] for i in (range(0, config.summ_length)): _, _, dec_padding_mask = create_masks(inp, dec_input) # (batch_size, i+1, d_bert) embeddings = model.embedding(dec_input) # (batch_size, i+1, vocab), (_) dec_output, dec_logits_i, attention_dist = model.decoder( embeddings, enc_output, training, look_ahead_mask, padding_mask) if config.copy_gen: dec_output = model.decoder.pointer_generator( dec_logits_i, dec_output, attention_dist, inp, tf.shape(inp)[1], tf.shape(dec_output)[1], training=False, ) # (batch_size, 1, vocab) dec_output_i = dec_output[:, -1:, :] if sampling_type == 'nucleus': preds = tf.cast( nucleus_sampling(((dec_output_i) / temperature), p=p), tf.int32) elif sampling_type == 'topk': preds = tf.cast( top_k_sampling(((dec_output_i) / temperature), k=k), tf.int32) elif sampling_type == 'random_sampling': preds = tf.cast(sampling((dec_output_i) / temperature), tf.int32) elif sampling_type == 'topktopp': preds = tf.cast( topp_topk(((dec_output_i) / temperature), p=p, k=k), tf.int32) else: preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32) dec_outputs += [dec_output_i] dec_logits_i = dec_logits_i[:, -1:, :] dec_logits += [dec_logits_i] summary += [preds] dec_input = with_column(dec_input, i + 1, preds) summary = tf.concat(summary, axis=1) # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) return summary, attention_dist
def refined_summary_sampling(inp, enc_output, draft_summary, padding_mask, sampling_type='greedy', temperature=0.9, p=0.9, k=25, beam_search=False, training=False): """ Inference call, builds a refined summary It first masks each word in the summary draft one by one, then feeds the draft to BERT to generate context vectors. """ log.info(f"Building: 'Refined {sampling_type} decoder'") N = tf.shape(enc_output)[0] refined_summary = tf.expand_dims(draft_summary,0) dec_outputs = [] dec_logits = [] attention_dists = [] for i in (range(1, config.summ_length)): # (batch_size, seq_len) refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID) # (batch_size, seq_len, d_bert) context_vectors = model.bert_model(refined_summary_)[0] # (batch_size, seq_len, d_bert), (_) dec_output, dec_logits_i, attention_dist = model.decoder( inp, context_vectors, enc_output, training=training, look_ahead_mask=None, padding_mask=padding_mask ) # (batch_size, 1, vocab_len) dec_output_i = dec_output[:, i:i+1 ,:] if sampling_type == 'nucleus': preds = tf.cast(nucleus_sampling((tf.squeeze(dec_output_i)/ temperature), p=p), tf.int32) elif sampling_type == 'topk': preds = tf.cast(top_k_sampling((tf.squeeze(dec_output_i)/ temperature), k=k), tf.int32) elif sampling_type == 'random_sampling': preds = tf.cast(sampling(tf.squeeze(dec_output_i)/ temperature), tf.int32) else: preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32) dec_outputs += [dec_output_i] dec_logits_i = dec_logits_i[:, i:i+1, :] dec_logits += [dec_logits_i] refined_summary = with_column(refined_summary, i, preds) attention_dists += [attention_dist[:, i:i+1, :]] cls_concat_dec_outputs = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.target_vocab_size), axis=0), [N, 1, 1])) cls_concat_dec_logits = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.d_model), axis=0), [N, 1, 1])) dec_outputs = tf.reshape(dec_outputs, (1, -1, config.target_vocab_size)) dec_logits = tf.reshape(dec_logits, (1, -1, config.d_model)) attention_dists = tf.reshape(attention_dists, (1, -1, config.doc_length)) dec_outputs = tf.concat([cls_concat_dec_outputs, dec_outputs], axis=1) dec_logits = tf.concat([cls_concat_dec_logits, dec_logits], axis=1) if config.copy_gen: predictions = model.decoder.pointer_generator( dec_logits, dec_outputs, attention_dists, inp, tf.shape(inp)[-1], tf.shape(dec_outputs)[1], training=training ) refined_summary = tf.cast(tf.argmax(predictions, axis=-1), dtype=tf.int32) # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) return tf.squeeze(refined_summary, axis=0), attention_dist