class MemTransformerLM(nn.Module):
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        self.attn_type = attn_type

        self.layers = nn.ModuleList()
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()

    def backward_compatible(self):
        self.sample_softmax = -1

    def _create_params(self):
        if self.attn_type == 0:  # default attention
            self.pos_emb = PositionalEmbedding(self.d_model)
            self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head,
                                                      self.d_head))
            self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head,
                                                      self.d_head))
        elif self.attn_type == 1:  # learnable
            self.r_emb = nn.Parameter(
                torch.Tensor(self.n_layer, self.max_klen, self.n_head,
                             self.d_head))
            self.r_w_bias = nn.Parameter(
                torch.Tensor(self.n_layer, self.n_head, self.d_head))
            self.r_bias = nn.Parameter(
                torch.Tensor(self.n_layer, self.max_klen, self.n_head))
        elif self.attn_type == 2:  # absolute standard
            self.pos_emb = PositionalEmbedding(self.d_model)
        elif self.attn_type == 3:  # absolute deeper SA
            self.r_emb = nn.Parameter(
                torch.Tensor(self.n_layer, self.max_klen, self.n_head,
                             self.d_head))

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len

    def init_mems(self):
        if self.mem_len > 0:
            mems = []
            param = next(self.parameters())
            for i in range(self.n_layer + 1):
                empty = torch.empty(0, dtype=param.dtype, device=param.device)
                mems.append(empty)

            return mems
        else:
            return None

    def _update_mems(self, hids, mems, qlen, mlen):
        # does not deal with None
        if mems is None: return None

        # mems is not None
        assert len(hids) == len(mems), 'len(hids) != len(mems)'

        # There are `mlen + qlen` steps that can be cached into mems
        # For the next step, the last `ext_len` of the `qlen` tokens
        # will be used as the extended context. Hence, we only cache
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
        # to `mlen + qlen - self.ext_len`.
        with torch.no_grad():
            new_mems = []
            end_idx = mlen + max(0, qlen - 0 - self.ext_len)
            beg_idx = max(0, end_idx - self.mem_len)
            for i in range(len(hids)):

                cat = torch.cat([mems[i], hids[i]], dim=0)
                new_mems.append(cat[beg_idx:end_idx].detach())

        return new_mems

    def _forward(self, dec_inp, mems=None):
        qlen, bsz = dec_inp.size()

        word_emb = self.word_emb(dec_inp)

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen
        if self.same_length:
            all_ones = word_emb.new_ones(qlen, klen)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (
                torch.triu(all_ones, 1 + mlen) +
                torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None]  # -1
        else:
            dec_attn_mask = torch.triu(word_emb.new_ones(qlen, klen),
                                       diagonal=1 + mlen).byte()[:, :, None]

        hids = []
        if self.attn_type == 0:  # default
            pos_seq = torch.arange(klen - 1,
                                   -1,
                                   -1.0,
                                   device=word_emb.device,
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb)
            pos_emb = self.drop(pos_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                core_out = layer(core_out,
                                 pos_emb,
                                 self.r_w_bias,
                                 self.r_r_bias,
                                 dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 1:  # learnable
            core_out = self.drop(word_emb)
            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                if self.clamp_len > 0:
                    r_emb = self.r_emb[i][-self.clamp_len:]
                    r_bias = self.r_bias[i][-self.clamp_len:]
                else:
                    r_emb, r_bias = self.r_emb[i], self.r_bias[i]

                mems_i = None if mems is None else mems[i]
                core_out = layer(core_out,
                                 r_emb,
                                 self.r_w_bias[i],
                                 r_bias,
                                 dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 2:  # absolute
            pos_seq = torch.arange(klen - 1,
                                   -1,
                                   -1.0,
                                   device=word_emb.device,
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb + pos_emb[-qlen:])

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and i == 0:
                    mems_i += pos_emb[:mlen]
                core_out = layer(core_out,
                                 dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 3:
            core_out = self.drop(word_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and mlen > 0:
                    cur_emb = self.r_emb[i][:-qlen]
                    cur_size = cur_emb.size(0)
                    if cur_size < mlen:
                        cur_emb_pad = cur_emb[0:1].expand(
                            mlen - cur_size, -1, -1)
                        cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
                    else:
                        cur_emb = cur_emb[-mlen:]
                    mems_i += cur_emb.view(mlen, 1, -1)
                core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)

                core_out = layer(core_out,
                                 dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

        return core_out, new_mems

    def forward(self, data, target, *mems):
        # nn.DataParallel does not allow size(0) tensors to be broadcasted.
        # So, have to initialize size(0) mems inside the model forward.
        # Moreover, have to return new_mems to allow nn.DataParallel to piece
        # them together.
        if not mems: mems = self.init_mems()

        tgt_len = target.size(0)
        hidden, new_mems = self._forward(data, mems=mems)

        pred_hid = hidden[-tgt_len:]
        if self.sample_softmax > 0 and self.training:
            assert self.tie_weight
            logit = sample_logits(self.word_emb, self.out_layer.bias, target,
                                  pred_hid, self.sampler)
            loss = -F.log_softmax(logit, -1)[:, :, 0]
        else:
            loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)),
                             target.view(-1))
            loss = loss.view(tgt_len, -1)

        if new_mems is None:
            return [loss]
        else:
            return [loss] + new_mems

    def forward_generate(self, data, *mems):
        if not mems: mems = self.init_mems()

        tgt_len = data.size(0)
        batch_size = data.size(1)
        hidden, new_mems = self._forward(data, mems=mems)

        pred_hid = hidden[-tgt_len:]

        assert self.crit.n_clusters == 0

        logits = self.crit._compute_logit(pred_hid.view(-1, pred_hid.size(-1)),
                                          self.crit.out_layers[0].weight,
                                          self.crit.out_layers[0].bias,
                                          self.crit.out_projs[0])
        logits = logits.view(tgt_len, batch_size, -1)

        if new_mems is None:
            return [logits]
        else:
            return [logits] + new_mems
Exemple #2
0
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token    # 1000

        d_embed = d_model if d_embed is None else d_embed    # 200
        self.d_embed = d_embed    # 200
        self.d_model = d_model    # 200
        self.n_head = n_head    # 2
        self.d_head = d_head    # 2

        self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
                                          div_val=div_val)    # 单纯的词向量,没位置信息

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer    # 4

        self.tgt_len = tgt_len    # 36
        self.mem_len = mem_len    # 36
        self.ext_len = ext_len    # 0
        self.max_klen = tgt_len + ext_len + mem_len    # 72

        self.attn_type = attn_type

        self.layers = nn.ModuleList()
        if attn_type == 0:    # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))
        elif attn_type == 1:    # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:    # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length    # False
        self.clamp_len = clamp_len    # -1

        self._create_params()    # 初始化向量属性
Exemple #3
0
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1):

        super(MemTransformerLM, self).__init__()

        self.n_token = n_token
        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head  # number of heads
        self.d_head = d_head  # head dimension
        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)
        self.drop = nn.Dropout(dropout)
        self.n_layer = n_layer
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len
        self.attn_type = attn_type

        self.layers = nn.ModuleList()  # These look like all decoder layers

        if attn_type == 0:  # the default attention. this will have relative positional encoding.
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))

        elif attn_type == 1:  # learnable embeddings. Shaw et al.
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))

        elif attn_type in [
                2, 3
        ]:  # absolute embeddings. 2 corresponds to original transformer. 3 is for Al-Rfou.
            for i in range(n_layer):
                # d_inner is dimension of FFN (2100), n_head=10, d_model=410, d_head=41, d_inner=2100, drop=0.1, dropatt=0, pre_lnor=False
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax

        if sample_softmax > 0:  # use sampled softmax
            self.out_layer = nn.Linear(d_model, n_token)

            if tie_weight:
                self.out_layer.weight = self.word_emb.weight  # Tie the weights of output and input embedding matrix.

            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        # TODO: understand what is adaptive softmax
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]

                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()
