def compute(self, particles, particle_info, loss_fn): diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims(particles, axis=1) if self._normed() and particles.ndim == 2: diffs = safe_norm(diffs, ord=2, axis=-1) # N x D -> N diffs = jnp.reshape(diffs, (diffs.shape[0] * diffs.shape[1], -1)) # N * N (x D) factor = self.bandwidth_factor(particles.shape[0]) if diffs.ndim == 2: diff_norms = safe_norm(diffs, ord=2, axis=-1) else: diff_norms = diffs bandwidth = jnp.median(diff_norms) ** 2 * factor + 1e-5 def kernel(x, y): diff = safe_norm(x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y kernel_res = jnp.exp(-(diff**2) / bandwidth) if self._mode == "matrix": if self.matrix_mode == "norm_diag": return kernel_res * jnp.identity(x.shape[0]) else: return jnp.diag(kernel_res) else: return kernel_res return kernel
def compute(self, particles, particle_info, loss_fn): if self._random_weights is None: raise RuntimeError( "The `.init` method should be called first to initialize the" " random weights, biases and subset indices." ) if particles.shape != self._random_weights.shape: raise ValueError( "Shapes of `particles` and the random weights are mismatched, got {}" " and {}.".format(particles.shape, self._random_weights.shape) ) factor = self.bandwidth_factor(particles.shape[0]) if self.bandwidth_subset is not None: particles = particles[self._bandwidth_subset_indices] diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims( particles, axis=1 ) # N x N x D if particles.ndim == 2: diffs = safe_norm(diffs, ord=2, axis=-1) # N x N x D -> N x N diffs = jnp.reshape(diffs, (diffs.shape[0] * diffs.shape[1], -1)) # N * N x 1 if diffs.ndim == 2: diff_norms = safe_norm(diffs, ord=2, axis=-1) else: diff_norms = diffs median = jnp.argsort(diff_norms)[int(diffs.shape[0] / 2)] bandwidth = jnp.abs(diffs)[median] ** 2 * factor + 1e-5 def feature(x, w, b): return jnp.sqrt(2) * jnp.cos((x @ w + b) / bandwidth) def kernel(x, y): ws = ( self._random_weights if self.random_indices is None else self._random_weights[self.random_indices] ) bs = ( self._random_biases if self.random_indices is None else self._random_biases[self.random_indices] ) return jnp.sum( jax.vmap(lambda w, b: feature(x, w, b) * feature(y, w, b))(ws, bs) ) return kernel
def kernel(x, y): diff = safe_norm(x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y kernel_res = jnp.exp(-(diff**2) / bandwidth) if self._mode == "matrix": if self.matrix_mode == "norm_diag": return kernel_res * jnp.identity(x.shape[0]) else: return jnp.diag(kernel_res) else: return kernel_res
def test_safe_norm(axis, ord): m = np.array([[1.0e-5, 2e-5, 3e-5], [-1e-5, 1e-5, 0]]) assert_allclose( safe_norm(m, axis=axis), jnp.linalg.norm( m + (1e-5**ord if axis is None and ord is not None else 0.0), ord=ord, axis=axis, ), atol=1e-4, )
def kernel(x, y): diff = safe_norm(x - y, ord=2, axis=-1) if self._mode == "norm" else x - y return (self.const**2 + diff**2) ** self.expon