def __init__(self, n_token, n_layer, eval_n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weights=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, wnorm=False, tgt_len=None, mem_len=None, local_size=0, pretrain_steps=1, cutoffs=[], load=''): super().__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.iodrop = VariationalDropout() self.dropout = dropout self.drop = VariationalHidDropout(dropout=dropout) self.pretrain_steps = pretrain_steps self.tgt_len = tgt_len self.mem_len = mem_len self.local_size = local_size self.max_klen = tgt_len + mem_len self.n_layer = n_layer self.eval_n_layer = eval_n_layer self.inject_conv = nn.Conv1d(d_model, 3*d_model, kernel_size=1) self.pos_emb = PositionalEmbedding(self.d_model) self.func = RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout=dropout, dropatt=dropatt, pre_lnorm=pre_lnorm, local_size=local_size) self.func_copy = copy.deepcopy(self.func) if wnorm: # If wnorm is specified, we need to make sure to do the deepcopy first # because pytorch prevents non-user defined variables from being copied. # The deepcopy will only access and copy the `xxx.weight` instead # of messing with `xxx.weight_g` that wnorm actually optimizes. self.func.wnorm() for params in self.func_copy.parameters(): params.requires_grad_(False) # Turn off autograd for func_copy self.deq = TransformerDEQForward(self.func) self.deqback = TransformerDEQBackward(self.func, self.func_copy) # use adaptive softmax (including standard softmax) # (Note: To use sample softmax, refer to the Transformer-XL implementation) self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weights: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] if len(load) > 0: params_dict = torch.load(load) self.load_weights(params_dict) print(f"Finished loading. d_embed={self.inject_conv.weight.data.size(1)}")
def __init__(self, n_token_in, n_token_out, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, conv_size=7, conv_emb=False, pre_conv=False, tie_weight=True, d_embed=None, tie_projs=[False], tgt_len=None, mem_len=None, ext_ds=None, cutoffs=[], same_length=False, clamp_len=-1): super().__init__() self.n_token_in = n_token_in self.n_token_out = n_token_out d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.conv_size = conv_size self.pre_conv = pre_conv self.conv_emb = conv_emb if conv_emb: self.word_emb = ConvEmbeddings( n_token_in, d_embed, d_model, conv_size, cutoffs, ) else: self.word_emb = AdaptiveEmbedding( n_token_in, d_embed, d_model, cutoffs, ) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_ds = ext_ds self.max_klen = tgt_len + mem_len self.layers = nn.ModuleList() for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, conv_size=conv_size, pre_conv=pre_conv, tgt_len=tgt_len, mem_len=mem_len, dropatt=dropatt)) self.out_layer = nn.Linear(d_model, n_token_out) # use adaptive softmax (including standard softmax) self.crit = ProjectedAdaptiveLogSoftmax(n_token_out, d_embed, d_model, cutoffs) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[ i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len self._create_params()
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, dtype, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1): super(MemTransformerLM, self).__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.tie_weight = tie_weight self.tie_projs = tie_projs self.div_val = div_val self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.attn_type = attn_type self.layers = nn.ModuleList() # the default attention if attn_type == 0: for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) # learnable embeddings elif attn_type == 1: for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer(n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) # absolute embeddings elif attn_type in [2, 3]: for i in range(n_layer): self.layers.append( DecoderLayer(n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm)) self.sample_softmax = sample_softmax # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: if tie_weight: emb_layers = [i.weight for i in self.word_emb.emb_layers] else: emb_layers = None emb_projs = self.word_emb.emb_projs self.crit = ProjectedAdaptiveLogSoftmax( n_token, d_embed, d_model, cutoffs, div_val=div_val, tie_projs=tie_projs, out_projs=emb_projs, out_layers_weights=emb_layers) self.same_length = same_length self.clamp_len = clamp_len self._create_params()
class MemTransformerLM(nn.Module): def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1): super(MemTransformerLM, self).__init__() self.n_token = n_token d_embed = d_model if d_embed is None else d_embed self.d_embed = d_embed self.d_model = d_model self.n_head = n_head self.d_head = d_head self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) self.drop = nn.Dropout(dropout) self.n_layer = n_layer self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len self.max_klen = tgt_len + ext_len + mem_len self.attn_type = attn_type self.layers = nn.ModuleList() if attn_type == 0: # the default attention for i in range(n_layer): self.layers.append( RelPartialLearnableDecoderLayer( n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm) ) elif attn_type == 1: # learnable embeddings for i in range(n_layer): self.layers.append( RelLearnableDecoderLayer( n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm) ) elif attn_type in [2, 3]: # absolute embeddings for i in range(n_layer): self.layers.append( DecoderLayer( n_head, d_model, d_head, d_inner, dropout, dropatt=dropatt, pre_lnorm=pre_lnorm) ) self.sample_softmax = sample_softmax # use sampled softmax if sample_softmax > 0: self.out_layer = nn.Linear(d_model, n_token) if tie_weight: self.out_layer.weight = self.word_emb.weight self.tie_weight = tie_weight self.sampler = LogUniformSampler(n_token, sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val) if tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight if tie_projs: for i, tie_proj in enumerate(tie_projs): if tie_proj and div_val == 1 and d_model != d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] elif tie_proj and div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] self.same_length = same_length self.clamp_len = clamp_len self._create_params() def backward_compatible(self): self.sample_softmax = -1 def _create_params(self): if self.attn_type == 0: # default attention self.pos_emb = PositionalEmbedding(self.d_model) self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) elif self.attn_type == 1: # learnable self.r_emb = nn.Parameter(torch.Tensor( self.n_layer, self.max_klen, self.n_head, self.d_head)) self.r_w_bias = nn.Parameter(torch.Tensor( self.n_layer, self.n_head, self.d_head)) self.r_bias = nn.Parameter(torch.Tensor( self.n_layer, self.max_klen, self.n_head)) elif self.attn_type == 2: # absolute standard self.pos_emb = PositionalEmbedding(self.d_model) elif self.attn_type == 3: # absolute deeper SA self.r_emb = nn.Parameter(torch.Tensor( self.n_layer, self.max_klen, self.n_head, self.d_head)) def reset_length(self, tgt_len, ext_len, mem_len): self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len def init_mems(self): if self.mem_len > 0: mems = [] param = next(self.parameters()) for i in range(self.n_layer+1): empty = torch.empty(0, dtype=param.dtype, device=param.device) mems.append(empty) return mems else: return None def _update_mems(self, hids, mems, qlen, mlen): # does not deal with None if mems is None: return None # mems is not None assert len(hids) == len(mems), 'len(hids) != len(mems)' # There are `mlen + qlen` steps that can be cached into mems # For the next step, the last `ext_len` of the `qlen` tokens # will be used as the extended context. Hence, we only cache # the tokens from `mlen + qlen - self.ext_len - self.mem_len` # to `mlen + qlen - self.ext_len`. with torch.no_grad(): new_mems = [] end_idx = mlen + max(0, qlen - 0 - self.ext_len) beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)): cat = torch.cat([mems[i], hids[i]], dim=0) new_mems.append(cat[beg_idx:end_idx].detach()) return new_mems def _forward(self, dec_inp, mems=None): qlen, bsz = dec_inp.size() word_emb = self.word_emb(dec_inp) mlen = mems[0].size(0) if mems is not None else 0 klen = mlen + qlen if self.same_length: all_ones = word_emb.new_ones(qlen, klen) mask_len = klen - self.mem_len if mask_len > 0: mask_shift_len = qlen - mask_len else: mask_shift_len = qlen dec_attn_mask = (torch.triu(all_ones, 1+mlen) + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 else: dec_attn_mask = torch.triu( word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] hids = [] if self.attn_type == 0: # default pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb) pos_emb = self.drop(pos_emb) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] core_out = layer(core_out, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 1: # learnable core_out = self.drop(word_emb) hids.append(core_out) for i, layer in enumerate(self.layers): if self.clamp_len > 0: r_emb = self.r_emb[i][-self.clamp_len :] r_bias = self.r_bias[i][-self.clamp_len :] else: r_emb, r_bias = self.r_emb[i], self.r_bias[i] mems_i = None if mems is None else mems[i] core_out = layer(core_out, r_emb, self.r_w_bias[i], r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 2: # absolute pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb + pos_emb[-qlen:]) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] if mems_i is not None and i == 0: mems_i += pos_emb[:mlen] core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 3: core_out = self.drop(word_emb) hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] if mems_i is not None and mlen > 0: cur_emb = self.r_emb[i][:-qlen] cur_size = cur_emb.size(0) if cur_size < mlen: cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1) cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) else: cur_emb = cur_emb[-mlen:] mems_i += cur_emb.view(mlen, 1, -1) core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) core_out = layer(core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) core_out = self.drop(core_out) new_mems = self._update_mems(hids, mems, mlen, qlen) return core_out, new_mems def forward_generate(self, data, *mems): if not mems: mems = self.init_mems() tgt_len = data.size(0) batch_size = data.size(1) hidden, new_mems = self._forward(data, mems=mems) pred_hid = hidden[-tgt_len:] assert self.crit.n_clusters == 0 logits = self.crit._compute_logit( pred_hid.view(-1, pred_hid.size(-1)), self.crit.out_layers[0].weight, self.crit.out_layers[0].bias, self.crit.out_projs[0]) logits = logits.view(tgt_len, batch_size, -1) if new_mems is None: return [logits] else: return [logits] + new_mems def forward(self, data, target, *mems): # nn.DataParallel does not allow size(0) tensors to be broadcasted. # So, have to initialize size(0) mems inside the model forward. # Moreover, have to return new_mems to allow nn.DataParallel to piece # them together. if not mems: mems = self.init_mems() tgt_len = target.size(0) hidden, new_mems = self._forward(data, mems=mems) pred_hid = hidden[-tgt_len:] if self.sample_softmax > 0 and self.training: assert self.tie_weight logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler) loss = -F.log_softmax(logit, -1)[:, :, 0] else: loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) loss = loss.view(tgt_len, -1) if new_mems is None: return [loss] else: return [loss] + new_mems