def cond_func(val): """Whether the while loop should continue.""" # We continue the greedy search iff both: # (1) We have yet to exceed the max steps set by p.decoder.seqlen, AND; # (2) At least one row in the batch has not terminated. length_ok = val.step < seq_len - 1 all_rows_done = jnp.all(val.done) return jnp.logical_and(length_ok, jnp.logical_not(all_rows_done))
def assert_func(error: Error, pred: Bool, msg: str, payload: Optional[Payload]) -> Error: code = next_code() payload = init_payload if payload is None else payload out_err = error.err | jnp.logical_not(pred) out_code = lax.select(error.err, error.code, code) out_payload = lax.select(error.err, error.payload, payload) return Error(out_err, out_code, {code: msg, **error.msgs}, out_payload)
def add_ones_to_line(single_group): indices, = np.where(np.logical_not(single_group)) k_zeros = len(indices) groups = (np.zeros((k_zeros, single_group.shape[0]), dtype=bool) + single_group[np.newaxis, :]) groups = jax.ops.index_update( groups, jax.ops.index[np.arange(0, k_zeros), indices], True) return groups
def assert_batching_rule(batched_args, batch_dims, *, msgs): size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims) if dim is not batching.not_mapped) pred, code, payload = (batching.bdim_at_front(a, d, size) for a, d in zip(batched_args, batch_dims)) err = Error(jnp.logical_not(pred), code, msgs, payload) check_error(err) return [], []
def simplex_proposal(self, rng_key, x, grad_, hess_): max_non_diag_hess = np.max(hess_[np.logical_not( np.eye(hess_.shape[0], dtype=bool))].reshape(hess_.shape[0], -1), axis=1) concentration = 1 - x**2 * (np.diag(hess_) - max_non_diag_hess) dist_ = dist.Dirichlet(concentration=concentration + W_CORRECTION) return dist_.sample(rng_key).reshape(x.shape), dist_
def cond_func(val): q, i, norm_delta_q, error, = val diverged = np.logical_or(error > divergence_tol, np.isnan(error)) converged = np.logical_and(error < convergence_tol, norm_delta_q < position_tol) return np.logical_not( np.logical_or((i >= max_iters), np.logical_or(diverged, converged)))
def get_init_state(self, batch_size: int): """ Returns jax object with initial state Args: batch_size: size of a batch Returns: [batch_size, 2] jax array with initial coordinates """ _onp_segments = onp.asarray(self.jax_scene.segments) eps = self.config["constants"]["radius"] / 2 max_x, min_x = onp.max(_onp_segments[:, :, 0]), onp.min(_onp_segments[:, :, 0]) max_y, min_y = onp.max(_onp_segments[:, :, 1]), onp.min(_onp_segments[:, :, 1]) remaining_idxs = np.arange(batch_size) init_proposal = onp.random.uniform((min_x - 1, min_y - 1), (max_x + 1, max_y + 1), size=(batch_size, 2)) proposal_jax = np.asarray(init_proposal) while True: is_inner = if_points_inside_any_polygon(proposal_jax, self.jax_scene) _, distance = find_closest_segment_to_points_batch( proposal_jax, self.jax_scene.segments) acceptable = onp.asarray( np.logical_not( np.logical_or( is_inner, distance <= self.config["constants"]["radius"] + eps))) init_proposal[remaining_idxs[acceptable], :] = onp.array( proposal_jax)[acceptable, :] if np.all(acceptable): break logger.debug("Resampling starting position") remaining_idxs = remaining_idxs[np.logical_not(acceptable)] proposal_jax = np.asarray( onp.random.uniform((min_x, min_y), (max_x, max_y), size=(len(remaining_idxs), 2))) return np.array(init_proposal)
def state_needs_iteration(self, theta: Parameters, augmented: TheAugmentedState) -> bool: """ Args: theta: The parameters. augmented: The state. Returns: True while iteration needs to continue. """ enough_iterations = augmented.iterations >= self.minimum_iterations converged = self.converged(augmented) not_too_many_iterations = augmented.iterations < self.maximum_iterations return jnp.logical_and(not_too_many_iterations, jnp.logical_not(jnp.logical_and(enough_iterations, converged)))
def random_adjacency(key: jnp.ndarray, num_nodes: int, num_edges: int, dtype=jnp.float32) -> COO: """ Get the adjacency matrix of a random fully connected undirected graph. Note that `num_edges` is only approximate. The process of creating edges it: - sample `num_edges` random edges - remove self-edges - add ring edges - add reverse edges - filter duplicates Args: key: `jax.random.PRNGKey`. num_nodes: number of nodes in returned graph. num_edges: number of random internal edges initially added. dtype: dtype of returned JAXSparse. Returns: COO, shape (num_nodes, num_nodes), weights all ones. """ shape = num_nodes, num_nodes internal_indices = jax.random.uniform( key, shape=(num_edges, ), dtype=jnp.float32, maxval=num_nodes**2, ).astype(jnp.int32) # remove randomly sampled self-edges. self_edges = (internal_indices // num_nodes) == (internal_indices % num_nodes) internal_indices = internal_indices[jnp.logical_not(self_edges)] # add a ring so we know the graph is connected r = jnp.arange(num_nodes, dtype=jnp.int32) ring_indices = r * num_nodes + (r + 1) % num_nodes indices = jnp.concatenate((internal_indices, ring_indices)) # add reverse indices coords = jnp.unravel_index(indices, shape) coords_rev = coords[-1::-1] indices_rev = jnp.ravel_multi_index(coords_rev, shape) indices = jnp.concatenate((indices, indices_rev)) # filter out duplicates indices = jnp.unique(indices) row, col = jnp.unravel_index(indices, shape) return COO((jnp.ones((row.size, ), dtype=dtype), row, col), shape=shape)
def __init__(self, v, s): self.v = v.toDict() if isinstance(v, DotMap) else v self.s = s.toDict() if isinstance(s, DotMap) else s self.c = {} merge(self.c, s) merge(self.c, v) self.c_flat, self.idx, self.tree = flatten(self.c) self.v_tree = nan_like(self.c) merge(self.v_tree, v) self.v_flat, _, _ = flatten(self.v_tree) self.update_idx = jnp.where(jnp.logical_not(jnp.isnan(self.v_flat))) self.x = self.v_flat[self.update_idx]
def get_rotation_pytree(src: Any, dst: Any) -> Any: """ Takes two n-dimensional vectors/Pytree and returns an nxn rotation matrix mapping cjax to dst. Raises Value Error when unsuccessful. """ def __assert_rotation(R): if R.ndim != 2: print("R must be a matrix") a, b = R.shape if a != b: print("R must be square") if (not jnp.isclose( jnp.abs(jnp.eye(a) - jnp.dot(R, R.T)).max(), 0.0, rtol=0.5) ) or (not jnp.isclose( jnp.abs(jnp.eye(a) - jnp.dot(R.T, R)).max(), 0.0, rtol=0.5)): print("R is not diagonal") if not pytree_shape_array_equal(src, dst): print("cjax and dst must be 1-dimensional arrays with the same shape.") x = pytree_normalized(src) y = pytree_normalized(dst) n = len(dst) # compute angle between x and y in their spanning space theta = jnp.arccos(jnp.dot( x, y)) # they are normalized so there is no denominator if jnp.isclose(theta, 0): print("x and y are co-linear") # construct the 2d rotation matrix connecting x to y in their spanning space R = jnp.array([[jnp.cos(theta), -jnp.sin(theta)], [jnp.sin(theta), jnp.cos(theta)]]) __assert_rotation(R) # get projections onto Span<x,y> and its orthogonal complement u = x v = pytree_normalized(pytree_sub(y, (jnp.dot(u, y) * u))) P = jnp.outer(u, u.T) + jnp.outer( v, v.T) # projection onto 2d space spanned by x and y Q = jnp.eye( n) - P # projection onto the orthogonal complement of Span<x,y> # lift the rotation matrix into the n-dimensional space uv = jnp.hstack((u[:, None], v[:, None])) R = Q + jnp.dot(uv, jnp.dot(R, uv.T)) __assert_rotation(R) if jnp.any(jnp.logical_not(jnp.isclose(jnp.dot(R, x), y, rtol=0.25))): print("Rotation matrix did not work") return R
def cond(val): """Check whether or not the proposal has been accepted. Args: val: A tuple containing the previous proposal, whether or not it was accepted (it wasn't), and the current iteration of the rejection sampling loop. Returns: out: A boolean for whether or not to continue sampling. If the sample was rejected, try again. Otherwise, return the accepted sample. """ _, isacc, _ = val return jnp.logical_not(isacc)
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None): land_mask = (ks >= 0)[:, :, np.newaxis] edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :] == ks[:, :, np.newaxis]) water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :] >= ks[:, :, np.newaxis]) a_tri = water_mask * a * np.logical_not(edge_mask) b_tri = where(water_mask, b, 1.) if b_edge is not None: b_tri = where(edge_mask, b_edge, b_tri) c_tri = water_mask * c d_tri = water_mask * d if d_edge is not None: d_tri = where(edge_mask, d_edge, d_tri) return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask, a_tri, b_tri, c_tri, d_tri
def sequence_prediction_metrics( logits: jnp.ndarray, labels: jnp.ndarray, mask: Optional[jnp.ndarray] = None) -> Dict[str, float]: """Compute the metrics for sequence prediction. Args: logits: [B, T, V] array of logits. labels: [B, T] array of labels. mask: [B, T] array of binary masks, if provided. Returns: metrics: a dictionary of metrics. """ vocab_size = logits.shape[-1] logps = jax.nn.log_softmax(logits) labels_one_hot = hk.one_hot(labels, vocab_size) class_logps = jnp.sum(logps * labels_one_hot, axis=-1) prediction_correct = jnp.argmax(logits, axis=-1) == labels if mask is not None: masked_logps = mask * class_logps total_count = jnp.sum(mask) tokens_correct = jnp.sum(prediction_correct * mask) seq_correct = jnp.all(jnp.logical_or(prediction_correct, jnp.logical_not(mask)), axis=-1) else: masked_logps = class_logps total_count = np.prod(class_logps.shape) tokens_correct = jnp.sum(prediction_correct) seq_correct = jnp.all(prediction_correct, axis=-1) token_accuracy = tokens_correct.astype(jnp.float32) / total_count seq_accuracy = jnp.mean(seq_correct) log_probs = jnp.mean(jnp.sum(masked_logps, axis=-1)) total_loss = -jnp.sum(masked_logps) loss = total_loss / total_count return dict( loss=loss, total_loss=total_loss, total_count=total_count, token_accuracy=token_accuracy, seq_accuracy=seq_accuracy, log_probs=log_probs, )
def setup(self): # Alternating binary mask. mask = jnp.arange(0, np.prod(self.event_shape)) % 2 mask = jnp.reshape(mask, self.event_shape) mask = mask.astype(bool) layers = [] for conditioner in self.conditioners: layer = distrax.MaskedCoupling(mask=mask, bijector=self.bijector_fn, conditioner=conditioner) layers.append(layer) # Flip the mask after each layer. mask = jnp.logical_not(mask) # Chain layers to create the flow. self.flow = distrax.Chain(layers)
def macula_matrix(d, k, n): """Produces d-separable design matrix.""" # https://core.ac.uk/download/pdf/82758506.pdf n_groups = int(scipy.special.comb(n, d)) n_cols = int(scipy.special.comb(n, k)) new_groups = np.zeros((n_groups, n_cols), dtype=bool) comb_groups = itertools.combinations(range(n), d) comb_cols = itertools.combinations(range(n), k) d_vec = np.zeros((n_groups, n), dtype=bool) k_vec = np.zeros((n_cols, n), dtype=bool) for i, comb_g in enumerate(comb_groups): d_vec[i, comb_g] = True for j, comb_c in enumerate(comb_cols): k_vec[j, comb_c] = True for i in range(n_groups): for j in range(n_cols): new_groups[i, j] = np.all( np.logical_or(np.logical_not(d_vec[i, :]), k_vec[j, :])) return new_groups
def observe_nb2(name, latent, det_prob, dispersion, obs=None): mask = True if obs is not None: mask = np.isfinite(obs) & (obs >= 0.0) obs = np.where(mask, obs, 0.0) if np.any(np.logical_not(mask)): warnings.warn('Some observed values are invalid') det_prob = np.broadcast_to(det_prob, latent.shape) mean = det_prob * latent numpyro.deterministic("mean_" + name, mean) d = NB2(mu=mean, k=dispersion) with numpyro.handlers.mask(mask_array=mask): y = numpyro.sample(name, d, obs=obs) return y
def make_flow_model(event_shape: Sequence[int], num_layers: int, hidden_sizes: Sequence[int], num_bins: int) -> distrax.Transformed: """Creates the flow model.""" # Alternating binary mask. mask = jnp.arange(0, np.prod(event_shape)) % 2 mask = jnp.reshape(mask, event_shape) mask = mask.astype(bool) def bijector_fn(params: Array): return distrax.RationalQuadraticSpline(params, range_min=0., range_max=1.) # Number of parameters for the rational-quadratic spline: # - `num_bins` bin widths # - `num_bins` bin heights # - `num_bins + 1` knot slopes # for a total of `3 * num_bins + 1` parameters. num_bijector_params = 3 * num_bins + 1 layers = [] for _ in range(num_layers): layer = distrax.MaskedCoupling(mask=mask, bijector=bijector_fn, conditioner=make_conditioner( event_shape, hidden_sizes, num_bijector_params)) layers.append(layer) # Flip the mask after each layer. mask = jnp.logical_not(mask) # We invert the flow so that the `forward` method is called with `log_prob`. flow = distrax.Inverse(distrax.Chain(layers)) base_distribution = distrax.Independent( distrax.Uniform(low=jnp.zeros(event_shape), high=jnp.ones(event_shape)), reinterpreted_batch_ndims=len(event_shape)) return distrax.Transformed(base_distribution, flow)
def __call__(self, rng: np.ndarray, particles: np.ndarray, rho: float, log_posterior_params: Dict[str, np.ndarray], log_base_measure_params: Dict[str, np.ndarray]): """Call carries out procedures 4 in https://arxiv.org/pdf/1101.6037.pdf. One expects that fit_model has been called right before to store the model in self.model Args: rng: np.ndarray<int> random key particles: np.ndarray [n_particles,n_patients] plausible infections states rho: float, scaling for posterior. log_posterior_params: Dict of parameters to compute log-posterior. log_base_measure_params: Dict of parameters to compute log-base measure. Returns: A np.ndarray representing the new particles. """ rngs = jax.random.split(rng, 2) n_samples = particles.shape[0] proposed, logprop_proposed, logprop_particles = self.sample_from_model( rngs[0], particles) llparticles = bayes.tempered_logpos_logbase(particles, log_posterior_params, log_base_measure_params, rho) llproposed = bayes.tempered_logpos_logbase(proposed, log_posterior_params, log_base_measure_params, rho) logratio = llproposed - llparticles + logprop_particles - logprop_proposed p_replacement = np.minimum(np.exp(logratio), 1) replacement = (jax.random.uniform(rngs[1], shape=(n_samples, )) < p_replacement) not_replacement = np.logical_not(replacement) return (replacement[:, np.newaxis] * proposed + not_replacement[:, np.newaxis] * particles)
def update(updates, state, params=None): inner_state = state.inner_state flat_updates = tree_flatten(updates)[0] isfinite = jnp.all( jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64), 1 + state.notfinite_count) def do_update(_): return inner.update(updates, inner_state, params) def reject_update(_): return (tree_map(jnp.zeros_like, updates), inner_state) updates, new_inner_state = lax.cond( jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors), do_update, reject_update, operand=None) return updates, ApplyIfFiniteState( notfinite_count=notfinite_count, last_finite=isfinite, total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite, inner_state=new_inner_state)
def __call__(self, inputs: Tuple[ndarray, ndarray]): input_seq, input_mask = inputs B, L, D = input_seq.shape del L, D input_seq = jnp.swapaxes(input_seq, 0, 1) input_mask = jnp.swapaxes(input_mask, 0, 1) reset_mask = jnp.logical_not(input_mask) reset_mask = reset_mask.at[1:].set( reset_mask[:-1]) # move the mask to the right h0c0: hk.LSTMState = self.lstm.initial_state(B) # type: ignore hx, state = hk.dynamic_unroll(self.lstm, (input_seq, reset_mask), h0c0) del state # split encoder/decoder states encoder_hx = hx[:self.padded_input_len] decoder_hx = hx[self.padded_input_len:] # append the initial hidden state. # this will be the encoder state for the [end] token. encoder_hx = jnp.concatenate([encoder_hx, h0c0.hidden[None]], axis=0) # create query and value for attention mechanism encoder_value = self.enc_att_fc(encoder_hx)[None] decoder_query = self.dec_att_fc(decoder_hx)[:, None] # energy function energy = encoder_value * decoder_query energy = jnp.sum(energy, axis=-1) / math.sqrt(energy.shape[-1]) # apply input sequence mask input_mask = input_mask[:self.padded_input_len + 1][None] energy = jnp.where(input_mask, energy, float('-inf')) # normalize energy = jax.nn.log_softmax(energy, axis=1) # batch first, logit last return jnp.transpose(energy, [2, 0, 1])
def compute_tp_fp_fn_weighted( predictions: Array, labels: Array, weights: Array, ignore_class: Optional[int]) -> Tuple[float, float, float]: """Compute true positives, false positives and false negatives. Args: predictions: [batch, length] categorical predictions int array. labels: [batch, length] categorical labels int array. weights: [batch, length]. ignore_class: which class to ignore in the computations Returns: Tuple with numbers of true positive, false positive and false negative predictions. """ true_positives = (predictions == labels) false_positives = jnp.logical_not(true_positives) false_negatives = false_positives if ignore_class is not None: dont_ignore_predictions = (predictions != ignore_class) dont_ignore_labels = (labels != ignore_class) true_positives = jnp.logical_and(true_positives, dont_ignore_predictions) # Exactly the same as # true_positives = jnp.logical_and(true_positives, dont_ignore_labels) # since for true positives `dont_ignore_predictions` = `dont_ignore_labels`. false_positives = jnp.logical_and(false_positives, dont_ignore_predictions) false_negatives = jnp.logical_and(false_negatives, dont_ignore_labels) def get_weighted_sum(values): values = values.astype(weights.dtype) return jnp.dot(values, weights) n_true_positive = get_weighted_sum(true_positives) n_false_positive = get_weighted_sum(false_positives) n_false_negative = get_weighted_sum(false_negatives) return n_true_positive, n_false_positive, n_false_negative
def add_mean(i, means): mask = jnp.arange(num_means) < i mask = mask * 1. + (jnp.logical_not(mask)) * jnp.inf def spaced_mean_body(state): key, means, mask, unused_sample = state key, subkey = jax.random.split(key) sample = jax.random.uniform(subkey, shape=[data_dim], minval=bounds[0], maxval=bounds[1]) return key, means, mask, sample def spaced_mean_cond(state): unused_key, means, mask, sample = state dists = mask * jax.vmap(dist, in_axes=(0, None))(means, sample) return jnp.any(jnp.less(dists, min_distance)) _, _, _, new_mean = jax.lax.while_loop(spaced_mean_cond, spaced_mean_body, (key, means, mask, means[0])) means = jax.ops.index_update(means, i, new_mean) return means
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): """Checks if indices are within bounds and update does not generate NaN.""" out = prim.bind(operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if ErrorCategory.OOB not in enabled_errors: return out, error in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers) oob_msg = f'out-of-bounds indexing while updating at {summary()}' oob_error = assert_func(error, in_bounds, oob_msg, None) no_nans = jnp.logical_not(jnp.any(jnp.isnan(out))) nan_msg = f'nan generated by primitive {prim.name} at {summary()}' return out, assert_func(oob_error, no_nans, nan_msg, None)
def while_cond(state): done, _, _ = state return jnp.logical_not(done)
def newton( backward_differences: np.ndarray, max_num_iters: Union[np.ndarray, float, int], newton_coefficient: Union[np.ndarray, float, int], ode_fn_vec: Callable, order: Union[np.ndarray, float, int], step_size: Union[np.ndarray, float, int], time: Union[np.ndarray, float, int], tol: Union[np.ndarray, float], unitary: np.ndarray, upper: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Runs Newton's method to solve the BDF equation.""" initial_guess = np.sum( np.where( np.arange(MAX_ORDER + 1).reshape(-1, 1) <= order, backward_differences[:MAX_ORDER + 1], np.zeros_like(backward_differences)[:MAX_ORDER + 1], ), axis=0, ) rhs_constant_term = newton_coefficient * np.sum( np.where( np.arange(1, MAX_ORDER + 1).reshape(-1, 1) <= order, RECIPROCAL_SUMS[1:, np.newaxis] * backward_differences[1:MAX_ORDER + 1], np.zeros_like(backward_differences)[1:MAX_ORDER + 1], ), axis=0, ) next_time = time + step_size def newton_body(iterand): """Performs one iteration of Newton's method.""" next_backward_difference = iterand.next_backward_difference next_state_vec = iterand.next_state_vec rhs = (newton_coefficient * step_size * ode_fn_vec(next_time, next_state_vec) - rhs_constant_term - next_backward_difference) delta = np.squeeze( jax.scipy.linalg.solve_triangular(upper, np.matmul( np.transpose(unitary), rhs[:, np.newaxis]), lower=False)) num_iters = iterand.num_iters + 1 next_backward_difference += delta next_state_vec += delta delta_norm = np.linalg.norm(delta) lipschitz_const = delta_norm / iterand.prev_delta_norm # Stop if method has converged. approx_dist_to_sol = lipschitz_const / (1.0 - lipschitz_const) * delta_norm close_to_sol = approx_dist_to_sol < tol delta_norm_is_zero = np.equal(delta_norm, np.array(0.0, dtype=np.float64)) converged = close_to_sol | delta_norm_is_zero finished = converged # Stop if any of the following conditions are met: # (A) We have hit the maximum number of iterations. # (B) The method is converging too slowly. # (C) The method is not expected to converge. too_slow = lipschitz_const > 1.0 finished = finished | too_slow too_many_iters = np.equal(num_iters, max_num_iters) num_iters_left = max_num_iters - num_iters wont_converge = approx_dist_to_sol * lipschitz_const**num_iters_left > tol finished = finished | too_many_iters | wont_converge return _NewtonIterand( converged=converged, finished=finished, next_backward_difference=next_backward_difference, next_state_vec=next_state_vec, num_iters=num_iters, prev_delta_norm=delta_norm, ) iterand = _NewtonIterand( converged=False, finished=False, next_backward_difference=np.zeros_like(initial_guess), next_state_vec=initial_guess, num_iters=0, prev_delta_norm=(np.array(-0.0)), ) iterand = jax.lax.while_loop( lambda iterand: np.logical_not(iterand.finished), newton_body, iterand) ## Krishna: need to double check this return ( iterand.converged, iterand.next_backward_difference, iterand.next_state_vec, iterand.num_iters, )
def cond_fun(carry): _n, x_last, x = carry converged = convergence_condition(a, x, x_last) not_exceeded = lax.lt(_n, max_iter) return np.logical_and(np.logical_not(converged), not_exceeded)
def loss_fn( model, padded_example_and_rng, static_metadata, regularization_weights = None, reinforce_weight = 1.0, baseline_weight = 0.001, ): """Loss function for multi-pointer task. Args: model: The model to evaluate. padded_example_and_rng: Padded example to evaluate on, with a PRNGKey. static_metadata: Padding configuration for the example, since this may vary for different examples. regularization_weights: Associates side output key regexes with regularization penalties. reinforce_weight: Weight to give to the reinforce term. baseline_weight: Weight to give to the baseline. Returns: Tuple of loss and metrics. """ padded_example, rng = padded_example_and_rng # Run the model. with side_outputs.collect_side_outputs() as collected_side_outputs: with flax.nn.stochastic(rng): joint_log_probs = model(padded_example, static_metadata) # Computing the loss: # Extract logits for the correct location. log_probs_at_bug = joint_log_probs[padded_example.bug_node_index, :] # Compute p(repair) = sum[ p(node) p(repair | node) ] # -> log p(repair) = logsumexp[ log p(node) + log p (repair | node) ] log_prob_joint = jax.scipy.special.logsumexp( log_probs_at_bug + jnp.log(padded_example.repair_node_mask)) # Metrics: # Marginal log probabilities: log_prob_bug = jax.scipy.special.logsumexp(log_probs_at_bug) log_prob_repair = jax.scipy.special.logsumexp( jax.scipy.special.logsumexp(joint_log_probs, axis=0) + jnp.log(padded_example.repair_node_mask)) # Conditional log probabilities: log_prob_repair_given_bug = log_prob_joint - log_prob_bug log_prob_bug_given_repair = log_prob_joint - log_prob_repair # Majority accuracy (1 if we assign the correct tuple > 50%): # (note that this is easier to compute, since we can't currently aggregate # probability separately for each candidate.) log_half = jnp.log(0.5) majority_acc_joint = log_prob_joint > log_half # Probabilities associated with each node. node_node_probs = jnp.exp(joint_log_probs) # Accumulate across unique candidates by identifier. This has the same shape, # but only the first few values will be populated. node_candidate_probs = padded_example.unique_candidate_operator.apply_add( in_array=node_node_probs, out_array=jnp.zeros_like(node_node_probs), in_dims=[1], out_dims=[1]) # Classify: 50% decision boundary only_buggy_probs = node_candidate_probs.at[0, :].set(0).at[:, 0].set(0) p_buggy = jnp.sum(only_buggy_probs) pred_nobug = p_buggy <= 0.5 # Localize/repair: take most likely bug position, conditioned on being buggy pred_bug_loc, pred_cand_id = jnp.unravel_index( jnp.argmax(only_buggy_probs), only_buggy_probs.shape) actual_nobug = jnp.array(padded_example.bug_node_index == 0) actual_bug = jnp.logical_not(actual_nobug) pred_bug = jnp.logical_not(pred_nobug) metrics = { 'nll/joint': -log_prob_joint, 'nll/marginal_bug': -log_prob_bug, 'nll/marginal_repair': -log_prob_repair, 'nll/repair_given_bug': -log_prob_repair_given_bug, 'nll/bug_given_repair': -log_prob_bug_given_repair, 'inaccuracy/legacy_overall': 1 - majority_acc_joint, 'inaccuracy/overall': (~((actual_nobug & pred_nobug) | (actual_bug & pred_bug & (pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id)))), 'inaccuracy/classification_overall': (actual_nobug != pred_nobug), 'inaccuracy/classification_given_nobug': train_util.RatioMetric( numerator=(actual_nobug & ~pred_nobug), denominator=actual_nobug), 'inaccuracy/classification_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~pred_bug), denominator=actual_bug), 'inaccuracy/localized_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_bug_loc == padded_example.bug_node_index)), denominator=actual_bug), 'inaccuracy/repaired_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_cand_id == padded_example.repair_id)), denominator=actual_bug), 'inaccuracy/localized_repaired_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~((pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id))), denominator=actual_bug), 'inaccuracy/overall_given_bug': train_util.RatioMetric( numerator=(actual_bug & ~(pred_bug & (pred_bug_loc == padded_example.bug_node_index) & (pred_cand_id == padded_example.repair_id))), denominator=actual_bug), } loss = -log_prob_joint for k, v in collected_side_outputs.items(): # Flax collection keys will start with "/". if v.shape == (): # pylint: disable=g-explicit-bool-comparison metrics['side' + k] = v if regularization_weights: total_regularization = 0 for query, weight in regularization_weights.items(): logging.info('Regularizing side outputs matching query %s', query) found = False for k, v in collected_side_outputs.items(): if re.search(query, k): found = True logging.info('Regularizing %s with weight %f', k, weight) total_regularization += weight * v if not found: raise ValueError( f'Regularization query {query} did not match any side output. ' f'Side outputs were {set(collected_side_outputs.keys())}') loss = loss + total_regularization is_single_sample = any( k.endswith('one_sample_log_prob_per_edge_per_node') for k in collected_side_outputs) if is_single_sample: log_prob, = [ v for k, v in collected_side_outputs.items() if k.endswith('one_sample_log_prob_per_edge_per_node') ] baseline, = [ v for k, v in collected_side_outputs.items() if k.endswith('one_sample_reward_baseline') ] num_real_nodes = padded_example.input_graph.bundle.graph_metadata.num_nodes valid_mask = ( jnp.arange(static_metadata.bundle_padding.static_max_metadata.num_nodes) < num_real_nodes) log_prob = jnp.where(valid_mask[None, :], log_prob, 0) total_log_prob = jnp.sum(log_prob) reinforce_virtual_cost = ( total_log_prob * jax.lax.stop_gradient(loss - baseline)) baseline_penalty = jnp.square(loss - baseline) reinforce_virtual_cost_zeroed = reinforce_virtual_cost - jax.lax.stop_gradient( reinforce_virtual_cost) loss = ( loss + reinforce_weight * reinforce_virtual_cost_zeroed + baseline_weight * baseline_penalty) metrics['reinforce_virtual_cost'] = reinforce_virtual_cost metrics['baseline_penalty'] = baseline_penalty metrics['baseline'] = baseline metrics['total_log_prob'] = total_log_prob metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics) return loss, metrics
def cond_fn(*args): """ check if all are done or reached max number of iterations """ i, _, done, _, _ = args[0] return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
def while_cond(val): possible_nan = jnp.sin(1. / val) return jnp.logical_not(jnp.isnan(possible_nan))