Пример #1
0
 def cond_func(val):
     """Whether the while loop should continue."""
     # We continue the greedy search iff both:
     #   (1) We have yet to exceed the max steps set by p.decoder.seqlen, AND;
     #   (2) At least one row in the batch has not terminated.
     length_ok = val.step < seq_len - 1
     all_rows_done = jnp.all(val.done)
     return jnp.logical_and(length_ok, jnp.logical_not(all_rows_done))
Пример #2
0
def assert_func(error: Error, pred: Bool, msg: str,
                payload: Optional[Payload]) -> Error:
    code = next_code()
    payload = init_payload if payload is None else payload
    out_err = error.err | jnp.logical_not(pred)
    out_code = lax.select(error.err, error.code, code)
    out_payload = lax.select(error.err, error.payload, payload)
    return Error(out_err, out_code, {code: msg, **error.msgs}, out_payload)
Пример #3
0
def add_ones_to_line(single_group):
    indices, = np.where(np.logical_not(single_group))
    k_zeros = len(indices)
    groups = (np.zeros((k_zeros, single_group.shape[0]), dtype=bool) +
              single_group[np.newaxis, :])
    groups = jax.ops.index_update(
        groups, jax.ops.index[np.arange(0, k_zeros), indices], True)
    return groups
Пример #4
0
def assert_batching_rule(batched_args, batch_dims, *, msgs):
    size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
                if dim is not batching.not_mapped)
    pred, code, payload = (batching.bdim_at_front(a, d, size)
                           for a, d in zip(batched_args, batch_dims))
    err = Error(jnp.logical_not(pred), code, msgs, payload)
    check_error(err)
    return [], []
Пример #5
0
    def simplex_proposal(self, rng_key, x, grad_, hess_):
        max_non_diag_hess = np.max(hess_[np.logical_not(
            np.eye(hess_.shape[0], dtype=bool))].reshape(hess_.shape[0], -1),
                                   axis=1)
        concentration = 1 - x**2 * (np.diag(hess_) - max_non_diag_hess)

        dist_ = dist.Dirichlet(concentration=concentration + W_CORRECTION)

        return dist_.sample(rng_key).reshape(x.shape), dist_
Пример #6
0
 def cond_func(val):
     q, i, norm_delta_q, error, = val
     diverged = np.logical_or(error > divergence_tol,
                              np.isnan(error))
     converged = np.logical_and(error < convergence_tol,
                                norm_delta_q < position_tol)
     return np.logical_not(
         np.logical_or((i >= max_iters),
                       np.logical_or(diverged, converged)))
Пример #7
0
    def get_init_state(self, batch_size: int):
        """
        Returns jax object with initial state
        Args:
            batch_size: size of a batch

        Returns:
            [batch_size, 2] jax array with initial coordinates
        """
        _onp_segments = onp.asarray(self.jax_scene.segments)
        eps = self.config["constants"]["radius"] / 2

        max_x, min_x = onp.max(_onp_segments[:, :,
                                             0]), onp.min(_onp_segments[:, :,
                                                                        0])
        max_y, min_y = onp.max(_onp_segments[:, :,
                                             1]), onp.min(_onp_segments[:, :,
                                                                        1])
        remaining_idxs = np.arange(batch_size)
        init_proposal = onp.random.uniform((min_x - 1, min_y - 1),
                                           (max_x + 1, max_y + 1),
                                           size=(batch_size, 2))
        proposal_jax = np.asarray(init_proposal)
        while True:
            is_inner = if_points_inside_any_polygon(proposal_jax,
                                                    self.jax_scene)
            _, distance = find_closest_segment_to_points_batch(
                proposal_jax, self.jax_scene.segments)
            acceptable = onp.asarray(
                np.logical_not(
                    np.logical_or(
                        is_inner,
                        distance <= self.config["constants"]["radius"] + eps)))
            init_proposal[remaining_idxs[acceptable], :] = onp.array(
                proposal_jax)[acceptable, :]
            if np.all(acceptable):
                break
            logger.debug("Resampling starting position")
            remaining_idxs = remaining_idxs[np.logical_not(acceptable)]
            proposal_jax = np.asarray(
                onp.random.uniform((min_x, min_y), (max_x, max_y),
                                   size=(len(remaining_idxs), 2)))
        return np.array(init_proposal)
