コード例 #1
0
def plot_mixing_prob_and_start_end(gp, solver, traj_opt=None):

    # plot original GP
    Xnew, xx, yy = create_grid(gp.X, N=961)

    # TODO need to change gp.X to gp.Z (and gp.q_mu) for sparse
    mixing_probs = jax.vmap(
        single_mogpe_mixing_probability,
        (0, None, None, None, None, None, None),
    )(
        Xnew,
        # gp.X,
        gp.Z,
        gp.kernel,
        gp.mean_func,
        gp.q_mu,
        False,
        gp.q_sqrt,
    )
    # mixing_probs = mogpe_mixing_probability(Xnew,
    #                                         gp.X,
    #                                         gp.kernel,
    #                                         mean_func=gp.mean_func,
    #                                         f=gp.Y,
    #                                         q_sqrt=gp.q_sqrt,
    #                                         full_cov=False)
    # print('mixing probs yo')
    # print(mixing_probs.shape)
    # mixing_probs = mixing_probs[:, 0, :] * mixing_probs[:, 1, :]
    # output_dim = mixing_probs.shape[0]
    fig, ax = plt.subplots(1, 1)
    plot_contourf(
        fig,
        ax,
        xx,
        yy,
        # mixing_probs[:, 1:2],
        mixing_probs[:, 0:1],
        label="$\Pr(\\alpha=1 | \mathbf{x})$",
    )
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y$")

    plot_omitted_data(fig, ax, color="k")

    plot_start_and_end_pos(fig, ax, solver)

    plot_traj(
        fig,
        ax,
        solver.state_guesses,
        color=color_init,
        label="Initial trajectory",
    )
    if traj_opt is not None:
        plot_traj(
            fig, ax, traj_opt, color=color_opt, label="Optimised trajectory"
        )
    ax.legend()

    return fig, ax
コード例 #2
0
 def select_all_tensors(all_tensors, indices):
     return jax.vmap(select_tensor)(all_tensors, indices)
コード例 #3
0
def loss(r_surf, nn, sg, weight, p):
    l = r(p, theta)
    dl = l[:, :-1, :] - l[:, 1:, :]
    return np.sum(pmap_quadratic_flux(r_surf, nn, sg, dl,
                                      l[:, :-1, :])) + weight * np.sum(dl)


#######################################################################
# JAX/Python Function Transformations
#######################################################################

# functional programming (Python)
objective_function = partial(loss, r_surf, nn, sg, 0.1)

# vectorize (JAX)
biot_savart_surface = vmap(vmap(biot_savart, (0, None, None), 0),
                           (1, None, None), 1)

# automatic differentiation (JAX)
grad_func = grad(objective_function)  # d output / d input

# jit-compile (JAX)
jit_grad_func = jit(grad_func)

# SPMP parallelization (JAX)
pmap_quadratic_flux = pmap(quadratic_flux, in_axes=(0, 0, 0, None, None))

#######################################################################
# Optimization
#######################################################################

