コード例 #1
0
ファイル: test_utils.py プロジェクト: jackd/grax
def random_adjacency(key: jnp.ndarray,
                     num_nodes: int,
                     num_edges: int,
                     dtype=jnp.float32) -> COO:
    """
    Get the adjacency matrix of a random fully connected undirected graph.

    Note that `num_edges` is only approximate. The process of creating edges it:
    - sample `num_edges` random edges
    - remove self-edges
    - add ring edges
    - add reverse edges
    - filter duplicates

    Args:
        key: `jax.random.PRNGKey`.
        num_nodes: number of nodes in returned graph.
        num_edges: number of random internal edges initially added.
        dtype: dtype of returned JAXSparse.

    Returns:
        COO, shape (num_nodes, num_nodes), weights all ones.
    """
    shape = num_nodes, num_nodes

    internal_indices = jax.random.uniform(
        key,
        shape=(num_edges, ),
        dtype=jnp.float32,
        maxval=num_nodes**2,
    ).astype(jnp.int32)
    # remove randomly sampled self-edges.
    self_edges = (internal_indices // num_nodes) == (internal_indices %
                                                     num_nodes)
    internal_indices = internal_indices[jnp.logical_not(self_edges)]

    # add a ring so we know the graph is connected
    r = jnp.arange(num_nodes, dtype=jnp.int32)
    ring_indices = r * num_nodes + (r + 1) % num_nodes
    indices = jnp.concatenate((internal_indices, ring_indices))

    # add reverse indices
    coords = jnp.unravel_index(indices, shape)
    coords_rev = coords[-1::-1]
    indices_rev = jnp.ravel_multi_index(coords_rev, shape)
    indices = jnp.concatenate((indices, indices_rev))

    # filter out duplicates
    indices = jnp.unique(indices)
    row, col = jnp.unravel_index(indices, shape)
    return COO((jnp.ones((row.size, ), dtype=dtype), row, col), shape=shape)
コード例 #2
0
ファイル: perceiver.py プロジェクト: yynst2/deepmind-research
    def decoder_query(self,
                      inputs,
                      modality_sizes=None,
                      inputs_without_pos=None,
                      subsampled_points=None):
        assert self._position_encoding_type != 'none'  # Queries come from elsewhere
        if subsampled_points is not None:
            # unravel_index returns a tuple (x_idx, y_idx, ...)
            # stack to get the [n, d] tensor of coordinates
            pos = jnp.stack(jnp.unravel_index(subsampled_points,
                                              self._output_index_dim),
                            axis=1)
            # Map these coordinates to [-1, 1]
            pos = -1 + 2 * pos / jnp.array(self._output_index_dim)[None, :]
            pos = jnp.broadcast_to(
                pos[None], [inputs.shape[0], pos.shape[0], pos.shape[1]])
            pos_emb = self.output_pos_enc(batch_size=inputs.shape[0], pos=pos)
            pos_emb = jnp.reshape(pos_emb,
                                  [pos_emb.shape[0], -1, pos_emb.shape[-1]])
        else:
            pos_emb = self.output_pos_enc(batch_size=inputs.shape[0])
        if self._concat_preprocessed_input:
            if inputs_without_pos is None:
                raise ValueError('Value is required for inputs_without_pos if'
                                 ' concat_preprocessed_input is True')
            pos_emb = jnp.concatenate([inputs_without_pos, pos_emb], axis=-1)

        return pos_emb
コード例 #3
0
ファイル: random.py プロジェクト: matthewfeickert/jax
 def _indices(key):
     if not sparse_shape:
         return jnp.empty((nse, n_sparse), dtype=int)
     flat_ind = random.choice(key,
                              sparse_size,
                              shape=(nse, ),
                              replace=not unique_indices)
     return jnp.column_stack(jnp.unravel_index(flat_ind, sparse_shape))
コード例 #4
0
def index_to_coordinate_array(idx, offset=4, repeat=1):
    # Turn an array of index values into a tuple of coordinate arrays
    H, W, C = idx.shape[:3]

    # The input indices will be spread out by some offset
    flat_coordinates = idx.ravel() + offset * jnp.arange(
        H * W * C).repeat(repeat)

    return jnp.unravel_index(flat_coordinates, (H, W, C, offset))
