示例#1
0
def lqr_continuous_time_infinite_horizon(A, B, Q, R, N):
    # Take the last dimension, in case we try to do some kind of broadcasting
    # thing in the future.
    x_dim = A.shape[-1]

    # pylint: disable=line-too-long
    # See https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_continuous-time_LQR.
    A1 = A - B @ jp.linalg.solve(R, N.T)
    Q1 = Q - N @ jp.linalg.solve(R, N.T)

    # See https://en.wikipedia.org/wiki/Algebraic_Riccati_equation#Solution.
    H = jp.block([[A1, -B @ jp.linalg.solve(R, B.T)], [-Q1, -A1]])
    eigvals, eigvectors = jp.linalg.eig(H)

    # For large-ish systems (eg x_dim = 7), sometimes we find some values that
    # have an imaginary component. That's an unfortunate consequence of the
    # numerical instability in the eigendecomposition. Still,
    # assert (eigvals.imag == jp.zeros_like(eigvals, dtype=jp.float32)).all()
    # assert (eigvectors.imag == jp.zeros_like(eigvectors, dtype=jp.float32)).all()

    # Now it should be safe to take out only the real components.
    eigvals = eigvals.real
    eigvectors = eigvectors.real

    argsort = jp.argsort(eigvals)
    ix = argsort[:x_dim]
    U = eigvectors[:, ix]
    P = U[x_dim:, :] @ jp.linalg.inv(U[:x_dim, :])

    K = jp.linalg.solve(R, (B.T @ P + N.T))
    return K, P, eigvals[ix]
示例#2
0
def compute_eigenvalue_decomposition(Ms,
                                     sort_by='magnitude',
                                     do_compute_lefts=True):
    """Compute the eigenvalues of the matrix M. No assumptions are made on M.

  Arguments: 
    M: 3D np.array nmatrices x dim x dim matrix
    do_compute_lefts: Compute the left eigenvectors? Requires a pseudo-inverse 
      call.

  Returns: 
    list of dictionaries with eigenvalues components: sorted 
      eigenvalues, sorted right eigenvectors, and sored left eigenvectors 
      (as column vectors).
  """
    if sort_by == 'magnitude':
        sort_fun = onp.abs
    elif sort_by == 'real':
        sort_fun = onp.real
    else:
        assert False, "Not implemented yet."

    decomps = []
    L = None
    for M in Ms:
        evals, R = onp.linalg.eig(M)
        indices = np.flipud(np.argsort(sort_fun(evals)))
        if do_compute_lefts:
            L = onp.linalg.pinv(R).T  # as columns
            L = L[:, indices]
        decomps.append({'evals': evals[indices], 'R': R[:, indices], 'L': L})

    return decomps
示例#3
0
def test_modelSelection(family, prior, method):
    p = 5
    n = 700
    key = random.PRNGKey(0)
    X = random.normal(key, (n, p))
    key, subkey1, subkey2 = random.split(key, 3)
    mu = 1.7 * X[:, 1] - 1.6 * X[:, 2]
    truth = jnp.full((p, ), False)
    truth = truth.at[[1, 2]].set(True)
    if family == "logistic":
        y = (mu + random.normal(subkey1, (n, )) > 0).astype(jnp.int32)
    elif family == "poisson":
        y = random.poisson(subkey2, lam=jnp.exp(mu), shape=(n, ))
    fmt = f"{{:0{p}b}}"
    gammes = np.array([list(fmt.format(i)) for i in range(2**p)])
    gammes = jnp.array(gammes == "0")[:-1, :]

    _, modprobs = modelSelection(X,
                                 y,
                                 gammes,
                                 family=family,
                                 prior=prior,
                                 method=method)
    order = jnp.argsort(modprobs)[::-1]
    assert np.all(np.isfinite(modprobs))
    if family == "logistic":
        assert np.all(gammes[order[0], :] == truth), gammes[order[0], :]
示例#4
0
def eigh(H,
         precision=lax.Precision.HIGHEST,
         symmetrize=True,
         termination_size=128):
    """ Computes the eigendecomposition of the symmetric/Hermitian matrix H.

  Args:
    H: The `n x n` Hermitian input.
    precision: The matmul precision.
    symmetrize: If True, `0.5 * (H + H.conj().T)` rather than `H` is used.
    termination_size: Recursion ends once the blocks reach this linear size.
  Returns:
    vals: The `n` eigenvalues of `H`, sorted from lowest to highest.
    vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
      of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
      to numerical error.
  """
    nrows, ncols = H.shape
    if nrows != ncols:
        raise TypeError(f"Input H of shape {H.shape} must be square.")

    if ncols <= termination_size:
        return jnp.linalg.eigh(H)

    evals, evecs = _eigh_work(H, precision=precision)
    sort_idxs = jnp.argsort(evals)
    evals = evals[sort_idxs]
    evecs = evecs[:, sort_idxs]
    return evals, evecs
