コード例 #1
0
def mass_base(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
    args.encoder_layers = getattr(args, 'encoder_layers', 6)
    args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)

    args.dropout = getattr(args, 'dropout', 0.1)
    args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
    args.activation_dropout = getattr(args, 'activation_dropout', 0.1)
    args.activation_fn = getattr(args, 'activation_fn', 'gelu')

    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 3072)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12)
    args.decoder_layers = getattr(args, 'decoder_layers', 6)
    args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)

    args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True)
    args.share_all_embeddings = getattr(args, 'share_all_embeddings', True)
    args.layernorm_embedding = getattr(args, "layernorm_embedding", True)

    args.apply_bert_init = getattr(args, "apply_bert_init", False)

    base_architecture(args)
コード例 #2
0
    def init_from_config(cls, impl, decoder_kwargs, embedding):

        module = cls(impl)

        module.embedding = embedding
        module.decoder_kwargs = decoder_kwargs

        if impl == "fairseq":
            args = {}

            # fairseq default args
            ap = ArgumentParser()
            FairseqModel.add_args(ap)
            args.update(vars(ap.parse_args("")))

            # fairseq base architecture args
            ns = Namespace(**decoder_kwargs)
            base_architecture(ns)
            args.update(vars(ns))

            # our args
            args.update(decoder_kwargs)

            namespace = Namespace(**args)
            dumb_dict = {0 for _ in range(embedding.weight.shape[0])}

            module.model = FairseqDecoder(namespace, dumb_dict, embedding)
        else:
            raise NotImplementedError()

        module.is_initialized = True

        return module
コード例 #3
0
ファイル: transformer_modular.py プロジェクト: varisd/fairseq
def transformer_modular(args):
    args.encoder_attention_heads_active = getattr(
        args, 'encoder_attention_heads_active', args.encoder_attention_heads)
    args.encoder_modular_layer_indices = getattr(
        args, 'encoder_modular_layer_indices', '()')
    args.decoder_attention_heads_active = getattr(
        args, 'decoder_attention_heads_active', args.decoder_attention_heads)
    args.decoder_modular_layer_indices = getattr(
        args, 'decoder_modular_layer_indices', '()')
    #args.decoder_attention_heads_active = getattr(
    #    args, 'enc_dec_attention_heads_active',
    #    args.decoder_attention_heads_active)
    #args.decoder_modular_layer_indices = getattr(
    #    args, 'enc_dec_modular_layer_indices',
    #    args.decoder_modular_layer_indices)
    args.share_encoder_ctrl = getattr(args, 'share_encoder_ctrl', False)

    args.module_ctrl_hidden_depth = getattr(args, 'module_ctrl_hidden_depth',
                                            0)
    args.module_ctrl_hidden_dim = getattr(args, 'module_ctrl_hidden_dim', None)
    args.module_ctrl_word_dropout = getattr(args, 'module_ctrl_word_dropout',
                                            0.0)
    args.module_ctrl_type = getattr(args, 'module_ctrl_type', 'joint')
    args.module_ctrl_init = getattr(args, 'module_ctrl_init', 'uniform')

    transformer.base_architecture(args)
コード例 #4
0
def base_multilingual_architecture(args):
    base_architecture(args)
    args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings',
                                            False)
    args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings',
                                            False)
    args.share_encoders = getattr(args, 'share_encoders', False)
    args.share_decoders = getattr(args, 'share_decoders', False)
コード例 #5
0
def transformer_mmt_base(args):
    # it corresponds to fairseq "base_architecture", having the following main parameters (as of 16/05/2019):
    # encoder_embed_dim = decoder_embed_dim = 512
    # encoder_ffn_embed_dim = decoder_ffn_embed_dim = 2048
    # encoder_attention_heads = decoder_attention_heads = 8
    # encoder_layers = decoder_layers = 6
    # dropout = 0.1
    base_architecture(args)
