示例#1
0
    def build_decoder(cls, args, task, embed_tokens):
        _args = copy.deepcopy(args)
        if args.adaptor_proj or args.encoder_proj:  # not V0 arch
            _args.encoder_embed_dim = _args.decoder_embed_dim
        _args.dropout = args.decoder_dropout
        _args.attention_dropout = args.decoder_attention_dropout
        _args.activation_dropout = args.decoder_activation_dropout

        decoder = TransformerDecoder(_args, task.target_dictionary,
                                     embed_tokens)
        decoder = cls.maybe_load_pretrained(
            decoder, getattr(args, "load_pretrained_decoder_from", None))

        for k, p in decoder.named_parameters():
            p.requires_grad = need_finetuning(args.finetune_decoder_params, k)
        return decoder
示例#2
0
    def build_decoder(cls, args, task):
        _args = copy.deepcopy(args)
        _args.dropout = args.mbart_dropout
        _args.attention_dropout = args.mbart_attention_dropout
        _args.activation_dropout = args.mbart_activation_dropout
        _args.max_target_positions = 1024
        dec_emb = nn.Embedding(
            len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad()
        )
        decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb)
        if getattr(args, "load_pretrained_mbart_from", None):
            decoder = checkpoint_utils.load_pretrained_component_from_model(
                component=decoder, checkpoint=args.load_pretrained_mbart_from
            )
        if getattr(args, "no_final_norm_decoder", False):
            decoder.layer_norm = None
        for k, p in decoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_mbart_decoder_params"
            ) and need_finetuning(
                args.finetune_mbart_decoder_params, k
            ):
                p.requires_grad = True
            else:
                p.requires_grad = False

        compute_cross_attentive_loss = (
            True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
        )
        cross_attentive_loss_without_norm = getattr(
            args, "attentive_cost_without_normalize", False
        )
        cross_attentive_loss_reverse = (
            False  # getattr(args, "attentive_cost_reverse", False)
        )
        decoder = TransformerMultiInputDecoder(
            dictionary=task.target_dictionary,
            spch_decoder=decoder,
            text_decoder=decoder,
            compute_cross_attentive_loss=compute_cross_attentive_loss,
            cross_attentive_loss_with_norm=True
            if not cross_attentive_loss_without_norm
            else False,
            cross_attentive_loss_reverse=cross_attentive_loss_reverse,
        )
        return decoder
示例#3
0
    def build_decoder(cls, args, task, embed_tokens):
        _args = copy.deepcopy(args)
        _args.dropout = args.decoder_dropout
        _args.attention_dropout = args.decoder_attention_dropout
        _args.activation_dropout = args.decoder_activation_dropout
        _args.max_target_positions = 1024

        decoder = TransformerDecoder(_args, task.target_dictionary,
                                     embed_tokens)
        if getattr(args, "load_pretrained_decoder_from", None):
            decoder = checkpoint_utils.load_pretrained_component_from_model(
                component=decoder,
                checkpoint=args.load_pretrained_decoder_from)
        for k, p in decoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(args, 'finetune_decoder_params'
                            ) and XMTransformerModel.finetune_params(
                                args.finetune_decoder_params, k):
                p.requires_grad = True
            else:
                p.requires_grad = False
        return decoder