Exemple #1
0
def align_block(u, v, c_mask, q_mask, filters=128, dropout=0.0):
    with tf.variable_scope("Interactive_Alignment"):
        # attention
        E = tf.matmul(v, u, transpose_b=True)  # [bs, len_q, len_c]
        E_ = tf.nn.softmax(exp_mask(E, tf.expand_dims(q_mask, axis=-1)),
                           axis=1)  # [bs, len_q, len_c]
        v_E = tf.matmul(E_, v, transpose_a=True)  # [bs, len_c, dim]

        # fusion
        uv = tf.concat([u, v_E, u * v_E, u - v_E], axis=-1)
        x = tf.nn.tanh(conv1d(uv, filters, 1, name='Wr'))
        g = tf.nn.sigmoid(conv1d(uv, filters, 1, name='Wg'))
        h = g * x + (1 - g) * u  # [bs, len_c, dim]

    with tf.variable_scope("Self_Alignment"):
        # attention
        B = tf.matmul(h, h, transpose_b=True)  # [bs, len_c, len_c]
        B = tf.matrix_set_diag(B, tf.zeros([tf.shape(B)[0], tf.shape(B)[-1]]))
        B_ = tf.nn.softmax(exp_mask(B, tf.expand_dims(c_mask, axis=-1)),
                           axis=1)  # [bs, len_c, len_c]
        h_B = tf.matmul(B_, h, transpose_a=True)

        # fusion
        hh = tf.concat([h, h_B, h * h_B, h - h_B], axis=-1)
        x = tf.nn.tanh(conv1d(hh, filters, 1, name='Wr'))
        g = tf.nn.sigmoid(conv1d(hh, filters, 1, name='Wg'))
        Z = g * x + (1 - g) * h  # [bs, len_c, dim]

    with tf.variable_scope("Evidence_Collection"):
        R = BiLSTM(Z, filters // 2, name='bilstm',
                   dropout=dropout)  # [bs, len_c, dim]

    return R
Exemple #2
0
def start_logits(R, s, mask, filters=28):
    with tf.variable_scope("Start_Pointer"):
        logits1 = tf.concat([R, s, R * s, R - s], axis=-1)
        logits1 = tf.nn.tanh(conv1d(logits1, filters, 1, name="W1"))
        logits1 = tf.squeeze(conv1d(logits1, 1, 1, name="W1t"), axis=-1)
        logits1 = exp_mask(logits1, mask)
    return logits1
def ffn_self_attention_layer(x,
                             filter_depth,
                             output_depth,
                             num_parts,
                             dropout_rate,
                             share_kv=False,
                             name=None):
  """Self-attention feedforward layer.

  We use self-attention to do feedforward computations. We apply this function
  positionwise where for each position, we linearly transform the output to have
  depth filter_depth, and break up the result depth-wise into num_parts
  contiguous parts.  The parts self-attentd, we concatenate the results
  depth-wise, and we linearly transform to a depth of output_depth. The
  goal is to get multiplicative interactions between components of a
  representation.

  Args:
    x: a Tensor with shape [batch, length, channels]
    filter_depth: an integer
    output_depth: an integer
    num_parts: an integer dividing filter depth
    dropout_rate: a floating point number
    share_kv: Share the key value transform
    name: an optional string

  Returns:
    A Tensor.
  """

  with tf.variable_scope(
      name, default_name="feedforward_self_attention", values=[x]):
    x_shape = tf.shape(x)
    part_depth = filter_depth // num_parts
    if not share_kv:
      combined = common_layers.conv1d(
          x, filter_depth * 3, 1, name="qkv_transform")
      combined = tf.expand_dims(combined, axis=2)
      q, k, v = tf.split(combined, 3, axis=3)
    else:
      q = tf.expand_dims(
          common_layers.conv1d(x, filter_depth, 1, name="q_transform"), axis=2)
      kv_combined = tf.expand_dims(
          common_layers.conv1d(
              tf.concat([x, x], axis=1), filter_depth, 1, name="kv_transform"),
          axis=2)
      k, v = tf.split(kv_combined, [x_shape[1], x_shape[1]], axis=1)

    batch_q = tf.reshape(q, [-1, 1, num_parts, part_depth])
    batch_k = tf.reshape(k, [-1, 1, num_parts, part_depth])
    batch_v = tf.reshape(v, [-1, 1, num_parts, part_depth])

    batch_q *= part_depth**-0.5
    # non-masked bias
    bias = None
    x = dot_product_attention(batch_q, batch_k, batch_v, bias, dropout_rate)
    x = tf.reshape(x, [x_shape[0], x_shape[1], filter_depth])
    x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
    return x