示例#5
0
def CM_eigenvectors_EVsorted(Z, R, N=0, cutoff=10):
    ''' Matrix containing eigenvalues of unsorted Coulomb matrix,
    sorted by their eigenvalues. Cutoff possible at dedicated len.
    or for certain sizes of eigenvalues


    Parameters
    ----------
    Z : 1 x n dimensional array
        contains nuclear charges
    R : 3 x n dimensional array
        contains nuclear positions
    N : float
        number of electrons in system
        here: meaningless, can remain empty
    
    Return
    ------
    M : Matrix 
        contains eigenvectors of sorted CM
    (vectors: tuple
        contains Eigenvectors of matrix (n dim.)
        If i out of bounds, return none and print error)
    '''
    N = CM_full_unsorted_matrix(Z, R)
    ev, evec = jnp.linalg.eigh(N)
    order = jnp.argsort(ev)[:min(ev.size, cutoff)]

    sorted_evec = evec[order]

    return (sorted_evec)
示例#6
0
def top_k_error_rate_metric(logits: jnp.ndarray,
                            one_hot_labels: jnp.ndarray,
                            k: int = 5,
                            mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
  """Returns the top-K error rate between some predictions and some labels.

  Args:
    logits: Output of the model.
    one_hot_labels: One-hot encoded labels. Dimensions should match the logits.
    k: Number of class the model is allowed to predict for each example.
    mask: Mask to apply to the loss to ignore some samples (usually, the padding
      of the batch). Array of ones and zeros.

  Returns:
    The error rate (1 - accuracy), averaged over the first dimension (samples).
  """
  if mask is None:
    mask = jnp.ones([logits.shape[0]])
  mask = mask.reshape([logits.shape[0]])
  true_labels = jnp.argmax(one_hot_labels, -1).reshape([-1, 1])
  top_k_preds = jnp.argsort(logits, axis=-1)[:, -k:]
  hit = jax.vmap(jnp.isin)(true_labels, top_k_preds)
  error_rate = 1 - ((hit * mask).sum() / mask.sum())
  # Set to zero if there is no non-masked samples.
  return jnp.nan_to_num(error_rate)
示例#7
0
文件: kernels.py 项目: ab-10/numpyro
    def compute(self, particles, particle_info, loss_fn):
        if self._random_weights is None:
            self._random_weights = jnp.array(npr.randn(*particles.shape))
            self._random_biases = jnp.array(
                npr.rand(*particles.shape) * 2 * np.pi)
        factor = self.bandwidth_factor(particles.shape[0])
        if self.bandwidth_subset is not None:
            particles = particles[npr.choice(particles.shape[0],
                                             self.bandwidth_subset)]
        diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims(
            particles, axis=1)  # N x N x D
        if particles.ndim == 2:
            diffs = safe_norm(diffs, ord=2, axis=-1)  # N x N x D -> N x N
        diffs = jnp.reshape(diffs,
                            (diffs.shape[0] * diffs.shape[1], -1))  # N * N x 1
        if diffs.ndim == 2:
            diff_norms = safe_norm(diffs, ord=2, axis=-1)
        else:
            diff_norms = diffs
        median = jnp.argsort(diff_norms)[int(diffs.shape[0] / 2)]
        bandwidth = jnp.abs(diffs)[median]**2 * factor + 1e-5

        def feature(x, w, b):
            return jnp.sqrt(2) * jnp.cos((x @ w + b) / bandwidth)

        def kernel(x, y):
            ws = self._random_weights if self.random_indices is None else self._random_weights[
                self.random_indices]
            bs = self._random_biases if self.random_indices is None else self._random_biases[
                self.random_indices]
            return jnp.sum(
                jax.vmap(lambda w, b: feature(x, w, b) * feature(y, w, b))(ws,
                                                                           bs))

        return kernel
示例#8
0
def PCA_visual(p_final):
    '''
    Visualize matrix (N* M) as N points
    '''

    loss_p = lambda p: loss(p_final[0], p, omega0, sx)
    Hess = jacrev(jacrev(loss_p))(p_final[1])

    w, v = jnp.linalg.eig(Hess)
    v_real = v.real
    arglist = jnp.argsort(w)[-3:]

    pca = PCA(n_components=2)  # scatter in 2d pannel
    reduced = pca.fit_transform(v_real[:, arglist].transpose())

    t = reduced.transpose()
    # t_main = t[:,arglist]

    fig = plt.figure()
    ax1 = fig.add_subplot(111)

    ax1.scatter(t[0], t[1], marker='v', label='main')
    # ax1.scatter(t[0], t[1],marker='o',label='full')
    # ax1.scatter(t_main[0],t_main[1],marker='v',label='main')

    plt.legend(loc='upper left')
    plt.show()
def sorted_topk_indicators(x, k, sort_by=SortBy.POSITION):
    """Finds the (sorted) positions of the topk values in x.

  Args:
    x: The input scores of dimension (d,).
    k: The number of top elements to find.
    sort_by: Strategy to order the extracted values. This is useful when this
      function is applied to many perturbed input and average. As topk's output
      does not have a fixed order, the indicator vectors could be swaped and the
      average of the indicators would not be spiky.

  Returns:
    Indicator vectors in a tensor of shape (k, d)
  """
    n = x.shape[-1]
    values, ranks = jax.lax.top_k(x, k)

    if sort_by == SortBy.NONE:
        sorted_ranks = ranks
    if sort_by == SortBy.VALUES:
        sorted_ranks = ranks[jnp.argsort(values)]
    if sort_by == SortBy.POSITION:
        sorted_ranks = jnp.sort(ranks)

    one_hot_fn = jax.vmap(functools.partial(jax.nn.one_hot, num_classes=n))
    indicators = one_hot_fn(sorted_ranks)
    return indicators
示例#10
0
def init_simplex_sampler_state(live_points_U):
    N, D = live_points_U.shape
    inter_point_distance = squared_norm(live_points_U, live_points_U)
    inter_point_distance = jnp.where(inter_point_distance == 0., jnp.inf,
                                     inter_point_distance)
    knn_indices = jnp.argsort(inter_point_distance, axis=-1)[:, :D + 1]
    return SimplexSamplerState(knn_indices=knn_indices)
示例#11
0
 def topk_mask_internal(value):
     assert value.ndim == 1
     indices = jnp.argsort(value)
     k = jnp.round(density_fraction * jnp.size(value)).astype(jnp.int32)
     mask = jnp.greater_equal(np.arange(value.size), value.size - k)
     mask = jnp.zeros_like(mask).at[indices].set(mask)
     return mask.astype(np.int32)
示例#12
0
def gather_error_check(error, enabled_errors, operand, start_indices, *,
                       dimension_numbers, slice_sizes, unique_indices,
                       indices_are_sorted, mode, fill_value):
    out = lax.gather_p.bind(operand,
                            start_indices,
                            dimension_numbers=dimension_numbers,
                            slice_sizes=slice_sizes,
                            unique_indices=unique_indices,
                            indices_are_sorted=indices_are_sorted,
                            mode=mode,
                            fill_value=fill_value)

    if ErrorCategory.OOB not in enabled_errors:
        return out, error

    # compare to OOB masking logic in lax._gather_translation_rule
    dnums = dimension_numbers
    operand_dims = np.array(operand.shape)
    num_batch_dims = len(start_indices.shape) - 1

    upper_bound = operand_dims[np.array(dnums.start_index_map)]
    upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
    upper_bound = jnp.expand_dims(upper_bound,
                                  axis=tuple(range(num_batch_dims)))
    in_bounds = (start_indices >= 0) & (start_indices <= upper_bound)

    msg = f'out-of-bounds indexing at {summary()}: '
    msg += 'index {payload} is out of bounds for '
    msg += f'array of shape {operand.shape}.'
    start_indices, in_bounds = jnp.ravel(start_indices), jnp.ravel(in_bounds)
    # Report first index which is out-of-bounds (in row-major order).
    payload = start_indices[jnp.argsort(in_bounds, axis=0)[0]]

    return out, assert_func(error, jnp.all(in_bounds), msg, payload)
示例#13
0
        def split_top_k(split_queries):
            split_scores = jnp.einsum('qd,rvd->qrv', split_queries, table)

            # Find highest scoring vector for each row.
            top_id_by_row = jnp.argmax(split_scores, axis=-1)
            top_score_by_row = jnp.max(split_scores, axis=-1)

            # Take k highest scores among all rows.
            top_row_idx = jnp.argsort(top_score_by_row,
                                      axis=-1)[:, :-self.k_top - 1:-1]

            # Sub-select best indices for k best rows.
            ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx)

            # Gather highest scoring vectors for k best rows.
            split_topk_values = table[top_row_idx, ids_by_topk_row]

            # Convert row indices to indices into flattened table.
            top_table_id_by_row = top_id_by_row + jnp.arange(
                0, table_size, scores_per_row)
            # Get best ids into flattened table.
            split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx)

            split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx)

            return split_topk_values, split_topk_scores, split_topk_ids
