def __init__(self, feature_size, self_attention_layer,
                 cross_attention_layer, feed_forward_layer, adapter_dict,
                 adapter_bottleneck_size, dropout_rate):
        super(TransformerDecoderLayer, self).__init__()
        self.feature_size = feature_size

        self.self_attention_layer = self_attention_layer
        self.cross_attention_layer = cross_attention_layer
        self.feed_forward_layer = feed_forward_layer

        adapters = nn.ModuleDict({})
        sublayer_connection_for_adapter = nn.ModuleDict({})
        _adapters = collections.OrderedDict()
        _sublayer_connection_for_adapter = collections.OrderedDict()
        for domain, size in zip(adapter_dict, adapter_bottleneck_size):
            _adapters[domain] = PositionWiseFeedForward(input_dim=feature_size,
                                                        ff_dim=size,
                                                        dropout=dropout_rate)
            _sublayer_connection_for_adapter[domain] = SublayerConnection(
                feature_size, dropout_rate)
        adapters.update(_adapters)
        sublayer_connection_for_adapter.update(
            _sublayer_connection_for_adapter)

        self.adapters = adapters
        self.sublayer_connection_for_adapter = sublayer_connection_for_adapter

        self.sublayer_with_cache = clones(
            SublayerConnectionWithCache(feature_size, dropout_rate), 2)
        self.sublayer = SublayerConnection(feature_size, dropout_rate)
Exemple #2
0
def make_transformer_with_split_position(model_config, vocab):
    attention = MultiHeadedAttention(head_num=model_config['head_num'],
                                     feature_size=model_config['feature_size'],
                                     dropout=model_config['dropout_rate'])
    attention_with_cache = MultiHeadedAttentionWithCache(
        head_num=model_config['head_num'],
        feature_size=model_config['feature_size'],
        dropout=model_config['dropout_rate'])
    feed_forward = PositionWiseFeedForward(
        input_dim=model_config['feature_size'],
        ff_dim=model_config['feedforward_dim'],
        dropout=model_config['dropout_rate'])

    model = TransformerWithSplitPosition(
        src_embedding_layer=Embeddings(
            emb_size=model_config['feature_size'],
            vocab_size=len(vocab['src']),
            dropout=model_config['dropout_rate'],
            linear_combination=model_config['position_linear_combination'],
        ),
        trg_embedding_layer=Embeddings(
            emb_size=model_config['feature_size'],
            vocab_size=len(vocab['trg']),
            dropout=model_config['dropout_rate'],
            linear_combination=model_config['position_linear_combination'],
        ),
        encoder=TransformerEncoder(
            layer=TransformerEncoderLayer(
                feature_size=model_config['feature_size'],
                self_attention_layer=copy.deepcopy(attention),
                feed_forward_layer=copy.deepcopy(feed_forward),
                dropout_rate=model_config['dropout_rate']),
            feature_size=model_config['feature_size'],
            num_layers=model_config['num_layers'],
        ),
        decoder=TransformerDecoder(
            layer=TransformerDecoderLayer(
                feature_size=model_config['feature_size'],
                self_attention_layer=copy.deepcopy(attention_with_cache),
                cross_attention_layer=copy.deepcopy(attention_with_cache),
                feed_forward_layer=copy.deepcopy(feed_forward),
                dropout_rate=model_config['dropout_rate']),
            num_layers=model_config['num_layers'],
            feature_size=model_config['feature_size'],
        ),
        generator=SimpleGenerator(feature_size=model_config['feature_size'],
                                  vocab_size=len(vocab['trg'])),
        vocab=vocab,
        share_decoder_embedding=model_config['share_decoder_embedding'],
        share_enc_dec_embedding=model_config['share_enc_dec_embedding'],
    )

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model
Exemple #3
0
    def init_adapter(
        self,
        feature_size,
        dropout_rate,
        adapter_dict,
        adapter_bottleneck_size,
    ):

        # define two new sublayer_connection for each domain with domain-specific layer norm
        domain_specific_sublayer_connection = nn.ModuleDict({})
        _domain_specific_sublayer_connection = collections.OrderedDict()
        for domain in adapter_dict:
            _domain_specific_sublayer_connection[domain] = clones(
                SublayerConnection(feature_size, dropout_rate), 2)
        domain_specific_sublayer_connection.update(
            _domain_specific_sublayer_connection)
        self.domain_specific_sublayer_connection = domain_specific_sublayer_connection

        # define two adapter layer for each sub_layer
        domain_specific_adapter_for_self_attn = nn.ModuleDict({})
        _domain_specific_adapter_for_self_attn = collections.OrderedDict()
        domain_specific_adapter_for_ffn = nn.ModuleDict({})
        _domain_specific_adapter_for_ffn = collections.OrderedDict()
        domain_specific_sublayer_connection_for_adapter = nn.ModuleDict({})
        _domain_specific_sublayer_connection_for_adapter = collections.OrderedDict(
        )
        for domain, domain_sz in zip(adapter_dict, adapter_bottleneck_size):
            _domain_specific_adapter_for_self_attn[
                domain] = PositionWiseFeedForward(feature_size, domain_sz,
                                                  dropout_rate)
            _domain_specific_adapter_for_ffn[domain] = PositionWiseFeedForward(
                feature_size, domain_sz, dropout_rate)
            _domain_specific_sublayer_connection_for_adapter[domain] = clones(
                SublayerConnection(feature_size, dropout_rate), 2)

        domain_specific_adapter_for_self_attn.update(
            _domain_specific_adapter_for_self_attn)
        domain_specific_adapter_for_ffn.update(
            _domain_specific_adapter_for_ffn)
        domain_specific_sublayer_connection_for_adapter.update(
            _domain_specific_sublayer_connection_for_adapter)
        self.domain_specific_adapter_for_self_attn = domain_specific_adapter_for_self_attn
        self.domain_specific_adapter_for_ffn = domain_specific_adapter_for_ffn
        self.domain_specific_sublayer_connection_for_adapter = domain_specific_sublayer_connection_for_adapter
