def test_autograd_sinkhorn(self, lse_mode):
        """Test gradient w.r.t. probability weights."""
        d = 3
        for n, m in ((11, 13), (15, 9)):
            eps = 1e-3  # perturbation magnitude
            keys = jax.random.split(self.rng, 5)
            x = jax.random.uniform(keys[0], (n, d))
            y = jax.random.uniform(keys[1], (m, d))
            a = jax.random.uniform(keys[2], (n, )) + eps
            b = jax.random.uniform(keys[3], (m, )) + eps
            # Adding zero weights to test proper handling
            a = jax.ops.index_update(a, 0, 0)
            b = jax.ops.index_update(b, 3, 0)
            a = a / jnp.sum(a)
            b = b / jnp.sum(b)
            geom = pointcloud.PointCloud(x, y, epsilon=0.1)

            def reg_ot(a, b):
                return sinkhorn.sinkhorn(geom, a=a, b=b,
                                         lse_mode=lse_mode).reg_ot_cost

            reg_ot_and_grad = jax.jit(jax.value_and_grad(reg_ot))
            _, grad_reg_ot = reg_ot_and_grad(a, b)
            delta = jax.random.uniform(keys[4], (n, ))
            delta = delta * (a > 0)  # ensures only perturbing non-zero coords.
            delta = delta - jnp.sum(delta) / jnp.sum(
                a > 0)  # center perturbation
            delta = delta * (a > 0)  # ensures only perturbing non-zero coords.
            reg_ot_delta_plus = reg_ot(a + eps * delta, b)
            reg_ot_delta_minus = reg_ot(a - eps * delta, b)
            delta_dot_grad = jnp.nansum(delta * grad_reg_ot)
            self.assertIsNot(jnp.any(jnp.isnan(delta_dot_grad)), True)
            self.assertAllClose(delta_dot_grad,
                                (reg_ot_delta_plus - reg_ot_delta_minus) /
                                (2 * eps),
                                rtol=1e-03,
                                atol=1e-02)
Пример #2
0
    def update(self, transition_batch):
        r"""

        Update the model parameters (weights) of the underlying function approximator.

        Parameters
        ----------
        transition_batch : TransitionBatch

            A batch of transitions.

        Returns
        -------
        metrics : dict of scalar ndarrays

            The structure of the metrics dict is ``{name: score}``.

        """
        grads, function_state, metrics = self.grads_and_metrics(
            transition_batch)
        if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
            raise RuntimeError(f"found nan's in grads: {grads}")
        self.apply_grads(grads, function_state)
        return metrics
Пример #3
0
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
                        *, update_jaxpr, update_consts, dimension_numbers,
                        indices_are_sorted, unique_indices, mode):
    """Checks if indices are within bounds and update does not generate NaN."""
    out = prim.bind(operand,
                    indices,
                    updates,
                    update_jaxpr=update_jaxpr,
                    update_consts=update_consts,
                    dimension_numbers=dimension_numbers,
                    indices_are_sorted=indices_are_sorted,
                    unique_indices=unique_indices,
                    mode=mode)

    if ErrorCategory.OOB not in enabled_errors:
        return out, error

    in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers)
    oob_msg = f'out-of-bounds indexing while updating at {summary()}'
    oob_error = assert_func(error, in_bounds, oob_msg, None)

    no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
    nan_msg = f'nan generated by primitive {prim.name} at {summary()}'
    return out, assert_func(oob_error, no_nans, nan_msg, None)
Пример #4
0
def grad_eigh(w, v, vg):
    """Gradient for eigenvalues and vectors of a symmetric matrix.

    Parameters
    ----------
    w: eigenvalues

    v: eigenvectors

    vg: adjoint eigenvectors
    """
    vc = v  # real
    N = 3
    # wg, vg = g          # Gradient w.r.t. eigenvalues, eigenvectors.
    w_repeated = np.repeat(w[..., np.newaxis], N, axis=-1)
    # Eigenvalue part (disabled)
    # vjp_temp = np.dot(vc * wg[..., np.newaxis, :], v.T)

    # Add eigenvector part only if non-zero backward signal is present.
    # This can avoid NaN results for degenerate cases if the function depends
    # on the eigenvalues only.

    if np.any(vg):
        off_diag = np.ones((N, N)) - np.eye(N)
        F = off_diag / (w_repeated.T - w_repeated + np.eye(N))
        # (this used to be += but we never do derivatives w.r.t. eigenvalues)
        vjp_temp = np.dot(np.dot(vc, F * np.dot(v.T, vg)), v.T)
    else:
        assert 0

    off_diag_mask = (onp.ones((3, 3)) - onp.eye(3)) / 2

    final = vjp_temp * np.eye(
        vjp_temp.shape[-1]) + (vjp_temp + vjp_temp.T) * off_diag_mask

    return final
Пример #5
0
    def test_sparse_inputs(self, act, kernel):
        key = random.PRNGKey(1)

        input_count = 4
        sparse_count = 2
        input_size = 128
        width = 4096

        # NOTE(schsam): It seems that convergence is slower when inputs are sparse.
        samples = N_SAMPLES

        if xla_bridge.get_backend().platform == 'gpu':
            jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-4
            samples = 100 * N_SAMPLES
        else:
            jtu._default_tolerance[np.onp.dtype(np.onp.float32)] = 5e-2
            jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-3

        # a batch of dense inputs
        x_dense = random.normal(key, (input_count, input_size))
        x_sparse = ops.index_update(x_dense, ops.index[:sparse_count, :], 0.)

        activation = stax.Relu() if act == 'relu' else stax.Erf()

        init_fn, apply_fn, kernel_fn = stax.serial(
            stax.Dense(width), activation,
            stax.Dense(1 if kernel == 'ntk' else width))
        exact = kernel_fn(x_sparse, None, kernel)
        mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn,
                                               random.split(key, 2)[0],
                                               samples)(x_sparse, None, kernel)
        mc = np.reshape(mc, exact.shape)

        assert not np.any(np.isnan(exact))
        self.assertAllClose(exact[sparse_count:, sparse_count:],
                            mc[sparse_count:, sparse_count:], True)
Пример #6
0
def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
    """Empirically choose initial step size.

  Args:
    fun: Function to evaluate like `func(y, t)` to compute the time
      derivative of `y`.
    t0: initial time.
    y0: initial value for the state.
    order: order of interpolation
    rtol: relative local error tolerance for solver.
    atol: absolute local error tolerance for solver.
    f0: initial value for the derivative, computed from `func(t0, y0)`.
  Returns:
    Initial step size for odeint algorithm.

  Algorithm from:
  E. Hairer, S. P. Norsett G. Wanner,
  Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
  """
    scale = atol + np.abs(y0) * rtol
    d0 = np.linalg.norm(y0 / scale)
    d1 = np.linalg.norm(f0 / scale)
    order_pow = (1. / (order + 1.))

    h0 = np.where(np.any(np.asarray([d0 < 1e-5, d1 < 1e-5])), 1e-6,
                  0.01 * d0 / d1)

    y1 = y0 + h0 * f0
    f1 = fun(y1, t0 + h0)
    d2 = (np.linalg.norm(f1 - f0) / scale) / h0

    h1 = np.where(np.all(np.asarray([d1 <= 1e-15, d2 < 1e-15])),
                  np.maximum(1e-6, h0 * 1e-3),
                  (0.01 / np.max(d1 + d2))**order_pow)

    return np.minimum(100. * h0, h1)
Пример #7
0
def newton_tol(fn, jac_fn, U,tol):
    maxit=20
    count = 0
    res = 100
    fail = 0
    Uold = U
    
    while(count < maxit and res > tol):
        J =  jac_fn(U, Uold)
#        J = jacrev(fn)(U,Uold)
#        Jsparse = csr_matrix(J)
        y = fn(U,Uold)
        res = max(abs(y/norm(y,2)))
        print(count, res)
        delta = solve(J,y)
#        delta = jitsolve(J,fn(U, Uold))
#        delta = spsolve(csr_matrix(J),fn(U,Uold))
        U = U - delta
        count = count + 1
    
        
    if fail ==0 and np.any(np.isnan(delta)):
        fail = 1
        print("nan solution")
        
    if fail == 0 and max(abs(np.imag(delta))) > 0:
            fail = 1
            print("solution complex")
    
    if fail == 0 and res > tol:
        fail = 1;
        print('Newton fail: no convergence')
    else:
        fail == 0 
        
    return U, fail