示例#14
0
def nucleaus_filter(logits, top_p=0.9, top_k=None):
    sorted_logits = jnp.sort(logits)[:, ::-1]  # sort descending
    sorted_indices = jnp.argsort(logits)[:, ::-1]
    cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits), axis=-1)

    if top_k is not None:
        # Keep only top_k tokens
        indices_range = jnp.arange(len(sorted_indices[0]))
        indices_range = jnp.stack([indices_range] * len(sorted_indices),
                                  axis=0)

        sorted_indices_to_remove = jnp.where(indices_range > top_k,
                                             sorted_indices, 0)

        _, indices_to_remove = jax.lax.sort_key_val(sorted_indices,
                                                    sorted_indices_to_remove)

        logit_mask = 1e10 * indices_to_remove

        logits -= logit_mask

    # Remove tokens with cumulative probability above a threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove = jnp.concatenate((jnp.zeros_like(
        sorted_indices_to_remove[:, :1]), sorted_indices_to_remove),
                                               axis=-1)[:, :-1]

    _, indices_to_remove = jax.lax.sort_key_val(sorted_indices,
                                                sorted_indices_to_remove)

    logit_mask = 1e10 * indices_to_remove

    logits -= logit_mask

    return logits
def sample_permutation(key, coupling):
  """Samples a permutation matrix from a doubly stochastic coupling matrix.

  CAREFUL: the couplings that come out of the sinkhorn solver
  are not doubly stochastic but 1/dim * doubly_stochastic.

  See **Convex Relaxations for Permutation Problems** paper for rough
  explanation of the algorithm.
  Best to use by drawing multiple samples and picking the permutation with
  lowest cost as sometimes permutations seem to be drawn with high cost.
  the sample_best_permutation method does this.

  Args:
    key: jnp.ndarray that functions as a PRNG key.
    coupling: jnp.ndarray of shape [N, N] which must have marginals such that
      coupling.sum(0) == 1. and coupling.sum(1) == 1. Note that in sinkhorn we
      usually output couplings with marginals that sum to 1/N.

  Returns:
    permutation matrix: jnp.ndarray of shape [N, N] of floating dtype.
  """
  dim = coupling.shape[0]

  # random monotonic vector v without duplicates.
  v = jax.random.choice(key, 10 * dim, shape=(dim,), replace=False)
  v = jnp.sort(v) * 10.

  w = jnp.dot(coupling, v)
  # Sorting w will give the row indices of the permutation matrix.
  row_ind = jnp.argsort(w)
  col_ind = jnp.arange(0, dim)

  # Compute permutation matrix from row and column indices
  perm = idx2permutation(row_ind, col_ind)
  return perm