print("loss is {}".format(objective_function(p)))
コード例 #4
0
def plot_svgp_jacobian_mean(gp, solver, traj_opt=None):
    params = {
        "text.usetex": True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)

    Xnew, xx, yy = create_grid(gp.X, N=961)
    mu, var = gp_predict(
        Xnew,
        gp.Z,
        kernels=gp.kernel,
        mean_funcs=gp.mean_func,
        f=gp.q_mu,
        q_sqrt=gp.q_sqrt,
        full_cov=False,
    )

    def gp_jacobian_all(x):
        if len(x.shape) == 1:
            x = x.reshape(1, -1)
        return gp_jacobian(
            x,
            gp.Z,
            gp.kernel,
            gp.mean_func,
            f=gp.q_mu,
            q_sqrt=gp.q_sqrt,
            full_cov=False,
        )

    mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew)
    print("gp jacobain mu var")
    print(mu_j.shape)
    print(var_j.shape)
    # mu = np.prod(mu, 1)
    # var = np.diagonal(var, axis1=-2, axis2=-1)
    # var = np.prod(var, 1)
    fig, axs = plot_mean_and_var(
        xx,
        yy,
        mu,
        var,
        # mu,
        # var,
        llabel="$\mathbb{E}[h^{(1)}]$",
        rlabel="$\mathbb{V}[h^{(1)}]$",
    )

    for ax in axs:
        ax.quiver(Xnew[:, 0], Xnew[:, 1], mu_j[:, 0], mu_j[:, 1], color="k")

        fig, ax = plot_start_and_end_pos(fig, ax, solver)
        plot_omitted_data(fig, ax, color="k")
        # ax.scatter(gp.X[:, 0], gp.X[:, 1])
        plot_traj(
            fig,
            ax,
            solver.state_guesses,
            color=color_init,
            label="Initial trajectory",
        )
        if traj_opt is not None:
            plot_traj(
                fig,
                ax,
                traj_opt,
                color=color_opt,
                label="Optimised trajectory",
            )
    axs[0].legend()
    return fig, axs
コード例 #5
0
ファイル: static.py プロジェクト: wjurayj/ergo
def dist_logloss(dist_class, fixed_params, opt_params, data):
    dist = dist_class.from_params(fixed_params, opt_params, traceable=True)
    if data.size == 1:
        return -dist.logpdf(data)
    scores = vmap(dist.logpdf)(data)
    return -np.sum(scores)
コード例 #6
0
ファイル: pull2.py プロジェクト: Dream7-Kim/graduation_code
def BW(phim,phiw,fm,fw,phi,f):
    a = np.moveaxis(vmap(partial(_BW,Sbc=phi))(phim,phiw),1,0)
    b = np.moveaxis(vmap(partial(_BW,Sbc=f))(fm,fw),1,0)
    result = dplex.deinsum('ij,ij->ij',a,b)
    return result
コード例 #7
0
def func_type1(S, A, is_training):
    # custom haiku function: s,a -> q(s,a)
    value = hk.Sequential([...])
    X = jax.vmap(jnp.kron)(S, A)  # or jnp.concatenate((S, A), axis=-1) or whatever you like
    return value(X)  # output shape: (batch_size,)
コード例 #8
0
 def inbetween(x):
     return 1 + vmap(innermost)(x)
コード例 #9
0
def _flip_state_batch_default_impl(hilb, key, states, indxs, scalar_rule):
    keys = jax.random.split(key, states.shape[0])
    res = jax.vmap(scalar_rule, in_axes=(None, 0, 0, 0), out_axes=0)(
        hilb, keys, states, indxs
    )
    return res
コード例 #10
0
ファイル: random_test.py プロジェクト: AnyaP/jax
  def testIssue1789(self):
    def f(x):
      return random.gamma(random.PRNGKey(0), x)

    grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
コード例 #11
0
def _random_state_batch_default_impl(hilb, key, size, dtype, scalar_rule):
    keys = jax.random.split(key, size)
    res = jax.vmap(scalar_rule, in_axes=(None, 0, None), out_axes=0)(hilb, key, dtype)
    return res