Exemple #4
0
def make_transformer_language_model(model_config, vocab):
    attention = MultiHeadedAttention(head_num=model_config['head_num'],
                                     feature_size=model_config['feature_size'],
                                     dropout=model_config['dropout_rate'])
    feed_forward = PositionWiseFeedForward(
        input_dim=model_config['feature_size'],
        ff_dim=model_config['feedforward_dim'],
        dropout=model_config['dropout_rate'])

    model = TransformerLanguageModel(
        embedding_layer=Embeddings(vocab_size=len(vocab['text']),
                                   emb_size=model_config['feature_size'],
                                   dropout=model_config['dropout_rate'],
                                   max_len=5000),
        decoder=TransformerEncoder(
            layer=TransformerEncoderLayer(
                feature_size=model_config['feature_size'],
                self_attention_layer=copy.deepcopy(attention),
                feed_forward_layer=copy.deepcopy(feed_forward),
                dropout_rate=model_config['dropout_rate'],
                layer_norm_rescale=model_config['layer_norm_rescale']),
            feature_size=model_config['feature_size'],
            num_layers=model_config['num_layers'],
            layer_norm_rescale=model_config['layer_norm_rescale'],
        ),
        generator=SimpleGenerator(feature_size=model_config['feature_size'],
                                  vocab_size=len(vocab['text']),
                                  bias=model_config['generator_bias']),
        vocab=vocab,
        share_embedding=model_config['share_embedding'],
    )

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model
Exemple #5
0
def make_transformer_with_full_adapter(model_config, vocab):
    attention = MultiHeadedAttention(head_num=model_config['head_num'],
                                     feature_size=model_config['feature_size'],
                                     dropout=model_config['dropout_rate'])
    attention_with_cache = MultiHeadedAttentionWithCache(
        head_num=model_config['head_num'],
        feature_size=model_config['feature_size'],
        dropout=model_config['dropout_rate'])
    feed_forward = PositionWiseFeedForward(
        input_dim=model_config['feature_size'],
        ff_dim=model_config['feedforward_dim'],
        dropout=model_config['dropout_rate'])

    model = TransformerWithFullAdapter(
        src_embedding_layer=EmbeddingWithAdapter(
            vocab_size=len(vocab['src']),
            emb_size=model_config['feature_size'],
            dropout=model_config['dropout_rate'],
            domain_adapter_dict=model_config['domain_adapter_dict'],
            max_len=5000),
        trg_embedding_layer=EmbeddingWithAdapter(
            vocab_size=len(vocab['src']),
            emb_size=model_config['feature_size'],
            dropout=model_config['dropout_rate'],
            domain_adapter_dict=model_config['domain_adapter_dict'],
            max_len=5000),
        encoder=TransformerEncoderWithStackedAdapter(
            layer=TransformerEncoderLayerWithStackedAdapter(
                feature_size=model_config['feature_size'],
                self_attention_layer=copy.deepcopy(attention),
                feed_forward_layer=copy.deepcopy(feed_forward),
                domain_adapter_dict=model_config['domain_adapter_dict'],
                dropout_rate=model_config['dropout_rate'],
            ),
            feature_size=model_config['feature_size'],
            num_layers=model_config['num_layers'],
        ),
        decoder=TransformerDecoderWithStackedAdapter(
            layer=TransformerDecoderLayerWithStackedAdapter(
                feature_size=model_config['feature_size'],
                self_attention_layer=copy.deepcopy(attention_with_cache),
                cross_attention_layer=copy.deepcopy(attention_with_cache),
                feed_forward_layer=copy.deepcopy(feed_forward),
                domain_adapter_dict=model_config['domain_adapter_dict'],
                dropout_rate=model_config['dropout_rate']),
            num_layers=model_config['num_layers'],
            feature_size=model_config['feature_size'],
        ),
        generator=SimpleGeneratorWithAdapter(
            feature_size=model_config['feature_size'],
            vocab_size=len(vocab['trg']),
            domain_adapter_dict=model_config['domain_adapter_dict'],
            bias=model_config['generator_bias']),
        vocab=vocab,
        share_decoder_embedding=model_config['share_decoder_embedding'],
        share_enc_dec_embedding=model_config['share_enc_dec_embedding'],
        domain_dict=model_config['domain_adapter_dict'],
    )

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    for name, param in model.named_parameters():
        # if param.dim() > 1 and 'adapter' in name:
        #     nn.init.zeros_(param)
        if 'memory_score_bias' in name:
            nn.init.xavier_uniform_(param)

    return model
