Exemple #1
0
def attention(query, key, value, query_size=(4, 4), key_size=(8, 8)):
    _, height, width, depth = query.shape.as_list()
    _, _, _, value_depth = value.shape.as_list()
    assert width % query_size[0] == 0
    assert height % query_size[1] == 0

    padding_kernel_size = ((key_size[0] - query_size[0]) * 2,
                           (key_size[1] - query_size[1]) * 2)

    unrolled_query = unroll(query, kernel_size=query_size, strides=query_size)
    unrolled_query = tf.reshape(unrolled_query,
                                (-1, query_size[0] * query_size[1], depth))

    unrolled_key = unroll(pad(key, kernel_size=padding_kernel_size),
                          kernel_size=key_size,
                          strides=query_size)
    unrolled_key = tf.reshape(unrolled_key,
                              (-1, key_size[0] * key_size[1], depth))

    unrolled_value = unroll(pad(value, kernel_size=padding_kernel_size),
                            kernel_size=key_size,
                            strides=query_size)
    unrolled_value = tf.reshape(unrolled_value,
                                (-1, key_size[0] * key_size[1], value_depth))

    tf.logging.debug('attention tensor query: %s', unrolled_query.get_shape())
    tf.logging.debug('attention tensor key:   %s', unrolled_key.get_shape())
    tf.logging.debug('attention tensor value: %s', unrolled_value.get_shape())

    distribution = tf.matmul(unrolled_query,
                             tf.transpose(unrolled_key, perm=[0, 2, 1]))
    distribution = tf.nn.softmax(distribution / tf.sqrt(float(depth)), axis=-1)
    tf.logging.debug('attention tensor distribution: %s',
                     distribution.get_shape())

    response = tf.matmul(distribution, unrolled_value)
    tf.logging.debug('attention tensor response: %s', response.get_shape())

    response = reroll(response, width, height, value_depth, query_size,
                      query_size)
    tf.logging.debug('attention tensor reshaped response: %s',
                     response.get_shape())

    distribution = tf.reshape(distribution, [
        -1, height // query_size[1], width // query_size[0],
        query_size[0] * query_size[1], key_size[0] * key_size[1]
    ])
    tf.logging.debug('attention tensor reshaped distribution: %s',
                     response.get_shape())

    return distribution, response
def test_unrolled_index():
    tensor = tf.constant(np.arange(4 * 32 * 32 * 1).reshape(4, 32, 32, 1))
    unrolled = unroll(pad(tensor, kernel_size=(3, 3)), kernel_size=(3, 3))
    with tf.Session() as session:
        unrolled = session.run(unrolled)

    assert (unrolled[0, 32 * 0 + 0, :,
                     0].flatten() == np.array([0, 0, 0, 0, 0, 1, 0, 32,
                                               33])).all()
    assert (unrolled[0, 32 * 1 + 1, :, 0].flatten() == np.array(
        [0, 1, 2, 32, 33, 34, 64, 65, 66])).all()
    assert (unrolled[0, 32 * 1 + 2, :, 0].flatten() == np.array(
        [1, 2, 3, 33, 34, 35, 65, 66, 67])).all()
    assert (unrolled[0, 32 * 2 + 1, :, 0].flatten() == np.array(
        [32, 33, 34, 64, 65, 66, 96, 97, 98])).all()
    assert (unrolled[0, 32 * 31 + 31, :, 0].flatten() == np.array(
        [990, 991, 0, 1022, 1023, 0, 0, 0, 0])).all()

    assert (
        unrolled[1, 32 * 1 + 1, :,
                 0].flatten() == np.array([0, 1, 2, 32, 33, 34, 64, 65, 66]) +
        (32 * 32 * 1)).all()
    assert (
        unrolled[1, 32 * 1 + 2, :,
                 0].flatten() == np.array([1, 2, 3, 33, 34, 35, 65, 66, 67]) +
        (32 * 32 * 1)).all()
    assert (unrolled[1, 32 * 2 + 1, :, 0].flatten() == np.array(
        [32, 33, 34, 64, 65, 66, 96, 97, 98]) + (32 * 32 * 1)).all()
