class TransformerEncoder(Chain): def __init__(self, depth, num_heads, model_dim, ff_dim, p_drop): super().__init__() with self.init_scope(): self.unit_links = ChainList() for i in range(depth): self.unit_links.append( TransformerEncoderUnit(num_heads, model_dim, ff_dim, p_drop)) def forward(self, input_encodings, input_masks=None): unit_inputs = [input_encodings] for unit_link in self.unit_links: x = unit_inputs[-1] o = unit_link(x, input_masks=input_masks) unit_inputs.append(o) return unit_inputs[-1]
class MultiHeadAttention(Chain): def __init__(self, num_heads, model_dim, key_dim, value_dim): super().__init__() self.num_heads = num_heads self.model_dim = model_dim self.key_dim = key_dim self.value_dim = value_dim self.multi_head_dim = num_heads * value_dim self.scale = 1. / sqrt(key_dim) with self.init_scope(): self.head_query_links = ChainList() self.head_key_links = ChainList() self.head_value_links = ChainList() for i in range(num_heads): self.head_query_links.append(L.Linear(model_dim, key_dim)) self.head_key_links.append(L.Linear(model_dim, key_dim)) self.head_value_links.append(L.Linear(model_dim, value_dim)) self.output_link = L.Linear(self.multi_head_dim, model_dim) def forward(self, queries, keys, values, mask=None): heads = [] for i in range(self.num_heads): query_projection = self.head_query_links[i](queries, n_batch_axes=2) key_projection = self.head_key_links[i](keys, n_batch_axes=2) value_projection = self.head_value_links[i](values, n_batch_axes=2) head = scaled_dot_product_attention(query_projection, key_projection, value_projection, scale=self.scale, mask=mask) heads.append(head) multi_head = F.concat(heads, axis=-1) return self.output_link(multi_head, n_batch_axes=2)