コード例 #12
0
ファイル: base.py プロジェクト: tensorflow/probability
def custom_layer_cau_batch(vals, dims, *, num_consts, in_tree, out_tree,
                           kwargs, **params):
    """Batching rule for layer_cau primitive to handle custom layers."""
    if all(dim is batching.not_mapped for dim in dims):
        return layer_cau_p.bind(*vals,
                                num_consts=num_consts,
                                in_tree=in_tree,
                                out_tree=out_tree,
                                kwargs=kwargs,
                                **params)
    orig_vals, orig_dims = vals, dims
    vals, dims = vals[num_consts:], dims[num_consts:]
    args = tree_util.tree_unflatten(in_tree, vals)
    dims_ = [not_mapped if dim is None else dim for dim in dims]
    layer, args = args[0], args[1:]
    if hasattr(layer, '_call_and_update_batched'):
        num_params = len(tree_util.tree_leaves(layer))
        layer_dims, arg_dims = dims_[:num_params], dims_[num_params:]
        if kwargs['has_rng']:
            rng, args = args[0], args[1:]
            rng_dim, arg_dims = arg_dims[0], arg_dims[1:]
        mapping_over_layer = all(layer_dim is not not_mapped
                                 for layer_dim in layer_dims)
        mapping_over_args = all(arg_dim is not not_mapped
                                for arg_dim in arg_dims)
        assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims)
        if not mapping_over_layer and mapping_over_args:
            if kwargs['has_rng']:
                if rng_dim is not not_mapped:
                    arg_dims = tuple(None if dim is not_mapped else dim
                                     for dim in arg_dims)
                    map_fun = jax.vmap(
                        lambda layer, rng, *args: _layer_cau_batched(
                            layer,
                            rng,
                            *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                            **kwargs),
                        in_axes=(None, rng_dim) + (None, ) * len(arg_dims))
                else:
                    map_fun = lambda layer, *args: _layer_cau_batched(
                        layer,
                        *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                        **kwargs)
                vals_out, update_out = map_fun(layer, rng, *args)
            else:
                vals_out, update_out = _layer_cau_batched(
                    layer, *args, **kwargs)
            vals_out = tree_util.tree_leaves(vals_out)
            update_out = tree_util.tree_leaves(update_out)
            assert all(dim == 0 for dim in arg_dims)
            # Assume dimensions out are consistent
            dims_out = (0, ) * len(vals_out)
            dims_update = (None, ) * len(update_out)
            assert len(vals_out) == len(dims_out)
            assert len(update_out) == len(dims_update)
            return vals_out + update_out, dims_out + dims_update
    batched, out_dims = primitive.batch_fun(
        lu.wrap_init(
            layer_cau_p.impl,
            dict(params,
                 num_consts=num_consts,
                 in_tree=in_tree,
                 out_tree=out_tree,
                 kwargs=kwargs)), orig_dims)
    return batched.call_wrapped(*orig_vals), out_dims()
コード例 #13
0
    y_at_t__backwards = solution_and_adjoint_variable_at_t[:, 0, :]
    adjoint_variable_at_t = solution_and_adjoint_variable_at_t[:, 1, :]

    J_entire_trajectory__reverse_classical_solve = loss_function_entire_trajectory(
        y_at_t__backwards, parameters_guess)

    # The initial condition was not dependent on the parameters
    d_u0__d_theta = jnp.zeros((2, 4))

    # We still have to do an integration over the time horizon
    dynamic_sensitivity_jacobian = lambda t, y, params: jnp.array(
        jax.jacobian(model, argnums=2)(t, y, *params)).T
    # The jit is not really advantageous, because we are only calling the function once
    vectorized_dynamic_sensitivity_jacobian = jax.jit(
        jax.vmap(dynamic_sensitivity_jacobian,
                 in_axes=(0, 1, None),
                 out_axes=2))
    del_f__del_theta__at_t = vectorized_dynamic_sensitivity_jacobian(
        t_discrete, y_at_t, parameters_guess)

    adjoint_variable_matmul_del_f__del_theta_at_t = jnp.einsum(
        "iN,ijN->jN", adjoint_variable_at_t, del_f__del_theta__at_t)

    d_J__d_theta__at_end__adjoint = (
        adjoint_variable_at_t[:, -1].T @ d_u0__d_theta + jnp.zeros((1, 4)) +
        integrate.trapezoid(adjoint_variable_matmul_del_f__del_theta_at_t,
                            t_discrete,
                            axis=-1))

    time_adjoint__at_end = time.time_ns() - time_adjoint__at_end
