Exemplo n.º 1
0
def summary_plot(losses, hmc_samples, hmcecs_samples, hmc_runtime, hmcecs_runtime):
    fig, ax = plt.subplots(2, 2)
    ax[0, 0].plot(losses, 'r')
    ax[0, 0].set_title('SVI losses')
    ax[0, 0].set_ylabel('ELBO')

    if hmc_runtime > hmcecs_runtime:
        ax[0, 1].bar([0], hmc_runtime, label='hmc', color='b')
        ax[0, 1].bar([0], hmcecs_runtime, label='hmcecs', color='r')
    else:
        ax[0, 1].bar([0], hmcecs_runtime, label='hmcecs', color='r')
        ax[0, 1].bar([0], hmc_runtime, label='hmc', color='b')
    ax[0, 1].set_title('Runtime')
    ax[0, 1].set_ylabel('Seconds')
    ax[0, 1].legend()
    ax[0, 1].set_xticks([])

    ax[1, 0].plot(jnp.sort(hmc_samples['theta'].mean(0)), 'or')
    ax[1, 0].plot(jnp.sort(hmcecs_samples['theta'].mean(0)), 'b')
    ax[1, 0].set_title(r'$\mathrm{\mathbb{E}}[\theta]$')

    ax[1, 1].plot(jnp.sort(hmc_samples['theta'].var(0)), 'or')
    ax[1, 1].plot(jnp.sort(hmcecs_samples['theta'].var(0)), 'b')
    ax[1, 1].set_title(r'Var$[\theta]$')

    for a in ax[1, :]:
        a.set_xticks([])

    fig.tight_layout()
    fig.savefig('hmcecs_plot.pdf', bbox_inches='tight')
Exemplo n.º 2
0
def summary_plot(losses, hmc_samples, hmcecs_samples, hmc_runtime,
                 hmcecs_runtime):
    fig, ax = plt.subplots(2, 2)
    ax[0, 0].plot(losses, "r")
    ax[0, 0].set_title("SVI losses")
    ax[0, 0].set_ylabel("ELBO")

    if hmc_runtime > hmcecs_runtime:
        ax[0, 1].bar([0], hmc_runtime, label="hmc", color="b")
        ax[0, 1].bar([0], hmcecs_runtime, label="hmcecs", color="r")
    else:
        ax[0, 1].bar([0], hmcecs_runtime, label="hmcecs", color="r")
        ax[0, 1].bar([0], hmc_runtime, label="hmc", color="b")
    ax[0, 1].set_title("Runtime")
    ax[0, 1].set_ylabel("Seconds")
    ax[0, 1].legend()
    ax[0, 1].set_xticks([])

    ax[1, 0].plot(jnp.sort(hmc_samples["theta"].mean(0)), "or")
    ax[1, 0].plot(jnp.sort(hmcecs_samples["theta"].mean(0)), "b")
    ax[1, 0].set_title(r"$\mathrm{\mathbb{E}}[\theta]$")

    ax[1, 1].plot(jnp.sort(hmc_samples["theta"].var(0)), "or")
    ax[1, 1].plot(jnp.sort(hmcecs_samples["theta"].var(0)), "b")
    ax[1, 1].set_title(r"Var$[\theta]$")

    for a in ax[1, :]:
        a.set_xticks([])

    fig.tight_layout()
    fig.savefig("hmcecs_plot.pdf", bbox_inches="tight")
Exemplo n.º 3
0
def test_DenseSymm_infeatures(symmetries, use_bias, mode):
    rng = nk.jax.PRNGSeq(0)

    g, hi, perms = _setup_symm(symmetries, N=8)

    if mode == "matrix":
        ma = nk.nn.DenseSymm(
            symmetries=perms,
            mode=mode,
            features=8,
            use_bias=use_bias,
            bias_init=uniform(),
        )
    else:
        ma = nk.nn.DenseSymm(
            symmetries=perms,
            shape=tuple(g.extent),
            mode=mode,
            features=8,
            use_bias=use_bias,
            bias_init=uniform(),
        )

    pars = ma.init(rng.next(), hi.random_state(rng.next(), 2).reshape(1, 2, -1))

    v = hi.random_state(rng.next(), 6).reshape(3, 2, -1)
    vals = [ma.apply(pars, v[..., p]) for p in np.asarray(perms)]
    for val in vals:
        assert jnp.allclose(jnp.sort(val, -1), jnp.sort(vals[0], -1))