Пример #8
0
 def state_needs_iteration(self, theta: Parameters, augmented: TheAugmentedState) -> bool:
     """
     Args:
         theta: The parameters.
         augmented: The state.
     Returns: True while iteration needs to continue.
     """
     enough_iterations = augmented.iterations >= self.minimum_iterations
     converged = self.converged(augmented)
     not_too_many_iterations = augmented.iterations < self.maximum_iterations
     return jnp.logical_and(not_too_many_iterations,
                            jnp.logical_not(jnp.logical_and(enough_iterations, converged)))
Пример #9
0
def random_adjacency(key: jnp.ndarray,
                     num_nodes: int,
                     num_edges: int,
                     dtype=jnp.float32) -> COO:
    """
    Get the adjacency matrix of a random fully connected undirected graph.

    Note that `num_edges` is only approximate. The process of creating edges it:
    - sample `num_edges` random edges
    - remove self-edges
    - add ring edges
    - add reverse edges
    - filter duplicates

    Args:
        key: `jax.random.PRNGKey`.
        num_nodes: number of nodes in returned graph.
        num_edges: number of random internal edges initially added.
        dtype: dtype of returned JAXSparse.

    Returns:
        COO, shape (num_nodes, num_nodes), weights all ones.
    """
    shape = num_nodes, num_nodes

    internal_indices = jax.random.uniform(
        key,
        shape=(num_edges, ),
        dtype=jnp.float32,
        maxval=num_nodes**2,
    ).astype(jnp.int32)
    # remove randomly sampled self-edges.
    self_edges = (internal_indices // num_nodes) == (internal_indices %
                                                     num_nodes)
    internal_indices = internal_indices[jnp.logical_not(self_edges)]

    # add a ring so we know the graph is connected
    r = jnp.arange(num_nodes, dtype=jnp.int32)
    ring_indices = r * num_nodes + (r + 1) % num_nodes
    indices = jnp.concatenate((internal_indices, ring_indices))

    # add reverse indices
    coords = jnp.unravel_index(indices, shape)
    coords_rev = coords[-1::-1]
    indices_rev = jnp.ravel_multi_index(coords_rev, shape)
    indices = jnp.concatenate((indices, indices_rev))

    # filter out duplicates
    indices = jnp.unique(indices)
    row, col = jnp.unravel_index(indices, shape)
    return COO((jnp.ones((row.size, ), dtype=dtype), row, col), shape=shape)
Пример #10
0
    def __init__(self, v, s):
        self.v = v.toDict() if isinstance(v, DotMap) else v
        self.s = s.toDict() if isinstance(s, DotMap) else s
        self.c = {}
        merge(self.c, s)
        merge(self.c, v)
        self.c_flat, self.idx, self.tree = flatten(self.c)

        self.v_tree = nan_like(self.c)
        merge(self.v_tree, v)
        self.v_flat, _, _ = flatten(self.v_tree)
        self.update_idx = jnp.where(jnp.logical_not(jnp.isnan(self.v_flat)))

        self.x = self.v_flat[self.update_idx]
Пример #11
0
def get_rotation_pytree(src: Any, dst: Any) -> Any:
    """
    Takes two n-dimensional vectors/Pytree and returns an
    nxn rotation matrix mapping cjax to dst.
    Raises Value Error when unsuccessful.
    """
    def __assert_rotation(R):
        if R.ndim != 2:
            print("R must be a matrix")
        a, b = R.shape
        if a != b:
            print("R must be square")
        if (not jnp.isclose(
                jnp.abs(jnp.eye(a) - jnp.dot(R, R.T)).max(), 0.0, rtol=0.5)
            ) or (not jnp.isclose(
                jnp.abs(jnp.eye(a) - jnp.dot(R.T, R)).max(), 0.0, rtol=0.5)):
            print("R is not diagonal")

    if not pytree_shape_array_equal(src, dst):
        print("cjax and dst must be 1-dimensional arrays with the same shape.")

    x = pytree_normalized(src)
    y = pytree_normalized(dst)
    n = len(dst)

    # compute angle between x and y in their spanning space
    theta = jnp.arccos(jnp.dot(
        x, y))  # they are normalized so there is no denominator
    if jnp.isclose(theta, 0):
        print("x and y are co-linear")
    # construct the 2d rotation matrix connecting x to y in their spanning space
    R = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
                   [jnp.sin(theta), jnp.cos(theta)]])
    __assert_rotation(R)
    # get projections onto Span<x,y> and its orthogonal complement
    u = x
    v = pytree_normalized(pytree_sub(y, (jnp.dot(u, y) * u)))
    P = jnp.outer(u, u.T) + jnp.outer(
        v, v.T)  # projection onto 2d space spanned by x and y
    Q = jnp.eye(
        n) - P  # projection onto the orthogonal complement of Span<x,y>
    # lift the rotation matrix into the n-dimensional space
    uv = jnp.hstack((u[:, None], v[:, None]))

    R = Q + jnp.dot(uv, jnp.dot(R, uv.T))
    __assert_rotation(R)
    if jnp.any(jnp.logical_not(jnp.isclose(jnp.dot(R, x), y, rtol=0.25))):
        print("Rotation matrix did not work")
    return R
    def cond(val):
        """Check whether or not the proposal has been accepted.

        Args:
            val: A tuple containing the previous proposal, whether or not it was
                accepted (it wasn't), and the current iteration of the rejection
                sampling loop.

        Returns:
            out: A boolean for whether or not to continue sampling. If the sample
                was rejected, try again. Otherwise, return the accepted sample.

        """
        _, isacc, _ = val
        return jnp.logical_not(isacc)
