Beispiel #1
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
        assert isinstance(task, MultilingualTranslationTask)

        # make sure all arguments are present in older models
        base_multilingual_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        src_langs = [
            lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs
        ]
        tgt_langs = [
            lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs
        ]

        if args.share_encoders:
            args.share_encoder_embeddings = True
        if args.share_decoders:
            args.share_decoder_embeddings = True

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        # build shared embeddings (if applicable)
        shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
        if args.share_all_embeddings:
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim'
                )
            if args.decoder_embed_path and (args.decoder_embed_path !=
                                            args.encoder_embed_path):
                raise ValueError(
                    '--share-all-embeddings not compatible with --decoder-embed-path'
                )
            shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=task.langs,
                embed_dim=args.encoder_embed_dim,
                build_embedding=build_embedding,
                pretrained_embed_path=args.encoder_embed_path,
            )
            shared_decoder_embed_tokens = shared_encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            if args.share_encoder_embeddings:
                shared_encoder_embed_tokens = (
                    FairseqMultiModel.build_shared_embeddings(
                        dicts=task.dicts,
                        langs=src_langs,
                        embed_dim=args.encoder_embed_dim,
                        build_embedding=build_embedding,
                        pretrained_embed_path=args.encoder_embed_path,
                    ))
            if args.share_decoder_embeddings:
                shared_decoder_embed_tokens = (
                    FairseqMultiModel.build_shared_embeddings(
                        dicts=task.dicts,
                        langs=tgt_langs,
                        embed_dim=args.decoder_embed_dim,
                        build_embedding=build_embedding,
                        pretrained_embed_path=args.decoder_embed_path,
                    ))

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}

        def get_encoder(lang):
            if lang not in lang_encoders:
                if shared_encoder_embed_tokens is not None:
                    encoder_embed_tokens = shared_encoder_embed_tokens
                else:
                    encoder_embed_tokens = build_embedding(
                        task.dicts[lang], args.encoder_embed_dim,
                        args.encoder_embed_path)
                lang_encoders[lang] = cls._get_module_class(
                    True, args, task.dicts[lang], encoder_embed_tokens,
                    src_langs)
            return lang_encoders[lang]

        def get_decoder(lang):
            if lang not in lang_decoders:
                if shared_decoder_embed_tokens is not None:
                    decoder_embed_tokens = shared_decoder_embed_tokens
                else:
                    decoder_embed_tokens = build_embedding(
                        task.dicts[lang], args.decoder_embed_dim,
                        args.decoder_embed_path)
                lang_decoders[lang] = cls._get_module_class(
                    False, args, task.dicts[lang], decoder_embed_tokens,
                    tgt_langs)
            return lang_decoders[lang]

        # shared encoders/decoders (if applicable)
        shared_encoder, shared_decoder = None, None
        if args.share_encoders:
            shared_encoder = get_encoder(src_langs[0])
        if args.share_decoders:
            shared_decoder = get_decoder(tgt_langs[0])

        encoders, decoders = OrderedDict(), OrderedDict()
        for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs,
                                       tgt_langs):
            encoders[
                lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(
                    src)
            decoders[
                lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(
                    tgt)

        return MultilingualTransformerModel(encoders, decoders)
Beispiel #2
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
        assert isinstance(task, MultilingualTranslationTask)

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 100000
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 100000

        src_langs = [
            lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs
        ]
        tgt_langs = [
            lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs
        ]

        if args.share_decoders:
            args.share_decoder_embeddings = True

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        # build shared embeddings (if applicable)
        shared_decoder_embed_tokens = None
        if args.share_decoder_embeddings:
            shared_decoder_embed_tokens = (
                FairseqMultiModel.build_shared_embeddings(
                    dicts=task.dicts,
                    langs=tgt_langs,
                    embed_dim=args.decoder_embed_dim,
                    build_embedding=build_embedding,
                    pretrained_embed_path=args.decoder_embed_path,
                ))

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}

        def get_encoder(src_lang, tgt_lang):
            if src_lang not in lang_encoders:
                lang_encoders[
                    src_lang] = TokenWiseConvolutionalTransformerEncoder(
                        args,
                        task.dicts[tgt_lang],
                        audio_features=args.input_feat_per_channel,
                        langs=task.langs)
                if args.pretrained_encoder is not None:
                    checkpoint_utils.load_pretrained_component_from_model(
                        lang_encoders[src_lang], args.pretrained_encoder,
                        args.allow_partial_restore)
            return lang_encoders[src_lang]

        def get_decoder(lang):
            if lang not in lang_decoders:
                if shared_decoder_embed_tokens is not None:
                    decoder_embed_tokens = shared_decoder_embed_tokens
                else:
                    decoder_embed_tokens = build_embedding(
                        task.dicts[lang], args.decoder_embed_dim,
                        args.decoder_embed_path)
                lang_decoders[lang] = TokenWiseTransformerDecoder(
                    args, task.dicts[lang], decoder_embed_tokens,
                    task.args.target_lang)
                if args.pretrained_decoder is not None:
                    decoder_loaded_state = load_checkpoint_to_cpu(
                        args.pretrained_decoder)
                    if args.encoder_langtok is not None or args.decoder_langtok:

                        def resize_model_to_new_dict(weights_tensor):
                            old_shape = weights_tensor.shape
                            new_tensor = weights_tensor.new_empty(
                                (old_shape[0] + len(task.langs), old_shape[1]))
                            nn.init.xavier_uniform_(
                                new_tensor,
                                gain=nn.init.calculate_gain('relu'))
                            new_tensor[:old_shape[0], :] = weights_tensor
                            return new_tensor

                        decoder_embed_tokens_key = "decoder.embed_tokens.weight"
                        if args.decoder_langtok and args.langtok_merge_strategy == "sum":
                            decoder_embed_tokens_key = "decoder.embed_tokens.base_embeddings.weight"

                        decoder_loaded_state["model"][
                            decoder_embed_tokens_key] = resize_model_to_new_dict(
                                decoder_loaded_state["model"]
                                ["decoder.embed_tokens.weight"])
                        decoder_loaded_state["model"][
                            "decoder.output_projection.weight"] = resize_model_to_new_dict(
                                decoder_loaded_state["model"]
                                ["decoder.output_projection.weight"])
                        if args.decoder_langtok and args.langtok_merge_strategy == "sum":
                            del decoder_loaded_state["model"][
                                "decoder.embed_tokens.weight"]
                    new_component_state_dict = OrderedDict()
                    for key in decoder_loaded_state["model"].keys():
                        if key.startswith("decoder"):
                            # decoder.input_layers.0.0.weight --> input_layers.0.0.weight
                            component_subkey = key[8:]
                            new_component_state_dict[
                                component_subkey] = decoder_loaded_state[
                                    "model"][key]
                    incompatible_keys = lang_decoders[lang].load_state_dict(
                        new_component_state_dict,
                        strict=(not args.allow_partial_restore))
                    if len(incompatible_keys.unexpected_keys) != 0:
                        logger.error(
                            "Cannot load the following keys from checkpoint: {}"
                            .format(incompatible_keys.unexpected_keys))
                        raise ValueError(
                            "Cannot load from checkpoint: {}".format(
                                args.pretrained_decoder))
                    if len(incompatible_keys.missing_keys) > 0:
                        logger.info(
                            "Loaded checkpoint misses the parameters: {}".
                            format(incompatible_keys.missing_keys))
            return lang_decoders[lang]

        # shared encoders/decoders (if applicable)
        shared_encoder, shared_decoder = None, None
        if args.share_encoders:
            shared_encoder = get_encoder(src_langs[0], tgt_langs[0])
        if args.share_decoders:
            shared_decoder = get_decoder(tgt_langs[0])

        encoders, decoders = OrderedDict(), OrderedDict()
        for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs,
                                       tgt_langs):
            encoders[
                lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(
                    src, tgt)
            decoders[
                lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(
                    tgt)

        return MultilingualConvolutionalTransformerModel(encoders, decoders)
Beispiel #3
0
    def build_model(self, args, task, cls_dictionary=MaskedLMDictionary):
        """Build a new model instance."""
        from fairseq.tasks.multilingual_summarization import MultilingualSummarization
        assert isinstance(task, MultilingualSummarization)
        assert hasattr(args, "pretrained_checkpoint"), (
            "You must specify a path for --pretrained-checkpoint to use ")
        assert isinstance(
            task.source_dictionary, cls_dictionary) and isinstance(
                task.target_dictionary, cls_dictionary), (
                    "You should use a MaskedLMDictionary when using --arch. ")
        assert not (
            getattr(args, "init_encoder_only", False)
            and getattr(args, "init_decoder_only", False)
        ), "Only one of --init-encoder-only and --init-decoder-only can be set."

        # return super().build_model(args, task)

        # make sure all arguments are present in older models
        multilingual_base_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = args.tokens_per_sample
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = args.tokens_per_sample

        # args.max_source_positions = 256
        # args.max_target_positions = 256

        src_langs = [
            lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs
        ]
        tgt_langs = [
            lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs
        ]

        if args.share_encoders:
            args.share_encoder_embeddings = True
        if args.share_decoders:
            args.share_decoder_embeddings = True

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        # build shared embeddings (if applicable)
        shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
        if args.share_all_embeddings:
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim'
                )
            if args.decoder_embed_path and (args.decoder_embed_path !=
                                            args.encoder_embed_path):
                raise ValueError(
                    '--share-all-embeddings not compatible with --decoder-embed-path'
                )
            shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=task.langs,
                embed_dim=args.encoder_embed_dim,
                build_embedding=build_embedding,
                pretrained_embed_path=args.encoder_embed_path,
            )
            shared_decoder_embed_tokens = shared_encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            if args.share_encoder_embeddings:
                shared_encoder_embed_tokens = (
                    FairseqMultiModel.build_shared_embeddings(
                        dicts=task.dicts,
                        langs=src_langs,
                        embed_dim=args.encoder_embed_dim,
                        build_embedding=build_embedding,
                        pretrained_embed_path=args.encoder_embed_path,
                    ))
            if args.share_decoder_embeddings:
                shared_decoder_embed_tokens = (
                    FairseqMultiModel.build_shared_embeddings(
                        dicts=task.dicts,
                        langs=tgt_langs,
                        embed_dim=args.decoder_embed_dim,
                        build_embedding=build_embedding,
                        pretrained_embed_path=args.decoder_embed_path,
                    ))

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}

        def get_encoder(lang):
            if lang not in lang_encoders:
                if shared_encoder_embed_tokens is not None:
                    encoder_embed_tokens = shared_encoder_embed_tokens
                else:
                    encoder_embed_tokens = build_embedding(
                        task.dicts[lang], args.encoder_embed_dim,
                        args.encoder_embed_path)
                lang_encoders[lang] = TransformerEncoderFromPretrainedModel(
                    args, task.dicts[lang], encoder_embed_tokens)
            return lang_encoders[lang]

        def get_decoder(lang):
            if lang not in lang_decoders:
                if shared_decoder_embed_tokens is not None:
                    decoder_embed_tokens = shared_decoder_embed_tokens
                else:
                    decoder_embed_tokens = build_embedding(
                        task.dicts[lang], args.decoder_embed_dim,
                        args.decoder_embed_path)
                lang_decoders[lang] = TransformerDecoderFromPretrainedModel(
                    args, task.dicts[lang], decoder_embed_tokens)
            return lang_decoders[lang]

        # shared encoders/decoders (if applicable)
        shared_encoder, shared_decoder = None, None
        if args.share_encoders:
            shared_encoder = get_encoder(src_langs[0])
        if args.share_decoders:
            shared_decoder = get_decoder(tgt_langs[0])

        encoders, decoders = OrderedDict(), OrderedDict()
        for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs,
                                       tgt_langs):
            encoders[
                lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(
                    src)
            decoders[
                lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(
                    tgt)

        return Generator(encoders, decoders)
    def build_model(cls, args, task):
        """Build a new model instance."""
        if not hasattr(args, "max_source_positions"):
            args.max_source_positions = 1024
        if not hasattr(args, "max_target_positions"):
            args.max_target_positions = 1024

        src_langs = [lang_pair.split("-")[0] for lang_pair in task.lang_pairs]
        tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.lang_pairs]

        if args.share_encoders:
            args.share_encoder_embeddings = True
        if args.share_decoders:
            args.share_decoder_embeddings = True

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}

        def get_encoder(lang, shared_encoder_embed_tokens=None):
            if lang not in lang_encoders:
                src_dict = task.dicts[lang]
                if shared_encoder_embed_tokens is None:
                    encoder_embed_tokens = common_layers.Embedding(
                        num_embeddings=len(src_dict),
                        embedding_dim=args.encoder_embed_dim,
                        padding_idx=src_dict.pad(),
                        freeze_embed=args.encoder_freeze_embed,
                        normalize_embed=getattr(args,
                                                "encoder_normalize_embed",
                                                False),
                    )
                    utils.load_embedding(
                        embedding=encoder_embed_tokens,
                        dictionary=src_dict,
                        pretrained_embed=args.encoder_pretrained_embed,
                    )
                else:
                    encoder_embed_tokens = shared_encoder_embed_tokens
                lang_encoders[lang] = cls.single_model_cls.build_encoder(
                    args, src_dict, embed_tokens=encoder_embed_tokens)
            return lang_encoders[lang]

        def get_decoder(lang, shared_decoder_embed_tokens=None):
            """
            Fetch decoder for the input `lang`, which denotes the target
            language of the model
            """
            if lang not in lang_decoders:
                tgt_dict = task.dicts[lang]
                if shared_decoder_embed_tokens is None:
                    decoder_embed_tokens = common_layers.Embedding(
                        num_embeddings=len(tgt_dict),
                        embedding_dim=args.decoder_embed_dim,
                        padding_idx=tgt_dict.pad(),
                        freeze_embed=args.decoder_freeze_embed,
                    )
                    utils.load_embedding(
                        embedding=decoder_embed_tokens,
                        dictionary=tgt_dict,
                        pretrained_embed=args.decoder_pretrained_embed,
                    )
                else:
                    decoder_embed_tokens = shared_decoder_embed_tokens
                lang_decoders[lang] = cls.single_model_cls.build_decoder(
                    args,
                    task.dicts[lang],
                    tgt_dict,
                    embed_tokens=decoder_embed_tokens)
            return lang_decoders[lang]

        # shared encoders/decoders (if applicable)
        shared_encoder, shared_decoder = None, None
        if args.share_encoders:
            shared_encoder = get_encoder(src_langs[0])
        if args.share_decoders:
            shared_decoder = get_decoder(tgt_langs[0])

        shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
        if args.share_encoder_embeddings:
            shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=src_langs,
                embed_dim=args.encoder_embed_dim,
                build_embedding=common_layers.build_embedding,
                pretrained_embed_path=None,
            )
        if args.share_decoder_embeddings:
            shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=tgt_langs,
                embed_dim=args.decoder_embed_dim,
                build_embedding=common_layers.build_embedding,
                pretrained_embed_path=None,
            )
        encoders, decoders = OrderedDict(), OrderedDict()
        for lang_pair, src_lang, tgt_lang in zip(task.lang_pairs, src_langs,
                                                 tgt_langs):
            encoders[lang_pair] = (
                shared_encoder if shared_encoder is not None else get_encoder(
                    src_lang,
                    shared_encoder_embed_tokens=shared_encoder_embed_tokens))
            decoders[lang_pair] = (
                shared_decoder if shared_decoder is not None else get_decoder(
                    tgt_lang,
                    shared_decoder_embed_tokens=shared_decoder_embed_tokens))

        return cls(task, encoders, decoders)