示例#16
0
def batched_neighbor_inds(confs, inds_l, inds_r, cutoff, boxes):
    """Given candidate interacting pairs (inds_l, inds_r),
        inds_l.shape == n_interactions
    exclude most pairs whose distances are >= cutoff (neighbor_inds_l, neighbor_inds_r)
        neighbor_inds_l.shape == (len(confs), max_n_neighbors)
        where the total number of neighbors returned for each conf in confs is the same
        max_n_neighbors

    This padding causes some amount of wasted effort, but keeps things nice and fixed-dimensional
        for later XLA steps
    """
    assert len(confs.shape) == 3
    distances = vmap(distance_on_pairs)(confs[:, inds_l], confs[:, inds_r],
                                        boxes)
    assert distances.shape == (len(confs), len(inds_l))

    neighbor_masks = distances < cutoff
    # how many total neighbors?

    n_neighbors = np.sum(neighbor_masks, 1)
    max_n_neighbors = max(n_neighbors)

    assert max_n_neighbors > 0

    # sorting in order of [falses, ..., trues]
    keep_inds = np.argsort(neighbor_masks, axis=1)[:, -max_n_neighbors:]
    neighbor_inds_l = inds_l[keep_inds]
    neighbor_inds_r = inds_r[keep_inds]

    assert neighbor_inds_l.shape == (len(confs), max_n_neighbors)
    assert neighbor_inds_l.shape == neighbor_inds_r.shape

    return neighbor_inds_l, neighbor_inds_r