Пример #13
0
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None):
    land_mask = (ks >= 0)[:, :, np.newaxis]
    edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                             == ks[:, :, np.newaxis])
    water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                              >= ks[:, :, np.newaxis])
    a_tri = water_mask * a * np.logical_not(edge_mask)
    b_tri = where(water_mask, b, 1.)
    if b_edge is not None:
        b_tri = where(edge_mask, b_edge, b_tri)
    c_tri = water_mask * c
    d_tri = water_mask * d
    if d_edge is not None:
        d_tri = where(edge_mask, d_edge, d_tri)

    return solve_tridiag(a_tri, b_tri, c_tri,
                         d_tri), water_mask, a_tri, b_tri, c_tri, d_tri
Пример #14
0
def sequence_prediction_metrics(
        logits: jnp.ndarray,
        labels: jnp.ndarray,
        mask: Optional[jnp.ndarray] = None) -> Dict[str, float]:
    """Compute the metrics for sequence prediction.

  Args:
    logits: [B, T, V] array of logits.
    labels: [B, T] array of labels.
    mask: [B, T] array of binary masks, if provided.

  Returns:
    metrics: a dictionary of metrics.
  """
    vocab_size = logits.shape[-1]
    logps = jax.nn.log_softmax(logits)
    labels_one_hot = hk.one_hot(labels, vocab_size)
    class_logps = jnp.sum(logps * labels_one_hot, axis=-1)
    prediction_correct = jnp.argmax(logits, axis=-1) == labels
    if mask is not None:
        masked_logps = mask * class_logps
        total_count = jnp.sum(mask)
        tokens_correct = jnp.sum(prediction_correct * mask)
        seq_correct = jnp.all(jnp.logical_or(prediction_correct,
                                             jnp.logical_not(mask)),
                              axis=-1)
    else:
        masked_logps = class_logps
        total_count = np.prod(class_logps.shape)
        tokens_correct = jnp.sum(prediction_correct)
        seq_correct = jnp.all(prediction_correct, axis=-1)

    token_accuracy = tokens_correct.astype(jnp.float32) / total_count
    seq_accuracy = jnp.mean(seq_correct)
    log_probs = jnp.mean(jnp.sum(masked_logps, axis=-1))
    total_loss = -jnp.sum(masked_logps)
    loss = total_loss / total_count
    return dict(
        loss=loss,
        total_loss=total_loss,
        total_count=total_count,
        token_accuracy=token_accuracy,
        seq_accuracy=seq_accuracy,
        log_probs=log_probs,
    )
Пример #15
0
    def setup(self):
        # Alternating binary mask.
        mask = jnp.arange(0, np.prod(self.event_shape)) % 2
        mask = jnp.reshape(mask, self.event_shape)
        mask = mask.astype(bool)

        layers = []
        for conditioner in self.conditioners:
            layer = distrax.MaskedCoupling(mask=mask,
                                           bijector=self.bijector_fn,
                                           conditioner=conditioner)
            layers.append(layer)

            # Flip the mask after each layer.
            mask = jnp.logical_not(mask)

        # Chain layers to create the flow.
        self.flow = distrax.Chain(layers)
