예제 #1
0
    def compute(self, particles, particle_info, loss_fn):

        diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims(particles, axis=1)

        if self._normed() and particles.ndim == 2:
            diffs = safe_norm(diffs, ord=2, axis=-1)  # N x D -> N
        diffs = jnp.reshape(diffs, (diffs.shape[0] * diffs.shape[1], -1))  # N * N (x D)
        factor = self.bandwidth_factor(particles.shape[0])

        if diffs.ndim == 2:
            diff_norms = safe_norm(diffs, ord=2, axis=-1)
        else:
            diff_norms = diffs

        bandwidth = jnp.median(diff_norms) ** 2 * factor + 1e-5

        def kernel(x, y):
            diff = safe_norm(x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y
            kernel_res = jnp.exp(-(diff**2) / bandwidth)
            if self._mode == "matrix":
                if self.matrix_mode == "norm_diag":
                    return kernel_res * jnp.identity(x.shape[0])
                else:
                    return jnp.diag(kernel_res)
            else:
                return kernel_res

        return kernel
예제 #2
0
    def compute(self, particles, particle_info, loss_fn):
        if self._random_weights is None:
            raise RuntimeError(
                "The `.init` method should be called first to initialize the"
                " random weights, biases and subset indices."
            )
        if particles.shape != self._random_weights.shape:
            raise ValueError(
                "Shapes of `particles` and the random weights are mismatched, got {}"
                " and {}.".format(particles.shape, self._random_weights.shape)
            )
        factor = self.bandwidth_factor(particles.shape[0])
        if self.bandwidth_subset is not None:
            particles = particles[self._bandwidth_subset_indices]
        diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims(
            particles, axis=1
        )  # N x N x D
        if particles.ndim == 2:
            diffs = safe_norm(diffs, ord=2, axis=-1)  # N x N x D -> N x N
        diffs = jnp.reshape(diffs, (diffs.shape[0] * diffs.shape[1], -1))  # N * N x 1
        if diffs.ndim == 2:
            diff_norms = safe_norm(diffs, ord=2, axis=-1)
        else:
            diff_norms = diffs
        median = jnp.argsort(diff_norms)[int(diffs.shape[0] / 2)]
        bandwidth = jnp.abs(diffs)[median] ** 2 * factor + 1e-5

        def feature(x, w, b):
            return jnp.sqrt(2) * jnp.cos((x @ w + b) / bandwidth)

        def kernel(x, y):
            ws = (
                self._random_weights
                if self.random_indices is None
                else self._random_weights[self.random_indices]
            )
            bs = (
                self._random_biases
                if self.random_indices is None
                else self._random_biases[self.random_indices]
            )
            return jnp.sum(
                jax.vmap(lambda w, b: feature(x, w, b) * feature(y, w, b))(ws, bs)
            )

        return kernel
예제 #3
0
 def kernel(x, y):
     diff = safe_norm(x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y
     kernel_res = jnp.exp(-(diff**2) / bandwidth)
     if self._mode == "matrix":
         if self.matrix_mode == "norm_diag":
             return kernel_res * jnp.identity(x.shape[0])
         else:
             return jnp.diag(kernel_res)
     else:
         return kernel_res
예제 #4
0
def test_safe_norm(axis, ord):
    m = np.array([[1.0e-5, 2e-5, 3e-5], [-1e-5, 1e-5, 0]])
    assert_allclose(
        safe_norm(m, axis=axis),
        jnp.linalg.norm(
            m + (1e-5**ord if axis is None and ord is not None else 0.0),
            ord=ord,
            axis=axis,
        ),
        atol=1e-4,
    )
예제 #5
0
 def kernel(x, y):
     diff = safe_norm(x - y, ord=2, axis=-1) if self._mode == "norm" else x - y
     return (self.const**2 + diff**2) ** self.expon