Exemplo n.º 4
0
def get_logit_snip_masks(params, nn_density_level, predict, x_batch, batch_input_shape, GlOBAL_PRUNE_BOOL = True):
    
    def norm_square_logits(params, f, x): 
        return np.sum(f(params, x) **2)
    
    init_grads = grad(norm_square_logits)(params, predict, x_batch.reshape(batch_input_shape) ) 

    thres_list = [None] * len(params)

    if GlOBAL_PRUNE_BOOL == True: # global pruning

        cs = [abs( init_grads[idx][0] *  params[idx][0]).flatten() for idx in range(len(params)) if len(params[idx]) == 2 ]

        pooled_cs = np.hstack(cs)

        idx = int( (1 - nn_density_level) * len(pooled_cs) )

        # threshold: entries which below the thredhold will be removed
        thres = np.sort(pooled_cs)[idx]
        thres_list = [thres] * len(params)
        
    else: # layerwise pruning
        for layer_index in range( len(params)):
            if len(params[layer_index]) == 2:

                cs = abs( init_grads[layer_index][0] *  params[layer_index][0]).flatten()
                idx = int( (1 - nn_density_level) * len(cs) )
                # threshold: entries which below the thredhold will be removed
                thres = np.sort(cs)[idx]            
                thres_list[layer_index] = thres

            

    masks = []
    for layer_index in range( len(params)):

        if len(params[layer_index]) < 2:
            # In this the case, the layer does not contain weight and bias parameters.
            masks.append( [] )

        elif len(params[layer_index]) == 2:
            # In this case, the layer contains a tuple of parameters for weights and biases

            weights = params[layer_index][0]

            weights_grad = init_grads[layer_index][0]

            layer_cs = np.abs(weights * weights_grad)

            # 0 selected for weight parameters with magnitudes smaller than the threshold, 1 otherwise
            this_mask = np.float32(layer_cs >= thres_list[layer_index])

            masks.append(this_mask ) 


        else:
            raise NotImplementedError
            
    return masks
Exemplo n.º 5
0
    def test_neighbor_list_build_time_dependent(self, dtype, dim):
        key = random.PRNGKey(1)

        if dim == 2:
            box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32)
        elif dim == 3:
            box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0],
                                         [0.0, 0.0, 7.25]])
        min_length = np.min(np.diag(box_fn(0.)))
        cutoff = f32(1.23)
        # TODO(schsam): Get cell-list working with anisotropic cell sizes.
        cell_size = cutoff / min_length

        displacement, _ = space.periodic_general(box_fn)
        metric = space.metric(displacement)

        R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
        N = R.shape[0]
        neighbor_list_fn = partition.neighbor_list(metric,
                                                   1.,
                                                   cutoff,
                                                   0.0,
                                                   1.1,
                                                   cell_size=cell_size,
                                                   t=np.array(0.))

        idx = neighbor_list_fn(R, t=np.array(0.25)).idx
        R_neigh = R[idx]
        mask = idx < N

        metric = partial(metric, t=f32(0.25))
        d = vmap(vmap(metric, (None, 0)))
        dR = d(R, R_neigh)

        d_exact = space.map_product(metric)
        dR_exact = d_exact(R, R)

        dR = np.where(dR < cutoff, dR, 0) * mask
        dR_exact = np.where(dR_exact < cutoff, dR_exact, 0)

        dR = np.sort(dR, axis=1)
        dR_exact = np.sort(dR_exact, axis=1)

        for i in range(dR.shape[0]):
            dR_row = dR[i]
            dR_row = dR_row[dR_row > 0.]

            dR_exact_row = dR_exact[i]
            dR_exact_row = dR_exact_row[dR_exact_row > 0.]

            self.assertAllClose(dR_row, dR_exact_row)
Exemplo n.º 6
0
def test_forced_identifiability_prior():
    from jax import random
    prior = PriorChain().push(ForcedIdentifiabilityPrior('x', 10, 0., 10.))
    for i in range(10):
        out = prior(random.uniform(random.PRNGKey(i), shape=(prior.U_ndims, )))
        assert jnp.all(jnp.sort(out['x'], axis=0) == out['x'])
        assert jnp.all((out['x'] >= 0.) & (out['x'] <= 10.))
    prior = PriorChain().push(
        ForcedIdentifiabilityPrior('x', 10, jnp.array([0., 0.]), 10.))
    for i in range(10):
        out = prior(random.uniform(random.PRNGKey(i), shape=(prior.U_ndims, )))
        assert out['x'].shape == (10, 2)
        assert jnp.all(jnp.sort(out['x'], axis=0) == out['x'])
        assert jnp.all((out['x'] >= 0.) & (out['x'] <= 10.))
