Beispiel #1
0
class ResidualFusedLayerNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super(ResidualFusedLayerNorm, self).__init__()

        global fused_layer_norm_cuda
        fused_layer_norm_cuda = importlib.import_module(
            "fused_layer_norm_cuda")

        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (int(normalized_shape / 64), 1)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self._weight = Parameter(torch.Tensor(*normalized_shape))
            self.aweight = Parameter(torch.zeros(8, 63))
            self._bias = Parameter(torch.Tensor(*normalized_shape))
            self.abias = Parameter(torch.zeros(8, 63))
            self.normalized_shape = torch.Size((8, 64))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        if self.elementwise_affine:
            init.ones_(self._weight)
            init.zeros_(self._bias)

    def forward(self, input):
        aweight = list(self.aweight.chunk(63, 1))
        pre = self._weight
        for i in range(63):
            aweight[i] = pre + aweight[i]
            pre = aweight[i]
        aweight = torch.cat(tuple(aweight), -1)
        self.weight = torch.cat((self._weight, aweight), dim=-1)
        abias = list(self.abias.chunk(63, 1))
        pre = self._bias
        for i in range(63):
            abias[i] = pre + abias[i]
            pre = abias[i]
        abias = torch.cat(tuple(abias), -1)
        self.bias = torch.cat((self._bias, abias), dim=-1)
        if not input.is_cuda:
            return F.layer_norm(input, self.normalized_shape, self.weight,
                                self.bias, self.eps)
        if self.elementwise_affine:
            return FusedLayerNormAffineFunction.apply(input, self.weight,
                                                      self.bias,
                                                      self.normalized_shape,
                                                      self.eps)
        else:
            return FusedLayerNormFunction.apply(input, self.normalized_shape,
                                                self.eps)

    def extra_repr(self):
        return '{normalized_shape}, eps={eps}, ' \
            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
