예제 #1
0
    def build_encoder(cls, args, task):
        # text_encoder = cls.build_text_encoder(args, task.source_dictionary )
        text_encoder = cls.build_text_encoder(args, task.src_dict)
        speech_encoder = cls.build_speech_encoder(args)
        if args.load_pretrained_wav2vec_encoder:
            component_pairs = (
                ("feature_extractor", speech_encoder.subsample),
                ("post_extract_proj", speech_encoder.feat_proj),
                ("layer_norm", speech_encoder.feat_layer_norm),
                ("encoder.pos_conv", speech_encoder.embed_positions),
                ("encoder.layers", speech_encoder.layers),
                ("encoder.layer_norm", speech_encoder.layer_norm),
                ("mask_emb", speech_encoder.mask_emb),
            )
            state = cls.load_pretrained_speech_text_components(
                args.load_pretrained_wav2vec_encoder, component_pairs)
            cls.check_args(
                args.encoder_normalize_before == state["cfg"]["model"]
                ["layer_norm_first"],
                not args.no_strict_check_pretrain_model,
                f"encoder_normalize_before {args.encoder_normalize_before} doesn't match with the pretrained model",
            )
            cls.check_args(
                args.activation_fn == state["cfg"]["model"]["activation_fn"],
                not args.no_strict_check_pretrain_model,
                f"activation_fn {args.activation_fn} doesn't match with the pretrained model",
            )

        if getattr(args, "stacked_encoder", False):
            if args.encoder_shared_text_layers_from_begin > 0:
                raise ValueError(
                    "We can not stack encoders and share encoders at the same time!"
                )
            speech_encoder = StackedSpeechWavTransformerEncoder(
                speech_encoder, text_encoder.layers, text_encoder.layer_norm)
        else:
            cls.share_speech_text_encoder(
                speech_encoder, text_encoder,
                args.encoder_shared_text_layers_from_begin)

        cross_attentive_loss_before_last_layer = (0 if getattr(
            args, "attentive_cost_regularization", 0.0) > 0.0 else -1)
        encoder = DualInputEncoder(
            args,
            speech_encoder,
            text_encoder,
            task.src_dict,
            cross_attentive_loss_before_last_layer,
        )
        if args.load_pretrained_speech_text_encoder:
            component_pairs = (
                ("encoder.sup_s2s_speech_encoder", encoder.spch_encoder),
                ("encoder.text_encoder", encoder.text_encoder),
            )
            cls.load_pretrained_speech_text_components(
                args.load_pretrained_speech_text_encoder, component_pairs)
        if getattr(args, "load_init_encoder", "") != "":
            checkpoint_utils.load_pretrained_component_from_model(
                encoder, args.load_init_encoder)
        return encoder
예제 #2
0
    def build_decoder(cls, args, task):
        text_decoder = cls.build_text_decoder(args, task.target_dictionary)
        compute_cross_attentive_loss = (True if getattr(
            args, "attentive_cost_regularization", 0.0) > 0.0 else False)
        cross_attentive_loss_without_norm = getattr(
            args, "attentive_cost_without_normalize", False)
        cross_attentive_loss_reverse = (
            False  # getattr(args, "attentive_cost_reverse", False)
        )
        if getattr(args, "load_pretrained_text_decoder", "") != "":
            checkpoint_utils.load_pretrained_component_from_model(
                text_decoder, args.load_pretrained_text_decoder)

        if args.load_pretrained_speech_text_decoder:
            component_pairs = (("decoder.text_decoder", text_decoder), )
            cls.load_pretrained_speech_text_components(
                args.load_pretrained_speech_text_decoder, component_pairs)

        decoder = TransformerMultiInputDecoder(
            dictionary=task.target_dictionary,
            spch_decoder=text_decoder,
            text_decoder=text_decoder,
            compute_cross_attentive_loss=compute_cross_attentive_loss,
            cross_attentive_loss_with_norm=True
            if not cross_attentive_loss_without_norm else False,
            cross_attentive_loss_reverse=cross_attentive_loss_reverse,
        )
        if getattr(args, "load_init_decoder", "") != "":
            checkpoint_utils.load_pretrained_component_from_model(
                decoder, args.load_init_decoder)
        return decoder
예제 #3
0
 def get_encoder(src_lang, tgt_lang):
     if src_lang not in lang_encoders:
         lang_encoders[
             src_lang] = TokenWiseConvolutionalTransformerEncoder(
                 args,
                 task.dicts[tgt_lang],
                 audio_features=args.input_feat_per_channel,
                 langs=task.langs)
         if args.pretrained_encoder is not None:
             checkpoint_utils.load_pretrained_component_from_model(
                 lang_encoders[src_lang], args.pretrained_encoder,
                 args.allow_partial_restore)
     return lang_encoders[src_lang]