コード例 #14
0
    # print(f"t.shape {t.shape}")
    # print(f"X.shape) {X.shape}")
    input = jnp.concatenate((t, X), 0)  # M x D+1
    # print("here?")
    activations = input
    counter = 0
    for w, b in params[:-1]:
        outputs = jnp.dot(activations, w) + b
        activations = relu(outputs)
        counter += 1
    final_w, final_b = params[-1]
    u = jnp.dot(activations, final_w) + final_b
    return jnp.reshape(u, ())  #need scalar for grad


vforward = vmap(forward, in_axes=(None, 0, 0))


def grad_forward(params, t, X):
    gradu = grad(forward, argnums=(2))  #<wrt X only not params or t
    # partial(grad(loss), params)
    Du = gradu(params, t, X)
    return Du


vgrad_forward = vmap(grad_forward, in_axes=(None, 0, 0))

# def Dg_tf(X):  # M x D
#
#     g = g_tf(X)
#     Dg = torch.autograd.grad(outputs=[g], inputs=[X], grad_outputs=torch.ones_like(g), allow_unused=True,
コード例 #15
0
def map_product(
        metric_or_displacement: DisplacementOrMetricFn
) -> DisplacementOrMetricFn:
    """Vectorizes a metric or displacement function over all pairs."""
    return vmap(vmap(metric_or_displacement, (0, None), 0), (None, 0), 0)
コード例 #16
0
ファイル: util.py プロジェクト: kumsh/numpyro
def fori_collect(lower, upper, body_fun, init_val, transform=identity,
                 progbar=True, return_last_val=False, collection_size=None, **progbar_opts):
    """
    This looping construct works like :func:`~jax.lax.fori_loop` but with the additional
    effect of collecting values from the loop body. In addition, this allows for
    post-processing of these samples via `transform`, and progress bar updates.
    Note that, `progbar=False` will be faster, especially when collecting a
    lot of samples. Refer to example usage in :func:`~numpyro.infer.mcmc.hmc`.

    :param int lower: the index to start the collective work. In other words,
        we will skip collecting the first `lower` values.
    :param int upper: number of times to run the loop body.
    :param body_fun: a callable that takes a collection of
        `np.ndarray` and returns a collection with the same shape and
        `dtype`.
    :param init_val: initial value to pass as argument to `body_fun`. Can
        be any Python collection type containing `np.ndarray` objects.
    :param transform: a callable to post-process the values returned by `body_fn`.
    :param progbar: whether to post progress bar updates.
    :param bool return_last_val: If `True`, the last value is also returned.
        This has the same type as `init_val`.
    :param int collection_size: Size of the returned collection. If not specified,
        the size will be ``upper - lower``. If the size is larger than
        ``upper - lower``, only the top ``upper - lower`` entries will be non-zero.
    :param `**progbar_opts`: optional additional progress bar arguments. A
        `diagnostics_fn` can be supplied which when passed the current value
        from `body_fun` returns a string that is used to update the progress
        bar postfix. Also a `progbar_desc` keyword argument can be supplied
        which is used to label the progress bar.
    :return: collection with the same type as `init_val` with values
        collected along the leading axis of `np.ndarray` objects.
    """
    assert lower <= upper
    collection_size = upper - lower if collection_size is None else collection_size
    assert collection_size >= upper - lower
    init_val_flat, unravel_fn = ravel_pytree(transform(init_val))

    @cached_by(fori_collect, body_fun, transform)
    def _body_fn(i, vals):
        val, collection, lower_idx = vals
        val = body_fun(val)
        i = np.where(i >= lower_idx, i - lower_idx, 0)
        collection = ops.index_update(collection, i, ravel_pytree(transform(val))[0])
        return val, collection, lower_idx

    collection = np.zeros((collection_size,) + init_val_flat.shape)
    if not progbar:
        last_val, collection, _ = fori_loop(0, upper, _body_fn, (init_val, collection, lower))
    else:
        diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
        progbar_desc = progbar_opts.pop('progbar_desc', lambda x: '')

        vals = (init_val, collection, device_put(lower))
        if upper == 0:
            # special case, only compiling
            jit(_body_fn)(0, vals)
        else:
            with tqdm.trange(upper) as t:
                for i in t:
                    vals = jit(_body_fn)(i, vals)
                    t.set_description(progbar_desc(i), refresh=False)
                    if diagnostics_fn:
                        t.set_postfix_str(diagnostics_fn(vals[0]), refresh=False)

        last_val, collection, _ = vals

    unravel_collection = vmap(unravel_fn)(collection)
    return (unravel_collection, last_val) if return_last_val else unravel_collection
