예제 #1
 def func(x):
     return lax.while_loop(lambda c: c[1] < 5, lambda c:
                           (y, hcb.id_print(c[1]) + 1), (x, 1))
예제 #2
    def _sample(
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        prng_key: Optional[jnp.ndarray] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        logits_warper: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

        batch_size, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch-item holding current token in loop.
        sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
        sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))

        # per batch-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)

        # initialize state
        state = SampleState(

        def sample_search_cond_fn(state):
            """state termination condition fn."""
            has_reached_max_length = state.cur_len == max_length
            all_sequence_finished = jnp.all(state.is_sent_finished)
            finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
            return ~finish_generation

        def sample_search_body_fn(state):
            """state update fn."""
            prng_key, prng_key_next = jax.random.split(state.prng_key)
            model_outputs = model(state.running_token, params=params, **state.model_kwargs)

            logits = model_outputs.logits[:, -1]

            # apply min_length, ...
            logits = logits_processor(state.sequences, logits, state.cur_len)
            # apply top_p, top_k, temperature
            logits = logits_warper(logits, logits, state.cur_len)

            next_token = jax.random.categorical(prng_key, logits, axis=-1)

            next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
            next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)

            return SampleState(
                cur_len=state.cur_len + 1,

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        if input_ids.shape[1] > 1:
            state = sample_search_body_fn(state)

        if not trace:
            state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
            state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)

        return FlaxSampleOutput(sequences=state.sequences)