Exemplo n.º 7
0
def sample_pdf(key, bins, weights, origins, directions, z_vals,
               num_coarse_samples, use_stratified_sampling):
    """Hierarchical sampling.

  Args:
    key: jnp.ndarray(float32), [2,], random number generator.
    bins: jnp.ndarray(float32), [batch_size, n_bins + 1].
    weights: jnp.ndarray(float32), [batch_size, n_bins].
    origins: ray origins.
    directions: ray directions.
    z_vals: jnp.ndarray(float32), [batch_size, n_coarse_samples].
    num_coarse_samples: int, the number of samples.
    use_stratified_sampling: bool, use use_stratified_sampling samples.

  Returns:
    z_vals: jnp.ndarray(float32),
      [batch_size, n_coarse_samples + num_fine_samples].
    points: jnp.ndarray(float32),
      [batch_size, n_coarse_samples + num_fine_samples, 3].
  """
    z_samples = piecewise_constant_pdf(key, bins, weights, num_coarse_samples,
                                       use_stratified_sampling)
    # Compute united z_vals and sample points
    z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
    return z_vals, (origins[..., None, :] +
                    z_vals[..., None] * directions[..., None, :])
Exemplo n.º 8
0
Arquivo: data.py Projeto: jackd/grax
def subgraph(single: SemiSupervisedSingle,
             indices: jnp.ndarray) -> SemiSupervisedSingle:
    indices = jnp.asarray(indices, jnp.int32)
    assert indices.ndim == 1, indices.shape
    index_dtype = indices.dtype
    assert jnp.issubdtype(index_dtype, jnp.integer)
    indices = jnp.sort(indices)
    adj = ops.gather(ops.gather(single.graph, indices, axis=0),
                     indices,
                     axis=1)

    node_features = single.node_features
    if isinstance(node_features, JAXSparse):
        node_features = ops.gather(node_features, indices, axis=0)
    else:
        node_features = node_features[indices]

    remap_indices = (jnp.zeros(
        (single.graph.shape[0], ), index_dtype).at[indices].set(
            jnp.arange(indices.size, dtype=index_dtype)))

    def valid_ids(ids):
        return None if ids is None else remap_indices[ids]

    return SemiSupervisedSingle(
        node_features,
        adj,
        single.labels[indices],
        train_ids=valid_ids(single.train_ids),
        validation_ids=valid_ids(single.validation_ids),
        test_ids=valid_ids(single.test_ids),
    )
Exemplo n.º 9
0
def sample_per_class(
    rng: PRNGKey,
    labels: jnp.ndarray,
    examples_per_class: int,
    valid_indices: tp.Optional[jnp.ndarray] = None,
    num_classes: tp.Optional[int] = None,
) -> jnp.ndarray:
    assert labels.ndim == 1, labels.shape
    if valid_indices is None:
        valid_indices = jnp.arange(labels.size, dtype=jnp.int64)
        valid_labels = labels
    else:
        assert valid_indices.ndim == 1, valid_indices.shape
        valid_labels = labels[valid_indices]

    if num_classes is None:
        num_classes = jnp.max(labels) + 1
    all_class_indices = []
    for class_index, key in enumerate(jax.random.split(rng, num_classes)):
        (class_indices, ) = jnp.where(valid_labels == class_index)
        assert class_indices.size >= examples_per_class
        class_indices = jax.random.choice(key,
                                          class_indices,
                                          (examples_per_class, ),
                                          replace=False)
        class_indices = valid_indices[class_indices]
        all_class_indices.append(class_indices)
    class_indices = jnp.concatenate(all_class_indices)
    class_indices = jnp.sort(class_indices)
    return class_indices