def test_unrolled_shape():
    tensor = tf.zeros((4, 32, 32, 3))

    assert unroll(pad(tensor, kernel_size=(3, 3)),
                  kernel_size=(3,
                               3)).shape.as_list() == [4, 32 * 32, 3 * 3, 3]
    assert unroll(pad(tensor, kernel_size=(5, 5)),
                  kernel_size=(5,
                               5)).shape.as_list() == [4, 32 * 32, 5 * 5, 3]
    assert unroll(pad(tensor, kernel_size=(3, 3)),
                  kernel_size=(3, 3),
                  strides=(2, 2)).shape.as_list() == [
                      4, (32 // 2) * (32 // 2), 3 * 3, 3
                  ]

    tensor = tf.zeros((4, 16, 32, 3))

    assert unroll(pad(tensor, kernel_size=(3, 3)),
                  kernel_size=(3,
                               3)).shape.as_list() == [4, 16 * 32, 3 * 3, 3]
    assert unroll(pad(tensor, kernel_size=(5, 5)),
                  kernel_size=(5,
                               5)).shape.as_list() == [4, 16 * 32, 5 * 5, 3]
    assert unroll(pad(tensor, kernel_size=(3, 3)),
                  kernel_size=(3, 3),
                  strides=(2, 2)).shape.as_list() == [
                      4, (16 // 2) * (32 // 2), 3 * 3, 3
                  ]
Exemple #4
0
def attention(query, key, value, kernel_size=(5, 5), strides=(1, 1)):
    _, height, width, depth = query.shape.as_list()
    _, _, _, value_depth = value.shape.as_list()

    unrolled_query = tf.reshape(
        unroll(query, kernel_size=(1, 1), strides=strides), (-1, 1, depth))
    unrolled_key = tf.reshape(
        unroll(pad(key, kernel_size=kernel_size),
               kernel_size=kernel_size,
               strides=strides), (-1, kernel_size[0] * kernel_size[1], depth))
    unrolled_value = tf.reshape(
        unroll(pad(value, kernel_size=kernel_size),
               kernel_size=kernel_size,
               strides=strides),
        (-1, kernel_size[0] * kernel_size[1], value_depth))
    tf.logging.debug('attention tensor query: %s', unrolled_query.get_shape())
    tf.logging.debug('attention tensor key:   %s', unrolled_key.get_shape())
    tf.logging.debug('attention tensor value: %s', unrolled_value.get_shape())

    distribution = tf.matmul(unrolled_query,
                             tf.transpose(unrolled_key, perm=[0, 2, 1]))
    distribution = tf.nn.softmax(distribution / tf.sqrt(float(depth)), axis=-1)
    tf.logging.debug('attention tensor distribution: %s',
                     distribution.get_shape())

    response = tf.matmul(distribution, unrolled_value)
    tf.logging.debug('attention tensor response: %s', response.get_shape())

    # rshape
    distribution = tf.reshape(distribution,
                              (-1, height // strides[1], width // strides[0],
                               1, kernel_size[1], kernel_size[0]))
    response = tf.reshape(
        response, (-1, height // strides[1], width // strides[0], value_depth))
    tf.logging.debug('attention tensor reshaped response: %s',
                     response.get_shape())
    return distribution, response
def test_unrolled_index_with_strides():
    tensor = tf.constant(np.arange(4 * 32 * 32 * 1).reshape(4, 32, 32, 1))

    unrolled = unroll(pad(tensor, kernel_size=(3, 3)),
                      kernel_size=(3, 3),
                      strides=(2, 2))
    with tf.Session() as session:
        unrolled = session.run(unrolled)
    assert (unrolled[0, 16 * 0 + 0, :,
                     0].flatten() == np.array([0, 0, 0, 0, 0, 1, 0, 32,
                                               33])).all()
    assert (unrolled[0, 16 * 0 + 1, :,
                     0].flatten() == np.array([0, 0, 0, 1, 2, 3, 33, 34,
                                               35])).all()
    assert (unrolled[0, 16 * 1 + 0, :, 0].flatten() == np.array(
        [0, 32, 33, 0, 64, 65, 0, 96, 97])).all()
def test_pad_shape():
    tensor = tf.zeros((4, 32, 32, 3))

    assert pad(tensor, kernel_size=(3, 3)).shape.as_list() == [4, 34, 34, 3]
    assert pad(tensor, kernel_size=(5, 5)).shape.as_list() == [4, 36, 36, 3]
    assert pad(tensor, kernel_size=(3, 5)).shape.as_list() == [4, 36, 34, 3]