コード例 #5
0
ファイル: _mountain.py プロジェクト: shpotes/clustering
def get_cluster(prototypes, prototypes_density):
    num_dims = prototypes.shape[1]
    cluster_id = jnp.unravel_index(jnp.argmax(prototypes_density),
                                   prototypes_density.shape)
    cluster = prototypes[(0, tuple(range(num_dims)), *cluster_id)]
    cluster = jnp.expand_dims(cluster,
                              axis=tuple(range(1, num_dims + 1)))[jnp.newaxis,
                                                                  ...]
    return cluster, prototypes_density[cluster_id]
コード例 #6
0
def propose_spin_flip_Z2(key, s, info):
    idxKey, flipKey = jax.random.split(key)
    idx = random.randint(idxKey, (1, ), 0, s.size)[0]
    idx = jnp.unravel_index(idx, s.shape)
    update = (s[idx] + 1) % 2
    s = jax.ops.index_update(s, jax.ops.index[idx], update)
    # On average, do a global spin flip every 30 updates to
    # reflect Z_2 symmetry
    doFlip = random.randint(flipKey, (1, ), 0, 5)[0]
    return jax.lax.cond(doFlip == 0, lambda x: 1 - x, lambda x: x, s)
コード例 #7
0
ファイル: checkify.py プロジェクト: xueeinstein/jax
def gather_error_check(error, enabled_errors, operand, start_indices, *,
                       dimension_numbers, slice_sizes, unique_indices,
                       indices_are_sorted, mode, fill_value):
    out = lax.gather_p.bind(operand,
                            start_indices,
                            dimension_numbers=dimension_numbers,
                            slice_sizes=slice_sizes,
                            unique_indices=unique_indices,
                            indices_are_sorted=indices_are_sorted,
                            mode=mode,
                            fill_value=fill_value)

    if ErrorCategory.OOB not in enabled_errors:
        return out, error

    # compare to OOB masking logic in lax._gather_translation_rule
    dnums = dimension_numbers
    operand_dims = np.array(operand.shape)
    num_batch_dims = len(start_indices.shape) - 1

    upper_bound = operand_dims[np.array(dnums.start_index_map)]
    upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
    upper_bound = jnp.expand_dims(upper_bound,
                                  axis=tuple(range(num_batch_dims)))
    in_bounds = (start_indices >= 0) & (start_indices <= upper_bound.astype(
        start_indices.dtype))

    # Get first OOB index, axis and axis size so it can be added to the error msg.
    flat_idx = jnp.argmin(in_bounds)
    multi_idx = jnp.unravel_index(flat_idx, start_indices.shape)
    oob_axis = jnp.array(dnums.start_index_map)[multi_idx[-1]]
    oob_axis_size = jnp.array(operand.shape)[oob_axis]
    oob_index = jnp.ravel(start_indices)[flat_idx]
    payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)

    msg = (f'out-of-bounds indexing at {summary()} for array of '
           f'shape {operand.shape}: '
           'index {payload0} is out of bounds for axis {payload1} '
           'with size {payload2}.')

    return out, assert_func(error, jnp.all(in_bounds), msg, payload)
コード例 #8
0
ファイル: jax_dispatch.py プロジェクト: jeffreyenos/aesara
 def unravelindex(indices, dims, order=order):
     return jnp.unravel_index(indices, dims)
