Пример #1
0
    def perform_attention(self, query, dim_per_head, key, relations_keys, mask,
                          value, relations_values, batch_size, head_count,
                          query_len, key_len, shape, unshape):

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)
        # batch x num_heads x query_len x key_len
        query_key = torch.matmul(query, key.transpose(2, 3))

        if self.max_relative_positions > 0 and type == "self":
            scores = query_key + relative_matmul(query, relations_keys, True)
        else:
            scores = query_key
        scores = scores.float()

        if mask is not None:
            mask = mask.unsqueeze(1)  # [B, 1, 1, T_values]
            scores = scores.masked_fill(mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)
        drop_attn = self.dropout(attn)

        context_original = torch.matmul(drop_attn, value)

        if self.max_relative_positions > 0 and type == "self":
            context = unshape(
                context_original +
                relative_matmul(drop_attn, relations_values, False))
        else:
            context = unshape(context_original)

        output = self.final_linear(context)

        # Return one attn
        top_attn = attn \
            .view(batch_size, head_count,
                  query_len, key_len)[:, 0, :, :] \
            .contiguous()

        return output, top_attn
Пример #2
0
    def forward(self,
                key,
                value,
                query,
                mask=None,
                layer_cache=None,
                attn_type=None,
                gold_par_attn=None,
                gold_ch_attn=None):
        """
        Compute the context vector and the attention vectors.

        Args:
           key (FloatTensor): set of `key_len`
               key vectors ``(batch, key_len, dim)``
           value (FloatTensor): set of `key_len`
               value vectors ``(batch, key_len, dim)``
           query (FloatTensor): set of `query_len`
               query vectors  ``(batch, query_len, dim)``
           mask: binary mask 1/0 indicating which keys have
               zero / non-zero attention ``(batch, query_len, key_len)``
        Returns:
           (FloatTensor, FloatTensor):

           * output context vectors ``(batch, query_len, dim)``
           * one of the attention vectors ``(batch, query_len, key_len)``
        """

        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)

        def shape(x):
            """Projection."""
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            """Compute context."""
            return x.transpose(1, 2).contiguous() \
                    .view(batch_size, -1, head_count * dim_per_head)

        def predict_ch_label(attn, value):

            if not self.opt.biaffine:
                value = unshape(value)
            ch_attn_repeat = torch.repeat_interleave(attn, value.size(2), dim=2) \
                .view(value.size(0), value.size(1), value.size(1), value.size(2))
            value_repeat = torch.repeat_interleave(value, value.size(1), dim=1) \
                .view(value.size(0), value.size(1), value.size(1), value.size(2)).transpose(1,2).contiguous()
            chs = ch_attn_repeat * value_repeat
            ch_label_h = torch.cat([chs, value_repeat], 3)
            if self.opt.biaffine:
                w_ch = self.Wlabel_ch_linear(chs, value_repeat)
                b_ch = self.blabel_ch_linear(ch_label_h)
                ch_labels = w_ch + b_ch
            else:
                ch_labels = self.p_ch_label(ch_label_h)
            return ch_labels

        def predict_par_label(attn, value):
            if not self.opt.biaffine:
                value = unshape(value)
            par = torch.matmul(attn, value)
            par_label_h = torch.cat([par, value], 2)
            #par_label_h = torch.cat([value, par],2)
            if self.opt.biaffine:
                w_par = self.Wlabel_par_linear(par, value)
                b_par = self.blabel_par_linear(par_label_h)
                par_labels = w_par + b_par
            else:
                par_labels = self.p_par_label(par_label_h)
                #par_labels = self.p_par_label(par_label_h)
            return par_labels

        # 1) Project key, value, and query.
        if layer_cache is not None:
            if attn_type == "self":
                query, key, value = self.linear_query(query),\
                                    self.linear_keys(query),\
                                    self.linear_values(query)
                key = shape(key)
                value = shape(value)
                if layer_cache["self_keys"] is not None:
                    key = torch.cat((layer_cache["self_keys"], key), dim=2)
                if layer_cache["self_values"] is not None:
                    value = torch.cat((layer_cache["self_values"], value),
                                      dim=2)
                layer_cache["self_keys"] = key
                layer_cache["self_values"] = value
            elif attn_type == "context":
                query = self.linear_query(query)
                if layer_cache["memory_keys"] is None:
                    key, value = self.linear_keys(key),\
                                 self.linear_values(value)
                    key = shape(key)
                    value = shape(value)
                else:
                    key, value = layer_cache["memory_keys"],\
                               layer_cache["memory_values"]
                layer_cache["memory_keys"] = key
                layer_cache["memory_values"] = value
        else:
            key = self.linear_keys(key)
            value = self.linear_values(value)
            query = self.linear_query(query)
            key = shape(key)
            value = shape(value)

        if self.max_relative_positions > 0 and attn_type == "self":
            key_len = key.size(2)
            # 1 or key_len x key_len
            relative_positions_matrix = generate_relative_positions_matrix(
                key_len,
                self.max_relative_positions,
                cache=True if layer_cache is not None else False)
            #  1 or key_len x key_len x dim_per_head
            relations_keys = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))
            #  1 or key_len x key_len x dim_per_head
            relations_values = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))

        query = shape(query)

        key_len = key.size(2)
        query_len = query.size(2)

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)

        if self.label_emb is not None and self.opt.biaffine:
            query_key = torch.matmul(query[:, 2:], key.transpose(2, 3)[:, 2:])
            w_par = torch.matmul(self.Warc_par_linear(query[:, 0]),
                                 key.transpose(2, 3)[:, 0])
            w_ch = torch.matmul(self.Warc_ch_linear(query[:, 1]),
                                key.transpose(2, 3)[:, 1])
            b_par = self.barc_par_linear(query[:, 0]).repeat_interleave(
                query.size(2), dim=2)
            b_ch = self.barc_ch_linear(query[:, 1]).repeat_interleave(
                query.size(2), dim=2)
            arc_par = w_par + b_par
            arc_ch = w_ch + b_ch
            query_key = torch.cat(
                [arc_par.unsqueeze(1),
                 arc_ch.unsqueeze(1), query_key], dim=1)
        else:
            # batch x num_heads x query_len x key_len
            query_key = torch.matmul(query, key.transpose(2, 3))

        if self.max_relative_positions > 0 and attn_type == "self":
            scores = query_key + relative_matmul(query, relations_keys, True)
        else:
            scores = query_key
        scores = scores.float()
        if mask is not None:
            mask = mask.unsqueeze(1)  # [B, 1, 1, T_values]
            scores = scores.masked_fill(mask, -1e18)
        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)

        drop_attn = self.dropout(attn)

        context_original = torch.matmul(drop_attn, value)

        if self.max_relative_positions > 0 and attn_type == "self":
            context = unshape(
                context_original +
                relative_matmul(drop_attn, relations_values, False))
        else:
            context = unshape(context_original)

        ch_labels = None
        par_labels = None

        if gold_ch_attn is not None:
            if self.opt.biaffine:
                par_labels = predict_par_label(gold_par_attn, value[:, 0])
                ch_labels = predict_ch_label(gold_ch_attn, value[:, 1])
            else:
                par_labels = predict_par_label(gold_par_attn, value)
                ch_labels = predict_ch_label(gold_ch_attn, value)

        output = self.final_linear(context)
        top_attn = attn \
            .view(batch_size, head_count,
                  query_len, key_len)[:, 0, :, :] \
            .contiguous()
        second_attn = attn \
            .view(batch_size, head_count,
                  query_len, key_len)[:, 1, :, :] \
            .contiguous()

        return output, top_attn, second_attn, ch_labels, par_labels
