def _sample_rotation(self, shape, vecs, rng): """Samples a rotation matrix, either randomly or based on `vecs`.""" if not self._data_rotation: return jax.random.normal(rng, shape).astype('float32') assert len(shape) == 3 unused_n_dim, n_hashes, r_div_2 = shape assert len(vecs.shape) == 2 n_vecs = vecs.shape[0] rng1, rng2 = backend.random.split(rng, num=2) # We need to sample 2 * n_hashes * r_div_2 vectors from `vecs` at random. num_needed = 2 * n_hashes * r_div_2 if n_vecs < num_needed: # shape = (n_hashes, r_div_2) random_idxs_1 = jax.random.randint(rng1, (n_hashes, r_div_2), 0, n_vecs) random_idxs_2 = jax.random.randint(rng2, (n_hashes, r_div_2), 0, n_vecs) else: # Sample without replacement. shuffled_indices = jax.random.shuffle(rng1, np.arange(n_vecs)) random_idxs = np.reshape(shuffled_indices[:num_needed], (2, n_hashes, r_div_2)) random_idxs_1 = random_idxs[0] random_idxs_2 = random_idxs[1] if self._data_rotation_farthest: # shape = (n_hashes * r_div_2, ) random_idxs_1 = np.reshape(random_idxs_1, (-1, )) random_vecs_1 = vecs[random_idxs_1] # Sample candidates for vec2s. rng, subrng = backend.random.split(rng) # shape = (self._data_rotation_farthest_num, n_hashes * r_div_2) candidate_idxs_2 = jax.random.randint( subrng, (self._data_rotation_farthest_num, n_hashes * r_div_2), 0, n_vecs) candidate_vecs_2 = vecs[candidate_idxs_2] # shape = candidate_idxs_2.shape distances = -np.abs( np.einsum('hd,chd->ch', random_vecs_1, candidate_vecs_2)) # shape = (n_hashes * r_div_2,) farthest_idxs = np.argmax(distances, axis=0) # candidate_vecs_2.shape random_vecs_2 = candidate_vecs_2[farthest_idxs, np.arange(n_hashes * r_div_2)] # reshape to (n_hashes, r_div_2, n_dim) random_vecs_1 = np.reshape(random_vecs_1, (n_hashes, r_div_2, -1)) random_vecs_2 = np.reshape(random_vecs_2, (n_hashes, r_div_2, -1)) else: # shape = (n_hashes, r_div_2, n_dim) random_vecs_1 = vecs[random_idxs_1] random_vecs_2 = vecs[random_idxs_2] # shape = (n_dim, n_hashes, r_div_2) return np.transpose(random_vecs_2 - random_vecs_1, axes=[2, 0, 1])
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. assert self.n_buckets % 2 == 0 # If we factorize the hash, find a factor dividing n_buckets nicely. rot_size, factor_list = self.n_buckets, [self.n_buckets] if self._factorize_hash: # If we are given a list of factors, verify it and use later. if isinstance(self._factorize_hash, list): rot_size, product = 0, 1 factor_list = self._factorize_hash for factor in factor_list: assert factor % 2 == 0 product *= factor rot_size += factor assert product == self.n_buckets else: # Find one factor if just set to True. # We want to represent self.n_buckets = factor * rest so that # (1) both factor and rest are even, and (2) factor + rest is minimal. # To compute this we start from factor = sqrt(n_buckets) and go down # with it until we find one that satisfies the constraints above. factor = int(math.sqrt(self.n_buckets)) while factor > 0 and not (self.n_buckets % factor == 0 and factor % 2 == 0 and (self.n_buckets // factor) % 2 == 0): factor -= 1 if factor > 2: # Factor of 2 does not warrant the effort. rot_size = factor + (self.n_buckets // factor) factor_list = [factor, self.n_buckets // factor] rotations_shape = (vecs.shape[-1], self.n_hashes if self._rehash_each_round else 1, rot_size // 2) rng = jax.lax.tie_in(vecs, rng) rng, subrng = backend.random.split(rng) random_rotations = self._sample_rotation(rotations_shape, vecs, rng) # TODO(lukaszkaiser): the dropout mask will be used for all rounds of # hashing, so it's shared between them. Check if that's what we want. dropped_vecs = self.drop_for_hash(vecs, subrng) rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations) if self._rehash_each_round: if self._factorize_hash and len(factor_list) > 1: # We factorized self.n_buckets as the product of factor_list. # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for factor in factor_list: rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] cur_sum += factor // 2 rv = np.concatenate([rv, -rv], axis=-1) if buckets is None: buckets = np.argmax(rv, axis=-1) else: buckets += cur_product * np.argmax(rv, axis=-1) cur_product *= factor else: rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) buckets = np.argmax(rotated_vecs, axis=-1) # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * self.n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1, )) else: assert not self._factorize_hash rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) # In this configuration, we map each item to the top self.n_hashes buckets rotated_vecs = np.squeeze(rotated_vecs, 0) bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1])) bucket_range = np.reshape(bucket_range, (1, -1)) bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape) _, buckets = jax.lax.sort_key_val(rotated_vecs, bucket_range, dimension=-1) buckets = buckets[:, -self.n_hashes:] buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1, )) return buckets
def Accuracy(x, axis=-1, **kw): del kw prediction, target = x predicted_class = np.argmax(prediction, axis=axis) return np.equal(predicted_class, target)