示例#17
0
    def sum_from_unique(
            cls,
            input: np.array,
            mean: bool = True) -> Tuple[np.array, np.array, "SparseReduce"]:
        un, cts = np.unique(input, return_counts=True)
        un_idx = [
            np.argwhere(input == un[i]).flatten() for i in range(un.size)
        ]
        l_arr = np.array([i.size for i in un_idx])
        argsort = np.argsort(l_arr)
        un_sorted = un[argsort]
        cts_sorted = cts[argsort]
        un_idx_sorted = [un_idx[i] for i in argsort]

        change = list(
            np.argwhere(
                l_arr[argsort][:-1] - l_arr[argsort][1:] != 0).flatten() + 1)
        change.insert(0, 0)
        change.append(len(l_arr))
        change = np.array(change)

        el = []
        for i in range(len(change) - 1):
            el.append(
                np.array([
                    un_idx_sorted[j] for j in range(change[i], change[i + 1])
                ]))

        #assert False
        return un_sorted, cts_sorted, SparseReduce(el, mean)
示例#18
0
def nystrom_inv(gram, n_comp, regul=0.) -> np.array:
    p = random.permutation(gram.shape[0])
    ip = np.argsort(p)
    (vec_in, vec_out, λ) = nystrom_eigh(gram[p, :][:, p], n_comp, regul)
    vec = np.vstack([vec_in, vec_out])
    rval = vec @ np.diag(1. / (λ + regul)) @ vec.T
    return rval[ip, :][:, ip]
示例#19
0
 def prior_sample(self, num_samps, t=None):
     """
     Sample from the model prior f~N(0,K) multiple times using a nested loop.
     :param num_samps: the number of samples to draw [scalar]
     :param t: the input locations at which to sample (defaults to train+test set) [N_samp, 1]
     :return:
         f_sample: the prior samples [S, N_samp]
     """
     self.update_model(softplus_list(self.prior.hyp))
     if t is None:
         t = self.t_all
     else:
         x_ind = np.argsort(t[:, 0])
         t = t[x_ind]
     dt = np.concatenate([np.array([0.0]), np.diff(t[:, 0])])
     N = dt.shape[0]
     with loops.Scope() as s:
         s.f_sample = np.zeros([N, self.func_dim, num_samps])
         s.m = np.linalg.cholesky(self.Pinf) @ random.normal(random.PRNGKey(99), shape=[self.state_dim, 1])
         for i in s.range(num_samps):
             s.m = np.linalg.cholesky(self.Pinf) @ random.normal(random.PRNGKey(i), shape=[self.state_dim, 1])
             for k in s.range(N):
                 A = self.prior.state_transition(dt[k], self.prior.hyp)  # transition and noise process matrices
                 Q = self.Pinf - A @ self.Pinf @ A.T
                 C = np.linalg.cholesky(Q + 1e-6 * np.eye(self.state_dim))  # <--- can be a bit unstable
                 # we need to provide a different PRNG seed every time:
                 s.m = A @ s.m + C @ random.normal(random.PRNGKey(i*k+k), shape=[self.state_dim, 1])
                 H = self.prior.measurement_model(t[k, 1:], softplus_list(self.prior.hyp))
                 f = (H @ s.m).T
                 s.f_sample = index_add(s.f_sample, index[k, ..., i], np.squeeze(f))
     return s.f_sample
