Example #1
0
    def compute(self, particles, particle_info, loss_fn):
        diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims(
            particles, axis=1)  # N x N (x D)
        if self._normed() and particles.ndim == 2:
            diffs = safe_norm(diffs, ord=2, axis=-1)  # N x D -> N
        diffs = jnp.reshape(
            diffs, (diffs.shape[0] * diffs.shape[1], -1))  # N * N (x D)
        factor = self.bandwidth_factor(particles.shape[0])
        if diffs.ndim == 2:
            diff_norms = safe_norm(diffs, ord=2, axis=-1)
        else:
            diff_norms = diffs
        bandwidth = jnp.median(diff_norms)**2 * factor + 1e-5

        def kernel(x, y):
            diff = safe_norm(
                x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y
            kernel_res = jnp.exp(-(diff**2) / bandwidth)
            if self._mode == "matrix":
                if self.matrix_mode == "norm_diag":
                    return kernel_res * jnp.identity(x.shape[0])
                else:
                    return jnp.diag(kernel_res)
            else:
                return kernel_res

        return kernel
Example #2
0
    def compute(self, particles, particle_info, loss_fn):
        if self._random_weights is None:
            self._random_weights = jnp.array(npr.randn(*particles.shape))
            self._random_biases = jnp.array(
                npr.rand(*particles.shape) * 2 * np.pi)
        factor = self.bandwidth_factor(particles.shape[0])
        if self.bandwidth_subset is not None:
            particles = particles[npr.choice(particles.shape[0],
                                             self.bandwidth_subset)]
        diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims(
            particles, axis=1)  # N x N x D
        if particles.ndim == 2:
            diffs = safe_norm(diffs, ord=2, axis=-1)  # N x N x D -> N x N
        diffs = jnp.reshape(diffs,
                            (diffs.shape[0] * diffs.shape[1], -1))  # N * N x 1
        if diffs.ndim == 2:
            diff_norms = safe_norm(diffs, ord=2, axis=-1)
        else:
            diff_norms = diffs
        median = jnp.argsort(diff_norms)[int(diffs.shape[0] / 2)]
        bandwidth = jnp.abs(diffs)[median]**2 * factor + 1e-5

        def feature(x, w, b):
            return jnp.sqrt(2) * jnp.cos((x @ w + b) / bandwidth)

        def kernel(x, y):
            ws = (self._random_weights if self.random_indices is None else
                  self._random_weights[self.random_indices])
            bs = (self._random_biases if self.random_indices is None else
                  self._random_biases[self.random_indices])
            return jnp.sum(
                jax.vmap(lambda w, b: feature(x, w, b) * feature(y, w, b))(ws,
                                                                           bs))

        return kernel
Example #3
0
 def kernel(x, y):
     diff = safe_norm(x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y
     kernel_res = jnp.exp(- diff ** 2 / bandwidth)
     if self._mode == 'matrix':
         if self.matrix_mode == 'norm_diag':
             return kernel_res * jnp.identity(x.shape[0])
         else:
             return jnp.diag(kernel_res)
     else:
         return kernel_res
Example #4
0
 def kernel(x, y):
     diff = safe_norm(x - y, ord=2,
                      axis=-1) if self._mode == "norm" else x - y
     return (self.const**2 + diff**2)**self.expon