Пример #8
0
def log_density(model,
                model_args,
                model_kwargs,
                params,
                skip_dist_transforms=False):
    """
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :param bool skip_dist_transforms: whether to compute log probability of a site
        (if its prior is a transformed distribution) in its base distribution
        domain.
    :return: log of joint density and a corresponding model trace
    """
    # We skip transforms in
    #   + autoguide's model
    #   + hmc's model
    # We apply transforms in
    #   + autoguide's guide
    #   + svi's model + guide
    if skip_dist_transforms:
        model = substitute(model, base_param_map=params)
    else:
        model = substitute(model, param_map=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = 0.
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            mask = site['mask']
            scale = site['scale']
            # Early exit when all elements are masked
            if not_jax_tracer(mask) and mask is not None and not np.any(mask):
                return jax.device_put(0.), model_trace
            if intermediates:
                if skip_dist_transforms:
                    log_prob = site['fn'].base_dist.log_prob(
                        intermediates[0][0])
                else:
                    log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            # Minor optimizations
            # XXX: note that this may not work correctly for dynamic masks, provide
            # explicit jax.DeviceArray for masking.
            if mask is not None:
                if scale is not None:
                    log_prob = np.where(mask, scale * log_prob, 0.)
                else:
                    log_prob = np.where(mask, log_prob, 0.)
            else:
                if scale is not None:
                    log_prob = scale * log_prob
            log_prob = np.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace
Пример #9
0
 def cond(vals):
     sigma, phi, phiprime, unconverged, j = vals
     return np.logical_and(np.any(unconverged), j < maxiter)
Пример #10
0
def ancom(
    table,
    grouping,
    alpha=0.05,
    tau=0.02,
    theta=0.1,
    multiple_comparisons_correction=None,
    significance_test=None,
):
    r"""Performs a differential abundance test using ANCOM.

    This is done by calculating pairwise log ratios between all features
    and performing a significance test to determine if there is a significant
    difference in feature ratios with respect to the variable of interest.

    In an experiment with only two treatments, this test tests the following
    hypothesis for feature :math:`i`

    .. math::

        H_{0i}: \mathbb{E}[\ln(u_i^{(1)})] = \mathbb{E}[\ln(u_i^{(2)})]

    where :math:`u_i^{(1)}` is the mean abundance for feature :math:`i` in the
    first group and :math:`u_i^{(2)}` is the mean abundance for feature
    :math:`i` in the second group.

    Parameters
    ----------
    table : pd.DataFrame
        A 2D matrix of strictly positive values (i.e. counts or proportions)
        where the rows correspond to samples and the columns correspond to
        features.
    grouping : pd.Series
        Vector indicating the assignment of samples to groups.  For example,
        these could be strings or integers denoting which group a sample
        belongs to.  It must be the same length as the samples in `table`.
        The index must be the same on `table` and `grouping` but need not be
        in the same order.
    alpha : float, optional
        Significance level for each of the statistical tests.
        This can can be anywhere between 0 and 1 exclusive.
    tau : float, optional
        A constant used to determine an appropriate cutoff.
        A value close to zero indicates a conservative cutoff.
        This can can be anywhere between 0 and 1 exclusive.
    theta : float, optional
        Lower bound for the proportion for the W-statistic.
        If all W-statistics are lower than theta, then no features
        will be detected to be differentially significant.
        This can can be anywhere between 0 and 1 exclusive.
    multiple_comparisons_correction : {None, 'holm-bonferroni'}, optional
        The multiple comparison correction procedure to run.  If None,
        then no multiple comparison correction procedure will be run.
        If 'holm-boniferroni' is specified, then the Holm-Boniferroni
        procedure [1]_ will be run.
    significance_test : function, optional
        A statistical significance function to test for significance between
        classes.  This function must be able to accept at least two 1D
        array_like arguments of floats and returns a test statistic and a
        p-value. By default ``scipy.stats.f_oneway`` is used.

    Returns
    -------
    pd.DataFrame
        A table of features, their W-statistics and whether the null hypothesis
        is rejected.

        `"W"` is the W-statistic, or number of features that a single feature
        is tested to be significantly different against.

        `"reject"` indicates if feature is significantly different or not.

    See Also
    --------
    multiplicative_replacement
    scipy.stats.ttest_ind
    scipy.stats.f_oneway
    scipy.stats.wilcoxon
    scipy.stats.kruskal

    Notes
    -----
    The developers of this method recommend the following significance tests
    ([2]_, Supplementary File 1, top of page 11): the standard parametric
    t-test (``scipy.stats.ttest_ind``) or one-way ANOVA
    (``scipy.stats.f_oneway``) if the number of groups is greater
    than 2, or non-parametric variants such as Wilcoxon rank sum
    (``scipy.stats.wilcoxon``) or Kruskal-Wallis (``scipy.stats.kruskal``)
    if the number of groups is greater than 2.  Because one-way ANOVA is
    equivalent to the standard t-test when the number of groups is two,
    we default to ``scipy.stats.f_oneway`` here, which can be used when
    there are two or more groups.  Users should refer to the documentation
    of these tests in SciPy to understand the assumptions made by each test.

    This method cannot handle any zero counts as input, since the logarithm
    of zero cannot be computed.  While this is an unsolved problem, many
    studies have shown promising results by replacing the zeros with pseudo
    counts. This can be also be done via the ``multiplicative_replacement``
    method.

    References
    ----------
    .. [1] Holm, S. "A simple sequentially rejective multiple test procedure".
       Scandinavian Journal of Statistics (1979), 6.
    .. [2] Mandal et al. "Analysis of composition of microbiomes: a novel
       method for studying microbial composition", Microbial Ecology in Health
       & Disease, (2015), 26.

    Examples
    --------
    First import all of the necessary modules:

    >>> import mushi.composition as cmp
    >>> import pandas as pd

    Now let's load in a pd.DataFrame with 6 samples and 7 unknown bacteria:

    >>> table = pd.DataFrame([[12, 11, 10, 10, 10, 10, 10],
    ...                       [9,  11, 12, 10, 10, 10, 10],
    ...                       [1,  11, 10, 11, 10, 5,  9],
    ...                       [22, 21, 9,  10, 10, 10, 10],
    ...                       [20, 22, 10, 10, 13, 10, 10],
    ...                       [23, 21, 14, 10, 10, 10, 10]],
    ...                      index=['s1','s2','s3','s4','s5','s6'],
    ...                      columns=['b1','b2','b3','b4','b5','b6','b7'])

    Then create a grouping vector.  In this scenario, there
    are only two classes, and suppose these classes correspond to the
    treatment due to a drug and a control.  The first three samples
    are controls and the last three samples are treatments.

    >>> grouping = pd.Series([0, 0, 0, 1, 1, 1],
    ...                      index=['s1','s2','s3','s4','s5','s6'])

    Now run ``ancom`` and see if there are any features that have any
    significant differences between the treatment and the control.

    >>> results = cmp.ancom(table, grouping) # doctest: +SKIP
    >>> results['W'] # doctest: +SKIP
    b1    0
    b2    4
    b3    1
    b4    1
    b5    1
    b6    0
    b7    1
    Name: W, dtype: np.int64

    The W-statistic is the number of features that a single feature is tested
    to be significantly different against.  In this scenario, `b2` was detected
    to have significantly different abundances compared to four of the other
    species. To summarize the results from the W-statistic, let's take a look
    at the results from the hypothesis test:

    >>> results['reject'] # doctest: +SKIP
    b1    False
    b2     True
    b3    False
    b4    False
    b5    False
    b6    False
    b7    False
    Name: reject, dtype: bool

    From this we can conclude that only `b2` was significantly
    different between the treatment and the control.

    """

    if not isinstance(table, pd.DataFrame):
        raise TypeError("`table` must be a `pd.DataFrame`, "
                        "not %r." % type(table).__name__)
    if not isinstance(grouping, pd.Series):
        raise TypeError("`grouping` must be a `pd.Series`,"
                        " not %r." % type(grouping).__name__)

    if np.any(table <= 0):
        raise ValueError(
            "Cannot handle zeros or negative values in `table`. "
            "Use pseudo counts or ``multiplicative_replacement``.")

    if not 0 < alpha < 1:
        raise ValueError("`alpha`=%f is not within 0 and 1." % alpha)

    if not 0 < tau < 1:
        raise ValueError("`tau`=%f is not within 0 and 1." % tau)

    if not 0 < theta < 1:
        raise ValueError("`theta`=%f is not within 0 and 1." % theta)

    if multiple_comparisons_correction is not None:
        if multiple_comparisons_correction != "holm-bonferroni":
            raise ValueError("%r is not an available option for "
                             "`multiple_comparisons_correction`." %
                             multiple_comparisons_correction)

    if (grouping.isnull()).any():
        raise ValueError("Cannot handle missing values in `grouping`.")

    if (table.isnull()).any().any():
        raise ValueError("Cannot handle missing values in `table`.")

    groups, _grouping = onp.unique(grouping, return_inverse=True)
    grouping = pd.Series(_grouping, index=grouping.index)
    num_groups = len(groups)

    if num_groups == len(grouping):
        raise ValueError(
            "All values in `grouping` are unique. This method cannot "
            "operate on a grouping vector with only unique values (e.g., "
            "there are no 'within' variance because each group of samples "
            "contains only a single sample).")

    if num_groups == 1:
        raise ValueError(
            "All values the `grouping` are the same. This method cannot "
            "operate on a grouping vector with only a single group of samples"
            "(e.g., there are no 'between' variance because there is only a "
            "single group).")

    if significance_test is None:
        significance_test = scipy.stats.f_oneway

    table_index_len = len(table.index)
    grouping_index_len = len(grouping.index)
    mat, cats = table.align(grouping, axis=0, join="inner")
    if len(mat) != table_index_len or len(cats) != grouping_index_len:
        raise ValueError("`table` index and `grouping` "
                         "index must be consistent.")

    n_feat = mat.shape[1]

    _logratio_mat = _log_compare(mat.values, cats.values, significance_test)
    logratio_mat = _logratio_mat + _logratio_mat.T

    # Multiple comparisons
    if multiple_comparisons_correction == "holm-bonferroni":
        logratio_mat = np.apply_along_axis(_holm_bonferroni, 1, logratio_mat)
    np.fill_diagonal(logratio_mat, 1)
    W = (logratio_mat < alpha).sum(axis=1)
    c_start = W.max() / n_feat
    if c_start < theta:
        reject = np.zeros_like(W, dtype=bool)
    else:
        # Select appropriate cutoff
        cutoff = c_start - np.linspace(0.05, 0.25, 5)
        prop_cut = np.array([(W > n_feat * cut).mean() for cut in cutoff])
        dels = np.abs(prop_cut - np.roll(prop_cut, -1))
        dels[-1] = 0

        if (dels[0] < tau) and (dels[1] < tau) and (dels[2] < tau):
            nu = cutoff[1]
        elif (dels[0] >= tau) and (dels[1] < tau) and (dels[2] < tau):
            nu = cutoff[2]
        elif (dels[1] >= tau) and (dels[2] < tau) and (dels[3] < tau):
            nu = cutoff[3]
        else:
            nu = cutoff[4]
        reject = W >= nu * n_feat
    labs = mat.columns
    return pd.DataFrame({
        "W": pd.Series(W, index=labs),
        "reject": pd.Series(reject, index=labs)
    })
Пример #11
0
    def neighbor_list_fn(position: Array,
                         neighbors: Optional[NeighborList] = None,
                         extra_capacity: int = 0,
                         **kwargs) -> NeighborList:
        nbrs = neighbors

        def neighbor_fn(position_and_overflow, max_occupancy=None):
            position, overflow = position_and_overflow
            N = position.shape[0]

            if use_cell_list:
                if neighbors is None:
                    cl = cl_fn.allocate(position,
                                        extra_capacity=extra_capacity)
                else:
                    cl = cl_fn.update(position, neighbors.cell_list_capacity)
                overflow = overflow | cl.did_buffer_overflow
                idx = cell_list_candidate_fn(cl, position)
                cl_capacity = cl.cell_capacity
            else:
                cl_capacity = None
                idx = candidate_fn(position)

            if mask_self:
                idx = mask_self_fn(idx)
            if custom_mask_function is not None:
                idx = custom_mask_function(idx)

            if is_sparse(format):
                idx, occupancy = prune_neighbor_list_sparse(
                    position, idx, **kwargs)
            else:
                idx, occupancy = prune_neighbor_list_dense(
                    position, idx, **kwargs)

            if max_occupancy is None:
                _extra_capacity = (extra_capacity if not is_sparse(format) else
                                   N * extra_capacity)
                max_occupancy = int(occupancy * capacity_multiplier +
                                    _extra_capacity)
                if max_occupancy > position.shape[0] and not is_sparse(format):
                    max_occupancy = position.shape[0]
                if max_occupancy > occupancy:
                    padding = max_occupancy - occupancy
                    pad = N * jnp.ones(
                        (idx.shape[0], padding), dtype=idx.dtype)
                    idx = jnp.concatenate([idx, pad], axis=1)
            idx = idx[:, :max_occupancy]
            update_fn = (neighbor_list_fn
                         if neighbors is None else neighbors.update_fn)
            return NeighborList(idx, position,
                                overflow | (occupancy >= max_occupancy),
                                cl_capacity, max_occupancy, format, update_fn)  # pytype: disable=wrong-arg-count

        if nbrs is None:
            return neighbor_fn((position, False))

        neighbor_fn = partial(neighbor_fn, max_occupancy=nbrs.max_occupancy)

        d = partial(metric_sq, **kwargs)
        d = vmap(d)
        return lax.cond(
            jnp.any(d(position, nbrs.reference_position) > threshold_sq),
            (position, nbrs.did_buffer_overflow), neighbor_fn, nbrs,
            lambda x: x)
Пример #12
0
    def _beam_search(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        early_stopping: Optional[bool] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
    ):
        """
        This beam search function is heavily inspired by Flax's official example:
        https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
        """
        def flatten_beam_dim(tensor):
            """Flattens the first two dimensions of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((tensor.shape[0] * tensor.shape[1], ) +
                                  tensor.shape[2:])

        def unflatten_beam_dim(tensor, batch_size, num_beams):
            """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])

        def gather_beams(nested, beam_indices, batch_size, new_num_beams):
            """
            Gathers the beam slices indexed by beam_indices into new beam array.
            """
            batch_indices = jnp.reshape(
                jnp.arange(batch_size * new_num_beams) // new_num_beams,
                (batch_size, new_num_beams))

            def gather_fn(tensor):
                # ignore scalars (e.g. cache index)
                if tensor.ndim == 0:
                    return tensor
                else:
                    return tensor[batch_indices, beam_indices]

            return jax.tree_map(gather_fn, nested)

        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping

        batch_size, num_beams, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch,beam-item holding current token in loop.
        sequences = jnp.full((batch_size, num_beams, max_length),
                             pad_token_id,
                             dtype=jnp.int32)
        running_sequences = jnp.full((batch_size, num_beams, max_length),
                                     pad_token_id,
                                     dtype=jnp.int32)
        running_sequences = lax.dynamic_update_slice(sequences, input_ids,
                                                     (0, 0, 0))

        # per batch,beam-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)

        # per batch,beam-item score, logprobs
        running_scores = jnp.tile(
            jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)),
            [batch_size, 1])
        scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # flatten beam dim
        if "encoder_outputs" in model_kwargs:
            model_kwargs["encoder_outputs"][
                "last_hidden_state"] = flatten_beam_dim(
                    model_kwargs["encoder_outputs"]["last_hidden_state"])
        if "attention_mask" in model_kwargs:
            model_kwargs["attention_mask"] = flatten_beam_dim(
                model_kwargs["attention_mask"])

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(
            flatten_beam_dim(input_ids), max_length, **model_kwargs)

        # initialize state
        state = BeamSearchState(
            cur_len=cur_len,
            running_sequences=running_sequences,
            running_scores=running_scores,
            sequences=sequences,
            scores=scores,
            is_sent_finished=is_sent_finished,
            model_kwargs=model_kwargs,
        )

        def beam_search_cond_fn(state):
            """beam search state termination condition fn."""

            # 1. is less than max length?
            not_max_length_yet = state.cur_len < max_length

            # 2. can the new beams still improve?
            best_running_score = state.running_scores[:, -1:] / (
                max_length**length_penalty)
            worst_finished_score = jnp.where(
                state.is_sent_finished,
                jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7))
            improvement_still_possible = jnp.all(
                worst_finished_score < best_running_score)

            # 3. is there still a beam that has not finished?
            still_open_beam = ~(jnp.all(state.is_sent_finished)
                                & early_stopping)

            return not_max_length_yet & still_open_beam & improvement_still_possible

        def beam_search_body_fn(state, input_ids_length=1):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(
                    state.running_sequences,
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
                ))
            model_outputs = model(input_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = unflatten_beam_dim(model_outputs.logits[:, -1],
                                        batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams
                                                  ),
                model_outputs.past_key_values)

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(flatten_beam_dim(running_sequences),
                                         flatten_beam_dim(log_probs),
                                         state.cur_len)
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores,
                                                    axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs,
                                                     k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(state.running_sequences,
                                                  topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences,
                                                      topk_ids,
                                                      (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.
                                                    cur_len] == eos_token_id
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(
                -1.0e7)
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs,
                                                   k=num_beams)[1],
                                         axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, running_topk_log_probs], next_topk_indices,
                batch_size, num_beams)

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
            beams_in_batch_are_full = (jnp.broadcast_to(
                state.is_sent_finished.all(axis=-1, keepdims=True),
                did_topk_just_finished.shape)
                                       & early_stopping)
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate(
                [state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs],
                                            axis=1)
            merged_is_sent_finished = jnp.concatenate(
                [state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores,
                                                     k=num_beams)[1],
                                           axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished],
                topk_merged_indices, batch_size, num_beams)

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices,
                                                next_topk_indices, batch_size,
                                                num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size,
                                      num_beams)
            model_outputs["past_key_values"] = jax.tree_map(
                lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        if input_ids.shape[-1] > 1:
            state = partial(beam_search_body_fn,
                            input_ids_length=input_ids.shape[-1])(state)

        if not trace:
            state = self._run_loop_in_debug(beam_search_cond_fn,
                                            beam_search_body_fn, state)
        else:
            state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn,
                                   state)

        # Account for the edge-case where there are no finished sequences for a
        # particular batch item. If so, return running sequences for that batch item.
        none_finished = jnp.any(state.is_sent_finished, axis=1)
        sequences = jnp.where(none_finished[:, None, None], state.sequences,
                              state.running_sequences)
        scores = jnp.where(none_finished[:, None], state.scores,
                           state.running_scores)

        # take best beam for each batch
        sequences = sequences[:, -1]
        scores = scores[:, -1]

        return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
