def forward(self, h, attn_mask=None, mems=None, carry_over_fast_weight=False): # multihead attention # shape h: (len, B, n_head * d_head) if self.pre_lnorm: # layer normalization h = self.layer_norm(h) slen, bsz, _ = h.size() qkvb = self.qkvb_net(h) qkvb = qkvb.view(slen, bsz, self.n_head, 3 * self.d_head + 1) head_q, head_k, head_v, head_beta = torch.split( qkvb, (self.d_head,) * 3 + (1,), -1) head_beta = torch.sigmoid(head_beta) # reshape to (B, heads, len, dim) head_q = head_q.permute(1, 2, 0, 3) head_k = head_k.permute(1, 2, 0, 3) head_v = head_v.permute(1, 2, 0, 3) head_beta = head_beta.permute(1, 2, 0, 3) act = lambda x: F.relu(x) # relu or exp head_k = torch.cat([act(head_k), act(-head_k)], dim=-1) head_q = torch.cat([act(head_q), act(-head_q)], dim=-1) head_k = self.mul_roll_repeat(head_k) head_q = self.mul_roll_repeat(head_q) # normalize k and q, crucial for stable training. head_k = head_k / head_k.sum(-1, keepdim=True) head_q = head_q / head_q.sum(-1, keepdim=True) if self.normalize_attn_scores: denominator_acc = torch.cumsum(head_k, dim=2) if mems is None: mem_fast_weights = torch.zeros( bsz, self.n_head, 2 * self.n_roll * self.d_head, self.d_head, device=head_k.device) else: assert carry_over_fast_weight mem_fast_weights, fast_denom = mems # bsz can be smaller for the last batch mem_fast_weights = mem_fast_weights[:bsz] if self.normalize_attn_scores: denominator_acc = denominator_acc + fast_denom[:bsz] if self.normalize_attn_scores: denominator = torch.einsum( 'lbij,lbij->lbi', denominator_acc, head_q).unsqueeze(-1) layer_out = fast_weight_memory( head_q, head_k, head_v, head_beta, mem_fast_weights) # shape (B, n_head, len, d_head) if self.normalize_attn_scores: layer_out = self.scale * layer_out / (denominator + self.eps) else: layer_out = self.scale * layer_out layer_out = layer_out.transpose(1, 2) layer_out = layer_out.reshape( bsz, slen, self.n_head * self.d_head) layer_out = layer_out.transpose(0, 1) # expect [qlen, B, n_head * d_head] # linear projection attn_out = self.o_net(layer_out) attn_out = self.drop(attn_out) if self.pre_lnorm: # residual connection output = h + attn_out else: # residual connection + layer normalization output = self.layer_norm(h + attn_out) if carry_over_fast_weight: # last values of accumulator should be carried over. # clone is needed as backward modifies the data of fast weight if self.normalize_attn_scores: new_k_acc = denominator_acc[:, :, -1, :].unsqueeze(2).detach() else: new_k_acc = None new_mem = (mem_fast_weights.clone().detach(), new_k_acc) return output, new_mem return output
def forward(self, h, attn_mask=None, mems=None, redraw=True, carry_over_fast_weight=False): # multihead attention # shape h: (len, B, n_head * d_head) if self.pre_lnorm: # layer normalization h = self.layer_norm(h) slen, bsz, _ = h.size() qkvb = self.qkvb_net(h) qkvb = qkvb.view(slen, bsz, self.n_head, 3 * self.d_head + 1) head_q, head_k, head_v, head_beta = torch.split( qkvb, (self.d_head,) * 3 + (1,), -1) head_beta = torch.sigmoid(head_beta) # reshape to (B, heads, len, dim) head_q = head_q.permute(1, 2, 0, 3) head_k = head_k.permute(1, 2, 0, 3) head_v = head_v.permute(1, 2, 0, 3) head_beta = head_beta.permute(1, 2, 0, 3) if redraw: self.proj_matrix = draw_orthogonal_random_matrix( self.d_head, self.proj_dim, device=h.device) head_q = prime(head_q, self.proj_matrix) # (B, n_head, len, proj_dim) head_k = prime(head_k, self.proj_matrix) # normalize k and q, crucial for stable training. head_k = head_k / head_k.sum(-1, keepdim=True) head_q = head_q / head_q.sum(-1, keepdim=True) if self.normalize_attn_scores: # another version would be: # head_k_beta = head_k * head_beta # denominator_acc = torch.cumsum(head_k_beta, dim=2) denominator_acc = torch.cumsum(head_k, dim=2) if mems is None: mem_fast_weights = torch.zeros( bsz, self.n_head, 2 * self.proj_dim, self.d_head, device=head_k.device) if self.normalize_attn_scores: # key_denom = z(i-1) * key(i) and 1 if i=1 # z(i) = denominator_acc key_denom = torch.cat( [torch.zeros([bsz, self.n_head, 1, self.proj_dim * 2], device=head_q.device), denominator_acc[:, :, :-1, :].clone()], dim=2) key_denom = torch.einsum('lbij,lbij->lbi', key_denom, head_k) key_denom = torch.cat( [torch.ones([bsz, self.n_head, 1], device=head_q.device), key_denom[:, :, 1:].clone()], dim=2).unsqueeze(-1) head_beta = head_beta * key_denom head_k = head_k / (key_denom + self.eps) else: assert carry_over_fast_weight mem_fast_weights, fast_denom = mems # bsz can be smaller for the last batch mem_fast_weights = mem_fast_weights[:bsz] if self.normalize_attn_scores: key_denom = torch.cat( [torch.zeros([bsz, self.n_head, 1, self.proj_dim * 2], device=head_q.device), denominator_acc[:, :, :-1, :].clone()], dim=2) key_denom = key_denom + fast_denom[:bsz] denominator_acc = denominator_acc + fast_denom[:bsz] key_denom = torch.einsum( 'lbij,lbij->lbi', key_denom, head_k).unsqueeze(-1) head_beta = head_beta * key_denom head_k = head_k / (key_denom + self.eps) if self.normalize_attn_scores: denominator = torch.einsum( 'lbij,lbij->lbi', denominator_acc, head_q).unsqueeze(-1) layer_out = fast_weight_memory( head_q, head_k, head_v, head_beta, mem_fast_weights) # shape (B, n_head, len, d_head) if self.normalize_attn_scores: layer_out = self.scale * layer_out / (denominator + self.eps) else: layer_out = self.scale * layer_out layer_out = layer_out.transpose(1, 2) layer_out = layer_out.reshape( bsz, slen, self.n_head * self.d_head) layer_out = layer_out.transpose(0, 1) # expect [qlen, B, n_head * d_head] # linear projection attn_out = self.o_net(layer_out) attn_out = self.drop(attn_out) if self.pre_lnorm: # residual connection output = h + attn_out else: # residual connection + layer normalization output = self.layer_norm(h + attn_out) if carry_over_fast_weight: # last values of accumulator should be carried over. # clone is needed as backward modifies the data of fast weight if self.normalize_attn_scores: new_k_acc = denominator_acc[:, :, -1, :].unsqueeze(2).detach() else: new_k_acc = None new_mem = (mem_fast_weights.clone().detach(), new_k_acc) return output, new_mem return output