Пример #16
0
def macula_matrix(d, k, n):
  """Produces d-separable design matrix."""
  # https://core.ac.uk/download/pdf/82758506.pdf
  n_groups = int(scipy.special.comb(n, d))
  n_cols = int(scipy.special.comb(n, k))
  new_groups = np.zeros((n_groups, n_cols), dtype=bool)
  comb_groups = itertools.combinations(range(n), d)
  comb_cols = itertools.combinations(range(n), k)
  d_vec = np.zeros((n_groups, n), dtype=bool)
  k_vec = np.zeros((n_cols, n), dtype=bool)
  for i, comb_g in enumerate(comb_groups):
    d_vec[i, comb_g] = True
  for j, comb_c in enumerate(comb_cols):
    k_vec[j, comb_c] = True

  for i in range(n_groups):
    for j in range(n_cols):
      new_groups[i, j] = np.all(
          np.logical_or(np.logical_not(d_vec[i, :]), k_vec[j, :]))
  return new_groups
Пример #17
0
def observe_nb2(name, latent, det_prob, dispersion, obs=None):

    mask = True
    if obs is not None:
        mask = np.isfinite(obs) & (obs >= 0.0)
        obs = np.where(mask, obs, 0.0)

    if np.any(np.logical_not(mask)):
        warnings.warn('Some observed values are invalid')

    det_prob = np.broadcast_to(det_prob, latent.shape)

    mean = det_prob * latent
    numpyro.deterministic("mean_" + name, mean)

    d = NB2(mu=mean, k=dispersion)

    with numpyro.handlers.mask(mask_array=mask):
        y = numpyro.sample(name, d, obs=obs)

    return y
Пример #18
0
def make_flow_model(event_shape: Sequence[int], num_layers: int,
                    hidden_sizes: Sequence[int],
                    num_bins: int) -> distrax.Transformed:
    """Creates the flow model."""
    # Alternating binary mask.
    mask = jnp.arange(0, np.prod(event_shape)) % 2
    mask = jnp.reshape(mask, event_shape)
    mask = mask.astype(bool)

    def bijector_fn(params: Array):
        return distrax.RationalQuadraticSpline(params,
                                               range_min=0.,
                                               range_max=1.)

    # Number of parameters for the rational-quadratic spline:
    # - `num_bins` bin widths
    # - `num_bins` bin heights
    # - `num_bins + 1` knot slopes
    # for a total of `3 * num_bins + 1` parameters.
    num_bijector_params = 3 * num_bins + 1

    layers = []
    for _ in range(num_layers):
        layer = distrax.MaskedCoupling(mask=mask,
                                       bijector=bijector_fn,
                                       conditioner=make_conditioner(
                                           event_shape, hidden_sizes,
                                           num_bijector_params))
        layers.append(layer)
        # Flip the mask after each layer.
        mask = jnp.logical_not(mask)

    # We invert the flow so that the `forward` method is called with `log_prob`.
    flow = distrax.Inverse(distrax.Chain(layers))
    base_distribution = distrax.Independent(
        distrax.Uniform(low=jnp.zeros(event_shape),
                        high=jnp.ones(event_shape)),
        reinterpreted_batch_ndims=len(event_shape))

    return distrax.Transformed(base_distribution, flow)
Пример #19
0
    def __call__(self, rng: np.ndarray, particles: np.ndarray, rho: float,
                 log_posterior_params: Dict[str, np.ndarray],
                 log_base_measure_params: Dict[str, np.ndarray]):
        """Call carries out procedures 4 in https://arxiv.org/pdf/1101.6037.pdf.

    One expects that fit_model has been called right before to store the model
    in self.model

    Args:

     rng: np.ndarray<int> random key
     particles: np.ndarray [n_particles,n_patients] plausible infections states
     rho: float, scaling for posterior.
     log_posterior_params: Dict of parameters to compute log-posterior.
     log_base_measure_params: Dict of parameters to compute log-base measure.

    Returns:
     A np.ndarray representing the new particles.
    """
        rngs = jax.random.split(rng, 2)
        n_samples = particles.shape[0]

        proposed, logprop_proposed, logprop_particles = self.sample_from_model(
            rngs[0], particles)
        llparticles = bayes.tempered_logpos_logbase(particles,
                                                    log_posterior_params,
                                                    log_base_measure_params,
                                                    rho)
        llproposed = bayes.tempered_logpos_logbase(proposed,
                                                   log_posterior_params,
                                                   log_base_measure_params,
                                                   rho)
        logratio = llproposed - llparticles + logprop_particles - logprop_proposed
        p_replacement = np.minimum(np.exp(logratio), 1)
        replacement = (jax.random.uniform(rngs[1], shape=(n_samples, )) <
                       p_replacement)
        not_replacement = np.logical_not(replacement)
        return (replacement[:, np.newaxis] * proposed +
                not_replacement[:, np.newaxis] * particles)
