示例#1
0
def AttentiveCNN_match(context,
                       query,
                       context_mask,
                       query_mask,
                       scope='AttentiveCNN_Block',
                       residual=False,
                       normalize_output=False,
                       reuse=None,
                       **kwargs):
    with tf.variable_scope(scope, reuse=reuse):
        cnn_wo_att = CNN_encode(context,
                                filter_size=3,
                                direction='none',
                                act_fn=None)
        att_context, _ = Attentive_match(context, query, context_mask,
                                         query_mask)
        cnn_att = CNN_encode(att_context,
                             filter_size=1,
                             direction='none',
                             act_fn=None)
        output = tf.nn.tanh(cnn_wo_att + cnn_att)
        if residual:
            # Residual connection
            output += context

        if normalize_output:
            # Normalize
            output = layer_norm(output)  # (N, T_q, C)

        return output
示例#2
0
def TCN_encode(seqs,
               num_layers,
               normalize_output=True,
               scope='tcn_encode_block',
               reuse=None,
               layer_norm_scope='layer_norm',
               **kwargs):
    with tf.variable_scope(scope, reuse=reuse):
        outputs = [seqs]
        for i in range(num_layers):
            dilation_size = 2**i
            out = Res_DualCNN_encode(outputs[-1],
                                     dilation=dilation_size,
                                     scope='res_biconv_%d' % i,
                                     **kwargs)
            outputs.append(out)
        result = outputs[-1]
        if normalize_output:
            result = layer_norm(result, scope=layer_norm_scope, reuse=reuse)
        return result
示例#3
0
def MH_Att_encode(queries,
                  keys,
                  num_units=None,
                  num_heads=8,
                  dropout_keep_rate=1.0,
                  causality=False,
                  scope='MultiHead_Attention_Block',
                  reuse=None,
                  residual=False,
                  normalize_output=True,
                  **kwargs):
    """Applies multihead attention.

    Args:
      queries: A 3d tensor with shape of [N, T_q, C_q].
      keys: A 3d tensor with shape of [N, T_k, C_k].
      num_units: A scalar. Attention size.
      dropout_rate: A floating point number.
      is_training: Boolean. Controller of mechanism for dropout.
      causality: Boolean. If true, units that reference the future are masked.
      num_heads: An int. Number of heads.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.

    Returns
      A 3d tensor with shape of (N, T_q, C)
    """
    if num_units is None or residual:
        num_units = queries.get_shape().as_list()[-1]
    with tf.variable_scope(scope, reuse=reuse):
        # Set the fall back option for num_units

        # Linear projections
        Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu)  # (N, T_q, C)
        K = tf.layers.dense(keys, num_units, activation=tf.nn.relu)  # (N, T_k, C)
        V = tf.layers.dense(keys, num_units, activation=tf.nn.relu)  # (N, T_k, C)

        # Split and concat
        Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)  # (h*N, T_q, C/h)
        K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)
        V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)

        # Multiplication
        outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]))  # (h*N, T_q, T_k)

        # Scale
        outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)

        # Key Masking
        key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1)))  # (N, T_k)
        key_masks = tf.tile(key_masks, [num_heads, 1])  # (h*N, T_k)
        key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1])  # (h*N, T_q, T_k)

        paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
        outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs)  # (h*N, T_q, T_k)

        # Causality = Future blinding
        if causality:
            diag_vals = tf.ones_like(outputs[0, :, :])  # (T_q, T_k)
            tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense()  # (T_q, T_k)
            masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1])  # (h*N, T_q, T_k)

            paddings = tf.ones_like(masks) * (-2 ** 32 + 1)
            outputs = tf.where(tf.equal(masks, 0), paddings, outputs)  # (h*N, T_q, T_k)

        # Activation
        outputs = tf.nn.softmax(outputs)  # (h*N, T_q, T_k)

        # Query Masking
        query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1)))  # (N, T_q)
        query_masks = tf.tile(query_masks, [num_heads, 1])  # (h*N, T_q)
        query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]])  # (h*N, T_q, T_k)
        outputs *= query_masks  # broadcasting. (N, T_q, C)

        # Dropouts
        outputs = tf.nn.dropout(outputs, keep_prob=dropout_keep_rate)

        # Weighted sum
        outputs = tf.matmul(outputs, V_)  # ( h*N, T_q, C/h)

        # Restore shape
        outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2)  # (N, T_q, C)

        if residual:
            # Residual connection
            outputs += queries

        if normalize_output:
            # Normalize
            outputs = layer_norm(outputs)  # (N, T_q, C)

    return outputs