Пример #3
0
    def forward(self,
                key,
                value,
                query,
                mask=None,
                layer_cache=None,
                attn_type=None,
                decoder=False):
        """
        Compute the context vector and the attention vectors.

        Args:
           key (FloatTensor): set of `key_len`
               key vectors ``(batch, key_len, dim)``
           value (FloatTensor): set of `key_len`
               value vectors ``(batch, key_len, dim)``
           query (FloatTensor): set of `query_len`
               query vectors  ``(batch, query_len, dim)``
           mask: binary mask 1/0 indicating which keys have
               zero / non-zero attention ``(batch, query_len, key_len)``
           decoder: indicates the self-attention is coming from the decoder.
        Returns:
           (FloatTensor, FloatTensor):

           * output context vectors ``(batch, query_len, dim)``
           * Attention vector in heads ``(batch, head, query_len, key_len)``.
        """

        # CHECKS
        # batch, k_len, d = key.size()
        # batch_, k_len_, d_ = value.size()
        # aeq(batch, batch_)
        # aeq(k_len, k_len_)
        # aeq(d, d_)
        # batch_, q_len, d_ = query.size()
        # aeq(batch, batch_)
        # aeq(d, d_)
        # aeq(self.model_dim % 8, 0)
        # if mask is not None:
        #    batch_, q_len_, k_len_ = mask.size()
        #    aeq(batch_, batch)
        #    aeq(k_len_, k_len)
        #    aeq(q_len_ == q_len)
        # END CHECKS

        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)

        use_causal = decoder is True and attn_type == 'self' and self.training is True

        def shape(x):
            """Projection."""
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            """Compute context."""
            return x.transpose(1, 2).contiguous() \
                    .view(batch_size, 0, head_count * dim_per_head)

        # 1) Project key, value, and query.
        if layer_cache is not None:
            if attn_type == "self":
                query, key, value = self.linear_query(query),\
                                    self.linear_keys(query),\
                                    self.linear_values(query)
                key = shape(key)
                value = shape(value)
                if layer_cache["self_keys"] is not None:
                    key = torch.cat((layer_cache["self_keys"], key), dim=2)
                if layer_cache["self_values"] is not None:
                    value = torch.cat((layer_cache["self_values"], value),
                                      dim=2)
                layer_cache["self_keys"] = key
                layer_cache["self_values"] = value
            elif attn_type == "context":
                query = self.linear_query(query)
                if layer_cache["memory_keys"] is None:
                    key, value = self.linear_keys(key),\
                                 self.linear_values(value)
                    key = shape(key)
                    value = shape(value)
                else:
                    key, value = layer_cache["memory_keys"],\
                               layer_cache["memory_values"]
                layer_cache["memory_keys"] = key
                layer_cache["memory_values"] = value
        else:
            key = self.linear_keys(key)
            value = self.linear_values(value)
            query = self.linear_query(query)
            key = shape(key)
            value = shape(value)

        if self.max_relative_positions > 0 and attn_type == "self":
            key_len = key.size(2)
            # 1 or key_len x key_len
            relative_positions_matrix = generate_relative_positions_matrix(
                key_len,
                self.max_relative_positions,
                cache=True if layer_cache is not None else False)
            #  1 or key_len x key_len x dim_per_head
            relations_keys = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))
            #  1 or key_len x key_len x dim_per_head
            relations_values = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))

        query = shape(query)

        key_len = key.size(2)
        query_len = query.size(2)

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)
        # batch x num_heads x query_len x key_len
        query_key = torch.matmul(query, key.transpose(2, 3))
        # Elliott: mask out some backward pass.
        if use_causal:
            assert query_len == key_len
            bk_mask = 1.0 - torch.diag(
                torch.ones(query_len, device=mask.device)).unsqueeze(0).repeat(
                    [batch_size, 1, 1]).to(mask.dtype)
            # [bz, len, len]
            bk_mask = (bk_mask + mask).gt(0).to(query_key.dtype)
            # [bz, 1, len, len]
            incre_mask = (bk_mask.to(mask.dtype) - mask).to(
                query_key.dtype).unsqueeze(1)
            # [bz, 1, len, len]
            bk_mask = bk_mask.unsqueeze(1)
            # [bz, num_heads, len, len]
            query_key_detach = query_key.detach()
            # [bz, num_heads, len, len]
            query_key = bk_mask * query_key + incre_mask * query_key_detach

        if self.max_relative_positions > 0 and attn_type == "self":
            scores = query_key + relative_matmul(query, relations_keys, True)
        else:
            scores = query_key
        scores = scores.float()

        if mask is not None:
            mask = mask.unsqueeze(1)  # [B, 1, 1, T_values]
            scores = scores.masked_fill(mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)
        drop_attn = self.dropout(attn)

        if use_causal:
            # [bz, num_heads, q_len, k_len, 1] * [bz, num_heads, 1, k_len, dim] --> [bz, num_heads, q_len, k_len, dim]
            context_original = drop_attn.unsqueeze(-1) * value.unsqueeze(2)
            context_original_detach = context_original.detach()
            context_original = bk_mask.unsqueeze(
                -1) * context_original + incre_mask.unsqueeze(
                    -1) * context_original_detach
            # [bz, num_heads, q_len, dim]
            context_original = context_original.sum(3)

        else:
            context_original = torch.matmul(drop_attn, value)

        if self.max_relative_positions > 0 and attn_type == "self":
            context = unshape(
                context_original +
                relative_matmul(drop_attn, relations_values, False))
        else:
            context = unshape(context_original)

        output = self.final_linear(context)
        # CHECK
        # batch_, q_len_, d_ = output.size()
        # aeq(q_len, q_len_)
        # aeq(batch, batch_)
        # aeq(d, d_)

        # Return multi-head attn
        attns = attn \
            .view(batch_size, head_count,
                  query_len, key_len)

        return output, attns
