Exemple #1
0
class TransformerEncoder(Chain):
    def __init__(self, depth, num_heads, model_dim, ff_dim, p_drop):
        super().__init__()
        with self.init_scope():
            self.unit_links = ChainList()
            for i in range(depth):
                self.unit_links.append(
                    TransformerEncoderUnit(num_heads, model_dim, ff_dim,
                                           p_drop))

    def forward(self, input_encodings, input_masks=None):
        unit_inputs = [input_encodings]
        for unit_link in self.unit_links:
            x = unit_inputs[-1]
            o = unit_link(x, input_masks=input_masks)
            unit_inputs.append(o)
        return unit_inputs[-1]
Exemple #2
0
class MultiHeadAttention(Chain):
    def __init__(self, num_heads, model_dim, key_dim, value_dim):
        super().__init__()
        self.num_heads = num_heads
        self.model_dim = model_dim
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.multi_head_dim = num_heads * value_dim
        self.scale = 1. / sqrt(key_dim)
        with self.init_scope():
            self.head_query_links = ChainList()
            self.head_key_links = ChainList()
            self.head_value_links = ChainList()
            for i in range(num_heads):
                self.head_query_links.append(L.Linear(model_dim, key_dim))
                self.head_key_links.append(L.Linear(model_dim, key_dim))
                self.head_value_links.append(L.Linear(model_dim, value_dim))
            self.output_link = L.Linear(self.multi_head_dim, model_dim)

    def forward(self, queries, keys, values, mask=None):
        heads = []
        for i in range(self.num_heads):
            query_projection = self.head_query_links[i](queries,
                                                        n_batch_axes=2)
            key_projection = self.head_key_links[i](keys, n_batch_axes=2)
            value_projection = self.head_value_links[i](values, n_batch_axes=2)

            head = scaled_dot_product_attention(query_projection,
                                                key_projection,
                                                value_projection,
                                                scale=self.scale,
                                                mask=mask)

            heads.append(head)

        multi_head = F.concat(heads, axis=-1)
        return self.output_link(multi_head, n_batch_axes=2)