def summary_plot(losses, hmc_samples, hmcecs_samples, hmc_runtime, hmcecs_runtime): fig, ax = plt.subplots(2, 2) ax[0, 0].plot(losses, 'r') ax[0, 0].set_title('SVI losses') ax[0, 0].set_ylabel('ELBO') if hmc_runtime > hmcecs_runtime: ax[0, 1].bar([0], hmc_runtime, label='hmc', color='b') ax[0, 1].bar([0], hmcecs_runtime, label='hmcecs', color='r') else: ax[0, 1].bar([0], hmcecs_runtime, label='hmcecs', color='r') ax[0, 1].bar([0], hmc_runtime, label='hmc', color='b') ax[0, 1].set_title('Runtime') ax[0, 1].set_ylabel('Seconds') ax[0, 1].legend() ax[0, 1].set_xticks([]) ax[1, 0].plot(jnp.sort(hmc_samples['theta'].mean(0)), 'or') ax[1, 0].plot(jnp.sort(hmcecs_samples['theta'].mean(0)), 'b') ax[1, 0].set_title(r'$\mathrm{\mathbb{E}}[\theta]$') ax[1, 1].plot(jnp.sort(hmc_samples['theta'].var(0)), 'or') ax[1, 1].plot(jnp.sort(hmcecs_samples['theta'].var(0)), 'b') ax[1, 1].set_title(r'Var$[\theta]$') for a in ax[1, :]: a.set_xticks([]) fig.tight_layout() fig.savefig('hmcecs_plot.pdf', bbox_inches='tight')
def summary_plot(losses, hmc_samples, hmcecs_samples, hmc_runtime, hmcecs_runtime): fig, ax = plt.subplots(2, 2) ax[0, 0].plot(losses, "r") ax[0, 0].set_title("SVI losses") ax[0, 0].set_ylabel("ELBO") if hmc_runtime > hmcecs_runtime: ax[0, 1].bar([0], hmc_runtime, label="hmc", color="b") ax[0, 1].bar([0], hmcecs_runtime, label="hmcecs", color="r") else: ax[0, 1].bar([0], hmcecs_runtime, label="hmcecs", color="r") ax[0, 1].bar([0], hmc_runtime, label="hmc", color="b") ax[0, 1].set_title("Runtime") ax[0, 1].set_ylabel("Seconds") ax[0, 1].legend() ax[0, 1].set_xticks([]) ax[1, 0].plot(jnp.sort(hmc_samples["theta"].mean(0)), "or") ax[1, 0].plot(jnp.sort(hmcecs_samples["theta"].mean(0)), "b") ax[1, 0].set_title(r"$\mathrm{\mathbb{E}}[\theta]$") ax[1, 1].plot(jnp.sort(hmc_samples["theta"].var(0)), "or") ax[1, 1].plot(jnp.sort(hmcecs_samples["theta"].var(0)), "b") ax[1, 1].set_title(r"Var$[\theta]$") for a in ax[1, :]: a.set_xticks([]) fig.tight_layout() fig.savefig("hmcecs_plot.pdf", bbox_inches="tight")
def test_DenseSymm_infeatures(symmetries, use_bias, mode): rng = nk.jax.PRNGSeq(0) g, hi, perms = _setup_symm(symmetries, N=8) if mode == "matrix": ma = nk.nn.DenseSymm( symmetries=perms, mode=mode, features=8, use_bias=use_bias, bias_init=uniform(), ) else: ma = nk.nn.DenseSymm( symmetries=perms, shape=tuple(g.extent), mode=mode, features=8, use_bias=use_bias, bias_init=uniform(), ) pars = ma.init(rng.next(), hi.random_state(rng.next(), 2).reshape(1, 2, -1)) v = hi.random_state(rng.next(), 6).reshape(3, 2, -1) vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)] for val in vals: assert jnp.allclose(jnp.sort(val, -1), jnp.sort(vals[0], -1))
def get_logit_snip_masks(params, nn_density_level, predict, x_batch, batch_input_shape, GlOBAL_PRUNE_BOOL = True): def norm_square_logits(params, f, x): return np.sum(f(params, x) **2) init_grads = grad(norm_square_logits)(params, predict, x_batch.reshape(batch_input_shape) ) thres_list = [None] * len(params) if GlOBAL_PRUNE_BOOL == True: # global pruning cs = [abs( init_grads[idx][0] * params[idx][0]).flatten() for idx in range(len(params)) if len(params[idx]) == 2 ] pooled_cs = np.hstack(cs) idx = int( (1 - nn_density_level) * len(pooled_cs) ) # threshold: entries which below the thredhold will be removed thres = np.sort(pooled_cs)[idx] thres_list = [thres] * len(params) else: # layerwise pruning for layer_index in range( len(params)): if len(params[layer_index]) == 2: cs = abs( init_grads[layer_index][0] * params[layer_index][0]).flatten() idx = int( (1 - nn_density_level) * len(cs) ) # threshold: entries which below the thredhold will be removed thres = np.sort(cs)[idx] thres_list[layer_index] = thres masks = [] for layer_index in range( len(params)): if len(params[layer_index]) < 2: # In this the case, the layer does not contain weight and bias parameters. masks.append( [] ) elif len(params[layer_index]) == 2: # In this case, the layer contains a tuple of parameters for weights and biases weights = params[layer_index][0] weights_grad = init_grads[layer_index][0] layer_cs = np.abs(weights * weights_grad) # 0 selected for weight parameters with magnitudes smaller than the threshold, 1 otherwise this_mask = np.float32(layer_cs >= thres_list[layer_index]) masks.append(this_mask ) else: raise NotImplementedError return masks
def test_neighbor_list_build_time_dependent(self, dtype, dim): key = random.PRNGKey(1) if dim == 2: box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32) elif dim == 3: box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0], [0.0, 0.0, 7.25]]) min_length = np.min(np.diag(box_fn(0.))) cutoff = f32(1.23) # TODO(schsam): Get cell-list working with anisotropic cell sizes. cell_size = cutoff / min_length displacement, _ = space.periodic_general(box_fn) metric = space.metric(displacement) R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) N = R.shape[0] neighbor_list_fn = partition.neighbor_list(metric, 1., cutoff, 0.0, 1.1, cell_size=cell_size, t=np.array(0.)) idx = neighbor_list_fn(R, t=np.array(0.25)).idx R_neigh = R[idx] mask = idx < N metric = partial(metric, t=f32(0.25)) d = vmap(vmap(metric, (None, 0))) dR = d(R, R_neigh) d_exact = space.map_product(metric) dR_exact = d_exact(R, R) dR = np.where(dR < cutoff, dR, 0) * mask dR_exact = np.where(dR_exact < cutoff, dR_exact, 0) dR = np.sort(dR, axis=1) dR_exact = np.sort(dR_exact, axis=1) for i in range(dR.shape[0]): dR_row = dR[i] dR_row = dR_row[dR_row > 0.] dR_exact_row = dR_exact[i] dR_exact_row = dR_exact_row[dR_exact_row > 0.] self.assertAllClose(dR_row, dR_exact_row)
def test_forced_identifiability_prior(): from jax import random prior = PriorChain().push(ForcedIdentifiabilityPrior('x', 10, 0., 10.)) for i in range(10): out = prior(random.uniform(random.PRNGKey(i), shape=(prior.U_ndims, ))) assert jnp.all(jnp.sort(out['x'], axis=0) == out['x']) assert jnp.all((out['x'] >= 0.) & (out['x'] <= 10.)) prior = PriorChain().push( ForcedIdentifiabilityPrior('x', 10, jnp.array([0., 0.]), 10.)) for i in range(10): out = prior(random.uniform(random.PRNGKey(i), shape=(prior.U_ndims, ))) assert out['x'].shape == (10, 2) assert jnp.all(jnp.sort(out['x'], axis=0) == out['x']) assert jnp.all((out['x'] >= 0.) & (out['x'] <= 10.))
def sample_pdf(key, bins, weights, origins, directions, z_vals, num_coarse_samples, use_stratified_sampling): """Hierarchical sampling. Args: key: jnp.ndarray(float32), [2,], random number generator. bins: jnp.ndarray(float32), [batch_size, n_bins + 1]. weights: jnp.ndarray(float32), [batch_size, n_bins]. origins: ray origins. directions: ray directions. z_vals: jnp.ndarray(float32), [batch_size, n_coarse_samples]. num_coarse_samples: int, the number of samples. use_stratified_sampling: bool, use use_stratified_sampling samples. Returns: z_vals: jnp.ndarray(float32), [batch_size, n_coarse_samples + num_fine_samples]. points: jnp.ndarray(float32), [batch_size, n_coarse_samples + num_fine_samples, 3]. """ z_samples = piecewise_constant_pdf(key, bins, weights, num_coarse_samples, use_stratified_sampling) # Compute united z_vals and sample points z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1) return z_vals, (origins[..., None, :] + z_vals[..., None] * directions[..., None, :])
def subgraph(single: SemiSupervisedSingle, indices: jnp.ndarray) -> SemiSupervisedSingle: indices = jnp.asarray(indices, jnp.int32) assert indices.ndim == 1, indices.shape index_dtype = indices.dtype assert jnp.issubdtype(index_dtype, jnp.integer) indices = jnp.sort(indices) adj = ops.gather(ops.gather(single.graph, indices, axis=0), indices, axis=1) node_features = single.node_features if isinstance(node_features, JAXSparse): node_features = ops.gather(node_features, indices, axis=0) else: node_features = node_features[indices] remap_indices = (jnp.zeros( (single.graph.shape[0], ), index_dtype).at[indices].set( jnp.arange(indices.size, dtype=index_dtype))) def valid_ids(ids): return None if ids is None else remap_indices[ids] return SemiSupervisedSingle( node_features, adj, single.labels[indices], train_ids=valid_ids(single.train_ids), validation_ids=valid_ids(single.validation_ids), test_ids=valid_ids(single.test_ids), )
def sample_per_class( rng: PRNGKey, labels: jnp.ndarray, examples_per_class: int, valid_indices: tp.Optional[jnp.ndarray] = None, num_classes: tp.Optional[int] = None, ) -> jnp.ndarray: assert labels.ndim == 1, labels.shape if valid_indices is None: valid_indices = jnp.arange(labels.size, dtype=jnp.int64) valid_labels = labels else: assert valid_indices.ndim == 1, valid_indices.shape valid_labels = labels[valid_indices] if num_classes is None: num_classes = jnp.max(labels) + 1 all_class_indices = [] for class_index, key in enumerate(jax.random.split(rng, num_classes)): (class_indices, ) = jnp.where(valid_labels == class_index) assert class_indices.size >= examples_per_class class_indices = jax.random.choice(key, class_indices, (examples_per_class, ), replace=False) class_indices = valid_indices[class_indices] all_class_indices.append(class_indices) class_indices = jnp.concatenate(all_class_indices) class_indices = jnp.sort(class_indices) return class_indices
def sorted_topk_indicators(x, k, sort_by=SortBy.POSITION): """Finds the (sorted) positions of the topk values in x. Args: x: The input scores of dimension (d,). k: The number of top elements to find. sort_by: Strategy to order the extracted values. This is useful when this function is applied to many perturbed input and average. As topk's output does not have a fixed order, the indicator vectors could be swaped and the average of the indicators would not be spiky. Returns: Indicator vectors in a tensor of shape (k, d) """ n = x.shape[-1] values, ranks = jax.lax.top_k(x, k) if sort_by == SortBy.NONE: sorted_ranks = ranks if sort_by == SortBy.VALUES: sorted_ranks = ranks[jnp.argsort(values)] if sort_by == SortBy.POSITION: sorted_ranks = jnp.sort(ranks) one_hot_fn = jax.vmap(functools.partial(jax.nn.one_hot, num_classes=n)) indicators = one_hot_fn(sorted_ranks) return indicators
def rs(self, key, budget): """Random search sampling method with uniform distribution.""" if operator.xor(self.x_precollect is None, self.y_precollect is None): raise ValueError( 'Both x_precollect and y_precollect need to be provided.') if self.x_precollect is not None: ind_pre = point_in_search_space(self.x_precollect, self.search_space) x_precollect_in_search_space = self.x_precollect[ind_pre, :] y_precollect_in_search_space = self.y_precollect[ind_pre, :] if y_precollect_in_search_space.shape[0] < budget: raise ValueError( 'budget is larger than the precollected data in the search space.') ind_chosen = jax.random.choice( key, y_precollect_in_search_space.shape[0], shape=(budget,), replace=False) ind_chosen = jnp.sort(ind_chosen) x = x_precollect_in_search_space[ind_chosen, :] y = y_precollect_in_search_space[ind_chosen, :] additional_info_dict = {} else: x = jax.random.uniform( key, shape=(budget, self.search_space.shape[0]), minval=self.search_space[:, 0], maxval=self.search_space[:, 1]) y, additional_info_dict = self.objective_fn(x) ind_chosen = None return x, y, additional_info_dict, ind_chosen
def nucleaus_filter(logits, top_p=0.9, top_k=None): sorted_logits = jnp.sort(logits)[:, ::-1] # sort descending sorted_indices = jnp.argsort(logits)[:, ::-1] cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits), axis=-1) if top_k is not None: # Keep only top_k tokens indices_range = jnp.arange(len(sorted_indices[0])) indices_range = jnp.stack([indices_range] * len(sorted_indices), axis=0) sorted_indices_to_remove = jnp.where(indices_range > top_k, sorted_indices, 0) _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove) logit_mask = 1e10 * indices_to_remove logits -= logit_mask # Remove tokens with cumulative probability above a threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = jnp.concatenate((jnp.zeros_like( sorted_indices_to_remove[:, :1]), sorted_indices_to_remove), axis=-1)[:, :-1] _, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove) logit_mask = 1e10 * indices_to_remove logits -= logit_mask return logits
def l1_unit_projection(x): """Euclidean projection to L1 unit ball i.e. argmin_{|v|_1<= 1} |x-v|_2. Args: x: An array of size dim x num. Returns: An array of size dim x num, the projection to the unit L1 ball. """ # https://dl.acm.org/citation.cfm?id=1390191 xshape = x.shape if len(x.shape) == 1: x = x.reshape(-1, 1) eshape = x.shape v = jnp.abs(x.reshape((-1, eshape[-1]))) u = jnp.sort(v, axis=0) u = u[::-1, :] # descending arange = (1 + jnp.arange(eshape[0])).reshape((-1, 1)) usum = (jnp.cumsum(u, axis=0) - 1) / arange rho = jnp.max(((u - usum) > 0) * arange - 1, axis=0, keepdims=True) thx = jnp.take_along_axis(usum, rho, axis=0) w = (v - thx).clip(a_min=0) w = jnp.where(jnp.linalg.norm(v, ord=1, axis=0, keepdims=True) > 1, w, v) x = w.reshape(eshape) * jnp.sign(x) return x.reshape(xshape)
def _holm_bonferroni(p): """Performs Holm-Bonferroni correction for pvalues to account for multiple comparisons. Parameters --------- p: numpy.array array of pvalues Returns ------- numpy.array corrected pvalues """ K = len(p) sort_index = -np.ones(K, dtype=np.int64) sorted_p = np.sort(p) sorted_p_adj = sorted_p * (K - np.arange(K)) for j in range(K): idx = (p == sorted_p[j]) & (sort_index < 0) num_ties = len(sort_index[idx]) sort_index[idx] = np.arange(j, (j + num_ties), dtype=np.int64) sorted_holm_p = [min([max(sorted_p_adj[:k]), 1]) for k in range(1, K + 1)] holm_p = [sorted_holm_p[sort_index[k]] for k in range(K)] return holm_p
def near_square_wave( n_train: int = 80, input_noise: float = 0.15, output_noise: float = 0.3, n_test: int = 400, random_state: int = 123, ): """Generates a near-square wave""" # function f = lambda x: np.sin(1.0 * np.pi / 1.6 * np.cos(5 + 0.5 * x)) # create clean inputs x_mu = np.linspace(-10, 10, n_train) # clean outputs y = f(x_mu) # generate noise x_rng = check_random_state(random_state) y_rng = check_random_state(random_state + 1) # noisy inputs x = x_mu + input_noise * x_rng.randn(x_mu.shape[0]) # noisy outputs y = f(x_mu) + output_noise * y_rng.randn(x_mu.shape[0]) # test points x_test = np.linspace(-12, 12, n_test) + x_rng.randn(n_test) y_test = f(np.linspace(-12, 12, n_test)) x_test = np.sort(x_test) return x[:, None], y[:, None], x_test[:, None], y_test
def kth_percent_distance(dists: np.ndarray, k: float = 0.3) -> np.ndarray: """kth percent distance in a gram matrix This calculates the kth percent in an (NxN) matrix. It sorts all distance values and then retrieves the kth value as a percentage of the number of samples. Parameters ---------- dists : jax.numpy.ndarray the distance matrix already calculate (n_samples, n_samples) k : int the kth value from the (default=0.15) Returns ------- kth_dist : jax.numpy.ndarray the neighbours up to the kth distance """ # kth distance calculation (50%) kth_sample = int(k * dists.shape[0]) # take the Kth neighbours of that distance k_dist = np.sort(dists)[:, kth_sample] return k_dist
def render_rays_fine(rays, z_vals, weights, num_importance, perturbation=True, rng=None): """Render rays for the fine model. Args: rays: (2, num_rays, 3) origin and direction generated rays z_vals: (num_rays, num_samples) depths of the sampled positions weights: (num_rays, num_samples) weights assigned to each sampled color for the coarse model num_importance: number of samples used in the fine model perturbation: whether to apply jitter on each ray or not rng: random key Returns: pts: (num_rays, num_samples + num_importance, 3) points in space to evaluate model at z_vals: (num_rays, num_samples + num_importance) depths of the sampled positions z_samples: (num_rays) standard deviation of distances along ray for each sample """ rays_o, rays_d = rays z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], num_importance, perturbation, rng) z_samples = lax.stop_gradient(z_samples) # obtain all points to evaluate color density at z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1) z_vals = z_vals.astype(rays_d.dtype) pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] return pts, z_vals, jnp.std(z_samples, axis=-1)
def sample_permutation(key, coupling): """Samples a permutation matrix from a doubly stochastic coupling matrix. CAREFUL: the couplings that come out of the sinkhorn solver are not doubly stochastic but 1/dim * doubly_stochastic. See **Convex Relaxations for Permutation Problems** paper for rough explanation of the algorithm. Best to use by drawing multiple samples and picking the permutation with lowest cost as sometimes permutations seem to be drawn with high cost. the sample_best_permutation method does this. Args: key: jnp.ndarray that functions as a PRNG key. coupling: jnp.ndarray of shape [N, N] which must have marginals such that coupling.sum(0) == 1. and coupling.sum(1) == 1. Note that in sinkhorn we usually output couplings with marginals that sum to 1/N. Returns: permutation matrix: jnp.ndarray of shape [N, N] of floating dtype. """ dim = coupling.shape[0] # random monotonic vector v without duplicates. v = jax.random.choice(key, 10 * dim, shape=(dim,), replace=False) v = jnp.sort(v) * 10. w = jnp.dot(coupling, v) # Sorting w will give the row indices of the permutation matrix. row_ind = jnp.argsort(w) col_ind = jnp.arange(0, dim) # Compute permutation matrix from row and column indices perm = idx2permutation(row_ind, col_ind) return perm
def test_prior_mll(): """ Test that the MLL evaluation works with priors attached to the parameter values. """ key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key, minval=-5.0, maxval=5.0, shape=(100, 1)), axis=0) f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x) y = f(x) + jr.normal(key, shape=x.shape) * 0.1 posterior = Prior(kernel=RBF()) * Gaussian() params = initialise(posterior) config = get_defaults() constrainer, unconstrainer = build_all_transforms(params.keys(), config) params = unconstrainer(params) print(params) mll = marginal_ll(posterior, transform=constrainer) priors = { "lengthscale": tfd.Gamma(1.0, 1.0), "variance": tfd.Gamma(2.0, 2.0), "obs_noise": tfd.Gamma(2.0, 2.0), } mll_eval = mll(params, x, y) mll_eval_priors = mll(params, x, y, priors) assert pytest.approx(mll_eval) == jnp.array(-103.28180663) assert pytest.approx(mll_eval_priors) == jnp.array(-105.509218857)
def sample_pdf(key, bins, weights, origins, directions, z_vals, num_samples, randomized): """Hierarchical sampling. Args: key: jnp.ndarray(float32), [2,], random number generator. bins: jnp.ndarray(float32), [batch_size, num_bins + 1]. weights: jnp.ndarray(float32), [batch_size, num_bins]. origins: jnp.ndarray(float32), [batch_size, 3], ray origins. directions: jnp.ndarray(float32), [batch_size, 3], ray directions. z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples]. num_samples: int, the number of samples. randomized: bool, use randomized samples. Returns: z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples + num_fine_samples]. points: jnp.ndarray(float32), [batch_size, num_coarse_samples + num_fine_samples, 3]. """ z_samples = piecewise_constant_pdf(key, bins, weights, num_samples, randomized) # Compute united z_vals and sample points z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1) coords = cast_rays(z_vals, origins, directions) return z_vals, coords
def sample_pdf(key, bins, weights, rays, z_vals, num_samples, randomized): """Hierarchical sampling. Args: key: jnp.ndarray(float32), [2,], random number generator. bins: jnp.ndarray(float32), [batch_size, num_bins + 1]. weights: jnp.ndarray(float32), [batch_size, num_bins]. rays: jnp.ndarray(float32), [batch_size, 6]. z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples]. num_samples: int, the number of samples. randomized: bool, use randomized samples. Returns: z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples + num_fine_samples]. points: jnp.ndarray(float32), [batch_size, num_coarse_samples + num_fine_samples, 3]. """ z_samples = piecewise_constant_pdf(key, bins, weights, num_samples, randomized) origins = rays[Ellipsis, 0:3] directions = rays[Ellipsis, 3:6] # Compute united z_vals and sample points z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1) return z_vals, (origins[Ellipsis, None, :] + z_vals[Ellipsis, None] * directions[Ellipsis, None, :])
def get_boundaries_intersections(z, d, trust_radius): a = jnp.vdot(d, d) b = 2 * jnp.vdot(z, d) c = jnp.vdot(z, z) - trust_radius**2 sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c) ta = (-b - sqrt_discriminant) / (2 * a) tb = (-b + sqrt_discriminant) / (2 * a) return jnp.sort(jnp.stack([ta, tb]))
def get_top_k_weights( top_k_fraction: float, restarting_weights: Array, scaled_advantages: Array, axis_name: Optional[str] = None, use_stop_gradient: bool = True, ): """Get the weights for the top top_k_fraction of advantages. Args: top_k_fraction: The fraction of weights to use. restarting_weights: Restarting weights, shape E*, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these. scaled_advantages: The advantages for each example (shape E*), scaled by temperature. axis_name: Optional axis name for `pmap`. If `None`, computations are performed locally on each device. use_stop_gradient: bool indicating whether or not to apply stop gradient. Returns: Weights for the top top_k_fraction of advantages """ chex.assert_equal_shape([scaled_advantages, restarting_weights]) chex.assert_type([scaled_advantages, restarting_weights], float) if not 0.0 < top_k_fraction <= 1.0: raise ValueError( f"`top_k_fraction` must be in (0, 1], got {top_k_fraction}") logging.info("[vmpo_e_step] top_k_fraction: %f", top_k_fraction) if top_k_fraction < 1.0: # Don't include the restarting samples in the determination of top-k. valid_scaled_advantages = scaled_advantages - ( 1.0 - restarting_weights) * _INFINITY # Determine the minimum top-k value across all devices, if axis_name: all_valid_scaled_advantages = jax.lax.all_gather( valid_scaled_advantages, axis_name=axis_name) else: all_valid_scaled_advantages = valid_scaled_advantages top_k = int(top_k_fraction * jnp.size(all_valid_scaled_advantages)) if top_k == 0: raise ValueError( "top_k_fraction too low to get any valid scaled advantages.") # TODO(b/160450251): Use jnp.partition(all_valid_scaled_advantages, top_k) # when this is implemented in jax. top_k_min = jnp.sort(jnp.reshape(all_valid_scaled_advantages, [-1]))[-top_k] # Fold the top-k into the restarting weights. top_k_weights = jnp.greater_equal(valid_scaled_advantages, top_k_min).astype(jnp.float32) top_k_weights = jax.lax.select( use_stop_gradient, jax.lax.stop_gradient(top_k_weights), top_k_weights) top_k_restarting_weights = restarting_weights * top_k_weights else: top_k_restarting_weights = restarting_weights return top_k_restarting_weights
def mask(x, mask_constant, mask_axis, key, p): if mask_constant is not None: mask_shape = [ 1 if i in mask_axis else s for i, s in enumerate(x.shape) ] mask_mat = jax.random.bernoulli(key, p=p, shape=mask_shape) x = np.where(mask_mat, mask_constant, x) x = np.sort(x, 1) return x
def test_topk_one_array(self, k): n = 20 x = jax.random.uniform(self.rng, (n,)) axis = 0 xs = soft_sort.sort(x, axis=axis, topk=k, epsilon=1e-3) outsize = k if 0 < k < n else n self.assertEqual(xs.shape, (outsize,)) self.assertTrue(jnp.alltrue(jnp.diff(xs, axis=axis) >= 0.0)) self.assertAllClose(xs, jnp.sort(x, axis=axis)[-outsize:], atol=0.01)
def gini(array): """Calculate the Gini coefficient of a numpy array.""" # based on bottom eq: http://www.statsdirect.com/help/content/image/stat0206_wmf.gif # from: http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm array = np.sort(array) #values must be sorted index = np.arange(1, array.shape[0] + 1) #index per array element n = array.shape[0] #number of array elements return ((np.sum( (2 * index - n - 1) * array)) / (n * np.sum(array))) #Gini coefficient
def top_k_approx(scores, k=100): """Returns approximate topk highest scores for each row. The api is same as jax.lax.top_k, so this can be used as a drop in replacement as long as num dims of scores tensor is 2. For more dimensions, please use one or more vmap(s) to be able to use it. In essence, we perform jnp.max operation, which can be thought of as lossy top 1, on fixed length window of items. We can control the amound of approximation by changing the window length. Smaller it gets, the approximation gets better but at the cost of performance. Once we have the max for all the windows, we apply regular slow but exact jax.lax.top_k over reduced set of items. Args: scores: [num_rows, num_cols] shaped tensor. Will return top K over last dim. k: How many top scores to return for each row. Returns: Topk scores, topk ids. Both shaped [num_rows, k] """ num_queries = scores.shape[0] num_items = scores.shape[1] # Make this bigger to improve recall. Should be between [1, k]. num_windows_multiplier = 5 window_lengths = num_items // k // num_windows_multiplier + 1 padded_num_items = k * num_windows_multiplier * window_lengths print(f"scores shape: {scores.shape}") print(f"padded_num_items: {padded_num_items}") print(f"num_items: {num_items}") scores = jnp.pad(scores, ((0, 0), (0, padded_num_items - num_items)), mode="constant", constant_values=jnp.NINF) scores = jnp.reshape( scores, (num_queries, k * num_windows_multiplier, window_lengths)) approx_top_local_scores = jnp.max(scores, axis=2) sorted_approx_top_scores_across_local = jnp.flip(jnp.sort( approx_top_local_scores, axis=1), axis=1) approx_top_ids_across_local = jnp.flip(jnp.argsort(approx_top_local_scores, axis=1), axis=1)[:, :k] approx_top_local_ids = jnp.argmax(scores, axis=2) offsets = jnp.arange(0, padded_num_items, window_lengths) approx_top_ids_with_offsets = approx_top_local_ids + offsets approx_top_ids = slice_2d(approx_top_ids_with_offsets, approx_top_ids_across_local) topk_scores = sorted_approx_top_scores_across_local[:, :k] topk_ids = approx_top_ids return topk_scores, topk_ids
def __init__(self, dim, bounds): """ :param dim: dimension of the space :param bounds: a list of floats. E.g [b_1, ..., b_k] representing k cubes, centered at (0,...,0) and sides b_i*2 we assume all bounds b_i are positive """ self.dim = dim self.pos_sorted_bounds = np.sort(bounds) assert self.pos_sorted_bounds[ 0] > 0 # we assume all bounds (thus the smallest) are positive
def test_soft_quantile_normalization(self): rngs = jax.random.split(self.rng, 2) x = jax.random.uniform(rngs[0], shape=(100,)) mu, sigma = 2.0, 1.2 y = mu + sigma * jax.random.normal(self.rng, shape=(48,)) mu_target, sigma_target = y.mean(), y.std() qn = soft_sort.quantile_normalization(x, jnp.sort(y), epsilon=1e-4) mu_transform, sigma_transform = qn.mean(), qn.std() self.assertAllClose([mu_transform, sigma_transform], [mu_target, sigma_target], rtol=0.05)
def gen(key: jnp.ndarray): feature_ids = jax.random.shuffle(key, remaining_ids) for batch_index in range(length): i = feature_ids[batch_size * batch_index:batch_size * (batch_index + 1)] i = jnp.concatenate((ids, i)) i = jnp.sort(i) f = features[i] p = prop[:, i] yield ((p, f, input_ids), labels)