Пример #4
0
    def forward(self,
                key,
                value,
                query,
                mask=None,
                layer_cache=None,
                type=None):
        """
        Compute the context vector and the attention vectors.

        Args:
           key (FloatTensor): set of `key_len`
               key vectors ``(batch, key_len, dim)``
           value (FloatTensor): set of `key_len`
               value vectors ``(batch, key_len, dim)``
           query (FloatTensor): set of `query_len`
               query vectors  ``(batch, query_len, dim)``
           mask: binary mask indicating which keys have
               non-zero attention ``(batch, query_len, key_len)``
        Returns:
           (FloatTensor, FloatTensor):

           * output context vectors ``(batch, query_len, dim)``
           * one of the attention vectors ``(batch, query_len, key_len)``
        """

        # CHECKS
        # batch, k_len, d = key.size()
        # batch_, k_len_, d_ = value.size()
        # aeq(batch, batch_)
        # aeq(k_len, k_len_)
        # aeq(d, d_)
        # batch_, q_len, d_ = query.size()
        # aeq(batch, batch_)
        # aeq(d, d_)
        # aeq(self.model_dim % 8, 0)
        # if mask is not None:
        #    batch_, q_len_, k_len_ = mask.size()
        #    aeq(batch_, batch)
        #    aeq(k_len_, k_len)
        #    aeq(q_len_ == q_len)
        # END CHECKS

        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)
        device = key.device

        def shape(x):
            """Projection."""
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            """Compute context."""
            return x.transpose(1, 2).contiguous() \
                    .view(batch_size, -1, head_count * dim_per_head)

        # 1) Project key, value, and query.
        if layer_cache is not None:
            if type == "self":
                query, key, value = self.linear_query(query),\
                                    self.linear_keys(query),\
                                    self.linear_values(query)
                key = shape(key)
                value = shape(value)
                if layer_cache["self_keys"] is not None:
                    key = torch.cat((layer_cache["self_keys"].to(device), key),
                                    dim=2)
                if layer_cache["self_values"] is not None:
                    value = torch.cat(
                        (layer_cache["self_values"].to(device), value), dim=2)
                layer_cache["self_keys"] = key
                layer_cache["self_values"] = value
            elif type == "context":
                query = self.linear_query(query)
                if layer_cache["memory_keys"] is None:
                    key, value = self.linear_keys(key),\
                                 self.linear_values(value)
                    key = shape(key)
                    value = shape(value)
                else:
                    key, value = layer_cache["memory_keys"],\
                               layer_cache["memory_values"]
                layer_cache["memory_keys"] = key
                layer_cache["memory_values"] = value
        else:
            key = self.linear_keys(key)
            value = self.linear_values(value)
            query = self.linear_query(query)
            key = shape(key)
            value = shape(value)

        if self.max_relative_positions > 0 and type == "self":
            key_len = key.size(2)
            # 1 or key_len x key_len
            relative_positions_matrix = generate_relative_positions_matrix(
                key_len,
                self.max_relative_positions,
                cache=True if layer_cache is not None else False)
            #  1 or key_len x key_len x dim_per_head
            relations_keys = self.relative_positions_embeddings(
                relative_positions_matrix.to(device))
            #  1 or key_len x key_len x dim_per_head
            relations_values = self.relative_positions_embeddings(
                relative_positions_matrix.to(device))

        query = shape(query)

        key_len = key.size(2)
        query_len = query.size(2)

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)
        # batch x num_heads x query_len x key_len
        query_key = torch.matmul(query, key.transpose(2, 3))

        if self.max_relative_positions > 0 and type == "self":
            scores = query_key + relative_matmul(query, relations_keys, True)
        else:
            scores = query_key
        scores = scores.float()

        if mask is not None:
            mask = mask.unsqueeze(1)  # [B, 1, 1, T_values]
            scores = scores.masked_fill(mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)
        drop_attn = self.dropout(attn)

        context_original = torch.matmul(drop_attn, value)

        if self.max_relative_positions > 0 and type == "self":
            context = unshape(
                context_original +
                relative_matmul(drop_attn, relations_values, False))
        else:
            context = unshape(context_original)

        output = self.final_linear(context)
        # CHECK
        # batch_, q_len_, d_ = output.size()
        # aeq(q_len, q_len_)
        # aeq(batch, batch_)
        # aeq(d, d_)

        # Return one attn
        top_attn = attn \
            .view(batch_size, head_count,
                  query_len, key_len)[:, 0, :, :] \
            .contiguous()

        return output, top_attn
