def predict_using_beam_search(inp,
                              beam_size=3,
                              refine_decoder_sampling_type='nucleus',
                              temperature=0.9,
                              p=0.9,
                              k=25):

    dec_padding_mask = create_padding_mask(inp)
    # (batch_size, seq_len, d_bert)
    enc_output = model.bert_model(inp)[0]

    #[batch_size*beam_size, input_Seq_len, d_bert]
    translated_output_temp = draft_summary_beam_search(inp, enc_output,
                                                       dec_padding_mask,
                                                       beam_size)
    # Take the sequence with high score (the last one)
    preds_draft_summary = translated_output_temp[0][:, 0, :]

    preds_refined_summary, refined_attention_dist = refined_summary_sampling(
        inp,
        enc_output=enc_output,
        padding_mask=dec_padding_mask,
        draft_summary=preds_draft_summary,
        sampling_type=refine_decoder_sampling_type,
        temperature=temperature,
        p=p,
        k=k)
    return preds_draft_summary, preds_refined_summary, refined_attention_dist
def predict_using_sampling(inp,
                           draft_decoder_sampling_type='topk',
                           refine_decoder_sampling_type='topk',
                           temperature=0.9,
                           p=0.9,
                           k=25):

    dec_padding_mask = create_padding_mask(inp)

    # (batch_size, seq_len, d_bert)
    enc_output = model.bert_model(inp)[0]
    # (batch_size, seq_len, vocab_len), (_)
    preds_draft_summary, draft_attention_dist = draft_summary_sampling(
        inp,
        enc_output=enc_output,
        look_ahead_mask=None,
        padding_mask=dec_padding_mask,
        sampling_type=draft_decoder_sampling_type,
        temperature=temperature,
        p=p,
        k=k,
    )
    # (batch_size, seq_len, vocab_len), ()
    preds_refined_summary, refined_attention_dist = refined_summary_sampling(
        inp,
        enc_output=enc_output,
        padding_mask=dec_padding_mask,
        draft_summary=preds_draft_summary,
        sampling_type=refine_decoder_sampling_type,
        temperature=temperature,
        p=p,
        k=k)

    return preds_draft_summary, draft_attention_dist, preds_refined_summary, refined_attention_dist
예제 #3
0
def refined_summary_sampling(inp,
                             enc_output,
                             draft_summary,
                             padding_mask,
                             sampling_type='greedy',
                             temperature=0.9,
                             p=0.9,
                             k=25,
                             beam_search=False,
                             training=False):
    """
        Inference call, builds a refined summary
        
        It first masks each word in the summary draft one by one,
        then feeds the draft to BERT to generate context vectors.
        """

    log.info(f"Building: 'Refined {sampling_type} decoder'")
    N = tf.shape(enc_output)[0]
    refined_summary = draft_summary
    batch = tf.shape(draft_summary)[0]
    print(f'draft_summary {tf.shape(draft_summary)}')
    dec_outputs = []
    dec_logits = []
    attention_dists = []
    for i in (range(1, config.summ_length)):

        # (batch_size, seq_len)
        refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID)

        # (batch_size, seq_len, d_bert)
        context_vectors = model.bert_model(refined_summary_)[0]

        # (batch_size, seq_len, d_bert), (_)
        dec_output, dec_logits_i, attention_dist = model.decoder(
            context_vectors,
            enc_output,
            training=training,
            look_ahead_mask=None,
            padding_mask=padding_mask)

        # (batch_size, 1, vocab_len)
        dec_output_i = dec_output[:, i:i + 1, :]
        if sampling_type == 'nucleus':
            preds = tf.cast(
                nucleus_sampling((dec_output_i / temperature), p=p), tf.int32)
        elif sampling_type == 'topk':
            preds = tf.cast(
                top_k_sampling(((dec_output_i) / temperature), k=k), tf.int32)
        elif sampling_type == 'topktopp':
            preds = tf.cast(
                topp_topk(((dec_output_i) / temperature), p=p, k=k), tf.int32)
        elif sampling_type == 'random_sampling':
            preds = tf.cast(sampling((dec_output_i) / temperature), tf.int32)
        else:
            preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32)
        refined_summary = with_column(refined_summary, i, preds)
    # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
    return refined_summary, attention_dist