示例#4
0
def Transformer_match(context,
                      query,
                      context_mask,
                      query_mask,
                      num_units=None,
                      num_heads=1,
                      dropout_keep_rate=1.0,
                      causality=False,
                      scope='MultiHead_Attention_Block',
                      reuse=None,
                      residual=False,
                      normalize_output=False,
                      **kwargs):
    """Applies multihead attention.

    Args:
      context: A 3d tensor with shape of [N, T_q, C_q].
      query: A 3d tensor with shape of [N, T_k, C_k].
      num_units: A scalar. Attention size.
      dropout_rate: A floating point number.
      is_training: Boolean. Controller of mechanism for dropout.
      causality: Boolean. If true, units that reference the future are masked.
      num_heads: An int. Number of heads.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.

    Returns
      A 3d tensor with shape of (N, T_q, C)
    """
    if num_units is None or residual:
        num_units = context.get_shape().as_list()[-1]
    with tf.variable_scope(scope, reuse=reuse):
        # Set the fall back option for num_units

        # Linear projections
        Q = tf.layers.dense(context, num_units, activation=tf.nn.relu)  # (N, T_q, C)
        K = tf.layers.dense(query, num_units, activation=tf.nn.relu)  # (N, T_k, C)
        V = tf.layers.dense(query, num_units, activation=tf.nn.relu)  # (N, T_k, C)

        # Split and concat
        Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)  # (h*N, T_q, C/h)
        K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)
        V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)

        # Multiplication
        outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]))  # (h*N, T_q, T_k)

        # Scale
        outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)

        # Key Masking, aka query
        if query_mask is None:
            query_mask = tf.sign(tf.abs(tf.reduce_sum(query, axis=-1)))  # (N, T_k)

        mask1 = tf.tile(query_mask, [num_heads, 1])  # (h*N, T_k)
        mask1 = tf.tile(tf.expand_dims(mask1, 1), [1, tf.shape(context)[1], 1])  # (h*N, T_q, T_k)

        paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
        outputs = tf.where(tf.equal(mask1, 0), paddings, outputs)  # (h*N, T_q, T_k)

        # Causality = Future blinding
        if causality:
            diag_vals = tf.ones_like(outputs[0, :, :])  # (T_q, T_k)
            tril = tf.contrib.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()  # (T_q, T_k)
            masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1])  # (h*N, T_q, T_k)

            paddings = tf.ones_like(masks) * (-2 ** 32 + 1)
            outputs = tf.where(tf.equal(masks, 0), paddings, outputs)  # (h*N, T_q, T_k)

        # Activation
        outputs = tf.nn.softmax(outputs)  # (h*N, T_q, T_k)

        # Query Masking  aka, context
        if context_mask is None:
            context_mask = tf.sign(tf.abs(tf.reduce_sum(context, axis=-1)))  # (N, T_q)

        mask2 = tf.tile(context_mask, [num_heads, 1])  # (h*N, T_q)
        mask2 = tf.tile(tf.expand_dims(mask2, -1), [1, 1, tf.shape(query)[1]])  # (h*N, T_q, T_k)
        outputs *= mask2  # (h*N, T_q, T_k)

        # Dropouts
        outputs = tf.nn.dropout(outputs, keep_prob=dropout_keep_rate)

        # Weighted sum
        outputs = tf.matmul(outputs, V_)  # ( h*N, T_q, C/h)

        # Restore shape
        outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2)  # (N, T_q, C)

        if residual:
            # Residual connection
            outputs += context

        if normalize_output:
            # Normalize
            outputs = layer_norm(outputs)  # (N, T_q, C)

    return outputs