def build_encoder(cls, args, task): # text_encoder = cls.build_text_encoder(args, task.source_dictionary ) text_encoder = cls.build_text_encoder(args, task.src_dict) speech_encoder = cls.build_speech_encoder(args) if args.load_pretrained_wav2vec_encoder: component_pairs = ( ("feature_extractor", speech_encoder.subsample), ("post_extract_proj", speech_encoder.feat_proj), ("layer_norm", speech_encoder.feat_layer_norm), ("encoder.pos_conv", speech_encoder.embed_positions), ("encoder.layers", speech_encoder.layers), ("encoder.layer_norm", speech_encoder.layer_norm), ("mask_emb", speech_encoder.mask_emb), ) state = cls.load_pretrained_speech_text_components( args.load_pretrained_wav2vec_encoder, component_pairs) cls.check_args( args.encoder_normalize_before == state["cfg"]["model"] ["layer_norm_first"], not args.no_strict_check_pretrain_model, f"encoder_normalize_before {args.encoder_normalize_before} doesn't match with the pretrained model", ) cls.check_args( args.activation_fn == state["cfg"]["model"]["activation_fn"], not args.no_strict_check_pretrain_model, f"activation_fn {args.activation_fn} doesn't match with the pretrained model", ) if getattr(args, "stacked_encoder", False): if args.encoder_shared_text_layers_from_begin > 0: raise ValueError( "We can not stack encoders and share encoders at the same time!" ) speech_encoder = StackedSpeechWavTransformerEncoder( speech_encoder, text_encoder.layers, text_encoder.layer_norm) else: cls.share_speech_text_encoder( speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin) cross_attentive_loss_before_last_layer = (0 if getattr( args, "attentive_cost_regularization", 0.0) > 0.0 else -1) encoder = DualInputEncoder( args, speech_encoder, text_encoder, task.src_dict, cross_attentive_loss_before_last_layer, ) if args.load_pretrained_speech_text_encoder: component_pairs = ( ("encoder.sup_s2s_speech_encoder", encoder.spch_encoder), ("encoder.text_encoder", encoder.text_encoder), ) cls.load_pretrained_speech_text_components( args.load_pretrained_speech_text_encoder, component_pairs) if getattr(args, "load_init_encoder", "") != "": checkpoint_utils.load_pretrained_component_from_model( encoder, args.load_init_encoder) return encoder
def build_decoder(cls, args, task): text_decoder = cls.build_text_decoder(args, task.target_dictionary) compute_cross_attentive_loss = (True if getattr( args, "attentive_cost_regularization", 0.0) > 0.0 else False) cross_attentive_loss_without_norm = getattr( args, "attentive_cost_without_normalize", False) cross_attentive_loss_reverse = ( False # getattr(args, "attentive_cost_reverse", False) ) if getattr(args, "load_pretrained_text_decoder", "") != "": checkpoint_utils.load_pretrained_component_from_model( text_decoder, args.load_pretrained_text_decoder) if args.load_pretrained_speech_text_decoder: component_pairs = (("decoder.text_decoder", text_decoder), ) cls.load_pretrained_speech_text_components( args.load_pretrained_speech_text_decoder, component_pairs) decoder = TransformerMultiInputDecoder( dictionary=task.target_dictionary, spch_decoder=text_decoder, text_decoder=text_decoder, compute_cross_attentive_loss=compute_cross_attentive_loss, cross_attentive_loss_with_norm=True if not cross_attentive_loss_without_norm else False, cross_attentive_loss_reverse=cross_attentive_loss_reverse, ) if getattr(args, "load_init_decoder", "") != "": checkpoint_utils.load_pretrained_component_from_model( decoder, args.load_init_decoder) return decoder
def get_encoder(src_lang, tgt_lang): if src_lang not in lang_encoders: lang_encoders[ src_lang] = TokenWiseConvolutionalTransformerEncoder( args, task.dicts[tgt_lang], audio_features=args.input_feat_per_channel, langs=task.langs) if args.pretrained_encoder is not None: checkpoint_utils.load_pretrained_component_from_model( lang_encoders[src_lang], args.pretrained_encoder, args.allow_partial_restore) return lang_encoders[src_lang]
def build_encoder(cls, args): encoder = ConvTransformerEncoder(args) if getattr(args, "load_pretrained_encoder_from", None): encoder = checkpoint_utils.load_pretrained_component_from_model( component=encoder, checkpoint=args.load_pretrained_encoder_from ) return encoder
def build_decoder(cls, args, task, embed_tokens): decoder = TransformerDecoderNoExtra(args, task.target_dictionary, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_decoder_from ) return decoder
def build_encoder(cls, args, task): spch_encoder = DualInputEncoder.build_spch_encoder(args) text_encoder = DualInputEncoder.build_text_encoder( args, task.src_dict, spch_encoder ) cross_attentive_loss_before_last_layer = ( 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 ) encoder = DualInputEncoder( args, spch_encoder, text_encoder, task.src_dict, cross_attentive_loss_before_last_layer, ) if args.init_scale != 1.0: with torch.no_grad(): for param in encoder.parameters(): param.data.mul_(args.init_scale) if args.load_pretrain_text_encoder != "": checkpoint_utils.load_pretrained_component_from_model( text_encoder, args.load_pretrain_text_encoder ) if args.load_pretrain_speech_encoder != "": if hasattr(spch_encoder, "encoder"): checkpoint_utils.load_pretrained_component_from_model( spch_encoder.encoder, args.load_pretrain_speech_encoder ) else: checkpoint_utils.load_pretrained_component_from_model( spch_encoder, args.load_pretrain_speech_encoder ) if ( args.load_pretrain_text_encoder_last != "" ): # if share encoder, speech encoder parameters will be used. # It provides a chance to use pre-trained mt encoder instead checkpoint_utils.load_pretrained_component_from_model( text_encoder, args.load_pretrain_text_encoder_last ) if args.load_pretrain_encoder != "": checkpoint_utils.load_pretrained_component_from_model( encoder, args.load_pretrain_encoder ) return encoder
def build_encoder(cls, args): encoder = S2TTransformerEncoder(args) if getattr(args, "load_pretrained_encoder_from", None): encoder = checkpoint_utils.load_pretrained_component_from_model( component=encoder, checkpoint=args.load_pretrained_encoder_from) logger.info(f"loaded pretrained encoder from: " f"{args.load_pretrained_encoder_from}") return encoder
def build_decoder(cls, args, task, embed_tokens): tgt_dict = task.tgt_dict decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_decoder_from) return decoder
def build_encoder(cls, args): encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args)) if getattr(args, "load_pretrained_encoder_from", None) is not None: encoder = checkpoint_utils.load_pretrained_component_from_model( component=encoder, checkpoint=args.load_pretrained_encoder_from) return encoder
def build_encoder(cls, args, dictionary): text_encoder = cls.build_text_encoder(args, dictionary) if getattr(args, "load_pretrained_mbart_encoder_from", None): text_encoder = checkpoint_utils.load_pretrained_component_from_model( component=text_encoder, checkpoint=args.load_pretrained_mbart_encoder_from, ) speech_encoder = cls.build_speech_encoder(args) if getattr(args, "load_pretrained_feature_extractor_from", None): def load_feature_extractor(component, checkpoint): if not PathManager.exists(checkpoint): raise IOError( "Model file not found: {}".format(checkpoint)) state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint) component_state_dict = OrderedDict() component_prefix = "feature_extractor" for key in state["model"].keys(): if key.startswith(component_prefix): component_subkey = key[len(component_prefix) + 1:] component_state_dict[component_subkey] = state[ "model"][key] component.load_state_dict(component_state_dict, strict=True) return component speech_encoder.subsample = load_feature_extractor( speech_encoder.subsample, args.load_pretrained_feature_extractor_from) speech_s2s_encoder = speech_encoder unsup_speech_encoder = cls.build_unsup_speech_encoder( args, speech_encoder) if getattr(args, "stacked_encoder", "none") != "none": if args.encoder_shared_text_layers_from_begin > 0: raise ValueError( "We can not stack encoders and share encoders at the same time!" ) speech_s2s_encoder = StackedSpeechWavTransformerEncoder( speech_encoder, text_encoder.layers, text_encoder.layer_norm) if args.stacked_encoder == "all": speech_encoder = speech_s2s_encoder unsup_speech_encoder = StackedSpeechWavTransformerEncoder( unsup_speech_encoder, text_encoder.layers, text_encoder.layer_norm) else: cls.share_speech_text_encoder( speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin) return SpeechTextPreTrainEncoder( dictionary, speech_encoder, speech_s2s_encoder, unsup_speech_encoder, text_encoder, )
def build_decoder(cls, cfg: Wav2Vec2Seq2SeqModConfig, tgt_dict, embed_tokens): decoder = TransformerDecoderMod(cfg, tgt_dict, embed_tokens) if getattr(cfg, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=cfg.load_pretrained_decoder_from ) logger.info( f"loaded pretrained decoder from: " f"{cfg.load_pretrained_decoder_from}" ) return decoder
def build_decoder(cls, args, text_dictionary, speech_dictionary, speech_output_embedding): text_decoder = cls.build_text_decoder(args, text_dictionary) speech_decoder = cls.build_dummy_speech_decoder( args, speech_dictionary, speech_output_embedding) if getattr(args, "load_pretrained_mbart_decoder_from", None): text_decoder = checkpoint_utils.load_pretrained_component_from_model( component=text_decoder, checkpoint=args.load_pretrained_mbart_decoder_from, ) return SpeechTextPreTrainDecoder(text_dictionary, speech_decoder, text_decoder)
def build_decoder(cls, args, task, embed_tokens): tgt_dict = task.tgt_dict from examples.simultaneous_translation.models.transformer_monotonic_attention import ( TransformerMonotonicDecoder, ) decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_decoder_from) return decoder
def build_encoder(cls, args): encoder = S2TTransformerEncoder(args) pretraining_path = getattr(args, "load_pretrained_encoder_from", None) if pretraining_path is not None: if not Path(pretraining_path).exists(): logger.warning( f"skipped pretraining because {pretraining_path} does not exist" ) else: encoder = checkpoint_utils.load_pretrained_component_from_model( component=encoder, checkpoint=pretraining_path) logger.info( f"loaded pretrained encoder from: {pretraining_path}") return encoder
def build_encoder(cls, args, task): encoder = BerardEncoder( input_layers=literal_eval(args.input_layers), conv_layers=literal_eval(args.conv_layers), in_channels=args.input_channels, input_feat_per_channel=args.input_feat_per_channel, num_blstm_layers=args.num_blstm_layers, lstm_size=args.lstm_size, dropout=args.dropout, ) if getattr(args, "load_pretrained_encoder_from", None): encoder = checkpoint_utils.load_pretrained_component_from_model( component=encoder, checkpoint=args.load_pretrained_encoder_from) return encoder
def build_decoder(cls, args, task): decoder = LSTMDecoder( dictionary=task.target_dictionary, embed_dim=args.decoder_embed_dim, num_layers=args.decoder_num_layers, hidden_size=args.decoder_hidden_dim, dropout=args.dropout, encoder_output_dim=2 * args.lstm_size, # bidirectional attention_dim=args.attention_dim, output_layer_dim=args.output_layer_dim, ) if getattr(args, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_decoder_from) return decoder
def build_decoder(cls, args, task): _args = copy.deepcopy(args) _args.dropout = args.mbart_dropout _args.attention_dropout = args.mbart_attention_dropout _args.activation_dropout = args.mbart_activation_dropout _args.max_target_positions = 1024 dec_emb = nn.Embedding( len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad() ) decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb) if getattr(args, "load_pretrained_mbart_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_mbart_from ) if getattr(args, "no_final_norm_decoder", False): decoder.layer_norm = None for k, p in decoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_mbart_decoder_params" ) and need_finetuning( args.finetune_mbart_decoder_params, k ): p.requires_grad = True else: p.requires_grad = False compute_cross_attentive_loss = ( True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False ) cross_attentive_loss_without_norm = getattr( args, "attentive_cost_without_normalize", False ) cross_attentive_loss_reverse = ( False # getattr(args, "attentive_cost_reverse", False) ) decoder = TransformerMultiInputDecoder( dictionary=task.target_dictionary, spch_decoder=decoder, text_decoder=decoder, compute_cross_attentive_loss=compute_cross_attentive_loss, cross_attentive_loss_with_norm=True if not cross_attentive_loss_without_norm else False, cross_attentive_loss_reverse=cross_attentive_loss_reverse, ) return decoder
def build_encoder(cls, args): print(args) data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) args.input_feat_per_channel = data_cfg.input_feat_per_channel args.input_channels = data_cfg.input_transformed_channels encoder = S2SConformerEncoder(args) pretraining_path = getattr(args, "load_pretrained_encoder_from", None) if pretraining_path is not None: if not Path(pretraining_path).exists(): logger.warning( f"skipped pretraining because {pretraining_path} does not exist" ) else: encoder = checkpoint_utils.load_pretrained_component_from_model( component=encoder, checkpoint=pretraining_path) logger.info( f"loaded pretrained encoder from: {pretraining_path}") return encoder
def build_decoder(cls, args, task, embed_tokens): _args = copy.deepcopy(args) _args.dropout = args.decoder_dropout _args.attention_dropout = args.decoder_attention_dropout _args.activation_dropout = args.decoder_activation_dropout _args.max_target_positions = 1024 decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_decoder_from) for k, p in decoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr(args, 'finetune_decoder_params' ) and XMTransformerModel.finetune_params( args.finetune_decoder_params, k): p.requires_grad = True else: p.requires_grad = False return decoder
def build_encoder(cls, args, task): _args = copy.deepcopy(args) _args.dropout = args.mbart_dropout _args.attention_dropout = args.mbart_attention_dropout _args.activation_dropout = args.mbart_activation_dropout _args.max_source_positions = 1024 enc_emb = nn.Embedding( len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad() ) text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb) spch_encoder = Wav2VecEncoderWithAdaptor(args) if getattr(args, "load_pretrained_mbart_from", None): text_encoder = checkpoint_utils.load_pretrained_component_from_model( component=text_encoder, checkpoint=args.load_pretrained_mbart_from ) if getattr(args, "stack_w2v_mbart_encoder", False): assert getattr(args, "share_w2v_text_encoder", False) is False spch_encoder = StackedWav2VecEncoderWithAdaptor( spch_encoder.w2v_encoder, text_encoder.layers, text_encoder.layer_norm, spch_encoder.adaptor, args.drop_w2v_layers, ) elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False): text_encoder.layer_norm = None spch_encoder = StackedWav2VecEncoderWithAdaptor( spch_encoder.w2v_encoder, text_encoder.layers, text_encoder.layer_norm, spch_encoder.adaptor, args.drop_w2v_layers, ) elif getattr(args, "share_w2v_text_encoder", False): spch_encoder = SharedEncoder( spch_encoder.w2v_encoder, text_encoder, spch_encoder.adaptor, args.shared_w2v_layers, ) for k, p in spch_encoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_w2v_params" ) and need_finetuning(args.finetune_w2v_params, k): p.requires_grad = True else: p.requires_grad = False for k, p in text_encoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_mbart_encoder_params" ) and need_finetuning( args.finetune_mbart_encoder_params, k ): p.requires_grad = True else: p.requires_grad = False cross_attentive_loss_before_last_layer = ( 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 ) encoder = DualInputEncoder( args, spch_encoder, text_encoder, task.src_dict, cross_attentive_loss_before_last_layer, ) return encoder
def build_decoder(cls, args, task): dec_cfg = { "decoder_layerdrop": args.decoder_layerdrop, "share_decoder_input_output_embed": args.share_decoder_input_output_embed, "decoder_embed_dim": args.decoder_embed_dim, "max_target_positions": args.max_target_positions, "dropout": args.dropout, "encoder_learned_pos": args.encoder_learned_pos, "decoder_learned_pos": args.decoder_learned_pos, "layernorm_embedding": args.layernorm_embedding, "decoder_normalize_before": args.decoder_normalize_before, "activation_dropout": args.activation_dropout, "attention_dropout": args.attention_dropout, "decoder_ffn_embed_dim": args.decoder_ffn_embed_dim, "decoder_layers": args.decoder_layers, "decoder_attention_heads": args.decoder_attention_heads, "decoder_output_dim": args.decoder_embed_dim, "no_scale_embedding": args.no_scale_embedding, "adaptive_input": args.adaptive_input, "quant_noise_pq": args.quant_noise_pq, "adaptive_softmax_cutoff": args.adaptive_softmax_cutoff, "tie_adaptive_weights": args.tie_adaptive_weights, "no_token_positional_embeddings": args.no_token_positional_embeddings, } dec_cfg = namedtuple("args", dec_cfg.keys())(*dec_cfg.values()) dec_emb = nn.Embedding( len(task.target_dictionary), args.decoder_embed_dim, task.target_dictionary.pad(), ) compute_cross_attentive_loss = (True if getattr( args, "attentive_cost_regularization", 0.0) > 0.0 else False) cross_attentive_loss_without_norm = getattr( args, "attentive_cost_without_normalize", False) cross_attentive_loss_reverse = ( False # getattr(args, "attentive_cost_reverse", False) ) text_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb) spch_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb) spch_decoder = TransformerMultiInputDecoder.share_spchdecoder( args, text_decoder, spch_decoder) decoder = TransformerMultiInputDecoder( dictionary=task.target_dictionary, spch_decoder=spch_decoder, text_decoder=text_decoder, compute_cross_attentive_loss=compute_cross_attentive_loss, cross_attentive_loss_with_norm=True if not cross_attentive_loss_without_norm else False, cross_attentive_loss_reverse=cross_attentive_loss_reverse, ) if args.init_scale != 1.0: with torch.no_grad(): for param in decoder.parameters(): param.data.mul_(args.init_scale) if args.load_pretrain_decoder != "": try: checkpoint_utils.load_pretrained_component_from_model( decoder, args.load_pretrain_decoder) except RuntimeError: checkpoint_utils.load_pretrained_component_from_model( decoder.text_decoder, args.load_pretrain_decoder) if args.decoder_shared_layer_level > 0: checkpoint_utils.load_pretrained_component_from_model( decoder.spch_decoder, args.load_pretrain_decoder) return decoder