예제 #4
0
 def build_encoder(cls, args):
     encoder = ConvTransformerEncoder(args)
     if getattr(args, "load_pretrained_encoder_from", None):
         encoder = checkpoint_utils.load_pretrained_component_from_model(
             component=encoder, checkpoint=args.load_pretrained_encoder_from
         )
     return encoder
예제 #5
0
 def build_decoder(cls, args, task, embed_tokens):
     decoder = TransformerDecoderNoExtra(args, task.target_dictionary, embed_tokens)
     if getattr(args, "load_pretrained_decoder_from", None):
         decoder = checkpoint_utils.load_pretrained_component_from_model(
             component=decoder, checkpoint=args.load_pretrained_decoder_from
         )
     return decoder
예제 #6
0
    def build_encoder(cls, args, task):
        spch_encoder = DualInputEncoder.build_spch_encoder(args)
        text_encoder = DualInputEncoder.build_text_encoder(
            args, task.src_dict, spch_encoder
        )
        cross_attentive_loss_before_last_layer = (
            0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
        )
        encoder = DualInputEncoder(
            args,
            spch_encoder,
            text_encoder,
            task.src_dict,
            cross_attentive_loss_before_last_layer,
        )
        if args.init_scale != 1.0:
            with torch.no_grad():
                for param in encoder.parameters():
                    param.data.mul_(args.init_scale)
        if args.load_pretrain_text_encoder != "":
            checkpoint_utils.load_pretrained_component_from_model(
                text_encoder, args.load_pretrain_text_encoder
            )
        if args.load_pretrain_speech_encoder != "":
            if hasattr(spch_encoder, "encoder"):
                checkpoint_utils.load_pretrained_component_from_model(
                    spch_encoder.encoder, args.load_pretrain_speech_encoder
                )
            else:
                checkpoint_utils.load_pretrained_component_from_model(
                    spch_encoder, args.load_pretrain_speech_encoder
                )
        if (
            args.load_pretrain_text_encoder_last != ""
        ):  # if share encoder, speech encoder parameters will be used.
            # It provides a chance to use pre-trained mt encoder instead
            checkpoint_utils.load_pretrained_component_from_model(
                text_encoder, args.load_pretrain_text_encoder_last
            )

        if args.load_pretrain_encoder != "":
            checkpoint_utils.load_pretrained_component_from_model(
                encoder, args.load_pretrain_encoder
            )
        return encoder
예제 #7
0
 def build_encoder(cls, args):
     encoder = S2TTransformerEncoder(args)
     if getattr(args, "load_pretrained_encoder_from", None):
         encoder = checkpoint_utils.load_pretrained_component_from_model(
             component=encoder,
             checkpoint=args.load_pretrained_encoder_from)
         logger.info(f"loaded pretrained encoder from: "
                     f"{args.load_pretrained_encoder_from}")
     return encoder
    def build_decoder(cls, args, task, embed_tokens):
        tgt_dict = task.tgt_dict

        decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)

        if getattr(args, "load_pretrained_decoder_from", None):
            decoder = checkpoint_utils.load_pretrained_component_from_model(
                component=decoder,
                checkpoint=args.load_pretrained_decoder_from)
        return decoder
예제 #9
0
    def build_encoder(cls, args):
        encoder = SequenceEncoder(args,
                                  AugmentedMemoryConvTransformerEncoder(args))

        if getattr(args, "load_pretrained_encoder_from", None) is not None:
            encoder = checkpoint_utils.load_pretrained_component_from_model(
                component=encoder,
                checkpoint=args.load_pretrained_encoder_from)

        return encoder