示例#20
0
def data_split(batch: Batch, num_workers: int, s: float):
    # s encodes the heterogeneity of the data split
    if num_workers == 1:
        return batch
    n_data = batch.x.shape[0]

    n_homo_data = int(n_data * s)

    assert 0 < n_homo_data < n_data

    data_homo, data_hetero = jnp.split(batch.x, [n_homo_data])
    label_homo, label_hetero = jnp.split(batch.y, [n_homo_data])

    data_homo_list = jnp.split(data_homo, num_workers)
    label_homo_list = jnp.split(label_homo, num_workers)

    index = jnp.argsort(label_hetero)
    label_hetero_sorted = label_hetero[index]
    data_hetero_sorted = data_hetero[index]
    data_hetero_list = jnp.split(data_hetero_sorted, num_workers)
    label_hetero_list = jnp.split(label_hetero_sorted, num_workers)

    data_list = [
        jnp.concatenate([data_homo, data_hetero], axis=0)
        for data_homo, data_hetero in zip(data_homo_list, data_hetero_list)
    ]
    label_list = [
        jnp.concatenate([label_homo, label_hetero], axis=0)
        for label_homo, label_hetero in zip(label_homo_list, label_hetero_list)
    ]

    return [Batch(x, y) for x, y in zip(data_list, label_list)]
示例#21
0
文件: kernels.py 项目: ab-10/numpyro
    def compute(self, particles, particle_info, loss_fn):
        diffs = jnp.expand_dims(particles, axis=0) - jnp.expand_dims(
            particles, axis=1)  # N x N (x D)
        if self._normed() and particles.ndim == 2:
            diffs = safe_norm(diffs, ord=2, axis=-1)  # N x D -> N
        diffs = jnp.reshape(
            diffs, (diffs.shape[0] * diffs.shape[1], -1))  # N * N (x D)
        factor = self.bandwidth_factor(particles.shape[0])
        if diffs.ndim == 2:
            diff_norms = safe_norm(diffs, ord=2, axis=-1)
        else:
            diff_norms = diffs
        median = jnp.argsort(diff_norms)[int(diffs.shape[0] / 2)]
        bandwidth = jnp.abs(diffs)[median]**2 * factor + 1e-5
        if self._normed():
            bandwidth = bandwidth[0]

        def kernel(x, y):
            diff = safe_norm(
                x - y, ord=2) if self._normed() and x.ndim >= 1 else x - y
            kernel_res = jnp.exp(-diff**2 / bandwidth)
            if self._mode == 'matrix':
                if self.matrix_mode == 'norm_diag':
                    return kernel_res * jnp.identity(x.shape[0])
                else:
                    return jnp.diag(kernel_res)
            else:
                return kernel_res

        return kernel
示例#22
0
    def evaluate_example(self, example: SingleExample,
                         prediction: SinglePrediction) -> MeanStat:
        """Computes token top k accuracy for a single sequence example.

    Args:
      example: One example with target in range [0, num_classes) of shape
        [max_length].
      prediction: Unnormalized prediction for ``example`` of shape
        [max_length, num_classes]

    Returns:
      MeanStat for token top k accuracy for a single sequence example or at each
        token position if ``per_position`` is ``True``.
    """
        target = example[self.target_key]
        pred = prediction if self.pred_key is None else prediction[
            self.pred_key]
        if self.logits_mask is not None:
            logits_mask = jnp.array(self.logits_mask)
            pred += logits_mask
        target_weight = get_target_weight(target, self.masked_target_values)
        top_k_pred = jnp.argsort(-pred, axis=1)[:, :self.k]
        correct = jnp.any(jnp.transpose(top_k_pred) == target,
                          axis=0).astype(jnp.float32)
        if self.per_position:
            return MeanStat.new(correct * target_weight, target_weight)
        return MeanStat.new(jnp.sum(correct * target_weight),
                            jnp.sum(target_weight))