Exemple #4
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]])
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      features["targets"] = tf.zeros_like(basic_result)
    targets_dropout = common_layers.mix(
        features["targets"], tf.zeros_like(basic_result),
        hparams.bottleneck_warmup_steps, is_training,
        max_prob=1.0 - hparams.autoregressive_dropout, broadcast_last=True)
    # Sometimes it's useful to look at non-autoregressive evals.
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(concat1d, shape[3], 3, padding="LEFT",
                                 activation=common_layers.belu,
                                 name="autoregressive_conv3")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(concat1d, shape[3], 5, padding="LEFT",
                                 activation=common_layers.belu,
                                 name="autoregressive_conv5")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(concat1d, shape[3], 3, padding="LEFT",
                                 activation=common_layers.belu,
                                 name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      return tf.reshape(res, shape), losses

    raise ValueError("Unsupported autoregressive mode: %s"
                     % hparams.autoregressive_mode)
Exemple #5
0
def start_logits(R, s, mask, filters=128, name='Start_Pointer'):
    with tf.variable_scope(name):
        if R.get_shape()[-1] == s.get_shape()[-1]:
            logits1 = tf.concat([R, s, R * s, R - s], axis=-1)
        else:
            logits1 = tf.concat([R, s], axis=-1)
        logits1 = tf.nn.tanh(conv1d(logits1, filters, 1, name='Wt'))
        logits1 = tf.squeeze(conv1d(logits1, 1, 1, name='Wf'), axis=-1)
        logits1 = exp_mask(logits1, mask)
    return logits1
Exemple #6
0
def end_logits(R, logits1, s, mask, filters=128):
    with tf.variable_scope("End_Pointer"):
        l = R * tf.expand_dims(logits1, axis=-1)  # [bs, len_c, dim]
        s_ = tf.concat([s, l, s * l, s - l], axis=-1)
        x = tf.nn.relu(conv1d(s_, filters, 1, name="Wr"))  # [bs, len_c, dim]
        g = tf.sigmoid(conv1d(s_, filters, 1, name="Wg"))  # [bs, len_c, dim]
        s_ = g * x + (1 - g) * s  # [bs, len_c, dim]

        logits2 = tf.concat([R, s_, R * s_, R - s_], axis=-1)
        logits2 = tf.nn.tanh(conv1d(logits2, filters, 1, name="W2"))
        logits2 = tf.squeeze(conv1d(logits2, 1, 1, name="W2t"), axis=-1)
        logits2 = exp_mask(logits2, mask)
    return logits2
Exemple #7
0
    def body(self, features):
        hparams = self._hparams
        shape = common_layers.shape_list(features["targets"])
        # Run the basic autoencoder part first.
        basic_result, losses = super(AutoencoderAutoregressive,
                                     self).body(features)
        # Prepare inputs for autoregressive modes.
        targets_keep_prob = 1.0 - hparams.autoregressive_dropout
        targets_dropout = common_layers.dropout_with_broadcast_dims(
            features["targets"], targets_keep_prob, broadcast_dims=[-1])
        targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]])
        targets_shifted = common_layers.shift_right_3d(targets1d)
        basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]])
        concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
        # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
        if hparams.autoregressive_forget_base:
            concat1d = tf.reshape(features["targets"],
                                  [shape[0], -1, shape[3]])
            concat1d = common_layers.shift_right_3d(concat1d)
        # The autoregressive part depends on the mode.
        if hparams.autoregressive_mode == "none":
            assert not hparams.autoregressive_forget_base
            return basic_result, losses
        if hparams.autoregressive_mode == "conv3":
            res = common_layers.conv1d(concat1d,
                                       shape[3],
                                       3,
                                       padding="LEFT",
                                       activation=common_layers.belu,
                                       name="autoregressive_conv3")
            return tf.reshape(res, shape), losses
        if hparams.autoregressive_mode == "conv5":
            res = common_layers.conv1d(concat1d,
                                       shape[3],
                                       5,
                                       padding="LEFT",
                                       activation=common_layers.belu,
                                       name="autoregressive_conv5")
            return tf.reshape(res, shape), losses
        if hparams.autoregressive_mode == "sru":
            res = common_layers.conv1d(concat1d,
                                       shape[3],
                                       3,
                                       padding="LEFT",
                                       activation=common_layers.belu,
                                       name="autoregressive_sru_conv3")
            res = common_layers.sru(res)
            return tf.reshape(res, shape), losses

        raise ValueError("Unsupported autoregressive mode: %s" %
                         hparams.autoregressive_mode)
def summary_vector(q_emb, c_maxlen, mask):
    with tf.variable_scope("Question_Summary"):
        alpha = tf.nn.softmax(exp_mask(tf.squeeze(conv1d(q_emb, 1, 1), axis=-1), mask))
        s = tf.expand_dims(alpha, axis=-1) * q_emb
        s = tf.reduce_sum(s, axis=1, keepdims=True)  # [bs, 1, dim]
        s = tf.tile(s, [1, c_maxlen, 1])  # [bs, len_c, dim]
    return s
Exemple #9
0
 def testConv1d(self):
     x = np.random.rand(5, 7, 11)
     with self.test_session() as session:
         y = common_layers.conv1d(tf.constant(x, dtype=tf.float32), 13, 1)
         session.run(tf.global_variables_initializer())
         res = session.run(y)
     self.assertEqual(res.shape, (5, 7, 13))