コード例 #9
0
def loss_fn(
    model,
    padded_example_and_rng,
    static_metadata,
    regularization_weights = None,
    reinforce_weight = 1.0,
    baseline_weight = 0.001,
):
  """Loss function for multi-pointer task.

  Args:
    model: The model to evaluate.
    padded_example_and_rng: Padded example to evaluate on, with a PRNGKey.
    static_metadata: Padding configuration for the example, since this may vary
      for different examples.
    regularization_weights: Associates side output key regexes with
      regularization penalties.
    reinforce_weight: Weight to give to the reinforce term.
    baseline_weight: Weight to give to the baseline.

  Returns:
    Tuple of loss and metrics.
  """
  padded_example, rng = padded_example_and_rng

  # Run the model.
  with side_outputs.collect_side_outputs() as collected_side_outputs:
    with flax.nn.stochastic(rng):
      joint_log_probs = model(padded_example, static_metadata)

  # Computing the loss:
  # Extract logits for the correct location.
  log_probs_at_bug = joint_log_probs[padded_example.bug_node_index, :]
  # Compute p(repair) = sum[ p(node) p(repair | node) ]
  # -> log p(repair) = logsumexp[ log p(node) + log p (repair | node) ]
  log_prob_joint = jax.scipy.special.logsumexp(
      log_probs_at_bug + jnp.log(padded_example.repair_node_mask))

  # Metrics:
  # Marginal log probabilities:
  log_prob_bug = jax.scipy.special.logsumexp(log_probs_at_bug)
  log_prob_repair = jax.scipy.special.logsumexp(
      jax.scipy.special.logsumexp(joint_log_probs, axis=0) +
      jnp.log(padded_example.repair_node_mask))

  # Conditional log probabilities:
  log_prob_repair_given_bug = log_prob_joint - log_prob_bug
  log_prob_bug_given_repair = log_prob_joint - log_prob_repair

  # Majority accuracy (1 if we assign the correct tuple > 50%):
  # (note that this is easier to compute, since we can't currently aggregate
  # probability separately for each candidate.)
  log_half = jnp.log(0.5)
  majority_acc_joint = log_prob_joint > log_half

  # Probabilities associated with each node.
  node_node_probs = jnp.exp(joint_log_probs)
  # Accumulate across unique candidates by identifier. This has the same shape,
  # but only the first few values will be populated.
  node_candidate_probs = padded_example.unique_candidate_operator.apply_add(
      in_array=node_node_probs,
      out_array=jnp.zeros_like(node_node_probs),
      in_dims=[1],
      out_dims=[1])

  # Classify: 50% decision boundary
  only_buggy_probs = node_candidate_probs.at[0, :].set(0).at[:, 0].set(0)
  p_buggy = jnp.sum(only_buggy_probs)
  pred_nobug = p_buggy <= 0.5

  # Localize/repair: take most likely bug position, conditioned on being buggy
  pred_bug_loc, pred_cand_id = jnp.unravel_index(
      jnp.argmax(only_buggy_probs), only_buggy_probs.shape)

  actual_nobug = jnp.array(padded_example.bug_node_index == 0)

  actual_bug = jnp.logical_not(actual_nobug)
  pred_bug = jnp.logical_not(pred_nobug)

  metrics = {
      'nll/joint':
          -log_prob_joint,
      'nll/marginal_bug':
          -log_prob_bug,
      'nll/marginal_repair':
          -log_prob_repair,
      'nll/repair_given_bug':
          -log_prob_repair_given_bug,
      'nll/bug_given_repair':
          -log_prob_bug_given_repair,
      'inaccuracy/legacy_overall':
          1 - majority_acc_joint,
      'inaccuracy/overall':
          (~((actual_nobug & pred_nobug) |
             (actual_bug & pred_bug &
              (pred_bug_loc == padded_example.bug_node_index) &
              (pred_cand_id == padded_example.repair_id)))),
      'inaccuracy/classification_overall': (actual_nobug != pred_nobug),
      'inaccuracy/classification_given_nobug':
          train_util.RatioMetric(
              numerator=(actual_nobug & ~pred_nobug), denominator=actual_nobug),
      'inaccuracy/classification_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug & ~pred_bug), denominator=actual_bug),
      'inaccuracy/localized_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_bug_loc == padded_example.bug_node_index)),
              denominator=actual_bug),
      'inaccuracy/repaired_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_cand_id == padded_example.repair_id)),
              denominator=actual_bug),
      'inaccuracy/localized_repaired_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~((pred_bug_loc == padded_example.bug_node_index) &
                             (pred_cand_id == padded_example.repair_id))),
              denominator=actual_bug),
      'inaccuracy/overall_given_bug':
          train_util.RatioMetric(
              numerator=(actual_bug
                         & ~(pred_bug &
                             (pred_bug_loc == padded_example.bug_node_index) &
                             (pred_cand_id == padded_example.repair_id))),
              denominator=actual_bug),
  }

  loss = -log_prob_joint

  for k, v in collected_side_outputs.items():
    # Flax collection keys will start with "/".
    if v.shape == ():  # pylint: disable=g-explicit-bool-comparison
      metrics['side' + k] = v

  if regularization_weights:
    total_regularization = 0
    for query, weight in regularization_weights.items():
      logging.info('Regularizing side outputs matching query %s', query)
      found = False
      for k, v in collected_side_outputs.items():
        if re.search(query, k):
          found = True
          logging.info('Regularizing %s with weight %f', k, weight)
          total_regularization += weight * v
      if not found:
        raise ValueError(
            f'Regularization query {query} did not match any side output. '
            f'Side outputs were {set(collected_side_outputs.keys())}')

    loss = loss + total_regularization

  is_single_sample = any(
      k.endswith('one_sample_log_prob_per_edge_per_node')
      for k in collected_side_outputs)
  if is_single_sample:
    log_prob, = [
        v for k, v in collected_side_outputs.items()
        if k.endswith('one_sample_log_prob_per_edge_per_node')
    ]
    baseline, = [
        v for k, v in collected_side_outputs.items()
        if k.endswith('one_sample_reward_baseline')
    ]

    num_real_nodes = padded_example.input_graph.bundle.graph_metadata.num_nodes
    valid_mask = (
        jnp.arange(static_metadata.bundle_padding.static_max_metadata.num_nodes)
        < num_real_nodes)
    log_prob = jnp.where(valid_mask[None, :], log_prob, 0)
    total_log_prob = jnp.sum(log_prob)

    reinforce_virtual_cost = (
        total_log_prob * jax.lax.stop_gradient(loss - baseline))
    baseline_penalty = jnp.square(loss - baseline)

    reinforce_virtual_cost_zeroed = reinforce_virtual_cost - jax.lax.stop_gradient(
        reinforce_virtual_cost)

    loss = (
        loss + reinforce_weight * reinforce_virtual_cost_zeroed +
        baseline_weight * baseline_penalty)
    metrics['reinforce_virtual_cost'] = reinforce_virtual_cost
    metrics['baseline_penalty'] = baseline_penalty
    metrics['baseline'] = baseline
    metrics['total_log_prob'] = total_log_prob

  metrics = jax.tree_map(lambda x: x.astype(jnp.float32), metrics)
  return loss, metrics
