class ResidualFusedLayerNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): super(ResidualFusedLayerNorm, self).__init__() global fused_layer_norm_cuda fused_layer_norm_cuda = importlib.import_module( "fused_layer_norm_cuda") if isinstance(normalized_shape, numbers.Integral): normalized_shape = (int(normalized_shape / 64), 1) self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self._weight = Parameter(torch.Tensor(*normalized_shape)) self.aweight = Parameter(torch.zeros(8, 63)) self._bias = Parameter(torch.Tensor(*normalized_shape)) self.abias = Parameter(torch.zeros(8, 63)) self.normalized_shape = torch.Size((8, 64)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): if self.elementwise_affine: init.ones_(self._weight) init.zeros_(self._bias) def forward(self, input): aweight = list(self.aweight.chunk(63, 1)) pre = self._weight for i in range(63): aweight[i] = pre + aweight[i] pre = aweight[i] aweight = torch.cat(tuple(aweight), -1) self.weight = torch.cat((self._weight, aweight), dim=-1) abias = list(self.abias.chunk(63, 1)) pre = self._bias for i in range(63): abias[i] = pre + abias[i] pre = abias[i] abias = torch.cat(tuple(abias), -1) self.bias = torch.cat((self._bias, abias), dim=-1) if not input.is_cuda: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) if self.elementwise_affine: return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) else: return FusedLayerNormFunction.apply(input, self.normalized_shape, self.eps) def extra_repr(self): return '{normalized_shape}, eps={eps}, ' \ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
class Attention(qc.Module): hs = qc.Hypers( ["d_model", "n_heads", "d_k", "d_v"], { "add_b_kv": False, "add_zero_attn": False, "batch_first": False, "bias": True, "drop": 0.0, }, ) w_pack, w_q, w_k, w_v = None b_pack, b_q, b_k, b_v = None def __init__(self, n_heads, d_model, hs=[], **kw): if n_heads is not None: kw.update(n_heads=n_heads) if d_model is not None: kw.update(d_model=d_model) super().__init__([self.hs] + hs, **kw) cfg = self.cfg n, h = cfg.n_heads, cfg.d_model assert h % n == 0 d_k = cfg.d_k if cfg.d_k is not None else h d_v = cfg.d_v if cfg.d_v is not None else h self.pack = self.d_k == h and self.d_v == h kw = {"device": cfg.device, "dtype": cfg.dtype} if self.pack: self.w_pack = Parameter(torch.empty((3 * h, h), **kw)) self.register_parameter("w_q", None) self.register_parameter("w_k", None) self.register_parameter("w_v", None) else: self.register_parameter("w_pack", None) self.w_q = Parameter(torch.empty((h, h), **kw)) self.w_k = Parameter(torch.empty((h, d_k), **kw)) self.w_v = Parameter(torch.empty((h, d_v), **kw)) if cfg.bias: self.b_pack = Parameter(torch.empty(3 * h, **kw)) else: self.register_parameter("b_pack", None) self.out = Linear(h, h, bias=cfg.bias, **kw) if cfg.add_b_kv: self.b_k = Parameter(torch.empty((1, 1, h), **kw)) self.b_v = Parameter(torch.empty((1, 1, h), **kw)) else: self.register_parameter("b_k", None) self.register_parameter("b_v", None) def build(self, _): if not self.is_built(): with torch.no_grad(): self.reset_params() def reset_params(self): if self.pack: nn.init.xavier_uniform_(self.w_pack) else: nn.init.xavier_uniform_(self.w_q) nn.init.xavier_uniform_(self.w_k) nn.init.xavier_uniform_(self.w_v) if self.b_pack is not None: nn.init.constant_(self.b_pack, 0.0) nn.init.constant_(self.out.bias, 0.0) if self.b_k is not None: nn.init.xavier_normal_(self.b_k) if self.b_v is not None: nn.init.xavier_normal_(self.b_v) def forward(self, q, k, v, mask=None, k_mask=None, need_weights=True, average=True): cfg = self.cfg is_batched = q.dim() == 3 if cfg.batch_first and is_batched: q, k, v = [x.transpose(1, 0) for x in (q, k, v)] if self.pack: y, w = self.multi_head_attention_forward( q, k, v, mask, k_mask, self.add_zero_attn, need_weights=need_weights, average=average, ) else: y, w = self.multi_head_attention_forward( q, k, v, self.add_zero_attn, mask, k_mask, need_weights=need_weights, average=average, ) if self.batch_first and is_batched: return y.transpose(1, 0), w else: return y, w def project_packed(self, q, k, v): w, b = self.w_pack, self.b_pack if k is v: if q is k: return F.linear(q, w, b).chunk(3, dim=-1) else: H = q.size(-1) w_q, w_kv = w.split([H, H * 2]) if b is None: b_q = b_kv = None else: b_q, b_kv = b.split([H, H * 2]) return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) else: w_q, w_k, w_v = w.chunk(3) if b is None: b_q = b_k = b_v = None else: b_q, b_k, b_v = b.chunk(3) return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) def project(self, q, k, v, bs): w_q, w_k, w_v = self.w_q, self.w_k, self.w_v H, Dk, Dv = q.size(-1), k.size(-1), v.size(-1) assert w_q.shape == (H, H) and w_k.shape == (H, Dk) and w_v.shape == (H, Dv) b_q, b_k, b_v = bs assert b_q is None or b_q.shape == (H,) assert b_k is None or b_k.shape == (H,) assert b_v is None or b_v.shape == (H,) return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) def attention(self, q, k, v, mask=None): cfg = self.cfg B, Nt, H = q.shape q = q / math.sqrt(H) w = torch.bmm(q, k.transpose(-2, -1)) if mask is not None: w += mask w = softmax(w, dim=-1) if self.training and cfg.dropout_p > 0.0: w = drop(w, p=self.drop) y = torch.bmm(w, v) return y, w def is_batched(self, q, k, v, k_mask, mask): if q.dim() == 3: assert k.dim() == 3 and v.dim() == 3 if k_mask is not None: assert k_mask.dim() == 2 if mask is not None: assert mask.dim() in (2, 3) return True assert q.dim() == 2 assert k.dim() == 2 and v.dim() == 2 if k_mask is not None: assert k_mask.dim() == 1 if mask is not None: assert mask.dim() in (2, 3) if mask.dim() == 3: assert mask.shape == (self.cfg.n_heads, q.shape[0], k.shape[0]) return False def multi_head_attention_forward( self, q, k, v, mask=None, k_mask=None, add_zero_attn=None, need_weights=True, static_k=None, static_v=None, average=True, ): if not self.is_batched(q, k, v, k_mask, mask): q = q.unsqueeze(1) k = k.unsqueeze(1) v = v.unsqueeze(1) if k_mask is not None: k_mask = k_mask.unsqueeze(0) cfg = self.cfg h, n = cfg.d_model, cfg.n_heads b_q, b_k, b_v = self.b_q, self.b_k, self.b_v if self.pack: assert k.shape == v.shape q, k, v = self.project_packed(q, k, v) else: assert k.shape[:2] == v.shape[:2] if self.b_pack is None: b_q = b_k = b_v = None else: b_q, b_k, b_v = self.b_pack.chunk(3) q, k, v = self.project(q, k, v, (b_q, b_k, b_v)) d_tgt, d_batch, _ = q.shape d_src, _, _ = k.shape if mask is not None: assert mask.is_floating_point() or mask.dtype == torch.bool if mask.dim() == 2: assert mask.shape == (d_tgt, d_src) mask = mask.unsqueeze(0) else: assert mask.shape == (d_batch * n, d_tgt, d_src) if b_k is not None and b_v is not None: assert static_k is None assert static_v is None k = torch.cat([k, b_k.repeat(1, d_batch, 1)]) v = torch.cat([v, b_v.repeat(1, d_batch, 1)]) if mask is not None: mask = pad(mask, (0, 1)) if k_mask is not None: k_mask = pad(k_mask, (0, 1)) else: assert b_k is None assert b_v is None d_head = h // n q = q.contiguous().view(d_tgt, d_batch * n, d_head).transpose(0, 1) if static_k is None: k = k.contiguous().view(k.shape[0], d_batch * n, d_head).transpose(0, 1) else: assert static_k.size(0) == d_batch * n assert static_k.size(2) == d_head k = static_k if static_v is None: v = v.contiguous().view(v.shape[0], d_batch * n, d_head).transpose(0, 1) else: assert static_v.size(0) == d_batch * n assert static_v.size(2) == d_head v = static_v if add_zero_attn: zero_attn_shape = (d_batch * n, 1, d_head) k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) if mask is not None: mask = pad(mask, (0, 1)) if k_mask is not None: k_mask = pad(k_mask, (0, 1)) d_src = k.size(1) if k_mask is not None: assert k_mask.shape == (d_batch, d_src) k_mask = ( k_mask.view(d_batch, 1, 1, d_src) .expand(-1, n, -1, -1) .reshape(d_batch * n, 1, d_src) ) if mask is None: mask = k_mask elif mask.dtype == torch.bool: mask = mask.logical_or(k_mask) else: mask = mask.masked_fill(k_mask, float("-inf")) if mask is not None and mask.dtype == torch.bool: mask = torch.zeros_like(mask, dtype=q.dtype).masked_fill_(mask, float("-inf")) y, w = _scaled_dot_product_attention(q, k, v, mask) y = y.transpose(0, 1).contiguous().view(d_tgt, d_batch, h) y = F.linear(y, self.out.weight, self.out.bias) if need_weights: w = w.view(d_batch, n, d_tgt, d_src) if average: w = w.sum(dim=1) / n return y, w else: return y, None