Exemplo n.º 10
0
def sorted_topk_indicators(x, k, sort_by=SortBy.POSITION):
    """Finds the (sorted) positions of the topk values in x.

  Args:
    x: The input scores of dimension (d,).
    k: The number of top elements to find.
    sort_by: Strategy to order the extracted values. This is useful when this
      function is applied to many perturbed input and average. As topk's output
      does not have a fixed order, the indicator vectors could be swaped and the
      average of the indicators would not be spiky.

  Returns:
    Indicator vectors in a tensor of shape (k, d)
  """
    n = x.shape[-1]
    values, ranks = jax.lax.top_k(x, k)

    if sort_by == SortBy.NONE:
        sorted_ranks = ranks
    if sort_by == SortBy.VALUES:
        sorted_ranks = ranks[jnp.argsort(values)]
    if sort_by == SortBy.POSITION:
        sorted_ranks = jnp.sort(ranks)

    one_hot_fn = jax.vmap(functools.partial(jax.nn.one_hot, num_classes=n))
    indicators = one_hot_fn(sorted_ranks)
    return indicators
Exemplo n.º 11
0
  def rs(self, key, budget):
    """Random search sampling method with uniform distribution."""
    if operator.xor(self.x_precollect is None, self.y_precollect is None):
      raise ValueError(
          'Both x_precollect and y_precollect need to be provided.')

    if self.x_precollect is not None:
      ind_pre = point_in_search_space(self.x_precollect, self.search_space)
      x_precollect_in_search_space = self.x_precollect[ind_pre, :]
      y_precollect_in_search_space = self.y_precollect[ind_pre, :]

      if y_precollect_in_search_space.shape[0] < budget:
        raise ValueError(
            'budget is larger than the precollected data in the search space.')
      ind_chosen = jax.random.choice(
          key,
          y_precollect_in_search_space.shape[0],
          shape=(budget,),
          replace=False)
      ind_chosen = jnp.sort(ind_chosen)
      x = x_precollect_in_search_space[ind_chosen, :]
      y = y_precollect_in_search_space[ind_chosen, :]
      additional_info_dict = {}
    else:
      x = jax.random.uniform(
          key,
          shape=(budget, self.search_space.shape[0]),
          minval=self.search_space[:, 0],
          maxval=self.search_space[:, 1])
      y, additional_info_dict = self.objective_fn(x)
      ind_chosen = None
    return x, y, additional_info_dict, ind_chosen
Exemplo n.º 12
0
def nucleaus_filter(logits, top_p=0.9, top_k=None):
    sorted_logits = jnp.sort(logits)[:, ::-1]  # sort descending
    sorted_indices = jnp.argsort(logits)[:, ::-1]
    cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits), axis=-1)

    if top_k is not None:
        # Keep only top_k tokens
        indices_range = jnp.arange(len(sorted_indices[0]))
        indices_range = jnp.stack([indices_range] * len(sorted_indices),
                                  axis=0)

        sorted_indices_to_remove = jnp.where(indices_range > top_k,
                                             sorted_indices, 0)

        _, indices_to_remove = jax.lax.sort_key_val(sorted_indices,
                                                    sorted_indices_to_remove)

        logit_mask = 1e10 * indices_to_remove

        logits -= logit_mask

    # Remove tokens with cumulative probability above a threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove = jnp.concatenate((jnp.zeros_like(
        sorted_indices_to_remove[:, :1]), sorted_indices_to_remove),
                                               axis=-1)[:, :-1]

    _, indices_to_remove = jax.lax.sort_key_val(sorted_indices,
                                                sorted_indices_to_remove)

    logit_mask = 1e10 * indices_to_remove

    logits -= logit_mask

    return logits
Exemplo n.º 13
0
def l1_unit_projection(x):
  """Euclidean projection to L1 unit ball i.e. argmin_{|v|_1<= 1} |x-v|_2.

  Args:
    x: An array of size dim x num.

  Returns:
    An array of size dim x num, the projection to the unit L1 ball.
  """
  # https://dl.acm.org/citation.cfm?id=1390191
  xshape = x.shape
  if len(x.shape) == 1:
    x = x.reshape(-1, 1)
  eshape = x.shape
  v = jnp.abs(x.reshape((-1, eshape[-1])))
  u = jnp.sort(v, axis=0)
  u = u[::-1, :]  # descending
  arange = (1 + jnp.arange(eshape[0])).reshape((-1, 1))
  usum = (jnp.cumsum(u, axis=0) - 1) / arange
  rho = jnp.max(((u - usum) > 0) * arange - 1, axis=0, keepdims=True)
  thx = jnp.take_along_axis(usum, rho, axis=0)
  w = (v - thx).clip(a_min=0)
  w = jnp.where(jnp.linalg.norm(v, ord=1, axis=0, keepdims=True) > 1, w, v)
  x = w.reshape(eshape) * jnp.sign(x)
  return x.reshape(xshape)