Пример #13
0
    def __call__(self,
                 Y,
                 Sigma,
                 mu0,
                 Gamma0,
                 Omega,
                 *control_params,
                 maxiter=None,
                 tol=1e-5,
                 momentum=0.,
                 omega_diag_range=(0, jnp.inf),
                 sigma_diag_range=(0, jnp.inf),
                 omega_window=3,
                 sigma_window=3,
                 beg=None,
                 quack=None):
        """
        Perform HMM inference on a sequence of data.

        Args:
            Y: [T, N] Data of dimension N sequence of length T
            Sigma: Initial estimate of observational covariance broadcastable to [T, N, N].
            mu0: [K] or [T, K] Estimate of initial timestep mean. If shape [T, K] it is taken to be initial estimate of
                post_mu.
            Gamma0: [K, K] or [T, K, K] Estimate of initial timestep state covariance. If shape [T, K, K] it is
                taken to be initial estimate of post_Gamma.
            Omega: Initial estimate of Levy step covariance broadcastable to [T, K, K].
            *control_params: List of control params with first dim size T.
            maxiter: int or None. Number of iterations to run, or until convergence if None. Note, convregence may not
                occur.
            tol: Convergence tolerance in post_mu broadcastable to [K]. Norm_inf is used to check.
            momentum: Momentum broadcastable to [4]. Momentum for [mu0, Gamma0, Omega, Sigma]
            omega_diag_range: 2-tuple of lower and upper bounds for omega, with each element broadcastable to [K].
            sigma_diag_range:2-tuple of lower and upper bounds for sigma, with each element broadcastable to [N].
            omega_window: int or None. Moving average window size for Omega. In None then use whole data sequence.
            sigma_window: int or None. Moving average window size Sigma. In None then use whole data sequence.
            beg: int, prepad `beg` elements to avoid large initial uncertainties in the sequence.
            quack: int, postpad `quack` elements to avoid large final uncertainties in the sequence.

        Returns:
            NonLinearDynamicsSmootherResults containing the results of the run.
                converged: bool  # Whether termination due to convergence
                niter: int  # number of iterations used
                post_mu: jnp.ndarray  # posterior mean
                post_Gamma: jnp.ndarray  # posterior convariance
                Sigma: jnp.ndarray  # estimated obs. covariance
                Omega: jnp.ndarray  # estimated transition covariance
                mu0: jnp.ndarray  # estimated initial mu
                Gamma0: jnp.ndarray  # estimate initial covariance

        """
        NonLinearDynamicsSmootherState = namedtuple(
            'NonLinearDynamicsSmootherState', [
                'done', 'i', 'Sigma_i1', 'post_mu', 'post_Gamma', 'Omega_i1',
                'mu0_i1', 'Gamma0_i1'
            ])
        if maxiter is None:
            maxiter = jnp.inf
        if maxiter <= 0:
            raise ValueError("maxiter {} should be > 0".format(maxiter))
        tol = jnp.array(tol)
        if jnp.any(tol) < 0:
            raise ValueError("tol {} should be > 0".format(tol))
        omega_diag_range = (jnp.array(omega_diag_range[0]),
                            jnp.array(omega_diag_range[1]))
        sigma_diag_range = (jnp.array(sigma_diag_range[0]),
                            jnp.array(sigma_diag_range[1]))

        momentum = jnp.array(momentum)
        momentum = jnp.broadcast_to(momentum, (4, ))

        if jnp.any(momentum > 1.) | jnp.any(momentum < 0.):
            raise ValueError("Momentum {} must be in (0,1)".format(momentum))

        Sigma = jnp.broadcast_to(Sigma, Y.shape[0:1] + Sigma.shape[-2:])
        Omega = jnp.broadcast_to(Omega, Y.shape[0:1] + Omega.shape[-2:])
        mu0 = jnp.broadcast_to(mu0, Y.shape[0:1] + Omega.shape[-1:])
        Gamma0 = jnp.broadcast_to(Gamma0, Y.shape[0:1] + Omega.shape[-2:])

        if beg is None:
            beg = 0
        if quack is None:
            quack = 0
        beg = jnp.array(beg)
        quack = jnp.array(quack)

        def _pad(v, beg, quack):
            dims = len(v.shape)
            pad_width = tuple([(beg, quack)] + [(0, 0)] * (dims - 1))
            return jnp.pad(v, pad_width, mode='reflect')

        # technique to avoid large uncertainties at the start and end
        Y = _pad(Y, beg, quack)
        Sigma = _pad(Sigma, beg, quack)
        Omega = _pad(Omega, beg, quack)
        mu0 = _pad(mu0, beg, quack)
        Gamma0 = _pad(Gamma0, beg, quack)
        control_params = [_pad(v, beg, quack) for v in control_params]

        state = NonLinearDynamicsSmootherState(done=False,
                                               i=0,
                                               post_mu=mu0,
                                               post_Gamma=Gamma0,
                                               Sigma_i1=Sigma,
                                               Omega_i1=Omega,
                                               mu0_i1=mu0[0, ...],
                                               Gamma0_i1=Gamma0[0, ...])

        def body(state):
            prior_Gamma, post_mu_f, post_Gamma_f = self.forward_filter(
                Y, state.Sigma_i1, state.mu0_i1, state.Gamma0_i1,
                state.Omega_i1, *control_params)

            post_mu_b, post_Gamma_b, inter_Gamma = self.backward_filter(
                prior_Gamma, post_mu_f, post_Gamma_f)

            mu0_i, Gamma0_i, Sigma_i, Omega_i = self.parameter_estimation(
                post_mu_b,
                post_Gamma_b,
                inter_Gamma,
                Y,
                *control_params,
                omega_window=omega_window,
                sigma_window=sigma_window)

            mu0_i = state.mu0_i1 * momentum[0] + (1. - momentum[0]) * mu0_i
            Gamma0_i = state.Gamma0_i1 * momentum[1] + (1. -
                                                        momentum[1]) * Gamma0_i
            Omega_i = self.clip_covariance_diag(
                state.Omega_i1 * momentum[2] + (1. - momentum[2]) * Omega_i,
                omega_diag_range[0], omega_diag_range[1])
            Sigma_i = self.clip_covariance_diag(
                state.Sigma_i1 * momentum[3] + (1. - momentum[3]) * Sigma_i,
                sigma_diag_range[0], sigma_diag_range[1])

            max_norm = jnp.max(jnp.linalg.norm(state.post_mu - post_mu_b,
                                               axis=-1),
                               axis=0)
            converged = jnp.all(max_norm < tol) & (state.i > 0)

            state = state._replace(done=converged,
                                   i=state.i + 1,
                                   post_mu=post_mu_b,
                                   post_Gamma=post_Gamma_b,
                                   Sigma_i1=Sigma_i,
                                   Omega_i1=Omega_i,
                                   mu0_i1=mu0_i,
                                   Gamma0_i1=Gamma0_i)

            return state

        state = while_loop(lambda state: (~state.done) & (state.i < maxiter),
                           body, state)

        def _trim(v, beg, quack):
            T = v.shape[0]
            return v[beg:T - quack, ...]

        return NonLinearDynamicsSmootherResults(
            converged=state.done,
            niter=state.i,
            post_mu=_trim(state.post_mu, beg, quack),
            post_Gamma=_trim(state.post_Gamma, beg, quack),
            Sigma=_trim(state.Sigma_i1, beg, quack),
            Omega=_trim(state.Omega_i1, beg, quack),
            mu0=state.post_mu[beg, ...],  #actual estimate of beginning is this
            Gamma0=state.post_Gamma[beg,
                                    ...]  #actual estimate of beginning is this
        )
