def forward(
        self,
        dec_input,
        dec_attn_mask,
        enc_output,
        enc_attn_mask,
        layer_past=None,
        get_key_value=False,
    ):
        # convert to Megatron mask
        dec_attn_mask_3d = build_attention_mask_3d(
            source_mask=dec_attn_mask,
            target_mask=dec_attn_mask,
            attn_mask_type=self.model_attn_mask_type,
        )
        enc_dec_attn_mask_3d = build_attention_mask_3d(
            source_mask=dec_attn_mask,
            target_mask=enc_attn_mask,
            attn_mask_type=AttnMaskType.padding,
        )

        # transformer decoder
        dec_output = self.model(
            dec_input,
            attn_mask_postprocess(dec_attn_mask_3d),
            layer_past=layer_past,
            get_key_value=get_key_value,
            encoder_output=enc_output,
            enc_dec_attn_mask=attn_mask_postprocess(enc_dec_attn_mask_3d),
        )

        return dec_output
    def forward(
        self,
        enc_input,
        enc_attn_mask,
        layer_past=None,
        get_key_value=False,
    ):
        # convert to Megatron mask
        enc_attn_mask_3d = build_attention_mask_3d(
            source_mask=enc_attn_mask,
            target_mask=enc_attn_mask,
            attn_mask_type=self.model_attn_mask_type,
        )

        # transformer encoder
        enc_output = self.model(
            enc_input,
            attn_mask_postprocess(enc_attn_mask_3d),
            layer_past=layer_past,
            get_key_value=get_key_value,
        )
        # we copy input mask for transformer
        enc_output_mask = enc_attn_mask

        return enc_output, enc_output_mask