Exemplo n.º 14
0
def _holm_bonferroni(p):
    """Performs Holm-Bonferroni correction for pvalues to account for multiple
    comparisons.

    Parameters
    ---------
    p: numpy.array
        array of pvalues

    Returns
    -------
    numpy.array
        corrected pvalues
    """
    K = len(p)
    sort_index = -np.ones(K, dtype=np.int64)
    sorted_p = np.sort(p)
    sorted_p_adj = sorted_p * (K - np.arange(K))
    for j in range(K):
        idx = (p == sorted_p[j]) & (sort_index < 0)
        num_ties = len(sort_index[idx])
        sort_index[idx] = np.arange(j, (j + num_ties), dtype=np.int64)

    sorted_holm_p = [min([max(sorted_p_adj[:k]), 1]) for k in range(1, K + 1)]
    holm_p = [sorted_holm_p[sort_index[k]] for k in range(K)]
    return holm_p
Exemplo n.º 15
0
def near_square_wave(
    n_train: int = 80,
    input_noise: float = 0.15,
    output_noise: float = 0.3,
    n_test: int = 400,
    random_state: int = 123,
):
    """Generates a near-square wave"""
    # function
    f = lambda x: np.sin(1.0 * np.pi / 1.6 * np.cos(5 + 0.5 * x))

    # create clean inputs
    x_mu = np.linspace(-10, 10, n_train)

    # clean outputs
    y = f(x_mu)

    # generate noise
    x_rng = check_random_state(random_state)
    y_rng = check_random_state(random_state + 1)

    # noisy inputs
    x = x_mu + input_noise * x_rng.randn(x_mu.shape[0])

    # noisy outputs
    y = f(x_mu) + output_noise * y_rng.randn(x_mu.shape[0])

    # test points
    x_test = np.linspace(-12, 12, n_test) + x_rng.randn(n_test)
    y_test = f(np.linspace(-12, 12, n_test))
    x_test = np.sort(x_test)

    return x[:, None], y[:, None], x_test[:, None], y_test
Exemplo n.º 16
0
def kth_percent_distance(dists: np.ndarray, k: float = 0.3) -> np.ndarray:
    """kth percent distance in a gram matrix

    This calculates the kth percent in an (NxN) matrix.
    It sorts all distance values and then retrieves the
    kth value as a percentage of the number of samples.

    Parameters
    ----------
    dists : jax.numpy.ndarray
        the distance matrix already calculate (n_samples, n_samples)

    k : int
        the kth value from the (default=0.15)

    Returns
    -------
    kth_dist : jax.numpy.ndarray
        the neighbours up to the kth distance
    """
    # kth distance calculation (50%)
    kth_sample = int(k * dists.shape[0])

    # take the Kth neighbours of that distance
    k_dist = np.sort(dists)[:, kth_sample]

    return k_dist
Exemplo n.º 17
0
def render_rays_fine(rays,
                     z_vals,
                     weights,
                     num_importance,
                     perturbation=True,
                     rng=None):
    """Render rays for the fine model.
    Args:
        rays: (2, num_rays, 3) origin and direction generated rays
        z_vals: (num_rays, num_samples) depths of the sampled positions
        weights: (num_rays, num_samples) weights assigned to each sampled color for the coarse model
        num_importance: number of samples used in the fine model
        perturbation: whether to apply jitter on each ray or not
        rng: random key
    Returns:
        pts: (num_rays, num_samples + num_importance, 3) points in space to evaluate model at
        z_vals: (num_rays, num_samples + num_importance) depths of the sampled positions
        z_samples: (num_rays) standard deviation of distances along ray for each sample
    """
    rays_o, rays_d = rays

    z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
    z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], num_importance,
                           perturbation, rng)
    z_samples = lax.stop_gradient(z_samples)

    # obtain all points to evaluate color density at
    z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
    z_vals = z_vals.astype(rays_d.dtype)
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
    return pts, z_vals, jnp.std(z_samples, axis=-1)