Exemple #10
0
def align_block(u,
                v,
                c_mask,
                q_mask,
                Lambda,
                filters=128,
                E_0=None,
                B_0=None,
                Z_0=None):
    with tf.variable_scope("Interactive_Alignment"):
        u_ = tf.nn.relu(conv1d(u, filters, 1, name="Wu"))
        v_ = tf.nn.relu(conv1d(v, filters, 1, name="Wv"))
        E = tf.matmul(v_, u_, transpose_b=True)  # [bs, len_q, len_c]
        if E_0 is not None:
            E += (Lambda * E_0)
        E_ = tf.nn.softmax(exp_mask(E, tf.expand_dims(q_mask, axis=-1)),
                           axis=1)  # [bs, len_q, len_c]
        v_E = tf.matmul(E_, v, transpose_a=True)  # [bs, len_c, dim]

        # fusion
        uv = tf.concat([u, v_E, u * v_E, u - v_E], axis=-1)
        x_ = tf.nn.relu(conv1d(uv, filters, 1, name="Wr"))
        g = tf.nn.sigmoid(conv1d(uv, filters, 1, name="Wg"))
        o = g * x_ + (1 - g) * u  # [bs, len_c, dim]

    with tf.variable_scope("Self_Alignment"):
        # attention
        h_1 = tf.nn.relu(conv1d(o, filters, 1, name="Wh1"))
        h_2 = tf.nn.relu(conv1d(o, filters, 1, name="Wh2"))
        B = tf.matmul(h_2, h_1, transpose_b=True)  # [bs, len_c, len_c]
        if B_0 is not None:
            B += (Lambda * B_0)
        B_ = tf.nn.softmax(exp_mask(B, tf.expand_dims(c_mask, axis=-1)),
                           axis=1)  # [bs, len_c, len_c]
        o_B = tf.matmul(B_, o, transpose_a=True)

        # fusion
        oo = tf.concat([o, o_B, o * o_B, o - o_B], axis=-1)
        x_ = tf.nn.relu(conv1d(oo, filters, 1, name="Wr"))
        g = tf.nn.sigmoid(conv1d(oo, filters, 1, name="Wg"))
        Z = g * x_ + (1 - g) * o  # [bs, len_c, dim]
    with tf.variable_scope("Evidence_Collection"):
        if Z_0 is not None:
            Z = tf.concat([Z, Z_0[0], Z_0[1]], axis=-1)
        R = layers.Bidirectional(
            layers.LSTM(filters // 2,
                        return_sequences=True))(Z)  # [bs, len_c, dim]

    # return the E_t, B_t
    E_t = tf.nn.softmax(exp_mask(E, tf.expand_dims(c_mask, axis=1)),
                        axis=-1)  # [bs, len_q, len_c]
    E_t = tf.matmul(E_t, B_)
    B_t = tf.nn.softmax(exp_mask(B, tf.expand_dims(c_mask, axis=1)),
                        axis=-1)  # [bs, len_c, len_c]
    B_t = tf.matmul(B_t, B_)

    return R, Z, E_t, B_t
 def _compute(inp, depth, filter_width, padding, name):
     if filter_width == 1:
         return common_layers.dense(inp, depth, use_bias=False, name=name)
     else:
         return common_layers.conv1d(inp,
                                     depth,
                                     filter_width,
                                     padding,
                                     name=name)
def end_logits(R, logits1, s, mask, filters=128, name='End_Pointer'):
    with tf.variable_scope(name):
        l = R * tf.expand_dims(tf.nn.softmax(logits1, axis=-1), axis=-1)  # [bs, len_c, dim] ## R3 * p1
        if s.get_shape()[-1] == l.get_shape()[-1]:
            s_ = tf.concat([s, l, s * l, s - l], axis=-1) ## x; y; x * y; x - y
        else:
            s_ = tf.concat([s, l], axis=-1)
        x = tf.nn.relu(conv1d(s_, filters, 1, name='Wr'))  # [bs, len_c, dim] ## relu(Wr * [x; y; x * y; x - y])
        g = tf.nn.sigmoid(conv1d(s_, filters, 1, name='Wg'))  # [bs, len_c, dim] ## sigmoid(Wg * [x; y; x * y; x - y])
        s_ = g * x + (1 - g) * s  # [bs, len_c, dim] ## g * x~ + (1 - g) * x

        if R.get_shape()[-1] == s_.get_shape()[-1]:
            logits2 = tf.concat([R, s_, R * s_, R - s_], axis=-1) ## R; s~; R * s~; R - s~
        else:
            logits2 = tf.concat([R, s_], axis=-1)
        logits2 = tf.nn.tanh(conv1d(logits2, filters, 1, name='Wt')) ## ## tanh(W2 * logit)
        logits2 = tf.squeeze(conv1d(logits2, 1, 1, name='Wf'), axis=-1)
        logits2 = exp_mask(logits2, mask) ## exp
    return logits2
