def padding_attention_mask(from_tensor, to_mask, *, dtype):
    """Create 3D attention mask from a 2D tensor mask.

      Args:
        from_tensor: A Tensor of shape [batch_size, heads, from_seq_length, to_seq_length].
        to_mask: int32 Tensor of shape [batch_size, to_seq_length].

      Returns:
        float Tensor of shape [batch_size, 1, from_seq_length, to_seq_length].
      """
    from_shape = shape_list(from_tensor)
    batch_size = from_shape[0]
    from_seq_length = from_shape[2]

    to_shape = shape_list(to_mask)
    to_seq_length = to_shape[1]

    to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
                      dtype=dtype)

    # We don't assume that `from_tensor` is a mask (although it could be). We
    # don't actually care if we attend *from* padding tokens (only *to* padding)
    # tokens so we create a tensor of all ones.
    #
    # `broadcast_ones` = [batch_size, from_seq_length, 1]
    broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1],
                             dtype=dtype)

    # Here we broadcast along two dimensions to create the mask.
    mask = broadcast_ones * to_mask
    mask = tf.expand_dims(mask, axis=1)
    return mask
Beispiel #2
0
def conv1d(x, scope, nf, *, w_init_stdev=0.02):
    with tf.variable_scope(scope):
        *start, nx = shape_list(x)
        w = tf.get_variable(
            'w', [1, nx, nf],
            initializer=tf.random_normal_initializer(stddev=w_init_stdev))
        b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0))
        c = tf.reshape(
            tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b,
            start + [nf])
        return c
def left_to_right_attention_mask(from_tensor, to_tensor, *, dtype):
    """1's in the lower triangle, counting from the lower right corner.

    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
    Return shape:[1,1,nd,ns]
    """
    assert from_tensor == to_tensor, "from_tensor and to_tensor must be same."
    _, _, nd, ns = shape_list(from_tensor)
    i = tf.range(nd)[:, None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    m = tf.cast(m, dtype)
    m = tf.reshape(m, [1, 1, nd, ns])
    return m
Beispiel #4
0
def merge_states(x):
    """Smash the last two dimensions of x into a single dimension."""
    *start, a, b = shape_list(x)
    return tf.reshape(x, start + [a * b])
Beispiel #5
0
def split_states(x, n):
    """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
    *start, m = shape_list(x)
    return tf.reshape(x, start + [n, m // n])