Beispiel #1
0
def make_features(x: Tensor[DType, N]) -> Tensor[DType, N, D4]:
    """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4]."""
    # x = x.unsqueeze(1)
    x2 = torch.unsqueeze(x, 1)
    # return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
    r: Tensor[DType, N,
              D4] = torch.cat([x2**i for i in range(1, POLY_DEGREE + 1)], 1)
    return r
Beispiel #2
0
    def forward(
        self,
        input: Tensor[BS, QLEN, DIM],
        mask: Union[Tensor[BS, KLEN], Tensor[BS, KLEN, KLEN]],
        kv: Optional[Tensor[BS, KLEN, DIM]],
        cache: Optional[
            Dict[
                int,
                Tuple[
                    Tensor[BS, N_HEADS, Any, DIM_PER_HEAD],
                    Tensor[BS, N_HEADS, Any, DIM_PER_HEAD],
                ],
            ]
        ],
        cache_slen: int,
        dim_per_head: DIM_PER_HEAD,
    ) -> Tensor[BS, QLEN, DIM]:
        """
        Self-attention (if kv is None) or attention
        over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()

        if kv is None:
            klen = qlen if cache is None else cache_slen + qlen
        else:
            klen = kv.size(1)

        # dim_per_head = dim // self.n_heads #> dim_per_head cannot be a literal
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)

        def shape(x: Tensor[BS, Any, DIM]) -> Tensor[BS, N_HEADS, Any, DIM_PER_HEAD]:
            """projection"""
            # variables defined outside of the body of the function are not typed
            bs: BS
            dim_per_head: DIM_PER_HEAD
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(
            x: Tensor[BS, N_HEADS, QLEN, DIM_PER_HEAD]
        ) -> Tensor[BS, QLEN, Any]:
            """compute context"""
            return (
                x.transpose(1, 2)
                .contiguous()
                .view(bs, -1, mult(self.n_heads, dim_per_head))
            )

        q = shape(self.q_lin(input))  # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))  # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))  # (bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k_lin(k))  # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(v))  # (bs, n_heads, qlen, dim_per_head)

        if cache is not None:
            if self.layer_id in cache:
                if kv is None:
                    k_, v_ = cache[self.layer_id]
                    k = torch.cat(k_, k, dim=2)  # (bs, n_heads, klen, dim_per_head)
                    v = torch.cat(v_, v, dim=2)  # (bs, n_heads, klen, dim_per_head)
                else:
                    k, v = cache[self.layer_id]
            cache[self.layer_id] = (k, v)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, qlen, dim_per_head)
        scores: Tensor[BS, N_HEADS, QLEN, KLEN] = torch.matmul(
            q, k.transpose(2, 3)
        )  # (bs, n_heads, qlen, klen)
        mask2 = (
            (mask == 0).view(mask_reshape).expand_as(scores)
        )  # (bs, n_heads, qlen, klen)
        scores2 = scores.masked_fill(mask2, -float("inf"))  # (bs, n_heads, qlen, klen)
        weights = F.softmax(scores2.float(), dim=-1).type_as(
            scores
        )  # (bs, n_heads, qlen, klen)
        weights = F.dropout(
            weights, p=self.dropout, training=0.5
        )  # self.training)  # (bs, n_heads, qlen, klen)
        context: Tensor[BS, N_HEADS, QLEN, DIM_PER_HEAD] = torch.matmul(
            weights, v
        )  # (bs, n_heads, qlen, dim_per_head)
        context2: Tensor[BS, QLEN, DIM] = unshape(context)  # (bs, qlen, dim)

        return self.out_lin(context2)