Beispiel #5
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        if not hasattr(args, "max_source_positions"):
            args.max_source_positions = 1024
        if not hasattr(args, "max_target_positions"):
            args.max_target_positions = 1024

        src_langs = [lang_pair.split("-")[0] for lang_pair in task.lang_pairs]
        tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.lang_pairs]

        if args.share_encoders:
            args.share_encoder_embeddings = True
        if args.share_decoders:
            args.share_decoder_embeddings = True

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}

        def strip_suffix(lang):
            """
            Both "lang" and "lang_mono" languages share the same encoder/decoder
            since they belong to the same language but use bilingual and monolingual
            corpora respectively to train
            So use "lang" as model key for both "lang" and "lang_mono" by stripping
            the suffix "_mono" if it exists
            """
            if f"_{constants.MONOLINGUAL_DATA_IDENTIFIER}" in lang:
                lang = lang[:-(
                    len(f"_{constants.MONOLINGUAL_DATA_IDENTIFIER}"))]
            return lang

        def get_encoder(lang):
            lang = strip_suffix(lang)
            if lang not in lang_encoders:
                src_dict = task.dicts[lang]
                encoder_embed_tokens = common_layers.Embedding(
                    num_embeddings=len(src_dict),
                    embedding_dim=args.encoder_embed_dim,
                    padding_idx=src_dict.pad(),
                    freeze_embed=args.encoder_freeze_embed,
                    normalize_embed=getattr(args, "encoder_normalize_embed",
                                            False),
                )
                utils.load_embedding(
                    embedding=encoder_embed_tokens,
                    dictionary=src_dict,
                    pretrained_embed=args.encoder_pretrained_embed,
                )
                lang_encoders[lang] = cls.single_model_cls.build_encoder(
                    args, src_dict, embed_tokens=encoder_embed_tokens)
            return lang_encoders[lang]

        def get_decoder(lang_pair, shared_decoder_embed_tokens=None):
            if args.share_decoders:
                args.remove_vr_if_same_lang_at_enc_and_dec = False
            """
            Fetch decoder for the input `lang_pair`, which denotes the target
            language of the model
            """
            source_lang, target_lang = (strip_suffix(lang)
                                        for lang in lang_pair.split("-"))
            if target_lang not in lang_decoders:
                # hack to prevent VR for denoising autoencoder. We remove vocab
                # reduction params if we have lang-lang_any_suffix
                args_maybe_modified = copy.deepcopy(args)
                if (source_lang == target_lang
                        and not args.remove_vr_if_same_lang_at_enc_and_dec):
                    args_maybe_modified.vocab_reduction_params = None
                tgt_dict = task.dicts[target_lang]
                if shared_decoder_embed_tokens is None:
                    decoder_embed_tokens = common_layers.Embedding(
                        num_embeddings=len(tgt_dict),
                        embedding_dim=args.decoder_embed_dim,
                        padding_idx=tgt_dict.pad(),
                        freeze_embed=args.decoder_freeze_embed,
                    )

                    utils.load_embedding(
                        embedding=decoder_embed_tokens,
                        dictionary=tgt_dict,
                        pretrained_embed=args.decoder_pretrained_embed,
                    )
                else:
                    decoder_embed_tokens = shared_decoder_embed_tokens
                lang_decoders[
                    target_lang] = cls.single_model_cls.build_decoder(
                        args_maybe_modified,
                        task.dicts[source_lang],
                        tgt_dict,
                        embed_tokens=decoder_embed_tokens,
                    )
            return lang_decoders[target_lang]

        # shared encoders/decoders (if applicable)
        shared_encoder, shared_decoder, shared_decoder_embed_tokens = None, None, None
        if args.share_encoders:
            shared_encoder = get_encoder(src_langs[0])
        if args.share_decoders:
            shared_decoder = get_decoder(tgt_langs[0])

        if args.share_decoder_embeddings:
            shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=[strip_suffix(tgt_lang) for tgt_lang in tgt_langs],
                embed_dim=args.decoder_embed_dim,
                build_embedding=common_layers.build_embedding,
                pretrained_embed_path=None,
            )
        encoders, decoders = OrderedDict(), OrderedDict()
        for lang_pair, src in zip(task.lang_pairs, src_langs):
            encoders[lang_pair] = (shared_encoder if shared_encoder is not None
                                   else get_encoder(src))
            decoders[lang_pair] = (
                shared_decoder if shared_decoder is not None else get_decoder(
                    lang_pair,
                    shared_decoder_embed_tokens=shared_decoder_embed_tokens))

        return cls(task, encoders, decoders)