예제 #10
0
    def build_encoder(cls, args, dictionary):
        text_encoder = cls.build_text_encoder(args, dictionary)
        if getattr(args, "load_pretrained_mbart_encoder_from", None):
            text_encoder = checkpoint_utils.load_pretrained_component_from_model(
                component=text_encoder,
                checkpoint=args.load_pretrained_mbart_encoder_from,
            )
        speech_encoder = cls.build_speech_encoder(args)
        if getattr(args, "load_pretrained_feature_extractor_from", None):

            def load_feature_extractor(component, checkpoint):
                if not PathManager.exists(checkpoint):
                    raise IOError(
                        "Model file not found: {}".format(checkpoint))
                state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
                component_state_dict = OrderedDict()

                component_prefix = "feature_extractor"
                for key in state["model"].keys():
                    if key.startswith(component_prefix):
                        component_subkey = key[len(component_prefix) + 1:]
                        component_state_dict[component_subkey] = state[
                            "model"][key]
                component.load_state_dict(component_state_dict, strict=True)
                return component

            speech_encoder.subsample = load_feature_extractor(
                speech_encoder.subsample,
                args.load_pretrained_feature_extractor_from)
        speech_s2s_encoder = speech_encoder
        unsup_speech_encoder = cls.build_unsup_speech_encoder(
            args, speech_encoder)
        if getattr(args, "stacked_encoder", "none") != "none":
            if args.encoder_shared_text_layers_from_begin > 0:
                raise ValueError(
                    "We can not stack encoders and share encoders at the same time!"
                )
            speech_s2s_encoder = StackedSpeechWavTransformerEncoder(
                speech_encoder, text_encoder.layers, text_encoder.layer_norm)
            if args.stacked_encoder == "all":
                speech_encoder = speech_s2s_encoder
                unsup_speech_encoder = StackedSpeechWavTransformerEncoder(
                    unsup_speech_encoder, text_encoder.layers,
                    text_encoder.layer_norm)
        else:
            cls.share_speech_text_encoder(
                speech_encoder, text_encoder,
                args.encoder_shared_text_layers_from_begin)
        return SpeechTextPreTrainEncoder(
            dictionary,
            speech_encoder,
            speech_s2s_encoder,
            unsup_speech_encoder,
            text_encoder,
        )
예제 #11
0
 def build_decoder(cls, cfg: Wav2Vec2Seq2SeqModConfig, tgt_dict, embed_tokens):
     decoder = TransformerDecoderMod(cfg, tgt_dict, embed_tokens)
     if getattr(cfg, "load_pretrained_decoder_from", None):
         decoder = checkpoint_utils.load_pretrained_component_from_model(
             component=decoder, checkpoint=cfg.load_pretrained_decoder_from
         )
         logger.info(
             f"loaded pretrained decoder from: "
             f"{cfg.load_pretrained_decoder_from}"
         )
     return decoder
예제 #12
0
 def build_decoder(cls, args, text_dictionary, speech_dictionary,
                   speech_output_embedding):
     text_decoder = cls.build_text_decoder(args, text_dictionary)
     speech_decoder = cls.build_dummy_speech_decoder(
         args, speech_dictionary, speech_output_embedding)
     if getattr(args, "load_pretrained_mbart_decoder_from", None):
         text_decoder = checkpoint_utils.load_pretrained_component_from_model(
             component=text_decoder,
             checkpoint=args.load_pretrained_mbart_decoder_from,
         )
     return SpeechTextPreTrainDecoder(text_dictionary, speech_decoder,
                                      text_decoder)
    def build_decoder(cls, args, task, embed_tokens):
        tgt_dict = task.tgt_dict

        from examples.simultaneous_translation.models.transformer_monotonic_attention import (
            TransformerMonotonicDecoder, )

        decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)

        if getattr(args, "load_pretrained_decoder_from", None):
            decoder = checkpoint_utils.load_pretrained_component_from_model(
                component=decoder,
                checkpoint=args.load_pretrained_decoder_from)
        return decoder
예제 #14
0
 def build_encoder(cls, args):
     encoder = S2TTransformerEncoder(args)
     pretraining_path = getattr(args, "load_pretrained_encoder_from", None)
     if pretraining_path is not None:
         if not Path(pretraining_path).exists():
             logger.warning(
                 f"skipped pretraining because {pretraining_path} does not exist"
             )
         else:
             encoder = checkpoint_utils.load_pretrained_component_from_model(
                 component=encoder, checkpoint=pretraining_path)
             logger.info(
                 f"loaded pretrained encoder from: {pretraining_path}")
     return encoder
예제 #15
0
파일: berard.py 프로젝트: veralily/fairseq
 def build_encoder(cls, args, task):
     encoder = BerardEncoder(
         input_layers=literal_eval(args.input_layers),
         conv_layers=literal_eval(args.conv_layers),
         in_channels=args.input_channels,
         input_feat_per_channel=args.input_feat_per_channel,
         num_blstm_layers=args.num_blstm_layers,
         lstm_size=args.lstm_size,
         dropout=args.dropout,
     )
     if getattr(args, "load_pretrained_encoder_from", None):
         encoder = checkpoint_utils.load_pretrained_component_from_model(
             component=encoder,
             checkpoint=args.load_pretrained_encoder_from)
     return encoder
