Beispiel #1
0
    def forward(self, queries, keys, values, attn_mask, visualize=False):
        # q, k, v 線形のNNを通ってくるが最初は同じ
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        # print(str(visualize))
        # これを可視化するにはどうしたらいいか
        if visualize:
            attention_weight = scores.to('cpu').detach().numpy().copy()
            np.save("./results/attention/attention_weight.npy",
                    attention_weight)

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            # 動的ネットワーク?? 流れてくるデータに応じて構造が変わる??
            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return (V.contiguous(), A)
        else:
            return (V.contiguous(), None)
Beispiel #2
0
    def forward(self, queries, keys, values, attn_mask):
        B, L_Q, H, E = queries.shape
        _, L, _, _ = keys.shape
        # _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L_Q, device=queries.device)

            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        return A
Beispiel #3
0
  def forward(self, queries, keys, values, attn_mask):
    B, L, H, E = queries.shape
    _, S, _, D = values.shape
    scale = self.scale or 1./sqrt(E)

    scores = torch.einsum("blhe,bshe->bhls", queries, keys)
    if self.mask_flag:
      if attn_mask is None:
        attn_mask = TriangularCausalMask(B, L, device=queries.device)
      scores.masked_fill_(attn_mask.mask, -np.inf)

    A = self.dropout(torch.softmax(scale * scores, dim=-1))
    V = torch.einsum("bhls,bshd->blhd", A, values)

    if self.output_attention:
      return (V.contiguous(), A)
    else:
      return (V.contiguous(), None)
Beispiel #4
0
    def call(self, inputs, attn_mask=None):
        queries, keys, values = inputs
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = tf.einsum("blhe,bshe->bhls", queries, keys)
        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L)

            # https://stackoverflow.com/questions/47447272/does-tensorflow-have-the-function-similar-to-pytorchs-masked-fill
            num = 3.4 * math.pow(10, 38)
            scores = (scores * attn_mask.mask) + (-((attn_mask.mask * num + num) - num))

        A = self.dropout(tf.keras.activations.softmax(scale * scores, axis=-1))
        V = tf.einsum("bhls,bshd->blhd", A, values)

        return V