Ejemplo n.º 1
0
def rotate(q, v):
    """Rotate a vector using a quaternion."""
    # Create the quaternion representation of the vector.
    q_v = jnp.concatenate([v, jnp.zeros_like(v[..., :1])], axis=-1)
    return im(multiply(multiply(q, q_v), conjugate(q)))
Ejemplo n.º 2
0
def conjugate(q):
    """Compute the conjugate of a quaternion."""
    return jnp.concatenate([-im(q), re(q)], axis=-1)
Ejemplo n.º 3
0
def multiply(q1, q2):
    """Multiply two quaternions."""
    c = (re(q1) * im(q2) + re(q2) * im(q1) + jnp.cross(im(q1), im(q2)))
    w = re(q1) * re(q2) - jnp.dot(im(q1), im(q2))
    return jnp.concatenate([c, w], axis=-1)
Ejemplo n.º 4
0
        def beam_search_body_fn(state, input_ids_length=1):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(
                    state.running_sequences,
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
                ))
            model_outputs = model(input_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = unflatten_beam_dim(model_outputs.logits[:, -1],
                                        batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams
                                                  ),
                model_outputs.past_key_values)

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(flatten_beam_dim(running_sequences),
                                         flatten_beam_dim(log_probs),
                                         state.cur_len)
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores,
                                                    axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs,
                                                     k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(state.running_sequences,
                                                  topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences,
                                                      topk_ids,
                                                      (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.
                                                    cur_len] == eos_token_id
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(
                -1.0e7)
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs,
                                                   k=num_beams)[1],
                                         axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, running_topk_log_probs], next_topk_indices,
                batch_size, num_beams)

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
            beams_in_batch_are_full = (jnp.broadcast_to(
                state.is_sent_finished.all(axis=-1, keepdims=True),
                did_topk_just_finished.shape)
                                       & early_stopping)
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate(
                [state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs],
                                            axis=1)
            merged_is_sent_finished = jnp.concatenate(
                [state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores,
                                                     k=num_beams)[1],
                                           axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished],
                topk_merged_indices, batch_size, num_beams)

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices,
                                                next_topk_indices, batch_size,
                                                num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size,
                                      num_beams)
            model_outputs["past_key_values"] = jax.tree_map(
                lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )
Ejemplo n.º 5
0
def normalize_GradientGP(XF, yF, XG, yG):
    y = np.concatenate([yF, yG], axis=0)
    batch = {'XF': XF, 'XG': XG, 'yF': yF, 'yG': yG, 'y': y}
    norm_const = {'mu_X': 0.0, 'sigma_X': 1.0, 'mu_y': 0.0, 'sigma_y': 1.0}
    return batch, norm_const
Ejemplo n.º 6
0
def extend_vector(p):
    """Extend n-1 arbitrary floats into n values in the range (0,1) that sum to one.
    Based on https://stackoverflow.com/questions/3589214/generate-random-numbers-summing-to-a-predefined-value
    """
    wsort = jnp.sort(1 / (1 + jnp.exp(-p)))
    return jnp.concatenate((wsort[0:1], jnp.diff(wsort), 1 - wsort[-1:]))
Ejemplo n.º 7
0
 def _flatten(params):
     """Flattens and concatenates all tensors in params to a single vector."""
     params, _ = tree_flatten(params)
     return jnp.concatenate([jnp.reshape(param, [-1]) for param in params])
Ejemplo n.º 8
0
        f_distill = jnp.zeros((x_distill.shape[0], 10))
    # run fgb steps
    residual = jnp.zeros((x_train.shape[0], 10))
    print("start running local updates")
    for local_step in range(hyperparams.num_local_steps):
        # print(local_step)
        key, subkey = random.split(key)
        target = - vg_ce(f_data, y_train) + residual  # (negative functional gradient direction)
        params = regression_oracle(model, jnp.expand_dims(batch.x, axis=0), jnp.expand_dims(target, axis=0), subkey, hyperparams)
        # new_weight = hyperparams.lr_0 / (round * hyperparams.num_local_steps + local_step + 1) ** .5
        new_weight = hyperparams.lr_0 / (local_step + 1) ** .5 * jnp.ones((1, 1))
        # new_weight = hyperparams.lr_0
        # predict = jnp.concatenate([model.apply(opt.target, _x) for _x in xs])
        classfier = Classifier(params, jnp.ones((1, 1)))
        classfier_fn = get_classifier_fn(classfier)
        predict = jnp.concatenate([classfier_fn(x) for x in xs], axis=0)
        residual = target - predict
        # params_list.append(new_params)
        # weight_list.append(new_weight)
        f_data += predict * new_weight



        # print("test fgb result")
        # predict_test = model.apply(opt.target, x_test)
        predict_test = jnp.concatenate([classfier_fn(x) for x in xts], axis=0)
        f_x_test += predict_test * new_weight
        test_loss = v_ce(f_x_test, y_test)
        pred = jnp.argmax(f_x_test, axis=1)
        corrct = jnp.true_divide(
            jnp.sum(jnp.equal(pred, jnp.reshape(y_test, pred.shape))),
Ejemplo n.º 9
0
def plot_aleatoric_var_vs_time(gp, solver, traj_init, traj_opt=None):
    params = {
        "text.usetex": True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)

    mixing_probs_init = jax.vmap(
        single_mogpe_mixing_probability,
        (0, None, None, None, None, None, None),
    )(
        traj_init[:, 0:2],
        gp.Z,
        gp.kernel,
        gp.mean_func,
        gp.q_mu,
        False,
        gp.q_sqrt,
    )
    if mixing_probs_init.shape[-1] == 1:
        mixing_probs_init = np.concatenate(
            [mixing_probs_init, 1 - mixing_probs_init], -1
        )
    if traj_opt is not None:
        mixing_probs_opt = jax.vmap(
            single_mogpe_mixing_probability,
            (0, None, None, None, None, None, None),
        )(
            traj_opt[:, 0:2],
            gp.Z,
            gp.kernel,
            gp.mean_func,
            gp.q_mu,
            False,
            gp.q_sqrt,
        )
        if mixing_probs_opt.shape[-1] == 1:
            mixing_probs_opt = np.concatenate(
                [mixing_probs_opt, 1 - mixing_probs_opt], -1
            )

    noise_vars = np.array(gp.noise_vars).reshape(-1, 1)
    var_init = mixing_probs_init @ noise_vars

    var_opt = 0
    if traj_opt is not None:
        var_opt = mixing_probs_opt @ noise_vars
        print("var opt")
        print(var_opt.shape)

    fig, ax = plt.subplots(1, 1, figsize=(6.4, 2.8))

    ax.set_xlabel("Time $t$")
    ax.set_ylabel("$\sum_{k=1}^K\Pr(\\alpha=k|\mathbf{x}) (\sigma^{(k)})^2$")

    ax.plot(
        solver.times, var_init, color=color_init, label="Initial trajectory"
    )
    if traj_opt is not None:
        ax.plot(
            solver.times,
            var_opt,
            color=color_opt,
            label="Optimised trajectory",
        )
    ax.legend()
    sum_var_init = np.sum(var_init)
    sum_var_opt = np.sum(var_opt)
    print("Sum aleatoric var init = ", sum_var_init)
    print("Sum aleatoric var opt = ", sum_var_opt)
    return fig, ax
Ejemplo n.º 10
0
def _cat(dim, *x):
    if len(x) == 1:
        return x[0]
    return np.concatenate(x, axis=dim)