예제 #16
0
파일: berard.py 프로젝트: veralily/fairseq
 def build_decoder(cls, args, task):
     decoder = LSTMDecoder(
         dictionary=task.target_dictionary,
         embed_dim=args.decoder_embed_dim,
         num_layers=args.decoder_num_layers,
         hidden_size=args.decoder_hidden_dim,
         dropout=args.dropout,
         encoder_output_dim=2 * args.lstm_size,  # bidirectional
         attention_dim=args.attention_dim,
         output_layer_dim=args.output_layer_dim,
     )
     if getattr(args, "load_pretrained_decoder_from", None):
         decoder = checkpoint_utils.load_pretrained_component_from_model(
             component=decoder,
             checkpoint=args.load_pretrained_decoder_from)
     return decoder
예제 #17
0
    def build_decoder(cls, args, task):
        _args = copy.deepcopy(args)
        _args.dropout = args.mbart_dropout
        _args.attention_dropout = args.mbart_attention_dropout
        _args.activation_dropout = args.mbart_activation_dropout
        _args.max_target_positions = 1024
        dec_emb = nn.Embedding(
            len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad()
        )
        decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb)
        if getattr(args, "load_pretrained_mbart_from", None):
            decoder = checkpoint_utils.load_pretrained_component_from_model(
                component=decoder, checkpoint=args.load_pretrained_mbart_from
            )
        if getattr(args, "no_final_norm_decoder", False):
            decoder.layer_norm = None
        for k, p in decoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_mbart_decoder_params"
            ) and need_finetuning(
                args.finetune_mbart_decoder_params, k
            ):
                p.requires_grad = True
            else:
                p.requires_grad = False

        compute_cross_attentive_loss = (
            True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
        )
        cross_attentive_loss_without_norm = getattr(
            args, "attentive_cost_without_normalize", False
        )
        cross_attentive_loss_reverse = (
            False  # getattr(args, "attentive_cost_reverse", False)
        )
        decoder = TransformerMultiInputDecoder(
            dictionary=task.target_dictionary,
            spch_decoder=decoder,
            text_decoder=decoder,
            compute_cross_attentive_loss=compute_cross_attentive_loss,
            cross_attentive_loss_with_norm=True
            if not cross_attentive_loss_without_norm
            else False,
            cross_attentive_loss_reverse=cross_attentive_loss_reverse,
        )
        return decoder
예제 #18
0
    def build_encoder(cls, args):
        print(args)
        data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml)
        args.input_feat_per_channel = data_cfg.input_feat_per_channel
        args.input_channels = data_cfg.input_transformed_channels

        encoder = S2SConformerEncoder(args)
        pretraining_path = getattr(args, "load_pretrained_encoder_from", None)
        if pretraining_path is not None:
            if not Path(pretraining_path).exists():
                logger.warning(
                    f"skipped pretraining because {pretraining_path} does not exist"
                )
            else:
                encoder = checkpoint_utils.load_pretrained_component_from_model(
                    component=encoder, checkpoint=pretraining_path)
                logger.info(
                    f"loaded pretrained encoder from: {pretraining_path}")
        return encoder
예제 #19
0
    def build_decoder(cls, args, task, embed_tokens):
        _args = copy.deepcopy(args)
        _args.dropout = args.decoder_dropout
        _args.attention_dropout = args.decoder_attention_dropout
        _args.activation_dropout = args.decoder_activation_dropout
        _args.max_target_positions = 1024

        decoder = TransformerDecoder(_args, task.target_dictionary,
                                     embed_tokens)
        if getattr(args, "load_pretrained_decoder_from", None):
            decoder = checkpoint_utils.load_pretrained_component_from_model(
                component=decoder,
                checkpoint=args.load_pretrained_decoder_from)
        for k, p in decoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(args, 'finetune_decoder_params'
                            ) and XMTransformerModel.finetune_params(
                                args.finetune_decoder_params, k):
                p.requires_grad = True
            else:
                p.requires_grad = False
        return decoder
