Exemplo n.º 1
0
    def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
                 dropout, dropatt, tie_weight=True, d_embed=None, 
                 div_val=1, tie_projs=[False], pre_lnorm=False,
                 tgt_len=None, ext_len=None, mem_len=None, 
                 cutoffs=[], adapt_inp=False,
                 same_length=False, attn_type=0, clamp_len=-1, 
                 sample_softmax=-1, th=0, fn="attn_scores.npy"):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, 
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        self.attn_type = attn_type

        self.th = th

        self.layers = nn.ModuleList()

        if attn_type == 0: # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                        dropatt=dropatt, pre_lnorm=pre_lnorm)
                )
        elif attn_type == 1: # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                        dropatt=dropatt, pre_lnorm=pre_lnorm)
                )
        elif attn_type in [2, 3]: # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        dropatt=dropatt, pre_lnorm=pre_lnorm)
                )

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, 
                                                    cutoffs, div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        # load average attention patterns
        # (n_layer, qlen, klen, 1, n_head)
        weights = np.load(fn)
        shape = weights.shape
        n_layer, n_head =  shape[0], shape[-1]
        # init prune mask attribute
        self.prune_masks = np.zeros(weights.shape)
        target_percentile = target_percentage * 100
        target_percentage = self.th
        for i in range(0, n_layer):
            w = weights[i,:,:,:,:]
            tau = np.percentile(w, target_percentile, interpolation='nearest')
            self.prune_masks[i,:,:,:,:] = w  <= tau

        self._create_params()
Exemplo n.º 2
0
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token    # 1000

        d_embed = d_model if d_embed is None else d_embed    # 200
        self.d_embed = d_embed    # 200
        self.d_model = d_model    # 200
        self.n_head = n_head    # 2
        self.d_head = d_head    # 2

        self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
                                          div_val=div_val)    # 单纯的词向量,没位置信息

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer    # 4

        self.tgt_len = tgt_len    # 36
        self.mem_len = mem_len    # 36
        self.ext_len = ext_len    # 0
        self.max_klen = tgt_len + ext_len + mem_len    # 72

        self.attn_type = attn_type

        self.layers = nn.ModuleList()
        if attn_type == 0:    # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))
        elif attn_type == 1:    # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:    # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length    # False
        self.clamp_len = clamp_len    # -1

        self._create_params()    # 初始化向量属性
Exemplo n.º 3
0
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1):

        super(MemTransformerLM, self).__init__()

        self.n_token = n_token
        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head  # number of heads
        self.d_head = d_head  # head dimension
        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)
        self.drop = nn.Dropout(dropout)
        self.n_layer = n_layer
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len
        self.attn_type = attn_type

        self.layers = nn.ModuleList()  # These look like all decoder layers

        if attn_type == 0:  # the default attention. this will have relative positional encoding.
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))

        elif attn_type == 1:  # learnable embeddings. Shaw et al.
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))

        elif attn_type in [
                2, 3
        ]:  # absolute embeddings. 2 corresponds to original transformer. 3 is for Al-Rfou.
            for i in range(n_layer):
                # d_inner is dimension of FFN (2100), n_head=10, d_model=410, d_head=41, d_inner=2100, drop=0.1, dropatt=0, pre_lnor=False
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax

        if sample_softmax > 0:  # use sampled softmax
            self.out_layer = nn.Linear(d_model, n_token)

            if tie_weight:
                self.out_layer.weight = self.word_emb.weight  # Tie the weights of output and input embedding matrix.

            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        # TODO: understand what is adaptive softmax
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]

                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()