示例#23
0
    def sampling_loop_body_fn(state):
        """Sampling loop state update."""
        i, sequences, cache, cur_token, ended, rng, tokens_to_logits_state = state

        # Split RNG for sampling.
        rng1, rng2 = random.split(rng)

        # Call fast-decoder model on current tokens to get raw next-position logits.
        logits, new_cache, new_tokens_to_logits_state = tokens_to_logits(
            cur_token, cache, internal_state=tokens_to_logits_state)
        logits = logits / temperature

        # Mask out the BOS token.
        if masked_tokens is not None:
            mask = common_utils.onehot(jnp.array(masked_tokens),
                                       num_classes=logits.shape[-1],
                                       on_value=LARGE_NEGATIVE)
            mask = jnp.sum(mask,
                           axis=0)[None, :]  # Combine multiple masks together
            logits = logits + mask

        # Apply the repetition penalty.
        if repetition_penalty != 1:
            logits = apply_repetition_penalty(
                sequences,
                logits,
                i,
                repetition_penalty=repetition_penalty,
                repetition_window=repetition_window,
                repetition_penalty_normalize=repetition_penalty_normalize)

        # Mask out everything but the top-k entries.
        if top_k is not None:
            # Compute top_k_index and top_k_threshold with shapes (batch_size, 1).
            top_k_index = jnp.argsort(logits,
                                      axis=-1)[:, ::-1][:, top_k - 1:top_k]
            top_k_threshold = jnp.take_along_axis(logits, top_k_index, axis=-1)
            logits = jnp.where(logits < top_k_threshold,
                               jnp.full_like(logits, LARGE_NEGATIVE), logits)
        # Sample next token from logits.
        sample = multinomial(rng1, logits)
        next_token = sample.astype(jnp.int32)
        # Only use sampled tokens if we have past the out_of_prompt_marker.
        out_of_prompt = (sequences[:, i + 1] == out_of_prompt_marker)
        next_token = (next_token * out_of_prompt +
                      sequences[:, i + 1] * ~out_of_prompt)
        # If end-marker reached for batch item, only emit padding tokens.
        next_token = next_token[:, None]
        next_token_or_endpad = jnp.where(ended,
                                         jnp.full_like(next_token, pad_token),
                                         next_token)
        ended |= (next_token_or_endpad == end_marker)
        # Add current sampled tokens to recorded sequences.
        new_sequences = lax.dynamic_update_slice(sequences,
                                                 next_token_or_endpad,
                                                 (0, i + 1))
        return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended,
                rng2, new_tokens_to_logits_state)
示例#24
0
def _argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None):  # pylint: disable=unused-argument
    """Numpy implementation of `tf.argsort`."""
    if direction == 'ASCENDING':
        pass
    elif direction == 'DESCENDING':
        values = np.negative(values)
    else:
        raise ValueError('Unrecognized direction: {}.'.format(direction))
    return np.argsort(values, axis, kind='stable' if stable else 'quicksort')
