Beispiel #1
0
 def build(self, input_shape):
     """Builds the layer."""
     # Layers for linearly projecting the queries, keys, and values.
     size_per_head = self.hidden_size // self.num_heads
     # 得到K,Q,V
     # 用到tf.einsum,模版计算,有意思
     # Xavier初始化
     self.query_dense_layer = common_layer.Dense3D(
         self.num_heads,
         size_per_head,
         kernel_initializer="glorot_uniform",
         use_bias=False,
         name="query")
     self.key_dense_layer = common_layer.Dense3D(
         self.num_heads,
         size_per_head,
         kernel_initializer="glorot_uniform",
         use_bias=False,
         name="key")
     self.value_dense_layer = common_layer.Dense3D(
         self.num_heads,
         size_per_head,
         kernel_initializer="glorot_uniform",
         use_bias=False,
         name="value")
     # scale-dot attention之后的linear层
     self.output_dense_layer = common_layer.Dense3D(
         self.num_heads,
         size_per_head,
         kernel_initializer="glorot_uniform",
         use_bias=False,
         output_projection=True,
         name="output_transform")
     super(Attention, self).build(input_shape)
Beispiel #2
0
 def build(self, input_shape):
     """Builds the layer."""
     # Layers for linearly projecting the queries, keys, and values.
     size_per_head = self.hidden_size // self.num_heads
     self.sharedQK_dense_layer = common_layer.Dense3D(
         self.num_heads,
         size_per_head,
         kernel_initializer="glorot_uniform",
         use_bias=False,
         name="sharedQK")
     self.value_dense_layer = common_layer.Dense3D(
         self.num_heads,
         size_per_head,
         kernel_initializer="glorot_uniform",
         use_bias=False,
         name="value")
     self.output_dense_layer = common_layer.Dense3D(
         self.num_heads,
         size_per_head,
         kernel_initializer="glorot_uniform",
         use_bias=False,
         output_projection=True,
         name="output_transform")
     super(LshSelfAttention, self).build(input_shape)