def predict_using_beam_search(inp, beam_size=3, refine_decoder_sampling_type='nucleus', temperature=0.9, p=0.9, k=25): dec_padding_mask = create_padding_mask(inp) # (batch_size, seq_len, d_bert) enc_output = model.bert_model(inp)[0] #[batch_size*beam_size, input_Seq_len, d_bert] translated_output_temp = draft_summary_beam_search(inp, enc_output, dec_padding_mask, beam_size) # Take the sequence with high score (the last one) preds_draft_summary = translated_output_temp[0][:, 0, :] preds_refined_summary, refined_attention_dist = refined_summary_sampling( inp, enc_output=enc_output, padding_mask=dec_padding_mask, draft_summary=preds_draft_summary, sampling_type=refine_decoder_sampling_type, temperature=temperature, p=p, k=k) return preds_draft_summary, preds_refined_summary, refined_attention_dist
def predict_using_sampling(inp, draft_decoder_sampling_type='topk', refine_decoder_sampling_type='topk', temperature=0.9, p=0.9, k=25): dec_padding_mask = create_padding_mask(inp) # (batch_size, seq_len, d_bert) enc_output = model.bert_model(inp)[0] # (batch_size, seq_len, vocab_len), (_) preds_draft_summary, draft_attention_dist = draft_summary_sampling( inp, enc_output=enc_output, look_ahead_mask=None, padding_mask=dec_padding_mask, sampling_type=draft_decoder_sampling_type, temperature=temperature, p=p, k=k, ) # (batch_size, seq_len, vocab_len), () preds_refined_summary, refined_attention_dist = refined_summary_sampling( inp, enc_output=enc_output, padding_mask=dec_padding_mask, draft_summary=preds_draft_summary, sampling_type=refine_decoder_sampling_type, temperature=temperature, p=p, k=k) return preds_draft_summary, draft_attention_dist, preds_refined_summary, refined_attention_dist
def refined_summary_sampling(inp, enc_output, draft_summary, padding_mask, sampling_type='greedy', temperature=0.9, p=0.9, k=25, beam_search=False, training=False): """ Inference call, builds a refined summary It first masks each word in the summary draft one by one, then feeds the draft to BERT to generate context vectors. """ log.info(f"Building: 'Refined {sampling_type} decoder'") N = tf.shape(enc_output)[0] refined_summary = draft_summary batch = tf.shape(draft_summary)[0] print(f'draft_summary {tf.shape(draft_summary)}') dec_outputs = [] dec_logits = [] attention_dists = [] for i in (range(1, config.summ_length)): # (batch_size, seq_len) refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID) # (batch_size, seq_len, d_bert) context_vectors = model.bert_model(refined_summary_)[0] # (batch_size, seq_len, d_bert), (_) dec_output, dec_logits_i, attention_dist = model.decoder( context_vectors, enc_output, training=training, look_ahead_mask=None, padding_mask=padding_mask) # (batch_size, 1, vocab_len) dec_output_i = dec_output[:, i:i + 1, :] if sampling_type == 'nucleus': preds = tf.cast( nucleus_sampling((dec_output_i / temperature), p=p), tf.int32) elif sampling_type == 'topk': preds = tf.cast( top_k_sampling(((dec_output_i) / temperature), k=k), tf.int32) elif sampling_type == 'topktopp': preds = tf.cast( topp_topk(((dec_output_i) / temperature), p=p, k=k), tf.int32) elif sampling_type == 'random_sampling': preds = tf.cast(sampling((dec_output_i) / temperature), tf.int32) else: preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32) refined_summary = with_column(refined_summary, i, preds) # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) return refined_summary, attention_dist
def draft_summary_beam_search(input_ids, beam_size): log.info(f"Building: 'Draft beam search decoder'") batch = tf.shape(input_ids)[0] end = [SEP_ID] # (batch_size, seq_len, d_bert) enc_output_ = model.bert_model(input_ids)[0] enc_output = tf.tile(enc_output_, multiples=[beam_size,1, 1]) input_ids = tf.tile(input_ids, multiples=[beam_size, 1]) # (batch_size, 1, 1, seq_len), (_), (batch_size, 1, 1, seq_len) dec_input = tf.convert_to_tensor([CLS_ID] * batch) output = tf.expand_dims(dec_input, 0) def beam_search_decoder(output): _, _, dec_padding_mask = create_masks(input_ids, output) embeddings = model.embedding(output) predictions, dec_op, attention_weights = model.decoder( input_ids, embeddings, enc_output, False, None, dec_padding_mask ) if config.copy_gen: predictions = model.decoder.pointer_generator( dec_op, predictions, attention_weights, input_ids, tf.shape(input_ids)[1], tf.shape(output)[-1], False ) # (batch_size, 1, target_vocab_size) return (predictions[:,-1:,:]) return (beam_search( beam_search_decoder, dec_input, beam_size, config.summ_length, config.input_vocab_size, h_parms.length_penalty, stop_early=False, eos_id=[end] ), enc_output_ )
def refined_summary_sampling(inp, enc_output, draft_summary, padding_mask, sampling_type='greedy', temperature=0.9, p=0.9, k=25, beam_search=False, training=False): """ Inference call, builds a refined summary It first masks each word in the summary draft one by one, then feeds the draft to BERT to generate context vectors. """ log.info(f"Building: 'Refined {sampling_type} decoder'") N = tf.shape(enc_output)[0] refined_summary = tf.expand_dims(draft_summary,0) dec_outputs = [] dec_logits = [] attention_dists = [] for i in (range(1, config.summ_length)): # (batch_size, seq_len) refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID) # (batch_size, seq_len, d_bert) context_vectors = model.bert_model(refined_summary_)[0] # (batch_size, seq_len, d_bert), (_) dec_output, dec_logits_i, attention_dist = model.decoder( inp, context_vectors, enc_output, training=training, look_ahead_mask=None, padding_mask=padding_mask ) # (batch_size, 1, vocab_len) dec_output_i = dec_output[:, i:i+1 ,:] if sampling_type == 'nucleus': preds = tf.cast(nucleus_sampling((tf.squeeze(dec_output_i)/ temperature), p=p), tf.int32) elif sampling_type == 'topk': preds = tf.cast(top_k_sampling((tf.squeeze(dec_output_i)/ temperature), k=k), tf.int32) elif sampling_type == 'random_sampling': preds = tf.cast(sampling(tf.squeeze(dec_output_i)/ temperature), tf.int32) else: preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32) dec_outputs += [dec_output_i] dec_logits_i = dec_logits_i[:, i:i+1, :] dec_logits += [dec_logits_i] refined_summary = with_column(refined_summary, i, preds) attention_dists += [attention_dist[:, i:i+1, :]] cls_concat_dec_outputs = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.target_vocab_size), axis=0), [N, 1, 1])) cls_concat_dec_logits = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.d_model), axis=0), [N, 1, 1])) dec_outputs = tf.reshape(dec_outputs, (1, -1, config.target_vocab_size)) dec_logits = tf.reshape(dec_logits, (1, -1, config.d_model)) attention_dists = tf.reshape(attention_dists, (1, -1, config.doc_length)) dec_outputs = tf.concat([cls_concat_dec_outputs, dec_outputs], axis=1) dec_logits = tf.concat([cls_concat_dec_logits, dec_logits], axis=1) if config.copy_gen: predictions = model.decoder.pointer_generator( dec_logits, dec_outputs, attention_dists, inp, tf.shape(inp)[-1], tf.shape(dec_outputs)[1], training=training ) refined_summary = tf.cast(tf.argmax(predictions, axis=-1), dtype=tf.int32) # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) return tf.squeeze(refined_summary, axis=0), attention_dist