def sample_ellipsoid(key, center, radii, rotation, unit_cube_constraint=False): """ Sample uniformly inside an ellipsoid. When unit_cube_constraint=True then during the sampling when a random radius is chosen, the radius is constrained. u(t) = R @ (t * n) + c u(t) == 1 1-c = t * R@n t = (1 - c)/R@n take minimum t satisfying this likewise for zero intersection Args: key: center: [D] radii: [D] rotation: [D,D] Returns: [D] """ direction_key, radii_key = random.split(key, 2) direction = random.normal(direction_key, shape=radii.shape) if unit_cube_constraint: direction = direction / jnp.linalg.norm(direction) R = rotation * radii D = R @ direction t0 = -center / D t1 = jnp.reciprocal(D) + t0 t0 = jnp.where(t0 < 0., jnp.inf, t0) t1 = jnp.where(t1 < 0., jnp.inf, t1) t = jnp.minimum(jnp.min(t0), jnp.min(t1)) t = jnp.minimum(t, 1.) return jnp.exp( jnp.log(random.uniform(radii_key, minval=0., maxval=t)) / radii.size) * D + center log_norm = jnp.log(jnp.linalg.norm(direction)) log_radius = jnp.log(random.uniform(radii_key)) / radii.size # x = direction * (radius/norm) x = direction * jnp.exp(log_radius - log_norm) return circle_to_ellipsoid(x, center, radii, rotation)
def __init__(self, x, y, alpha=0., sigma=None, lamb=None, kernel_num=100): """[summary] Args: x (array-like of float): Numerator samples array. x is generated from p(x). y (array-like of float): Denumerator samples array. y is generated from q(x). alpha (float or array-like, optional): The alpha is a parameter that can adjust the mixing ratio r(x) = p(x)/(alpha*p(x)+(1-alpha)q(x)) , and is set in the range of 0-1. Defaults to 0. sigma (float or array-like, optional): Bandwidth of kernel. If a value is set for sigma, that value is used for kernel bandwidth , and if a numerical array is set for sigma, Densratio selects the optimum value by using CV. Defaults to array of 10e-4 to 10e+9 divided into 14 on the log scale. lamb (float or array-like, optional): Regularization parameter. If a value is set for lamb, that value is used for hyperparameter , and if a numerical array is set for lamb, Densratio selects the optimum value by using CV. Defaults to array of 10e-4 to 10e+9 divided into 14 on the log scale. kernel_num (int, optional): The number of kernels in the linear model. Defaults to 100. Raises: ValueError: [description] """ self.__x = transform_data(x) self.__y = transform_data(y) if self.__x.shape[1] != self.__y.shape[1]: raise ValueError("x and y must be same dimentions.") if sigma is None: sigma = np.logspace(-3,1,9) if lamb is None: lamb = np.logspace(-3,1,9) self.__x_num_row = self.__x.shape[0] self.__y_num_row = self.__y.shape[0] self.__kernel_num = np.min(np.array([kernel_num, self.__x_num_row])).item() #kernel number is the minimum number of x's lines and the number of kernel. self.__centers = np.array(rand.sample(list(self.__x),k=self.__kernel_num)) #randomly choose candidates of rbf kernel centroid. self.__n_minimum = min(self.__x_num_row, self.__y_num_row) # self.__kernel = jit(partial(gauss_kernel,centers=self.__centers)) self._RuLSIF(x = self.__x, y = self.__y, alpha = alpha, s_sigma = np.atleast_1d(sigma), s_lambda = np.atleast_1d(lamb), )
def test_neighbor_list_build_time_dependent(self, dtype, dim): key = random.PRNGKey(1) if dim == 2: box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32) elif dim == 3: box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0], [0.0, 0.0, 7.25]]) min_length = np.min(np.diag(box_fn(0.))) cutoff = f32(1.23) # TODO(schsam): Get cell-list working with anisotropic cell sizes. cell_size = cutoff / min_length displacement, _ = space.periodic_general(box_fn) metric = space.metric(displacement) R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype) N = R.shape[0] neighbor_list_fn = partition.neighbor_list(metric, 1., cutoff, 0.0, 1.1, cell_size=cell_size, t=np.array(0.)) idx = neighbor_list_fn(R, t=np.array(0.25)).idx R_neigh = R[idx] mask = idx < N metric = partial(metric, t=f32(0.25)) d = vmap(vmap(metric, (None, 0))) dR = d(R, R_neigh) d_exact = space.map_product(metric) dR_exact = d_exact(R, R) dR = np.where(dR < cutoff, dR, 0) * mask dR_exact = np.where(dR_exact < cutoff, dR_exact, 0) dR = np.sort(dR, axis=1) dR_exact = np.sort(dR_exact, axis=1) for i in range(dR.shape[0]): dR_row = dR[i] dR_row = dR_row[dR_row > 0.] dR_exact_row = dR_exact[i] dR_exact_row = dR_exact_row[dR_exact_row > 0.] self.assertAllClose(dR_row, dR_exact_row)
def loss_func(params, target_params, state, target_state, rng, transition_batch): rngs = hk.PRNGSequence(rng) S = self.q.observation_preprocessor(next(rngs), transition_batch.S) A = self.q.action_preprocessor(next(rngs), transition_batch.A) W = jnp.clip(transition_batch.W, 0.1, 10.) # clip importance weights to reduce variance # regularization term if self.policy_regularizer is None: regularizer = 0. else: # flip sign (typical example: regularizer = -beta * entropy) regularizer = -self.policy_regularizer.batch_eval( target_params['reg'], target_params['reg_hparams'], target_state['reg'], next(rngs), transition_batch) Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True) G = self.target_func(target_params, target_state, next(rngs), transition_batch) G += regularizer loss = self.loss_function(G, Q, W) dLoss_dQ = jax.grad(self.loss_function, argnums=1) td_error = -Q.shape[0] * dLoss_dQ( G, Q) # e.g. (G - Q) if loss function is MSE # target-network estimate (is this worth computing?) Q_targ_list = [] qs = list( zip(self.q_targ_list, target_params['q_targ'], target_state['q_targ'])) for q, pm, st in qs: Q_targ, _ = q.function_type1(pm, st, next(rngs), S, A, False) assert Q_targ.ndim == 1, f"bad shape: {Q_targ.shape}" Q_targ_list.append(Q_targ) Q_targ_list = jnp.stack(Q_targ_list, axis=-1) assert Q_targ_list.ndim == 2, f"bad shape: {Q_targ_list.shape}" Q_targ = jnp.min(Q_targ_list, axis=-1) chex.assert_equal_shape([td_error, W, Q_targ]) metrics = { f'{self.__class__.__name__}/loss': loss, f'{self.__class__.__name__}/td_error': jnp.mean(W * td_error), f'{self.__class__.__name__}/td_error_targ': jnp.mean(-dLoss_dQ(Q, Q_targ, W)), } return loss, (td_error, state_new, metrics)
def svd_truncated(A, chi_max=None, cutoff=DEFAULT_CUTOFF, epsilon=DEFAULT_EPS, return_norm_change=False): """ Like `svd`, but keeps at most `chi_max` many singular values and ignores singular values below `cutoff` """ if return_norm_change: U, S, Vh, norm_change = svd_reduced(A, tolerance=cutoff, epsilon=epsilon, return_norm_change=True) if chi_max is not None: k = np.min([chi_max, len(S)]) old_S_norm = np.linalg.norm(S) U = dynamic_slice_in_dim(U, 0, k, 1) S = dynamic_slice_in_dim(S, 0, k, 0) Vh = dynamic_slice_in_dim(Vh, 0, k, 0) norm_change *= np.linalg.norm( S) / old_S_norm # FIXME is this correct? return U, S, Vh, norm_change else: U, S, Vh = svd_reduced(A, tolerance=cutoff, epsilon=epsilon, return_norm_change=False) if chi_max is not None: k = np.min([chi_max, len(S)]) U = dynamic_slice_in_dim(U, 0, k, 1) S = dynamic_slice_in_dim(S, 0, k, 0) Vh = dynamic_slice_in_dim(Vh, 0, k, 0) return U, S, Vh
def actor_loss(policy_params: networks_lib.Params, critic_params: networks_lib.Params, sac_alpha: jnp.ndarray, transitions: types.Transition, key: jnp.ndarray) -> jnp.ndarray: """Computes the loss for the policy.""" dist_params = networks.policy_network.apply( policy_params, transitions.observation) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) q_action = networks.critic_network.apply(critic_params, transitions.observation, action) min_q = jnp.min(q_action, axis=-1) return jnp.mean(sac_alpha * log_prob - min_q)
def plot_posterior(self, X, ax, T): mu1, mu2, _, sigma1_p, sigma2_p, __ = self.compute_posterior_params( X, T) samples = self.generate_posterior_samples(40, X, T) spread_samples = samples.mean( axis=0) + (samples - samples.mean(axis=0)) * 2.5 mi = np.min(spread_samples, axis=0) ma = np.max(spread_samples, axis=0) xs = np.linspace(mi[0], ma[0], 1000) ys = np.linspace(mi[1], ma[1], 1000) X, Y = np.meshgrid(xs, ys) Z = self.banana_density(X, Y, mu1, mu2, sigma1_p, sigma2_p, self.a, self.b, self.m) ax.contour(X, Y, Z)
def single_interp(xstar): # N dx = great_circle_sep(ra1=x[:, 0] * jnp.pi / 180., dec1=x[:, 1] * jnp.pi / 180., ra2=xstar[0] * jnp.pi / 180., dec2=xstar[1] * jnp.pi / 180.) # dx = jnp.linalg.norm(xstar - x, axis=-1) nn_dist = jnp.min(jnp.where(dx == 0., jnp.inf, dx)) dx = dx / nn_dist weight = jnp.exp(-0.5 * dx**2) if outliers is not None: weight = jnp.where(outliers, 0., weight) weight /= jnp.sum(weight) return jnp.sum(y * weight)
def improvement(self, y_batch): """Return how much a batch of y values can improve over the incumbent. Args: y_batch: (q, t) shaped array of t y values corresponding to a batch of size q of x-locations. Each column of this array corresponds to one realization of y values at q x-locations evaluated/predicted t times. Returns: (t,) shaped array of non-negative float values indicating the improvement the best of q y value within each of t realizations achieves over the incumbent. """ difference = self.incumbent - jnp.min(y_batch, axis=0) return jnp.maximum(0.0, difference)
def actor_loss(policy_params: networks_lib.Params, q_params: networks_lib.Params, alpha: jnp.ndarray, transitions: types.Transition, key: networks_lib.PRNGKey) -> jnp.ndarray: dist_params = networks.policy_network.apply( policy_params, transitions.observation) action = networks.sample(dist_params, key) log_prob = networks.log_prob(dist_params, action) q_action = networks.q_network.apply(q_params, transitions.observation, action) min_q = jnp.min(q_action, axis=-1) actor_loss = alpha * log_prob - min_q return jnp.mean(actor_loss)
def testShapesAndValues(self): agent = self._create_test_agent() self.assertEqual(agent._support.shape[0], self._num_atoms) self.assertEqual(jnp.min(agent._support), -self._vmax) self.assertEqual(jnp.max(agent._support), self._vmax) state = onp.ones((1, 28224)) net_output = agent.online_network(state) self.assertEqual(net_output.logits.shape, (1, self._num_actions, self._num_atoms)) self.assertEqual(net_output.probabilities.shape, net_output.logits.shape) self.assertEqual(net_output.logits.shape[1], self._num_actions) self.assertEqual(net_output.logits.shape[2], self._num_atoms) self.assertEqual(net_output.q_values.shape, (1, self._num_actions))
def update_preconditioner(config, optimizer, p_update_grad_vars, rng, state, train_iter): """Computes preconditioner state using samples from dataloader.""" # TODO(basv): support multiple hosts. values = jax.tree_map(jnp.zeros_like, optimizer.target) eps = config.precon_est_eps n_batches = config.precon_est_batches for _ in range(n_batches): rng, est_key = jax.random.split(rng) batch = next(train_iter) batch = input_pipeline.load_and_shard_tf_batch(config, batch) if not config.debug_run: # Shard the step PRNG key sharded_keys = common_utils.shard_prng_key(est_key) else: sharded_keys = est_key values = p_update_grad_vars(optimizer, state, batch, sharded_keys, values) stds = jax.tree_map( lambda v: jnp.sqrt(eps + (1 / n_batches) * jnp.mean(v)), values) std_min = jnp.min(jnp.asarray(jax.tree_leaves(stds))) # TODO(basv): verify preconditioner estimate. new_precon = jax.tree_map(lambda s, x: jnp.ones_like(x) * (s / std_min), stds, optimizer.target) def convert_momentum( new_precon, state, ): """Converts momenta to new preconditioner.""" if config.weight_norm == 'learned': state = state.direction_state old_precon = state.preconditioner momentum = state.momentum m_c = jnp.power(old_precon, -.5) * momentum m = jnp.power(new_precon, .5) * m_c return m # TODO(basv): verify momentum convert. new_momentum = jax.tree_map(convert_momentum, new_precon, optimizer.state.param_states) # TODO(basv): verify this is replaced correctly, check replicated. optimizer = replace_param_state(config, optimizer, preconditioner=new_precon, momentum=new_momentum) return optimizer, rng
def get_peaks_single(history,tvec,int=0,Tint=0): """ calculates the peak prevalence for a single run, with or without an intervention history: 2D array of values for each variable at each timepoint tvec: 1D vector of timepoints int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0 Tint: Optional, timepoint (days) at which intervention was started """ delta_t=tvec[1]-tvec[0] if int==0: time_int=0 else: time_int=Tint # Final values print('Final recovered: {:3.1f}%'.format(100 * history[-1][6])) print('Final deaths: {:3.1f}%'.format(100 * history[-1][5])) print('Remaining infections: {:3.1f}%'.format( 100 * np.sum(history[-1][1:5], axis=-1))) # Peak prevalence print('Peak I1: {:3.1f}%'.format( 100 * np.max(history[:, 2]))) print('Peak I2: {:3.1f}%'.format( 100 * np.max(history[:, 3]))) print('Peak I3: {:3.1f}%'.format( 100 * np.max(history[:, 4]))) # Time of peaks print('Time of peak I1: {:3.1f} days'.format( np.argmax(history[:, 2])*delta_t - time_int)) print('Time of peak I2: {:3.1f} days'.format( np.argmax(history[:, 3])*delta_t - time_int)) print('Time of peak I3: {:3.1f} days'.format( np.argmax(history[:, 4])*delta_t - time_int)) # First time when all infections go extinct all_cases=history[:, 1]+history[:, 2]+history[:, 3]+history[:, 4] extinct=np.where(all_cases == 0)[0] if len(extinct) != 0: extinction_time=np.min(extinct)*delta_t - time_int print('Time of extinction of all infections: {:3.1f} days'.format(extinction_time)) else: print('Infections did not go extinct by end of simulation') return
def _test_online_smoothing_pf_full(self): if not hasattr(self, 'sim_samps'): self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0)) pf = BootstrapFilter() len_t = len(self.t) rkeys = random.split(random.PRNGKey(0), len_t) particles = initiate_particles(self.ssm_scenario, pf, self.n, rkeys[0], self.sim_samps.y[0], self.t[0]) for i in range(1, len_t): particles = propagate_particle_smoother(self.ssm_scenario, pf, particles, self.sim_samps.y[i], self.t[i], rkeys[i], 3, False) npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1) npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0]) < 0).mean(), 0.1)
def get_domain_extension( data: np.ndarray, extension: Union[float, int], ) -> Tuple[float, float]: """Gets the extension for the support Parameters ---------- data : np.ndarray the input data to get max and minimum extension : Union[float, int] the extension Returns ------- lb : float the new extended lower bound for the data ub : float the new extended upper bound for the data """ # case of int, convert to float if isinstance(extension, int): extension = float(extension / 100) # get the domain domain = np.abs(np.max(data) - np.min(data)) # extend the domain domain_ext = extension * domain # get the extended domain lb = np.min(data) - domain_ext up = np.max(data) + domain_ext return lb, up
def test_euclidean_point_cloud_parallel_weights(self, lse_mode): """Two point clouds, parallel execution for batched histograms.""" self.rng, *rngs = jax.random.split(self.rng, 2) batch = 4 a = jax.random.uniform(rngs[0], (batch, self.n)) b = jax.random.uniform(rngs[0], (batch, self.m)) a = a / jnp.sum(a, axis=1)[:, jnp.newaxis] b = b / jnp.sum(b, axis=1)[:, jnp.newaxis] threshold = 1e-3 geom = pointcloud.PointCloud( self.x, self.y, epsilon=0.1, online=True) errors = sinkhorn.sinkhorn( geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors err = errors[errors > -1][-1] self.assertGreater(jnp.min(threshold - err), 0)
def minmax(self, data, max_val=None, min_val=None): if min_val is None: self.min = np.min(data, axis=0) else: self.min = min_val if max_val is None: self.max = np.max(data, axis=0) else: self.max = max_val minmax_data = (data - self.min) / (self.max - self.min) return minmax_data
def _assert_is_diagonal(self, j, axis1, axis2, constant_diagonal: bool): c = j.shape[axis1] self.assertEqual(c, j.shape[axis2]) mask_shape = [c if i in (axis1, axis2) else 1 for i in range(j.ndim)] mask = np.eye(c, dtype=np.bool_).reshape(mask_shape) # Check that removing the diagonal makes the array all 0. j_masked = np.where(mask, np.zeros((), j.dtype), j) self.assertAllClose(np.zeros_like(j, j.dtype), j_masked) if constant_diagonal: # Check that diagonal is constant. if j.size != 0: j_diagonals = np.diagonal(j, axis1=axis1, axis2=axis2) self.assertAllClose(np.min(j_diagonals, -1), np.max(j_diagonals, -1))
def plot_mean_response(mus: np.ndarray, x: np.ndarray, y: np.ndarray, response) -> plt.Figure: fig = plt.figure() plt.plot(x, y, '.', alpha=0.3, label='data') xx = np.linspace(np.min(x), np.max(x), 100) yy = [] for x in xx: yy.append(response(mus[-1, :], x)) yy = np.array(yy).squeeze() plt.plot(xx, yy, label='mean response') plt.xlabel("x") plt.ylabel("y") plt.title("Mean response") return fig
def compute_periodic_snr_debug_info(state): debug_info = OrderedDict() snr_state = state.snr_state if snr_state.snr_matrix is not None: w, v = jax_eig(snr_state.snr_matrix) D = jnp.linalg.inv(v) @ snr_state.snr_matrix @ v D = jnp.diag(D).real debug_info['debug_info/max_real_eig'] = jnp.max(D) debug_info['debug_info/min_real_eig'] = jnp.min(D) debug_info['debug_info/avg_real_eig'] = jnp.mean(D) debug_info['debug_info/std_real_eig'] = jnp.std(D) debug_info['debug_info/num_real_eig_gt_zero'] = jnp.sum(D > 0.) debug_info['debug_info/ratio_real_eig_gt_zero'] = jnp.mean(D > 0.) return debug_info
def get_bmu_distance_squares(self, bmu_loc): if self.periodic: offsets = jnp.array([[0, 0], [-self.m, -self.n], [self.m, self.n], [-self.m, 0], [self.m, 0], [0, -self.n], [0, self.n], [-self.m, self.n], [self.m, -self.n]]) # offsets = jax.device_put(offsets) offset_locations = jnp.array( [self.locations + offset for offset in offsets]) batch_cdist = jax.vmap(jax_cdist_0, (None, 0)) distances = batch_cdist(bmu_loc, offset_locations) min_dists = jnp.min(distances, axis=0) return min_dists else: distances = jax_cdist_0(bmu_loc, self.locations) return distances
def visualize_depth(depth, acc=None, near=None, far=None, curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps), modulus=0, colormap=None): """Visualize a depth map. Args: depth: A depth map. acc: An accumulation map, in [0, 1]. near: The depth of the near plane, if None then just use the min(). far: The depth of the far plane, if None then just use the max(). curve_fn: A curve function that gets applied to `depth`, `near`, and `far` before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). modulus: If > 0, mod the normalized depth by `modulus`. Use (0, 1]. colormap: A colormap function. If None (default), will be set to matplotlib's viridis if modulus==0, sinebow otherwise. Returns: An RGB visualization of `depth`. """ # If `near` or `far` are None, identify the min/max non-NaN values. eps = jnp.finfo(jnp.float32).eps near = near or jnp.min(jnp.nan_to_num(depth, jnp.inf)) - eps far = far or jnp.max(jnp.nan_to_num(depth, -jnp.inf)) + eps # Curve all values. depth, near, far = [curve_fn(x) for x in [depth, near, far]] # Wrap the values around if requested. if modulus > 0: value = jnp.mod(depth, modulus) / modulus colormap = colormap or sinebow else: # Scale to [0, 1]. value = jnp.nan_to_num(jnp.clip((depth - near) / (far - near), 0, 1)) colormap = colormap or cm.get_cmap('viridis') vis = colormap(value)[:, :, :3] # Set non-accumulated pixels to white. if acc is not None: vis = vis * acc[:, :, None] + (1 - acc)[:, :, None] return vis
def resolve_clashes(x0, box0, min_dist=0.1): def urt(x, box): distance_matrix = distance(x, box) i, j = np.triu_indices(len(distance_matrix), k=1) return distance_matrix[i, j] dij = urt(x0, box0) x_shape = x0.shape box_shape = box0.shape if np.min(dij) < min_dist: # print('some distances too small') print( f"before optimization: min(dij) = {np.min(dij)} < min_dist threshold ({min_dist})" ) # print('smallest few distances', sorted(dij)[:10]) def unflatten(xbox): n = x_shape[0] * x_shape[1] x = xbox[:n].reshape(x_shape) box = xbox[n:].reshape(box_shape) return x, box def U_repulse(xbox): x, box = unflatten(xbox) dij = urt(x, box) return np.sum(np.where(dij < min_dist, (dij - min_dist)**2, 0)) def fun(xbox): v, g = value_and_grad(U_repulse)(xbox) return float(v), onp.array(g, onp.float64) initial_state = np.hstack([x0.flatten(), box0.flatten()]) # print(f'penalty before: {U_repulse(initial_state)}') result = minimize(fun, initial_state, jac=True, method="L-BFGS-B") # print(f'penalty after minimization: {U_repulse(result.x)}') x, box = unflatten(result.x) dij = urt(x, box) print(f"after optimization: min(dij) = {np.min(dij)}") return x, box else: return x0, box0
def get_data_min_max(records): """ Get min and max for each feature across the dataset. """ cache_path = os.path.join(records.processed_folder, "minmax_" + str(records.quantization) + '.pt') if os.path.exists(cache_path): with open(cache_path, "rb") as cache_file: data = pickle.load(cache_file) data_min, data_max = data return data_min, data_max data_min, data_max = None, None for b, (tt, vals, mask) in enumerate(records): if b % 100 == 0: print(b, len(records)) n_features = vals.shape[-1] batch_min = [] batch_max = [] for i in range(n_features): non_missing_vals = vals[:, i][mask[:, i] == 1] if len(non_missing_vals) == 0: batch_min.append(jnp.inf) batch_max.append(-jnp.inf) else: batch_min.append(jnp.min(non_missing_vals)) batch_max.append(jnp.max(non_missing_vals)) batch_min = jnp.stack(batch_min) batch_max = jnp.stack(batch_max) if (data_min is None) and (data_max is None): data_min = batch_min data_max = batch_max else: data_min = jnp.minimum(data_min, batch_min) data_max = jnp.maximum(data_max, batch_max) with open(cache_path, "wb") as cache_file: pickle.dump((data_min, data_max), cache_file) return data_min, data_max
def next_mask(self, prev_sel, size, rng): # Choose the degrees of the next layer max_connection = self.dim - 1 if self.triangular_jacobian == False else self.dim if self.method == "random": sel = random.randint(rng, shape=(size,), minval=min(jnp.min(sel), max_connection), maxval=dim) elif "sequential" in self.method: sel = jnp.arange(size)%max(1, max_connection) + min(1, max_connection) if self.method == "shuffled_sequential": sel = random.permutation(rng, sel) else: assert 0, "Invalid mask method" # Create the new mask mask = (prev_sel[:,None] <= sel).astype(jnp.int32) return mask, sel
def top_k_accuracy(top_k: int, logits: JTensor, label_ids: Optional[JTensor] = None, label_probs: Optional[JTensor] = None, weights: Optional[JTensor] = None) -> JTensor: """Computes the top-k accuracy given the logits and labels. Args: top_k: An int scalar, specifying the value of top-k. logits: A [..., C] float tensor corresponding to the logits. label_ids: A [...] int vector corresponding to the class labels. One of label_ids and label_probs should be presented. label_probs: A [..., C] float vector corresponding to the class probabilites. Must be presented if label_ids is None. weights: A [...] float vector corresponding to the weight to assign to each example. Returns: The top-k accuracy represented as a `JTensor`. Raises: ValueError if neither `label_ids` nor `label_probs` are provided. """ if label_ids is None and label_probs is None: raise ValueError("One of label_ids and label_probs should be given.") if label_ids is None: label_ids = jnp.argmax(label_probs, axis=-1) values, _ = jax.lax.top_k(logits, k=top_k) threshold = jnp.min(values, axis=-1) # Reshape logits to [-1, C]. logits_reshaped = jnp.reshape(logits, [-1, logits.shape[-1]]) # Reshape label_ids to [-1, 1]. label_ids_reshaped = jnp.reshape(label_ids, [-1, 1]) logits_slice = jnp.take_along_axis(logits_reshaped, label_ids_reshaped, axis=-1)[..., 0] # Reshape logits_slice back to original shape to be compatible with weights. logits_slice = jnp.reshape(logits_slice, label_ids.shape) correct = jnp.greater_equal(logits_slice, threshold) correct_sum = jnp.sum(correct * weights) all_sum = jnp.maximum(1.0, jnp.sum(weights)) return correct_sum / all_sum
def beam_search_cond_fn(state): """beam search state termination condition fn.""" # 1. is less than max length? not_max_length_yet = state.cur_len < max_length # 2. can the new beams still improve? best_running_score = state.running_scores[:, -1:] / (max_length ** length_penalty) worst_finished_score = jnp.where( state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) ) improvement_still_possible = jnp.all(worst_finished_score < best_running_score) # 3. is there still a beam that has not finished? still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) return not_max_length_yet & still_open_beam & improvement_still_possible
def spin_wiener_filter(data_q, data_u, ncov_diag_Q, ncov_diag_U, input_ps_map_E, input_ps_map_B, iterations=10): """ Wiener filter Elsner-Wandelt messenger field adapted for spin-2 fields (CMB polarization or galaxy weak lensing( Parameters ---------- data_q : Q square image data (e.g. gamma1) data_u : U square image data (e.g. gamma2) ncov_diag_Q : Q noise variance per pixel (assumed uncorrelated) ncov_diag_U : U noise variance per pixel (assumed uncorrelated) input_ps_map_E : 1D power P(k) for E-mode signal power spectrum evaluated 2D components k1,k2 as a square image input_ps_map_B : 1D power P(k) for B-mode signal power spectrum evaluated 2D components k1,k2 as a square image iterations : number of iterations Returns ------- s_q,s_u : Wiener filtered q and u signals """ tcov_diag = jnp.min(jnp.array([ncov_diag_Q, ncov_diag_U])) scov_ft_E = jnp.fft.fftshift(input_ps_map_E) scov_ft_B = jnp.fft.fftshift(input_ps_map_B) s_q = jnp.zeros(data_q.shape) s_u = jnp.zeros(data_q.shape) for i in jnp.arange(iterations): # in Q, U representation t_Q = (tcov_diag / ncov_diag_Q) * data_q + ( (ncov_diag_Q - tcov_diag) / ncov_diag_Q) * s_q t_U = (tcov_diag / ncov_diag_U) * data_u + ( (ncov_diag_U - tcov_diag) / ncov_diag_U) * s_u # in E, B representation t_E, t_B = ks93(t_Q, t_U) s_E = (scov_ft_E / (scov_ft_E + tcov_diag)) * jnp.fft.fft2(t_E) s_B = (scov_ft_B / (scov_ft_B + tcov_diag)) * jnp.fft.fft2(t_B) s_E = jnp.fft.ifft2(s_E) s_B = jnp.fft.ifft2(s_B) # in Q, U representation s_q, s_u = ks93inv(s_E, s_B) return s_q, s_u
def _decode(self, feats): b, t, d = feats.shape mean_val = jnp.mean(feats) max_val = jnp.max(feats) min_val = jnp.min(feats) frames = jnp.reshape(feats, [b * t, d]) val = jnp.where(frames > 0.8, size=b * t * d) hist = jnp.zeros([d]) hist = hist.at[val[1]].add(1) hist = hist.at[0].set(1) nframes = jnp.array(b * t) metrics = { 'mean': (mean_val, nframes), 'max': (max_val, nframes), 'min': (min_val, nframes), 'hist': (hist, nframes) } return metrics
def initialize_slices(T, b): B1s = [] B2s = [] B3s = [] B2B3s = [] bj0 = 0 mindim = jnp.min(T.shape) for bj in range(0, mindim - b, b): bj0 = bj B1s.append(index[:bj]) B2s.append(index[bj:bj + b]) B3s.append(index[bj + b:]) B2B3s.append(index[bj:]) for bj in range(bj0 + b, mindim, b): B1s.append(index[:bj]) B2B3s.append(index[bj:]) return [B1s, B2s, B3s, B2B3s]