Ejemplo n.º 1
0
def test_build_logits_with_batch_comparison():
    batch_size = 7
    embed_dim = 5
    K = 13

    q = tf.nn.l2_normalize(np.random.normal(
        0, 1, size=(batch_size, embed_dim)).astype(np.float32),
                           axis=1)
    k = tf.nn.l2_normalize(np.random.normal(
        0, 1, size=(batch_size, embed_dim)).astype(np.float32),
                           axis=1)
    buffer = tf.Variable(
        tf.nn.l2_normalize(np.random.normal(0, 1, size=(K, embed_dim)).astype(
            np.float32),
                           axis=1))
    all_logits = _build_logits(q, k, buffer, compare_batch=True)
    assert len(all_logits.shape) == 2
    assert all_logits.shape[0] == batch_size
    assert all_logits.shape[1] == K + batch_size
Ejemplo n.º 2
0
def test_build_logits_with_mochi():
    batch_size = 7
    embed_dim = 5
    K = 13
    N = 6
    s = 2

    q = tf.nn.l2_normalize(np.random.normal(
        0, 1, size=(batch_size, embed_dim)).astype(np.float32),
                           axis=1)
    k = tf.nn.l2_normalize(np.random.normal(
        0, 1, size=(batch_size, embed_dim)).astype(np.float32),
                           axis=1)
    buffer = tf.Variable(
        tf.nn.l2_normalize(np.random.normal(0, 1, size=(K, embed_dim)).astype(
            np.float32),
                           axis=1))
    all_logits = _build_logits(q, k, buffer, N, s)
    assert len(all_logits.shape) == 2
    assert all_logits.shape[0] == batch_size
    assert all_logits.shape[1] == K + 1 + s