Beispiel #2
0
class Attention(qc.Module):
    hs = qc.Hypers(
        ["d_model", "n_heads", "d_k", "d_v"],
        {
            "add_b_kv": False,
            "add_zero_attn": False,
            "batch_first": False,
            "bias": True,
            "drop": 0.0,
        },
    )

    w_pack, w_q, w_k, w_v = None
    b_pack, b_q, b_k, b_v = None

    def __init__(self, n_heads, d_model, hs=[], **kw):
        if n_heads is not None:
            kw.update(n_heads=n_heads)
        if d_model is not None:
            kw.update(d_model=d_model)
        super().__init__([self.hs] + hs, **kw)
        cfg = self.cfg
        n, h = cfg.n_heads, cfg.d_model
        assert h % n == 0
        d_k = cfg.d_k if cfg.d_k is not None else h
        d_v = cfg.d_v if cfg.d_v is not None else h
        self.pack = self.d_k == h and self.d_v == h
        kw = {"device": cfg.device, "dtype": cfg.dtype}
        if self.pack:
            self.w_pack = Parameter(torch.empty((3 * h, h), **kw))
            self.register_parameter("w_q", None)
            self.register_parameter("w_k", None)
            self.register_parameter("w_v", None)
        else:
            self.register_parameter("w_pack", None)
            self.w_q = Parameter(torch.empty((h, h), **kw))
            self.w_k = Parameter(torch.empty((h, d_k), **kw))
            self.w_v = Parameter(torch.empty((h, d_v), **kw))
        if cfg.bias:
            self.b_pack = Parameter(torch.empty(3 * h, **kw))
        else:
            self.register_parameter("b_pack", None)
        self.out = Linear(h, h, bias=cfg.bias, **kw)
        if cfg.add_b_kv:
            self.b_k = Parameter(torch.empty((1, 1, h), **kw))
            self.b_v = Parameter(torch.empty((1, 1, h), **kw))
        else:
            self.register_parameter("b_k", None)
            self.register_parameter("b_v", None)

    def build(self, _):
        if not self.is_built():
            with torch.no_grad():
                self.reset_params()

    def reset_params(self):
        if self.pack:
            nn.init.xavier_uniform_(self.w_pack)
        else:
            nn.init.xavier_uniform_(self.w_q)
            nn.init.xavier_uniform_(self.w_k)
            nn.init.xavier_uniform_(self.w_v)
        if self.b_pack is not None:
            nn.init.constant_(self.b_pack, 0.0)
            nn.init.constant_(self.out.bias, 0.0)
        if self.b_k is not None:
            nn.init.xavier_normal_(self.b_k)
        if self.b_v is not None:
            nn.init.xavier_normal_(self.b_v)

    def forward(self, q, k, v, mask=None, k_mask=None, need_weights=True, average=True):
        cfg = self.cfg
        is_batched = q.dim() == 3
        if cfg.batch_first and is_batched:
            q, k, v = [x.transpose(1, 0) for x in (q, k, v)]
        if self.pack:
            y, w = self.multi_head_attention_forward(
                q,
                k,
                v,
                mask,
                k_mask,
                self.add_zero_attn,
                need_weights=need_weights,
                average=average,
            )
        else:
            y, w = self.multi_head_attention_forward(
                q,
                k,
                v,
                self.add_zero_attn,
                mask,
                k_mask,
                need_weights=need_weights,
                average=average,
            )
        if self.batch_first and is_batched:
            return y.transpose(1, 0), w
        else:
            return y, w

    def project_packed(self, q, k, v):
        w, b = self.w_pack, self.b_pack
        if k is v:
            if q is k:
                return F.linear(q, w, b).chunk(3, dim=-1)
            else:
                H = q.size(-1)
                w_q, w_kv = w.split([H, H * 2])
                if b is None:
                    b_q = b_kv = None
                else:
                    b_q, b_kv = b.split([H, H * 2])
                return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
        else:
            w_q, w_k, w_v = w.chunk(3)
            if b is None:
                b_q = b_k = b_v = None
            else:
                b_q, b_k, b_v = b.chunk(3)
            return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)

    def project(self, q, k, v, bs):
        w_q, w_k, w_v = self.w_q, self.w_k, self.w_v
        H, Dk, Dv = q.size(-1), k.size(-1), v.size(-1)
        assert w_q.shape == (H, H) and w_k.shape == (H, Dk) and w_v.shape == (H, Dv)
        b_q, b_k, b_v = bs
        assert b_q is None or b_q.shape == (H,)
        assert b_k is None or b_k.shape == (H,)
        assert b_v is None or b_v.shape == (H,)
        return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)

    def attention(self, q, k, v, mask=None):
        cfg = self.cfg
        B, Nt, H = q.shape
        q = q / math.sqrt(H)
        w = torch.bmm(q, k.transpose(-2, -1))
        if mask is not None:
            w += mask
        w = softmax(w, dim=-1)
        if self.training and cfg.dropout_p > 0.0:
            w = drop(w, p=self.drop)
        y = torch.bmm(w, v)
        return y, w

    def is_batched(self, q, k, v, k_mask, mask):
        if q.dim() == 3:
            assert k.dim() == 3 and v.dim() == 3
            if k_mask is not None:
                assert k_mask.dim() == 2
            if mask is not None:
                assert mask.dim() in (2, 3)
            return True
        assert q.dim() == 2
        assert k.dim() == 2 and v.dim() == 2
        if k_mask is not None:
            assert k_mask.dim() == 1
        if mask is not None:
            assert mask.dim() in (2, 3)
            if mask.dim() == 3:
                assert mask.shape == (self.cfg.n_heads, q.shape[0], k.shape[0])
        return False

    def multi_head_attention_forward(
        self,
        q,
        k,
        v,
        mask=None,
        k_mask=None,
        add_zero_attn=None,
        need_weights=True,
        static_k=None,
        static_v=None,
        average=True,
    ):
        if not self.is_batched(q, k, v, k_mask, mask):
            q = q.unsqueeze(1)
            k = k.unsqueeze(1)
            v = v.unsqueeze(1)
            if k_mask is not None:
                k_mask = k_mask.unsqueeze(0)
        cfg = self.cfg
        h, n = cfg.d_model, cfg.n_heads
        b_q, b_k, b_v = self.b_q, self.b_k, self.b_v
        if self.pack:
            assert k.shape == v.shape
            q, k, v = self.project_packed(q, k, v)
        else:
            assert k.shape[:2] == v.shape[:2]
            if self.b_pack is None:
                b_q = b_k = b_v = None
            else:
                b_q, b_k, b_v = self.b_pack.chunk(3)
            q, k, v = self.project(q, k, v, (b_q, b_k, b_v))
        d_tgt, d_batch, _ = q.shape
        d_src, _, _ = k.shape
        if mask is not None:
            assert mask.is_floating_point() or mask.dtype == torch.bool
            if mask.dim() == 2:
                assert mask.shape == (d_tgt, d_src)
                mask = mask.unsqueeze(0)
            else:
                assert mask.shape == (d_batch * n, d_tgt, d_src)
        if b_k is not None and b_v is not None:
            assert static_k is None
            assert static_v is None
            k = torch.cat([k, b_k.repeat(1, d_batch, 1)])
            v = torch.cat([v, b_v.repeat(1, d_batch, 1)])
            if mask is not None:
                mask = pad(mask, (0, 1))
            if k_mask is not None:
                k_mask = pad(k_mask, (0, 1))
        else:
            assert b_k is None
            assert b_v is None
        d_head = h // n
        q = q.contiguous().view(d_tgt, d_batch * n, d_head).transpose(0, 1)
        if static_k is None:
            k = k.contiguous().view(k.shape[0], d_batch * n, d_head).transpose(0, 1)
        else:
            assert static_k.size(0) == d_batch * n
            assert static_k.size(2) == d_head
            k = static_k
        if static_v is None:
            v = v.contiguous().view(v.shape[0], d_batch * n, d_head).transpose(0, 1)
        else:
            assert static_v.size(0) == d_batch * n
            assert static_v.size(2) == d_head
            v = static_v
        if add_zero_attn:
            zero_attn_shape = (d_batch * n, 1, d_head)
            k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
            v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
            if mask is not None:
                mask = pad(mask, (0, 1))
            if k_mask is not None:
                k_mask = pad(k_mask, (0, 1))
        d_src = k.size(1)
        if k_mask is not None:
            assert k_mask.shape == (d_batch, d_src)
            k_mask = (
                k_mask.view(d_batch, 1, 1, d_src)
                .expand(-1, n, -1, -1)
                .reshape(d_batch * n, 1, d_src)
            )
            if mask is None:
                mask = k_mask
            elif mask.dtype == torch.bool:
                mask = mask.logical_or(k_mask)
            else:
                mask = mask.masked_fill(k_mask, float("-inf"))
        if mask is not None and mask.dtype == torch.bool:
            mask = torch.zeros_like(mask, dtype=q.dtype).masked_fill_(mask, float("-inf"))
        y, w = _scaled_dot_product_attention(q, k, v, mask)
        y = y.transpose(0, 1).contiguous().view(d_tgt, d_batch, h)
        y = F.linear(y, self.out.weight, self.out.bias)
        if need_weights:
            w = w.view(d_batch, n, d_tgt, d_src)
            if average:
                w = w.sum(dim=1) / n
            return y, w
        else:
            return y, None