예제 #4
0
def draft_summary_beam_search(input_ids, beam_size):

    log.info(f"Building: 'Draft beam search decoder'")

    batch = tf.shape(input_ids)[0]
    end = [SEP_ID]
    # (batch_size, seq_len, d_bert)
    enc_output_ = model.bert_model(input_ids)[0]
    enc_output = tf.tile(enc_output_, multiples=[beam_size,1, 1])
    input_ids = tf.tile(input_ids, multiples=[beam_size, 1])
    # (batch_size, 1, 1, seq_len), (_), (batch_size, 1, 1, seq_len)
    dec_input = tf.convert_to_tensor([CLS_ID] * batch)
    output = tf.expand_dims(dec_input, 0)
    def beam_search_decoder(output):
      _, _, dec_padding_mask = create_masks(input_ids, output)    
      embeddings = model.embedding(output)
      predictions, dec_op, attention_weights = model.decoder(
                                                            input_ids, 
                                                            embeddings, 
                                                            enc_output, 
                                                            False, 
                                                            None, 
                                                            dec_padding_mask
                                                            )
      if config.copy_gen:
        predictions = model.decoder.pointer_generator(
                                                      dec_op, 
                                                      predictions,
                                                      attention_weights,
                                                      input_ids,
                                                      tf.shape(input_ids)[1], 
                                                      tf.shape(output)[-1], 
                                                      False
                                                     )
      # (batch_size, 1, target_vocab_size)
      return (predictions[:,-1:,:])
    return (beam_search(
                        beam_search_decoder, 
                        dec_input, 
                        beam_size, 
                        config.summ_length, 
                        config.input_vocab_size, 
                        h_parms.length_penalty, 
                        stop_early=False, 
                        eos_id=[end]
                        ),
                        enc_output_
            )
예제 #5
0
def refined_summary_sampling(inp, 
                           enc_output, 
                           draft_summary, 
                           padding_mask, 
                           sampling_type='greedy', 
                           temperature=0.9, 
                           p=0.9, 
                           k=25,
                           beam_search=False,
                           training=False):
        """
        Inference call, builds a refined summary
        
        It first masks each word in the summary draft one by one,
        then feeds the draft to BERT to generate context vectors.
        """
        
        log.info(f"Building: 'Refined {sampling_type} decoder'")
        N = tf.shape(enc_output)[0]
        refined_summary = tf.expand_dims(draft_summary,0)
        dec_outputs = []
        dec_logits = []
        attention_dists = []
        for i in (range(1, config.summ_length)):
            
            # (batch_size, seq_len)
            refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID)
            # (batch_size, seq_len, d_bert)
            context_vectors = model.bert_model(refined_summary_)[0]
            # (batch_size, seq_len, d_bert), (_)
            dec_output, dec_logits_i, attention_dist = model.decoder(
                                                                    inp,
                                                                    context_vectors,
                                                                    enc_output,
                                                                    training=training,
                                                                    look_ahead_mask=None,
                                                                    padding_mask=padding_mask
                                                                  )
            
            # (batch_size, 1, vocab_len)
            dec_output_i = dec_output[:, i:i+1 ,:]
            if sampling_type == 'nucleus':
              preds = tf.cast(nucleus_sampling((tf.squeeze(dec_output_i)/ temperature), p=p), tf.int32)
            elif sampling_type == 'topk':
              preds = tf.cast(top_k_sampling((tf.squeeze(dec_output_i)/ temperature), k=k), tf.int32)
            elif sampling_type == 'random_sampling':
              preds = tf.cast(sampling(tf.squeeze(dec_output_i)/ temperature), tf.int32)
            else:
              preds = tf.cast(tf.argmax(dec_output_i, axis=-1), tf.int32)
            dec_outputs += [dec_output_i]
            dec_logits_i = dec_logits_i[:, i:i+1, :]
            dec_logits += [dec_logits_i]
            
            refined_summary = with_column(refined_summary, i, preds)
            attention_dists += [attention_dist[:, i:i+1, :]]
        cls_concat_dec_outputs = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.target_vocab_size), axis=0), [N, 1, 1]))
        cls_concat_dec_logits = (tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], config.d_model), axis=0), [N, 1, 1]))
        dec_outputs = tf.reshape(dec_outputs, (1, -1, config.target_vocab_size))
        dec_logits = tf.reshape(dec_logits, (1, -1, config.d_model))
        attention_dists = tf.reshape(attention_dists, (1, -1, config.doc_length))
        dec_outputs = tf.concat([cls_concat_dec_outputs, dec_outputs], axis=1)
        dec_logits = tf.concat([cls_concat_dec_logits, dec_logits], axis=1)
        
        if config.copy_gen: 
          predictions = model.decoder.pointer_generator(
                                                        dec_logits,
                                                        dec_outputs, 
                                                        attention_dists, 
                                                        inp, 
                                                        tf.shape(inp)[-1], 
                                                        tf.shape(dec_outputs)[1], 
                                                        training=training
                                                        )
          refined_summary = tf.cast(tf.argmax(predictions, axis=-1), dtype=tf.int32)
        # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)        
        return tf.squeeze(refined_summary, axis=0), attention_dist