def predict(self,
                input_ids,
                enc_padding_mask,
                decoder_type=config.draft_decoder_type,
                beam_size=config.beam_size,
                length_penalty=config.length_penalty,
                temperature=config.softmax_temperature,
                top_p=config.top_p,
                top_k=config.top_k):

        # (batch_size, inp_seq_len, d_model)
        # Both dec_padding_mask and enc_padding_mask are same
        batch_size = tf.shape(input_ids)[0]
        enc_output = self.encoder(input_ids, False, enc_padding_mask)
        # (batch_size, seq_len, vocab_len),
        # ()
        (predicted_draft_output_sequence,
         draft_attention_dist) = draft_decoder(self,
                                               input_ids,
                                               enc_output=enc_output,
                                               beam_size=beam_size,
                                               length_penalty=length_penalty,
                                               temperature=temperature,
                                               top_p=top_p,
                                               top_k=top_k,
                                               batch_size=batch_size)

        return (predicted_draft_output_sequence, draft_attention_dist, None,
                None)
    def predict(self,
               input_ids,
               draft_decoder_type,
               beam_size,
               length_penalty, 
               temperature, 
               top_p, 
               top_k,
               refine_decoder_type=config.refine_decoder_type):

        # (batch_size, seq_len, d_bert)
        batch_size = tf.shape(input_ids)[0]
        enc_output = self.encoder(input_ids)[0]
        # (batch_size, seq_len, vocab_len), 
        # ()
        (predicted_draft_output_sequence, 
          draft_attention_dist) = draft_decoder(self,
                                                input_ids,
                                                enc_output=enc_output,
                                                beam_size=beam_size,
                                                length_penalty=length_penalty,
                                                temperature=temperature,
                                                top_p=top_p, 
                                                top_k=top_k,
                                                batch_size=batch_size
                                                )
        # (batch_size, seq_len, vocab_len), 
        # ()
        (predicted_refined_output_sequence, 
          refined_attention_dist) = self.refined_output_sequence_sampling(
                                            input_ids,
                                            enc_output=enc_output,
                                            draft_output_sequence=predicted_draft_output_sequence,
                                            decoder_type=refine_decoder_type,
                                            batch_size=batch_size, 
                                            temperature=temperature, 
                                            top_p=top_p, 
                                            top_k=top_k
                                            )
        
        return (predicted_draft_output_sequence, draft_attention_dist, 
               predicted_refined_output_sequence, refined_attention_dist)