コード例 #10
0
def unravel_index(indices, shape):
  indices = _remove_jaxarray(indices)
  shape = _remove_jaxarray(shape)
  return jnp.unravel_index(indices, shape)
コード例 #11
0
import mpi4jax  # noqa: E402

#
# MPI setup
#

supported_nproc = (1, 2, 4, 6, 8, 16)
if mpi_size not in supported_nproc:
    raise RuntimeError(f"Got invalid number of MPI processes: {mpi_size}. "
                       f"Please choose one of these: {supported_nproc}.")

nproc_y = min(mpi_size, 2)
nproc_x = mpi_size // nproc_y

proc_idx = jnp.unravel_index(mpi_rank, (nproc_y, nproc_x))

#
# Grid setup
#

# we use 1 cell overlap on each side of the domain
nx_global = 360 + 2
ny_global = 180 + 2

# grid spacing in metres
dx = 5e3
dy = 5e3

# make sure processes divide the domain evenly
assert (nx_global - 2) % nproc_x == 0
コード例 #12
0
    def body(state: State):
        new_state_date = dict()
        # upon the start of each iteration the state is consistent.
        # we use the consistent state to calculate the reassignment metrics.
        # we then reassign and update the state so that it is consistent again.
        # K, N
        # K
        log_f_k = log_factor_k(state.cluster_id, state.log_maha_k, state.num_k,
                               state.logdetC_k)

        def single_log_h(log_f_k, log_maha_k, num_k, logdetC_k):
            log_d = log_maha_k + log_f_k
            log_VS_k = log_VS + jnp.log(num_k) - jnp.log(num_S)
            return log_ellipsoid_volume(logdetC_k, num_k,
                                        log_f_k) + log_d - log_VS_k

        # K, N
        log_h_k = vmap(single_log_h)(log_f_k, state.log_maha_k, state.num_k,
                                     state.logdetC_k)
        h_k = jnp.exp(log_h_k)
        # # K, K, N
        delta_F = h_k[:, None, :] - h_k
        # Can reassign if mask says we are working on that node and there would be at least D+1 points in that cluster
        # after taking from it. And, if delta_F < 0.
        able_to_reassign = mask & (state.num_k[state.cluster_id] > D + 1)
        delta_F_masked = jnp.where(able_to_reassign, delta_F, jnp.inf)

        # (k_to, k_from, n_reassign) = jnp.where(delta_F == min_delta_F)
        (k_to, k_from,
         n_reassign) = jnp.unravel_index(jnp.argmin(delta_F_masked.flatten()),
                                         delta_F.shape)
        # dynamic update index arrays of sufficient length for all
        dyn_k_to_idx = jnp.concatenate([k_to[None], jnp.asarray([0, 0])])
        dyn_k_from_idx = jnp.concatenate([k_from[None], jnp.asarray([0, 0])])

        ###
        # update the state

        ###
        # cluster id
        cluster_id = dynamic_update_slice(state.cluster_id, dyn_k_to_idx[0:1],
                                          n_reassign[None])

        ###
        # num_k
        num_from = state.num_k[k_from] - 1
        num_to = state.num_k[k_from] + 1
        num_k = dynamic_update_slice(state.num_k, num_from[None],
                                     dyn_k_from_idx[0:1])
        num_k = dynamic_update_slice(num_k, num_to[None], dyn_k_to_idx[0:1])

        ###
        # ellipsoid parameters
        x_n = points[n_reassign, :]
        mu_from = state.mu_k[k_from, :] + (state.mu_k[k_from, :] -
                                           x_n) / (state.num_k[k_from] - 1)
        C_from, logdetC_from = rank_one_update_matrix_inv(
            state.C_k[k_from, :, :],
            state.logdetC_k[k_from],
            x_n - mu_from,
            x_n - state.mu_k[k_from, :],
            add=False)
        # print(C_from, logdetC_from)
        mu_to = state.mu_k[
            k_to, :] + (x_n - state.mu_k[k_to, :]) / (state.num_k[k_to] + 1)
        C_to, logdetC_to = rank_one_update_matrix_inv(state.C_k[k_to, :, :],
                                                      state.logdetC_k[k_to],
                                                      x_n - mu_to,
                                                      x_n -
                                                      state.mu_k[k_to, :],
                                                      add=True)
        print('from', state.logdetC_k[k_from])
        # print(C_to, logdetC_to)
        mu_k = dynamic_update_slice(state.mu_k, mu_from[None, :],
                                    dyn_k_from_idx[0:2])
        mu_k = dynamic_update_slice(mu_k, mu_to[None, :], dyn_k_to_idx[0:2])
        C_k = dynamic_update_slice(state.C_k, C_from[None, :, :],
                                   dyn_k_from_idx)
        C_k = dynamic_update_slice(C_k, C_to[None, :, :], dyn_k_to_idx)
        logdetC_k = dynamic_update_slice(state.logdetC_k, logdetC_from[None],
                                         dyn_k_from_idx[0:1])
        logdetC_k = dynamic_update_slice(logdetC_k, logdetC_to[None],
                                         dyn_k_to_idx[0:1])

        ###
        # maha

        precision_from = C_from * num_from
        precision_to = C_to * num_to
        log_maha_from = jnp.log(
            vmap(lambda point: (point - mu_from) @ precision_from @ (
                point - mu_from))(points))
        log_maha_to = jnp.log(
            vmap(lambda point:
                 (point - mu_to) @ precision_to @ (point - mu_to))(points))
        log_maha_k = dynamic_update_slice(state.log_maha_k,
                                          log_maha_from[None, :],
                                          dyn_k_from_idx[0:2])
        log_maha_k = dynamic_update_slice(log_maha_k, log_maha_to[None, :],
                                          dyn_k_to_idx[0:2])

        # estimate volumes of current clustering
        log_f_k = log_factor_k(cluster_id, log_maha_k, num_k, logdetC_k)
        log_VE_k = vmap(log_ellipsoid_volume)(logdetC_k, num_k, log_f_k)
        log_VS_k = jnp.log(num_k) - jnp.log(num_S)
        log_V_sum = logsumexp(log_VE_k)
        new_loss = log_V_sum - log_VS
        loss_decreased = new_loss < state.min_loss
        delay = jnp.where(loss_decreased, 0, state.delay + 1)
        min_loss = jnp.where(loss_decreased, new_loss, state.min_loss)
        print(jnp.min(delta_F_masked), log_V_sum, logdetC_k)
        done = jnp.all(cluster_id == state.cluster_id) \
               | (delay >= 10) \
               | jnp.any(num_k < D + 1) \
               | jnp.isnan(log_V_sum) \
               | (jnp.min(delta_F_masked) >= 0.)
        # ['i', 'done', 'cluster_id', 'C_k', 'logdetC_k',
        # 'mu_k', 'log_maha_k', 'num_k',
        # 'log_VE_k', 'log_VS_k',
        # 'min_loss', 'delay']
        state = state._replace(i=state.i + 1,
                               done=done,
                               cluster_id=cluster_id,
                               C_k=C_k,
                               logdetC_k=logdetC_k,
                               mu_k=mu_k,
                               log_maha_k=log_maha_k,
                               num_k=num_k,
                               log_VE_k=log_VE_k,
                               log_VS_k=log_VS_k,
                               min_loss=min_loss,
                               delay=delay)
        return state
コード例 #13
0
def propose_spin_flip(key, s, info):
    idx = random.randint(key, (1, ), 0, s.size)[0]
    idx = jnp.unravel_index(idx, s.shape)
    update = (s[idx] + 1) % 2
    return jax.ops.index_update(s, jax.ops.index[idx], update)