Exemplo n.º 1
0
 def build_text_encoder(cls, args, src_dictionary, spch_encoder):
     if args.encoder_shared_layers > 0:
         mx_shared_layers = (
             args.speech_encoder_layers
             if args.speech_encoder_layers < args.text_encoder_layers else
             args.text_encoder_layers)
         args.encoder_shared_layers = (
             args.encoder_shared_layers
             if args.encoder_shared_layers <= mx_shared_layers else
             mx_shared_layers)
     cfg = {
         "encoder_embed_dim": args.encoder_text_embed_dim,
         "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
         "encoder_layers": args.text_encoder_layers,
         "encoder_layerdrop": args.encoder_layerdrop,
         "encoder_attention_heads": args.encoder_attention_heads,
         "encoder_learned_pos": args.encoder_learned_pos,
         "max_source_positions": args.max_source_positions,
         "dropout": args.dropout,
         "encoder_normalize_before": args.encoder_normalize_before,
         "activation_dropout": args.activation_dropout,
         "attention_dropout": args.attention_dropout,
         "activation_fn": args.activation_fn,
         "adaptive_input": args.adaptive_input,
         "no_token_positional_embeddings":
         args.no_token_positional_embeddings,
         "no_scale_embedding": args.no_scale_embedding,
         "quant_noise_pq": args.quant_noise_pq,
     }
     model_args = namedtuple("args", cfg.keys())(*cfg.values())
     enc_emb = nn.Embedding(len(src_dictionary),
                            model_args.encoder_embed_dim,
                            src_dictionary.pad())
     text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
     if args.add_speech_eos:
         spch_encoder = spch_encoder.encoder
     if args.encoder_shared_layers > 0:
         text_encoder.layer_norm = cls.set_shared_layer(
             args.encoder_shared_layer_level,
             text_encoder.layer_norm,
             spch_encoder.layer_norm,
         )
         for i, ly in enumerate(
                 spch_encoder.
                 transformer_layers[-args.encoder_shared_layers:]):
             ly_id = i + args.text_encoder_layers - args.encoder_shared_layers
             if not isinstance(text_encoder.layers[ly_id], type(ly)):
                 if text_encoder.layers[ly_id]._get_name() not in (
                         'TransformerEncoderLayerBase',
                         'TransformerEncoderLayer'):
                     raise ValueError(
                         "The shared layers are expected from the same class"
                     )
             text_encoder.layers[ly_id] = cls.set_shared_layer(
                 args.encoder_shared_layer_level,
                 text_encoder.layers[ly_id],
                 ly,
             )
     return text_encoder
Exemplo n.º 2
0
    def build_encoder(cls, args, task):
        _args = copy.deepcopy(args)
        _args.dropout = args.mbart_dropout
        _args.attention_dropout = args.mbart_attention_dropout
        _args.activation_dropout = args.mbart_activation_dropout
        _args.max_source_positions = 1024
        enc_emb = nn.Embedding(
            len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad()
        )
        text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb)
        spch_encoder = Wav2VecEncoderWithAdaptor(args)
        if getattr(args, "load_pretrained_mbart_from", None):
            text_encoder = checkpoint_utils.load_pretrained_component_from_model(
                component=text_encoder, checkpoint=args.load_pretrained_mbart_from
            )
        if getattr(args, "stack_w2v_mbart_encoder", False):
            assert getattr(args, "share_w2v_text_encoder", False) is False
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False):
            text_encoder.layer_norm = None
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "share_w2v_text_encoder", False):
            spch_encoder = SharedEncoder(
                spch_encoder.w2v_encoder,
                text_encoder,
                spch_encoder.adaptor,
                args.shared_w2v_layers,
            )

        for k, p in spch_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_w2v_params"
            ) and need_finetuning(args.finetune_w2v_params, k):
                p.requires_grad = True
            else:
                p.requires_grad = False
        for k, p in text_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_mbart_encoder_params"
            ) and need_finetuning(
                args.finetune_mbart_encoder_params, k
            ):
                p.requires_grad = True
            else:
                p.requires_grad = False
        cross_attentive_loss_before_last_layer = (
            0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
        )
        encoder = DualInputEncoder(
            args,
            spch_encoder,
            text_encoder,
            task.src_dict,
            cross_attentive_loss_before_last_layer,
        )
        return encoder