def random_adjacency(key: jnp.ndarray, num_nodes: int, num_edges: int, dtype=jnp.float32) -> COO: """ Get the adjacency matrix of a random fully connected undirected graph. Note that `num_edges` is only approximate. The process of creating edges it: - sample `num_edges` random edges - remove self-edges - add ring edges - add reverse edges - filter duplicates Args: key: `jax.random.PRNGKey`. num_nodes: number of nodes in returned graph. num_edges: number of random internal edges initially added. dtype: dtype of returned JAXSparse. Returns: COO, shape (num_nodes, num_nodes), weights all ones. """ shape = num_nodes, num_nodes internal_indices = jax.random.uniform( key, shape=(num_edges, ), dtype=jnp.float32, maxval=num_nodes**2, ).astype(jnp.int32) # remove randomly sampled self-edges. self_edges = (internal_indices // num_nodes) == (internal_indices % num_nodes) internal_indices = internal_indices[jnp.logical_not(self_edges)] # add a ring so we know the graph is connected r = jnp.arange(num_nodes, dtype=jnp.int32) ring_indices = r * num_nodes + (r + 1) % num_nodes indices = jnp.concatenate((internal_indices, ring_indices)) # add reverse indices coords = jnp.unravel_index(indices, shape) coords_rev = coords[-1::-1] indices_rev = jnp.ravel_multi_index(coords_rev, shape) indices = jnp.concatenate((indices, indices_rev)) # filter out duplicates indices = jnp.unique(indices) row, col = jnp.unravel_index(indices, shape) return COO((jnp.ones((row.size, ), dtype=dtype), row, col), shape=shape)
def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): assert self._position_encoding_type != 'none' # Queries come from elsewhere if subsampled_points is not None: # unravel_index returns a tuple (x_idx, y_idx, ...) # stack to get the [n, d] tensor of coordinates pos = jnp.stack(jnp.unravel_index(subsampled_points, self._output_index_dim), axis=1) # Map these coordinates to [-1, 1] pos = -1 + 2 * pos / jnp.array(self._output_index_dim)[None, :] pos = jnp.broadcast_to( pos[None], [inputs.shape[0], pos.shape[0], pos.shape[1]]) pos_emb = self.output_pos_enc(batch_size=inputs.shape[0], pos=pos) pos_emb = jnp.reshape(pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]]) else: pos_emb = self.output_pos_enc(batch_size=inputs.shape[0]) if self._concat_preprocessed_input: if inputs_without_pos is None: raise ValueError('Value is required for inputs_without_pos if' ' concat_preprocessed_input is True') pos_emb = jnp.concatenate([inputs_without_pos, pos_emb], axis=-1) return pos_emb
def _indices(key): if not sparse_shape: return jnp.empty((nse, n_sparse), dtype=int) flat_ind = random.choice(key, sparse_size, shape=(nse, ), replace=not unique_indices) return jnp.column_stack(jnp.unravel_index(flat_ind, sparse_shape))
def index_to_coordinate_array(idx, offset=4, repeat=1): # Turn an array of index values into a tuple of coordinate arrays H, W, C = idx.shape[:3] # The input indices will be spread out by some offset flat_coordinates = idx.ravel() + offset * jnp.arange( H * W * C).repeat(repeat) return jnp.unravel_index(flat_coordinates, (H, W, C, offset))
def get_cluster(prototypes, prototypes_density): num_dims = prototypes.shape[1] cluster_id = jnp.unravel_index(jnp.argmax(prototypes_density), prototypes_density.shape) cluster = prototypes[(0, tuple(range(num_dims)), *cluster_id)] cluster = jnp.expand_dims(cluster, axis=tuple(range(1, num_dims + 1)))[jnp.newaxis, ...] return cluster, prototypes_density[cluster_id]
def propose_spin_flip_Z2(key, s, info): idxKey, flipKey = jax.random.split(key) idx = random.randint(idxKey, (1, ), 0, s.size)[0] idx = jnp.unravel_index(idx, s.shape) update = (s[idx] + 1) % 2 s = jax.ops.index_update(s, jax.ops.index[idx], update) # On average, do a global spin flip every 30 updates to # reflect Z_2 symmetry doFlip = random.randint(flipKey, (1, ), 0, 5)[0] return jax.lax.cond(doFlip == 0, lambda x: 1 - x, lambda x: x, s)
def gather_error_check(error, enabled_errors, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): out = lax.gather_p.bind(operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) if ErrorCategory.OOB not in enabled_errors: return out, error # compare to OOB masking logic in lax._gather_translation_rule dnums = dimension_numbers operand_dims = np.array(operand.shape) num_batch_dims = len(start_indices.shape) - 1 upper_bound = operand_dims[np.array(dnums.start_index_map)] upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)] upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) in_bounds = (start_indices >= 0) & (start_indices <= upper_bound.astype( start_indices.dtype)) # Get first OOB index, axis and axis size so it can be added to the error msg. flat_idx = jnp.argmin(in_bounds) multi_idx = jnp.unravel_index(flat_idx, start_indices.shape) oob_axis = jnp.array(dnums.start_index_map)[multi_idx[-1]] oob_axis_size = jnp.array(operand.shape)[oob_axis] oob_index = jnp.ravel(start_indices)[flat_idx] payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) msg = (f'out-of-bounds indexing at {summary()} for array of ' f'shape {operand.shape}: ' 'index {payload0} is out of bounds for axis {payload1} ' 'with size {payload2}.') return out, assert_func(error, jnp.all(in_bounds), msg, payload)
def unravelindex(indices, dims, order=order): return jnp.unravel_index(indices, dims)
def loss_fn( model, padded_example_and_rng, static_metadata, regularization_weights = None, reinforce_weight = 1.0, baseline_weight = 0.001, ): """Loss function for multi-pointer task. Args: model: The model to evaluate. padded_example_and_rng: Padded example to evaluate on, with a PRNGKey. static_metadata: Padding configuration for the example, since this may vary for different examples. regularization_weights: Associates side output key regexes with regularization penalties. reinforce_weight: Weight to give to the reinforce term. baseline_weight: Weight to give to the baseline. Returns: Tuple of loss and metrics. """ padded_example, rng = padded_example_and_rng # Run the model. with side_outputs.collect_side_outputs() as collected_side_outputs: with flax.nn.stochastic(rng): joint_log_probs = model(padded_example, static_metadata) # Computing the loss: # Extract logits for the correct location. log_probs_at_bug = joint_log_probs[padded_example.bug_node_index, :] # Compute p(repair) = sum[ p(node) p(repair | node) ] # -> log p(repair) = logsumexp[ log p(node) + log p (repair | node) ] log_prob_joint = jax.scipy.special.logsumexp( log_probs_at_bug + jnp.log(padded_example.repair_node_mask)) # Metrics: # Marginal log probabilities: log_prob_bug = jax.scipy.special.logsumexp(log_probs_at_bug) log_prob_repair = jax.scipy.special.logsumexp( jax.scipy.special.logsumexp(joint_log_probs, axis=0) + jnp.log(padded_example.repair_node_mask)) # Conditional log probabilities: log_prob_repair_given_bug = log_prob_joint - log_prob_bug log_prob_bug_given_repair = log_prob_joint - log_prob_repair # Majority accuracy (1 if we assign the correct tuple > 50%): # (note that this is easier to compute, since we can't currently aggregate # probability separately for each candidate.) log_half = jnp.log(0.5) majority_acc_joint = log_prob_joint > log_half # Probabilities associated with each node. node_node_probs = jnp.exp(joint_log_probs) # Accumulate across unique candidates by identifier. This has the same shape, # but only the first few values will be populated. node_candidate_probs = padded_example.unique_candidate_operator.apply_add( in_array=node_node_probs, out_array=jnp.zeros_like(node_node_probs), in_dims=[1], out_dims=[1]) # Classify: 50% decision boundary only_buggy_probs = node_candidate_probs.at[0, :].set(0).at[:, 0].set(0) p_buggy = jnp.sum(only_buggy_probs) pred_nobug = p_buggy <= 0.5 # Localize/repair: take most likely bug position, conditioned on being buggy pred_bug_loc, pred_cand_id = jnp.unravel_index( jnp.argmax(only_buggy_probs), only_buggy_probs.shape) actual_nobug = jnp.array(padded_example.bug_node_index == 0) actual_bug = jnp.logical_not(actual_nobug) pred_bug = jnp.logical_not(pred_nobug) metrics = { 'nll/joint': -log_prob_joint, 'nll/marginal_bug': -log_prob_bug, 'nll/marginal_repair': -log_prob_repair, 'nll/repair_given_bug': -log_prob_repair_given_bug, 'nll/bug_given_repair': -log_prob_bug_given_repair, 'inaccuracy/legacy_overall': 1 - majority_acc_joint, 'inaccuracy/overall': (~((actual_nobug & pred_nobug) | (actual_bug & pred_bug & (pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id)))), 'inaccuracy/classification_overall': (actual_nobug != pred_nobug), 'inaccuracy/classification_given_nobug': train_util.RatioMetric( numerator=(actual_nobug & ~pred_nobug), denominator=actual_nobug), 'inaccuracy/classification_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~pred_bug), denominator=actual_bug), 'inaccuracy/localized_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_bug_loc == padded_example.bug_node_index)), denominator=actual_bug), 'inaccuracy/repaired_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_cand_id == padded_example.repair_id)), denominator=actual_bug), 'inaccuracy/localized_repaired_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~((pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id))), denominator=actual_bug), 'inaccuracy/overall_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_bug & (pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id))), denominator=actual_bug), } loss = -log_prob_joint for k, v in collected_side_outputs.items(): # Flax collection keys will start with "/". if v.shape == (): # pylint: disable=g-explicit-bool-comparison metrics['side' + k] = v if regularization_weights: total_regularization = 0 for query, weight in regularization_weights.items(): logging.info('Regularizing side outputs matching query %s', query) found = False for k, v in collected_side_outputs.items(): if re.search(query, k): found = True logging.info('Regularizing %s with weight %f', k, weight) total_regularization += weight * v if not found: raise ValueError( f'Regularization query {query} did not match any side output. ' f'Side outputs were {set(collected_side_outputs.keys())}') loss = loss + total_regularization is_single_sample = any( k.endswith('one_sample_log_prob_per_edge_per_node') for k in collected_side_outputs) if is_single_sample: log_prob, = [ v for k, v in collected_side_outputs.items() if k.endswith('one_sample_log_prob_per_edge_per_node') ] baseline, = [ v for k, v in collected_side_outputs.items() if k.endswith('one_sample_reward_baseline') ] num_real_nodes = padded_example.input_graph.bundle.graph_metadata.num_nodes valid_mask = ( jnp.arange(static_metadata.bundle_padding.static_max_metadata.num_nodes) < num_real_nodes) log_prob = jnp.where(valid_mask[None, :], log_prob, 0) total_log_prob = jnp.sum(log_prob) reinforce_virtual_cost = ( total_log_prob * jax.lax.stop_gradient(loss - baseline)) baseline_penalty = jnp.square(loss - baseline) reinforce_virtual_cost_zeroed = reinforce_virtual_cost - jax.lax.stop_gradient( reinforce_virtual_cost) loss = ( loss + reinforce_weight * reinforce_virtual_cost_zeroed + baseline_weight * baseline_penalty) metrics['reinforce_virtual_cost'] = reinforce_virtual_cost metrics['baseline_penalty'] = baseline_penalty metrics['baseline'] = baseline metrics['total_log_prob'] = total_log_prob metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics) return loss, metrics
def unravel_index(indices, shape): indices = _remove_jaxarray(indices) shape = _remove_jaxarray(shape) return jnp.unravel_index(indices, shape)
import mpi4jax # noqa: E402 # # MPI setup # supported_nproc = (1, 2, 4, 6, 8, 16) if mpi_size not in supported_nproc: raise RuntimeError(f"Got invalid number of MPI processes: {mpi_size}. " f"Please choose one of these: {supported_nproc}.") nproc_y = min(mpi_size, 2) nproc_x = mpi_size // nproc_y proc_idx = jnp.unravel_index(mpi_rank, (nproc_y, nproc_x)) # # Grid setup # # we use 1 cell overlap on each side of the domain nx_global = 360 + 2 ny_global = 180 + 2 # grid spacing in metres dx = 5e3 dy = 5e3 # make sure processes divide the domain evenly assert (nx_global - 2) % nproc_x == 0
def body(state: State): new_state_date = dict() # upon the start of each iteration the state is consistent. # we use the consistent state to calculate the reassignment metrics. # we then reassign and update the state so that it is consistent again. # K, N # K log_f_k = log_factor_k(state.cluster_id, state.log_maha_k, state.num_k, state.logdetC_k) def single_log_h(log_f_k, log_maha_k, num_k, logdetC_k): log_d = log_maha_k + log_f_k log_VS_k = log_VS + jnp.log(num_k) - jnp.log(num_S) return log_ellipsoid_volume(logdetC_k, num_k, log_f_k) + log_d - log_VS_k # K, N log_h_k = vmap(single_log_h)(log_f_k, state.log_maha_k, state.num_k, state.logdetC_k) h_k = jnp.exp(log_h_k) # # K, K, N delta_F = h_k[:, None, :] - h_k # Can reassign if mask says we are working on that node and there would be at least D+1 points in that cluster # after taking from it. And, if delta_F < 0. able_to_reassign = mask & (state.num_k[state.cluster_id] > D + 1) delta_F_masked = jnp.where(able_to_reassign, delta_F, jnp.inf) # (k_to, k_from, n_reassign) = jnp.where(delta_F == min_delta_F) (k_to, k_from, n_reassign) = jnp.unravel_index(jnp.argmin(delta_F_masked.flatten()), delta_F.shape) # dynamic update index arrays of sufficient length for all dyn_k_to_idx = jnp.concatenate([k_to[None], jnp.asarray([0, 0])]) dyn_k_from_idx = jnp.concatenate([k_from[None], jnp.asarray([0, 0])]) ### # update the state ### # cluster id cluster_id = dynamic_update_slice(state.cluster_id, dyn_k_to_idx[0:1], n_reassign[None]) ### # num_k num_from = state.num_k[k_from] - 1 num_to = state.num_k[k_from] + 1 num_k = dynamic_update_slice(state.num_k, num_from[None], dyn_k_from_idx[0:1]) num_k = dynamic_update_slice(num_k, num_to[None], dyn_k_to_idx[0:1]) ### # ellipsoid parameters x_n = points[n_reassign, :] mu_from = state.mu_k[k_from, :] + (state.mu_k[k_from, :] - x_n) / (state.num_k[k_from] - 1) C_from, logdetC_from = rank_one_update_matrix_inv( state.C_k[k_from, :, :], state.logdetC_k[k_from], x_n - mu_from, x_n - state.mu_k[k_from, :], add=False) # print(C_from, logdetC_from) mu_to = state.mu_k[ k_to, :] + (x_n - state.mu_k[k_to, :]) / (state.num_k[k_to] + 1) C_to, logdetC_to = rank_one_update_matrix_inv(state.C_k[k_to, :, :], state.logdetC_k[k_to], x_n - mu_to, x_n - state.mu_k[k_to, :], add=True) print('from', state.logdetC_k[k_from]) # print(C_to, logdetC_to) mu_k = dynamic_update_slice(state.mu_k, mu_from[None, :], dyn_k_from_idx[0:2]) mu_k = dynamic_update_slice(mu_k, mu_to[None, :], dyn_k_to_idx[0:2]) C_k = dynamic_update_slice(state.C_k, C_from[None, :, :], dyn_k_from_idx) C_k = dynamic_update_slice(C_k, C_to[None, :, :], dyn_k_to_idx) logdetC_k = dynamic_update_slice(state.logdetC_k, logdetC_from[None], dyn_k_from_idx[0:1]) logdetC_k = dynamic_update_slice(logdetC_k, logdetC_to[None], dyn_k_to_idx[0:1]) ### # maha precision_from = C_from * num_from precision_to = C_to * num_to log_maha_from = jnp.log( vmap(lambda point: (point - mu_from) @ precision_from @ ( point - mu_from))(points)) log_maha_to = jnp.log( vmap(lambda point: (point - mu_to) @ precision_to @ (point - mu_to))(points)) log_maha_k = dynamic_update_slice(state.log_maha_k, log_maha_from[None, :], dyn_k_from_idx[0:2]) log_maha_k = dynamic_update_slice(log_maha_k, log_maha_to[None, :], dyn_k_to_idx[0:2]) # estimate volumes of current clustering log_f_k = log_factor_k(cluster_id, log_maha_k, num_k, logdetC_k) log_VE_k = vmap(log_ellipsoid_volume)(logdetC_k, num_k, log_f_k) log_VS_k = jnp.log(num_k) - jnp.log(num_S) log_V_sum = logsumexp(log_VE_k) new_loss = log_V_sum - log_VS loss_decreased = new_loss < state.min_loss delay = jnp.where(loss_decreased, 0, state.delay + 1) min_loss = jnp.where(loss_decreased, new_loss, state.min_loss) print(jnp.min(delta_F_masked), log_V_sum, logdetC_k) done = jnp.all(cluster_id == state.cluster_id) \ | (delay >= 10) \ | jnp.any(num_k < D + 1) \ | jnp.isnan(log_V_sum) \ | (jnp.min(delta_F_masked) >= 0.) # ['i', 'done', 'cluster_id', 'C_k', 'logdetC_k', # 'mu_k', 'log_maha_k', 'num_k', # 'log_VE_k', 'log_VS_k', # 'min_loss', 'delay'] state = state._replace(i=state.i + 1, done=done, cluster_id=cluster_id, C_k=C_k, logdetC_k=logdetC_k, mu_k=mu_k, log_maha_k=log_maha_k, num_k=num_k, log_VE_k=log_VE_k, log_VS_k=log_VS_k, min_loss=min_loss, delay=delay) return state
def propose_spin_flip(key, s, info): idx = random.randint(key, (1, ), 0, s.size)[0] idx = jnp.unravel_index(idx, s.shape) update = (s[idx] + 1) % 2 return jax.ops.index_update(s, jax.ops.index[idx], update)