Пример #14
0
def _newton_update(weights_0,
                   X,
                   XX_T,
                   target,
                   k,
                   method_,
                   maxiter=int(1024),
                   ftol=1e-12,
                   gtol=1e-8,
                   reg_lambda=0.0,
                   reg_mu=None,
                   ref_row=True,
                   initializer=None,
                   reg_format=None):

    L_list = [
        float(
            _objective(weights_0, X, XX_T, target, k, method_, reg_lambda,
                       reg_mu, ref_row, initializer, reg_format))
    ]

    weights = weights_0.copy()

    # TODO move this to the initialization
    if method_ is None:
        weights = jax_np.zeros_like(weights)

    for i in range(0, maxiter):

        gradient = _gradient(weights, X, XX_T, target, k, method_, reg_lambda,
                             reg_mu, ref_row, initializer, reg_format)

        if jax_np.abs(gradient).sum() < gtol:
            break

        # FIXME hessian is ocasionally NaN
        hessian = _hessian(weights, X, XX_T, target, k, method_, reg_lambda,
                           reg_mu, ref_row, initializer, reg_format)

        if method_ == 'FixDiag':
            updates = gradient / hessian
        else:
            try:
                inverse = scipy.linalg.pinv2(hessian)
                updates = jax_np.matmul(inverse, gradient)
            except (np.linalg.LinAlgError, ValueError) as err:
                logging.error(err)
                updates = gradient

        for step_size in jax_np.hstack(
            (jax_np.linspace(1, 0.1, 10), jax_np.logspace(-2, -32, 31))):

            tmp_w = weights - (updates * step_size).ravel()

            if jax_np.any(jax_np.isnan(tmp_w)):
                logging.debug("{}: There are NaNs in tmp_w".format(method_))

            L = _objective(tmp_w, X, XX_T, target, k, method_, reg_lambda,
                           reg_mu, ref_row, initializer, reg_format)

            if (L - L_list[-1]) < 0:
                break

        L_list.append(float(L))

        logging.debug(
            "{}: after {} iterations log-loss = {:.7e}, sum_grad = {:.7e}".
            format(method_, i, L,
                   jax_np.abs(gradient).sum()))

        if jax_np.isnan(L):
            logging.error("{}: log-loss is NaN".format(method_))
            break

        if i >= 5:
            if (float(np.min(np.diff(L_list[-5:]))) > -ftol) & \
               (float(np.sum(np.diff(L_list[-5:])) > 0) == 0):
                weights = tmp_w.copy()
                logging.debug(
                    '{}: Terminate as there is not enough changes on loss.'.
                    format(method_))
                break

        if (L_list[-1] - L_list[-2]) > 0:
            logging.debug('{}: Terminate as the loss increased {}.'.format(
                method_, jax_np.diff(L_list[-2:])))
            break
        else:
            weights = tmp_w.copy()

    L = _objective(weights, X, XX_T, target, k, method_, reg_lambda, reg_mu,
                   ref_row, initializer, reg_format)

    logging.debug(
        "{}: after {} iterations final log-loss = {:.7e}, sum_grad = {:.7e}".
        format(method_, i, L,
               jax_np.abs(gradient).sum()))

    return weights