Beispiel #6
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        assert isinstance(task, MultilingualTranslationTask)

        # make sure all arguments are present in older models
        base_multilingual_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024



        src_langs = [lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs]
        tgt_langs = [lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs]

        lang2idx = task.lang2idx
        lang2idx2idx = [-1]*(max(task.lang2idx.values())+1)


        # import pdb; pdb.set_trace()
        for idx, l in enumerate(lang2idx.keys()):
            lang2idx2idx[lang2idx[l]] = idx
        # define semantic and syntactic matrices
        no_langs = len([i for i in lang2idx2idx if i>-1])

        M = nn.Parameter(torch.randn(args.encoder_embed_dim, args.encoder_embed_dim//2))
        N = nn.Parameter(torch.randn(no_langs, args.encoder_embed_dim, args.encoder_embed_dim//2))


        if args.share_encoders:
            args.share_encoder_embeddings = True
        if args.share_decoders:
            args.share_decoder_embeddings = True

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        # build shared embeddings (if applicable)
        shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
        if args.share_all_embeddings:
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
            if args.decoder_embed_path and (
                    args.decoder_embed_path != args.encoder_embed_path):
                raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
            shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=task.langs,
                embed_dim=args.encoder_embed_dim,
                build_embedding=build_embedding,
                pretrained_embed_path=args.encoder_embed_path,
            )
            shared_decoder_embed_tokens = shared_encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            if args.share_encoder_embeddings:
                shared_encoder_embed_tokens = (
                    FairseqMultiModel.build_shared_embeddings(
                        dicts=task.dicts,
                        langs=src_langs,
                        embed_dim=args.encoder_embed_dim,
                        build_embedding=build_embedding,
                        pretrained_embed_path=args.encoder_embed_path,
                    )
                )
            if args.share_decoder_embeddings:
                shared_decoder_embed_tokens = (
                    FairseqMultiModel.build_shared_embeddings(
                        dicts=task.dicts,
                        langs=tgt_langs,
                        embed_dim=args.decoder_embed_dim,
                        build_embedding=build_embedding,
                        pretrained_embed_path=args.decoder_embed_path,
                    )
                )

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}


        def get_encoder(lang):
            if lang not in lang_encoders:
                if shared_encoder_embed_tokens is not None:
                    encoder_embed_tokens = shared_encoder_embed_tokens
                else:
                    encoder_embed_tokens = build_embedding(
                        task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path
                    )
                lang_encoders[lang] = TransformerEncoder(args, task.dicts[lang], encoder_embed_tokens, lang2idx2idx, M, N)
            return lang_encoders[lang]

        def get_decoder(lang):
            if lang not in lang_decoders:
                if shared_decoder_embed_tokens is not None:
                    decoder_embed_tokens = shared_decoder_embed_tokens
                else:
                    decoder_embed_tokens = build_embedding(
                        task.dicts[lang], args.decoder_embed_dim, args.decoder_embed_path
                    )
                lang_decoders[lang] = TransformerDecoder(args, task.dicts[lang], decoder_embed_tokens, lang2idx2idx, M, N)
            return lang_decoders[lang]

        # shared encoders/decoders (if applicable)
        shared_encoder, shared_decoder = None, None
        if args.share_encoders:
            shared_encoder = get_encoder(src_langs[0])
        if args.share_decoders:
            shared_decoder = get_decoder(tgt_langs[0])

        encoders, decoders = OrderedDict(), OrderedDict()
        for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
            encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src)
            decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt)

        return MultilingualTransformerModel(encoders, decoders)
