def attention(query, key, value, query_size=(4, 4), key_size=(8, 8)): _, height, width, depth = query.shape.as_list() _, _, _, value_depth = value.shape.as_list() assert width % query_size[0] == 0 assert height % query_size[1] == 0 padding_kernel_size = ((key_size[0] - query_size[0]) * 2, (key_size[1] - query_size[1]) * 2) unrolled_query = unroll(query, kernel_size=query_size, strides=query_size) unrolled_query = tf.reshape(unrolled_query, (-1, query_size[0] * query_size[1], depth)) unrolled_key = unroll(pad(key, kernel_size=padding_kernel_size), kernel_size=key_size, strides=query_size) unrolled_key = tf.reshape(unrolled_key, (-1, key_size[0] * key_size[1], depth)) unrolled_value = unroll(pad(value, kernel_size=padding_kernel_size), kernel_size=key_size, strides=query_size) unrolled_value = tf.reshape(unrolled_value, (-1, key_size[0] * key_size[1], value_depth)) tf.logging.debug('attention tensor query: %s', unrolled_query.get_shape()) tf.logging.debug('attention tensor key: %s', unrolled_key.get_shape()) tf.logging.debug('attention tensor value: %s', unrolled_value.get_shape()) distribution = tf.matmul(unrolled_query, tf.transpose(unrolled_key, perm=[0, 2, 1])) distribution = tf.nn.softmax(distribution / tf.sqrt(float(depth)), axis=-1) tf.logging.debug('attention tensor distribution: %s', distribution.get_shape()) response = tf.matmul(distribution, unrolled_value) tf.logging.debug('attention tensor response: %s', response.get_shape()) response = reroll(response, width, height, value_depth, query_size, query_size) tf.logging.debug('attention tensor reshaped response: %s', response.get_shape()) distribution = tf.reshape(distribution, [ -1, height // query_size[1], width // query_size[0], query_size[0] * query_size[1], key_size[0] * key_size[1] ]) tf.logging.debug('attention tensor reshaped distribution: %s', response.get_shape()) return distribution, response
def test_unrolled_index(): tensor = tf.constant(np.arange(4 * 32 * 32 * 1).reshape(4, 32, 32, 1)) unrolled = unroll(pad(tensor, kernel_size=(3, 3)), kernel_size=(3, 3)) with tf.Session() as session: unrolled = session.run(unrolled) assert (unrolled[0, 32 * 0 + 0, :, 0].flatten() == np.array([0, 0, 0, 0, 0, 1, 0, 32, 33])).all() assert (unrolled[0, 32 * 1 + 1, :, 0].flatten() == np.array( [0, 1, 2, 32, 33, 34, 64, 65, 66])).all() assert (unrolled[0, 32 * 1 + 2, :, 0].flatten() == np.array( [1, 2, 3, 33, 34, 35, 65, 66, 67])).all() assert (unrolled[0, 32 * 2 + 1, :, 0].flatten() == np.array( [32, 33, 34, 64, 65, 66, 96, 97, 98])).all() assert (unrolled[0, 32 * 31 + 31, :, 0].flatten() == np.array( [990, 991, 0, 1022, 1023, 0, 0, 0, 0])).all() assert ( unrolled[1, 32 * 1 + 1, :, 0].flatten() == np.array([0, 1, 2, 32, 33, 34, 64, 65, 66]) + (32 * 32 * 1)).all() assert ( unrolled[1, 32 * 1 + 2, :, 0].flatten() == np.array([1, 2, 3, 33, 34, 35, 65, 66, 67]) + (32 * 32 * 1)).all() assert (unrolled[1, 32 * 2 + 1, :, 0].flatten() == np.array( [32, 33, 34, 64, 65, 66, 96, 97, 98]) + (32 * 32 * 1)).all()
def test_unrolled_shape(): tensor = tf.zeros((4, 32, 32, 3)) assert unroll(pad(tensor, kernel_size=(3, 3)), kernel_size=(3, 3)).shape.as_list() == [4, 32 * 32, 3 * 3, 3] assert unroll(pad(tensor, kernel_size=(5, 5)), kernel_size=(5, 5)).shape.as_list() == [4, 32 * 32, 5 * 5, 3] assert unroll(pad(tensor, kernel_size=(3, 3)), kernel_size=(3, 3), strides=(2, 2)).shape.as_list() == [ 4, (32 // 2) * (32 // 2), 3 * 3, 3 ] tensor = tf.zeros((4, 16, 32, 3)) assert unroll(pad(tensor, kernel_size=(3, 3)), kernel_size=(3, 3)).shape.as_list() == [4, 16 * 32, 3 * 3, 3] assert unroll(pad(tensor, kernel_size=(5, 5)), kernel_size=(5, 5)).shape.as_list() == [4, 16 * 32, 5 * 5, 3] assert unroll(pad(tensor, kernel_size=(3, 3)), kernel_size=(3, 3), strides=(2, 2)).shape.as_list() == [ 4, (16 // 2) * (32 // 2), 3 * 3, 3 ]
def attention(query, key, value, kernel_size=(5, 5), strides=(1, 1)): _, height, width, depth = query.shape.as_list() _, _, _, value_depth = value.shape.as_list() unrolled_query = tf.reshape( unroll(query, kernel_size=(1, 1), strides=strides), (-1, 1, depth)) unrolled_key = tf.reshape( unroll(pad(key, kernel_size=kernel_size), kernel_size=kernel_size, strides=strides), (-1, kernel_size[0] * kernel_size[1], depth)) unrolled_value = tf.reshape( unroll(pad(value, kernel_size=kernel_size), kernel_size=kernel_size, strides=strides), (-1, kernel_size[0] * kernel_size[1], value_depth)) tf.logging.debug('attention tensor query: %s', unrolled_query.get_shape()) tf.logging.debug('attention tensor key: %s', unrolled_key.get_shape()) tf.logging.debug('attention tensor value: %s', unrolled_value.get_shape()) distribution = tf.matmul(unrolled_query, tf.transpose(unrolled_key, perm=[0, 2, 1])) distribution = tf.nn.softmax(distribution / tf.sqrt(float(depth)), axis=-1) tf.logging.debug('attention tensor distribution: %s', distribution.get_shape()) response = tf.matmul(distribution, unrolled_value) tf.logging.debug('attention tensor response: %s', response.get_shape()) # rshape distribution = tf.reshape(distribution, (-1, height // strides[1], width // strides[0], 1, kernel_size[1], kernel_size[0])) response = tf.reshape( response, (-1, height // strides[1], width // strides[0], value_depth)) tf.logging.debug('attention tensor reshaped response: %s', response.get_shape()) return distribution, response
def test_unrolled_index_with_strides(): tensor = tf.constant(np.arange(4 * 32 * 32 * 1).reshape(4, 32, 32, 1)) unrolled = unroll(pad(tensor, kernel_size=(3, 3)), kernel_size=(3, 3), strides=(2, 2)) with tf.Session() as session: unrolled = session.run(unrolled) assert (unrolled[0, 16 * 0 + 0, :, 0].flatten() == np.array([0, 0, 0, 0, 0, 1, 0, 32, 33])).all() assert (unrolled[0, 16 * 0 + 1, :, 0].flatten() == np.array([0, 0, 0, 1, 2, 3, 33, 34, 35])).all() assert (unrolled[0, 16 * 1 + 0, :, 0].flatten() == np.array( [0, 32, 33, 0, 64, 65, 0, 96, 97])).all()
def test_pad_shape(): tensor = tf.zeros((4, 32, 32, 3)) assert pad(tensor, kernel_size=(3, 3)).shape.as_list() == [4, 34, 34, 3] assert pad(tensor, kernel_size=(5, 5)).shape.as_list() == [4, 36, 36, 3] assert pad(tensor, kernel_size=(3, 5)).shape.as_list() == [4, 36, 34, 3]