def compute(self, particles, particle_info, loss_fn): if self._random_weights is None: self._random_weights = jnp.array(npr.randn(*particles.shape)) self._random_biases = jnp.array( npr.rand(*particles.shape) * 2 * np.pi) factor = self.bandwidth_factor(particles.shape[0]) if self.bandwidth_subset is not None: particles = particles[npr.choice(particles.shape[0], self.bandwidth_subset)] diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims( particles, axis=1) # N x N x D if particles.ndim == 2: diffs = safe_norm(diffs, ord=2, axis=-1) # N x N x D -> N x N diffs = jnp.reshape(diffs, (diffs.shape[0] * diffs.shape[1], -1)) # N * N x 1 if diffs.ndim == 2: diff_norms = safe_norm(diffs, ord=2, axis=-1) else: diff_norms = diffs median = jnp.argsort(diff_norms)[int(diffs.shape[0] / 2)] bandwidth = jnp.abs(diffs)[median]**2 * factor + 1e-5 def feature(x, w, b): return jnp.sqrt(2) * jnp.cos((x @ w + b) / bandwidth) def kernel(x, y): ws = self._random_weights if self.random_indices is None else self._random_weights[ self.random_indices] bs = self._random_biases if self.random_indices is None else self._random_biases[ self.random_indices] return jnp.sum( jax.vmap(lambda w, b: feature(x, w, b) * feature(y, w, b))(ws, bs)) return kernel
def compute(self, particles, particle_info, loss_fn): diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims( particles, axis=1) # N x N (x D) if self._normed() and particles.ndim == 2: diffs = safe_norm(diffs, ord=2, axis=-1) # N x D -> N diffs = jnp.reshape( diffs, (diffs.shape[0] * diffs.shape[1], -1)) # N * N (x D) factor = self.bandwidth_factor(particles.shape[0]) if diffs.ndim == 2: diff_norms = safe_norm(diffs, ord=2, axis=-1) else: diff_norms = diffs median = jnp.argsort(diff_norms)[int(diffs.shape[0] / 2)] bandwidth = jnp.abs(diffs)[median]**2 * factor + 1e-5 if self._normed(): bandwidth = bandwidth[0] def kernel(x, y): diff = safe_norm( x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y kernel_res = jnp.exp(-diff**2 / bandwidth) if self._mode == 'matrix': if self.matrix_mode == 'norm_diag': return kernel_res * jnp.identity(x.shape[0]) else: return jnp.diag(kernel_res) else: return kernel_res return kernel
def kernel(x, y): diff = safe_norm( x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y kernel_res = jnp.exp(-diff**2 / bandwidth) if self._mode == 'matrix': if self.matrix_mode == 'norm_diag': return kernel_res * jnp.identity(x.shape[0]) else: return jnp.diag(kernel_res) else: return kernel_res
def kernel(x, y): diff = safe_norm(x - y, ord=2, axis=-1) if self._mode == 'norm' else x - y return (self.const**2 + diff**2)**self.expon