Exemplo n.º 1
0
    def forward(self, inputs, weights):
        gamma, beta, epsilon_l = weights

        epsilon = self._init_epsilon
        if epsilon_l is not base.EMPTY_WEIGHTS:
            epsilon += np.abs(epsilon_l[0])

        # Omit B and C
        axis = tuple(range(1, len(np.shape(inputs)) - 1))
        # (B, 1, 1, C)
        nu2 = np.mean(inputs**2, axis=axis, keepdims=True)
        # (B, W, H, C)
        xhat = inputs / np.sqrt(nu2 + epsilon)

        return gamma * xhat + beta
Exemplo n.º 2
0
    def _sample_rotation(self, shape, vecs, rng):
        """Samples a rotation matrix, either randomly or based on `vecs`."""

        if not self._data_rotation:
            return jax.random.normal(rng, shape).astype('float32')

        assert len(shape) == 3
        unused_n_dim, n_hashes, r_div_2 = shape

        assert len(vecs.shape) == 2
        n_vecs = vecs.shape[0]

        rng1, rng2 = backend.random.split(rng, num=2)

        # We need to sample 2 * n_hashes * r_div_2 vectors from `vecs` at random.
        num_needed = 2 * n_hashes * r_div_2
        if n_vecs < num_needed:
            # shape = (n_hashes, r_div_2)
            random_idxs_1 = jax.random.randint(rng1, (n_hashes, r_div_2), 0,
                                               n_vecs)
            random_idxs_2 = jax.random.randint(rng2, (n_hashes, r_div_2), 0,
                                               n_vecs)
        else:
            # Sample without replacement.
            shuffled_indices = jax.random.shuffle(rng1, np.arange(n_vecs))
            random_idxs = np.reshape(shuffled_indices[:num_needed],
                                     (2, n_hashes, r_div_2))
            random_idxs_1 = random_idxs[0]
            random_idxs_2 = random_idxs[1]

        if self._data_rotation_farthest:
            # shape = (n_hashes * r_div_2, )
            random_idxs_1 = np.reshape(random_idxs_1, (-1, ))
            random_vecs_1 = vecs[random_idxs_1]

            # Sample candidates for vec2s.
            rng, subrng = backend.random.split(rng)
            # shape = (self._data_rotation_farthest_num, n_hashes * r_div_2)
            candidate_idxs_2 = jax.random.randint(
                subrng, (self._data_rotation_farthest_num, n_hashes * r_div_2),
                0, n_vecs)
            candidate_vecs_2 = vecs[candidate_idxs_2]
            # shape = candidate_idxs_2.shape
            distances = -np.abs(
                np.einsum('hd,chd->ch', random_vecs_1, candidate_vecs_2))
            # shape = (n_hashes * r_div_2,)
            farthest_idxs = np.argmax(distances, axis=0)
            # candidate_vecs_2.shape
            random_vecs_2 = candidate_vecs_2[farthest_idxs,
                                             np.arange(n_hashes * r_div_2)]

            # reshape to (n_hashes, r_div_2, n_dim)
            random_vecs_1 = np.reshape(random_vecs_1, (n_hashes, r_div_2, -1))
            random_vecs_2 = np.reshape(random_vecs_2, (n_hashes, r_div_2, -1))
        else:
            # shape = (n_hashes, r_div_2, n_dim)
            random_vecs_1 = vecs[random_idxs_1]
            random_vecs_2 = vecs[random_idxs_2]

        # shape = (n_dim, n_hashes, r_div_2)
        return np.transpose(random_vecs_2 - random_vecs_1, axes=[2, 0, 1])
Exemplo n.º 3
0
def SaturationCost(x, limit=0.9):
    return np.minimum(0, np.abs(x) - limit)