def forward( self, query: flow.Tensor, key: flow.Tensor, value: flow.Tensor, mask: Optional[flow.Tensor] = None, ) -> Tuple[flow.Tensor, flow.Tensor]: batch_size = query.size(0) query = self.query(query) key = self.key(key) value = self.value(value) # multi head query = query.view(batch_size, -1, self.num_attention_heads, self.dims_per_head).transpose(1, 2) key = key.view(batch_size, -1, self.num_attention_heads, self.dims_per_head).transpose(1, 2) value = value.view(batch_size, -1, self.num_attention_heads, self.dims_per_head).transpose(1, 2) # self attention context, attention = self.attention(query, key, value, attn_mask=mask) # concat heads context = context.transpose(1, 2).contiguous().view( batch_size, -1, self.hidden_size) output = self.dense(context) return output, attention
def topk_accuracy(output: Tensor, target: Tensor, topk: Sequence[int] = (1, )) -> List[Tensor]: """ https://github.com/pytorch/examples/blob/master/imagenet/main.py#L411 Args: output: [B, C], for C way classification target: [B] """ maxk = max(topk) batch_size = target.size(0) if target.ndim == 2: # Possibly onehot target target = target.max(dim=1).values _, pred = output.topk(maxk, dim=1, largest=True, sorted=True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=False) res.append(correct_k.mul_(100.0 / batch_size)) return res