コード例 #17
0
def map_bond(
        metric_or_displacement: DisplacementOrMetricFn
) -> DisplacementOrMetricFn:
    """Vectorizes a metric or displacement function over bonds."""
    return vmap(metric_or_displacement, (0, 0), 0)
コード例 #18
0
ファイル: static.py プロジェクト: wjurayj/ergo
def logistic_mixture_logpdf(params, data):
    # params are assumed to be normalized
    if data.size == 1:
        return logistic_mixture_logpdf1(params, data)
    scores = vmap(partial(logistic_mixture_logpdf1, params))(data)
    return np.sum(scores)
コード例 #19
0
ファイル: checkify_test.py プロジェクト: John1Tang/jax
 def test_nd_payloads(self):
     cf = checkify.checkify(lambda x, i: x[i], errors=checkify.index_checks)
     errs, _ = jax.vmap(cf)(jnp.ones((3, 2)), jnp.array([5, 0, 100]))
     self.assertIsNotNone(errs.get())
     self.assertIn("index 5", errs.get())
     self.assertIn("index 100", errs.get())
コード例 #20
0
 def _vmapped_projection(self, supports, weights, target_support):
     return jax.vmap(rainbow_agent.project_distribution,
                     in_axes=(0, 0, None))(supports, weights,
                                           target_support)
コード例 #21
0
 def update_chains(state, rng_key):
     keys = jax.random.split(rng_key, self.num_chains)
     new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys,
                                                              state)
     return new_states, info
コード例 #22
0
ファイル: pull2.py プロジェクト: Dream7-Kim/graduation_code
def phase(_theta,_rho):
    result = vmap(_phase)(_theta,_rho)
    return result
コード例 #23
0
ファイル: mnist-mlp.py プロジェクト: adambozson/vae-jax
def relu(x: Array) -> Array:
    return np.maximum(0, x)


@jit
def predict(params: List[Tuple[Array]], image: Array) -> Array:
    activations = image
    for w, b in params[:-1]:
        out = np.dot(w, activations) + b
        activations = relu(out)
    w, b = params[-1]
    logits = np.dot(w, activations) + b
    return logits - jax.scipy.special.logsumexp(logits)


batch_predict = vmap(predict, in_axes=(None, 0))


def loss(params: List[Tuple[Array]], images: Array, targets: Array) -> float:
    preds = batch_predict(params, images)
    return -np.sum(preds * targets)


@jit
def update(params: List[Tuple[Array]], x: Array, y: Array) -> List[Array]:
    grads = grad(loss)(params, x, y)
    return [
        (w - step_size * dw, b - step_size * db)
        for (w, b), (dw, db) in zip(params, grads)
    ]