Пример #5
0
    def forward(self, key, value, query, mask=None,
                layer_cache=None, type=None):
        """
        Compute the context vector and the attention vectors.

        Args:
           key (FloatTensor): set of `key_len`
               key vectors ``(batch, key_len, dim)``
           value (FloatTensor): set of `key_len`
               value vectors ``(batch, key_len, dim)``
           query (FloatTensor): set of `query_len`
               query vectors  ``(batch, query_len, dim)``
           mask: binary mask indicating which keys have
               non-zero attention ``(batch, query_len, key_len)``
        Returns:
           (FloatTensor, FloatTensor):

           * output context vectors ``(batch, query_len, dim)``
           * one of the attention vectors ``(batch, query_len, key_len)``
        """

        # CHECKS
        # batch, k_len, d = key.size()
        # batch_, k_len_, d_ = value.size()
        # aeq(batch, batch_)
        # aeq(k_len, k_len_)
        # aeq(d, d_)
        # batch_, q_len, d_ = query.size()
        # aeq(batch, batch_)
        # aeq(d, d_)
        # aeq(self.model_dim % 8, 0)
        # if mask is not None:
        #    batch_, q_len_, k_len_ = mask.size()
        #    aeq(batch_, batch)
        #    aeq(k_len_, k_len)
        #    aeq(q_len_ == q_len)
        # END CHECKS

        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)
        device = key.device

        def shape(x):
            """Projection."""
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            """Compute context."""
            return x.transpose(1, 2).contiguous() \
                    .view(batch_size, -1, head_count * dim_per_head)

        # 1) Project key, value, and query.
        if layer_cache is not None:
            if type == "self":
                query, key, value = self.linear_query(query),\
                                    self.linear_keys(query),\
                                    self.linear_values(query)
                key = shape(key)
                value = shape(value)
                if layer_cache["self_keys"] is not None:
                    key = torch.cat(
                        (layer_cache["self_keys"].to(device), key),
                        dim=2)
                if layer_cache["self_values"] is not None:
                    value = torch.cat(
                        (layer_cache["self_values"].to(device), value),
                        dim=2)
                layer_cache["self_keys"] = key
                layer_cache["self_values"] = value
            elif type == "context":
                query = self.linear_query(query)
                if layer_cache["memory_keys"] is None:
                    key, value = self.linear_keys(key),\
                                 self.linear_values(value)
                    key = shape(key)
                    value = shape(value)
                else:
                    key, value = layer_cache["memory_keys"],\
                               layer_cache["memory_values"]
                layer_cache["memory_keys"] = key
                layer_cache["memory_values"] = value
        else:
            key = self.linear_keys(key)
            value = self.linear_values(value)
            query = self.linear_query(query)
            key = shape(key)
            value = shape(value)

        if self.max_relative_positions > 0 and type == "self":
            key_len = key.size(2)
            # 1 or key_len x key_len
            relative_positions_matrix = generate_relative_positions_matrix(
                key_len, self.max_relative_positions,
                cache=True if layer_cache is not None else False)
            #  1 or key_len x key_len x dim_per_head
            relations_keys = self.relative_positions_embeddings(
                relative_positions_matrix.to(device))
            #  1 or key_len x key_len x dim_per_head
            relations_values = self.relative_positions_embeddings(
                relative_positions_matrix.to(device))

        query = shape(query)

        key_len = key.size(2)
        query_len = query.size(2)

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)
        # batch x num_heads x query_len x key_len
        query_key = torch.matmul(query, key.transpose(2, 3))

        if self.max_relative_positions > 0 and type == "self":
            scores = query_key + relative_matmul(query, relations_keys, True)
        else:
            scores = query_key
        scores = scores.float()

        if mask is not None:
            mask = mask.unsqueeze(1)  # [B, 1, 1, T_values]
            scores = scores.masked_fill(mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)
        drop_attn = self.dropout(attn)

        context_original = torch.matmul(drop_attn, value)

        if self.max_relative_positions > 0 and type == "self":
            context = unshape(context_original
                              + relative_matmul(drop_attn,
                                                relations_values,
                                                False))
        else:
            context = unshape(context_original)

        output = self.final_linear(context)
        # CHECK
        # batch_, q_len_, d_ = output.size()
        # aeq(q_len, q_len_)
        # aeq(batch, batch_)
        # aeq(d, d_)

        # Return one attn
        top_attn = attn \
            .view(batch_size, head_count,
                  query_len, key_len)[:, 0, :, :] \
            .contiguous()

        return output, top_attn