示例#25
0
 def reshape_inv(y):
     # Expand the extra dims hanging off the end, "b_extra_sh".
     # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
     # Could have different batch dims than a and b, because of broadcasting.
     y_extra_shape = array_ops.concat(
         (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
     y_extra_on_end = array_ops.reshape(y, y_extra_shape)
     inverse_perm = np.argsort(perm)
     return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
示例#26
0
def top_k_approx(scores, k=100):
    """Returns approximate topk highest scores for each row.

  The api is same as jax.lax.top_k, so this can be used as a drop in replacement
  as long as num dims of scores tensor is 2. For more dimensions, please use one
  or more vmap(s) to be able to use it.

  In essence, we perform jnp.max operation, which can be thought of as
  lossy top 1, on fixed length window of items. We can control the amound of
  approximation by changing the window length. Smaller it gets, the
  approximation gets better but at the cost of performance.

  Once we have the max for all the windows, we apply regular slow but exact
  jax.lax.top_k over reduced set of items.

  Args:
    scores: [num_rows, num_cols] shaped tensor. Will return top K over last dim.
    k: How many top scores to return for each row.

  Returns:
    Topk scores, topk ids. Both shaped [num_rows, k]
  """
    num_queries = scores.shape[0]
    num_items = scores.shape[1]

    # Make this bigger to improve recall. Should be between [1, k].
    num_windows_multiplier = 5
    window_lengths = num_items // k // num_windows_multiplier + 1
    padded_num_items = k * num_windows_multiplier * window_lengths

    print(f"scores shape: {scores.shape}")
    print(f"padded_num_items: {padded_num_items}")
    print(f"num_items: {num_items}")
    scores = jnp.pad(scores, ((0, 0), (0, padded_num_items - num_items)),
                     mode="constant",
                     constant_values=jnp.NINF)
    scores = jnp.reshape(
        scores, (num_queries, k * num_windows_multiplier, window_lengths))
    approx_top_local_scores = jnp.max(scores, axis=2)

    sorted_approx_top_scores_across_local = jnp.flip(jnp.sort(
        approx_top_local_scores, axis=1),
                                                     axis=1)
    approx_top_ids_across_local = jnp.flip(jnp.argsort(approx_top_local_scores,
                                                       axis=1),
                                           axis=1)[:, :k]

    approx_top_local_ids = jnp.argmax(scores, axis=2)
    offsets = jnp.arange(0, padded_num_items, window_lengths)
    approx_top_ids_with_offsets = approx_top_local_ids + offsets
    approx_top_ids = slice_2d(approx_top_ids_with_offsets,
                              approx_top_ids_across_local)

    topk_scores = sorted_approx_top_scores_across_local[:, :k]
    topk_ids = approx_top_ids

    return topk_scores, topk_ids
示例#27
0
文件: eigh.py 项目: wayfeng/jax
def _projector_subspace(P, H, rank, maxiter=2):
    """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into
  an `n x rank` isometry `Vm` such that `P = Vm @ Vm.conj().T` and
  an `n x (n - rank)` isometry `Vm` such that -(I - P) = Vp @ Vp.conj().T`.

  The subspaces are computed using the naiive QR eigendecomposition
  algorithm, which converges very quickly due to the sharp separation
  between the relevant eigenvalues of the projector.

  Args:
    P: A rank-`rank` Hermitian projector into the space of `H`'s
       first `rank` eigenpairs.
    H: The aforementioned Hermitian matrix, which is used to track
       convergence.
    rank: Rank of `P`.
    maxiter: Maximum number of iterations.
  Returns:
    Vm, Vp: Isometries into the eigenspaces described in the docstring.
  """
    # Choose an initial guess: the `rank` largest-norm columns of P.
    column_norms = jnp.linalg.norm(P, axis=1)
    sort_idxs = jnp.argsort(column_norms)
    X = P[:, sort_idxs]
    X = X[:, :rank]

    H_norm = jnp.linalg.norm(H)
    thresh = 10 * jnp.finfo(X.dtype).eps * H_norm

    # First iteration skips the matmul.
    def body_f_after_matmul(X):
        Q, _ = jnp.linalg.qr(X, mode="complete")
        V1 = Q[:, :rank]
        V2 = Q[:, rank:]
        # TODO: might be able to get away with lower precision here
        error_matrix = jnp.dot(V2.conj().T, H, precision=lax.Precision.HIGHEST)
        error_matrix = jnp.dot(error_matrix,
                               V1,
                               precision=lax.Precision.HIGHEST)
        error = jnp.linalg.norm(error_matrix) / H_norm
        return V1, V2, error

    def cond_f(args):
        _, _, j, error = args
        still_counting = j < maxiter
        unconverged = error > thresh
        return jnp.logical_and(still_counting, unconverged)[0]

    def body_f(args):
        V1, _, j, _ = args
        X = jnp.dot(P, V1, precision=lax.Precision.HIGHEST)
        V1, V2, error = body_f_after_matmul(X)
        return V1, V2, j + 1, error

    V1, V2, error = body_f_after_matmul(X)
    one = jnp.ones(1, dtype=jnp.int32)
    V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
    return V1, V2
示例#28
0
def weighted_percentile(x, w, ps, assume_sorted=False):
    """Compute the weighted percentile(s) of a single vector."""
    x = x.reshape([-1])
    w = w.reshape([-1])
    if not assume_sorted:
        sortidx = jnp.argsort(jax.lax.stop_gradient(x))
        x, w = x[sortidx], w[sortidx]
    acc_w = jnp.cumsum(w)
    return jnp.interp(jnp.array(ps) * (acc_w[-1] / 100), acc_w, x)
示例#29
0
def permuted_inverse_cdf_coupling(logits_1, logits_2, permutation_seed=1):
    """Constructs the matrix for an inverse CDF coupling under a permutation."""
    dim, = logits_1.shape
    perm = jnp.argsort(
        jax.random.uniform(jax.random.PRNGKey(permutation_seed), shape=[dim]))
    invperm = jnp.argsort(perm)
    p1 = jnp.exp(logits_1)[perm]
    p2 = jnp.exp(logits_2)[perm]
    p1_bins = jnp.concatenate([jnp.array([0.]), jnp.cumsum(p1)])
    p2_bins = jnp.concatenate([jnp.array([0.]), jnp.cumsum(p2)])

    # Value in bin (i, j): overlap between bin ranges
    def get(i, j):
        left = jnp.maximum(p1_bins[i], p2_bins[j])
        right = jnp.minimum(p1_bins[i + 1], p2_bins[j + 1])
        return jnp.where(left < right, right - left, 0.0)

    return jax.vmap(lambda i: jax.vmap(lambda j: get(i, j))(invperm))(invperm)
 def knn(self, instance):
     distances = [
         self.euclidean_distance(instance, x) for x in self.locations
     ]
     # the above could be down in parallel
     k_neighbors = jnp.argsort(distances[:self.k])
     # get the nearest k points
     vote = Counter(self.labels[k_neighbors])
     return vote.most_common()[0][0]