Esempio n. 1
0
 def wrapped(*args, **kwargs):
     in_specs = tree_util.tree_map(state.make_array_spec, args)
     kwargs = kwargs_util.filter_kwargs(template._spec, kwargs)  # pylint: disable=protected-access
     out_specs = template._spec(*in_specs, **kwargs)  # pylint: disable=protected-access
     return tree_util.tree_map(state.make_array_spec, out_specs)
Esempio n. 2
0
def _zeros_like_pytree(x):
    return tree_map(Zero.from_value, x)
Esempio n. 3
0
def gmres(A,
          b,
          x0=None,
          *,
          tol=1e-5,
          atol=0.0,
          restart=20,
          maxiter=None,
          M=None,
          solve_method='batched'):
    """
  GMRES solves the linear system A x = b for x, given A and b.

  A is specified as a function performing A(vi) -> vf = A @ vi, and in principle
  need not have any particular special properties, such as symmetry. However,
  convergence is often slow for nearly symmetric operators.

  Parameters
  ----------
  A: ndarray or function
      2D array or function that calculates the linear map (matrix-vector
      product) ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with
      the same structure and shape as its argument.
  b : array or tree of arrays
      Right hand side of the linear system representing a single vector. Can be
      stored as an array or Python container of array(s) with any shape.

  Returns
  -------
  x : array or tree of arrays
      The converged solution. Has the same structure as ``b``.
  info : None
      Placeholder for convergence information. In the future, JAX will report
      the number of iterations when convergence is not achieved, like SciPy.

  Other Parameters
  ----------------
  x0 : array, optional
      Starting guess for the solution. Must have the same structure as ``b``.
      If this is unspecified, zeroes are used.
  tol, atol : float, optional
      Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
      We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
      differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.
  restart : integer, optional
      Size of the Krylov subspace ("number of iterations") built between
      restarts. GMRES works by approximating the true solution x as its
      projection into a Krylov space of this dimension - this parameter
      therefore bounds the maximum accuracy achievable from any guess
      solution. Larger values increase both number of iterations and iteration
      cost, but may be necessary for convergence. The algorithm terminates
      early if convergence is achieved before the full subspace is built.
      Default is 20.
  maxiter : integer
      Maximum number of times to rebuild the size-``restart`` Krylov space
      starting from the solution found at the last iteration. If GMRES
      halts or is very slow, decreasing this parameter may help.
      Default is infinite.
  M : ndarray or function
      Preconditioner for A.  The preconditioner should approximate the
      inverse of A.  Effective preconditioning dramatically improves the
      rate of convergence, which implies that fewer iterations are needed
      to reach a given error tolerance.
  solve_method : 'incremental' or 'batched'
      The 'incremental' solve method builds a QR decomposition for the Krylov
      subspace incrementally during the GMRES process using Givens rotations.
      This improves numerical stability and gives a free estimate of the
      residual norm that allows for early termination within a single "restart".
      In contrast, the 'batched' solve method solves the least squares problem
      from scratch at the end of each GMRES iteration. It does not allow for
      early termination, but has much less overhead on GPUs.

  See also
  --------
  scipy.sparse.linalg.gmres
  jax.lax.custom_linear_solve
  """

    if x0 is None:
        x0 = tree_map(jnp.zeros_like, b)
    if M is None:
        M = _identity
    A = _normalize_matvec(A)
    M = _normalize_matvec(M)

    b, x0 = device_put((b, x0))
    size = sum(bi.size for bi in tree_leaves(b))

    if maxiter is None:
        maxiter = 10 * size  # copied from scipy
    restart = min(restart, size)

    if tree_structure(x0) != tree_structure(b):
        raise ValueError('x0 and b must have matching tree structure: '
                         f'{tree_structure(x0)} vs {tree_structure(b)}')

    b_norm = _norm(b)
    atol = jnp.maximum(tol * b_norm, atol)

    Mb = M(b)
    Mb_norm = _norm(Mb)
    ptol = Mb_norm * jnp.minimum(1.0, atol / b_norm)

    if solve_method == 'incremental':
        gmres_func = _gmres_incremental
    elif solve_method == 'batched':
        gmres_func = _gmres_batched
    else:
        raise ValueError(
            f"invalid solve_method {solve_method}, must be either "
            "'incremental' or 'batched'")

    def _solve(A, b):
        return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M,
                            gmres_func)

    x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)

    failed = jnp.isnan(_norm(x))
    info = jnp.where(failed, x=-1, y=0)
    return x, info