Пример #6
0
    def forward(self, key, value, query, mask=None,
                layer_cache=None, attn_type=None):
        """
        Compute the context vector and the attention vectors.

        Args:
           key (FloatTensor): set of `key_len`
               key vectors ``(batch, key_len, dim)``
           value (FloatTensor): set of `key_len`
               value vectors ``(batch, key_len, dim)``
           query (FloatTensor): set of `query_len`
               query vectors  ``(batch, query_len, dim)``
           mask: binary mask 1/0 indicating which keys have
               zero / non-zero attention ``(batch, query_len, key_len)``
        Returns:
           (FloatTensor, FloatTensor):

           * output context vectors ``(batch, query_len, dim)``
           * one of the attention vectors ``(batch, query_len, key_len)``
        """

        # CHECKS
        # batch, k_len, d = key.size()
        # batch_, k_len_, d_ = value.size()
        # aeq(batch, batch_)
        # aeq(k_len, k_len_)
        # aeq(d, d_)
        # batch_, q_len, d_ = query.size()
        # aeq(batch, batch_)
        # aeq(d, d_)
        # aeq(self.model_dim % 8, 0)
        # if mask is not None:
        #    batch_, q_len_, k_len_ = mask.size()
        #    aeq(batch_, batch)
        #    aeq(k_len_, k_len)
        #    aeq(q_len_ == q_len)
        # END CHECKS

        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)

        def shape(x):
            """Projection."""
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            """Compute context."""
            return x.transpose(1, 2).contiguous() \
                    .view(batch_size, -1, head_count * dim_per_head)

        if self.with_saliency_selection:
            selection_query = self.linear_selection_query(query)
            selection_key = self.linear_selection_key(key)

            selection_key = shape(selection_key)
            selection_query = shape(selection_query)

        # 1) Project key, value, and query.
        if layer_cache is not None:
            if attn_type == "self":
                query, key, value = self.linear_query(query),\
                                    self.linear_keys(query),\
                                    self.linear_values(query)
                key = shape(key)
                value = shape(value)
                if layer_cache["self_keys"] is not None:
                    key = torch.cat(
                        (layer_cache["self_keys"], key),
                        dim=2)
                if layer_cache["self_values"] is not None:
                    value = torch.cat(
                        (layer_cache["self_values"], value),
                        dim=2)
                layer_cache["self_keys"] = key
                layer_cache["self_values"] = value
            elif attn_type == "context":
                query = self.linear_query(query)
                if layer_cache["memory_keys"] is None:
                    key, value = self.linear_keys(key),\
                                 self.linear_values(value)
                    key = shape(key)
                    value = shape(value)
                else:
                    key, value = layer_cache["memory_keys"],\
                               layer_cache["memory_values"]
                layer_cache["memory_keys"] = key
                layer_cache["memory_values"] = value

        else:
            key = self.linear_keys(key)
            value = self.linear_values(value)
            query = self.linear_query(query)
            key = shape(key)
            value = shape(value)


        if self.max_relative_positions > 0 and attn_type == "self":
            key_len = key.size(2)
            # 1 or key_len x key_len
            relative_positions_matrix = generate_relative_positions_matrix(
                key_len, self.max_relative_positions,
                cache=True if layer_cache is not None else False)
            #  1 or key_len x key_len x dim_per_head
            relations_keys = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))
            #  1 or key_len x key_len x dim_per_head
            relations_values = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))


        if self.with_focus_attention == True:
            glo = torch.mean(query, dim=1, keepdim=True)

            c = self.tanh(self.linear_focus_query(query) + self.linear_focus_global(glo))
            # c = self.tanh(self.linear_focus_query(query))#  + self.linear_focus_global(glo))
            c = shape(c)

            p = c * self.up
            p = p.sum(3).squeeze()
            z = c * self.uz
            z = z.sum(3).squeeze()

            P = self.sigmoid(p) * key_len
            Z = self.sigmoid(z) * key_len

            j = torch.arange(start=0, end=key_len, dtype=P.dtype).unsqueeze(0).unsqueeze(0).unsqueeze(0).to('cuda')
            P = P.unsqueeze(-1)
            Z = Z.unsqueeze(-1)

            G = - (j-P)**2 * 2 / (Z**2)

        query = shape(query)
        if self.with_saliency_selection == True:
            # gate_key = self.linear_selection_key(unshape(key))
            # gate_query = self.linear_selection_query(unshape(query))

            # gate_key = shape(gate_key)
            # gate_query = shape(gate_query)

            gate = self.sigmoid(torch.matmul(selection_query, selection_key.transpose(2, 3)))

        key_len = key.size(2)
        query_len = query.size(2)


        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)

        # batch x num_heads x query_len x key_len
        query_key = torch.matmul(query, key.transpose(2, 3))

        if self.max_relative_positions > 0 and attn_type == "self":
            scores = query_key + relative_matmul(query, relations_keys, True)
        else:
            scores = query_key
        scores = scores.float()

        if self.with_focus_attention == True:
            scores = scores + G

        if mask is not None:
            mask = mask.unsqueeze(1)  # [B, 1, 1, T_values]
            scores = scores.masked_fill(mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)
        if self.with_saliency_selection:
            new_attn = attn * gate
            drop_attn = self.dropout(new_attn)
        else:
            drop_attn = self.dropout(attn)

        context_original = torch.matmul(drop_attn, value)

        if self.max_relative_positions > 0 and attn_type == "self":
            print('relative')
            context = unshape(context_original
                              + relative_matmul(drop_attn,
                                                relations_values,
                                                False))
        else:
            context = unshape(context_original)

        output = self.final_linear(context)
        # CHECK
        # batch_, q_len_, d_ = output.size()
        # aeq(q_len, q_len_)
        # aeq(batch, batch_)
        # aeq(d, d_)

        # Return one attn
        top_attn = attn \
            .view(batch_size, head_count,
                  query_len, key_len)[:, 0, :, :] \
            .contiguous()

        return output, top_attn