コード例 #6
0
ファイル: models.py プロジェクト: mindis/Faster_Transformers
def transformer_small(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
    args.encoder_layers = getattr(args, 'encoder_layers', 3)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
    args.decoder_layers = getattr(args, 'decoder_layers', 3)
    transformer.base_architecture(args)
コード例 #7
0
def transformer_iwslt_de_en(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
    args.encoder_layers = getattr(args, 'encoder_layers', 6)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
    args.decoder_layers = getattr(args, 'decoder_layers', 6)
    base_architecture(args)
コード例 #8
0
def transformer_big(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before",
                                            False)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
    args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
    base_architecture(args)
コード例 #9
0
def transformer_xlarge(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 2048)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 8192)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before",
                                            False)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
    args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
    base_architecture(args)
コード例 #10
0
ファイル: __init__.py プロジェクト: baoy-nlp/NAT-fairseq
def transformer_iwslt16_de_en(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 278)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 507)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 2)
    args.encoder_layers = getattr(args, 'encoder_layers', 5)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 278)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 507)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 2)
    args.decoder_layers = getattr(args, 'decoder_layers', 5)
    base_architecture(args)
コード例 #11
0
    def build_model_base(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()

            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError(
                    '--share-all-embeddings requires a joined dictionary')
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim'
                )
            if args.decoder_embed_path and (args.decoder_embed_path !=
                                            args.encoder_embed_path):
                raise ValueError(
                    '--share-all-embeddings not compatible with --decoder-embed-path'
                )
            encoder_embed_tokens = build_embedding(
                src_dict,
                args.encoder_embed_dim,
                args.encoder_embed_path,
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = build_embedding(src_dict,
                                                   args.encoder_embed_dim,
                                                   args.encoder_embed_path)
            decoder_embed_tokens = build_embedding(tgt_dict,
                                                   args.decoder_embed_dim,
                                                   args.decoder_embed_path)

        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        return (encoder, decoder)
コード例 #12
0
    def build_model(cls, args, task):
        transformer.base_architecture(args)

        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = transformer.DEFAULT_MAX_TARGET_POSITIONS

        captions_dict = task.target_dictionary

        encoder = cls.do_build_encoder(args)
        decoder = cls.do_build_decoder(args, captions_dict)
        return cls.do_build_model(encoder, decoder)
コード例 #13
0
ファイル: lanmt.py プロジェクト: George0828Zhang/NAT
    def build_model(cls, args, task):
        """Build a new model instance."""
        """The same as models.transformer, but adds prior and posterior"""
        # make sure all arguments are present in older models
        base_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError(
                    "--share-all-embeddings requires a joined dictionary")
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (args.decoder_embed_path !=
                                            args.encoder_embed_path):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = cls.build_embedding(args, src_dict,
                                                       args.encoder_embed_dim,
                                                       args.encoder_embed_path)
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = cls.build_embedding(args, src_dict,
                                                       args.encoder_embed_dim,
                                                       args.encoder_embed_path)
            decoder_embed_tokens = cls.build_embedding(args, tgt_dict,
                                                       args.decoder_embed_dim,
                                                       args.decoder_embed_path)

        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)

        ## because posterior requires vocab and embeddings, we need to build them here.
        prior = cls.build_prior(args)
        posterior = cls.build_posterior(args, task, src_dict, tgt_dict,
                                        encoder_embed_tokens,
                                        decoder_embed_tokens)
        return cls(args, encoder, decoder, prior, posterior)
コード例 #14
0
def latent_multilingual_architecture(args):
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
    args.encoder_layers = getattr(args, "encoder_layers", 12)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
    args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
    args.decoder_layers = getattr(args, "decoder_layers", 24)
    args.share_encoders = getattr(args, "share_encoders", True)
    args.share_decoders = getattr(args, "share_decoders", True)
    args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
    args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)

    base_architecture(args)
コード例 #15
0
def latent_multilingual_architecture(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
    args.encoder_layers = getattr(args, 'encoder_layers', 12)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
    args.decoder_layers = getattr(args, 'decoder_layers', 24)
    args.share_encoders = getattr(args, 'share_encoders', True)
    args.share_decoders = getattr(args, 'share_decoders', True)
    args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings',
                                            True)
    args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings',
                                            True)

    base_architecture(args)