Пример #15
0
def multi_ellipsoid_sampler(key, log_L_constraint, live_points_U,
                            loglikelihood_from_constrained, prior_transform,
                            sampler_state, log_X, i_min):
    """
    Does iterative multi-nest sampling with a few extra features to improve over the original algorithm.

    References:

    [1] MULTINEST: an efficient and robust Bayesian inference tool for cosmology and particle physics,
        F. Feroz et al. 2008. https://arxiv.org/pdf/0809.3437.pdf

    Args:
        key:
        log_L_constraint:
        live_points_U:
        loglikelihood_from_constrained:

    Returns:

    """
    N, D = live_points_U.shape

    # subtract the i_min
    k_from = sampler_state.cluster_id[i_min]
    num_k = dynamic_update_slice(sampler_state.num_k,
                                 sampler_state.num_k[k_from, None] - 1,
                                 k_from[None])
    sampler_state = sampler_state._replace(num_k=num_k)
    ###
    # evolve ellipsoids or potentially recalculate ellipsoids
    sampler_state = evolve_sampler_state(sampler_state, live_points_U)
    scale = 1.1**(1. / D)

    def body(state):
        (key, i, _, u_test, x_test, log_L_test) = state
        key, sample_key = random.split(key, 2)
        k, u_test = sample_multi_ellipsoid(sample_key,
                                           sampler_state.mu,
                                           sampler_state.radii * scale,
                                           sampler_state.rotation,
                                           unit_cube_constraint=True)

        x_test = prior_transform(u_test)
        log_L_test = loglikelihood_from_constrained(**x_test)
        return (key, i + 1, k, u_test, x_test, log_L_test)

    (key, num_likelihood_evaluations, ellipsoid_select, u_new,
     x_new, log_L_new) = while_loop(
         lambda state: state[-1] <= log_L_constraint, body,
         (key, 0, 0, live_points_U[0, :], prior_transform(
             live_points_U[0, :]), log_L_constraint))

    cluster_id = dynamic_update_slice(sampler_state.cluster_id,
                                      ellipsoid_select[None], i_min[None])

    num_k = dynamic_update_slice(
        sampler_state.num_k, sampler_state.num_k[ellipsoid_select, None] + 1,
        ellipsoid_select[None])
    sampler_state = sampler_state._replace(cluster_id=cluster_id, num_k=num_k)

    log_volumes = vmap(lambda radii: log_ellipsoid_volume(radii))(
        sampler_state.radii)
    log_F = logsumexp(log_volumes) - log_X

    # V(E_k) > 2 V(S_k)
    # |S_k| < D+1
    # jnp.any(log_volumes > jnp.log(sampler_state.num_k) - jnp.log(N) + log_X + jnp.log(2.))
    do_recalculate =  jnp.any(sampler_state.num_k == D) \
                    | (num_likelihood_evaluations > 3. * sampler_state.num_fev_ma) \
                    | (log_F < 0.)
    tau = 1. / N
    sampler_state = sampler_state._replace(
        num_fev_ma=sampler_state.num_fev_ma * (1. - tau) +
        tau * num_likelihood_evaluations)

    print(
        'do_recalculate', do_recalculate, 'num fev',
        num_likelihood_evaluations, '/', sampler_state.num_fev_ma, 'V(E)/V(S)',
        jnp.exp(log_F), 'V(E_k)', jnp.exp(log_volumes), '2V(S_k)',
        jnp.exp(
            jnp.log(sampler_state.num_k) - jnp.log(N) + log_X + jnp.log(2.)))

    key, recalc_key = random.split(key, 2)

    sampler_state = cond(do_recalculate,
                         lambda args: recalculate_sampler_state(*args),
                         lambda _: sampler_state,
                         (recalc_key, live_points_U, sampler_state, log_X))

    MultiEllipsoidResults = namedtuple('MultiEllipsoidResults', [
        'key', 'num_likelihood_evaluations', 'u_new', 'x_new', 'log_L_new',
        'sampler_state'
    ])
    return MultiEllipsoidResults(key, num_likelihood_evaluations, u_new, x_new,
                                 log_L_new, sampler_state)
Пример #16
0
def nerf(key, example_batch, args):
    """Neural Randiance Field.

  Args:
    key: jnp.ndarray. Random number generator.
    example_batch: dict, an example of a batch of data.
    args: FLAGS class. Hyperparameters of nerf.

  Returns:
    model: nn.Model. Nerf model with parameters.
    state: flax.Module.state. Nerf model state for stateful parameters.
  """
    net_activation = getattr(nn, str(args.net_activation))
    rgb_activation = getattr(nn, str(args.rgb_activation))
    sigma_activation = getattr(nn, str(args.sigma_activation))

    # Assert that rgb_activation always produces outputs in [0, 1], and
    # sigma_activation always produce non-negative outputs.
    x = jnp.exp(jnp.linspace(-90, 90, 1024))
    x = jnp.concatenate([-x[::-1], x], 0)

    rgb = rgb_activation(x)
    if jnp.any(rgb < 0) or jnp.any(rgb > 1):
        raise NotImplementedError(
            "Choice of rgb_activation `{}` produces colors outside of [0, 1]".
            format(args.rgb_activation))

    sigma = sigma_activation(x)
    if jnp.any(sigma < 0):
        raise NotImplementedError(
            "Choice of sigma_activation `{}` produces negative densities".
            format(args.sigma_activation))

    model_fn = NerfModel.partial(deg_point=args.deg_point,
                                 deg_view=args.deg_view,
                                 num_coarse_samples=args.num_coarse_samples,
                                 num_fine_samples=args.num_fine_samples,
                                 use_viewdirs=args.use_viewdirs,
                                 near=args.near,
                                 far=args.far,
                                 noise_std=args.noise_std,
                                 randomized=args.randomized,
                                 white_bkgd=args.white_bkgd,
                                 net_depth=args.net_depth,
                                 net_width=args.net_width,
                                 net_depth_condition=args.net_depth_condition,
                                 net_width_condition=args.net_width_condition,
                                 skip_layer=args.skip_layer,
                                 num_rgb_channels=args.num_rgb_channels,
                                 num_sigma_channels=args.num_sigma_channels,
                                 lindisp=args.lindisp,
                                 net_activation=net_activation,
                                 rgb_activation=rgb_activation,
                                 sigma_activation=sigma_activation,
                                 legacy_posenc_order=args.legacy_posenc_order)
    with nn.stateful() as init_state:
        rays = example_batch["rays"]
        key1, key2, key3 = random.split(key, num=3)
        # TODO(barron): Determine why the rays have an unused first dimension.
        _, init_params = model_fn.init(key1, key2, key3, rays.origins[0],
                                       rays.directions[0], rays.viewdirs[0])

        model = nn.Model(model_fn, init_params)
    return model, init_state