def make_transformer_with_classifier_adapter(model_config, vocab):
    attention = MultiHeadedAttention(head_num=model_config['head_num'],
                                     feature_size=model_config['feature_size'],
                                     dropout=model_config['dropout_rate'])
    attention_with_cache = MultiHeadedAttentionWithCache(
        head_num=model_config['head_num'],
        feature_size=model_config['feature_size'],
        dropout=model_config['dropout_rate'])
    feed_forward = PositionWiseFeedForward(
        input_dim=model_config['feature_size'],
        ff_dim=model_config['feedforward_dim'],
        dropout=model_config['dropout_rate'])

    if model_config['classifier_type'] == 'simple':
        classifier = SimpleClassifier(
            input_dim=model_config['feature_size'],
            feature_size=model_config['classify_feature_size'],
            class_num=model_config['domain_class_num'],
        )
    elif model_config['classifier_type'] == 'cnn':
        classifier = CNNClassifier(num_class=model_config['domain_class_num'],
                                   input_dim=model_config['feature_size'],
                                   kernel_nums=model_config['kernel_nums'],
                                   kernel_sizes=model_config['kernel_sizes'],
                                   dropout_rate=model_config['dropout_rate'])
    else:
        classifier = None

    model = TransformerWithClassifierAdapter(
        src_embedding_layer=Embeddings(vocab_size=len(vocab['src']),
                                       emb_size=model_config['feature_size'],
                                       dropout=model_config['dropout_rate'],
                                       max_len=5000),
        trg_embedding_layer=Embeddings(vocab_size=len(vocab['trg']),
                                       emb_size=model_config['feature_size'],
                                       dropout=model_config['dropout_rate'],
                                       max_len=5000),
        encoder=TransformerEncoderWithClassifierAdapter(
            layer=TransformerEncoderLayerWithClassifierAdapter(
                feature_size=model_config['feature_size'],
                self_attention_layer=copy.deepcopy(attention),
                feed_forward_layer=copy.deepcopy(feed_forward),
                domain_adapter_dict=model_config['domain_adapter_dict'],
                dropout_rate=model_config['dropout_rate'],
            ),
            feature_size=model_config['feature_size'],
            num_layers=model_config['num_layers'],
            domain_label_dict=model_config['domain_dict']),
        decoder=TransformerDecoderWithClassifierAdapter(
            layer=TransformerDecoderLayerWithClassifierAdapter(
                feature_size=model_config['feature_size'],
                self_attention_layer=copy.deepcopy(attention_with_cache),
                cross_attention_layer=copy.deepcopy(attention_with_cache),
                feed_forward_layer=copy.deepcopy(feed_forward),
                domain_adapter_dict=model_config['domain_adapter_dict'],
                dropout_rate=model_config['dropout_rate']),
            num_layers=model_config['num_layers'],
            feature_size=model_config['feature_size'],
            domain_label_dict=model_config['domain_dict'],
        ),
        generator=SimpleGenerator(feature_size=model_config['feature_size'],
                                  vocab_size=len(vocab['trg']),
                                  bias=model_config['generator_bias']),
        emb_classifier=copy.deepcopy(classifier),
        domain_mask=model_config['domain_mask'],
        vocab=vocab,
        share_decoder_embedding=model_config['share_decoder_embedding'],
        share_enc_dec_embedding=model_config['share_enc_dec_embedding'],
    )

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    for name, param in model.named_parameters():
        # if param.dim() > 1 and 'adapter' in name:
        #     nn.init.zeros_(param)
        if 'memory_score_bias' in name:
            nn.init.xavier_uniform_(param)

    return model
    def init_adapter(self, feature_size, dropout_rate, adapter_dict, adapter_bottleneck_size):

        domain_specific_sublayer_with_cache = nn.ModuleDict({})
        _domain_specific_sublayer_with_cache = collections.OrderedDict()
        for domain in adapter_dict:
            _domain_specific_sublayer_with_cache[domain] = clones(
                SublayerConnectionWithCache(feature_size, dropout_rate), 2)
        domain_specific_sublayer_with_cache.update(_domain_specific_sublayer_with_cache)
        self.domain_specific_sublayer_with_cache = domain_specific_sublayer_with_cache

        domain_specific_sublayer = nn.ModuleDict({})
        _domain_specific_sublayer = collections.OrderedDict()
        for domain in adapter_dict:
            _domain_specific_sublayer[domain] = SublayerConnection(feature_size, dropout_rate)
        domain_specific_sublayer.update(_domain_specific_sublayer)
        self.domain_specific_sublayer = domain_specific_sublayer

        # define three adapter for each sub layer

        domain_specific_adapter_for_self_attn = nn.ModuleDict({})
        _domain_specific_adapter_for_self_attn = collections.OrderedDict()
        domain_specific_adapter_for_cross_attn = nn.ModuleDict({})
        _domain_specific_adapter_for_cross_attn = collections.OrderedDict()
        domain_specific_adapter_for_ffn = nn.ModuleDict({})
        _domain_specific_adapter_for_ffn = collections.OrderedDict()

        domain_specific_sublayer_for_self_attn_adapter = nn.ModuleDict({})
        domain_specific_sublayer_for_cross_attn_adapter = nn.ModuleDict({})
        domain_specific_sublayer_for_ffn_adapter = nn.ModuleDict({})
        _domain_specific_sublayer_for_self_attn_adapter = collections.OrderedDict()
        _domain_specific_sublayer_for_cross_attn_adapter = collections.OrderedDict()
        _domain_specific_sublayer_for_ffn_adapter = collections.OrderedDict()

        for domain, domain_sz in zip(adapter_dict, adapter_bottleneck_size):
            _domain_specific_adapter_for_self_attn[domain] = PositionWiseFeedForward(feature_size,
                                                                                     domain_sz,
                                                                                     dropout_rate)
            _domain_specific_adapter_for_cross_attn[domain] = PositionWiseFeedForward(feature_size,
                                                                                      domain_sz,
                                                                                      dropout_rate)
            _domain_specific_adapter_for_ffn[domain] = PositionWiseFeedForward(feature_size,
                                                                               domain_sz,
                                                                               dropout_rate)

            _domain_specific_sublayer_for_self_attn_adapter[domain] = SublayerConnectionWithCache(feature_size, dropout_rate)
            _domain_specific_sublayer_for_cross_attn_adapter[domain] = SublayerConnectionWithCache(feature_size, dropout_rate)
            _domain_specific_sublayer_for_ffn_adapter[domain] = SublayerConnection(feature_size, dropout_rate)

        domain_specific_adapter_for_self_attn.update(_domain_specific_adapter_for_self_attn)
        domain_specific_adapter_for_cross_attn.update(_domain_specific_adapter_for_cross_attn)
        domain_specific_adapter_for_ffn.update(_domain_specific_adapter_for_ffn)
        domain_specific_sublayer_for_self_attn_adapter.update(_domain_specific_sublayer_for_self_attn_adapter)
        domain_specific_sublayer_for_cross_attn_adapter.update(_domain_specific_sublayer_for_cross_attn_adapter)
        domain_specific_sublayer_for_ffn_adapter.update(_domain_specific_sublayer_for_ffn_adapter)

        self.domain_specific_adapter_for_self_attn = domain_specific_adapter_for_self_attn
        self.domain_specific_adapter_for_ffn = domain_specific_adapter_for_ffn
        self.domain_specific_adapter_for_cross_attn = domain_specific_adapter_for_cross_attn
        self.domain_specific_sublayer_for_self_attn_adapter = domain_specific_sublayer_for_self_attn_adapter
        self.domain_specific_sublayer_for_cross_attn_adapter = domain_specific_sublayer_for_cross_attn_adapter
        self.domain_specific_sublayer_for_ffn_adapter = domain_specific_sublayer_for_ffn_adapter