Пример #20
0
  def update(updates, state, params=None):
    inner_state = state.inner_state
    flat_updates = tree_flatten(updates)[0]
    isfinite = jnp.all(
        jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
    notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64),
                                1 + state.notfinite_count)

    def do_update(_):
      return inner.update(updates, inner_state, params)
    def reject_update(_):
      return (tree_map(jnp.zeros_like, updates), inner_state)

    updates, new_inner_state = lax.cond(
        jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors),
        do_update, reject_update, operand=None)

    return updates, ApplyIfFiniteState(
        notfinite_count=notfinite_count,
        last_finite=isfinite,
        total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite,
        inner_state=new_inner_state)
Пример #21
0
    def __call__(self, inputs: Tuple[ndarray, ndarray]):
        input_seq, input_mask = inputs
        B, L, D = input_seq.shape
        del L, D

        input_seq = jnp.swapaxes(input_seq, 0, 1)
        input_mask = jnp.swapaxes(input_mask, 0, 1)
        reset_mask = jnp.logical_not(input_mask)
        reset_mask = reset_mask.at[1:].set(
            reset_mask[:-1])  # move the mask to the right
        h0c0: hk.LSTMState = self.lstm.initial_state(B)  # type: ignore
        hx, state = hk.dynamic_unroll(self.lstm, (input_seq, reset_mask), h0c0)
        del state

        # split encoder/decoder states
        encoder_hx = hx[:self.padded_input_len]
        decoder_hx = hx[self.padded_input_len:]

        # append the initial hidden state.
        # this will be the encoder state for the [end] token.
        encoder_hx = jnp.concatenate([encoder_hx, h0c0.hidden[None]], axis=0)

        # create query and value for attention mechanism
        encoder_value = self.enc_att_fc(encoder_hx)[None]
        decoder_query = self.dec_att_fc(decoder_hx)[:, None]

        # energy function
        energy = encoder_value * decoder_query
        energy = jnp.sum(energy, axis=-1) / math.sqrt(energy.shape[-1])

        # apply input sequence mask
        input_mask = input_mask[:self.padded_input_len + 1][None]
        energy = jnp.where(input_mask, energy, float('-inf'))

        # normalize
        energy = jax.nn.log_softmax(energy, axis=1)

        # batch first, logit last
        return jnp.transpose(energy, [2, 0, 1])
Пример #22
0
def compute_tp_fp_fn_weighted(
    predictions: Array, labels: Array, weights: Array,
    ignore_class: Optional[int]) -> Tuple[float, float, float]:
  """Compute true positives, false positives and false negatives.

  Args:
   predictions: [batch, length] categorical predictions int array.
   labels: [batch, length] categorical labels int array.
   weights: [batch, length].
   ignore_class: which class to ignore in the computations

  Returns:
    Tuple with numbers of true positive, false positive and false negative
      predictions.
  """
  true_positives = (predictions == labels)
  false_positives = jnp.logical_not(true_positives)
  false_negatives = false_positives

  if ignore_class is not None:
    dont_ignore_predictions = (predictions != ignore_class)
    dont_ignore_labels = (labels != ignore_class)
    true_positives = jnp.logical_and(true_positives, dont_ignore_predictions)
    # Exactly the same as
    # true_positives = jnp.logical_and(true_positives, dont_ignore_labels)
    # since for true positives `dont_ignore_predictions` = `dont_ignore_labels`.
    false_positives = jnp.logical_and(false_positives, dont_ignore_predictions)
    false_negatives = jnp.logical_and(false_negatives, dont_ignore_labels)

  def get_weighted_sum(values):
    values = values.astype(weights.dtype)
    return jnp.dot(values, weights)

  n_true_positive = get_weighted_sum(true_positives)
  n_false_positive = get_weighted_sum(false_positives)
  n_false_negative = get_weighted_sum(false_negatives)

  return n_true_positive, n_false_positive, n_false_negative