예제 #3
    def _beam_search(
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        early_stopping: Optional[bool] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
        This beam search function is heavily inspired by Flax's official example:

        def flatten_beam_dim(tensor):
            """Flattens the first two dimensions of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])

        def unflatten_beam_dim(tensor, batch_size, num_beams):
            """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])

        def gather_beams(nested, beam_indices, batch_size, new_num_beams):
            Gathers the beam slices indexed by beam_indices into new beam array.
            batch_indices = jnp.reshape(
                jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)

            def gather_fn(tensor):
                # ignore scalars (e.g. cache index)
                if tensor.ndim == 0:
                    return tensor
                    return tensor[batch_indices, beam_indices]

            return jax.tree_map(gather_fn, nested)

        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping

        batch_size, num_beams, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch,beam-item holding current token in loop.
        sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
        running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
        running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))

        # per batch,beam-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)

        # per batch,beam-item score, logprobs
        running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
        scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # flatten beam dim
        if "encoder_outputs" in model_kwargs:
            model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
        if "attention_mask" in model_kwargs:
            model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)

        # initialize state
        state = BeamSearchState(

        def beam_search_cond_fn(state):
            """beam search state termination condition fn."""

            # 1. is less than max length?
            not_max_length_yet = state.cur_len < max_length

            # 2. can the new beams still improve?
            best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
            worst_finished_score = jnp.where(
                state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
            improvement_still_possible = jnp.all(worst_finished_score < best_running_score)

            # 3. is there still a beam that has not finished?
            still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)

            return not_max_length_yet & still_open_beam & improvement_still_possible

        def beam_search_body_fn(state, input_ids_length=1):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
            model_outputs = model(input_token, params=params, **state.model_kwargs)

            logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(
                flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(
                state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
            beams_in_batch_are_full = (
                jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
                & early_stopping
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
            merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
            model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        if input_ids.shape[-1] > 1:
            state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)

        if not trace:
            state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
            state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)

        # Account for the edge-case where there are no finished sequences for a
        # particular batch item. If so, return running sequences for that batch item.
        none_finished = jnp.any(state.is_sent_finished, axis=1)
        sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
        scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)

        # take best beam for each batch
        sequences = sequences[:, -1]
        scores = scores[:, -1]

        return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
예제 #4
 def fun(x, y):
     return lax.while_loop(cond_fun, body_fun, (x, y))
예제 #5
 def loop(init):
     result = lax.while_loop(loop_cond, loop_body, (init, 0))
     _, count = result
     return count
예제 #6
 def count_to(N):
     return lax.while_loop(lambda x: x < N, lambda x: x + 1.0, 0.0)
예제 #7
 def fun(x, y):
     return lax.while_loop(lambda x: x < 3, lambda x: x + y, x)
예제 #8
파일: linalg.py 프로젝트: jbampton/jax
def eigh_tridiagonal(d,
    if not eigvals_only:
        raise NotImplementedError(
            "Calculation of eigenvectors is not implemented")

    def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
        """Implements the Sturm sequence recurrence."""
        n = alpha.shape[0]
        zeros = jnp.zeros(x.shape, dtype=jnp.int32)
        ones = jnp.ones(x.shape, dtype=jnp.int32)

        # The first step in the Sturm sequence recurrence
        # requires special care if x is equal to alpha[0].
        def sturm_step0():
            q = alpha[0] - x
            count = jnp.where(q < 0, ones, zeros)
            q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
            return q, count

        # Subsequent steps all take this form:
        def sturm_step(i, q, count):
            q = alpha[i] - beta_sq[i - 1] / q - x
            count = jnp.where(q <= pivmin, count + 1, count)
            q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
            return q, count

        # The first step initializes q and count.
        q, count = sturm_step0()

        # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
        # the bulk of the iterations unrolled by a factor of blocksize.
        blocksize = 16
        i = 1
        peel = (n - 1) % blocksize
        unroll_cnt = peel

        def unrolled_steps(args):
            start, q, count = args
            for j in range(unroll_cnt):
                q, count = sturm_step(start + j, q, count)
            return start + unroll_cnt, q, count

        i, q, count = unrolled_steps((i, q, count))

        # Run the remaining steps of the Sturm sequence using a partially
        # unrolled while loop.
        unroll_cnt = blocksize

        def cond(iqc):
            i, q, count = iqc
            return jnp.less(i, n)

        _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
        return count

    alpha = jnp.asarray(d)
    beta = jnp.asarray(e)
    supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64,
    if alpha.dtype != beta.dtype:
        raise TypeError(
            "diagonal and off-diagonal values must have same dtype, "
            f"got {alpha.dtype} and {beta.dtype}")
    if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes:
        raise TypeError(
            "Only float32 and float64 inputs are supported as inputs "
            "to jax.scipy.linalg.eigh_tridiagonal, got "
            f"{alpha.dtype} and {beta.dtype}")
    n = alpha.shape[0]
    if n <= 1:
        return jnp.real(alpha)

    if jnp.issubdtype(alpha.dtype, jnp.complexfloating):
        alpha = jnp.real(alpha)
        beta_sq = jnp.real(beta * jnp.conj(beta))
        beta_abs = jnp.sqrt(beta_sq)
        beta_abs = jnp.abs(beta)
        beta_sq = jnp.square(beta)

    # Estimate the largest and smallest eigenvalues of T using the Gershgorin
    # circle theorem.
    off_diag_abs_row_sum = jnp.concatenate(
        [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
    lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum)
    lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum)
    # Upper bound on 2-norm of T.
    t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max))

    # Compute the smallest allowed pivot in the Sturm sequence to avoid
    # overflow.
    finfo = np.finfo(alpha.dtype)
    one = np.ones([], dtype=alpha.dtype)
    safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
    pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq))
    alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0])
    abs_tol = finfo.eps * t_norm
    if tol is not None:
        abs_tol = jnp.maximum(tol, abs_tol)

    # In the worst case, when the absolute tolerance is eps*lambda_est_max and
    # lambda_est_max = -lambda_est_min, we have to take as many bisection steps
    # as there are bits in the mantissa plus 1.
    # The proof is left as an exercise to the reader.
    max_it = finfo.nmant + 1

    # Determine the indices of the desired eigenvalues, based on select and
    # select_range.
    if select == 'a':
        target_counts = jnp.arange(n)
    elif select == 'i':
        if select_range[0] > select_range[1]:
            raise ValueError('Got empty index range in select_range.')
        target_counts = jnp.arange(select_range[0], select_range[1] + 1)
    elif select == 'v':
        # TODO(phawkins): requires dynamic shape support.
        raise NotImplementedError("eigh_tridiagonal(..., select='v') is not "
        raise ValueError("'select must have a value in {'a', 'i', 'v'}.")

    # Run binary search for all desired eigenvalues in parallel, starting from
    # the interval lightly wider than the estimated
    # [lambda_est_min, lambda_est_max].
    fudge = 2.1  # We widen starting interval the Gershgorin interval a bit.
    norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm
    lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
    upper = lambda_est_max + norm_slack + fudge * pivmin

    # Pre-broadcast the scalars used in the Sturm sequence for improved
    # performance.
    target_shape = jnp.shape(target_counts)
    lower = jnp.broadcast_to(lower, shape=target_shape)
    upper = jnp.broadcast_to(upper, shape=target_shape)
    mid = 0.5 * (upper + lower)
    pivmin = jnp.broadcast_to(pivmin, target_shape)
    alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)

    # Start parallel binary searches.
    def cond(args):
        i, lower, _, upper = args
        return jnp.logical_and(jnp.less(i, max_it),
                               jnp.less(abs_tol, jnp.amax(upper - lower)))

    def body(args):
        i, lower, mid, upper = args
        counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
        lower = jnp.where(counts <= target_counts, mid, lower)
        upper = jnp.where(counts > target_counts, mid, upper)
        mid = 0.5 * (lower + upper)
        return i + 1, lower, mid, upper

    _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
    return mid
