class MemTransformerLM(nn.Module): def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1): super(MemTransformerLM, self).__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.attn_type = attn_type self.layers = nn.ModuleList() if attn_type == 0: # the default attention for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type == 1: # learnable embeddings for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): self.layers.append( DecoderLayer(n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm)) self.sample_softmax = sample_softmax # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[ i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len self._create_params() def backward_compatible(self): self.sample_softmax = -1 def _create_params(self): if self.attn_type == 0: # default attention self.pos_emb = PositionalEmbedding(self.d_model) self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) elif self.attn_type == 1: # learnable self.r_emb = nn.Parameter( torch.Tensor(self.n_layer, self.max_klen, self.n_head, self.d_head)) self.r_w_bias = nn.Parameter( torch.Tensor(self.n_layer, self.n_head, self.d_head)) self.r_bias = nn.Parameter( torch.Tensor(self.n_layer, self.max_klen, self.n_head)) elif self.attn_type == 2: # absolute standard self.pos_emb = PositionalEmbedding(self.d_model) elif self.attn_type == 3: # absolute deeper SA self.r_emb = nn.Parameter( torch.Tensor(self.n_layer, self.max_klen, self.n_head, self.d_head)) def reset_length(self, tgt_len, ext_len, mem_len): self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len def init_mems(self): if self.mem_len > 0: mems = [] param = next(self.parameters()) for i in range(self.n_layer + 1): empty = torch.empty(0, dtype=param.dtype, device=param.device) mems.append(empty) return mems else: return None def _update_mems(self, hids, mems, qlen, mlen): # does not deal with None if mems is None: return None # mems is not None assert len(hids) == len(mems), 'len(hids) != len(mems)' # There are `mlen + qlen` steps that can be cached into mems # For the next step, the last `ext_len` of the `qlen` tokens # will be used as the extended context. Hence, we only cache # the tokens from `mlen + qlen - self.ext_len - self.mem_len` # to `mlen + qlen - self.ext_len`. with torch.no_grad(): new_mems = [] end_idx = mlen + max(0, qlen - 0 - self.ext_len) beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)): cat = torch.cat([mems[i], hids[i]], dim=0) new_mems.append(cat[beg_idx:end_idx].detach()) return new_mems def _forward(self, dec_inp, mems=None): qlen, bsz = dec_inp.size() word_emb = self.word_emb(dec_inp) mlen = mems[0].size(0) if mems is not None else 0 klen = mlen + qlen if self.same_length: all_ones = word_emb.new_ones(qlen, klen) mask_len = klen - self.mem_len if mask_len > 0: mask_shift_len = qlen - mask_len else: mask_shift_len = qlen dec_attn_mask = ( torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 else: dec_attn_mask = torch.triu(word_emb.new_ones(qlen, klen), diagonal=1 + mlen).byte()[:, :, None] hids = [] if self.attn_type == 0: # default pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb) pos_emb = self.drop(pos_emb) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] core_out = layer(core_out, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 1: # learnable core_out = self.drop(word_emb) hids.append(core_out) for i, layer in enumerate(self.layers): if self.clamp_len > 0: r_emb = self.r_emb[i][-self.clamp_len:] r_bias = self.r_bias[i][-self.clamp_len:] else: r_emb, r_bias = self.r_emb[i], self.r_bias[i] mems_i = None if mems is None else mems[i] core_out = layer(core_out, r_emb, self.r_w_bias[i], r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 2: # absolute pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb + pos_emb[-qlen:]) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] if mems_i is not None and i == 0: mems_i += pos_emb[:mlen] core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 3: core_out = self.drop(word_emb) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] if mems_i is not None and mlen > 0: cur_emb = self.r_emb[i][:-qlen] cur_size = cur_emb.size(0) if cur_size < mlen: cur_emb_pad = cur_emb[0:1].expand( mlen - cur_size, -1, -1) cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) else: cur_emb = cur_emb[-mlen:] mems_i += cur_emb.view(mlen, 1, -1) core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) core_out = self.drop(core_out) new_mems = self._update_mems(hids, mems, mlen, qlen) return core_out, new_mems def forward(self, data, target, *mems): # nn.DataParallel does not allow size(0) tensors to be broadcasted. # So, have to initialize size(0) mems inside the model forward. # Moreover, have to return new_mems to allow nn.DataParallel to piece # them together. if not mems: mems = self.init_mems() tgt_len = target.size(0) hidden, new_mems = self._forward(data, mems=mems) pred_hid = hidden[-tgt_len:] if self.sample_softmax > 0 and self.training: assert self.tie_weight logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler) loss = -F.log_softmax(logit, -1)[:, :, 0] else: loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) loss = loss.view(tgt_len, -1) if new_mems is None: return [loss] else: return [loss] + new_mems def forward_generate(self, data, *mems): if not mems: mems = self.init_mems() tgt_len = data.size(0) batch_size = data.size(1) hidden, new_mems = self._forward(data, mems=mems) pred_hid = hidden[-tgt_len:] assert self.crit.n_clusters == 0 logits = self.crit._compute_logit(pred_hid.view(-1, pred_hid.size(-1)), self.crit.out_layers[0].weight, self.crit.out_layers[0].bias, self.crit.out_projs[0]) logits = logits.view(tgt_len, batch_size, -1) if new_mems is None: return [logits] else: return [logits] + new_mems
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1): super(MemTransformerLM, self).__init__() self.n_token = n_token # 1000 d_embed = d_model if d_embed is None else d_embed # 200 self.d_embed = d_embed # 200 self.d_model = d_model # 200 self.n_head = n_head # 2 self.d_head = d_head # 2 self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) # 单纯的词向量,没位置信息 self.drop = nn.Dropout(dropout) self.n_layer = n_layer # 4 self.tgt_len = tgt_len # 36 self.mem_len = mem_len # 36 self.ext_len = ext_len # 0 self.max_klen = tgt_len + ext_len + mem_len # 72 self.attn_type = attn_type self.layers = nn.ModuleList() if attn_type == 0: # the default attention for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type == 1: # learnable embeddings for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): self.layers.append( DecoderLayer(n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm)) self.sample_softmax = sample_softmax # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length # False self.clamp_len = clamp_len # -1 self._create_params() # 初始化向量属性
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1): super(MemTransformerLM, self).__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head # number of heads self.d_head = d_head # head dimension self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.attn_type = attn_type self.layers = nn.ModuleList() # These look like all decoder layers if attn_type == 0: # the default attention. this will have relative positional encoding. for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type == 1: # learnable embeddings. Shaw et al. for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type in [ 2, 3 ]: # absolute embeddings. 2 corresponds to original transformer. 3 is for Al-Rfou. for i in range(n_layer): # d_inner is dimension of FFN (2100), n_head=10, d_model=410, d_head=41, d_inner=2100, drop=0.1, dropatt=0, pre_lnor=False self.layers.append( DecoderLayer(n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm)) self.sample_softmax = sample_softmax if sample_softmax > 0: # use sampled softmax self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight # Tie the weights of output and input embedding matrix. self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) # TODO: understand what is adaptive softmax else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[ i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len self._create_params()
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1, th=0, fn="attn_scores.npy"): super(MemTransformerLM, self).__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.attn_type = attn_type self.th = th self.layers = nn.ModuleList() if attn_type == 0: # the default attention for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer( n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm) ) elif attn_type == 1: # learnable embeddings for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer( n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm) ) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): self.layers.append( DecoderLayer( n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm) ) self.sample_softmax = sample_softmax # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len # load average attention patterns # (n_layer, qlen, klen, 1, n_head) weights = np.load(fn) shape = weights.shape n_layer, n_head = shape[0], shape[-1] # init prune mask attribute self.prune_masks = np.zeros(weights.shape) target_percentile = target_percentage * 100 target_percentage = self.th for i in range(0, n_layer): w = weights[i,:,:,:,:] tau = np.percentile(w, target_percentile, interpolation='nearest') self.prune_masks[i,:,:,:,:] = w <= tau self._create_params()
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1, rnnenc=False, rnndim=0, layer_list='', future_len=0, attn_layerlist='', merge_type='direct'): super(MemTransformerLM, self).__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.future_len = future_len self.max_klen = tgt_len + ext_len + mem_len + future_len self.attn_type = attn_type # RNN hidden state carry on self.layer_list = [int(i) for i in layer_list.split()] self.rnnlayer_list = self.layer_list print("rnn layer list: {}".format(self.rnnlayer_list)) if rnnenc and rnndim != 0: if merge_type in ['gating', 'project']: self.rnnproj = nn.Linear(rnndim + d_model, d_model) self.rnn_list = nn.ModuleList([ nn.LSTM(d_model, rnndim, 1) for i in range(len(self.rnnlayer_list)) ]) # attn penalisation self.attn_pen_layers = [int(i) for i in attn_layerlist.split()] self.merge_type = merge_type self.layers = nn.ModuleList() if attn_type == 0: # the default attention for i in range(n_layer): # dropatt = dropatt * 2 if (i == 0 and rnnenc) else dropatt use_penalty = i in self.attn_pen_layers self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm, penalty=use_penalty)) elif attn_type == 1: # learnable embeddings for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): use_penalty = i in self.attn_pen_layers # dropatt = dropatt * 0 if (i == 0 and rnnenc) else dropatt self.layers.append( DecoderLayer(n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm, penalty=use_penalty)) self.sample_softmax = sample_softmax # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[ i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.rnnenc = rnnenc self.rnndim = rnndim self.same_length = same_length self.clamp_len = clamp_len self._create_params()
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1): ''' :param n_token: 单词表的大小 :param n_layer: 16 :param n_head: 10 :param d_model: 410 :param d_head: 41 :param d_inner: 2100 :param dropout: 0.1 :param dropatt: 0.0 :param tie_weight: True :param d_embed: 410 :param div_val: 1 :param tie_projs: [F,T,T,T] :param pre_lnorm: False :param tgt_len: 150 :param ext_len: 0 :param mem_len: 150 :param cutoffs: [20000, 40000, 200000] :param adapt_inp: :param same_length: 训练的时候是False :param attn_type: 0 :param clamp_len: -1 :param sample_softmax: -1 ''' super(MemTransformerLM, self).__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.attn_type = attn_type self.layers = nn.ModuleList() if attn_type == 0: # the default attention for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type == 1: # learnable embeddings for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): self.layers.append( DecoderLayer(n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm)) self.sample_softmax = sample_softmax # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[ i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len self._create_params()
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1): super(MemTransformerLM, self).__init__() # print("###### MemTransformerLM parameters ######") # print("n_token={}".format(n_token),"###") # print("n_layer={}".format(n_layer),"###") # print("n_head",n_head) # print("d_model",d_model) # print("d_inner",d_inner) # print("dropout",dropout) # print("dropatt",dropatt) # print("tie_weight",tie_weight) # print("div_val",div_val) # print("tie_projs",tie_projs) # print("pre_lnorm",pre_lnorm) #print("tgt_len",tgt_len) #print("ext_len",ext_len) #print("mem_len",mem_len) # print("cutoffs",cutoffs) # print("adapt_inp",adapt_inp) # print("same_length",same_length) # print("attn_type",attn_type) # print("clamp_len",clamp_len) # print("sample_softmax",sample_softmax) # print("###### MemTransformerLM parameters ######") self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.attn_type = attn_type self.model = "training" self.layers = nn.ModuleList() #print("##### ModuleList #####") #print("n_layer=",attn_type) #print("attn_type=",attn_type) # attn_type 0 if attn_type == 0: # the default attention for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type == 1: # learnable embeddings for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): self.layers.append( DecoderLayer(n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm)) self.sample_softmax = sample_softmax #print("###### sample_softmax=",sample_softmax) # use sampled softmax # sample_softmax = -1, tie_weight = True, tie_projs = [False] if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[ i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len self._create_params()
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, act_type=0, use_bn=False, clamp_len=-1, sample_softmax=-1, glu_type=0, glu_layers=[]): super(MemTransformerLM, self).__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.act_type = act_type self.attn_type = attn_type print('glu type: ', glu_type) self.layers = nn.ModuleList() if attn_type == 0: # the default attention for i in range(n_layer): if i in glu_layers: print('use GLU in layer-{}'.format(i)) else: glu_type = 0 # no gate self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm, glu_type=glu_type, act_type=act_type, use_bn=use_bn)) elif attn_type == 1: # learnable embeddings for i in range(n_layer): if i in glu_layers: print('use GLU in layer-{}'.format(i)) else: glu_type = 0 self.layers.append( RelLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm, glu_type=glu_type, act_type=act_type)) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): # set not sepecified layer GLU type -> 0 if i in glu_layers: print('use GLU in layer-{}'.format(i)) else: glu_type = 0 self.layers.append( DecoderLayer(n_head, d_model, d_head, d_inner, dropout, glu_type, act_type, dropatt=dropatt, pre_lnorm=pre_lnorm)) self.sample_softmax = sample_softmax # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[ i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len self._create_params()
class MemTransformerLM(nn.Module): def __init__( self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1, ): super(MemTransformerLM, self).__init__() self.n_token = n_token self.training_steps = 0 self.compute = 0 d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.d_inner = d_inner self.n_layer = n_layer self.attn_type = attn_type self.tie_weight = tie_weight self.tie_projs = tie_projs self.drop = nn.Dropout(dropout) self.dropout_p = dropout self.dropatt_p = dropatt self.pre_lnorm = pre_lnorm self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.layers = nn.ModuleList() if attn_type == 0: # the default attention for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type == 1: # learnable embeddings for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): self.layers.append( DecoderLayer(n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm)) self.sample_softmax = sample_softmax self._create_params() # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[ i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len def backward_compatible(self, tie_weight, tie_projs): self.sample_softmax = -1 self.tie_weight = tie_weight self.tie_projs = tie_projs def _create_params(self): if self.attn_type == 0: # default attention self.pos_emb = PositionalEmbedding(self.d_model) # self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) # self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) elif self.attn_type == 1: # learnable self.r_emb = nn.Parameter( torch.Tensor(self.n_layer, self.max_klen, self.n_head, self.d_head)) self.r_w_bias = nn.Parameter( torch.Tensor(self.n_layer, self.n_head, self.d_head)) self.r_bias = nn.Parameter( torch.Tensor(self.n_layer, self.max_klen, self.n_head)) elif self.attn_type == 2: # absolute standard self.pos_emb = PositionalEmbedding(self.d_model) elif self.attn_type == 3: # absolute deeper SA self.r_emb = nn.Parameter( torch.Tensor(self.n_layer, self.max_klen, self.n_head, self.d_head)) def reset_length(self, tgt_len, ext_len, mem_len): self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len def init_mems(self): if self.mem_len > 0: mems = [] param = next(self.parameters()) for i in range(self.n_layer + 1): empty = torch.empty(0, dtype=param.dtype, device=param.device) mems.append(empty) return mems else: return None def _update_mems(self, hids, mems, qlen, mlen): # does not deal with None if mems is None: return None # mems is not None assert len(hids) == len(mems), 'len(hids) != len(mems)' # There are `mlen + qlen` steps that can be cached into mems # For the next step, the last `ext_len` of the `qlen` tokens # will be used as the extended context. Hence, we only cache # the tokens from `mlen + qlen - self.ext_len - self.mem_len` # to `mlen + qlen - self.ext_len`. with torch.no_grad(): new_mems = [] end_idx = mlen + max(0, qlen - 0 - self.ext_len) beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)): cat = torch.cat([mems[i], hids[i]], dim=0) new_mems.append(cat[beg_idx:end_idx].detach()) return new_mems def _forward(self, dec_inp, mems=None): qlen, bsz = dec_inp.size() word_emb = self.word_emb(dec_inp) mlen = mems[0].size(0) if mems is not None else 0 klen = mlen + qlen if self.same_length: all_ones = word_emb.new_ones(qlen, klen) mask_len = klen - self.mem_len if mask_len > 0: mask_shift_len = qlen - mask_len else: mask_shift_len = qlen dec_attn_mask = ( torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1 else: dec_attn_mask = torch.triu(word_emb.new_ones(qlen, klen), diagonal=1 + mlen).bool()[:, :, None] hids = [] if self.attn_type == 0: # default pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb) pos_emb = self.drop(pos_emb) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] core_out = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 1: # learnable core_out = self.drop(word_emb) hids.append(core_out) for i, layer in enumerate(self.layers): if self.clamp_len > 0: r_emb = self.r_emb[i][-self.clamp_len:] r_bias = self.r_bias[i][-self.clamp_len:] else: r_emb, r_bias = self.r_emb[i], self.r_bias[i] mems_i = None if mems is None else mems[i] core_out = layer(core_out, r_emb, self.r_w_bias[i], r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 2: # absolute pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb + pos_emb[-qlen:]) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] if mems_i is not None and i == 0: mems_i += pos_emb[:mlen] core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 3: core_out = self.drop(word_emb) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] if mems_i is not None and mlen > 0: cur_emb = self.r_emb[i][:-qlen] cur_size = cur_emb.size(0) if cur_size < mlen: cur_emb_pad = cur_emb[0:1].expand( mlen - cur_size, -1, -1) cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) else: cur_emb = cur_emb[-mlen:] mems_i += cur_emb.view(mlen, 1, -1) core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) core_out = self.drop(core_out) new_mems = self._update_mems(hids, mems, mlen, qlen) return core_out, new_mems def forward(self, data, target, *mems): # nn.DataParallel does not allow size(0) tensors to be broadcasted. # So, have to initialize size(0) mems inside the model forward. # Moreover, have to return new_mems to allow nn.DataParallel to piece # them together. if not mems: mems = self.init_mems() tgt_len = target.size(0) hidden, new_mems = self._forward(data, mems=mems) pred_hid = hidden[-tgt_len:] if self.sample_softmax > 0 and self.training: assert self.tie_weight logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler) loss = -F.log_softmax(logit, -1)[:, :, 0] else: loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.reshape(-1)) loss = loss.view(tgt_len, -1) if new_mems is None: return [loss] else: return [loss] + new_mems def expand_layers(self, n_add, strategy="repeat", function=None): assert self.attn_type == 0, f"only works with default attention mode, not mode {self.attn}" assert strategy in ["repeat", "reinit", "repeat_bottom", "reinit_bottom", "duplicate"], \ f"initialization mode {strategy} not implemented" duplicate = "duplicate" in strategy bottom = "bottom" in strategy new_layers = nn.ModuleList([]) if duplicate: assert n_add % len(self.layers) == 0, \ f"duplicating the network requires the number of extra layers {n_add} " \ f"to be an integer nultiple of previous length {len(self.layers)}" factor = n_add // len(self.layers) # we append them in the order they'll be found in the network later for layer in self.layers: for _ in range(factor): new_layers.append(deepcopy(layer)) # we interleave the new layers on top of the layer they duplicate positions = [ i for i in range(len(self.layers) + n_add) if i % len(self.layers) != 0 ] for i, new_layer in enumerate(new_layers): self.layers.insert(positions[i], new_layer) else: for _ in range(n_add): new_layer = deepcopy(self.layers[0 if bottom else -1]) if "reinit" in strategy: new_layer.apply(function) new_layers.append(new_layer) if bottom: # not as elegant as extending the end but we have to add modules one by one at the start # the count i is to make sure they're in the same order in new_layers and self.layers for i, layer in enumerate(new_layers): self.layers.insert(i, layer) else: self.layers.extend(new_layers) self.n_layer += n_add return new_layers def add_heads(self, ratio, strategy="duplicate", function=None, expand_inner=True, expand_embeddings=True, change_freq=False): assert self.attn_type == 0, f"only works with default attention mode, not mode {self.attn}" assert strategy in ["reinit", "duplicate"], \ f"initialization mode {strategy} not implemented" self.d_model *= ratio self.n_head *= ratio if expand_inner: self.d_inner *= ratio if expand_embeddings: self.d_embed *= ratio for layer in self.layers: layer.add_heads(ratio, strategy, function, expand_inner) if expand_embeddings: self.crit.widen(ratio, strategy, function) self.word_emb.widen(ratio, strategy, function, self.tie_weight, self.tie_projs) self.pos_emb.widen(ratio, change_freq) return self.layers