Пример #17
0
 def kalman_filter(self,
                   y,
                   dt,
                   params,
                   store=False,
                   mask=None,
                   site_params=None,
                   r=None):
     """
     Run the Kalman filter to get p(fₙ|y₁,...,yₙ).
     The Kalman update step invloves some control flow to work out whether we are
         i) initialising the sites
         ii) using supplied sites
         iii) performing a Gaussian update with fixed parameters (e.g. in posterior sampling or ELBO calc.)
     If store is True then we compute and return the intermediate filtering distributions
     p(fₙ|y₁,...,yₙ) and sites sₙ(fₙ), otherwise we do not store the intermediates and simply
     return the energy / negative log-marginal likelihood, -log p(y).
     :param y: observed data [N, obs_dim]
     :param dt: step sizes Δtₙ = tₙ - tₙ₋₁ [N, 1]
     :param params: the model parameters, i.e the hyperparameters of the prior & likelihood
     :param store: flag to notify whether to store the intermediates
     :param mask: boolean array signifying which elements of y are observed [N, obs_dim]
     :param site_params: the Gaussian approximate likelihoods [2, N, obs_dim]
     :param r: spatial input locations
     :return:
         if store is True:
             neg_log_marg_lik: the filter energy, i.e. negative log-marginal likelihood -log p(y),
                               used for hyperparameter optimisation (learning) [scalar]
             filtered_mean: intermediate filtering means [N, state_dim, 1]
             filtered_cov: intermediate filtering covariances [N, state_dim, state_dim]
             site_mean: mean of the approximate likelihood sₙ(fₙ) [N, obs_dim]
             site_cov: variance of the approximate likelihood sₙ(fₙ) [N, obs_dim]
         otherwise:
             neg_log_marg_lik: the filter energy, i.e. negative log-marginal likelihood -log p(y),
                               used for hyperparameter optimisation (learning) [scalar]
     """
     theta_prior, theta_lik = softplus_list(params[0]), softplus(params[1])
     self.update_model(
         theta_prior
     )  # all model components that are not static must be computed inside the function
     N = dt.shape[0]
     with loops.Scope() as s:
         s.neg_log_marg_lik = 0.0  # negative log-marginal likelihood
         s.m, s.P = self.minf, self.Pinf
         if store:
             s.filtered_mean = np.zeros([N, self.state_dim, 1])
             s.filtered_cov = np.zeros([N, self.state_dim, self.state_dim])
             s.site_mean = np.zeros([N, self.func_dim, 1])
             s.site_cov = np.zeros([N, self.func_dim, self.func_dim])
         for n in s.range(N):
             y_n = y[n][..., np.newaxis]
             # -- KALMAN PREDICT --
             #  mₙ⁻ = Aₙ mₙ₋₁
             #  Pₙ⁻ = Aₙ Pₙ₋₁ Aₙ' + Qₙ, where Qₙ = Pinf - Aₙ Pinf Aₙ'
             A = self.prior.state_transition(dt[n], theta_prior)
             m_ = A @ s.m
             P_ = A @ (s.P - self.Pinf) @ A.T + self.Pinf
             # --- KALMAN UPDATE ---
             # Given previous predicted mean mₙ⁻ and cov Pₙ⁻, incorporate yₙ to get filtered mean mₙ &
             # cov Pₙ and compute the marginal likelihood p(yₙ|y₁,...,yₙ₋₁)
             H = self.prior.measurement_model(r[n], theta_prior)
             predict_mean = H @ m_
             predict_cov = H @ P_ @ H.T
             if mask is not None:  # note: this is a bit redundant but may come in handy in multi-output problems
                 y_n = np.where(mask[n][..., np.newaxis],
                                predict_mean[:y_n.shape[0]],
                                y_n)  # fill in masked obs with expectation
             log_lik_n, site_mean, site_cov = self.sites.update(
                 self.likelihood, y_n, predict_mean, predict_cov, theta_lik,
                 None)
             if site_params is not None:  # use supplied site parameters to perform the update
                 site_mean, site_cov = site_params[0][n], site_params[1][n]
             # modified Kalman update (see Nickish et. al. ICML 2018 or Wilkinson et. al. ICML 2019):
             S = predict_cov + site_cov
             HP = H @ P_
             K = solve(S, HP).T  # PH'(S^-1)
             s.m = m_ + K @ (site_mean - predict_mean)
             s.P = P_ - K @ HP
             if mask is not None:  # note: this is a bit redundant but may come in handy in multi-output problems
                 s.m = np.where(np.any(mask[n]), m_, s.m)
                 s.P = np.where(np.any(mask[n]), P_, s.P)
                 log_lik_n = np.where(mask[n][..., 0],
                                      np.zeros_like(log_lik_n), log_lik_n)
             s.neg_log_marg_lik -= np.sum(log_lik_n)
             if store:
                 s.filtered_mean = index_add(s.filtered_mean, index[n, ...],
                                             s.m)
                 s.filtered_cov = index_add(s.filtered_cov, index[n, ...],
                                            s.P)
                 s.site_mean = index_add(s.site_mean, index[n, ...],
                                         site_mean)
                 s.site_cov = index_add(s.site_cov, index[n, ...], site_cov)
     if store:
         return s.neg_log_marg_lik, (s.filtered_mean, s.filtered_cov,
                                     (s.site_mean, s.site_cov))
     return s.neg_log_marg_lik
Пример #18
0
def any(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True):
    return jnp.any(a, axis=axis, keepdims=not squeeze)
Пример #19
0
 def any(self, boolean_tensor, axis=None, keepdims=False):
     return jnp.any(boolean_tensor, axis=axis, keepdims=keepdims)
Пример #20
0
def close_or_nan(delta, scale, rtol, atol):
    is_close = delta < (rtol * scale + atol)
    is_nan = np.any(np.isnan(delta))
    return np.logical_or(is_close, is_nan)
Пример #21
0
    def sample_kernel(sa_state, model_args=(), model_kwargs=None):
        pe_fn = potential_fn
        if potential_fn_gen:
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
        zs, pes, loc, scale = sa_state.adapt_state
        # we recompute loc/scale after each iteration to avoid precision loss
        # XXX: consider to expose a setting to do this job periodically
        # to save some computations
        loc = jnp.mean(zs, 0)
        if scale.ndim == 2:
            cov = jnp.cov(zs, rowvar=False, bias=True)
            if cov.shape == ():  # JAX returns scalar for 1D input
                cov = cov.reshape((1, 1))
            cholesky = jnp.linalg.cholesky(cov)
            scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky)
        else:
            scale = jnp.std(zs, 0)

        rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(
            sa_state.rng_key, 4)
        _, unravel_fn = ravel_pytree(sa_state.z)

        z = loc + _sample_proposal(scale, rng_key_z)
        pe = pe_fn(unravel_fn(z))
        pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
        diverging = (pe - sa_state.potential_energy) > max_delta_energy

        # NB: all terms having the pattern *s will have shape N x ...
        # and all terms having the pattern *s_ will have shape (N + 1) x ...
        locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z)
        zs_ = jnp.concatenate([zs, z[None, :]])
        pes_ = jnp.concatenate([pes, pe[None]])
        locs_ = jnp.concatenate([locs, loc[None, :]])
        scales_ = jnp.concatenate([scales, scale[None, ...]])
        if scale.ndim == 2:  # dense_mass
            log_weights_ = dist.MultivariateNormal(
                locs_, scale_tril=scales_).log_prob(zs_) + pes_
        else:
            log_weights_ = dist.Normal(locs_,
                                       scales_).log_prob(zs_).sum(-1) + pes_
        # mask invalid values (nan, +inf) by -inf
        log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_,
                                 -jnp.inf)
        # get rejecting index
        j = random.categorical(rng_key_reject, log_weights_)
        zs = _numpy_delete(zs_, j)
        pes = _numpy_delete(pes_, j)
        loc = locs_[j]
        scale = scales_[j]
        adapt_state = SAAdaptState(zs, pes, loc, scale)

        # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
        itr = sa_state.i + 1
        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = sa_state.mean_accept_prob + (
            accept_prob - sa_state.mean_accept_prob) / n

        # XXX: we make a modification of SA sampler in [1]
        # in [1], each MCMC state contains N points `zs`
        # here we do resampling to pick randomly a point from those N points
        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
        z = unravel_fn(zs[k])
        pe = pes[k]
        return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging,
                       adapt_state, rng_key)
Пример #22
0
def any(a, axis=None, keepdims=None, where=None):
  a = _remove_jaxarray(a)
  r = jnp.any(a=a, axis=axis, keepdims=keepdims, where=where)
  return r if axis is None else JaxArray(r)
Пример #23
0
 def get_receptive_field_1d(pos):
   g = grad_fn(inputs, pos)[0, :, :]
   return jnp.any((jnp.abs(g) > 1e-5).astype(jnp.uint32), axis=-1)
Пример #24
0
def tree_any(tree):
    return np.any(tree_flatten(tree_map(np.any, tree))[0])