def end_logits(R, logits1, s, mask, filters=128, name='End_Pointer'):
    with tf.variable_scope(name):
        l = R * tf.expand_dims(tf.nn.softmax(logits1, axis=-1), axis=-1)  # [bs, len_c, dim]
        if s.get_shape()[-1] == l.get_shape()[-1]:
            s_ = tf.concat([s, l, s * l, s - l], axis=-1)
        else:
            s_ = tf.concat([s, l], axis=-1)
        x = tf.nn.relu(conv1d(s_, filters, 1, name='Wr'))  # [bs, len_c, dim]
        g = tf.nn.sigmoid(conv1d(s_, filters, 1, name='Wg'))  # [bs, len_c, dim]
        s_ = g * x + (1 - g) * s  # [bs, len_c, dim]

        if R.get_shape()[-1] == s_.get_shape()[-1]:
            logits2 = tf.concat([R, s_, R * s_, R - s_], axis=-1)
        else:
            logits2 = tf.concat([R, s_], axis=-1)
        logits2 = tf.nn.tanh(conv1d(logits2, filters, 1, name='Wt'))
        logits2 = tf.squeeze(conv1d(logits2, 1, 1, name='Wf'), axis=-1)
        logits2 = exp_mask(logits2, mask)
    return logits2
def FeedForward(x, filters, dropout, name):
    with tf.variable_scope(name):
        x = tf.nn.relu(
            conv1d(x,
                   filters,
                   kernel_size=1,
                   padding='same',
                   name='FFN_1',
                   kernel_initializer=initializer_relu(),
                   kernel_regularizer=regularizer))
        x = conv1d(x,
                   filters,
                   kernel_size=1,
                   padding='same',
                   name="FFN_2",
                   kernel_initializer=initializer(),
                   kernel_regularizer=regularizer)
        x = tf.nn.dropout(x, 1 - dropout)
    return x
def transform_qkv(query_antecedent,
                  key_antecedent,
                  value_antecedent,
                  total_key_depth,
                  total_value_depth,
                  q_filter_width=1,
                  kv_filter_width=1,
                  q_padding="VALID",
                  kv_padding="VALID"):
    """Computes query, key and value.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels]
    total_key_depth: an integer
    total_value_depth: and integer
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
    to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
    kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.

  Returns:
    q, k, v : [batch, length, depth] tensors
  """
    q = common_layers.conv1d(query_antecedent,
                             total_key_depth,
                             q_filter_width,
                             padding=q_padding,
                             name="q_transform")
    k = common_layers.conv1d(key_antecedent,
                             total_key_depth,
                             1,
                             padding=kv_padding,
                             name="k_transform")
    v = common_layers.conv1d(value_antecedent,
                             total_key_depth,
                             1,
                             padding=kv_padding,
                             name="v_transform")
    return q, k, v
Exemple #16
0
def linear_sum_attention(x, mask, dropout):
    alpha = tf.squeeze(conv1d(x,
                              1,
                              1,
                              kernel_initializer=initializer,
                              kernel_regularizer=regularizer),
                       axis=-1)  # [bs, c_len]
    alpha = exp_mask(alpha, mask)  # [bs, c_len]
    alpha = tf.expand_dims(tf.nn.softmax(alpha), axis=1)  # [bs, 1, c_len]
    x = tf.squeeze(tf.matmul(alpha, x), axis=1)  # [bs, dim]
    x = tf.nn.dropout(x, 1.0 - dropout)
    return x
def compute_qkv_pos(query_antecedent,
                    memory_antecedent,
                    total_key_depth,
                    total_value_depth,
                    qkv_padding="VALID"):
    """Computes query, key and value.
  Returns:
    q, k, v : [batch, length, depth] tensors
  """
    if memory_antecedent is None:
        # self attention
        combined = common_layers.conv1d(query_antecedent,
                                        total_key_depth * 2 +
                                        total_value_depth,
                                        1,
                                        padding=qkv_padding,
                                        name="qkv_transform")
        q, k, v = tf.split(
            combined, [total_key_depth, total_key_depth, total_value_depth],
            axis=2)
        return q, k, v

    # encoder-decoder attention
    q = common_layers.conv1d(query_antecedent,
                             total_key_depth,
                             1,
                             padding=qkv_padding,
                             name="q_transform")
    k = common_layers.conv1d(query_antecedent,
                             total_key_depth,
                             1,
                             padding=qkv_padding,
                             name="k_transform")
    v = common_layers.conv1d(memory_antecedent,
                             total_value_depth,
                             1,
                             padding=qkv_padding,
                             name="v_transform")

    return q, k, v