Пример #7
0
    def forward(self,
                self_kvq,
                ctx_kv,
                self_mask=None,
                ctx_mask=None,
                layer_cache=None,
                type=None):
        """
        Compute the context vector and the attention vectors.

        Args:
           self_kvq (FloatTensor): set of `self_len`
               key vectors ``(batch, self_len, dim)``
           ctz_kv (FloatTensor): set of `ctx_len`
               value vectors ``(batch, ctx_len, dim)``
           mask: binary mask indicating which keys have
               non-zero attention ``(batch, self_len, self_len)``
        Returns:
           (FloatTensor, FloatTensor):

           * output context vectors ``(batch, self_len, dim)``
           * one of the attention vectors ``(batch, self_len, ctx_len)``
        """

        # CHECKS
        # batch, k_len, d = key.size()
        # batch_, k_len_, d_ = value.size()
        # aeq(batch, batch_)
        # aeq(k_len, k_len_)
        # aeq(d, d_)
        # batch_, q_len, d_ = query.size()
        # aeq(batch, batch_)
        # aeq(d, d_)
        # aeq(self.model_dim % 8, 0)
        # if mask is not None:
        #    batch_, q_len_, k_len_ = mask.size()
        #    aeq(batch_, batch)
        #    aeq(k_len_, k_len)
        #    aeq(q_len_ == q_len)
        # END CHECKS

        batch_size = self_kvq.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        self_len = self_kvq.size(1)
        ctx_len = ctx_kv.size(1)
        device = self_kvq.device

        def shape(x):
            """Projection."""
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            """Compute context."""
            return x.transpose(1, 2).contiguous() \
                    .view(batch_size, -1, head_count * dim_per_head)

        # 1) Project key, value, and query.
        if layer_cache is not None:
            query, self_key, self_value = self.linear_query(self_kvq),\
                                          self.linear_keys(self_kvq),\
                                          self.linear_values(self_kvq)
            #self_key = shape(self_key)
            #self_value = shape(self_value)
            if layer_cache["self_keys"] is not None:
                self_key = torch.cat(
                    (layer_cache["self_keys"].to(device), self_key), dim=1)
            if layer_cache["self_values"] is not None:
                self_value = torch.cat(
                    (layer_cache["self_values"].to(device), self_value), dim=1)
            layer_cache["self_keys"] = self_key
            layer_cache["self_values"] = self_value

            if layer_cache["memory_keys"] is None:
                ctx_key = self.ctx_linear_keys(ctx_kv)  # [batch, ctx_len, dim]
                ctx_value = self.ctx_linear_values(ctx_kv)
                layer_cache["memory_keys"] = ctx_key
                layer_cache["memory_values"] = ctx_value
            else:
                ctx_key = layer_cache["memory_keys"]
                ctx_value = layer_cache["memory_values"]
        else:
            self_key = self.linear_keys(self_kvq)  # [batch, self_len, dim]
            self_value = self.linear_values(self_kvq)
            query = self.linear_query(self_kvq)

            ctx_key = self.ctx_linear_keys(ctx_kv)  # [batch, ctx_len, dim]
            ctx_value = self.ctx_linear_values(ctx_kv)

        self_len = self_key.size(
            1)  # Need to do this again to include the layer_cache length
        ctx_len = ctx_key.shape[1]

        key = torch.cat((self_key, ctx_key), dim=1)
        value = torch.cat((self_value, ctx_value), dim=1)

        key = shape(key)
        value = shape(value)

        if self.max_relative_positions > 0 and type == "self":
            raise NotImplementedError
            key_len = key.size(2)
            # 1 or key_len x key_len
            relative_positions_matrix = generate_relative_positions_matrix(
                key_len,
                self.max_relative_positions,
                cache=True if layer_cache is not None else False)
            #  1 or key_len x key_len x dim_per_head
            relations_keys = self.relative_positions_embeddings(
                relative_positions_matrix.to(device))
            #  1 or key_len x key_len x dim_per_head
            relations_values = self.relative_positions_embeddings(
                relative_positions_matrix.to(device))

        query = shape(query)

        key_len = key.size(2)  # self_len+ctx_len
        query_len = query.size(2)  # self_len

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)
        # batch x num_heads x query_len x key_len
        query_key = torch.matmul(query, key.transpose(
            2, 3))  # [batch, head, self_len, self_len+ctx_len]

        if self.ctx_weight_param:
            query_key[..., self_len:] += self.ctx_bias
        #print(query_key.mean(), query_key.std())

        if self.max_relative_positions > 0 and type == "self":
            scores = query_key + relative_matmul(query, relations_keys, True)
        else:
            scores = query_key
        scores = scores.float()

        if self_mask is not None:
            self_mask = self_mask.unsqueeze(1)  # [B, 1, self_len, self_len]
            scores[:, :, :, :self_len] = scores[:, :, :, :
                                                self_len].masked_fill(
                                                    self_mask, -1e18)
        if ctx_mask is not None:
            ctx_mask = ctx_mask.unsqueeze(1)  # [B, 1, 1, ctx_len]
            scores[:, :, :,
                   self_len:] = scores[:, :, :,
                                       self_len:].masked_fill(ctx_mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)
        drop_attn = self.dropout(attn)

        context_original = torch.matmul(drop_attn,
                                        value)  # [batch, head, self_len, dim]

        if self.max_relative_positions > 0 and type == "self":
            context = unshape(
                context_original +
                relative_matmul(drop_attn, relations_values, False))
        else:
            context = unshape(context_original)

        output = self.final_linear(context)
        # CHECK
        # batch_, q_len_, d_ = output.size()
        # aeq(q_len, q_len_)
        # aeq(batch, batch_)
        # aeq(d, d_)

        # Return one attn (to context)
        ctx_attn_probs = attn[:, :, :, self_len:]
        ctx_attn_probs = ctx_attn_probs / ctx_attn_probs.sum(dim=-1,
                                                             keepdim=True)

        top_attn = ctx_attn_probs \
            .view(batch_size, head_count,
                  query_len, ctx_len)[:, 0, :, :] \
            .contiguous()

        return output, top_attn, attn