コード例 #16
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        tgt_dict = task.target_dictionary

        encoder_embed_speech = ASRFeature(
            cmvn=args.cmvn,
            n_mels=args.fbank_dim,
            dropout=args.dropout,
            sample_rate=task.
            sample_rate,  # NOTE: assumes load_dataset is called before build_model
            n_fft=args.stft_dim,
            stride=args.stft_stride,
            n_subsample=args.encoder_subsample_layers,
            odim=args.encoder_embed_dim,
        )

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        decoder_embed_tokens = build_embedding(tgt_dict,
                                               args.decoder_embed_dim,
                                               args.decoder_embed_path)

        setattr(encoder_embed_speech, "padding_idx",
                -1)  # decoder_embed_tokens.padding_idx)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        encoder = cls.build_encoder(args, encoder_embed_speech)
        return TransformerModel(encoder, decoder)
コード例 #17
0
def prime_base_architecture(args):
    base_architecture(args)
    args.use_att = getattr(args, "use_att", [
        'es',
        'ds',
        'dc',
    ])
    args.kernel_size = getattr(args, "kernel_size", 0)
    args.attn_dynamic_type = getattr(args, "attn_dynamic_type", 0)
    args.attn_cat_relu = getattr(args, "attn_cat_relu", 0)
    args.attn_wide_kernels = getattr(args, "attn_wide_kernels", [3, 15])
    args.dynamic_gate = getattr(args, "dynamic_gate", 1)
    args.dynamic_depth_kernels = getattr(
        args, "dynamic_depth_kernels", [3, 3, 3, 7, 7, 7, 7, 7, 7, 15, 15, 15])
    args.dynamic_padding = getattr(args, "dynamic_padding", 0)
    args.attn_dynamic_cat = getattr(args, "attn_dynamic_cat", 1)
    args.input_dropout = getattr(args, "input_dropout", 0)
    args.init_method = getattr(args, "init_method", 'km')
    args.lnv = getattr(args, "lnv", 'origin')
コード例 #18
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        encoder_embed_tokens = cls.build_embedding(args,
                                                   task.target_dictionary,
                                                   args.decoder_embed_dim,
                                                   args.decoder_embed_path)

        encoder = cls.build_encoder(args, task.target_dictionary,
                                    encoder_embed_tokens)
        return cls(args, encoder)
コード例 #19
0
def base_monotonic_architecture(args):
    base_architecture(args)
    args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
