def testBiasBatchCoordinates(self): """Testing the batch cooridnates mask.""" q = tf.constant([0, 0, 1, 1, 1, 1, 2, 2, 2], dtype=tf.int32) q = tf.expand_dims(q, axis=-1) k = tf.constant([0, 0, 0, 2, 2, 3, 3, 3], dtype=tf.int32) k = tf.expand_dims(k, axis=-1) ground_truth = np.array( [ [0, 0, 0, 1, 1, 1, 1, 1], # 0 [0, 0, 0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], # 1 (just masked) [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 1, 1, 1], # 2 [1, 1, 1, 0, 0, 1, 1, 1], [1, 1, 1, 0, 0, 1, 1, 1], ], np.float32) * -1e9 bias = common_attention.attention_bias_coordinates(q, k) with self.test_session() as session: session.run(tf.global_variables_initializer()) self.assertAllClose( bias.eval(), ground_truth, )
def testBiasBatchCoordinates(self): """Testing the batch cooridnates mask.""" q = tf.constant([0, 0, 1, 1, 1, 1, 2, 2, 2], dtype=tf.int32) q = tf.expand_dims(q, axis=-1) k = tf.constant([0, 0, 0, 2, 2, 3, 3, 3], dtype=tf.int32) k = tf.expand_dims(k, axis=-1) ground_truth = np.array([ [0, 0, 0, 1, 1, 1, 1, 1], # 0 [0, 0, 0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], # 1 (just masked) [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 1, 1, 1], # 2 [1, 1, 1, 0, 0, 1, 1, 1], [1, 1, 1, 0, 0, 1, 1, 1], ], np.float32) * -1e9 bias = common_attention.attention_bias_coordinates(q, k) with self.test_session() as session: session.run(tf.global_variables_initializer()) self.assertAllClose( bias.eval(), ground_truth, )
def testBiasBatchCoordinates(self): """Testing the batch coordinates mask.""" q = tf.constant([0, 0, 1, 1, 1, 1, 2, 2, 2], dtype=tf.int32) q = tf.expand_dims(q, axis=-1) k = tf.constant([0, 0, 0, 2, 2, 3, 3, 3], dtype=tf.int32) k = tf.expand_dims(k, axis=-1) ground_truth = np.array([ [0, 0, 0, 1, 1, 1, 1, 1], # 0 [0, 0, 0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], # 1 (just masked) [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 1, 1, 1], # 2 [1, 1, 1, 0, 0, 1, 1, 1], [1, 1, 1, 0, 0, 1, 1, 1], ], np.float32) * -1e9 bias = common_attention.attention_bias_coordinates(q, k) self.assertAllClose(self.evaluate(bias), ground_truth)