Beispiel #7
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        assert isinstance(task, PytorchTranslateSemiSupervised)

        if not hasattr(args, "max_source_positions"):
            args.max_source_positions = 1024
        if not hasattr(args, "max_target_positions"):
            args.max_target_positions = 1024

        src_langs = [lang_pair.split("-")[0] for lang_pair in task.lang_pairs]
        tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.lang_pairs]

        if args.share_encoders:
            args.share_encoder_embeddings = True
        if args.share_decoders:
            args.share_decoder_embeddings = True

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}

        def strip_suffix(lang):
            """
            Both "lang" and "lang_mono" languages share the same encoder/decoder
            since they belong to the same language but use bilingual and monolingual
            corpora respectively to train
            So use "lang" as model key for both "lang" and "lang_mono" by stripping
            the suffix "_mono" if it exists
            """
            if f"_{constants.MONOLINGUAL_DATA_IDENTIFIER}" in lang:
                lang = lang[:-(
                    len(f"_{constants.MONOLINGUAL_DATA_IDENTIFIER}"))]
            return lang

        """
        TODO(T35638969): Generalize this to be able to use other model classes
        like Transformer TransformerModel does not currently have build_encoder
        and build_decoder methods
        """

        def get_encoder(lang):
            lang = strip_suffix(lang)
            if lang not in lang_encoders:
                lang_encoders[lang] = RNNModel.build_encoder(
                    args, task.dicts[lang])
            return lang_encoders[lang]

        def get_decoder(lang_pair, shared_decoder_embed_tokens=None):
            """
            Fetch decoder for the input `lang_pair`, which denotes the target
            language of the model
            """
            source_lang, target_lang = (strip_suffix(lang)
                                        for lang in lang_pair.split("-"))
            if target_lang not in lang_decoders:
                # hack to prevent VR for denoising autoencoder. We remove vocab
                # reduction params if we have lang-lang_any_suffix
                args_maybe_modified = copy.deepcopy(args)
                if source_lang == target_lang:
                    args_maybe_modified.vocab_reduction_params = None

                lang_decoders[target_lang] = RNNModel.build_decoder(
                    args_maybe_modified,
                    task.dicts[source_lang],
                    task.dicts[target_lang],
                    embedding_module=shared_decoder_embed_tokens,
                )
            return lang_decoders[target_lang]

        # shared encoders/decoders (if applicable)
        shared_encoder, shared_decoder, shared_decoder_embed_tokens = None, None, None
        if args.share_encoders:
            shared_encoder = get_encoder(src_langs[0])
        if args.share_decoders:
            shared_decoder = get_decoder(tgt_langs[0])

        if args.share_decoder_embeddings:
            shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=[strip_suffix(tgt_lang) for tgt_lang in tgt_langs],
                embed_dim=args.decoder_embed_dim,
                build_embedding=common_layers.build_embedding,
                pretrained_embed_path=None,
            )
        encoders, decoders = OrderedDict(), OrderedDict()
        for lang_pair, src in zip(task.lang_pairs, src_langs):
            encoders[lang_pair] = (shared_encoder if shared_encoder is not None
                                   else get_encoder(src))
            decoders[lang_pair] = (
                shared_decoder if shared_decoder is not None else get_decoder(
                    lang_pair,
                    shared_decoder_embed_tokens=shared_decoder_embed_tokens))

        return SemiSupervisedModel(task, encoders, decoders)
