示例#1
0
    def _sample_rotation(self, shape, vecs, rng):
        """Samples a rotation matrix, either randomly or based on `vecs`."""

        if not self._data_rotation:
            return jax.random.normal(rng, shape).astype('float32')

        assert len(shape) == 3
        unused_n_dim, n_hashes, r_div_2 = shape

        assert len(vecs.shape) == 2
        n_vecs = vecs.shape[0]

        rng1, rng2 = backend.random.split(rng, num=2)

        # We need to sample 2 * n_hashes * r_div_2 vectors from `vecs` at random.
        num_needed = 2 * n_hashes * r_div_2
        if n_vecs < num_needed:
            # shape = (n_hashes, r_div_2)
            random_idxs_1 = jax.random.randint(rng1, (n_hashes, r_div_2), 0,
                                               n_vecs)
            random_idxs_2 = jax.random.randint(rng2, (n_hashes, r_div_2), 0,
                                               n_vecs)
        else:
            # Sample without replacement.
            shuffled_indices = jax.random.shuffle(rng1, np.arange(n_vecs))
            random_idxs = np.reshape(shuffled_indices[:num_needed],
                                     (2, n_hashes, r_div_2))
            random_idxs_1 = random_idxs[0]
            random_idxs_2 = random_idxs[1]

        if self._data_rotation_farthest:
            # shape = (n_hashes * r_div_2, )
            random_idxs_1 = np.reshape(random_idxs_1, (-1, ))
            random_vecs_1 = vecs[random_idxs_1]

            # Sample candidates for vec2s.
            rng, subrng = backend.random.split(rng)
            # shape = (self._data_rotation_farthest_num, n_hashes * r_div_2)
            candidate_idxs_2 = jax.random.randint(
                subrng, (self._data_rotation_farthest_num, n_hashes * r_div_2),
                0, n_vecs)
            candidate_vecs_2 = vecs[candidate_idxs_2]
            # shape = candidate_idxs_2.shape
            distances = -np.abs(
                np.einsum('hd,chd->ch', random_vecs_1, candidate_vecs_2))
            # shape = (n_hashes * r_div_2,)
            farthest_idxs = np.argmax(distances, axis=0)
            # candidate_vecs_2.shape
            random_vecs_2 = candidate_vecs_2[farthest_idxs,
                                             np.arange(n_hashes * r_div_2)]

            # reshape to (n_hashes, r_div_2, n_dim)
            random_vecs_1 = np.reshape(random_vecs_1, (n_hashes, r_div_2, -1))
            random_vecs_2 = np.reshape(random_vecs_2, (n_hashes, r_div_2, -1))
        else:
            # shape = (n_hashes, r_div_2, n_dim)
            random_vecs_1 = vecs[random_idxs_1]
            random_vecs_2 = vecs[random_idxs_2]

        # shape = (n_dim, n_hashes, r_div_2)
        return np.transpose(random_vecs_2 - random_vecs_1, axes=[2, 0, 1])
示例#2
0
    def hash_vectors(self, vecs, rng):
        # See https://arxiv.org/pdf/1509.02897.pdf
        # We sample a different random rotation for each round of hashing to
        # decrease the probability of hash misses.
        assert self.n_buckets % 2 == 0

        # If we factorize the hash, find a factor dividing n_buckets nicely.
        rot_size, factor_list = self.n_buckets, [self.n_buckets]
        if self._factorize_hash:
            # If we are given a list of factors, verify it and use later.
            if isinstance(self._factorize_hash, list):
                rot_size, product = 0, 1
                factor_list = self._factorize_hash
                for factor in factor_list:
                    assert factor % 2 == 0
                    product *= factor
                    rot_size += factor
                assert product == self.n_buckets
            else:  # Find one factor if just set to True.
                # We want to represent self.n_buckets = factor * rest so that
                # (1) both factor and rest are even, and (2) factor + rest is minimal.
                # To compute this we start from factor = sqrt(n_buckets) and go down
                # with it until we find one that satisfies the constraints above.
                factor = int(math.sqrt(self.n_buckets))
                while factor > 0 and not (self.n_buckets % factor == 0
                                          and factor % 2 == 0 and
                                          (self.n_buckets // factor) % 2 == 0):
                    factor -= 1
                if factor > 2:  # Factor of 2 does not warrant the effort.
                    rot_size = factor + (self.n_buckets // factor)
                    factor_list = [factor, self.n_buckets // factor]

        rotations_shape = (vecs.shape[-1],
                           self.n_hashes if self._rehash_each_round else 1,
                           rot_size // 2)

        rng = jax.lax.tie_in(vecs, rng)
        rng, subrng = backend.random.split(rng)
        random_rotations = self._sample_rotation(rotations_shape, vecs, rng)

        # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
        # hashing, so it's shared between them. Check if that's what we want.
        dropped_vecs = self.drop_for_hash(vecs, subrng)
        rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)

        if self._rehash_each_round:
            if self._factorize_hash and len(factor_list) > 1:
                # We factorized self.n_buckets as the product of factor_list.
                # Get the buckets for them and combine.
                buckets, cur_sum, cur_product = None, 0, 1
                for factor in factor_list:
                    rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
                    cur_sum += factor // 2
                    rv = np.concatenate([rv, -rv], axis=-1)
                    if buckets is None:
                        buckets = np.argmax(rv, axis=-1)
                    else:
                        buckets += cur_product * np.argmax(rv, axis=-1)
                    cur_product *= factor
            else:
                rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs],
                                              axis=-1)
                buckets = np.argmax(rotated_vecs, axis=-1)
            # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
            # bucket numbers from different hashing rounds don't overlap.
            offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
            offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
            buckets = np.reshape(buckets + offsets, (-1, ))
        else:
            assert not self._factorize_hash
            rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs],
                                          axis=-1)
            # In this configuration, we map each item to the top self.n_hashes buckets
            rotated_vecs = np.squeeze(rotated_vecs, 0)
            bucket_range = jax.lax.tie_in(vecs,
                                          np.arange(rotated_vecs.shape[-1]))
            bucket_range = np.reshape(bucket_range, (1, -1))
            bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)

            _, buckets = jax.lax.sort_key_val(rotated_vecs,
                                              bucket_range,
                                              dimension=-1)
            buckets = buckets[:, -self.n_hashes:]
            buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1, ))

        return buckets
示例#3
0
def Accuracy(x, axis=-1, **kw):
    del kw
    prediction, target = x
    predicted_class = np.argmax(prediction, axis=axis)
    return np.equal(predicted_class, target)