def build_text_encoder(cls, args, src_dictionary, spch_encoder): if args.encoder_shared_layers > 0: mx_shared_layers = ( args.speech_encoder_layers if args.speech_encoder_layers < args.text_encoder_layers else args.text_encoder_layers) args.encoder_shared_layers = ( args.encoder_shared_layers if args.encoder_shared_layers <= mx_shared_layers else mx_shared_layers) cfg = { "encoder_embed_dim": args.encoder_text_embed_dim, "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim, "encoder_layers": args.text_encoder_layers, "encoder_layerdrop": args.encoder_layerdrop, "encoder_attention_heads": args.encoder_attention_heads, "encoder_learned_pos": args.encoder_learned_pos, "max_source_positions": args.max_source_positions, "dropout": args.dropout, "encoder_normalize_before": args.encoder_normalize_before, "activation_dropout": args.activation_dropout, "attention_dropout": args.attention_dropout, "activation_fn": args.activation_fn, "adaptive_input": args.adaptive_input, "no_token_positional_embeddings": args.no_token_positional_embeddings, "no_scale_embedding": args.no_scale_embedding, "quant_noise_pq": args.quant_noise_pq, } model_args = namedtuple("args", cfg.keys())(*cfg.values()) enc_emb = nn.Embedding(len(src_dictionary), model_args.encoder_embed_dim, src_dictionary.pad()) text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb) if args.add_speech_eos: spch_encoder = spch_encoder.encoder if args.encoder_shared_layers > 0: text_encoder.layer_norm = cls.set_shared_layer( args.encoder_shared_layer_level, text_encoder.layer_norm, spch_encoder.layer_norm, ) for i, ly in enumerate( spch_encoder. transformer_layers[-args.encoder_shared_layers:]): ly_id = i + args.text_encoder_layers - args.encoder_shared_layers if not isinstance(text_encoder.layers[ly_id], type(ly)): if text_encoder.layers[ly_id]._get_name() not in ( 'TransformerEncoderLayerBase', 'TransformerEncoderLayer'): raise ValueError( "The shared layers are expected from the same class" ) text_encoder.layers[ly_id] = cls.set_shared_layer( args.encoder_shared_layer_level, text_encoder.layers[ly_id], ly, ) return text_encoder
def build_encoder(cls, args, task): _args = copy.deepcopy(args) _args.dropout = args.mbart_dropout _args.attention_dropout = args.mbart_attention_dropout _args.activation_dropout = args.mbart_activation_dropout _args.max_source_positions = 1024 enc_emb = nn.Embedding( len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad() ) text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb) spch_encoder = Wav2VecEncoderWithAdaptor(args) if getattr(args, "load_pretrained_mbart_from", None): text_encoder = checkpoint_utils.load_pretrained_component_from_model( component=text_encoder, checkpoint=args.load_pretrained_mbart_from ) if getattr(args, "stack_w2v_mbart_encoder", False): assert getattr(args, "share_w2v_text_encoder", False) is False spch_encoder = StackedWav2VecEncoderWithAdaptor( spch_encoder.w2v_encoder, text_encoder.layers, text_encoder.layer_norm, spch_encoder.adaptor, args.drop_w2v_layers, ) elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False): text_encoder.layer_norm = None spch_encoder = StackedWav2VecEncoderWithAdaptor( spch_encoder.w2v_encoder, text_encoder.layers, text_encoder.layer_norm, spch_encoder.adaptor, args.drop_w2v_layers, ) elif getattr(args, "share_w2v_text_encoder", False): spch_encoder = SharedEncoder( spch_encoder.w2v_encoder, text_encoder, spch_encoder.adaptor, args.shared_w2v_layers, ) for k, p in spch_encoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_w2v_params" ) and need_finetuning(args.finetune_w2v_params, k): p.requires_grad = True else: p.requires_grad = False for k, p in text_encoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_mbart_encoder_params" ) and need_finetuning( args.finetune_mbart_encoder_params, k ): p.requires_grad = True else: p.requires_grad = False cross_attentive_loss_before_last_layer = ( 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 ) encoder = DualInputEncoder( args, spch_encoder, text_encoder, task.src_dict, cross_attentive_loss_before_last_layer, ) return encoder