Пример #8
0
    def forward(self,
                key,
                value,
                query,
                grh=None,
                mask=None,
                layer_cache=None,
                attn_type=None):
        """
        Compute the context vector and the attention vectors.

        Args:
           key (FloatTensor): set of `key_len`
               key vectors ``(batch, key_len, dim)``
           value (FloatTensor): set of `key_len`
               value vectors ``(batch, key_len, dim)``
           query (FloatTensor): set of `query_len`
               query vectors  ``(batch, query_len, dim)``
           mask: binary mask 1/0 indicating which keys have
               zero / non-zero attention ``(batch, query_len, key_len)``
        Returns:
           (FloatTensor, FloatTensor):

           * output context vectors ``(batch, query_len, dim)``
           * one of the attention vectors ``(batch, query_len, key_len)``
        """

        # CHECKS
        # batch, k_len, d = key.size()
        # batch_, k_len_, d_ = value.size()
        # aeq(batch, batch_)
        # aeq(k_len, k_len_)
        # aeq(d, d_)
        # batch_, q_len, d_ = query.size()
        # aeq(batch, batch_)
        # aeq(d, d_)
        # aeq(self.model_dim % 8, 0)
        # if mask is not None:
        #    batch_, q_len_, k_len_ = mask.size()
        #    aeq(batch_, batch)
        #    aeq(k_len_, k_len)
        #    aeq(q_len_ == q_len)
        # END CHECKS

        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)

        def shape(x):
            """Projection."""
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            """Compute context."""
            return x.transpose(1, 2).contiguous() \
                    .view(batch_size, -1, head_count * dim_per_head)

        # 1) Project key, value, and query.
        if layer_cache is not None:
            if attn_type == "self":
                query, key, value = self.linear_query(query),\
                                    self.linear_keys(query),\
                                    self.linear_values(query)
                key = shape(key)
                value = shape(value)
                if layer_cache["self_keys"] is not None:
                    key = torch.cat((layer_cache["self_keys"], key), dim=2)
                if layer_cache["self_values"] is not None:
                    value = torch.cat((layer_cache["self_values"], value),
                                      dim=2)
                layer_cache["self_keys"] = key
                layer_cache["self_values"] = value
            elif attn_type == "context":
                query = self.linear_query(query)
                if layer_cache["memory_keys"] is None:
                    key, value = self.linear_keys(key),\
                                 self.linear_values(value)
                    key = shape(key)
                    value = shape(value)
                else:
                    key, value = layer_cache["memory_keys"],\
                               layer_cache["memory_values"]
                layer_cache["memory_keys"] = key
                layer_cache["memory_values"] = value
        else:
            key = self.linear_keys(key)
            value = self.linear_values(value)
            query = self.linear_query(query)
            key = shape(key)
            value = shape(value)

        if self.max_relative_positions > 0 and attn_type == "self":
            key_len = key.size(2)
            # 1 or key_len x key_len
            relative_positions_matrix = generate_relative_positions_matrix(
                key_len,
                self.max_relative_positions,
                cache=True if layer_cache is not None else False)
            #  1 or key_len x key_len x dim_per_head
            relations_keys = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))
            #  1 or key_len x key_len x dim_per_head
            relations_values = self.relative_positions_embeddings(
                relative_positions_matrix.to(key.device))

        query = shape(query)

        key_len = key.size(2)
        query_len = query.size(2)

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)
        # batch x num_heads x query_len x key_len
        query_key = torch.matmul(query, key.transpose(2, 3))

        if self.max_relative_positions > 0 and attn_type == "self":
            scores = query_key + relative_matmul(query, relations_keys, True)
        else:
            scores = query_key
        scores = scores.float()

        if mask is not None:
            mask = mask.unsqueeze(1)  # [B, 1, 1, T_values]
            scores = scores.masked_fill(mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.softmax(scores).to(query.dtype)
        drop_attn = self.dropout(attn)

        context_original = torch.matmul(drop_attn, value)

        if self.max_relative_positions > 0 and attn_type == "self":
            context = unshape(
                context_original +
                relative_matmul(drop_attn, relations_values, False))
        else:
            context = unshape(context_original)
        if self.gate:
            gate = torch.sigmoid(self.gate_linear[0](context))
            gate_context = gate * context
        # CHECK
        # batch_, q_len_, d_ = output.size()
        # aeq(q_len, q_len_)
        # aeq(batch, batch_)
        # aeq(d, d_)

        # Return one attn
        top_attn = attn \
            .view(batch_size, head_count,
                  query_len, key_len)[:, 0, :, :] \
            .contiguous()

        ## above is fully-connected graph
        ## Multi-View self-attention
        if grh is not None:
            assert query_len == key_len
            #assert key_len-1 != grh[0][-1][0], "the num of nodes is not consistent"
            views = []
            index = [(0, 1), (2, 3), (4, 5), (6, 7)]

            # whole sub graph
            h_i = self.linear_attention[index[-1][0]](value)
            h_j = self.linear_attention[index[-1][1]](value)
            e = nn.functional.leaky_relu(
                h_i +
                h_j.transpose(2, 3))  # default alpha=0.01, but =0.2 in tf

            grh_mask = torch.ones_like(grh)
            adj = (grh_mask < grh).unsqueeze(1).expand(-1, self.head_count, -1,
                                                       -1)
            zero = torch.ones_like(e) * (-9e15)

            e_shape = e.shape
            attention = self.softmax(e.where(adj > 0, zero))  # 17 8 56 56

            whole_sub_view = torch.matmul(attention, value)
            if self.gate:
                gate = torch.sigmoid(self.gate_linear[-1](
                    unshape(whole_sub_view)))
                views.append(gate * unshape(whole_sub_view))
            else:
                views.append(unshape(whole_sub_view))

            # edge-aware sub graph
            for i in range(self.edge_type - 1):
                h_i = self.linear_attention[index[i][0]](value)
                h_j = self.linear_attention[index[i][1]](value)
                e = nn.functional.leaky_relu(
                    h_i +
                    h_j.transpose(2, 3))  # default alpha=0.01, but =0.2 in tf

                label_id = i + 2  # +2 because the followed is ones_like, so 1 can't be the edge type
                grh_mask = torch.ones_like(grh) * label_id
                eye = (torch.eye(grh.size(-1), dtype=torch.int64) *
                       (4 - label_id)).cuda()  # here doesnt support multi-gpu
                grh_mask = grh_mask + eye
                adj = (grh_mask == grh).unsqueeze(1).expand(
                    -1, self.head_count, -1, -1)
                zero = torch.ones_like(e) * (-9e15)

                e_shape = e.shape
                attention = self.softmax(e.where(adj > 0, zero))  # 17 8 56 56

                sub_view = torch.matmul(attention, value)
                if self.gate:
                    gate = torch.sigmoid(self.gate_linear[i + 1](
                        unshape(sub_view)))
                    views.append(gate * unshape(sub_view))
                else:
                    views.append(unshape(sub_view))

            if self.fusion == "cat":
                ## TODO: MAX_P LSTM
                if self.gate:
                    Views = [gate_context] + views
                else:
                    Views = [context] + views
                output = torch.cat(Views, dim=-1)

                return self.sub_final_linear(output), top_attn

        output = self.final_linear(context)

        return output, top_attn