Exemplo n.º 4
0
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1,
                 rnnenc=False,
                 rnndim=0,
                 layer_list='',
                 future_len=0,
                 attn_layerlist='',
                 merge_type='direct'):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.future_len = future_len
        self.max_klen = tgt_len + ext_len + mem_len + future_len

        self.attn_type = attn_type

        # RNN hidden state carry on
        self.layer_list = [int(i) for i in layer_list.split()]
        self.rnnlayer_list = self.layer_list
        print("rnn layer list: {}".format(self.rnnlayer_list))
        if rnnenc and rnndim != 0:
            if merge_type in ['gating', 'project']:
                self.rnnproj = nn.Linear(rnndim + d_model, d_model)
            self.rnn_list = nn.ModuleList([
                nn.LSTM(d_model, rnndim, 1)
                for i in range(len(self.rnnlayer_list))
            ])
        # attn penalisation
        self.attn_pen_layers = [int(i) for i in attn_layerlist.split()]
        self.merge_type = merge_type

        self.layers = nn.ModuleList()
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                # dropatt = dropatt * 2 if (i == 0 and rnnenc) else dropatt
                use_penalty = i in self.attn_pen_layers
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm,
                                                    penalty=use_penalty))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                use_penalty = i in self.attn_pen_layers
                # dropatt = dropatt * 0 if (i == 0 and rnnenc) else dropatt
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm,
                                 penalty=use_penalty))

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.rnnenc = rnnenc
        self.rnndim = rnndim
        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()
Exemplo n.º 5
0
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1):
        '''

        :param n_token: 单词表的大小
        :param n_layer: 16
        :param n_head: 10
        :param d_model: 410
        :param d_head:  41
        :param d_inner: 2100
        :param dropout: 0.1
        :param dropatt: 0.0
        :param tie_weight: True
        :param d_embed: 410
        :param div_val: 1
        :param tie_projs: [F,T,T,T]
        :param pre_lnorm: False
        :param tgt_len: 150
        :param ext_len: 0
        :param mem_len: 150
        :param cutoffs: [20000, 40000, 200000]
        :param adapt_inp:
        :param same_length:  训练的时候是False
        :param attn_type: 0
        :param clamp_len: -1
        :param sample_softmax: -1
        '''
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)
        self.n_layer = n_layer
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len
        self.attn_type = attn_type
        self.layers = nn.ModuleList()
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1):
        super(MemTransformerLM, self).__init__()

        # print("######  MemTransformerLM parameters ######")
        # print("n_token={}".format(n_token),"###")
        # print("n_layer={}".format(n_layer),"###")
        # print("n_head",n_head)
        # print("d_model",d_model)
        # print("d_inner",d_inner)
        # print("dropout",dropout)
        # print("dropatt",dropatt)
        # print("tie_weight",tie_weight)
        # print("div_val",div_val)
        # print("tie_projs",tie_projs)
        # print("pre_lnorm",pre_lnorm)
        #print("tgt_len",tgt_len)
        #print("ext_len",ext_len)
        #print("mem_len",mem_len)
        # print("cutoffs",cutoffs)
        # print("adapt_inp",adapt_inp)
        # print("same_length",same_length)
        # print("attn_type",attn_type)
        # print("clamp_len",clamp_len)
        # print("sample_softmax",sample_softmax)
        # print("######  MemTransformerLM parameters ######")

        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        self.attn_type = attn_type

        self.model = "training"

        self.layers = nn.ModuleList()
        #print("##### ModuleList #####")
        #print("n_layer=",attn_type)
        #print("attn_type=",attn_type)
        # attn_type 0
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax

        #print("###### sample_softmax=",sample_softmax)

        # use sampled softmax
        # sample_softmax = -1, tie_weight = True, tie_projs = [False]
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)
        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)
            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight
            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 act_type=0,
                 use_bn=False,
                 clamp_len=-1,
                 sample_softmax=-1,
                 glu_type=0,
                 glu_layers=[]):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        self.act_type = act_type
        self.attn_type = attn_type
        print('glu type: ', glu_type)
        self.layers = nn.ModuleList()
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                if i in glu_layers:
                    print('use GLU in layer-{}'.format(i))
                else:
                    glu_type = 0  # no gate
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm,
                                                    glu_type=glu_type,
                                                    act_type=act_type,
                                                    use_bn=use_bn))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                if i in glu_layers:
                    print('use GLU in layer-{}'.format(i))
                else:
                    glu_type = 0
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm,
                                             glu_type=glu_type,
                                             act_type=act_type))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                # set not sepecified layer GLU type -> 0
                if i in glu_layers:
                    print('use GLU in layer-{}'.format(i))
                else:
                    glu_type = 0
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 glu_type,
                                 act_type,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()