コード例 #20
0
    def build_model_base(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, "max_source_positions"):
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(args, "max_target_positions"):
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
            assert embed_dim % num_embed_chunks == 0, (
                f"Number of embedding chunks = {num_embed_chunks} should be "
                + f"divisible by the embedding dimension = {embed_dim}"
            )
            assert path is None or num_embed_chunks == 1, (
                "Loading embedding from a path with number of embedding chunks > 1"
                + " is not yet supported"
            )
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            # if provided, load from preloaded dictionaries
            if path:
                emb = Embedding(num_embeddings, embed_dim, padding_idx)
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            else:
                embed_chunk_dim = embed_dim // num_embed_chunks
                emb = nn.ModuleList()
                for i in range(num_embed_chunks):
                    emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx))
            return emb

        num_embed_chunks = args.num_embedding_chunks
        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError("--share-all-embeddings requires a joined dictionary")
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = build_embedding(
                src_dict,
                args.encoder_embed_dim,
                args.encoder_embed_path,
                num_embed_chunks,
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            assert args.share_decoder_input_output_embed or num_embed_chunks == 1, (
                "Not sharing decoder I/O embeddings is not yet supported with number of "
                + "embedding chunks > 1"
            )
            encoder_embed_tokens = build_embedding(
                src_dict,
                args.encoder_embed_dim,
                args.encoder_embed_path,
                num_embed_chunks,
            )
            decoder_embed_tokens = build_embedding(
                tgt_dict,
                args.decoder_embed_dim,
                args.decoder_embed_path,
                num_embed_chunks,
            )

        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        return (encoder, decoder)
def sparse_base_architecture(args):
    base_architecture(args)
    args.top_k = getattr(args, "top_k", -1)
    args.print_attn_score = getattr(args, "print_attn_score", False)
def base_x_transformer(args):
    base_architecture(args)
コード例 #23
0
ファイル: transformer_pg.py プロジェクト: AK391/m2m100
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
        if getattr(args, "source_position_markers", None) is None:
            args.source_position_markers = args.max_source_positions

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
        if src_dict != tgt_dict:
            raise ValueError("Pointer-generator requires a joined dictionary")

        def build_embedding(dictionary, embed_dim, path=None):
            # The dictionary may include additional items that can be used in
            # place of the normal OOV token and that all map to the same
            # embedding. Using a different token for each input position allows
            # one to restore the word identities from the original source text.
            num_embeddings = len(dictionary) - args.source_position_markers
            padding_idx = dictionary.pad()
            unk_idx = dictionary.unk()
            logger.info(
                "dictionary indices from {0} to {1} will be mapped to {2}".
                format(num_embeddings,
                       len(dictionary) - 1, unk_idx))
            emb = Embedding(num_embeddings, embed_dim, padding_idx, unk_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        if args.share_all_embeddings:
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (args.decoder_embed_path !=
                                            args.encoder_embed_path):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = build_embedding(src_dict,
                                                   args.encoder_embed_dim,
                                                   args.encoder_embed_path)
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = build_embedding(src_dict,
                                                   args.encoder_embed_dim,
                                                   args.encoder_embed_path)
            decoder_embed_tokens = build_embedding(tgt_dict,
                                                   args.decoder_embed_dim,
                                                   args.decoder_embed_path)

        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        return cls(args, encoder, decoder)
コード例 #24
0
ファイル: transformer_pg.py プロジェクト: AK391/m2m100
def transformer_pointer_generator(args):
    args.alignment_heads = getattr(args, "alignment_heads", 1)
    args.alignment_layer = getattr(args, "alignment_layer", -1)
    base_architecture(args)
    if args.alignment_layer < 0:
        args.alignment_layer = args.decoder_layers + args.alignment_layer
コード例 #25
0
def residual_drop_transformer_architecture(args):
    base_architecture(args)
コード例 #26
0
ファイル: model.py プロジェクト: George0828Zhang/NAT
    def build_model(cls, args, task):
        """Build a new model instance."""
        # from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
        # assert isinstance(task, MultilingualTranslationTask)

        # make sure all arguments are present in older models
        base_architecture(args)

        if args.share_encoders:
            args.share_encoder_embeddings = True

        ### nat model
        # build shared embeddings (if applicable)
        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError("--share-all-embeddings requires a joined dictionary")
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = TransformerModel.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = TransformerModel.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = TransformerModel.build_embedding(
                args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
            )


        student_cls = ARCH_MODEL_REGISTRY[args.student_arch]
        encoder = student_cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = student_cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        student = student_cls(args,encoder,decoder)

        teacher_cls = ARCH_MODEL_REGISTRY[args.teacher_arch]
        if not issubclass(teacher_cls, NATransformerModel):
            teacher_cls = PatchedTransformerModel

        teacher_encoder = teacher_cls.build_encoder(
            args, src_dict,
            encoder_embed_tokens if args.share_encoder_embeddings else TransformerModel.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
                )
            )
        teacher_decoder = teacher_cls.build_decoder(
            args, tgt_dict,
            decoder_embed_tokens if args.share_decoder_embeddings else TransformerModel.build_embedding(
                args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
                )
            )
        teacher = teacher_cls(args,teacher_encoder,teacher_decoder)

        return cls(args, student, teacher)
コード例 #27
0
def transformer_base_architecture(args):
    args.print_attn_score = getattr(args, "print_attn_score", False)
    base_architecture(args)
コード例 #28
0
def my_hyperparameters(args):
    base_architecture(args)
コード例 #29
0
def transformer_single_shot(args):
    args.warn_patched = getattr(args, "warn_patched", False)
    args.warn_not_patched = getattr(args, "warn_not_patched", False)
    base_architecture(args)
コード例 #30
0
ファイル: transformer_align.py プロジェクト: veralily/fairseq
def transformer_align(args):
    args.alignment_heads = getattr(args, "alignment_heads", 1)
    args.alignment_layer = getattr(args, "alignment_layer", 4)
    args.full_context_alignment = getattr(args, "full_context_alignment", False)
    base_architecture(args)