def make_transformer_with_mix_adapter_update(model_config, vocab):
    attention = MultiHeadedAttention(
        head_num=model_config['head_num'],
        feature_size=model_config['feature_size'],
        dropout=model_config['dropout_rate']
    )
    attention_with_cache = MultiHeadedAttentionWithCache(
        head_num=model_config['head_num'],
        feature_size=model_config['feature_size'],
        dropout=model_config['dropout_rate']
    )
    feed_forward = PositionWiseFeedForward(
        input_dim=model_config['feature_size'],
        ff_dim=model_config['feedforward_dim'],
        dropout=model_config['dropout_rate']
    )

    adapter = None
    if 'domain_adapter_dict' in model_config:
        adapter = MixtureOfAdapter(
            domain_adapter_dict=model_config['domain_adapter_dict'],
            feature_size=model_config['feature_size'],
            dropout_rate=model_config['dropout_rate'],
            # domain_list=model_config['domain_list'],
            # domain_inner_gate_list=model_config['domain_inner_gate_list'],
            # gate_activate_func=model_config['adapter_gate_activate'],
            # stack_between_adapter_and_experts=model_config['stack_between_adapter_and_experts'] if
            # 'stack_between_adapter_and_experts' in model_config else False
        )

    enc_adapter = adapter
    if 'enc_domain_adapter_dict' in model_config:
        enc_adapter = MixtureOfAdapter(
            domain_adapter_dict=model_config['enc_domain_adapter_dict'],
            feature_size=model_config['feature_size'],
            dropout_rate=model_config['dropout_rate'],
        )

    dec_adapter = adapter
    if 'dec_domain_adapter_dict' in model_config:
        dec_adapter = MixtureOfAdapter(
            domain_adapter_dict=model_config['dec_domain_adapter_dict'],
            feature_size=model_config['feature_size'],
            dropout_rate=model_config['dropout_rate'],
        )

    model = TransformerWithMixAdapter(
        src_embedding_layer=Embeddings(
            vocab_size=len(vocab['src']),
            emb_size=model_config['feature_size'],
            dropout=model_config['dropout_rate'],
            max_len=5000
        ),
        trg_embedding_layer=Embeddings(
            vocab_size=len(vocab['trg']),
            emb_size=model_config['feature_size'],
            dropout=model_config['dropout_rate'],
            max_len=5000
        ),
        encoder=TransformerEncoderWithMixAdapter(
            layer=TransformerEncoderLayerWithMixAdapter(feature_size=model_config['feature_size'],
                                                        self_attention_layer=copy.deepcopy(attention),
                                                        feed_forward_layer=copy.deepcopy(feed_forward),
                                                        adapters=copy.deepcopy(enc_adapter),
                                                        dropout_rate=model_config['dropout_rate'],
                                                        ),
            feature_size=model_config['feature_size'],
            num_layers=model_config['num_layers'],
        ),
        decoder=TransformerDecoderWithMixAdapter(
            layer=TransformerDecoderLayerWithMixAdapter(feature_size=model_config['feature_size'],
                                                        self_attention_layer=copy.deepcopy(
                                                            attention_with_cache),
                                                        cross_attention_layer=copy.deepcopy(
                                                            attention_with_cache),
                                                        feed_forward_layer=copy.deepcopy(feed_forward),
                                                        adapters=copy.deepcopy(dec_adapter),
                                                        dropout_rate=model_config['dropout_rate'],
                                                        ),
            num_layers=model_config['num_layers'],
            feature_size=model_config['feature_size'],
        ),
        generator=SimpleGenerator(feature_size=model_config['feature_size'],
                                  vocab_size=len(vocab['trg']),
                                  bias=model_config['generator_bias']),
        vocab=vocab,
        share_decoder_embedding=model_config['share_decoder_embedding'],
        share_enc_dec_embedding=model_config['share_enc_dec_embedding'],
    )

    # for p in model.parameters():
    #     if p.dim() > 1:
    #         nn.init.xavier_uniform_(p)

    for name, param in model.named_parameters():
        if param.dim() > 1 and 'W' not in name and 'B' not in name:
            nn.init.xavier_uniform_(param)

        else:
            if param.dim() > 1:
                print('module self init', name, param.size())

    # for name, param in model.named_parameters():
    #     if param.dim() > 1 and 'adapter' in name:
    #     #     nn.init.zeros_(param)
    #     # if 'adapter' in name:
    #         print('init adapter', name)
    #         nn.init.xavier_uniform_(param, 0.001)

    return model