예제 #9
 def fun(x):
     return lax.while_loop(lambda x: x < 3, lambda x: x + 2, x)
예제 #10
파일: bfgs.py 프로젝트: wayfeng/jax
def minimize_bfgs(
    fun: Callable,
    x0: jnp.ndarray,
    maxiter: Optional[int] = None,
    gtol: float = 1e-5,
    line_search_maxiter: int = 10,
) -> _BFGSResults:
    """Minimize a function using BFGS.

  Implements the BFGS algorithm from
    Algorithm 6.1 from Wright and Nocedal, 'Numerical Optimization', 1999, pg.

    fun: function of the form f(x) where x is a flat ndarray and returns a real
      scalar. The function should be composed of operations with vjp defined.
    x0: initial guess.
    maxiter: maximum number of iterations.
    norm: order of norm for convergence check. Default inf.
    gtol: terminates minimization when |grad|_norm < g_tol.
    line_search_maxiter: maximum number of linesearch iterations.

    Optimization result.

    if maxiter is None:
        maxiter = jnp.size(x0) * 200

    d = x0.shape[0]

    initial_H = jnp.eye(d, dtype=x0.dtype)
    f_0, g_0 = jax.value_and_grad(fun)(x0)
    state = _BFGSResults(
        converged=jnp.linalg.norm(g_0, ord=norm) < gtol,
        old_old_fval=f_0 + jnp.linalg.norm(g_0) / 2,

    def cond_fun(state):
        return (jnp.logical_not(state.converged)
                & jnp.logical_not(state.failed)
                & (state.k < maxiter))

    def body_fun(state):
        p_k = -_dot(state.H_k, state.g_k)
        line_search_results = line_search(
        state = state._replace(
            nfev=state.nfev + line_search_results.nfev,
            ngev=state.ngev + line_search_results.ngev,
        s_k = line_search_results.a_k * p_k
        x_kp1 = state.x_k + s_k
        f_kp1 = line_search_results.f_k
        g_kp1 = line_search_results.g_k
        y_k = g_kp1 - state.g_k
        rho_k = jnp.reciprocal(_dot(y_k, s_k))

        sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :]
        w = jnp.eye(d) - rho_k * sy_k
        H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w) +
                 rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :])
        H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k)
        converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

        state = state._replace(
            k=state.k + 1,
        return state

    state = lax.while_loop(cond_fun, body_fun, state)
    status = jnp.where(
        0,  # converged
            state.k == maxiter,
            1,  # max iters reached
                2 + state.line_search_status,  # ls failed (+ reason)
                -1,  # undefined
    state = state._replace(status=status)
    return state
예제 #11
def _minimize_lbfgs(
    fun: Callable,
    x0: jnp.ndarray,
    maxiter: Optional[int] = None,
    maxcor: int = 10,
    ftol: float = 2.220446049250313e-09,
    gtol: float = 1e-05,
    maxfun: Optional[int] = None,
    maxgrad: Optional[int] = None,
    maxls: int = 20,
  Minimize a function using L-BFGS

  Implements the L-BFGS algorithm from
    Algorithm 7.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 176-185
  And generalizes to complex variables from
     Sorber, L., Barel, M.V. and Lathauwer, L.D., 2012.
     "Unconstrained optimization of real functions in complex variables"
     SIAM Journal on Optimization, 22(3), pp.879-898.

    fun: function of the form f(x) where x is a flat ndarray and returns a real scalar.
      The function should be composed of operations with vjp defined.
    x0: initial guess
    maxiter: maximum number of iterations
    norm: order of norm for convergence check. Default inf.
    maxcor: maximum number of metric corrections ("history size")
    ftol: terminates the minimization when `(f_k - f_{k+1}) < ftol`
    gtol: terminates the minimization when `|g_k|_norm < gtol`
    maxfun: maximum number of function evaluations
    maxgrad: maximum number of gradient evaluations
    maxls: maximum number of line search steps (per iteration)

    Optimization results.
    d = len(x0)
    dtype = jnp.dtype(x0)

    # ensure there is at least one termination condition
    if (maxiter is None) and (maxfun is None) and (maxgrad is None):
        maxiter = d * 200

    # set others to inf, such that >= is supported
    if maxiter is None:
        maxiter = jnp.inf
    if maxfun is None:
        maxfun = jnp.inf
    if maxgrad is None:
        maxgrad = jnp.inf

    # initial evaluation
    f_0, g_0 = jax.value_and_grad(fun)(x0)
    state_initial = LBFGSResults(
        s_history=jnp.zeros((maxcor, d), dtype=dtype),
        y_history=jnp.zeros((maxcor, d), dtype=dtype),
        rho_history=jnp.zeros((maxcor, ), dtype=dtype),

    def cond_fun(state: LBFGSResults):
        return (~state.converged) & (~state.failed)

    def body_fun(state: LBFGSResults):
        # find search direction
        p_k = _two_loop_recursion(state)

        # line search
        ls_results = line_search(

        # evaluate at next iterate
        s_k = ls_results.a_k * p_k
        x_kp1 = state.x_k + s_k
        f_kp1 = ls_results.f_k
        g_kp1 = ls_results.g_k
        y_k = g_kp1 - state.g_k
        rho_k_inv = jnp.real(_dot(y_k, s_k))
        rho_k = jnp.reciprocal(rho_k_inv)
        gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))

        # replacements for next iteration
        status = 0
        status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
        status = jnp.where(state.ngev >= maxgrad, 3, status)  # type: ignore
        status = jnp.where(state.nfev >= maxfun, 2, status)  # type: ignore
        status = jnp.where(state.k >= maxiter, 1, status)  # type: ignore
        status = jnp.where(ls_results.failed, 5, status)

        converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

        state = state._replace(
            failed=(status > 0) & (~converged),
            k=state.k + 1,
            nfev=state.nfev + ls_results.nfev,
            ngev=state.ngev + ls_results.ngev,
            status=jnp.where(converged, 0, status),

        return state

    return lax.while_loop(cond_fun, body_fun, state_initial)
예제 #12
def beam_search(inputs,
  """Beam search for transformer machine translation.

    inputs: array: [batch_size, length] int32 sequence of tokens.
    cache: flax attention cache.
    tokens_to_logits: fast autoregressive decoder function taking single token
      slices and cache and returning next-token logits and updated cache.
    beam_size: int: number of beams to use in beam search.
    alpha: float: scaling factor for brevity penalty.
    eos_id: int: if of end-of-sentence token for target vocabulary.
    max_decode_len: int: maximum length of decoded translations.

     Tuple of:
       [batch_size, beam_size, max_decode_len] top-scoring sequences
       [batch_size, beam_size] beam-search scores.
  # We liberally annotate shape information for clarity below.

  batch_size = inputs.shape[0]
  if max_decode_len is None:
    max_decode_len = inputs.shape[1]
  end_marker = jnp.array(eos_id)

  # initialize beam search state
  beam_search_init_state = beam_init(batch_size,

  def beam_search_loop_cond_fn(state):
    """Beam search loop termination condition."""
    # Have we reached max decoding length?
    not_at_end = (state.cur_index < max_decode_len - 1)

    # Is no further progress in the beam search possible?
    # Get the best possible scores from alive sequences.
    min_brevity_penalty = brevity_penalty(alpha, max_decode_len)
    best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
    # Get the worst scores from finished sequences.
    worst_finished_scores = jnp.min(
        state.finished_scores, axis=1, keepdims=True)
    # Mask out scores from slots without any actual finished sequences.
    worst_finished_scores = jnp.where(
        state.finished_flags, worst_finished_scores, NEG_INF)
    # If no best possible live score is better than current worst finished
    # scores, the search cannot improve the finished set further.
    search_terminated = jnp.all(worst_finished_scores > best_live_scores)

    # If we're not at the max decode length, and the search hasn't terminated,
    # continue looping.
    return not_at_end & (~search_terminated)

  def beam_search_loop_body_fn(state):
    """Beam search loop state update function."""
    # Collect the current position slice along length to feed the fast
    # autoregressive decoder model.  Flatten the beam dimension into batch
    # dimension for feeding into the model.
    # --> [batch * beam, 1]
    flat_ids = flatten_beam_dim(lax.dynamic_slice(
        (0, 0, state.cur_index),
        (batch_size, beam_size, 1)))
    # Flatten beam dimension into batch to be compatible with model.
    # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
    flat_cache = jax.tree_map(flatten_beam_dim, state.cache)

    # Call fast-decoder model on current tokens to get next-position logits.
    # --> [batch * beam, vocab]
    flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)

    # unflatten beam dimension
    # [batch * beam, vocab] --> [batch, beam, vocab]
    logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
    # Unflatten beam dimension in attention cache arrays
    # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
    new_cache = jax.tree_map(
        lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)

    # Gather log probabilities from logits
    candidate_log_probs = jax.nn.log_softmax(logits)
    # Add new logprobs to existing prefix logprobs.
    # --> [batch, beam, vocab]
    log_probs = (candidate_log_probs +
                 jnp.expand_dims(state.live_logprobs, axis=2))

    # We'll need the vocab size, gather it from the log probability dimension.
    vocab_size = log_probs.shape[2]

    # Each item in batch has beam_size * vocab_size candidate sequences.
    # For each item, get the top 2*k candidates with the highest log-
    # probabilities. We gather the top 2*K beams here so that even if the best
    # K sequences reach EOS simultaneously, we have another K sequences
    # remaining to continue the live beam search.
    beams_to_keep = 2 * beam_size
    # Flatten beam and vocab dimensions.
    flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size))
    # Gather the top 2*K scores from _all_ beams.
    # --> [batch, 2*beams], [batch, 2*beams]
    topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep)
    # Recover the beam index by floor division.
    topk_beam_indices = topk_indices // vocab_size
    # Gather 2*k top beams.
    # --> [batch, 2*beams, length]
    topk_seq = gather_beams(state.live_seqs,
                            batch_size, beams_to_keep)

    # Append the most probable 2*K token IDs to the top 2*K sequences
    # Recover token id by modulo division and expand Id array for broadcasting.
    # --> [batch, 2*beams, 1]
    topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
    # Update sequences for the 2*K top-k new sequences.
    # --> [batch, 2*beams, length]
    topk_seq = lax.dynamic_update_slice(
        topk_seq, topk_ids, (0, 0, state.cur_index + 1))

    # Update LIVE (in-progress) sequences:
    # Did any of these sequences reach an end marker?
    # --> [batch, 2*beams]
    newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
    # To prevent these newly finished sequences from being added to the LIVE
    # set of active beam search sequences, set their log probs to a very large
    # negative value.
    new_log_probs = topk_log_probs + newly_finished * NEG_INF
    # Determine the top k beam indices (from top 2*k beams) from log probs.
    # --> [batch, beams]
    _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
    new_topk_indices = jnp.flip(new_topk_indices, axis=1)
    # Gather the top k beams (from top 2*k beams).
    # --> [batch, beams, length], [batch, beams]
    top_alive_seq, top_alive_log_probs = gather_beams(
        [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size)

    # Determine the top k beam indices from the original set of all beams.
    # --> [batch, beams]
    top_alive_indices = gather_beams(
        topk_beam_indices, new_topk_indices, batch_size, beam_size)
    # With these, gather the top k beam-associated caches.
    # --> {[batch, beams, ...], ...}
    top_alive_cache = gather_beams(
        new_cache, top_alive_indices, batch_size, beam_size)

    # Update FINISHED (reached end of sentence) sequences:
    # Calculate new seq scores from log probabilities.
    new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1)
    # Mask out the still unfinished sequences by adding large negative value.
    # --> [batch, 2*beams]
    new_scores += (~newly_finished) * NEG_INF

    # Combine sequences, scores, and flags along the beam dimension and compare
    # new finished sequence scores to existing finished scores and select the
    # best from the new set of beams.
    finished_seqs = jnp.concatenate(  # --> [batch, 3*beams, length]
        [state.finished_seqs, topk_seq], axis=1)
    finished_scores = jnp.concatenate(  # --> [batch, 3*beams]
        [state.finished_scores, new_scores], axis=1)
    finished_flags = jnp.concatenate(  # --> [batch, 3*beams]
        [state.finished_flags, newly_finished], axis=1)
    # --> [batch, beams, length], [batch, beams], [batch, beams]
    top_finished_seq, top_finished_scores, top_finished_flags = (
        gather_topk_beams([finished_seqs, finished_scores, finished_flags],
                          finished_scores, batch_size, beam_size))

    return BeamState(cur_index=state.cur_index + 1,

  # Run while loop and get final beam search state.
  final_state = lax.while_loop(beam_search_loop_cond_fn,

  # Account for the edge-case where there are no finished sequences for a
  # particular batch item. If so, return live sequences for that batch item.
  # --> [batch]
  none_finished = jnp.any(final_state.finished_flags, axis=1)
  # --> [batch, beams, length]
  finished_seqs = jnp.where(none_finished[:, None, None],
  # --> [batch, beams]
  finished_scores = jnp.where(none_finished[:, None],

  return finished_seqs, finished_scores
예제 #13
 def g(x):
     return lax.while_loop(lambda carry: carry[0] < 10, lambda carry:
                           (carry[0] + 1., f(carry[1])), (0., x))
예제 #14
 def func(x):
     # Equivalent to:
     #      for(i=x; i < 4; i++);
     return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x)
예제 #15
 def func(x):
     return lax.while_loop(
         lambda c: c[1] < 5, lambda c:
         (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
         (x, 1))
예제 #16
 def fun(x, y):
     return lax.while_loop(lambda x: x < y, lambda x: x + 2, x)
예제 #17
 def f(x):
     _, y = lax.while_loop(lambda s: s[0] < 0., lambda s:
                           (jnp.sin(s[0]), jnp.cos(s[1])), (x, x))
     return y + 1.
예제 #18
 def f():
     pred = lambda _: False
     return lax.while_loop(pred, err, ())