예제 #20
0
    def build_encoder(cls, args, task):
        _args = copy.deepcopy(args)
        _args.dropout = args.mbart_dropout
        _args.attention_dropout = args.mbart_attention_dropout
        _args.activation_dropout = args.mbart_activation_dropout
        _args.max_source_positions = 1024
        enc_emb = nn.Embedding(
            len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad()
        )
        text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb)
        spch_encoder = Wav2VecEncoderWithAdaptor(args)
        if getattr(args, "load_pretrained_mbart_from", None):
            text_encoder = checkpoint_utils.load_pretrained_component_from_model(
                component=text_encoder, checkpoint=args.load_pretrained_mbart_from
            )
        if getattr(args, "stack_w2v_mbart_encoder", False):
            assert getattr(args, "share_w2v_text_encoder", False) is False
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False):
            text_encoder.layer_norm = None
            spch_encoder = StackedWav2VecEncoderWithAdaptor(
                spch_encoder.w2v_encoder,
                text_encoder.layers,
                text_encoder.layer_norm,
                spch_encoder.adaptor,
                args.drop_w2v_layers,
            )
        elif getattr(args, "share_w2v_text_encoder", False):
            spch_encoder = SharedEncoder(
                spch_encoder.w2v_encoder,
                text_encoder,
                spch_encoder.adaptor,
                args.shared_w2v_layers,
            )

        for k, p in spch_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_w2v_params"
            ) and need_finetuning(args.finetune_w2v_params, k):
                p.requires_grad = True
            else:
                p.requires_grad = False
        for k, p in text_encoder.named_parameters():
            # Freeze pretrained models by default
            if safe_hasattr(
                args, "finetune_mbart_encoder_params"
            ) and need_finetuning(
                args.finetune_mbart_encoder_params, k
            ):
                p.requires_grad = True
            else:
                p.requires_grad = False
        cross_attentive_loss_before_last_layer = (
            0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
        )
        encoder = DualInputEncoder(
            args,
            spch_encoder,
            text_encoder,
            task.src_dict,
            cross_attentive_loss_before_last_layer,
        )
        return encoder
예제 #21
0
    def build_decoder(cls, args, task):
        dec_cfg = {
            "decoder_layerdrop": args.decoder_layerdrop,
            "share_decoder_input_output_embed":
            args.share_decoder_input_output_embed,
            "decoder_embed_dim": args.decoder_embed_dim,
            "max_target_positions": args.max_target_positions,
            "dropout": args.dropout,
            "encoder_learned_pos": args.encoder_learned_pos,
            "decoder_learned_pos": args.decoder_learned_pos,
            "layernorm_embedding": args.layernorm_embedding,
            "decoder_normalize_before": args.decoder_normalize_before,
            "activation_dropout": args.activation_dropout,
            "attention_dropout": args.attention_dropout,
            "decoder_ffn_embed_dim": args.decoder_ffn_embed_dim,
            "decoder_layers": args.decoder_layers,
            "decoder_attention_heads": args.decoder_attention_heads,
            "decoder_output_dim": args.decoder_embed_dim,
            "no_scale_embedding": args.no_scale_embedding,
            "adaptive_input": args.adaptive_input,
            "quant_noise_pq": args.quant_noise_pq,
            "adaptive_softmax_cutoff": args.adaptive_softmax_cutoff,
            "tie_adaptive_weights": args.tie_adaptive_weights,
            "no_token_positional_embeddings":
            args.no_token_positional_embeddings,
        }
        dec_cfg = namedtuple("args", dec_cfg.keys())(*dec_cfg.values())
        dec_emb = nn.Embedding(
            len(task.target_dictionary),
            args.decoder_embed_dim,
            task.target_dictionary.pad(),
        )
        compute_cross_attentive_loss = (True if getattr(
            args, "attentive_cost_regularization", 0.0) > 0.0 else False)
        cross_attentive_loss_without_norm = getattr(
            args, "attentive_cost_without_normalize", False)
        cross_attentive_loss_reverse = (
            False  # getattr(args, "attentive_cost_reverse", False)
        )

        text_decoder = TransformerDecoder(dec_cfg, task.target_dictionary,
                                          dec_emb)
        spch_decoder = TransformerDecoder(dec_cfg, task.target_dictionary,
                                          dec_emb)
        spch_decoder = TransformerMultiInputDecoder.share_spchdecoder(
            args, text_decoder, spch_decoder)
        decoder = TransformerMultiInputDecoder(
            dictionary=task.target_dictionary,
            spch_decoder=spch_decoder,
            text_decoder=text_decoder,
            compute_cross_attentive_loss=compute_cross_attentive_loss,
            cross_attentive_loss_with_norm=True
            if not cross_attentive_loss_without_norm else False,
            cross_attentive_loss_reverse=cross_attentive_loss_reverse,
        )
        if args.init_scale != 1.0:
            with torch.no_grad():
                for param in decoder.parameters():
                    param.data.mul_(args.init_scale)
        if args.load_pretrain_decoder != "":
            try:
                checkpoint_utils.load_pretrained_component_from_model(
                    decoder, args.load_pretrain_decoder)
            except RuntimeError:
                checkpoint_utils.load_pretrained_component_from_model(
                    decoder.text_decoder, args.load_pretrain_decoder)
                if args.decoder_shared_layer_level > 0:
                    checkpoint_utils.load_pretrained_component_from_model(
                        decoder.spch_decoder, args.load_pretrain_decoder)

        return decoder