Exemple #9
0
def make_transformer_with_diff_size_stacked_adapter(model_config, vocab):
    attention = MultiHeadedAttention(head_num=model_config['head_num'],
                                     feature_size=model_config['feature_size'],
                                     dropout=model_config['dropout_rate'])
    attention_with_cache = MultiHeadedAttentionWithCache(
        head_num=model_config['head_num'],
        feature_size=model_config['feature_size'],
        dropout=model_config['dropout_rate'])
    feed_forward = PositionWiseFeedForward(
        input_dim=model_config['feature_size'],
        ff_dim=model_config['feedforward_dim'],
        dropout=model_config['dropout_rate'])

    encoder_layers = []
    decoder_layers = []
    for i in range(0, model_config['num_layers']):
        adapter_dict = {}
        if 'encoder.' + str(i) in model_config['domain_adapter_dict'].keys():
            adapter_dict = model_config['domain_adapter_dict']['encoder.' +
                                                               str(i)]
        else:
            adapter_dict = model_config['domain_adapter_dict']['default']

        encoder_layer = TransformerEncoderLayerWithDiffSizeStackedAdapter(
            feature_size=model_config['feature_size'],
            self_attention_layer=copy.deepcopy(attention),
            feed_forward_layer=copy.deepcopy(feed_forward),
            domain_adapter_dict=adapter_dict,
            dropout_rate=model_config['dropout_rate'],
        )
        encoder_layers.append(encoder_layer)

        adapter_dict = {}
        if 'decoder.' + str(i) in model_config['domain_adapter_dict'].keys():
            adapter_dict = model_config['domain_adapter_dict']['decoder.' +
                                                               str(i)]
        else:
            adapter_dict = model_config['domain_adapter_dict']['default']

        decoder_layer = TransformerDecoderLayerWithDiffSizeStackedAdapter(
            feature_size=model_config['feature_size'],
            self_attention_layer=copy.deepcopy(attention_with_cache),
            cross_attention_layer=copy.deepcopy(attention_with_cache),
            feed_forward_layer=copy.deepcopy(feed_forward),
            domain_adapter_dict=adapter_dict,
            dropout_rate=model_config['dropout_rate'])
        decoder_layers.append(decoder_layer)

    model = TransformerWithAdapter(
        src_embedding_layer=Embeddings(vocab_size=len(vocab['src']),
                                       emb_size=model_config['feature_size'],
                                       dropout=model_config['dropout_rate'],
                                       max_len=5000),
        trg_embedding_layer=Embeddings(vocab_size=len(vocab['trg']),
                                       emb_size=model_config['feature_size'],
                                       dropout=model_config['dropout_rate'],
                                       max_len=5000),
        encoder=TransformerEncoderWithDiffSizeStackedAdapter(
            layers=encoder_layers,
            feature_size=model_config['feature_size'],
            num_layers=model_config['num_layers'],
        ),
        decoder=TransformerDecoderWithDiffSizeStackedAdapter(
            layers=decoder_layers,
            num_layers=model_config['num_layers'],
            feature_size=model_config['feature_size'],
        ),
        generator=SimpleGenerator(feature_size=model_config['feature_size'],
                                  vocab_size=len(vocab['trg']),
                                  bias=model_config['generator_bias']),
        vocab=vocab,
        share_decoder_embedding=model_config['share_decoder_embedding'],
        share_enc_dec_embedding=model_config['share_enc_dec_embedding'],
    )

    # for name, param in model.named_parameters():
    #     print(name)

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    for name, param in model.named_parameters():
        # if param.dim() > 1 and 'adapter' in name:
        #     nn.init.zeros_(param)
        if 'memory_score_bias' in name:
            nn.init.xavier_uniform_(param)

    return model