def attention_core(self, q, k, v, attn_mask): """ Core math operations of multihead attention layer """ q = self._split_heads( q) # [batch_size * n_heads * n_q * (k_dim/n_heads)] k = self._split_heads( k) # [batch_size * n_heads * n_kv * (k_dim/n_heads)] v = self._split_heads( v) # [batch_size * n_heads * n_kv * (v_dim/n_heads)] key_depth_per_head = self.key_depth / self.num_heads q = q / math.sqrt(key_depth_per_head) # Dot-product attention # logits: (batch_size * n_heads * n_q * n_kv) attn_bias = MultiHeadAttn.ATTN_BIAS_VALUE * (1 - attn_mask) logits = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) + attn_bias weights = tf.nn.softmax(logits) tf.add_to_collection("AttnWeights", weights) tf.add_to_collection( lib.meta.ATTENTIONS, lib.meta.Attention(self.scope, weights, logits, attn_mask)) if is_dropout_enabled(): weights = dropout(weights, 1.0 - self.attn_dropout) x = tf.matmul( weights, # [batch_size * n_heads * n_q * n_kv] v # [batch_size * n_heads * n_kv * (v_deph/n_heads)] ) combined_x = self._combine_heads(x) if is_dropout_enabled(): combined_x = dropout(combined_x, 1.0 - self.attn_value_dropout) return combined_x
def __call__(self, query_inp, attn_mask, kv_inp=None, kv=None): """ query_inp: [batch_size * n_q * inp_dim] attn_mask: [batch_size * 1 * n_q * n_kv] kv_inp: [batch_size * n_kv * inp_dim] ----------------------------------------------- results: [batch_size * n_q * output_depth] """ assert kv is None or kv_inp is None, "please only feed one of kv or kv_inp" with tf.name_scope(self.name) as scope: if kv_inp is not None or kv is not None: q = self.query_conv(query_inp) if kv is None: kv = self.kv_conv(kv_inp) k, v = tf.split(kv, [self.key_depth, self.value_depth], axis=2) else: combined = self.combined_conv(query_inp) q, k, v = tf.split( combined, [self.key_depth, self.key_depth, self.value_depth], axis=2) q = self._split_heads( q) # [batch_size * n_heads * n_q * (k_dim/n_heads)] k = self._split_heads( k) # [batch_size * n_heads * n_kv * (k_dim/n_heads)] v = self._split_heads( v) # [batch_size * n_heads * n_kv * (v_dim/n_heads)] key_depth_per_head = self.key_depth / self.num_heads q = q / math.sqrt(key_depth_per_head) # Dot-product attention # logits: (batch_size * n_heads * n_q * n_kv) attn_bias = MultiHeadAttn.ATTN_BIAS_VALUE * (1 - attn_mask) logits = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2 ])) + attn_bias weights = tf.nn.softmax(logits) tf.add_to_collection("AttnWeights", weights) tf.add_to_collection( lib.meta.ATTENTIONS, lib.meta.Attention(scope, weights, logits, attn_mask)) if is_dropout_enabled(): weights = dropout(weights, 1.0 - self.attn_dropout) x = tf.matmul( weights, # [batch_size * n_heads * n_q * n_kv] v # [batch_size * n_heads * n_kv * (v_deph/n_heads)] ) # x: [batch, n_heads, n_q, (v_deph/n_heads)] # ======================== apply the gate ======================== gated_x = self.gate * x # ================================================================== combined_x = self._combine_heads(gated_x) if is_dropout_enabled(): combined_x = dropout(combined_x, 1.0 - self.attn_value_dropout) outputs = self.out_conv(combined_x) return outputs