コード例 #24
0
 def mass_matrix_inv_mul(self, q: jnp.ndarray, v: jnp.ndarray,
                         **kwargs) -> jnp.ndarray:
     """Computes the product of the inverse mass matrix with a vector."""
     if self.kinetic_func_form in ("separable_net", "dep_net"):
         raise ValueError(
             "It is not possible to compute `M^-1 p` when using a "
             "network for the kinetic energy.")
     if self.kinetic_func_form in ("pure_quad", "embed_quad"):
         return v
     if self.kinetic_func_form == "matrix_diag_quad":
         if self.parametrize_mass_matrix:
             m_diag_log = hk.get_parameter(
                 "MassMatrixDiagLog",
                 shape=[self.system_dim],
                 init=hk.initializers.Constant(0.0))
             m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps)
         else:
             m_inv_diag_log = hk.get_parameter(
                 "InvMassMatrixDiagLog",
                 shape=[self.system_dim],
                 init=hk.initializers.Constant(0.0))
             m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps
         return m_inv_diag * v
     if self.kinetic_func_form == "matrix_quad":
         if self.parametrize_mass_matrix:
             m_triu = hk.get_parameter(
                 "MassMatrixU",
                 shape=[self.system_dim, self.system_dim],
                 init=hk.initializers.Identity())
             m_triu = jnp.triu(m_triu)
             m = jnp.matmul(m_triu.T, m_triu)
             m = m + self.mass_eps * jnp.eye(self.system_dim)
             solve = jnp.linalg.solve
             for _ in range(v.ndim + 1 - m.ndim):
                 solve = jax.vmap(solve, in_axes=(None, 0))
             return solve(m, v)
         else:
             m_inv_triu = hk.get_parameter(
                 "InvMassMatrixU",
                 shape=[self.system_dim, self.system_dim],
                 init=hk.initializers.Identity())
             m_inv_triu = jnp.triu(m_inv_triu)
             m_inv = jnp.matmul(m_inv_triu.T, m_inv_triu)
             m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim)
             return self.feature_matrix_vector(m_inv, v)
     if self.kinetic_func_form in ("matrix_dep_diag_quad",
                                   "matrix_dep_diag_embed_quad"):
         if self.parametrize_mass_matrix:
             m_diag_log = self.mass_matrix_net(q, **kwargs)
             m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps)
         else:
             m_inv_diag_log = self.mass_matrix_net(q, **kwargs)
             m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps
         return m_inv_diag * v
     if self.kinetic_func_form in ("matrix_dep_quad",
                                   "matrix_dep_embed_quad"):
         if self.parametrize_mass_matrix:
             m_triu = self.mass_matrix_net(q, **kwargs)
             m_triu = utils.triu_matrix_from_v(m_triu, self.system_dim)
             m = jnp.matmul(jnp.swapaxes(m_triu, -2, -1), m_triu)
             m = m + self.mass_eps * jnp.eye(self.system_dim)
             return jnp.linalg.solve(m, v)
         else:
             m_inv_triu = self.mass_matrix_net(q, **kwargs)
             m_inv_triu = utils.triu_matrix_from_v(m_inv_triu,
                                                   self.system_dim)
             m_inv = jnp.matmul(jnp.swapaxes(m_inv_triu, -2, -1),
                                m_inv_triu)
             m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim)
             return self.feature_matrix_vector(m_inv, v)
     raise NotImplementedError()
コード例 #25
0
def map_product(metric_or_displacement):
  return vmap(vmap(metric_or_displacement, (0, None), 0), (None, 0), 0)
コード例 #26
0
def plot_aleatoric_var_vs_time(gp, solver, traj_init, traj_opt=None):
    params = {
        "text.usetex": True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)

    mixing_probs_init = jax.vmap(
        single_mogpe_mixing_probability,
        (0, None, None, None, None, None, None),
    )(
        traj_init[:, 0:2],
        gp.Z,
        gp.kernel,
        gp.mean_func,
        gp.q_mu,
        False,
        gp.q_sqrt,
    )
    if mixing_probs_init.shape[-1] == 1:
        mixing_probs_init = np.concatenate(
            [mixing_probs_init, 1 - mixing_probs_init], -1
        )
    if traj_opt is not None:
        mixing_probs_opt = jax.vmap(
            single_mogpe_mixing_probability,
            (0, None, None, None, None, None, None),
        )(
            traj_opt[:, 0:2],
            gp.Z,
            gp.kernel,
            gp.mean_func,
            gp.q_mu,
            False,
            gp.q_sqrt,
        )
        if mixing_probs_opt.shape[-1] == 1:
            mixing_probs_opt = np.concatenate(
                [mixing_probs_opt, 1 - mixing_probs_opt], -1
            )

    noise_vars = np.array(gp.noise_vars).reshape(-1, 1)
    var_init = mixing_probs_init @ noise_vars

    var_opt = 0
    if traj_opt is not None:
        var_opt = mixing_probs_opt @ noise_vars
        print("var opt")
        print(var_opt.shape)

    fig, ax = plt.subplots(1, 1, figsize=(6.4, 2.8))

    ax.set_xlabel("Time $t$")
    ax.set_ylabel("$\sum_{k=1}^K\Pr(\\alpha=k|\mathbf{x}) (\sigma^{(k)})^2$")

    ax.plot(
        solver.times, var_init, color=color_init, label="Initial trajectory"
    )
    if traj_opt is not None:
        ax.plot(
            solver.times,
            var_opt,
            color=color_opt,
            label="Optimised trajectory",
        )
    ax.legend()
    sum_var_init = np.sum(var_init)
    sum_var_opt = np.sum(var_opt)
    print("Sum aleatoric var init = ", sum_var_init)
    print("Sum aleatoric var opt = ", sum_var_opt)
    return fig, ax
