def test_different_config_different_results(self): inputs = tf.get_variable(name='inputs', shape=[50, 100]) with tf.variable_scope('a'): conf_a = lsh.get_simhash_config(100, 8) with tf.variable_scope('b'): conf_b = lsh.get_simhash_config(100, 8) hash_a = lsh.simhash(inputs, conf_a) hash_b = lsh.simhash(inputs, conf_b) equal = tf.reduce_all(tf.equal(hash_a, hash_b)) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) self.assertFalse(sess.run(equal))
def test_one_at_a_time(self): """make sure if gives a scalr back if given a vector""" inputs = tf.random_normal([5]) conf = lsh.get_simhash_config(5, 2) hashed = lsh.simhash(inputs, conf) self.assertEqual(hashed.get_shape().as_list(), [])
def test_shapes(self): """just ensure we get the correct shapes back""" inputs = tf.random_normal([10, 15]) conf = lsh.get_simhash_config(15, 16) hashed = lsh.simhash(inputs, conf) self.assertEqual([10, 1], hashed.get_shape().as_list())
def test_deterministic(self): """ensure it gives the same result twice""" inputs = tf.get_variable(name='inputs', shape=[1, 20]) conf = lsh.get_simhash_config(20, 8) hashed = lsh.simhash(inputs, conf) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) a = sess.run(hashed) b = sess.run(hashed) self.assertEqual(a, b)
def test_max_bits(self): """not an exhaustive test""" inputs = tf.random_normal([10, 100]) conf = lsh.get_simhash_config(100, 4) hashed = lsh.simhash(inputs, conf) in_range = tf.reduce_all(tf.less(hashed, 2**4)) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(100): self.assertTrue(sess.run(in_range))
def __init__(self, hash_bits, max_neighbours, key_size, value_shapes, similarity_measure=None, name='dnd'): """Set up the dnd. Args: hash_bits (int): how many bits for the hash. There will be `2**num_bits` individual buckets. max_neighbours (int): how many entries to store in each bucket. This controls the number of neighbours we check against. Operations will be linear in this value and it will likely effect learning performance significantly as well. key_size (int): size of the key vectors. We use the unhashed key vectors to compute similarities between keys we find from the nearest neighbour lookup. value_shapes (list): list of shapes for the values stored in the dictionary. similarity_measure (Optional[callable]): function which adds ops to compare a query key with all of the other keys in the bucket. If unspecified, the cosine similarity is used. Should be a callable which takes two input tensors: the query key (shaped `[key_size]`) and a `[max_neighbours, key_size]` tensor of keys to compare against. Should return a `[max_neighbours]` tensor of similarities, between 0 and 1 where 1 means the two keys were identical. name (Optional[str]): a name under which to group ops and variables. Defaults to `dnd`. """ self._name = name self._hash_size = hash_bits self._key_size = key_size self._bucket_size = max_neighbours with tf.variable_scope(self._name): self._keys, self._values = HashDND._setup_variables( hash_bits, max_neighbours, key_size, value_shapes) self._hash_config = get_simhash_config(self._key_size, self._hash_size) if not similarity_measure: similarity_measure = dnd.similarities.cosine_similarity self._similarity_measure = similarity_measure self._summarise_pressure()
def __init__(self, hash_bits, max_neighbours, key_size, value_shapes): """Set up the dnd. Args: hash_bits (int): how many bits for the hash. There will be `2**num_bits` individual buckets. max_neighbours (int): how many entries to store in each bucket. This controls the number of neighbours we check against. Operations will be linear in this value and it will likely effect learning performance significantly as well. key_size (int): size of the key vectors. We use the unhashed key vectors to compute similarities between keys we find from the nearest neighbour lookup. value_shapes () """ self._hash_size = hash_bits self._key_size = key_size self._bucket_size = max_neighbours self._keys, self._values = HashDND._setup_variables( hash_bits, max_neighbours, key_size, value_shapes) self._hash_config = get_simhash_config(self._key_size, self._hash_size) self._summarise_pressure()
def test_config_reuse(self): """make sure it is trying to reuse variables""" conf = lsh.get_simhash_config(10, 10) with self.assertRaisesRegex(ValueError, '.* already exists'): conf = lsh.get_simhash_config(10, 12)