Пример #25
0
def construct_nerf(key, example_batch, args):
    """Construct a Neural Radiance Field.

  Args:
    key: jnp.ndarray. Random number generator.
    example_batch: dict, an example of a batch of data.
    args: FLAGS class. Hyperparameters of nerf.

  Returns:
    model: nn.Model. Nerf model with parameters.
    state: flax.Module.state. Nerf model state for stateful parameters.
  """
    net_activation = getattr(nn, str(args.net_activation))
    rgb_activation = getattr(nn, str(args.rgb_activation))
    sigma_activation = getattr(nn, str(args.sigma_activation))

    # Assert that rgb_activation always produces outputs in [0, 1], and
    # sigma_activation always produce non-negative outputs.
    x = jnp.exp(jnp.linspace(-90, 90, 1024))
    x = jnp.concatenate([-x[::-1], x], 0)

    rgb = rgb_activation(x)
    if jnp.any(rgb < 0) or jnp.any(rgb > 1):
        raise NotImplementedError(
            "Choice of rgb_activation `{}` produces colors outside of [0, 1]".
            format(args.rgb_activation))

    sigma = sigma_activation(x)
    if jnp.any(sigma < 0):
        raise NotImplementedError(
            "Choice of sigma_activation `{}` produces negative densities".
            format(args.sigma_activation))

    model = NerfModel(min_deg_point=args.min_deg_point,
                      max_deg_point=args.max_deg_point,
                      deg_view=args.deg_view,
                      num_coarse_samples=args.num_coarse_samples,
                      num_fine_samples=args.num_fine_samples,
                      use_viewdirs=args.use_viewdirs,
                      near=args.near,
                      far=args.far,
                      noise_std=args.noise_std,
                      white_bkgd=args.white_bkgd,
                      net_depth=args.net_depth,
                      net_width=args.net_width,
                      net_depth_condition=args.net_depth_condition,
                      net_width_condition=args.net_width_condition,
                      skip_layer=args.skip_layer,
                      num_rgb_channels=args.num_rgb_channels,
                      num_sigma_channels=args.num_sigma_channels,
                      lindisp=args.lindisp,
                      net_activation=net_activation,
                      rgb_activation=rgb_activation,
                      sigma_activation=sigma_activation,
                      legacy_posenc_order=args.legacy_posenc_order)
    rays = example_batch["rays"]
    key1, key2, key3 = random.split(key, num=3)

    init_variables = model.init(key1,
                                rng_0=key2,
                                rng_1=key3,
                                rays=utils.namedtuple_map(
                                    lambda x: x[0], rays),
                                randomized=args.randomized)

    return model, init_variables
Пример #26
0
 def init_kernel(
     init_params,
     num_warmup,
     adapt_state_size=None,
     inverse_mass_matrix=None,
     dense_mass=False,
     model_args=(),
     model_kwargs=None,
     rng_key=random.PRNGKey(0),
 ):
     nonlocal wa_steps
     wa_steps = num_warmup
     pe_fn = potential_fn
     if potential_fn_gen:
         if pe_fn is not None:
             raise ValueError(
                 "Only one of `potential_fn` or `potential_fn_gen` must be provided."
             )
         else:
             kwargs = {} if model_kwargs is None else model_kwargs
             pe_fn = potential_fn_gen(*model_args, **kwargs)
     rng_key_sa, rng_key_zs, rng_key_z = random.split(rng_key, 3)
     z = init_params
     z_flat, unravel_fn = ravel_pytree(z)
     if inverse_mass_matrix is None:
         inverse_mass_matrix = (
             jnp.identity(z_flat.shape[-1])
             if dense_mass
             else jnp.ones(z_flat.shape[-1])
         )
     inv_mass_matrix_sqrt = (
         jnp.linalg.cholesky(inverse_mass_matrix)
         if dense_mass
         else jnp.sqrt(inverse_mass_matrix)
     )
     if adapt_state_size is None:
         # XXX: heuristic choice
         adapt_state_size = 2 * z_flat.shape[-1]
     else:
         assert adapt_state_size > 1, "adapt_state_size should be greater than 1."
     # NB: mean is init_params
     zs = z_flat + _sample_proposal(
         inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size,)
     )
     # compute potential energies
     pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs)
     if dense_mass:
         cov = jnp.cov(zs, rowvar=False, bias=True)
         if cov.shape == ():  # JAX returns scalar for 1D input
             cov = cov.reshape((1, 1))
         cholesky = jnp.linalg.cholesky(cov)
         # if cholesky is NaN, we use the scale from `sample_proposal` here
         inv_mass_matrix_sqrt = jnp.where(
             jnp.any(jnp.isnan(cholesky)), inv_mass_matrix_sqrt, cholesky
         )
     else:
         inv_mass_matrix_sqrt = jnp.std(zs, 0)
     adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0), inv_mass_matrix_sqrt)
     k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
     z = unravel_fn(zs[k])
     pe = pes[k]
     sa_state = SAState(
         jnp.array(0),
         z,
         pe,
         jnp.zeros(()),
         jnp.zeros(()),
         jnp.array(False),
         adapt_state,
         rng_key_sa,
     )
     return device_put(sa_state)
Пример #27
0
def construct_model(key, example_batch, args):
    """Construct a  Light Field Neural Renderig Model.

  Args:
    key: jnp.ndarray. Random number generator.
    example_batch: dict, an example of a batch of data.
    args: FLAGS class. Hyperparameters of nerf.

  Returns:
    model: nn.Model. Nerf model with parameters.
    state: flax.Module.state. Nerf model state for stateful parameters.
  """
    net_activation = getattr(nn, str(args.model.net_activation))
    rgb_activation = getattr(nn, str(args.model.rgb_activation))
    sigma_activation = getattr(nn, str(args.model.sigma_activation))

    # Assert that rgb_activation always produces outputs in [0, 1], and
    # sigma_activation always produce non-negative outputs.
    x = jnp.exp(jnp.linspace(-90, 90, 1024))
    x = jnp.concatenate([-x[::-1], x], 0)

    rgb = rgb_activation(x)  # pylint: disable=not-callable
    if jnp.any(rgb < 0) or jnp.any(rgb > 1):
        raise NotImplementedError(
            "Choice of rgb_activation `{}` produces colors outside of [0, 1]".
            format(args.rgb_activation))

    sigma = sigma_activation(x)  # pylint: disable=not-callable
    if jnp.any(sigma < 0):
        raise NotImplementedError(
            "Choice of sigma_activation `{}` produces negative densities".
            format(args.sigma_activation))

    # We have defined some wrapper functions to extract the relavant cofiguration
    # so are to alow for efficient reuse
    mlp_config = config_utils.get_mlp_config(args, net_activation)
    render_config = config_utils.get_render_params(args, rgb_activation,
                                                   sigma_activation)
    encoding_config = config_utils.get_encoding_params(args)
    lf_config = config_utils.get_lightfield_params(args)
    epipolar_config = config_utils.get_epipolar_params(args)
    epipolar_transformer_config = config_utils.get_epipolar_transformer_params(
        args)
    view_transformer_config = config_utils.get_view_transformer_params(args)

    if epipolar_config.use_learned_embedding:
        assert epipolar_transformer_config.qkv_params == view_transformer_config.qkv_params, "Currently the learned embedding are shared so the transformers need to have same qkv dim"

    model = LFNR(mlp_config=mlp_config,
                 render_config=render_config,
                 encoding_config=encoding_config,
                 lf_config=lf_config,
                 epipolar_config=epipolar_config,
                 epipolar_transformer_config=epipolar_transformer_config,
                 view_transformer_config=view_transformer_config,
                 return_attn=args.model.return_attn)

    key1, key2, key3 = random.split(key, num=3)

    init_variables = model.init(  # pylint: disable=no-member
        key1,
        rng_0=key2,
        rng_1=key3,
        batch=example_batch,
        randomized=args.model.randomized)

    return model, init_variables