def sample_permutation(key, coupling):
  """Samples a permutation matrix from a doubly stochastic coupling matrix.

  CAREFUL: the couplings that come out of the sinkhorn solver
  are not doubly stochastic but 1/dim * doubly_stochastic.

  See **Convex Relaxations for Permutation Problems** paper for rough
  explanation of the algorithm.
  Best to use by drawing multiple samples and picking the permutation with
  lowest cost as sometimes permutations seem to be drawn with high cost.
  the sample_best_permutation method does this.

  Args:
    key: jnp.ndarray that functions as a PRNG key.
    coupling: jnp.ndarray of shape [N, N] which must have marginals such that
      coupling.sum(0) == 1. and coupling.sum(1) == 1. Note that in sinkhorn we
      usually output couplings with marginals that sum to 1/N.

  Returns:
    permutation matrix: jnp.ndarray of shape [N, N] of floating dtype.
  """
  dim = coupling.shape[0]

  # random monotonic vector v without duplicates.
  v = jax.random.choice(key, 10 * dim, shape=(dim,), replace=False)
  v = jnp.sort(v) * 10.

  w = jnp.dot(coupling, v)
  # Sorting w will give the row indices of the permutation matrix.
  row_ind = jnp.argsort(w)
  col_ind = jnp.arange(0, dim)

  # Compute permutation matrix from row and column indices
  perm = idx2permutation(row_ind, col_ind)
  return perm
Exemplo n.º 19
0
def test_prior_mll():
    """
    Test that the MLL evaluation works with priors attached to the parameter values.
    """
    key = jr.PRNGKey(123)
    x = jnp.sort(jr.uniform(key, minval=-5.0, maxval=5.0, shape=(100, 1)),
                 axis=0)
    f = lambda x: jnp.sin(jnp.pi * x) / (jnp.pi * x)
    y = f(x) + jr.normal(key, shape=x.shape) * 0.1
    posterior = Prior(kernel=RBF()) * Gaussian()

    params = initialise(posterior)
    config = get_defaults()
    constrainer, unconstrainer = build_all_transforms(params.keys(), config)
    params = unconstrainer(params)
    print(params)

    mll = marginal_ll(posterior, transform=constrainer)

    priors = {
        "lengthscale": tfd.Gamma(1.0, 1.0),
        "variance": tfd.Gamma(2.0, 2.0),
        "obs_noise": tfd.Gamma(2.0, 2.0),
    }
    mll_eval = mll(params, x, y)
    mll_eval_priors = mll(params, x, y, priors)

    assert pytest.approx(mll_eval) == jnp.array(-103.28180663)
    assert pytest.approx(mll_eval_priors) == jnp.array(-105.509218857)
Exemplo n.º 20
0
def sample_pdf(key, bins, weights, origins, directions, z_vals, num_samples,
               randomized):
    """Hierarchical sampling.

  Args:
    key: jnp.ndarray(float32), [2,], random number generator.
    bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
    weights: jnp.ndarray(float32), [batch_size, num_bins].
    origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
    directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
    z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
    num_samples: int, the number of samples.
    randomized: bool, use randomized samples.

  Returns:
    z_vals: jnp.ndarray(float32),
      [batch_size, num_coarse_samples + num_fine_samples].
    points: jnp.ndarray(float32),
      [batch_size, num_coarse_samples + num_fine_samples, 3].
  """
    z_samples = piecewise_constant_pdf(key, bins, weights, num_samples,
                                       randomized)
    # Compute united z_vals and sample points
    z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
    coords = cast_rays(z_vals, origins, directions)
    return z_vals, coords
Exemplo n.º 21
0
def sample_pdf(key, bins, weights, rays, z_vals, num_samples, randomized):
    """Hierarchical sampling.

  Args:
    key: jnp.ndarray(float32), [2,], random number generator.
    bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
    weights: jnp.ndarray(float32), [batch_size, num_bins].
    rays: jnp.ndarray(float32), [batch_size, 6].
    z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
    num_samples: int, the number of samples.
    randomized: bool, use randomized samples.

  Returns:
    z_vals: jnp.ndarray(float32),
      [batch_size, num_coarse_samples + num_fine_samples].
    points: jnp.ndarray(float32),
      [batch_size, num_coarse_samples + num_fine_samples, 3].
  """
    z_samples = piecewise_constant_pdf(key, bins, weights, num_samples,
                                       randomized)
    origins = rays[Ellipsis, 0:3]
    directions = rays[Ellipsis, 3:6]
    # Compute united z_vals and sample points
    z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
    return z_vals, (origins[Ellipsis, None, :] +
                    z_vals[Ellipsis, None] * directions[Ellipsis, None, :])