コード例 #27
0
def map_bond(metric_or_displacement):
  return vmap(metric_or_displacement, (0, 0), 0)
コード例 #28
0
def compute_ssim(img0,
                 img1,
                 max_val,
                 filter_size=11,
                 filter_sigma=1.5,
                 k1=0.01,
                 k2=0.03,
                 return_map=False):
  """Computes SSIM from two images.

  This function was modeled after tf.image.ssim, and should produce comparable
  output.

  Args:
    img0: array. An image of size [..., width, height, num_channels].
    img1: array. An image of size [..., width, height, num_channels].
    max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
    filter_size: int >= 1. Window size.
    filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
    k1: float > 0. One of the SSIM dampening parameters.
    k2: float > 0. One of the SSIM dampening parameters.
    return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned

  Returns:
    Each image's mean SSIM, or a tensor of individual values if `return_map`.
  """
  # Construct a 1D Gaussian blur filter.
  hw = filter_size // 2
  shift = (2 * hw - filter_size + 1) / 2
  f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2
  filt = jnp.exp(-0.5 * f_i)
  filt /= jnp.sum(filt)

  # Blur in x and y (faster than the 2D convolution).
  filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
  filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")

  # Vmap the blurs to the tensor size, and then compose them.
  num_dims = len(img0.shape)
  map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
  for d in map_axes:
    filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
    filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
  filt_fn = lambda z: filt_fn1(filt_fn2(z))

  mu0 = filt_fn(img0)
  mu1 = filt_fn(img1)
  mu00 = mu0 * mu0
  mu11 = mu1 * mu1
  mu01 = mu0 * mu1
  sigma00 = filt_fn(img0**2) - mu00
  sigma11 = filt_fn(img1**2) - mu11
  sigma01 = filt_fn(img0 * img1) - mu01

  # Clip the variances and covariances to valid values.
  # Variance must be non-negative:
  sigma00 = jnp.maximum(0., sigma00)
  sigma11 = jnp.maximum(0., sigma11)
  sigma01 = jnp.sign(sigma01) * jnp.minimum(
      jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01))

  c1 = (k1 * max_val)**2
  c2 = (k2 * max_val)**2
  numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
  denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
  ssim_map = numer / denom
  ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
  return ssim_map if return_map else ssim
コード例 #29
0
 def wrapped_fn(Ra, Rb, **kwargs):
   return vmap(vmap(metric_or_displacement, (None, 0)))(-Ra, -Rb, **kwargs)
コード例 #30
0
ファイル: jax.py プロジェクト: Aurametrix/AScript
import jax.numpy as np
from jax import grad, jit, vmap
from functools import partial

def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)

grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jit(lambda params, inputs, targets:  # fast per-example gradients
                  vmap(partial(grad_fun, params), inputs, targets))
コード例 #31
0
 def objective_vectorized(X):  # x is (N,2)
     f = vmap(objective)(X)
     return f