Exemple #4
0
    def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
                 dropout, dropatt, tie_weight=True, d_embed=None, 
                 div_val=1, tie_projs=[False], pre_lnorm=False,
                 tgt_len=None, ext_len=None, mem_len=None, 
                 cutoffs=[], adapt_inp=False,
                 same_length=False, attn_type=0, clamp_len=-1, 
                 sample_softmax=-1, th=0, fn="attn_scores.npy"):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, 
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        self.attn_type = attn_type

        self.th = th

        self.layers = nn.ModuleList()

        if attn_type == 0: # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                        dropatt=dropatt, pre_lnorm=pre_lnorm)
                )
        elif attn_type == 1: # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                        dropatt=dropatt, pre_lnorm=pre_lnorm)
                )
        elif attn_type in [2, 3]: # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        dropatt=dropatt, pre_lnorm=pre_lnorm)
                )

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, 
                                                    cutoffs, div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        # load average attention patterns
        # (n_layer, qlen, klen, 1, n_head)
        weights = np.load(fn)
        shape = weights.shape
        n_layer, n_head =  shape[0], shape[-1]
        # init prune mask attribute
        self.prune_masks = np.zeros(weights.shape)
        target_percentile = target_percentage * 100
        target_percentage = self.th
        for i in range(0, n_layer):
            w = weights[i,:,:,:,:]
            tau = np.percentile(w, target_percentile, interpolation='nearest')
            self.prune_masks[i,:,:,:,:] = w  <= tau

        self._create_params()