Пример #23
0
    def add_mean(i, means):
        mask = jnp.arange(num_means) < i
        mask = mask * 1. + (jnp.logical_not(mask)) * jnp.inf

        def spaced_mean_body(state):
            key, means, mask, unused_sample = state
            key, subkey = jax.random.split(key)
            sample = jax.random.uniform(subkey,
                                        shape=[data_dim],
                                        minval=bounds[0],
                                        maxval=bounds[1])
            return key, means, mask, sample

        def spaced_mean_cond(state):
            unused_key, means, mask, sample = state
            dists = mask * jax.vmap(dist, in_axes=(0, None))(means, sample)
            return jnp.any(jnp.less(dists, min_distance))

        _, _, _, new_mean = jax.lax.while_loop(spaced_mean_cond,
                                               spaced_mean_body,
                                               (key, means, mask, means[0]))
        means = jax.ops.index_update(means, i, new_mean)
        return means
Пример #24
0
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
                        *, update_jaxpr, update_consts, dimension_numbers,
                        indices_are_sorted, unique_indices, mode):
    """Checks if indices are within bounds and update does not generate NaN."""
    out = prim.bind(operand,
                    indices,
                    updates,
                    update_jaxpr=update_jaxpr,
                    update_consts=update_consts,
                    dimension_numbers=dimension_numbers,
                    indices_are_sorted=indices_are_sorted,
                    unique_indices=unique_indices,
                    mode=mode)

    if ErrorCategory.OOB not in enabled_errors:
        return out, error

    in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers)
    oob_msg = f'out-of-bounds indexing while updating at {summary()}'
    oob_error = assert_func(error, in_bounds, oob_msg, None)

    no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
    nan_msg = f'nan generated by primitive {prim.name} at {summary()}'
    return out, assert_func(oob_error, no_nans, nan_msg, None)
Пример #25
0
 def while_cond(state):
     done, _, _ = state
     return jnp.logical_not(done)
