def multinomial(rng, logits, num_samples): """Draws samples from a multinomial distribution. Args: rng: A JAX PRNGKey. logits: Unnormalized log-probabilities, of shape ``[batch_size, categories]`` or ``[categories]``. num_samples: Number of samples to draw. Returns: Chosen categories, of shape ``[batch_size, num_samples]`` or ``[num_samples]``. """ # NOTE(tycai): Currently, tf.multinomial uses CDF for non-XLA CPU only. # We may want to switch to the Gumbel trick as used in XLA. if len(logits.shape) > 2 or not logits.shape: raise ValueError("Logits must be rank-1 or rank-2.") probs = jax.nn.softmax(logits) probs = jnp.cumsum(probs, axis=-1) # Special-case num_samples == 1 due to TPU padding, as in TF2XLA. # https://github.com/tensorflow/tensorflow/blob/b1608511d5a50d05825c4025b0c347e8689a241f/tensorflow/compiler/tf2xla/kernels/categorical_op.cc#L79 if num_samples == 1: a = jax.random.uniform(rng, logits.shape[:-1] + (1, )) out = jnp.argmin(a > probs, axis=-1) return out[..., None] else: a = jax.random.uniform(rng, (num_samples, ) + logits.shape[:-1] + (1, )) out = jnp.argmin(a > probs, axis=-1) return jnp.transpose(out)
def body(state: State): # N, K cluster_dist = vmap(lambda point: cluster_dist_metric( point, state.metric_state))(points) # N current_cluster_dist = vmap(lambda n, k: cluster_dist[n, k])( jnp.arange(N, dtype=jnp.int_), state.cluster_id) # N, K rel_dist = cluster_dist - current_cluster_dist[:, None] # N min_dist = jnp.min(rel_dist, axis=-1) proposed_cluster_id = jnp.argmin(rel_dist, axis=-1) can_take_from = state.metric_state.num_k[state.cluster_id] > D + 1 min_dist = jnp.where(mask & can_take_from, min_dist, jnp.inf) amin = jnp.argmin(min_dist) k_to = proposed_cluster_id[amin] cluster_id = dynamic_update_slice(state.cluster_id, k_to[None], amin[None]) # # update cluster_id # cluster_id = jnp.where(state.metric_state.num_k[state.cluster_id] < D+1, state.cluster_id, jnp.argmin(rel_dist, axis=-1)) # proposed_num_k = jnp.bincount(proposed_cluster_id, weights, minlength=0, length=K) # cluster_id = jnp.where(proposed_num_k[proposed_cluster_id] < D + 1, state.cluster_id, proposed_cluster_id) metric_state = update_metric_state(cluster_id) # print() # print(state.i, jnp.sum(state.cluster_id!=cluster_id), amin, state.cluster_id[amin], k_to, jnp.min(rel_dist)) done = jnp.all(cluster_id == state.cluster_id) state = state._replace(i=state.i + 1, done=done, cluster_id=cluster_id, metric_state=metric_state) return state
def onnx_argmax(x, axis=0, keepdims=1, select_last_index=0): if select_last_index == 0: y = jnp.argmin(x, axis=axis) if keepdims == 1: y = jnp.expand_dims(y, axis) else: x = jnp.flip(x, axis) y = jnp.argmin(x, axis=axis) y = x.shape[axis] - y - 1 if keepdims: y = jnp.expand_dims(y, axis) return y
def test_pair_correlation_species(self, dtype): displacement = lambda Ra, Rb, **kwargs: Ra - Rb R = np.array([[1, 0], [0, 0], [10, 1], [10, 3]], dtype=dtype) species = np.array([0, 0, 1, 1]) rs = np.linspace(0, 2, 60, dtype=dtype) g = quantity.pair_correlation(displacement, rs, f32(0.1), species) g_0, g_1 = g(R) g_0 = np.mean(g_0, axis=0) g_1 = np.mean(g_1, axis=0) self.assertAllClose(np.argmax(g_0), np.argmin((rs - 1.)**2)) self.assertAllClose(np.argmax(g_1), np.argmin((rs - 2.)**2)) assert g_0.dtype == dtype assert g_1.dtype == dtype
def fit(self, X, y=None): self.initialize_centers(X) for _ in tqdm(range(self.n_iter)): self.inertia_ = self.dist_fun(X, self.centers) self.clusters = jnp.argmin(self.inertia_, axis=-1) self.adjust_centers(X) return self.clusters
def _iter_body(state): i, centroids, counts, key = state centroids_norm = jnp.sum(centroids**2, axis=-1, keepdims=True) # K x 1 data_norm = jnp.sum(data**2, axis=-1, keepdims=True) # N x 1 dot_product = jnp.matmul(data, jnp.transpose(centroids)) # N x K distances = data_norm + jnp.transpose( centroids_norm) - 2 * dot_product # N x K # centroids_norm = jnp.sum(centroids ** 2, axis=-1, keepdims=True) # K x 1 # data_norm = jnp.sum(data ** 2, axis=-1, keepdims=True) # N x 1 # dot_product = jnp.matmul( # data / data_norm, # jnp.transpose(centroids / centroids_norm)) # N x K # distances = -1. * dot_product # N x K labels = jnp.argmin(distances, axis=1) one_hot = jax.nn.one_hot(labels, prev_counts.shape[0]) # labels = one_hot * norm labels = one_hot # N x K dw = jnp.matmul(jnp.transpose(one_hot), data) # K x dim count = jnp.expand_dims(jnp.sum(one_hot, axis=0), axis=-1) # K x 1 dw /= (count + eps) centroids = decay * centroids + (1 - decay) * dw key, sub_key = jax.random.split(key) counts = counts_decay * counts + (1 - counts_decay) * count counts, centroids = jax.lax.cond(pred=jnp.min(counts) < dead_threshold, true_fun=_dead_centroid_fix, false_fun=lambda operand: (operand[0], operand[1]), operand=(counts, centroids, sub_key)) return (i + 1, centroids, count, key)
def _compute_assignments(numerical_points, categorical_points, numerical_prototypes, categorical_prototypes, norm_ord, gamma, nan_friendly): if nan_friendly: numerical_points, categorical_points = _fill_nans( numerical_points, numerical_prototypes, categorical_points, categorical_prototypes) numerical_dist = compute_kmeans_distance(numerical_points, numerical_prototypes, norm_ord) categorical_dist = jax.vmap(lambda point: jax.vmap(jnp.sum)( categorical_prototypes != point))(categorical_points) dist = numerical_dist + gamma * categorical_dist assignment = jnp.argmin(dist, axis=1) numerical_cost = compute_kmeans_cost(numerical_points, numerical_prototypes, assignment, norm_ord) categorical_cost = jax.vmap(jnp.sum)( categorical_prototypes[assignment, :] == categorical_points).sum() cost = numerical_cost + gamma * categorical_cost # Review paper gamma return assignment, cost
def update(state, inp): Psi_c, P_c, Psi_a = state y, x, train = inp Psi_p = Psi_c P_p = P_c + Q Psi_a = beta * Psi_a + (1 - beta) * Psi_c d = jnp.where(train, x, const[jnp.argmin(jnp.abs(const - y * jnp.exp(-1j * Psi_a)))]) H = 1j * d * jnp.exp(1j * Psi_p) K = P_p * H.conj() / (H * P_p * H.conj() + R) v = y - d * jnp.exp(1j * Psi_p) out = (Psi_c, d) Psi_c = Psi_p + K * v P_c = (1. - K * H) * P_p state = (Psi_c, P_c, Psi_a) return state, out
def update(i, state, inp): Psi_c, P_c, Psi_a, Q, R = state y, x = inp Psi_p = Psi_c P_p = P_c + Q # exponential moving average Psi_a = beta * Psi_a + (1 - beta) * Psi_c d = jnp.where( train(i), x, const[jnp.argmin(jnp.abs(const - y * jnp.exp(-1j * Psi_a)))]) H = 1j * d * jnp.exp(1j * Psi_p) K = P_p * H.conj() / (H * P_p * H.conj() + R) v = y - d * jnp.exp(1j * Psi_p) out = (Psi_c, (Q, R)) Psi_c = Psi_p + K * v P_c = (1. - K * H) * P_p e = y - d * jnp.exp(1j * Psi_c) Q = alpha * Q + (1 - alpha) * K * v * v.conj() * K.conj() if akf else Q R = alpha * R + (1 - alpha) * (e * e.conj() + H * P_p * H.conj()) if akf else R state = (Psi_c, P_c, Psi_a, Q, R) return state, out
def sample_best_permutation(key, coupling, cost, num_trials=10): """Samples permutation matrices and returns the one with lowest cost. See **Convex Relaxations for Permutation Problems** paper for rough explanation of the algorithm. Args: key: jnp.ndarray that functions as a PRNG key. coupling: jnp.ndarray of shape [N, N] cost: jnp.ndarray of shape [N, N]. num_trials: int, determins the amount of times we sample a permutation. Returns: permutation matrix: jnp.ndarray of shape [N, N] of floating point type. this is the permutation matrix with lowest optimal transport cost. """ vec_sample_permutation = jax.vmap( sample_permutation, in_axes=(0, None), out_axes=0) key = jax.random.split(key, num_trials) perms = vec_sample_permutation(key, coupling) # Pick the permutation with minimal ot cost ot = jnp.sum(perms * cost[jnp.newaxis, :, :], axis=(1, 2)) min_idx = jnp.argmin(ot) out_perm = perms[min_idx] return out_perm
def choose_representer_from_gram(G, factors): fG = np.dot(factors, G) rkhs_distances_sq = (np.dot(factors, fG).flatten() + np.diag(G) - 2 * fG).squeeze() rval = np.argmin(rkhs_distances_sq) assert rval < rkhs_distances_sq.size return rval
def step(states, r): Q, R, x_c, P_c = states x_p = A @ x_c P_p = A @ P_c @ A.T + Q p_p = jnp.exp(1j * x_p[0,0]) s_hat_p = const[jnp.argmin(jnp.abs(const - r * p_p.conj()))] r_hat_p = s_hat_p * p_p d = r - r_hat_p H = jnp.array([[-r_hat_p.imag, 0], [ r_hat_p.real, 0]]) I = jnp.array([[d.real], [d.imag]]) S = H @ P_p @ H.T + R K = P_p @ H.T @ jnp.linalg.inv(S) x_c = x_p + K @ I P_c = P_p - K @ H @ P_p # adapt Q and R beta = .99 # p_c = jnp.exp(1j * x_c[0,0]) # e = r - s_hat_p * p_c # e_R = jnp.array([[e.real], # [e.imag]]) # R = beta * R + (1 - beta) * (e_R @ e_R.T + H @ P_p @ H.T) Q = beta * Q + (1. - beta) * K @ I @ I.T @ K.T return (Q, R, x_c, P_c), x_p[:,0]
def fprop(self, input_ids: JTensor, input_embs: JTensor, paddings: Optional[JTensor] = None, segment_pos: Optional[JTensor] = None) -> JTensor: """Augments the input embeddings with VQ ngram layer embeddings. Args: input_ids: Input unigram id tensor of shape [B, L] or [B, L, N]. This is unused and is added here to be consistent with the Ngrammger API. input_embs: Input unigram embedding tensor of shape [B, L, D] to which to add the ngram embedding. paddings: If not None, a tensor of shape [B, L] corresponding to padding. segment_pos: If not None, a tensor of shape [B, L] corresponding to the position of an id in a packed sequence. Returns: outputs: Input embedding with the VQ ngram added of shape [B, L, D]. """ del input_ids # Cast input embeddings to fprop dtype. input_embs = self._cast_to_fprop_dtype(input_embs) # Distances of shape [B, L, N, K]. distances, _ = self.vq_layer.fprop(input_embs, paddings=paddings) # [B, L, N]. cluster_ids = jnp.argmin(distances, -1) # [B, L, D]. output_embs = self.ngram_layer.fprop(cluster_ids, input_embs, paddings, segment_pos) return output_embs
def body_fn(i, vals): ( rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times, ) = vals idx = jnp.argmin(arrival_times) # NB: length of each sub-trajectory is scaled from the current min(arrival_times) # (see the note at total_time below) trajectory_length = arrival_times[idx] * time_unit arrival_times = arrival_times - arrival_times[idx] arrival_times = ops.index_update(arrival_times, idx, 1.0) # this is a trick, so that in a sub-trajectory of HMC, we always accept the new proposal pe = jnp.inf hmc_state = hmc_state._replace(trajectory_length=trajectory_length, potential_energy=pe) # Algo 1, line 7: perform a sub-trajectory hmc_state = update_continuous(hmc_state, z_discrete) # Algo 1, line 8: perform a discrete update rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum = update_discrete( idx, rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum) return ( rng_key, hmc_state, z_discrete, ke_discrete, delta_pe_sum, arrival_times, )
def __call__( self, loss_fn: utils.LossFn, rng: chex.PRNGKey, inputs: chex.Array, ) -> chex.Array: """Performs an optimization multiple times by tiling the inputs.""" if not self._has_batch_dim: opt_inputs = self._wrapped_optimizer(loss_fn, rng, inputs) opt_losses = loss_fn(opt_inputs) return opt_inputs, opt_losses # Tile the inputs and labels. batch_size = inputs.shape[0] # Tile inputs. shape = inputs.shape[1:] # Shape is [num_restarts * batch_size, ...]. inputs = jnp.tile(inputs, [self._restarts_using_tiling] + [1] * len(shape)) # Optimize. opt_inputs = self._wrapped_optimizer(loss_fn, rng, inputs) opt_losses = loss_fn(opt_inputs) opt_losses = jnp.reshape(opt_losses, [self._restarts_using_tiling, batch_size]) # Extract best. i = jnp.argmin(opt_losses, axis=0) j = jnp.arange(batch_size) shape = opt_inputs.shape[1:] return jnp.reshape( opt_inputs, (self._restarts_using_tiling, batch_size) + shape)[i, j]
def apply (self, data): """Applies training to the data. Parameters: ----------- Data: numpy array, size Ngalaxes x Nbands testing data, each row is a galaxy, each column is a band as per band defined above Returns: tomographic_selections: numpy array, int, size Ngalaxies tomographic selection for galaxies return as bin number for each galaxy. """ data_valid = [] f = ["r", "gr", "ri", "rz"] if self.bands == "griz" else ["r", "ri", "rz"] for c in f: data_valid.append(data[c]) data_valid = np.asarray(data_valid).T data_valid_r = data_valid @ self.eigs # Finds the distance between the points and the centroids dist = [] for center in self.centroids: shift = data_valid_r - center dist.append(jnp.linalg.norm(shift, axis=1)) # Converting to numpy array so we can use axis for argmin. dist = jnp.asarray(dist) # Which category these would be assigned to based on their distances return jnp.argmin(dist, axis=0)
def measure_cd(x, sr, start=-0.25, end=0.25, bins=2000, wavlen=1550e-9): ''' References: Zhou, H., Li, B., Tang, et. al, 2016. Fractional fourier transformation-based blind chromatic dispersion estimation for coherent optical communications. Journal of Lightwave Technology, 34(10), pp.2371-2380. ''' c = 299792458. p = jnp.linspace(start, end, bins) N = x.shape[0] K = p.shape[0] L = jnp.zeros(K, dtype=jnp.float32) def f(_, pi): return None, jnp.sum(jnp.abs(xop.frft(jnp.abs(xop.frft(x, pi))**2, -1))**2) # Use `scan` instead of `vmap` here to avoid potential large memory allocation. # Despite the speed of `scan` scales surprisingly well to large bins, # the speed has a lowerbound e.g 600ms at bins=1, possiblely related to the blind # migration of `frft` from Github :) (could frft be jitted in theory?). # TODO review `frft` _, L = xop.scan(f, None, p) B2z = jnp.tan(jnp.pi/2 - (p - 1) / 2 * jnp.pi)/(sr * 2 * jnp.pi / N * sr) Dz_set = -B2z / wavlen**2 * 2 * jnp.pi * c # the swept set of CD metrics Dz_hat = Dz_set[jnp.argmin(L)] # estimated accumulated CD return Dz_hat, L, Dz_set
def interp(x, xp, fp): """ Simple equivalent of np.interp that compute a linear interpolation. We are not doing any checks, so make sure your query points are lying inside the array. TODO: Implement proper interpolation! x, xp, fp need to be 1d arrays """ # First we find the nearest neighbour ind = np.argmin((x - xp) ** 2) # Perform linear interpolation ind = np.clip(ind, 1, len(xp) - 2) xi = xp[ind] # Figure out if we are on the right or the left of nearest s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64) a = (fp[ind + np.copysign(1, s)] - fp[ind]) / ( xp[ind + np.copysign(1, s)] - xp[ind] ) b = fp[ind] - a * xp[ind] return a * x + b
def update(state, inp): w, f, s, = state u, x, train = inp v = mimo(w, u) z = v * f * s d = jnp.where(train, x, const[jnp.argmin(jnp.abs(const[:,None] - z[None,:]), axis=0)]) psi_hat = jnp.abs(f)/f * jnp.abs(s)/s e_p = d * psi_hat - v e_f = d - f * v e_s = d - s * f * v gs = -1. / (jnp.abs(f * v)**2 + eps) * e_s * (f * v).conj() gf = -1. / (jnp.abs(v)**2 + eps) * e_f * v.conj() # clip the grads of f and s which are less regulated than w, # it may stablize this algo. in some corner cases? gw = -e_p[:, None, None] * u.conj().T[None, ...] gf = jnp.where(jnp.abs(gf) > grad_max[0], gf / jnp.abs(gf) * grad_max[0], gf) gs = jnp.where(jnp.abs(gs) > grad_max[1], gs / jnp.abs(gs) * grad_max[1], gs) out = (w, f, s, d) # update w = w - mu_w * gw f = f - mu_f * gf s = s - mu_s * gs state = (w, f, s) return state, out
def step_cpane_ekf(params, inputs): Q, R, Psi_c, P_c, Psi_a, beta, const = params r, x, train = inputs Psi_p = Psi_c P_p = P_c + Q Psi_a = beta * Psi_a + (1 - beta) * Psi_c d = lax.cond( train, None, lambda _: x, # data-aided mode None, lambda _: const[jnp.argmin(jnp.abs(const - r * jnp.exp(-1j * Psi_a)))]) # decision directed mode H = 1j * d * jnp.exp(1j * Psi_p) K = P_p * H.conj() / (H * P_p * H.conj() + R) v = r - d * jnp.exp(1j * Psi_p) outputs = (Psi_c, d) # return averaged decision results Psi_c = Psi_p + K * v P_c = (1. - K * H) * P_p params = (Q, R, Psi_c, P_c, Psi_a, beta, const) return params, outputs
def update(i, state, inp): z_c, P_c, Q = state y, x = inp N = y.shape[0] # frame size A = jnp.array([[1, N], [0, 1]]) I = jnp.eye(2) n = (jnp.arange(N) - (N - 1) / 2) z_p = A @ z_c P_p = A @ P_c @ A.T + Q phi_p = z_p[0, 0] + n * z_p[1, 0] # linear approx. s_p = y * jnp.exp(-1j * phi_p) d = jnp.where( train(i), x, const[jnp.argmin(jnp.abs(const[None, :] - s_p[:, None]), axis=-1)]) scd_p = s_p * d.conj() sumscd_p = jnp.sum(scd_p) e = jnp.array([[jnp.arctan(sumscd_p.imag / sumscd_p.real)], [(jnp.sum(n * scd_p)).imag / (jnp.sum(n * n * scd_p)).real]]) G = P_p @ jnp.linalg.pinv((P_p + R)) z_c = z_p + G @ e P_c = (I - G) @ P_p Q = jnp.where(akf(i), alpha * Q + (1 - alpha) * (G @ e @ e.T @ G), Q) out = (z_p[1, 0], phi_p) state = (z_c, P_c, Q) return state, out
def decision(const, v, stopgrad=True): """ simple symbol decision based on Euclidean distance """ if v.ndim > 1: raise ValueError(f'ndim = 1 is expected, but got {v.ndim} instead') d = const[jnp.argmin(jnp.abs(const[:, None] - v[None, :]), axis=0)] return stop_gradient(d) if stopgrad else d
def ddqnBestAction(Q1, cos_th, sin_th, thdot): # (val_est, randkey) = val_est_randkey s_a = state_action_template * jnp.array([[cos_th, sin_th, thdot, 1]]).T val_ests = jnn.predict(Q1, s_a) # NOTE: we use argMIN here since everything is framed as cost not reward # Note also that this does not need to be clipped; all the opts are in # the right range return U_opts[jnp.argmin(val_ests)]
def predict(self, X): X = jnp.array(X, dtype=self._dtype) dist_matrix = 1 - cosine_similarity(X, self.clusters, norm_axis=1) assignment = jnp.argmin(dist_matrix, axis=1) self.labels_ = onp.array(assignment) return assignment
def loss_fn(w, u): v = mimo(w, u)[None,:] R2 = jnp.where(train, Rx**2, Rs[jnp.argmin( jnp.abs(Rs[:,None] * v / jnp.abs(v) - v), axis=0)]**2) l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0,:])**2)) return l
def material_at(self, point: Point) -> Material: # We use tree_multimap to create a batch of all materials materials = jax.tree_multimap( lambda *xs: jnp.stack(xs), *[obj.material_at(point) for obj in self.objs] ) idx = jnp.argmin(jnp.array([obj.sdf(point) for obj in self.objs])) # We then tree_map to select a single material from the batch return jax.tree_map(lambda x: x[idx], materials)
def test_pair_correlation(self, dtype): displacement = lambda Ra, Rb, **kwargs: Ra - Rb R = np.array([[1, 0], [0, 0], [0, 1]], dtype=dtype) rs = np.linspace(0, 2, 60, dtype=dtype) g = quantity.pair_correlation(displacement, rs, f32(0.1)) gs = g(R) gs = np.mean(gs, axis=0) assert np.argmax(gs) == np.argmin((rs - 1.)**2) assert gs.dtype == dtype
def loss_fn(w, u, x, i): v = r2c(mimo(w, u)[None, :]) R2 = jnp.where( train(i), jnp.abs(x)**2, Rs[jnp.argmin(jnp.abs(Rs[:, None] * v / jnp.abs(v) - v), axis=0)]**2) l = jnp.sum(jnp.abs(R2 - jnp.abs(v[0, :])**2)) return l
def linesearch(cell: PVCell, bound: Boundary, pot: Potentials, p: Array) -> Array: alphas = jnp.linspace(0, 2, n_lnsrch) R = vmap(residnorm, (None, None, None, None, 0))(cell, bound, pot, p, alphas) alpha_best = alphas[n_lnsrch // 10:][jnp.argmin(R[n_lnsrch // 10:])] return alpha_best
def predict(self, X): clusters_norm = jnp.linalg.norm(moun.clusters, ord=2, axis=1)[:, jnp.newaxis] points_norm = jnp.linalg.norm(X, ord=2, axis=1) assignment = jnp.argmin( (moun.clusters @ X.T) / (clusters_norm * points_norm), axis=0) return assignment