def AttentiveCNN_match(context, query, context_mask, query_mask, scope='AttentiveCNN_Block', residual=False, normalize_output=False, reuse=None, **kwargs): with tf.variable_scope(scope, reuse=reuse): cnn_wo_att = CNN_encode(context, filter_size=3, direction='none', act_fn=None) att_context, _ = Attentive_match(context, query, context_mask, query_mask) cnn_att = CNN_encode(att_context, filter_size=1, direction='none', act_fn=None) output = tf.nn.tanh(cnn_wo_att + cnn_att) if residual: # Residual connection output += context if normalize_output: # Normalize output = layer_norm(output) # (N, T_q, C) return output
def TCN_encode(seqs, num_layers, normalize_output=True, scope='tcn_encode_block', reuse=None, layer_norm_scope='layer_norm', **kwargs): with tf.variable_scope(scope, reuse=reuse): outputs = [seqs] for i in range(num_layers): dilation_size = 2**i out = Res_DualCNN_encode(outputs[-1], dilation=dilation_size, scope='res_biconv_%d' % i, **kwargs) outputs.append(out) result = outputs[-1] if normalize_output: result = layer_norm(result, scope=layer_norm_scope, reuse=reuse) return result
def MH_Att_encode(queries, keys, num_units=None, num_heads=8, dropout_keep_rate=1.0, causality=False, scope='MultiHead_Attention_Block', reuse=None, residual=False, normalize_output=True, **kwargs): """Applies multihead attention. Args: queries: A 3d tensor with shape of [N, T_q, C_q]. keys: A 3d tensor with shape of [N, T_k, C_k]. num_units: A scalar. Attention size. dropout_rate: A floating point number. is_training: Boolean. Controller of mechanism for dropout. causality: Boolean. If true, units that reference the future are masked. num_heads: An int. Number of heads. scope: Optional scope for `variable_scope`. reuse: Boolean, whether to reuse the weights of a previous layer by the same name. Returns A 3d tensor with shape of (N, T_q, C) """ if num_units is None or residual: num_units = queries.get_shape().as_list()[-1] with tf.variable_scope(scope, reuse=reuse): # Set the fall back option for num_units # Linear projections Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # (N, T_q, C) K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) # Split and concat Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h) K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) # Multiplication outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k) # Scale outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5) # Key Masking key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k) key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k) key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k) paddings = tf.ones_like(outputs) * (-2 ** 32 + 1) outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k) # Causality = Future blinding if causality: diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k) tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense() # (T_q, T_k) masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k) paddings = tf.ones_like(masks) * (-2 ** 32 + 1) outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k) # Activation outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k) # Query Masking query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q) query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q) query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k) outputs *= query_masks # broadcasting. (N, T_q, C) # Dropouts outputs = tf.nn.dropout(outputs, keep_prob=dropout_keep_rate) # Weighted sum outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h) # Restore shape outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) # (N, T_q, C) if residual: # Residual connection outputs += queries if normalize_output: # Normalize outputs = layer_norm(outputs) # (N, T_q, C) return outputs
def Transformer_match(context, query, context_mask, query_mask, num_units=None, num_heads=1, dropout_keep_rate=1.0, causality=False, scope='MultiHead_Attention_Block', reuse=None, residual=False, normalize_output=False, **kwargs): """Applies multihead attention. Args: context: A 3d tensor with shape of [N, T_q, C_q]. query: A 3d tensor with shape of [N, T_k, C_k]. num_units: A scalar. Attention size. dropout_rate: A floating point number. is_training: Boolean. Controller of mechanism for dropout. causality: Boolean. If true, units that reference the future are masked. num_heads: An int. Number of heads. scope: Optional scope for `variable_scope`. reuse: Boolean, whether to reuse the weights of a previous layer by the same name. Returns A 3d tensor with shape of (N, T_q, C) """ if num_units is None or residual: num_units = context.get_shape().as_list()[-1] with tf.variable_scope(scope, reuse=reuse): # Set the fall back option for num_units # Linear projections Q = tf.layers.dense(context, num_units, activation=tf.nn.relu) # (N, T_q, C) K = tf.layers.dense(query, num_units, activation=tf.nn.relu) # (N, T_k, C) V = tf.layers.dense(query, num_units, activation=tf.nn.relu) # (N, T_k, C) # Split and concat Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h) K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) # Multiplication outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k) # Scale outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5) # Key Masking, aka query if query_mask is None: query_mask = tf.sign(tf.abs(tf.reduce_sum(query, axis=-1))) # (N, T_k) mask1 = tf.tile(query_mask, [num_heads, 1]) # (h*N, T_k) mask1 = tf.tile(tf.expand_dims(mask1, 1), [1, tf.shape(context)[1], 1]) # (h*N, T_q, T_k) paddings = tf.ones_like(outputs) * (-2 ** 32 + 1) outputs = tf.where(tf.equal(mask1, 0), paddings, outputs) # (h*N, T_q, T_k) # Causality = Future blinding if causality: diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k) tril = tf.contrib.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k) masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k) paddings = tf.ones_like(masks) * (-2 ** 32 + 1) outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k) # Activation outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k) # Query Masking aka, context if context_mask is None: context_mask = tf.sign(tf.abs(tf.reduce_sum(context, axis=-1))) # (N, T_q) mask2 = tf.tile(context_mask, [num_heads, 1]) # (h*N, T_q) mask2 = tf.tile(tf.expand_dims(mask2, -1), [1, 1, tf.shape(query)[1]]) # (h*N, T_q, T_k) outputs *= mask2 # (h*N, T_q, T_k) # Dropouts outputs = tf.nn.dropout(outputs, keep_prob=dropout_keep_rate) # Weighted sum outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h) # Restore shape outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) # (N, T_q, C) if residual: # Residual connection outputs += context if normalize_output: # Normalize outputs = layer_norm(outputs) # (N, T_q, C) return outputs