Exemple #5
0
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1,
                 rnnenc=False,
                 rnndim=0,
                 layer_list='',
                 future_len=0,
                 attn_layerlist='',
                 merge_type='direct'):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.future_len = future_len
        self.max_klen = tgt_len + ext_len + mem_len + future_len

        self.attn_type = attn_type

        # RNN hidden state carry on
        self.layer_list = [int(i) for i in layer_list.split()]
        self.rnnlayer_list = self.layer_list
        print("rnn layer list: {}".format(self.rnnlayer_list))
        if rnnenc and rnndim != 0:
            if merge_type in ['gating', 'project']:
                self.rnnproj = nn.Linear(rnndim + d_model, d_model)
            self.rnn_list = nn.ModuleList([
                nn.LSTM(d_model, rnndim, 1)
                for i in range(len(self.rnnlayer_list))
            ])
        # attn penalisation
        self.attn_pen_layers = [int(i) for i in attn_layerlist.split()]
        self.merge_type = merge_type

        self.layers = nn.ModuleList()
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                # dropatt = dropatt * 2 if (i == 0 and rnnenc) else dropatt
                use_penalty = i in self.attn_pen_layers
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm,
                                                    penalty=use_penalty))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                use_penalty = i in self.attn_pen_layers
                # dropatt = dropatt * 0 if (i == 0 and rnnenc) else dropatt
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm,
                                 penalty=use_penalty))

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.rnnenc = rnnenc
        self.rnndim = rnndim
        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()