Exemple #18
0
def answer_block(R, z1, filters, c_mask, dropout=0.0, return_logits=True):
    # start
    z_s = tf.tile(tf.expand_dims(z1, axis=1),
                  [1, tf.shape(R)[1], 1])  # [bs, 1*c_len, dim]
    s = feed_forward(tf.concat([R, z_s, R * z_s], axis=-1), 'st', filters,
                     dropout)  # [bs, c_len, dim]
    s_logits = exp_mask(tf.squeeze(conv1d(s, 1, 1, name='Ws'), axis=-1),
                        c_mask)  # [bs, c_len]
    s = tf.expand_dims(tf.nn.softmax(s_logits),
                       axis=-1)  # [bs, c_len]->[bs, c_len, 1]

    # get z2
    u = tf.squeeze(tf.matmul(R, s, transpose_a=True),
                   axis=-1)  # [bs, dim, 1]->[bs, dim]
    zu = tf.concat([z1, u, z1 * u, z1 - u], axis=-1)
    z_s_ = tf.nn.tanh(dense(zu, filters, name='Wru'))
    g = tf.nn.sigmoid(dense(zu, filters, name='Wgu'))
    z2 = g * z_s_ + (1 - g) * z1  # [bs, dim]

    # end
    z_e = tf.tile(tf.expand_dims(z2, axis=1),
                  [1, tf.shape(R)[1], 1])  # [bs, 1*c_len, dim]
    e = feed_forward(tf.concat([R, z_e, R * z_e], axis=-1), 'ed', filters,
                     dropout)
    e_logits = exp_mask(tf.squeeze(conv1d(e, 1, 1, name='We'), axis=-1),
                        c_mask)  # [bs, c_len]
    e = tf.expand_dims(tf.nn.softmax(e_logits), axis=-1)

    # get z3
    v = tf.squeeze(tf.matmul(R, e, transpose_a=True), axis=-1)
    zv = tf.concat([z2, v, z2 * v, z2 - v], axis=-1)
    z_e_ = tf.nn.tanh(dense(zv, filters, name='Wrv'))
    g = tf.nn.sigmoid(dense(zv, filters, name='Wgv'))
    z3 = g * z_e_ + (1 - g) * z2  # [bs, dim]

    if return_logits:
        return s_logits, e_logits
    else:
        return z3
Exemple #19
0
def preprocess_inputs(inputs, hidden_size):
    """Transform input size and add positional encodings."""
    if inputs.shape.as_list()[-1] != hidden_size:
        # Project to proper size
        inputs = common_layers.conv1d(inputs=inputs,
                                      filters=hidden_size,
                                      kernel_size=1,
                                      activation=None,
                                      padding='SAME')
    net = inputs
    net = common_attention.add_timing_signal_nd(net)

    return net
def DotProductProject(x1, x2, filters, dropout):
    x1 = tf.nn.dropout(x1, 1 - dropout)  # [bs, c_len, dim]
    x2 = tf.nn.dropout(x2, 1 - dropout)  # [bs, q_len, dim]
    x1 = conv1d(x1,
                filters,
                kernel_size=1,
                padding='same',
                name='conv',
                reuse=tf.AUTO_REUSE,
                kernel_initializer=initializer_relu(),
                kernel_regularizer=regularizer)  # [bs, c_len, filters]
    x1 = tf.nn.relu(layer_norm(x1))
    x2 = conv1d(x2,
                filters,
                kernel_size=1,
                padding='same',
                name='conv',
                reuse=tf.AUTO_REUSE,
                kernel_initializer=initializer_relu(),
                kernel_regularizer=regularizer)  # [bs, q_len, filters]
    x2 = tf.nn.relu(layer_norm(x2))
    S = tf.matmul(x1, x2, transpose_b=True)  # [bs, c_len, q_len]
    return S
def SumAttention(x, mask, dropout):
    x = tf.nn.dropout(x, 1 - dropout)
    alpha = tf.squeeze(conv1d(x,
                              1,
                              1,
                              kernel_initializer=initializer(),
                              name='sum_conv',
                              kernel_regularizer=regularizer),
                       axis=-1)  # [bs, c_len]
    alpha = exp_mask(alpha, mask)  # [bs, c_len]
    alpha = tf.expand_dims(tf.nn.softmax(alpha), axis=1)  # [bs, 1, c_len]
    x = tf.squeeze(tf.matmul(alpha, x),
                   axis=1)  # x:[bs, c_len, dim] -> [bs, dim]
    return x
