示例#1
0
    def forward(
        self,
        query: flow.Tensor,
        key: flow.Tensor,
        value: flow.Tensor,
        mask: Optional[flow.Tensor] = None,
    ) -> Tuple[flow.Tensor, flow.Tensor]:
        batch_size = query.size(0)

        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        # multi head
        query = query.view(batch_size, -1, self.num_attention_heads,
                           self.dims_per_head).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_attention_heads,
                       self.dims_per_head).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_attention_heads,
                           self.dims_per_head).transpose(1, 2)

        # self attention
        context, attention = self.attention(query, key, value, attn_mask=mask)
        # concat heads
        context = context.transpose(1, 2).contiguous().view(
            batch_size, -1, self.hidden_size)
        output = self.dense(context)

        return output, attention
示例#2
0
def topk_accuracy(output: Tensor, target: Tensor,
                  topk: Sequence[int] = (1, )) -> List[Tensor]:
    """
    https://github.com/pytorch/examples/blob/master/imagenet/main.py#L411

    Args:
        output: [B, C], for C way classification
        target: [B]
    """
    maxk = max(topk)
    batch_size = target.size(0)

    if target.ndim == 2:
        # Possibly onehot target
        target = target.max(dim=1).values

    _, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=False)
        res.append(correct_k.mul_(100.0 / batch_size))

    return res