Exemplo n.º 22
0
def get_boundaries_intersections(z, d, trust_radius):
    a = jnp.vdot(d, d)
    b = 2 * jnp.vdot(z, d)
    c = jnp.vdot(z, z) - trust_radius**2
    sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c)
    ta = (-b - sqrt_discriminant) / (2 * a)
    tb = (-b + sqrt_discriminant) / (2 * a)
    return jnp.sort(jnp.stack([ta, tb]))
Exemplo n.º 23
0
def get_top_k_weights(
    top_k_fraction: float,
    restarting_weights: Array,
    scaled_advantages: Array,
    axis_name: Optional[str] = None,
    use_stop_gradient: bool = True,
):
  """Get the weights for the top top_k_fraction of advantages.

  Args:
    top_k_fraction: The fraction of weights to use.
    restarting_weights: Restarting weights, shape E*, 0 means that this step is
      the start of a new episode and we ignore losses at this step because the
      agent cannot influence these.
    scaled_advantages: The advantages for each example (shape E*), scaled by
      temperature.
    axis_name: Optional axis name for `pmap`. If `None`, computations are
      performed locally on each device.
    use_stop_gradient: bool indicating whether or not to apply stop gradient.

  Returns:
    Weights for the top top_k_fraction of advantages
  """
  chex.assert_equal_shape([scaled_advantages, restarting_weights])
  chex.assert_type([scaled_advantages, restarting_weights], float)

  if not 0.0 < top_k_fraction <= 1.0:
    raise ValueError(
        f"`top_k_fraction` must be in (0, 1], got {top_k_fraction}")
  logging.info("[vmpo_e_step] top_k_fraction: %f", top_k_fraction)

  if top_k_fraction < 1.0:
    # Don't include the restarting samples in the determination of top-k.
    valid_scaled_advantages = scaled_advantages - (
        1.0 - restarting_weights) * _INFINITY
    # Determine the minimum top-k value across all devices,
    if axis_name:
      all_valid_scaled_advantages = jax.lax.all_gather(
          valid_scaled_advantages, axis_name=axis_name)
    else:
      all_valid_scaled_advantages = valid_scaled_advantages
    top_k = int(top_k_fraction * jnp.size(all_valid_scaled_advantages))
    if top_k == 0:
      raise ValueError(
          "top_k_fraction too low to get any valid scaled advantages.")
    # TODO(b/160450251): Use jnp.partition(all_valid_scaled_advantages, top_k)
    #   when this is implemented in jax.
    top_k_min = jnp.sort(jnp.reshape(all_valid_scaled_advantages, [-1]))[-top_k]
    # Fold the top-k into the restarting weights.
    top_k_weights = jnp.greater_equal(valid_scaled_advantages,
                                      top_k_min).astype(jnp.float32)
    top_k_weights = jax.lax.select(
        use_stop_gradient, jax.lax.stop_gradient(top_k_weights), top_k_weights)
    top_k_restarting_weights = restarting_weights * top_k_weights
  else:
    top_k_restarting_weights = restarting_weights

  return top_k_restarting_weights
Exemplo n.º 24
0
def mask(x, mask_constant, mask_axis, key, p):
    if mask_constant is not None:
        mask_shape = [
            1 if i in mask_axis else s for i, s in enumerate(x.shape)
        ]
        mask_mat = jax.random.bernoulli(key, p=p, shape=mask_shape)
        x = np.where(mask_mat, mask_constant, x)
        x = np.sort(x, 1)
    return x
Exemplo n.º 25
0
 def test_topk_one_array(self, k):
   n = 20
   x = jax.random.uniform(self.rng, (n,))
   axis = 0
   xs = soft_sort.sort(x, axis=axis, topk=k, epsilon=1e-3)
   outsize = k if 0 < k < n else n
   self.assertEqual(xs.shape, (outsize,))
   self.assertTrue(jnp.alltrue(jnp.diff(xs, axis=axis) >= 0.0))
   self.assertAllClose(xs, jnp.sort(x, axis=axis)[-outsize:], atol=0.01)