Пример #26
0
def newton(
    backward_differences: np.ndarray,
    max_num_iters: Union[np.ndarray, float, int],
    newton_coefficient: Union[np.ndarray, float, int],
    ode_fn_vec: Callable,
    order: Union[np.ndarray, float, int],
    step_size: Union[np.ndarray, float, int],
    time: Union[np.ndarray, float, int],
    tol: Union[np.ndarray, float],
    unitary: np.ndarray,
    upper: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Runs Newton's method to solve the BDF equation."""
    initial_guess = np.sum(
        np.where(
            np.arange(MAX_ORDER + 1).reshape(-1, 1) <= order,
            backward_differences[:MAX_ORDER + 1],
            np.zeros_like(backward_differences)[:MAX_ORDER + 1],
        ),
        axis=0,
    )

    rhs_constant_term = newton_coefficient * np.sum(
        np.where(
            np.arange(1, MAX_ORDER + 1).reshape(-1, 1) <= order,
            RECIPROCAL_SUMS[1:, np.newaxis] *
            backward_differences[1:MAX_ORDER + 1],
            np.zeros_like(backward_differences)[1:MAX_ORDER + 1],
        ),
        axis=0,
    )

    next_time = time + step_size

    def newton_body(iterand):
        """Performs one iteration of Newton's method."""
        next_backward_difference = iterand.next_backward_difference
        next_state_vec = iterand.next_state_vec

        rhs = (newton_coefficient * step_size *
               ode_fn_vec(next_time, next_state_vec) - rhs_constant_term -
               next_backward_difference)
        delta = np.squeeze(
            jax.scipy.linalg.solve_triangular(upper,
                                              np.matmul(
                                                  np.transpose(unitary),
                                                  rhs[:, np.newaxis]),
                                              lower=False))
        num_iters = iterand.num_iters + 1

        next_backward_difference += delta
        next_state_vec += delta

        delta_norm = np.linalg.norm(delta)
        lipschitz_const = delta_norm / iterand.prev_delta_norm

        # Stop if method has converged.
        approx_dist_to_sol = lipschitz_const / (1.0 -
                                                lipschitz_const) * delta_norm
        close_to_sol = approx_dist_to_sol < tol
        delta_norm_is_zero = np.equal(delta_norm,
                                      np.array(0.0, dtype=np.float64))
        converged = close_to_sol | delta_norm_is_zero
        finished = converged

        # Stop if any of the following conditions are met:
        # (A) We have hit the maximum number of iterations.
        # (B) The method is converging too slowly.
        # (C) The method is not expected to converge.
        too_slow = lipschitz_const > 1.0
        finished = finished | too_slow

        too_many_iters = np.equal(num_iters, max_num_iters)
        num_iters_left = max_num_iters - num_iters
        wont_converge = approx_dist_to_sol * lipschitz_const**num_iters_left > tol
        finished = finished | too_many_iters | wont_converge

        return _NewtonIterand(
            converged=converged,
            finished=finished,
            next_backward_difference=next_backward_difference,
            next_state_vec=next_state_vec,
            num_iters=num_iters,
            prev_delta_norm=delta_norm,
        )

    iterand = _NewtonIterand(
        converged=False,
        finished=False,
        next_backward_difference=np.zeros_like(initial_guess),
        next_state_vec=initial_guess,
        num_iters=0,
        prev_delta_norm=(np.array(-0.0)),
    )
    iterand = jax.lax.while_loop(
        lambda iterand: np.logical_not(iterand.finished), newton_body, iterand)
    ## Krishna: need to double check this
    return (
        iterand.converged,
        iterand.next_backward_difference,
        iterand.next_state_vec,
        iterand.num_iters,
    )
Пример #27
0
 def cond_fun(carry):
     _n, x_last, x = carry
     converged = convergence_condition(a, x, x_last)
     not_exceeded = lax.lt(_n, max_iter)
     return np.logical_and(np.logical_not(converged), not_exceeded)
def loss_fn(
    model,
    padded_example_and_rng,
    static_metadata,
    regularization_weights = None,
    reinforce_weight = 1.0,
    baseline_weight = 0.001,
):
  """Loss function for multi-pointer task.

  Args:
    model: The model to evaluate.
    padded_example_and_rng: Padded example to evaluate on, with a PRNGKey.
    static_metadata: Padding configuration for the example, since this may vary
      for different examples.
    regularization_weights: Associates side output key regexes with
      regularization penalties.
    reinforce_weight: Weight to give to the reinforce term.
    baseline_weight: Weight to give to the baseline.

  Returns:
    Tuple of loss and metrics.
  """
  padded_example, rng = padded_example_and_rng

  # Run the model.
  with side_outputs.collect_side_outputs() as collected_side_outputs:
    with flax.nn.stochastic(rng):
      joint_log_probs = model(padded_example, static_metadata)

  # Computing the loss:
  # Extract logits for the correct location.
  log_probs_at_bug = joint_log_probs[padded_example.bug_node_index, :]
  # Compute p(repair) = sum[ p(node) p(repair | node) ]
  # -> log p(repair) = logsumexp[ log p(node) + log p (repair | node) ]
  log_prob_joint = jax.scipy.special.logsumexp(
      log_probs_at_bug + jnp.log(padded_example.repair_node_mask))

  # Metrics:
  # Marginal log probabilities:
  log_prob_bug = jax.scipy.special.logsumexp(log_probs_at_bug)
  log_prob_repair = jax.scipy.special.logsumexp(
      jax.scipy.special.logsumexp(joint_log_probs, axis=0) +
      jnp.log(padded_example.repair_node_mask))

  # Conditional log probabilities:
  log_prob_repair_given_bug = log_prob_joint - log_prob_bug
  log_prob_bug_given_repair = log_prob_joint - log_prob_repair

  # Majority accuracy (1 if we assign the correct tuple > 50%):
  # (note that this is easier to compute, since we can't currently aggregate
  # probability separately for each candidate.)
  log_half = jnp.log(0.5)
  majority_acc_joint = log_prob_joint > log_half

  # Probabilities associated with each node.
  node_node_probs = jnp.exp(joint_log_probs)
  # Accumulate across unique candidates by identifier. This has the same shape,
  # but only the first few values will be populated.
  node_candidate_probs = padded_example.unique_candidate_operator.apply_add(
      in_array=node_node_probs,
      out_array=jnp.zeros_like(node_node_probs),
      in_dims=[1],
      out_dims=[1])

  # Classify: 50% decision boundary
  only_buggy_probs = node_candidate_probs.at[0, :].set(0).at[:, 0].set(0)
  p_buggy = jnp.sum(only_buggy_probs)
  pred_nobug = p_buggy <= 0.5

  # Localize/repair: take most likely bug position, conditioned on being buggy
  pred_bug_loc, pred_cand_id = jnp.unravel_index(
      jnp.argmax(only_buggy_probs), only_buggy_probs.shape)

  actual_nobug = jnp.array(padded_example.bug_node_index == 0)

  actual_bug = jnp.logical_not(actual_nobug)
  pred_bug = jnp.logical_not(pred_nobug)

  metrics = {
      'nll/joint':
          -log_prob_joint,
      'nll/marginal_bug':
          -log_prob_bug,
      'nll/marginal_repair':
          -log_prob_repair,
      'nll/repair_given_bug':
          -log_prob_repair_given_bug,
      'nll/bug_given_repair':
          -log_prob_bug_given_repair,
      'inaccuracy/legacy_overall':
          1 - majority_acc_joint,
      'inaccuracy/overall':
          (~((actual_nobug & pred_nobug) |
             (actual_bug & pred_bug &
              (pred_bug_loc == padded_example.bug_node_index) &
              (pred_cand_id == padded_example.repair_id)))),
      'inaccuracy/classification_overall': (actual_nobug != pred_nobug),
      'inaccuracy/classification_given_nobug':
          train_util.RatioMetric(
              numerator=(actual_nobug & ~pred_nobug), denominator=actual_nobug),
      'inaccuracy/classification_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug & ~pred_bug), denominator=actual_bug),
      'inaccuracy/localized_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_bug_loc == padded_example.bug_node_index)),
              denominator=actual_bug),
      'inaccuracy/repaired_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_cand_id == padded_example.repair_id)),
              denominator=actual_bug),
      'inaccuracy/localized_repaired_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~((pred_bug_loc == padded_example.bug_node_index) &
                             (pred_cand_id == padded_example.repair_id))),
              denominator=actual_bug),
      'inaccuracy/overall_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_bug &
                             (pred_bug_loc == padded_example.bug_node_index) &
                             (pred_cand_id == padded_example.repair_id))),
              denominator=actual_bug),
  }

  loss = -log_prob_joint

  for k, v in collected_side_outputs.items():
    # Flax collection keys will start with "/".
    if v.shape == ():  # pylint: disable=g-explicit-bool-comparison
      metrics['side' + k] = v

  if regularization_weights:
    total_regularization = 0
    for query, weight in regularization_weights.items():
      logging.info('Regularizing side outputs matching query %s', query)
      found = False
      for k, v in collected_side_outputs.items():
        if re.search(query, k):
          found = True
          logging.info('Regularizing %s with weight %f', k, weight)
          total_regularization += weight * v
      if not found:
        raise ValueError(
            f'Regularization query {query} did not match any side output. '
            f'Side outputs were {set(collected_side_outputs.keys())}')

    loss = loss + total_regularization

  is_single_sample = any(
      k.endswith('one_sample_log_prob_per_edge_per_node')
      for k in collected_side_outputs)
  if is_single_sample:
    log_prob, = [
        v for k, v in collected_side_outputs.items()
        if k.endswith('one_sample_log_prob_per_edge_per_node')
    ]
    baseline, = [
        v for k, v in collected_side_outputs.items()
        if k.endswith('one_sample_reward_baseline')
    ]

    num_real_nodes = padded_example.input_graph.bundle.graph_metadata.num_nodes
    valid_mask = (
        jnp.arange(static_metadata.bundle_padding.static_max_metadata.num_nodes)
        < num_real_nodes)
    log_prob = jnp.where(valid_mask[None, :], log_prob, 0)
    total_log_prob = jnp.sum(log_prob)

    reinforce_virtual_cost = (
        total_log_prob * jax.lax.stop_gradient(loss - baseline))
    baseline_penalty = jnp.square(loss - baseline)

    reinforce_virtual_cost_zeroed = reinforce_virtual_cost - jax.lax.stop_gradient(
        reinforce_virtual_cost)

    loss = (
        loss + reinforce_weight * reinforce_virtual_cost_zeroed +
        baseline_weight * baseline_penalty)
    metrics['reinforce_virtual_cost'] = reinforce_virtual_cost
    metrics['baseline_penalty'] = baseline_penalty
    metrics['baseline'] = baseline
    metrics['total_log_prob'] = total_log_prob

  metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics)
  return loss, metrics
Пример #29
0
 def cond_fn(*args):
     """ check if all are done or reached max number of iterations """
     i, _, done, _, _ = args[0]
     return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
Пример #30
0
 def while_cond(val):
     possible_nan = jnp.sin(1. / val)
     return jnp.logical_not(jnp.isnan(possible_nan))