Esempio n. 4
0
def clip_grads(grad_tree, max_norm):
    """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
    norm = l2_norm(grad_tree)
    normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
    return tree_map(normalize, grad_tree)
Esempio n. 5
0
def _tree_ones_like(tree):
  def f(x):
    return jnp.ones_like(x)
  return tu.tree_map(f, tree)
Esempio n. 6
0
def _allgather(x, dim, size, index, axis_name, axis_index_groups=None):
    outs = tree_util.tree_map(partial(_expand, dim, size, index), x)
    return psum(outs, axis_name, axis_index_groups=axis_index_groups)
Esempio n. 7
0
    def predict_fn(
        t: ArrayOrScalar = None,
        fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
        fx_test_0: ArrayOrScalar = None,
        ntk_test_train: np.ndarray = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]:
        """Return output predictions on train [and test] set[s] at time[s] `t`.

    Args:
      t:
        a scalar or array of scalars of any shape in strictly increasing order.
        `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of
        training steps (but can be fractional).
      fx_train_or_state_0:
        either (a) output of the network at `t == 0` on the training set or (b)
        complete ODE state (`predict.ODEState`). Pass an ODE state if you want
        to operate on the full ODE state instead of output variables only
        (useful for inspecting auxiliary variables or resuming an optimizer with
        auxiliary variables from a specific state. Note that only
        `momentum != None` optimizer currently has auxiliary variables. To
        initialize an ODE state from scratch, call
        `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an
        ODE state is returned. `fx_train_0=None` means to not compute
        predictions on the training set.
      fx_test_0:
        output of the network at `t == 0` on the test set. `fx_test_0=None`
        means to not compute predictions on the test set.
      ntk_test_train:
        kernel relating test data with training data. Must have the shape of
        `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass
        `ntk_test_train=None` if you only need predictions on the training set.

    Returns:
      `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with
      potentially additional leading time dimensions matching `t.shape`.
      Alternatively can return an `ODEState` at time[s] `t`.

    Raises:
      ValueError: if `fx_test_0` is not `None`, but `ntk_test_train` is `None`.
    """
        _check_inputs(fx_train_or_state_0, fx_test_0, ntk_test_train)

        t = np.array(t if t is not None else np.inf, dtype) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, ))

        # ODE solver requires `t[0]` to be the time where `fx_train_0` [and
        # `fx_test_0`] are evaluated, but also a strictly increasing sequence of
        # timesteps, so we always temporarily append an [almost] `0` at the start.
        identity = lambda x: x
        t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype),
                      np.zeros((1, ), t.dtype))
        t = np.concatenate([t0, t])

        # Solve the ODE.
        fx_test_shape = _get_fx_test_shape(y_train, ntk_test_train, trace_axes)
        state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape)
        state_t = ode.odeint(get_dstate_dt(ntk_test_train), state_0, t)

        # Remove the added `t0`.
        trim = lambda x: x[1:].reshape(t_shape + x.shape[1:])
        trim_tree = lambda tree: tree_map(trim, tree)
        state_t = trim_tree(state_t)

        # `ODEState` -> `ODEState`
        if isinstance(fx_train_or_state_0, ODEState):
            return state_t

        # `np.ndarray` -> `np.ndarray`
        fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test

        if fx_train_or_state_0 is not None and fx_test_0 is None:
            return fx_train_t
        if fx_test_0 is not None and fx_train_or_state_0 is None:
            return fx_test_t
        return fx_train_t, fx_test_t
Esempio n. 8
0
def tree_size(tree):
    """
    Returns the sum of the size of all leaves in the tree.
    It's equivalent to the number of scalars in the pytree.
    """
    return sum(tree_leaves(tree_map(lambda x: x.size, tree)))
Esempio n. 9
0
    def _ApplyGraphNet(graph):
        """Applies a configured GraphNetwork to a graph.

    This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261

    There is one difference. For the nodes update the class aggregates over the
    sender edges and receiver edges separately. This is a bit more general
    the algorithm described in the paper. The original behaviour can be
    recovered by using only the receiver edge aggregations for the update.

    In addition this implementation supports softmax attention over incoming
    edge features.

    Many popular Graph Neural Networks can be implemented as special cases of
    GraphNets, for more information please see the paper.

    Args:
      graph: a `GraphsTuple` containing the graph.

    Returns:
      Updated `GraphsTuple`.
    """
        # pylint: disable=g-long-lambda
        nodes, edges, receivers, senders, globals_, n_node, n_edge = graph
        # Equivalent to jnp.sum(n_node), but jittable
        sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
        sum_n_edge = senders.shape[0]
        if not tree.tree_all(
                tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
            raise ValueError(
                'All node arrays in nest must contain the same number of nodes.'
            )

        sent_attributes = tree.tree_map(lambda n: n[senders], nodes)
        received_attributes = tree.tree_map(lambda n: n[receivers], nodes)
        # Here we scatter the global features to the corresponding edges,
        # giving us tensors of shape [num_edges, global_feat].
        global_edge_attributes = tree.tree_map(
            lambda g: jnp.repeat(
                g, n_edge, axis=0, total_repeat_length=sum_n_edge), globals_)

        if update_edge_fn:
            edges = update_edge_fn(edges, sent_attributes, received_attributes,
                                   global_edge_attributes)

        if attention_logit_fn:
            logits = attention_logit_fn(edges, sent_attributes,
                                        received_attributes,
                                        global_edge_attributes)
            tree_calculate_weights = functools.partial(attention_normalize_fn,
                                                       segment_ids=receivers,
                                                       num_segments=sum_n_node)
            weights = tree.tree_map(tree_calculate_weights, logits)
            edges = attention_reduce_fn(edges, weights)

        if update_node_fn:
            sent_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(e, senders, sum_n_node),
                edges)
            received_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(e, receivers, sum_n_node
                                                       ), edges)
            # Here we scatter the global features to the corresponding nodes,
            # giving us tensors of shape [num_nodes, global_feat].
            global_attributes = tree.tree_map(
                lambda g: jnp.repeat(
                    g, n_node, axis=0, total_repeat_length=sum_n_node),
                globals_)
            nodes = update_node_fn(nodes, sent_attributes, received_attributes,
                                   global_attributes)

        if update_global_fn:
            n_graph = n_node.shape[0]
            graph_idx = jnp.arange(n_graph)
            # To aggregate nodes and edges from each graph to global features,
            # we first construct tensors that map the node to the corresponding graph.
            # For example, if you have `n_node=[1,2]`, we construct the tensor
            # [0, 1, 1]. We then do the same for edges.
            node_gr_idx = jnp.repeat(graph_idx,
                                     n_node,
                                     axis=0,
                                     total_repeat_length=sum_n_node)
            edge_gr_idx = jnp.repeat(graph_idx,
                                     n_edge,
                                     axis=0,
                                     total_repeat_length=sum_n_edge)
            # We use the aggregation function to pool the nodes/edges per graph.
            node_attributes = tree.tree_map(
                lambda n: aggregate_nodes_for_globals_fn(
                    n, node_gr_idx, n_graph), nodes)
            edge_attribtutes = tree.tree_map(
                lambda e: aggregate_edges_for_globals_fn(
                    e, edge_gr_idx, n_graph), edges)
            # These pooled nodes are the inputs to the global update fn.
            globals_ = update_global_fn(node_attributes, edge_attribtutes,
                                        globals_)
        # pylint: enable=g-long-lambda
        return gn_graph.GraphsTuple(nodes=nodes,
                                    edges=edges,
                                    receivers=receivers,
                                    senders=senders,
                                    globals=globals_,
                                    n_node=n_node,
                                    n_edge=n_edge)
Esempio n. 10
0
 def reject_update(_):
     return (tree_map(jnp.zeros_like, updates), inner_state)
Esempio n. 11
0
def _is_on_cpu(x):
  return tree_all(tree_map(_arr_is_on_cpu, x))
Esempio n. 12
0
    def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
        # 0. Separate model and guide parameters, since only guide parameters are updated using Stein
        classic_uparams = {
            p: v
            for p, v in unconstr_params.items() if
            p not in self.guide_param_names or self.classic_guide_params_fn(p)
        }
        stein_uparams = {
            p: v
            for p, v in unconstr_params.items() if p not in classic_uparams
        }
        # 1. Collect each guide parameter into monolithic particles that capture correlations between parameter values across each individual particle
        stein_particles, unravel_pytree, unravel_pytree_batched = ravel_pytree(
            stein_uparams, batch_dims=1)
        particle_info = self._calc_particle_info(stein_uparams,
                                                 stein_particles.shape[0])

        # 2. Calculate loss and gradients for each parameter (broadcasting by num_loss_particles for increased variance reduction)
        def scaled_loss(rng_key, classic_params, stein_params):
            params = {**classic_params, **stein_params}
            loss_val = self.loss.loss(
                rng_key, params,
                handlers.scale(self.model, self.loss_temperature), self.guide,
                *args, **kwargs, **self.static_kwargs)
            return -loss_val

        kernel_particle_loss_fn = lambda ps: scaled_loss(
            rng_key, self.constrain_fn(classic_uparams),
            self.constrain_fn(unravel_pytree(ps)))
        loss, particle_ljp_grads = jax.vmap(
            jax.value_and_grad(kernel_particle_loss_fn))(stein_particles)
        classic_param_grads = jax.vmap(lambda ps: jax.grad(
            lambda cps: scaled_loss(rng_key, self.constrain_fn(cps),
                                    self.constrain_fn(unravel_pytree(ps))))
                                       (classic_uparams))(stein_particles)
        classic_param_grads = tree_map(jax.partial(np.mean, axis=0),
                                       classic_param_grads)

        # 3. Calculate kernel on monolithic particle
        kernel = self.kernel_fn.compute(stein_particles, particle_info,
                                        kernel_particle_loss_fn)

        # 4. Calculate the attractive force and repulsive force on the monolithic particles
        attractive_force = jax.vmap(lambda y: np.sum(jax.vmap(
            lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad)
        )(stein_particles, particle_ljp_grads),
                                                     axis=0))(stein_particles)
        repulsive_force = jax.vmap(lambda y: np.sum(jax.vmap(
            lambda x: self.repulsion_temperature * self._kernel_grad(
                kernel, x, y))(stein_particles),
                                                    axis=0))(stein_particles)
        particle_grads = (attractive_force +
                          repulsive_force) / self.num_stein_particles

        # 5. Decompose the monolithic particle forces back to concrete parameter values
        stein_param_grads = unravel_pytree_batched(particle_grads)

        # 6. Return loss and gradients (based on parameter forces)
        res_grads = tree_map(lambda x: -x, {
            **classic_param_grads,
            **stein_param_grads
        })
        return -np.mean(loss), res_grads
Esempio n. 13
0
 def z(t):
     p = tree_multimap(np.add, params, tree_map(lambda x: t * x, dfdw))
     return f(p, x1)
Esempio n. 14
0
    @partial(pmap, axis_name='batch')
    def spmd_update(params, batch):
        grads = grad(loss)(params, batch)
        # We compute the total gradients, summing across the device-mapped axis,
        # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum.
        grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch'))
                 for dw, db in grads]
        return [(w - step_size * dw, b - step_size * db)
                for (w, b), (dw, db) in zip(params, grads)]

    # We replicate the parameters so that the constituent arrays have a leading
    # dimension of size equal to the number of devices we're pmapping over.
    init_params = init_random_params(param_scale, layer_sizes)
    replicate_array = lambda x: onp.broadcast_to(x, (num_devices, ) + x.shape)
    replicated_params = tree_map(replicate_array, init_params)

    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            replicated_params = spmd_update(replicated_params, next(batches))
        epoch_time = time.time() - start_time

        # We evaluate using the jitted `accuracy` function (not using pmap) by
        # grabbing just one of the replicated parameter values.
        params = tree_map(lambda x: x[0], replicated_params)
        train_acc = accuracy(params, (train_images, train_labels))
        test_acc = accuracy(params, (test_images, test_labels))
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))
Esempio n. 15
0
def _get_value_from_index(xs, i):
    return tree_map(lambda x: x[i], xs)
Esempio n. 16
0
def _cx2flt(c):
    # convert a complex-valued tree to a pair of real-valued trees
    return tree_multimap(lambda *xs: tuple(xs),
                         *tree_map(lambda x: (x.real, x.imag), c))
Esempio n. 17
0
    def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
        """
        Run the MCMC samplers and collect samples.

        :param random.PRNGKey rng_key: Random number generator key to be used for the sampling.
            For multi-chains, a batch of `num_chains` keys can be supplied. If `rng_key`
            does not have batch_size, it will be split in to a batch of `num_chains` keys.
        :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method.
            These are typically the arguments needed by the `model`.
        :param extra_fields: Extra fields (aside from `z`, `diverging`) from :data:`numpyro.infer.mcmc.HMCState`
            to collect during the MCMC run.
        :type extra_fields: tuple or list
        :param init_params: Initial parameters to begin sampling. The type must be consistent
            with the input type to `potential_fn`.
        :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init`
            method. These are typically the keyword arguments needed by the `model`.

        .. note:: jax allows python code to continue even when the compiled code has not finished yet.
            This can cause troubles when trying to profile the code for speed.
            See https://jax.readthedocs.io/en/latest/async_dispatch.html and
            https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.
        """
        self._args = args
        self._kwargs = kwargs
        init_state = self._get_cached_init_state(rng_key, args, kwargs)
        if self.num_chains > 1 and rng_key.ndim == 1:
            rng_key = random.split(rng_key, self.num_chains)

        if self._warmup_state is not None:
            self._set_collection_params(0, self.num_samples, self.num_samples,
                                        "sample")
            init_state = self._warmup_state._replace(rng_key=rng_key)

        if init_params is not None and self.num_chains > 1:
            prototype_init_val = tree_flatten(init_params)[0][0]
            if jnp.shape(prototype_init_val)[0] != self.num_chains:
                raise ValueError(
                    '`init_params` must have the same leading dimension'
                    ' as `num_chains`.')
        assert isinstance(extra_fields, (tuple, list))
        collect_fields = tuple(
            set((self._sample_field, ) + tuple(self._default_fields) +
                tuple(extra_fields)))
        partial_map_fn = partial(self._single_chain_mcmc,
                                 args=args,
                                 kwargs=kwargs,
                                 collect_fields=collect_fields)
        map_args = (rng_key, init_state, init_params)
        if self.num_chains == 1:
            states_flat, last_state = partial_map_fn(map_args)
            states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
        else:
            if self.chain_method == 'sequential':
                states, last_state = _laxmap(partial_map_fn, map_args)
            elif self.chain_method == 'parallel':
                states, last_state = pmap(partial_map_fn)(map_args)
            else:
                assert self.chain_method == 'vectorized'
                states, last_state = partial_map_fn(map_args)
                # swap num_samples x num_chains to num_chains x num_samples
                states = tree_map(lambda x: jnp.swapaxes(x, 0, 1), states)
            states_flat = tree_map(
                lambda x: jnp.reshape(x, (-1, ) + x.shape[2:]), states)
        self._last_state = last_state
        self._states = states
        self._states_flat = states_flat
        self._set_collection_params()
Esempio n. 18
0
def tree_broadcast(full_treedef, tree, is_leaf=None):
    full_tree = tree_fill(0, full_treedef)
    return tree_map(tree_fill_like, tree, full_tree, is_leaf=is_leaf)
Esempio n. 19
0
 def is_array(x):
   return tree_all(tree_map(
       lambda x: isinstance(x, (onp.ndarray, np.ndarray)), x))
Esempio n. 20
0
def sdScene(scene, pos):
    distances = vmap(partial(sdObj, pos))(scene)
    d = np.min(distances)
    i = np.argmin(distances)
    return (tree_map(lambda x: x[i], scene), d)
Esempio n. 21
0
def get_shallow_tree(is_leaf, tree):
    """Returns a shallow tree, expanding only when is_leaf(subtree) is False."""
    return tree_util.tree_map(is_leaf, tree, is_leaf=is_leaf)
Esempio n. 22
0
        for i in range(num_devices):
            params = jax.device_put(get_params(op_state), jax.devices()[i])
            _grad = jax.device_put(
                grad(loss)(params, batch_list[i]),
                jax.devices()[i])
        _grad = jax.device_put(_grad, jax.devices()[0])
        k = jax.device_put(k, jax.devices()[0])
        op_state = opt_update(k, _grad, op_state)
        return op_state

    replicate_array = lambda x: jnp.broadcast_to(x, (num_devices, ) + x.shape)
    allreduce = True

    if allreduce:
        op_state = opt_init(init_params)
        replicated_op_state = tree_map(replicate_array, op_state)
        for i in range(num_steps):
            #params, treedef = tree_flatten(params)
            if i == 3:
                cu_prof_start()
            new_batch = next(batches)
            start_time = time.time()
            replicated_op_state = allreduce_spmd_update(
                jnp.array([i] * num_devices), replicated_op_state, new_batch)
            if i == 3:
                cu_prof_stop()
            end_time = time.time() - start_time
            print("time:", end_time)
    else:
        op_state = jax.device_put(opt_init(init_params), jax.devices()[0])
        '''
Esempio n. 23
0
def _tree_copy(tree):
    def f(x):
        arr = jnp.empty_like(x)
        arr.fill(x)
        return arr
    return tu.tree_map(f, tree)
Esempio n. 24
0
 def ps_pre_process(op_state):
     params = get_params(op_state)
     replicated_op_params = tree_map(replicate_array, params)
     return replicated_op_params
Esempio n. 25
0
 def in_avals(self):
   """Tree of input avals."""
   return tree_util.tree_map(lambda x: x.aval, self.args_info)
Esempio n. 26
0
 def ps_post_process(grads, op_state, i):
     grads = tree_map(lambda x: jnp.sum(x, axis=0), grads)
     op_state = opt_update(i, grads, op_state)
     return op_state
Esempio n. 27
0
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
    """Use Conjugate Gradient iteration to solve ``Ax = b``.

  The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
  numerical precision), but note that the interface is slightly different: you
  need to supply the linear operator ``A`` as a function instead of a sparse
  matrix or ``LinearOperator``.

  Derivatives of ``cg`` are implemented via implicit differentiation with
  another ``cg`` solve, rather than by differentiating *through* the solver.
  They will be accurate only if both solves converge.

  Parameters
  ----------
  A: ndarray or function
      2D array or function that calculates the linear map (matrix-vector
      product) ``Ax`` when called like ``A(x)``. ``A`` must represent a
      hermitian, positive definite matrix, and must return array(s) with the
      same structure and shape as its argument.
  b : array or tree of arrays
      Right hand side of the linear system representing a single vector. Can be
      stored as an array or Python container of array(s) with any shape.

  Returns
  -------
  x : array or tree of arrays
      The converged solution. Has the same structure as ``b``.
  info : None
      Placeholder for convergence information. In the future, JAX will report
      the number of iterations when convergence is not achieved, like SciPy.

  Other Parameters
  ----------------
  x0 : array
      Starting guess for the solution. Must have the same structure as ``b``.
  tol, atol : float, optional
      Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
      We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
      differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
  maxiter : integer
      Maximum number of iterations.  Iteration will stop after maxiter
      steps even if the specified tolerance has not been achieved.
  M : ndarray or function
      Preconditioner for A.  The preconditioner should approximate the
      inverse of A.  Effective preconditioning dramatically improves the
      rate of convergence, which implies that fewer iterations are needed
      to reach a given error tolerance.

  See also
  --------
  scipy.sparse.linalg.cg
  jax.lax.custom_linear_solve
  """
    if x0 is None:
        x0 = tree_map(jnp.zeros_like, b)

    b, x0 = device_put((b, x0))

    if maxiter is None:
        size = sum(bi.size for bi in tree_leaves(b))
        maxiter = 10 * size  # copied from scipy

    if M is None:
        M = _identity
    A = _normalize_matvec(A)
    M = _normalize_matvec(M)

    if tree_structure(x0) != tree_structure(b):
        raise ValueError('x0 and b must have matching tree structure: '
                         f'{tree_structure(x0)} vs {tree_structure(b)}')

    if _shapes(x0) != _shapes(b):
        raise ValueError('arrays in x0 and b must have matching shapes: '
                         f'{_shapes(x0)} vs {_shapes(b)}')

    cg_solve = partial(_cg_solve,
                       x0=x0,
                       tol=tol,
                       atol=atol,
                       maxiter=maxiter,
                       M=M)

    # real-valued positive-definite linear operators are symmetric
    def real_valued(x):
        return not issubclass(x.dtype.type, np.complexfloating)

    symmetric = all(map(real_valued, tree_leaves(b)))
    x = lax.custom_linear_solve(A,
                                b,
                                solve=cg_solve,
                                transpose_solve=cg_solve,
                                symmetric=symmetric)
    info = None  # TODO(shoyer): return the real iteration count here
    return x, info
Esempio n. 28
0
    def test_jit_or_pmap_broadcast(self):
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= random.uniform(keys) * p
            return [res, params]

        params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
        x2 = np.arange(0, 10).reshape((10, ))
        keys = random.PRNGKey(1)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=0)
        x1 = np.arange(0, 10).reshape((1, 10))
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=0):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=True,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=True)
                    self.assertAllClose(res_1, res_2, True)

        test_utils.stub_out_pmap(batch, 1)
        x1 = np.arange(0, 10).reshape((1, 10))
        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=1)
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=1):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=False,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None)
                    self.assertAllClose(res_1[0], res_2[0], True)
                    self.assertAllClose(
                        tree_map(partial(np.expand_dims, axis=0), res_1[1]),
                        res_2[1], True)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=2)
        x1 = np.arange(0, 20).reshape((2, 10))
        test_utils.stub_out_pmap(batch, 2)

        def broadcast(arg):
            return np.broadcast_to(arg, (2, ) + arg.shape)

        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=2):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      p=0.2)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None,
                                              p=0.2)
                    self.assertAllClose(res_1[0][0], res_2[0][0], True)
                    self.assertAllClose(res_1[0][1], res_2[0][1], True)
                    self.assertAllClose(tree_map(broadcast, res_1[1]),
                                        res_2[1], True)
Esempio n. 29
0
def _mul(scalar, tree):
    return tree_map(partial(operator.mul, scalar), tree)
Esempio n. 30
0
 def wrapped(*args, **kwargs):
     in_specs = tree_util.tree_map(state.make_array_spec, args)
     out_specs = layer.spec(*in_specs, **kwargs)
     return tree_util.tree_map(state.make_array_spec, out_specs)