Exemple #6
0
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1):
        '''

        :param n_token: 单词表的大小
        :param n_layer: 16
        :param n_head: 10
        :param d_model: 410
        :param d_head:  41
        :param d_inner: 2100
        :param dropout: 0.1
        :param dropatt: 0.0
        :param tie_weight: True
        :param d_embed: 410
        :param div_val: 1
        :param tie_projs: [F,T,T,T]
        :param pre_lnorm: False
        :param tgt_len: 150
        :param ext_len: 0
        :param mem_len: 150
        :param cutoffs: [20000, 40000, 200000]
        :param adapt_inp:
        :param same_length:  训练的时候是False
        :param attn_type: 0
        :param clamp_len: -1
        :param sample_softmax: -1
        '''
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)
        self.n_layer = n_layer
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len
        self.attn_type = attn_type
        self.layers = nn.ModuleList()
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 clamp_len=-1,
                 sample_softmax=-1):
        super(MemTransformerLM, self).__init__()

        # print("######  MemTransformerLM parameters ######")
        # print("n_token={}".format(n_token),"###")
        # print("n_layer={}".format(n_layer),"###")
        # print("n_head",n_head)
        # print("d_model",d_model)
        # print("d_inner",d_inner)
        # print("dropout",dropout)
        # print("dropatt",dropatt)
        # print("tie_weight",tie_weight)
        # print("div_val",div_val)
        # print("tie_projs",tie_projs)
        # print("pre_lnorm",pre_lnorm)
        #print("tgt_len",tgt_len)
        #print("ext_len",ext_len)
        #print("mem_len",mem_len)
        # print("cutoffs",cutoffs)
        # print("adapt_inp",adapt_inp)
        # print("same_length",same_length)
        # print("attn_type",attn_type)
        # print("clamp_len",clamp_len)
        # print("sample_softmax",sample_softmax)
        # print("######  MemTransformerLM parameters ######")

        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        self.attn_type = attn_type

        self.model = "training"

        self.layers = nn.ModuleList()
        #print("##### ModuleList #####")
        #print("n_layer=",attn_type)
        #print("attn_type=",attn_type)
        # attn_type 0
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax

        #print("###### sample_softmax=",sample_softmax)

        # use sampled softmax
        # sample_softmax = -1, tie_weight = True, tie_projs = [False]
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)
        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)
            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight
            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()
    def __init__(self,
                 n_token,
                 n_layer,
                 n_head,
                 d_model,
                 d_head,
                 d_inner,
                 dropout,
                 dropatt,
                 tie_weight=True,
                 d_embed=None,
                 div_val=1,
                 tie_projs=[False],
                 pre_lnorm=False,
                 tgt_len=None,
                 ext_len=None,
                 mem_len=None,
                 cutoffs=[],
                 adapt_inp=False,
                 same_length=False,
                 attn_type=0,
                 act_type=0,
                 use_bn=False,
                 clamp_len=-1,
                 sample_softmax=-1,
                 glu_type=0,
                 glu_layers=[]):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        self.act_type = act_type
        self.attn_type = attn_type
        print('glu type: ', glu_type)
        self.layers = nn.ModuleList()
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                if i in glu_layers:
                    print('use GLU in layer-{}'.format(i))
                else:
                    glu_type = 0  # no gate
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm,
                                                    glu_type=glu_type,
                                                    act_type=act_type,
                                                    use_bn=use_bn))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                if i in glu_layers:
                    print('use GLU in layer-{}'.format(i))
                else:
                    glu_type = 0
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm,
                                             glu_type=glu_type,
                                             act_type=act_type))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                # set not sepecified layer GLU type -> 0
                if i in glu_layers:
                    print('use GLU in layer-{}'.format(i))
                else:
                    glu_type = 0
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 glu_type,
                                 act_type,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()