Пример #28
0
    outer_grad = estimator.grad_estimate(theta)

    if args.outer_clip > 0:
        outer_grad = jax.tree_map(
            lambda g: jnp.clip(
                g, a_min=-args.outer_clip, a_max=args.outer_clip), outer_grad)

    outer_update, outer_opt_state = outer_opt.update(outer_grad,
                                                     outer_opt_state)
    theta = optax.apply_updates(theta, outer_update)

    total_inner_iterations = args.K * outer_iteration
    total_inner_iterations_including_N = (args.K * args.n_chunks *
                                          args.n_per_chunk * outer_iteration)

    if jnp.any(jnp.isnan(theta)):
        print('=' * 80 + '\nExiting early.\n' + '=' * 80)
        sys.exit(0)

    if outer_iteration % args.print_every == 0:
        print('Outer iter: {} | Theta: {} | Theta constrained: {}'.format(
            outer_iteration, theta, to_constrained(theta)))
        sys.stdout.flush()

    if outer_iteration % args.eval_every == 0:
        key, skey = jax.random.split(key)
        stats_dict = full_evaluation_runs(skey,
                                          theta,
                                          num_eval_runs=args.num_eval_runs)
        mean_stats_dict = {
            metric: onp.mean(stats_dict[metric])
Пример #29
0
def beam_search(inputs,
                cache,
                tokens_to_logits,
                beam_size=4,
                alpha=0.6,
                eos_token=EOS_ID,
                max_decode_len=None):
    """Beam search for transformer machine translation.

  Args:
    inputs: array: [batch_size, length] int32 sequence of tokens.
    cache: flax attention cache.
    tokens_to_logits: fast autoregressive decoder function taking single token
      slices and cache and returning next-token logits and updated cache.
    beam_size: int: number of beams to use in beam search.
    alpha: float: scaling factor for brevity penalty.
    eos_token: int: end-of-sentence token for target vocabulary.
    max_decode_len: int: maximum length of decoded translations.

  Returns:
     Tuple of:
       [batch_size, beam_size, max_decode_len] top-scoring sequences
       [batch_size, beam_size] beam-search scores.
  """
    # We liberally annotate shape information for clarity below.

    batch_size = inputs.shape[0]
    if max_decode_len is None:
        max_decode_len = inputs.shape[1]
    end_marker = jnp.array(eos_token)

    # initialize beam search state
    beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len,
                                       cache)

    def beam_search_loop_cond_fn(state):
        """Beam search loop termination condition."""
        # Have we reached max decoding length?
        not_at_end = (state.cur_index <= max_decode_len)

        # Is no further progress in the beam search possible?
        # Get the best possible scores from alive sequences.
        min_brevity_penalty = brevity_penalty(alpha, max_decode_len)
        best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
        # Get the worst scores from finished sequences.
        worst_finished_scores = jnp.min(state.finished_scores,
                                        axis=1,
                                        keepdims=True)
        # Mask out scores from slots without any actual finished sequences.
        worst_finished_scores = jnp.where(state.finished_flags,
                                          worst_finished_scores, NEG_INF)
        # If no best possible live score is better than current worst finished
        # scores, the search cannot improve the finished set further.
        search_terminated = jnp.all(worst_finished_scores > best_live_scores)

        # If we're not at the max decode length, and the search hasn't terminated,
        # continue looping.
        return not_at_end & (~search_terminated)

    def beam_search_loop_body_fn(state):
        """Beam search loop state update function."""
        # Collect the current position slice along length to feed the fast
        # autoregressive decoder model.  Flatten the beam dimension into batch
        # dimension for feeding into the model.
        # --> [batch * beam, 1]
        flat_ids = flatten_beam_dim(
            lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index),
                              (batch_size, beam_size, 1)))
        # Flatten beam dimension into batch to be compatible with model.
        # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
        flat_cache = jax.tree_map(flatten_beam_dim, state.cache)

        # Call fast-decoder model on current tokens to get next-position logits.
        # --> [batch * beam, vocab]
        flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)

        # unflatten beam dimension
        # [batch * beam, vocab] --> [batch, beam, vocab]
        logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
        # Unflatten beam dimension in attention cache arrays
        # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
        new_cache = jax.tree_map(
            lambda x: unflatten_beam_dim(x, batch_size, beam_size),
            new_flat_cache)

        # Gather log probabilities from logits
        candidate_log_probs = jax.nn.log_softmax(logits)
        # Add new logprobs to existing prefix logprobs.
        # --> [batch, beam, vocab]
        log_probs = (candidate_log_probs +
                     jnp.expand_dims(state.live_logprobs, axis=2))

        # We'll need the vocab size, gather it from the log probability dimension.
        vocab_size = log_probs.shape[2]

        # Each item in batch has beam_size * vocab_size candidate sequences.
        # For each item, get the top 2*k candidates with the highest log-
        # probabilities. We gather the top 2*K beams here so that even if the best
        # K sequences reach EOS simultaneously, we have another K sequences
        # remaining to continue the live beam search.
        beams_to_keep = 2 * beam_size
        # Flatten beam and vocab dimensions.
        flat_log_probs = log_probs.reshape(
            (batch_size, beam_size * vocab_size))
        # Gather the top 2*K scores from _all_ beams.
        # --> [batch, 2*beams], [batch, 2*beams]
        topk_log_probs, topk_indices = lax.top_k(flat_log_probs,
                                                 k=beams_to_keep)
        # Recover the beam index by floor division.
        topk_beam_indices = topk_indices // vocab_size
        # Gather 2*k top beams.
        # --> [batch, 2*beams, length]
        topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size,
                                beams_to_keep)

        # Append the most probable 2*K token IDs to the top 2*K sequences
        # Recover token id by modulo division and expand Id array for broadcasting.
        # --> [batch, 2*beams, 1]
        topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
        # Update sequences for the 2*K top-k new sequences.
        # --> [batch, 2*beams, length]
        topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids,
                                            (0, 0, state.cur_index + 1))

        # Update LIVE (in-progress) sequences:
        # Did any of these sequences reach an end marker?
        # --> [batch, 2*beams]
        newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
        # To prevent these newly finished sequences from being added to the LIVE
        # set of active beam search sequences, set their log probs to a very large
        # negative value.
        new_log_probs = topk_log_probs + newly_finished * NEG_INF
        # Determine the top k beam indices (from top 2*k beams) from log probs.
        # --> [batch, beams]
        _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
        new_topk_indices = jnp.flip(new_topk_indices, axis=1)
        # Gather the top k beams (from top 2*k beams).
        # --> [batch, beams, length], [batch, beams]
        top_alive_seq, top_alive_log_probs = gather_beams(
            [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size)

        # Determine the top k beam indices from the original set of all beams.
        # --> [batch, beams]
        top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices,
                                         batch_size, beam_size)
        # With these, gather the top k beam-associated caches.
        # --> {[batch, beams, ...], ...}
        top_alive_cache = gather_beams(new_cache, top_alive_indices,
                                       batch_size, beam_size)

        # Update FINISHED (reached end of sentence) sequences:
        # Calculate new seq scores from log probabilities.
        new_scores = topk_log_probs / brevity_penalty(alpha,
                                                      state.cur_index + 1)
        # Mask out the still unfinished sequences by adding large negative value.
        # --> [batch, 2*beams]
        new_scores += (~newly_finished) * NEG_INF

        # Combine sequences, scores, and flags along the beam dimension and compare
        # new finished sequence scores to existing finished scores and select the
        # best from the new set of beams.
        finished_seqs = jnp.concatenate(  # --> [batch, 3*beams, length]
            [state.finished_seqs, topk_seq],
            axis=1)
        finished_scores = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_scores, new_scores],
            axis=1)
        finished_flags = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_flags, newly_finished],
            axis=1)
        # --> [batch, beams, length], [batch, beams], [batch, beams]
        top_finished_seq, top_finished_scores, top_finished_flags = (
            gather_topk_beams([finished_seqs, finished_scores, finished_flags],
                              finished_scores, batch_size, beam_size))

        return BeamState(cur_index=state.cur_index + 1,
                         live_logprobs=top_alive_log_probs,
                         finished_scores=top_finished_scores,
                         live_seqs=top_alive_seq,
                         finished_seqs=top_finished_seq,
                         finished_flags=top_finished_flags,
                         cache=top_alive_cache)

    # Run while loop and get final beam search state.
    final_state = lax.while_loop(beam_search_loop_cond_fn,
                                 beam_search_loop_body_fn,
                                 beam_search_init_state)

    # Account for the edge-case where there are no finished sequences for a
    # particular batch item. If so, return live sequences for that batch item.
    # --> [batch]
    none_finished = jnp.any(final_state.finished_flags, axis=1)
    # --> [batch, beams, length]
    finished_seqs = jnp.where(none_finished[:, None, None],
                              final_state.finished_seqs, final_state.live_seqs)
    # --> [batch, beams]
    finished_scores = jnp.where(none_finished[:, None],
                                final_state.finished_scores,
                                final_state.live_logprobs)

    return finished_seqs, finished_scores
Пример #30
0
def multiplicative_replacement(mat, delta=None):
    r"""Replace all zeros with small non-zero values

    It uses the multiplicative replacement strategy [1]_ ,
    replacing zeros with a small positive :math:`\delta`
    and ensuring that the compositions still add up to 1.


    Parameters
    ----------
    mat: array_like
       a matrix of proportions where
       rows = compositions and
       columns = components
    delta: float, optional
       a small number to be used to replace zeros
       If delta is not specified, then the default delta is
       :math:`\delta = \frac{1}{N^2}` where :math:`N`
       is the number of components

    Returns
    -------
    numpy.ndarray, np.float64
       A matrix of proportions where all of the values
       are nonzero and each composition (row) adds up to 1

    Raises
    ------
    ValueError
       Raises an error if negative proportions are created due to a large
       `delta`.

    Notes
    -----
    This method will result in negative proportions if a large delta is chosen.

    References
    ----------
    .. [1] J. A. Martin-Fernandez. "Dealing With Zeros and Missing Values in
           Compositional Data Sets Using Nonparametric Imputation"


    Examples
    --------
    >>> import numpy as np
    >>> import mushi.composition as cmp
    >>> X = np.array([[.2,.4,.4, 0],[0,.5,.5,0]])
    >>> cmp.multiplicative_replacement(X)
    DeviceArray([[0.1875, 0.375 , 0.375 , 0.0625],
                 [0.0625, 0.4375, 0.4375, 0.0625]], dtype=float64)

    """
    mat = closure(mat)
    z_mat = mat == 0

    num_feats = mat.shape[-1]
    tot = z_mat.sum(axis=-1, keepdims=True)

    if delta is None:
        delta = (1.0 / num_feats)**2

    zcnts = 1 - tot * delta
    if np.any(zcnts) < 0:
        raise ValueError("The multiplicative replacment created negative "
                         "proportions. Consider using a smaller `delta`.")
    mat = np.where(z_mat, delta, zcnts * mat)
    return mat.squeeze()