Exemplo n.º 26
0
def gini(array):
    """Calculate the Gini coefficient of a numpy array."""
    # based on bottom eq: http://www.statsdirect.com/help/content/image/stat0206_wmf.gif
    # from: http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm
    array = np.sort(array)  #values must be sorted
    index = np.arange(1, array.shape[0] + 1)  #index per array element
    n = array.shape[0]  #number of array elements
    return ((np.sum(
        (2 * index - n - 1) * array)) / (n * np.sum(array)))  #Gini coefficient
Exemplo n.º 27
0
def top_k_approx(scores, k=100):
    """Returns approximate topk highest scores for each row.

  The api is same as jax.lax.top_k, so this can be used as a drop in replacement
  as long as num dims of scores tensor is 2. For more dimensions, please use one
  or more vmap(s) to be able to use it.

  In essence, we perform jnp.max operation, which can be thought of as
  lossy top 1, on fixed length window of items. We can control the amound of
  approximation by changing the window length. Smaller it gets, the
  approximation gets better but at the cost of performance.

  Once we have the max for all the windows, we apply regular slow but exact
  jax.lax.top_k over reduced set of items.

  Args:
    scores: [num_rows, num_cols] shaped tensor. Will return top K over last dim.
    k: How many top scores to return for each row.

  Returns:
    Topk scores, topk ids. Both shaped [num_rows, k]
  """
    num_queries = scores.shape[0]
    num_items = scores.shape[1]

    # Make this bigger to improve recall. Should be between [1, k].
    num_windows_multiplier = 5
    window_lengths = num_items // k // num_windows_multiplier + 1
    padded_num_items = k * num_windows_multiplier * window_lengths

    print(f"scores shape: {scores.shape}")
    print(f"padded_num_items: {padded_num_items}")
    print(f"num_items: {num_items}")
    scores = jnp.pad(scores, ((0, 0), (0, padded_num_items - num_items)),
                     mode="constant",
                     constant_values=jnp.NINF)
    scores = jnp.reshape(
        scores, (num_queries, k * num_windows_multiplier, window_lengths))
    approx_top_local_scores = jnp.max(scores, axis=2)

    sorted_approx_top_scores_across_local = jnp.flip(jnp.sort(
        approx_top_local_scores, axis=1),
                                                     axis=1)
    approx_top_ids_across_local = jnp.flip(jnp.argsort(approx_top_local_scores,
                                                       axis=1),
                                           axis=1)[:, :k]

    approx_top_local_ids = jnp.argmax(scores, axis=2)
    offsets = jnp.arange(0, padded_num_items, window_lengths)
    approx_top_ids_with_offsets = approx_top_local_ids + offsets
    approx_top_ids = slice_2d(approx_top_ids_with_offsets,
                              approx_top_ids_across_local)

    topk_scores = sorted_approx_top_scores_across_local[:, :k]
    topk_ids = approx_top_ids

    return topk_scores, topk_ids
Exemplo n.º 28
0
 def __init__(self, dim, bounds):
     """
     :param dim: dimension of the space
     :param bounds: a list of floats. E.g [b_1, ..., b_k] representing k cubes, centered at (0,...,0) and sides b_i*2
         we assume all bounds b_i are positive
     """
     self.dim = dim
     self.pos_sorted_bounds = np.sort(bounds)
     assert self.pos_sorted_bounds[
         0] > 0  # we assume all bounds (thus the smallest) are positive
Exemplo n.º 29
0
 def test_soft_quantile_normalization(self):
   rngs = jax.random.split(self.rng, 2)
   x = jax.random.uniform(rngs[0], shape=(100,))
   mu, sigma = 2.0, 1.2
   y = mu + sigma * jax.random.normal(self.rng, shape=(48,))
   mu_target, sigma_target = y.mean(), y.std()
   qn = soft_sort.quantile_normalization(x, jnp.sort(y), epsilon=1e-4)
   mu_transform, sigma_transform = qn.mean(), qn.std()
   self.assertAllClose([mu_transform, sigma_transform],
                       [mu_target, sigma_target], rtol=0.05)
Exemplo n.º 30
0
 def gen(key: jnp.ndarray):
     feature_ids = jax.random.shuffle(key, remaining_ids)
     for batch_index in range(length):
         i = feature_ids[batch_size * batch_index:batch_size *
                         (batch_index + 1)]
         i = jnp.concatenate((ids, i))
         i = jnp.sort(i)
         f = features[i]
         p = prop[:, i]
         yield ((p, f, input_ids), labels)