Exemple #9
0
class MemTransformerLM(nn.Module):
    def __init__(
        self,
        n_token,
        n_layer,
        n_head,
        d_model,
        d_head,
        d_inner,
        dropout,
        dropatt,
        tie_weight=True,
        d_embed=None,
        div_val=1,
        tie_projs=[False],
        pre_lnorm=False,
        tgt_len=None,
        ext_len=None,
        mem_len=None,
        cutoffs=[],
        adapt_inp=False,
        same_length=False,
        attn_type=0,
        clamp_len=-1,
        sample_softmax=-1,
    ):
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token
        self.training_steps = 0
        self.compute = 0

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head
        self.d_inner = d_inner
        self.n_layer = n_layer
        self.attn_type = attn_type
        self.tie_weight = tie_weight
        self.tie_projs = tie_projs

        self.drop = nn.Dropout(dropout)
        self.dropout_p = dropout
        self.dropatt_p = dropatt
        self.pre_lnorm = pre_lnorm

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        self.word_emb = AdaptiveEmbedding(n_token,
                                          d_embed,
                                          d_model,
                                          cutoffs,
                                          div_val=div_val)

        self.layers = nn.ModuleList()
        if attn_type == 0:  # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(n_head,
                                                    d_model,
                                                    d_head,
                                                    d_inner,
                                                    dropout,
                                                    tgt_len=tgt_len,
                                                    ext_len=ext_len,
                                                    mem_len=mem_len,
                                                    dropatt=dropatt,
                                                    pre_lnorm=pre_lnorm))
        elif attn_type == 1:  # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(n_head,
                                             d_model,
                                             d_head,
                                             d_inner,
                                             dropout,
                                             tgt_len=tgt_len,
                                             ext_len=ext_len,
                                             mem_len=mem_len,
                                             dropatt=dropatt,
                                             pre_lnorm=pre_lnorm))
        elif attn_type in [2, 3]:  # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(n_head,
                                 d_model,
                                 d_head,
                                 d_inner,
                                 dropout,
                                 dropatt=dropatt,
                                 pre_lnorm=pre_lnorm))

        self.sample_softmax = sample_softmax

        self._create_params()
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(n_token,
                                                    d_embed,
                                                    d_model,
                                                    cutoffs,
                                                    div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[
                        i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                    elif tie_proj and div_val != 1:
                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]

        self.same_length = same_length
        self.clamp_len = clamp_len

    def backward_compatible(self, tie_weight, tie_projs):
        self.sample_softmax = -1
        self.tie_weight = tie_weight
        self.tie_projs = tie_projs

    def _create_params(self):
        if self.attn_type == 0:  # default attention
            self.pos_emb = PositionalEmbedding(self.d_model)
            # self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
            # self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
        elif self.attn_type == 1:  # learnable
            self.r_emb = nn.Parameter(
                torch.Tensor(self.n_layer, self.max_klen, self.n_head,
                             self.d_head))
            self.r_w_bias = nn.Parameter(
                torch.Tensor(self.n_layer, self.n_head, self.d_head))
            self.r_bias = nn.Parameter(
                torch.Tensor(self.n_layer, self.max_klen, self.n_head))
        elif self.attn_type == 2:  # absolute standard
            self.pos_emb = PositionalEmbedding(self.d_model)
        elif self.attn_type == 3:  # absolute deeper SA
            self.r_emb = nn.Parameter(
                torch.Tensor(self.n_layer, self.max_klen, self.n_head,
                             self.d_head))

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len

    def init_mems(self):
        if self.mem_len > 0:
            mems = []
            param = next(self.parameters())
            for i in range(self.n_layer + 1):
                empty = torch.empty(0, dtype=param.dtype, device=param.device)
                mems.append(empty)

            return mems
        else:
            return None

    def _update_mems(self, hids, mems, qlen, mlen):
        # does not deal with None
        if mems is None: return None

        # mems is not None
        assert len(hids) == len(mems), 'len(hids) != len(mems)'

        # There are `mlen + qlen` steps that can be cached into mems
        # For the next step, the last `ext_len` of the `qlen` tokens
        # will be used as the extended context. Hence, we only cache
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
        # to `mlen + qlen - self.ext_len`.
        with torch.no_grad():
            new_mems = []
            end_idx = mlen + max(0, qlen - 0 - self.ext_len)
            beg_idx = max(0, end_idx - self.mem_len)
            for i in range(len(hids)):
                cat = torch.cat([mems[i], hids[i]], dim=0)
                new_mems.append(cat[beg_idx:end_idx].detach())

        return new_mems

    def _forward(self, dec_inp, mems=None):
        qlen, bsz = dec_inp.size()

        word_emb = self.word_emb(dec_inp)

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen
        if self.same_length:
            all_ones = word_emb.new_ones(qlen, klen)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (
                torch.triu(all_ones, 1 + mlen) +
                torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None]  # -1
        else:
            dec_attn_mask = torch.triu(word_emb.new_ones(qlen, klen),
                                       diagonal=1 + mlen).bool()[:, :, None]

        hids = []
        if self.attn_type == 0:  # default
            pos_seq = torch.arange(klen - 1,
                                   -1,
                                   -1.0,
                                   device=word_emb.device,
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb)
            pos_emb = self.drop(pos_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                core_out = layer(core_out,
                                 pos_emb,
                                 dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 1:  # learnable
            core_out = self.drop(word_emb)
            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                if self.clamp_len > 0:
                    r_emb = self.r_emb[i][-self.clamp_len:]
                    r_bias = self.r_bias[i][-self.clamp_len:]
                else:
                    r_emb, r_bias = self.r_emb[i], self.r_bias[i]

                mems_i = None if mems is None else mems[i]
                core_out = layer(core_out,
                                 r_emb,
                                 self.r_w_bias[i],
                                 r_bias,
                                 dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 2:  # absolute
            pos_seq = torch.arange(klen - 1,
                                   -1,
                                   -1.0,
                                   device=word_emb.device,
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb + pos_emb[-qlen:])

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and i == 0:
                    mems_i += pos_emb[:mlen]
                core_out = layer(core_out,
                                 dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 3:
            core_out = self.drop(word_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and mlen > 0:
                    cur_emb = self.r_emb[i][:-qlen]
                    cur_size = cur_emb.size(0)
                    if cur_size < mlen:
                        cur_emb_pad = cur_emb[0:1].expand(
                            mlen - cur_size, -1, -1)
                        cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
                    else:
                        cur_emb = cur_emb[-mlen:]
                    mems_i += cur_emb.view(mlen, 1, -1)
                core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)

                core_out = layer(core_out,
                                 dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

        return core_out, new_mems

    def forward(self, data, target, *mems):
        # nn.DataParallel does not allow size(0) tensors to be broadcasted.
        # So, have to initialize size(0) mems inside the model forward.
        # Moreover, have to return new_mems to allow nn.DataParallel to piece
        # them together.
        if not mems: mems = self.init_mems()

        tgt_len = target.size(0)
        hidden, new_mems = self._forward(data, mems=mems)

        pred_hid = hidden[-tgt_len:]
        if self.sample_softmax > 0 and self.training:
            assert self.tie_weight
            logit = sample_logits(self.word_emb, self.out_layer.bias, target,
                                  pred_hid, self.sampler)
            loss = -F.log_softmax(logit, -1)[:, :, 0]
        else:
            loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)),
                             target.reshape(-1))
            loss = loss.view(tgt_len, -1)

        if new_mems is None:
            return [loss]
        else:
            return [loss] + new_mems

    def expand_layers(self, n_add, strategy="repeat", function=None):
        assert self.attn_type == 0, f"only works with default attention mode, not mode {self.attn}"
        assert strategy in ["repeat", "reinit", "repeat_bottom", "reinit_bottom", "duplicate"], \
            f"initialization mode {strategy} not implemented"
        duplicate = "duplicate" in strategy
        bottom = "bottom" in strategy
        new_layers = nn.ModuleList([])

        if duplicate:
            assert n_add % len(self.layers) == 0, \
                f"duplicating the network requires the number of extra layers {n_add} " \
                f"to be an integer nultiple of previous length {len(self.layers)}"
            factor = n_add // len(self.layers)
            # we append them in the order they'll be found in the network later
            for layer in self.layers:
                for _ in range(factor):
                    new_layers.append(deepcopy(layer))
            # we interleave the new layers on top of the layer they duplicate
            positions = [
                i for i in range(len(self.layers) + n_add)
                if i % len(self.layers) != 0
            ]
            for i, new_layer in enumerate(new_layers):
                self.layers.insert(positions[i], new_layer)

        else:
            for _ in range(n_add):
                new_layer = deepcopy(self.layers[0 if bottom else -1])
                if "reinit" in strategy:
                    new_layer.apply(function)
                new_layers.append(new_layer)
            if bottom:
                # not as elegant as extending the end but we have to add modules one by one at the start
                # the count i is to make sure they're in the same order in new_layers and self.layers
                for i, layer in enumerate(new_layers):
                    self.layers.insert(i, layer)
            else:
                self.layers.extend(new_layers)

        self.n_layer += n_add
        return new_layers

    def add_heads(self,
                  ratio,
                  strategy="duplicate",
                  function=None,
                  expand_inner=True,
                  expand_embeddings=True,
                  change_freq=False):
        assert self.attn_type == 0, f"only works with default attention mode, not mode {self.attn}"
        assert strategy in ["reinit", "duplicate"], \
            f"initialization mode {strategy} not implemented"

        self.d_model *= ratio
        self.n_head *= ratio
        if expand_inner:
            self.d_inner *= ratio
        if expand_embeddings:
            self.d_embed *= ratio

        for layer in self.layers:
            layer.add_heads(ratio, strategy, function, expand_inner)

        if expand_embeddings:
            self.crit.widen(ratio, strategy, function)
            self.word_emb.widen(ratio, strategy, function, self.tie_weight,
                                self.tie_projs)
            self.pos_emb.widen(ratio, change_freq)

        return self.layers