Beispiel #8
0
    def build_model(cls, cfg: MultilingualRNNModelConfig, task):
        """Build a new model instance."""
        from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
        assert isinstance(task, MultilingualTranslationTask)

        if not hasattr(cfg, 'max_source_positions'):
            cfg.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(cfg, 'max_target_positions'):
            cfg.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_langs = [
            lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs
        ]
        tgt_langs = [
            lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs
        ]

        if cfg.share_encoders:
            cfg.share_encoder_embeddings = True
        if cfg.share_decoders:
            cfg.share_decoder_embeddings = True

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        # build shared embeddings (if applicable)
        shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
        if cfg.share_all_embeddings:
            if cfg.encoder_embed_dim != cfg.decoder_embed_dim:
                raise ValueError(
                    '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim'
                )
            if cfg.decoder_embed_path and (cfg.decoder_embed_path !=
                                           cfg.encoder_embed_path):
                raise ValueError(
                    '--share-all-embeddings not compatible with --decoder-embed-path'
                )
            shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
                dicts=task.dicts,
                langs=task.langs,
                embed_dim=cfg.encoder_embed_dim,
                build_embedding=build_embedding,
                pretrained_embed_path=cfg.encoder_embed_path,
            )
            shared_decoder_embed_tokens = shared_encoder_embed_tokens
            cfg.share_decoder_input_output_embed = True
        else:
            if cfg.share_encoder_embeddings:
                shared_encoder_embed_tokens = (
                    FairseqMultiModel.build_shared_embeddings(
                        dicts=task.dicts,
                        langs=src_langs,
                        embed_dim=cfg.encoder_embed_dim,
                        build_embedding=build_embedding,
                        pretrained_embed_path=cfg.encoder_embed_path,
                    ))
            if cfg.share_decoder_embeddings:
                shared_decoder_embed_tokens = (
                    FairseqMultiModel.build_shared_embeddings(
                        dicts=task.dicts,
                        langs=tgt_langs,
                        embed_dim=cfg.decoder_embed_dim,
                        build_embedding=build_embedding,
                        pretrained_embed_path=cfg.decoder_embed_path,
                    ))

        # encoders/decoders for each language
        lang_encoders, lang_decoders = {}, {}

        def get_encoder(lang):
            if lang not in lang_encoders:
                if shared_encoder_embed_tokens is not None:
                    encoder_embed_tokens = shared_encoder_embed_tokens
                else:
                    encoder_embed_tokens = build_embedding(
                        task.dicts[lang], cfg.encoder_embed_dim,
                        cfg.encoder_embed_path)
                lang_encoders[lang] = RNNEncoder(
                    dictionary=task.dicts[lang],
                    embed_dim=cfg.encoder_embed_dim,
                    hidden_size=cfg.encoder_hidden_size,
                    num_layers=cfg.encoder_layers,
                    dropout_in=(cfg.encoder_dropout_in if
                                cfg.encoder_dropout_in >= 0 else cfg.dropout),
                    dropout_out=(cfg.encoder_dropout_out
                                 if cfg.encoder_dropout_out >= 0 else
                                 cfg.dropout),
                    bidirectional=cfg.encoder_bidirectional,
                    pretrained_embed=encoder_embed_tokens,
                    rnn_type=cfg.rnn_type,
                    max_source_positions=cfg.max_source_positions)
            return lang_encoders[lang]

        def get_decoder(lang):
            if lang not in lang_decoders:
                if shared_decoder_embed_tokens is not None:
                    decoder_embed_tokens = shared_decoder_embed_tokens
                else:
                    decoder_embed_tokens = build_embedding(
                        task.dicts[lang], cfg.decoder_embed_dim,
                        cfg.decoder_embed_path)

                lang_decoders[lang] = RNNDecoder(
                    dictionary=task.dicts[lang],
                    embed_dim=cfg.decoder_embed_dim,
                    hidden_size=cfg.decoder_hidden_size,
                    out_embed_dim=cfg.decoder_out_embed_dim,
                    num_layers=cfg.decoder_layers,
                    attention_type=cfg.attention_type,
                    dropout_in=(cfg.decoder_dropout_in if
                                cfg.decoder_dropout_in >= 0 else cfg.dropout),
                    dropout_out=(cfg.decoder_dropout_out
                                 if cfg.decoder_dropout_out >= 0 else
                                 cfg.dropout),
                    rnn_type=cfg.rnn_type,
                    encoder_output_units=cfg.encoder_hidden_size,
                    pretrained_embed=decoder_embed_tokens,
                    share_input_output_embed=cfg.
                    share_decoder_input_output_embed,
                    adaptive_softmax_cutoff=(utils.eval_str_list(
                        cfg.adaptive_softmax_cutoff, type=int) if cfg.criterion
                                             == "adaptive_loss" else None),
                    max_target_positions=cfg.max_target_positions,
                    residuals=False,
                )
            return lang_decoders[lang]

        # shared encoders/decoders (if applicable)
        shared_encoder, shared_decoder = None, None
        if cfg.share_encoders:
            shared_encoder = get_encoder(src_langs[0])
        if cfg.share_decoders:
            shared_decoder = get_decoder(tgt_langs[0])

        encoders, decoders = OrderedDict(), OrderedDict()
        for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs,
                                       tgt_langs):
            encoders[
                lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(
                    src)
            decoders[
                lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(
                    tgt)

        return MultilingualRNNModel(encoders, decoders)