def test_build_logits_with_batch_comparison(): batch_size = 7 embed_dim = 5 K = 13 q = tf.nn.l2_normalize(np.random.normal( 0, 1, size=(batch_size, embed_dim)).astype(np.float32), axis=1) k = tf.nn.l2_normalize(np.random.normal( 0, 1, size=(batch_size, embed_dim)).astype(np.float32), axis=1) buffer = tf.Variable( tf.nn.l2_normalize(np.random.normal(0, 1, size=(K, embed_dim)).astype( np.float32), axis=1)) all_logits = _build_logits(q, k, buffer, compare_batch=True) assert len(all_logits.shape) == 2 assert all_logits.shape[0] == batch_size assert all_logits.shape[1] == K + batch_size
def test_build_logits_with_mochi(): batch_size = 7 embed_dim = 5 K = 13 N = 6 s = 2 q = tf.nn.l2_normalize(np.random.normal( 0, 1, size=(batch_size, embed_dim)).astype(np.float32), axis=1) k = tf.nn.l2_normalize(np.random.normal( 0, 1, size=(batch_size, embed_dim)).astype(np.float32), axis=1) buffer = tf.Variable( tf.nn.l2_normalize(np.random.normal(0, 1, size=(K, embed_dim)).astype( np.float32), axis=1)) all_logits = _build_logits(q, k, buffer, N, s) assert len(all_logits.shape) == 2 assert all_logits.shape[0] == batch_size assert all_logits.shape[1] == K + 1 + s