def align_block(u, v, c_mask, q_mask, Lambda, filters=128, E_0=None, B_0=None, Z_0=None, dropout=0.0):
    with tf.variable_scope("Interactive_Alignment"):
        # attention
        u_ = tf.nn.relu(conv1d(u, filters, 1, name="Wu"))  # [bs, len_c, dim]
        v_ = tf.nn.relu(conv1d(v, filters, 1, name="Wv"))  # [bs, len_q, dim]
        E = tf.matmul(v_, u_, transpose_b=True)  # [bs, len_q, len_c] ## relu(WuU)*relu(WvV)
        if E_0 is not None: ## block 2, 3
            E += (Lambda * E_0)
        E_ = tf.nn.softmax(exp_mask(E, tf.expand_dims(q_mask, axis=-1)), axis=1)  # [bs, len_q, len_c]
        v_E = tf.matmul(E_, v, transpose_a=True)  # [bs, len_c, dim] ## v~ = V * softmax(E)

        # fusion
        uv = tf.concat([u, v_E, u * v_E, u - v_E], axis=-1) ## x; y; x * y; x - y
        x = tf.nn.relu(conv1d(uv, filters, 1, name='Wr')) ## x~ = relu(Wr[x; y; x * y; x - y])
        g = tf.nn.sigmoid(conv1d(uv, filters, 1, name='Wg')) ## g = sigmoid(Wg[x; y; x * y; x - y])
        h = g * x + (1 - g) * u  # [bs, len_c, dim] ## o = g * x~ + (1 - g) * x

    with tf.variable_scope("Self_Alignment"):
        # attention
        h_1 = tf.nn.relu(conv1d(h, filters, 1, name='Wh1')) ## h1 = relu(Wh1 * h)
        h_2 = tf.nn.relu(conv1d(h, filters, 1, name='Wh2')) ## h2 = relu(Wh2 * h)
        B = tf.matmul(h_2, h_1, transpose_b=True)  # [bs, len_c, len_c] ## seftattention
        if B_0 is not None: ## block 2, 3
            B += (Lambda * B_0)
        B_ = tf.nn.softmax(exp_mask(B, tf.expand_dims(c_mask, axis=-1)), axis=1)  # [bs, len_c, len_c]
        h_B = tf.matmul(B_, h, transpose_a=True) ## H~ = H * softmax(B)

        # fusion
        hh = tf.concat([h, h_B, h * h_B, h - h_B], axis=-1)
        x = tf.nn.relu(conv1d(hh, filters, 1, name='Wr'))
        g = tf.nn.sigmoid(conv1d(hh, filters, 1, name='Wg'))
        Z = g * x + (1 - g) * h  # [bs, len_c, dim]

    with tf.variable_scope("Evidence_Collection"):
        if Z_0 is not None: ## block 3
            Z = tf.concat([Z, Z_0[0], Z_0[1]], axis=-1)
        R = BiLSTM(Z, filters // 2, name='bilstm', dropout=dropout)  # [bs, len_c, dim] ## R = BiLSTM(Z)

    # return the E_t, B_t
    E_t = tf.nn.softmax(exp_mask(E, tf.expand_dims(c_mask, axis=1)), axis=-1)  # [bs, len_q, len_c]
    E_t = tf.matmul(E_t, B_) ## ???
    B_t = tf.nn.softmax(exp_mask(B, tf.expand_dims(c_mask, axis=1)), axis=-1)  # [bs, len_c, len_c]
    B_t = tf.matmul(B_t, B_) ## ???

    return R, Z, E_t, B_t
def align_block(u, v, c_mask, q_mask, Lambda, filters=128, E_0=None, B_0=None, Z_0=None, dropout=0.0):
    with tf.variable_scope("Interactive_Alignment"):
        # attention
        u_ = tf.nn.relu(conv1d(u, filters, 1, name="Wu"))  # [bs, len_c, dim]
        v_ = tf.nn.relu(conv1d(v, filters, 1, name="Wv"))  # [bs, len_q, dim]
        E = tf.matmul(v_, u_, transpose_b=True)  # [bs, len_q, len_c]
        if E_0 is not None:
            E += (Lambda * E_0)
        E_ = tf.nn.softmax(exp_mask(E, tf.expand_dims(q_mask, axis=-1)), axis=1)  # [bs, len_q, len_c]
        v_E = tf.matmul(E_, v, transpose_a=True)  # [bs, len_c, dim]

        # fusion
        uv = tf.concat([u, v_E, u * v_E, u - v_E], axis=-1)
        x = tf.nn.relu(conv1d(uv, filters, 1, name='Wr'))
        g = tf.nn.sigmoid(conv1d(uv, filters, 1, name='Wg'))
        h = g * x + (1 - g) * u  # [bs, len_c, dim]

    with tf.variable_scope("Self_Alignment"):
        # attention
        h_1 = tf.nn.relu(conv1d(h, filters, 1, name='Wh1'))
        h_2 = tf.nn.relu(conv1d(h, filters, 1, name='Wh2'))
        B = tf.matmul(h_2, h_1, transpose_b=True)  # [bs, len_c, len_c]
        if B_0 is not None:
            B += (Lambda * B_0)
        B_ = tf.nn.softmax(exp_mask(B, tf.expand_dims(c_mask, axis=-1)), axis=1)  # [bs, len_c, len_c]
        h_B = tf.matmul(B_, h, transpose_a=True)

        # fusion
        hh = tf.concat([h, h_B, h * h_B, h - h_B], axis=-1)
        x = tf.nn.relu(conv1d(hh, filters, 1, name='Wr'))
        g = tf.nn.sigmoid(conv1d(hh, filters, 1, name='Wg'))
        Z = g * x + (1 - g) * h  # [bs, len_c, dim]

    with tf.variable_scope("Evidence_Collection"):
        if Z_0 is not None:
            Z = tf.concat([Z, Z_0[0], Z_0[1]], axis=-1)
        R = BiLSTM(Z, filters // 2, name='bilstm', dropout=dropout)  # [bs, len_c, dim]

    # return the E_t, B_t
    E_t = tf.nn.softmax(exp_mask(E, tf.expand_dims(c_mask, axis=1)), axis=-1)  # [bs, len_q, len_c]
    E_t = tf.matmul(E_t, B_)
    B_t = tf.nn.softmax(exp_mask(B, tf.expand_dims(c_mask, axis=1)), axis=-1)  # [bs, len_c, len_c]
    B_t = tf.matmul(B_t, B_)

    return R, Z, E_t, B_t
def compute_q(query_antecedent,
              total_key_depth,
              q_filter_width=1,
              q_padding="VALID"):
    """Computes query.
  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    total_key_depth: an integer
    q_filter_width: An integer specifying how wide you want the query to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
  Returns:
    q: [batch, length, depth] tensors
  """
    q = common_layers.conv1d(query_antecedent,
                             total_key_depth,
                             q_filter_width,
                             padding=q_padding,
                             name="q_transform")
    return q
def CharCNN(x,
            char_limit,
            char_dim,
            filters,
            maxlen,
            kernel_size=5,
            name='char_conv'):
    x = tf.reshape(x, [-1, char_limit, char_dim])
    x = tf.nn.relu(
        conv1d(x,
               filters,
               kernel_size=kernel_size,
               name=name,
               padding='same',
               kernel_initializer=initializer_relu(),
               kernel_regularizer=regularizer))
    x = tf.reduce_max(x, axis=1)
    x = tf.reshape(x, [-1, maxlen, filters])
    return x
Exemple #26
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]])
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      features["targets"] = tf.zeros_like(basic_result)
    targets_dropout = common_layers.mix(
        features["targets"],
        tf.zeros_like(basic_result),
        hparams.bottleneck_warmup_steps,
        is_training,
        max_prob=1.0 - hparams.autoregressive_dropout,
        broadcast_last=True)
    # Sometimes it's useful to look at non-autoregressive evals.
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(
          concat1d,
          shape[3],
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv3")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(
          concat1d,
          shape[3],
          5,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv5")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(
          concat1d,
          shape[3],
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      return tf.reshape(res, shape), losses

    raise ValueError(
        "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
 def testConv1d(self):
   x = np.random.rand(5, 7, 11)
   y = common_layers.conv1d(tf.constant(x, dtype=tf.float32), 13, 1)
   self.evaluate(tf.global_variables_initializer())
   res = self.evaluate(y)
   self.assertEqual(res.shape, (5, 7, 13))
Exemple #28
0
  def body(self, features):
    hparams = self.hparams
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    if "training" in losses:
      plain_training_loss = losses.pop("training")
      losses["plain"] = plain_training_loss
    res_shape = common_layers.shape_list(basic_result)
    vocab_size = self._problem_hparams.modality["targets"].top_dimensionality
    targets = tf.one_hot(features["targets_raw"], vocab_size)
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      targets = tf.zeros_like(basic_result)
    targets = self.embed(targets)
    if hparams.autoregressive_gumbel_sample:
      basic_hot = self.gumbel_sample(basic_result)
    else:
      basic_hot = basic_result
    basic_result = self.embed(basic_hot)
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[-1]])
    targets = tf.reshape(targets, common_layers.shape_list(basic_result))
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Sometimes it's useful to look at non-autoregressive evals.
    targets_dropout = targets
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[-1]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(targets, [shape[0], -1, shape[-1]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv3")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          5,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv5")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses

    raise ValueError(
        "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
Exemple #29
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        image_shapes=None,
                        attention_type="dot_product",
                        block_length=128,
                        block_width=128,
                        name=None):
    """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels]
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    attention_type: a string, either "dot_product" or "local_mask_right" or
                    "local_unmasked"
    block_length: an integer - relevant for "local_mask_right"
    block_width: an integer - relevant for "local_unmasked"
    name: an optional string

  Returns:
    A Tensor.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))

    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):
        if memory_antecedent is None:
            # self attention
            combined = common_layers.conv1d(query_antecedent,
                                            total_key_depth * 2 +
                                            total_value_depth,
                                            1,
                                            name="qkv_transform")
            q, k, v = tf.split(
                combined,
                [total_key_depth, total_key_depth, total_value_depth],
                axis=2)
        else:
            q = common_layers.conv1d(query_antecedent,
                                     total_key_depth,
                                     1,
                                     name="q_transform")
            combined = common_layers.conv1d(memory_antecedent,
                                            total_key_depth +
                                            total_value_depth,
                                            1,
                                            name="kv_transform")
            k, v = tf.split(combined, [total_key_depth, total_value_depth],
                            axis=2)
        q = split_heads(q, num_heads)
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)
        key_depth_per_head = total_key_depth // num_heads
        q *= key_depth_per_head**-0.5
        if attention_type == "dot_product":
            x = dot_product_attention(q, k, v, bias, dropout_rate,
                                      image_shapes)
        elif attention_type == "local_mask_right":
            x = masked_local_attention_1d(q, k, v, block_length=block_length)
        else:
            assert attention_type == "local_unmasked"
            x = unmasked_local_attention_1d(q,
                                            k,
                                            v,
                                            block_length=block_length,
                                            filter_width=block_width)
        x = combine_heads(x)
        x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
        return x
Exemple #30
0
  def body(self, features):
    hparams = self.hparams
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    if "training" in losses:
      plain_training_loss = losses.pop("training")
      losses["plain"] = plain_training_loss
    res_shape = common_layers.shape_list(basic_result)
    vocab_size = self._problem_hparams.vocab_size["targets"]
    if hasattr(self._hparams, "vocab_divisor"):
      vocab_size += (-vocab_size) % self._hparams.vocab_divisor
    targets = tf.one_hot(features["targets_raw"], vocab_size)
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      targets = tf.zeros_like(basic_result)
    targets = self.embed(targets)
    if hparams.autoregressive_gumbel_sample:
      basic_hot = self.gumbel_sample(basic_result)
    else:
      basic_hot = basic_result
    basic_result = self.embed(basic_hot)
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[-1]])
    targets = tf.reshape(targets, common_layers.shape_list(basic_result))
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Sometimes it's useful to look at non-autoregressive evals.
    targets_dropout = targets
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[-1]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(targets, [shape[0], -1, shape[-1]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv3")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          5,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv5")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses

    raise ValueError(
        "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
Exemple #31
0
 def testConv1d(self):
     x = np.random.rand(5, 7, 11)
     y = common_layers.conv1d(tf.constant(x, dtype=tf.float32), 13, 1)
     self.evaluate(tf.global_variables_initializer())
     res = self.evaluate(y)
     self.assertEqual(res.shape, (5, 7, 13))
Exemple #32
0
def feed_forward(x, name, filters=128, dropout=0.0):
    x = tf.nn.relu(conv1d(x, filters, 1, name=name + '_FF1'))
    x = conv1d(x, filters, 1, name=name + 'FF2')
    x = tf.nn.dropout(x, 1 - dropout)
    return x
Exemple #33
0
def parameter_attention(x,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        memory_rows,
                        num_heads,
                        dropout_rate,
                        name=None):
    """Attention over parameters.

  We use the same multi-headed attention as in the other layers, but the memory
  keys and values are model parameters.  There are no linear transformation
  on the keys or values.

  We are also a bit more careful about memory usage, since the number of
  memory positions may be very large.

  Args:
    x: a Tensor with shape [batch, length_q, channels]
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    memory_rows: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    name: an optional string

  Returns:
    A Tensor.
  """
    with tf.variable_scope(name,
                           default_name="parameter_attention",
                           values=[x]):
        head_size_k = total_key_depth // num_heads
        head_size_v = total_value_depth // num_heads
        var_shape_k = [num_heads, memory_rows, head_size_k]
        var_shape_v = [num_heads, memory_rows, head_size_v]
        k = tf.get_variable("k",
                            var_shape_k,
                            initializer=tf.random_normal_initializer(
                                0, output_depth**-0.5)) * (num_heads**0.5)
        v = tf.get_variable("v",
                            var_shape_v,
                            initializer=tf.random_normal_initializer(
                                0, output_depth**-0.5)) * (output_depth**0.5)
        batch_size = tf.shape(x)[0]
        length = tf.shape(x)[1]
        q = common_layers.conv1d(x, total_key_depth, 1, name="q_transform")
        if dropout_rate:
            # This is a cheaper form of attention dropout where we use to use
            # the same dropout decisions across batch elemets and query positions,
            # but different decisions across heads and memory positions.
            v = tf.nn.dropout(v,
                              1.0 - dropout_rate,
                              noise_shape=[num_heads, memory_rows, 1])
        # query is [batch, length, hidden_size]
        # reshape and transpose it to [heads, batch * length, head_size]
        q = tf.reshape(q, [batch_size, length, num_heads, head_size_k])
        q = tf.transpose(q, [2, 0, 1, 3])
        q = tf.reshape(q, [num_heads, batch_size * length, head_size_k])
        weights = tf.matmul(q, k, transpose_b=True)
        weights = tf.nn.softmax(weights)
        y = tf.matmul(weights, v)
        y = tf.reshape(y, [num_heads, batch_size, length, head_size_v])
        y = tf.transpose(y, [1, 2, 0, 3])
        y = tf.reshape(y, [batch_size, length, total_value_depth])
        y.set_shape([None, None, total_value_depth])
        y = common_layers.conv1d(y, output_depth, 1, name="output_transform")

        return y