Ejemplo n.º 1
0
    def optimize_params(self, p0, num_iters, step_size, tolerance, verbose):
        """

        Perform gradient descent using JAX optimizers.
        """

        opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
        opt_state = opt_init(p0)

        @jit
        def step(i, opt_state):
            p = get_params(opt_state)
            g = grad(self.negative_log_evidence)(p)
            return opt_update(i, g, opt_state)

        cost_list = []
        params_list = []

        if verbose:
            self.print_progress_header(p0)

        for i in range(num_iters):

            opt_state = step(i, opt_state)
            params_list.append(get_params(opt_state))
            cost_list.append(self.negative_log_evidence(params_list[-1]))

            if verbose:
                if i % verbose == 0:
                    self.print_progress(i, params_list[-1], cost_list[-1])

            if len(params_list) > tolerance:

                if np.all(
                    (np.array(cost_list[1:])) - np.array(cost_list[:-1]) > 0):
                    params = params_list[0]
                    if verbose:
                        print(
                            'Stop: cost has been monotonically increasing for {} steps.'
                            .format(tolerance))
                    break
                elif np.all(
                        np.array(cost_list[:-1]) -
                        np.array(cost_list[1:]) < 1e-5):
                    params = params_list[-1]
                    if verbose:
                        print(
                            'Stop: cost has been stop changing for {} steps.'.
                            format(tolerance))
                    break
                else:
                    params_list.pop(0)
                    cost_list.pop(0)

        else:

            params = params_list[-1]
            if verbose:
                print('Stop: reached {0} steps, final cost={1:.5f}.'.format(
                    num_iters, cost_list[-1]))

        return params
Ejemplo n.º 2
0
 def cond_fn(curr):
     return jnp.bitwise_and(
         curr.i < SineBivariateVonMises.max_sample_iter,
         jnp.logical_not(jnp.all(curr.done)),
     )
Ejemplo n.º 3
0
def initialize_model(rng_key,
                     model,
                     init_strategy=init_to_uniform,
                     dynamic_args=False,
                     model_args=(),
                     model_kwargs=None):
    """
    (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
    and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
    to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
        See :ref:`init_strategy` section for available functions.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: a namedtupe `ModelInfo` which contains the fields
        (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
        `param_info` is a namedtuple `ParamInfo` containing values from the prior
        used to initiate MCMC, their corresponding potential energy, and their gradients;
        `postprocess_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support, in addition to returning values
        at `deterministic` sites in the model.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    substituted_model = substitute(seed(
        model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
                                   substitute_fn=init_strategy)
    inv_transforms, replay_model, has_enumerate_support, model_trace = _get_model_transforms(
        substituted_model, model_args, model_kwargs)
    # substitute param sites from model_trace to model so
    # we don't need to generate again parameters of `numpyro.module`
    model = substitute(model,
                       data={
                           k: site["value"]
                           for k, site in model_trace.items()
                           if site["type"] in ["param", "plate"]
                       })
    constrained_values = {
        k: v['value']
        for k, v in model_trace.items() if v['type'] == 'sample'
        and not v['is_observed'] and not v['fn'].is_discrete
    }

    if has_enumerate_support:
        from numpyro.contrib.funsor import config_enumerate, enum

        if not isinstance(model, enum):
            max_plate_nesting = _guess_max_plate_nesting(model_trace)
            _validate_model(model_trace)
            model = enum(config_enumerate(model), -max_plate_nesting - 1)

    potential_fn, postprocess_fn = get_potential_fn(model,
                                                    inv_transforms,
                                                    replay_model=replay_model,
                                                    enum=has_enumerate_support,
                                                    dynamic_args=dynamic_args,
                                                    model_args=model_args,
                                                    model_kwargs=model_kwargs)

    init_strategy = init_strategy if isinstance(init_strategy,
                                                partial) else init_strategy()
    if (init_strategy.func is init_to_value) and not replay_model:
        init_values = init_strategy.keywords.get("values")
        unconstrained_values = transform_fn(inv_transforms,
                                            init_values,
                                            invert=True)
        init_strategy = _init_to_unconstrained_value(
            values=unconstrained_values)
    prototype_params = transform_fn(inv_transforms,
                                    constrained_values,
                                    invert=True)
    (init_params, pe, grad), is_valid = find_valid_initial_params(
        rng_key,
        model,
        init_strategy=init_strategy,
        enum=has_enumerate_support,
        model_args=model_args,
        model_kwargs=model_kwargs,
        prototype_params=prototype_params)

    if not_jax_tracer(is_valid):
        if device_get(~jnp.all(is_valid)):
            with numpyro.validation_enabled(), trace() as tr:
                # validate parameters
                substituted_model(*model_args, **model_kwargs)
                # validate values
                for site in tr.values():
                    if site['type'] == 'sample':
                        with warnings.catch_warnings(record=True) as ws:
                            site['fn']._validate_sample(site['value'])
                        if len(ws) > 0:
                            for w in ws:
                                # at site information to the warning message
                                w.message.args = ("Site {}: {}".format(site["name"], w.message.args[0]),) \
                                    + w.message.args[1:]
                                warnings.showwarning(w.message,
                                                     w.category,
                                                     w.filename,
                                                     w.lineno,
                                                     file=w.file,
                                                     line=w.line)
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn,
                     postprocess_fn, model_trace)
Ejemplo n.º 4
0
 def all(self, boolean_tensor, axis=None, keepdims=False):
     if isinstance(boolean_tensor, (tuple, list)):
         boolean_tensor = jnp.stack(boolean_tensor)
     return jnp.all(boolean_tensor, axis=axis, keepdims=keepdims)
Ejemplo n.º 5
0
 def multi_errors(x):
     checkify.check(jnp.all(x < 0), "must be negative!")  # ASSERT
     x = x / 0  # DIV
     x = jnp.sin(x)  # NAN
     x = x[500]  # OOB
     return x
Ejemplo n.º 6
0
 def testAlphaDerivativeIsPositive(self):
     # Assert that d_loss / d_alpha > 0.
     _, _, _, alpha, _, _, d_alpha, _ = self._precompute_lossfun_inputs()
     mask = jnp.isfinite(alpha)
     self.assertTrue(
         jnp.all(d_alpha[mask] > (-300. * jnp.finfo(jnp.float32).eps)))
Ejemplo n.º 7
0
 def __eq__(self, other):
     assert False, "not yet implemented checking equality of reduce"
     return (isinstance(other, self.__class__)
             and np.all(other.insp_pts == self.insp_pts)
             and other.k == self.k)
Ejemplo n.º 8
0
 def solve(self, result:FiniteVec):
     if np.all(self.outp_feat.inspace_points == result.inspace_points):
         s = np.linalg.solve(self.matr @ inner(self.inp_feat, self.inp_feat), result.prefactors)
         return FiniteVec.construct_RKHS_Elem(result.k, result.inspace_points, s)
     else:
         assert()
Ejemplo n.º 9
0
def neighbor_list(displacement_or_metric: DisplacementOrMetricFn,
                  box_size: Box,
                  r_cutoff: float,
                  dr_threshold: float,
                  capacity_multiplier: float = 1.25,
                  cell_size: float = None,
                  disable_cell_list: bool = False,
                  mask_self: bool = True,
                  **static_kwargs) -> NeighborFn:
    """Returns a function that builds a list neighbors for collections of points.

  Neighbor lists must balance the need to be jit compatable with the fact that
  under a jit the maximum number of neighbors cannot change (owing to static
  shape requirements). To deal with this, our `neighbor_list` returns a
  function `neighbor_fn` that can operate in two modes: 1) create a new
  neighbor list or 2) update an existing neighbor list. Case 1) cannot be jit
  and it creates a neighbor list with a maximum neighbor count of the current
  neighbor count times capacity_multiplier. Case 2) is jit compatable, if any
  particle has more neighbors than the maximum, the `did_buffer_overflow` bit
  will be set to `True` and a new neighbor list will need to be created.

  Here is a typical example of a simulation loop with neighbor lists:

  >>> init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
  >>> exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3)
  >>>
  >>> nbrs = neighbor_fn(R)
  >>> state = init_fn(random.PRNGKey(0), R, neighbor_idx=nbrs.idx)
  >>>
  >>> def body_fn(i, state):
  >>>   state, nbrs = state
  >>>   nbrs = neighbor_fn(state.position, nbrs)
  >>>   state = apply_fn(state, neighbor_idx=nbrs.idx)
  >>>   return state, nbrs
  >>>
  >>> step = 0
  >>> for _ in range(20):
  >>>   new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
  >>>   if nbrs.did_buffer_overflow:
  >>>     nbrs = neighbor_fn(state.position)
  >>>   else:
  >>>     state = new_state
  >>>     step += 1

  Args:
    displacement: A function `d(R_a, R_b)` that computes the displacement
      between pairs of points.
    box_size: Either a float specifying the size of the box or an array of
      shape [spatial_dim] specifying the box size in each spatial dimension.
    r_cutoff: A scalar specifying the neighborhood radius.
    dr_threshold: A scalar specifying the maximum distance particles can move 
      before rebuilding the neighbor list.
    capacity_multiplier: A floating point scalar specifying the fractional
      increase in maximum neighborhood occupancy we allocate compared with the
      maximum in the example positions.
    cell_size: An optional scalar specifying the size of cells in the cell list
      used in an intermediate step.
    disable_cell_list: An optional boolean. If set to True then the neighbor
      list is constructed using only distances. This can be useful for
      debugging but should generally be left as False.
    mask_self: An optional boolean. Determines whether points can consider
      themselves to be their own neighbors.
    **static_kwargs: kwargs that get threaded through the calculation of
      example positions.
  Returns:
    A pair. The first element is a NeighborList containing the current neighbor
    list. The second element contains a function 
    `neighbor_list_fn(R, neighbor_list=None)` that will update the neighbor
    list. If neighbor_list is None then the function will construct a new
    neighbor list whose capacity is inferred from R. If neighbor_list is given
    then it will update the neighbor list (with fixed capacity) if any particle
    has moved more than dr_threshold / 2. Note that only
    `neighbor_list_fn(R, neighbor_list)` can be `jit` since it keeps array
    shapes fixed.
  """
    box_size = f32(box_size)

    cutoff = r_cutoff + dr_threshold
    cutoff_sq = cutoff**2
    threshold_sq = (dr_threshold / f32(2))**2
    metric_sq = _displacement_or_metric_to_metric_sq(displacement_or_metric)

    if cell_size is None:
        cell_size = cutoff

    use_cell_list = np.all(cell_size < box_size / 3.) and not disable_cell_list

    @jit
    def candidate_fn(R, **kwargs):
        return np.broadcast_to(
            np.reshape(np.arange(R.shape[0]), (1, R.shape[0])),
            (R.shape[0], R.shape[0]))

    @jit
    def cell_list_candidate_fn(cl, R, **kwargs):
        N, dim = R.shape

        R = cl.position_buffer
        idx = cl.id_buffer

        cell_idx = [idx]

        for dindex in _neighboring_cells(dim):
            if onp.all(dindex == 0):
                continue
            cell_idx += [_shift_array(idx, dindex)]

        cell_idx = np.concatenate(cell_idx, axis=-2)
        cell_idx = cell_idx[..., np.newaxis, :, :]
        cell_idx = np.broadcast_to(cell_idx,
                                   idx.shape[:-1] + cell_idx.shape[-2:])

        def copy_values_from_cell(value, cell_value, cell_id):
            scatter_indices = np.reshape(cell_id, (-1, ))
            cell_value = np.reshape(cell_value, (-1, ) + cell_value.shape[-2:])
            return ops.index_update(value, scatter_indices, cell_value)

        # NOTE(schsam): Currently, this makes a verlet list that is larger than
        # needed since the idx buffer inherets its size from the cell-list. In
        # three-dimensions this seems to translate into an occupancy of ~70%. We
        # can make this more efficient by shrinking the verlet list at the cost of
        # another sort. However, this seems possibly less efficient than just
        # computing everything.

        neighbor_idx = np.zeros((N + 1, ) + cell_idx.shape[-2:], np.int32)
        neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx)
        return neighbor_idx[:-1, :, 0]

    @jit
    def prune_neighbor_list(R, idx, **kwargs):
        d = partial(metric_sq, **kwargs)
        d = vmap(vmap(d, (None, 0)))

        N = R.shape[0]
        neigh_R = R[idx]
        dR = d(R, neigh_R)

        mask = np.logical_and(dR < cutoff_sq, idx < N)
        out_idx = N * np.ones(idx.shape, np.int32)

        cumsum = np.cumsum(mask, axis=1)
        index = np.where(mask, cumsum - 1, idx.shape[1] - 1)
        p_index = np.arange(idx.shape[0])[:, None]
        out_idx = ops.index_update(out_idx, ops.index[p_index, index], idx)
        max_occupancy = np.max(cumsum[:, -1])

        return out_idx, max_occupancy

    @jit
    def mask_self_fn(idx):
        self_mask = idx == np.reshape(np.arange(idx.shape[0]),
                                      (idx.shape[0], 1))
        return np.where(self_mask, idx.shape[0], idx)

    def neighbor_list_fn(R: Array,
                         neighbor_list: NeighborList = None,
                         extra_capacity: int = 0,
                         **kwargs) -> NeighborList:
        nbrs = neighbor_list

        def neighbor_fn(R_and_overflow, max_occupancy=None):
            R, overflow = R_and_overflow
            if cell_list_fn is not None:
                cl = cell_list_fn(R)
                idx = cell_list_candidate_fn(cl, R, **kwargs)
            else:
                idx = candidate_fn(R, **kwargs)
            idx, occupancy = prune_neighbor_list(R, idx, **kwargs)
            if max_occupancy is None:
                max_occupancy = int(occupancy * capacity_multiplier +
                                    extra_capacity)
                padding = max_occupancy - occupancy
                N = R.shape[0]
                if max_occupancy > occupancy:
                    idx = np.concatenate(
                        [idx, N * np.ones((N, padding), dtype=idx.dtype)],
                        axis=1)
            idx = idx[:, :max_occupancy]
            return NeighborList(
                mask_self_fn(idx) if mask_self else idx, R,
                np.logical_or(overflow, (max_occupancy < occupancy)),
                max_occupancy, cell_list_fn)  # pytype: disable=wrong-arg-count

        if nbrs is None:
            cell_list_fn = (cell_list(box_size, cell_size, R,
                                      capacity_multiplier)
                            if use_cell_list else None)
            return neighbor_fn((R, False))
        else:
            cell_list_fn = nbrs.cell_list_fn
            neighbor_fn = partial(neighbor_fn,
                                  max_occupancy=nbrs.max_occupancy)

        d = partial(metric_sq, **kwargs)
        d = vmap(d)
        return lax.cond(np.any(d(R, nbrs.reference_position) > threshold_sq),
                        (R, nbrs.did_buffer_overflow), neighbor_fn, nbrs,
                        lambda x: x)

    return neighbor_list_fn
Ejemplo n.º 10
0
 def testEigvalsInf(self):
   # https://github.com/google/jax/issues/2661
   x = np.array([[np.inf]], np.float64)
   self.assertTrue(np.all(np.isnan(np.linalg.eigvals(x))))
Ejemplo n.º 11
0
import numpy as np
from pysr import pysr, sympy2jax
from jax import numpy as jnp
from jax import random
from jax import grad
import sympy

print("Test JAX 1 - test export")
x, y, z = sympy.symbols('x y z')
cosx = 1.0 * sympy.cos(x) + y
key = random.PRNGKey(0)
X = random.normal(key, (1000, 2))
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
f, params = sympy2jax(cosx, [x, y, z])
assert jnp.all(jnp.isclose(f(X, params), true)).item()
Ejemplo n.º 12
0
def test_resample():
    x = random.normal(key=random.PRNGKey(0), shape=(50,))
    logits = -jnp.ones(50)
    samples = {'x':x}
    assert jnp.all(resample(random.PRNGKey(0), samples, logits)['x'] == resample(random.PRNGKey(0), x, logits))
Ejemplo n.º 13
0
def test_msqrt():
    for i in range(10):
        A = random.normal(random.PRNGKey(i),shape=(30,30))
        A = A @ A.T
        L = msqrt(A)
        assert jnp.all(jnp.isclose(A, L @ L.T))
Ejemplo n.º 14
0
def test_random_ortho_normal_matrix():
    H = random_ortho_matrix(random.PRNGKey(0), 3)
    assert jnp.all(jnp.isclose(H @ H.conj().T, jnp.eye(3), atol=1e-7))
Ejemplo n.º 15
0
def bilinear_sampler(imgs, coords, mask_value):
    """Construct a new image by bilinear sampling from the input image.
    Points falling outside the source image boundary have value of mask_value.
    Args:
        imgs: source image to be sampled from [b, h, w, c]
        coords: coordinates of source pixels to sample from [b, h, w, 2].
            height_t/width_t correspond to the dimensions of the output
            image (don't need to be the same as height_s/width_s).
            The two channels correspond to x and y coordinates respectively.
        mask_value: value of points outside of image. -1 for edge sampling.
        Returns:
            A new sampled image [height_t, width_t, channels]
    """
    coords_x, coords_y = jnp.split(coords, 2, axis=2)
    inp_size = imgs.shape
    out_size = list(coords.shape)
    out_size[2] = imgs.shape[2]

    coords_x = jnp.array(coords_x, dtype='float32')
    coords_y = jnp.array(coords_y, dtype='float32')

    y_max = jnp.array(jnp.shape(imgs)[0] - 1, dtype='float32')
    x_max = jnp.array(jnp.shape(imgs)[1] - 1, dtype='float32')
    zero = jnp.zeros([1], dtype='float32')
    eps = jnp.array([0.5], dtype='float32')

    coords_x_clipped = jnp.clip(coords_x, zero, x_max - eps)
    coords_y_clipped = jnp.clip(coords_y, zero, y_max - eps)

    x0 = jnp.floor(coords_x_clipped)
    x1 = x0 + 1
    y0 = jnp.floor(coords_y_clipped)
    y1 = y0 + 1

    x0_safe = jnp.clip(x0, zero, x_max)
    y0_safe = jnp.clip(y0, zero, y_max)
    x1_safe = jnp.clip(x1, zero, x_max)
    y1_safe = jnp.clip(y1, zero, y_max)

    # bilinear interp weights, with points outside the grid having weight 0
    # wt_x0 = (x1 - coords_x) * jnp.equal(x0, x0_safe).astype('float32')
    # wt_x1 = (coords_x - x0) * jnp.equal(x1, x1_safe).astype('float32')
    # wt_y0 = (y1 - coords_y) * jnp.equal(y0, y0_safe).astype('float32')
    # wt_y1 = (coords_y - y0) * jnp.equal(y1, y1_safe).astype('float32')

    wt_x0 = x1_safe - coords_x  # 1
    wt_x1 = coords_x - x0_safe  # 0
    wt_y0 = y1_safe - coords_y  # 1
    wt_y1 = coords_y - y0_safe  # 0

    # indices in the flat image to sample from
    dim2 = jnp.array(inp_size[1], dtype='float32')

    base_y0 = y0_safe * dim2
    base_y1 = y1_safe * dim2
    idx00 = jnp.reshape(x0_safe + base_y0, [-1])
    idx01 = x0_safe + base_y1
    idx10 = x1_safe + base_y0
    idx11 = x1_safe + base_y1

    # sample from imgs
    imgs_flat = jnp.reshape(imgs, [-1, inp_size[2]])
    imgs_flat = imgs_flat.astype('float32')
    im00 = jnp.reshape(
        jnp.take(imgs_flat, idx00.astype('int32'), axis=0), out_size)
    im01 = jnp.reshape(
        jnp.take(imgs_flat, idx01.astype('int32'), axis=0), out_size)
    im10 = jnp.reshape(
        jnp.take(imgs_flat, idx10.astype('int32'), axis=0), out_size)
    im11 = jnp.reshape(
        jnp.take(imgs_flat, idx11.astype('int32'), axis=0), out_size)

    w00 = wt_x0 * wt_y0
    w01 = wt_x0 * wt_y1
    w10 = wt_x1 * wt_y0
    w11 = wt_x1 * wt_y1

    output = jnp.clip(jnp.round(w00 * im00 + w01 * im01 + w10 * im10 +
                                w11 * im11), 0, 255)

    return jnp.where(jnp.all(mask_value >= 0),
                     jnp.where(
                         compute_mask(coords_x, coords_y, x_max, y_max),
                         output,
                         jnp.ones_like(output) *
                         jnp.reshape(jnp.array(mask_value), [1, 1, -1])
                     ),
                     output)
Ejemplo n.º 16
0
def test_mean_var(jax_dist, sp_dist, params):
    n = 20000 if jax_dist in [dist.LKJ, dist.LKJCholesky] else 200000
    d_jax = jax_dist(*params)
    k = random.PRNGKey(0)
    samples = d_jax.sample(k, sample_shape=(n, ))
    # check with suitable scipy implementation if available
    if sp_dist and not _is_batched_multivariate(d_jax):
        d_sp = sp_dist(*params)
        try:
            sp_mean = d_sp.mean()
        except TypeError:  # mvn does not have .mean() method
            sp_mean = d_sp.mean
        # for multivariate distns try .cov first
        if d_jax.event_shape:
            try:
                sp_var = np.diag(d_sp.cov())
            except TypeError:  # mvn does not have .cov() method
                sp_var = np.diag(d_sp.cov)
            except AttributeError:
                sp_var = d_sp.var()
        else:
            sp_var = d_sp.var()
        assert_allclose(d_jax.mean, sp_mean, rtol=0.01, atol=1e-7)
        assert_allclose(d_jax.variance, sp_var, rtol=0.01, atol=1e-7)
        if np.all(np.isfinite(sp_mean)):
            assert_allclose(np.mean(samples, 0),
                            d_jax.mean,
                            rtol=0.05,
                            atol=1e-2)
        if np.all(np.isfinite(sp_var)):
            assert_allclose(np.std(samples, 0),
                            np.sqrt(d_jax.variance),
                            rtol=0.05,
                            atol=1e-2)
    elif jax_dist in [dist.LKJ, dist.LKJCholesky]:
        if jax_dist is dist.LKJCholesky:
            corr_samples = np.matmul(samples, np.swapaxes(samples, -2, -1))
        else:
            corr_samples = samples
        dimension, concentration, _ = params
        # marginal of off-diagonal entries
        marginal = dist.Beta(concentration + 0.5 * (dimension - 2),
                             concentration + 0.5 * (dimension - 2))
        # scale statistics due to linear mapping
        marginal_mean = 2 * marginal.mean - 1
        marginal_std = 2 * np.sqrt(marginal.variance)
        expected_mean = np.broadcast_to(
            np.reshape(marginal_mean,
                       np.shape(marginal_mean) + (1, 1)),
            np.shape(marginal_mean) + d_jax.event_shape)
        expected_std = np.broadcast_to(
            np.reshape(marginal_std,
                       np.shape(marginal_std) + (1, 1)),
            np.shape(marginal_std) + d_jax.event_shape)
        # diagonal elements of correlation matrices are 1
        expected_mean = expected_mean * (
            1 - np.identity(dimension)) + np.identity(dimension)
        expected_std = expected_std * (1 - np.identity(dimension))

        assert_allclose(np.mean(corr_samples, axis=0),
                        expected_mean,
                        atol=0.01)
        assert_allclose(np.std(corr_samples, axis=0), expected_std, atol=0.01)
    else:
        if np.all(np.isfinite(d_jax.mean)):
            assert_allclose(np.mean(samples, 0),
                            d_jax.mean,
                            rtol=0.05,
                            atol=1e-2)
        if np.all(np.isfinite(d_jax.variance)):
            assert_allclose(np.std(samples, 0),
                            np.sqrt(d_jax.variance),
                            rtol=0.05,
                            atol=1e-2)
Ejemplo n.º 17
0
 def _validate_sample(self, value):
     mask = super(ImproperUniform, self)._validate_sample(value)
     batch_dim = jnp.ndim(value) - len(self.event_shape)
     if batch_dim < jnp.ndim(mask):
         mask = jnp.all(jnp.reshape(mask, jnp.shape(mask)[:batch_dim] + (-1,)), -1)
     return mask
Ejemplo n.º 18
0
 def _validate_sample(self, x):
     if not np.all(self.support(x)):
         raise ValueError('Invalid values provided to log prob method. '
                          'The value argument must be within the support.')
Ejemplo n.º 19
0
 def testScaleDerivativeIsNegative(self):
     # Assert that d_loss / d_scale < 0.
     _, _, _, alpha, _, _, _, d_scale = self._precompute_lossfun_inputs()
     mask = jnp.isfinite(alpha)
     self.assertTrue(
         jnp.all(d_scale[mask] < (300. * jnp.finfo(jnp.float32).eps)))
Ejemplo n.º 20
0
def initialize_model(rng_key,
                     model,
                     init_strategy=init_to_uniform,
                     dynamic_args=False,
                     model_args=(),
                     model_kwargs=None):
    """
    (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
    and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
    to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
        See :ref:`init_strategy` section for available functions.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: a namedtupe `ModelInfo` which contains the fields
        (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
        `param_info` is a namedtuple `ParamInfo` containing values from the prior
        used to initiate MCMC, their corresponding potential energy, and their gradients;
        `postprocess_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support, in addition to returning values
        at `deterministic` sites in the model.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    substituted_model = substitute(seed(
        model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
                                   substitute_fn=init_strategy)
    inv_transforms, replay_model, model_trace = _get_model_transforms(
        substituted_model, model_args, model_kwargs)
    constrained_values = {
        k: v['value']
        for k, v in model_trace.items() if v['type'] == 'sample'
        and not v['is_observed'] and not v['fn'].is_discrete
    }

    potential_fn, postprocess_fn = get_potential_fn(model,
                                                    inv_transforms,
                                                    replay_model=replay_model,
                                                    dynamic_args=dynamic_args,
                                                    model_args=model_args,
                                                    model_kwargs=model_kwargs)

    init_strategy = init_strategy if isinstance(init_strategy,
                                                partial) else init_strategy()
    if init_strategy.func is init_to_value:
        init_values = init_strategy.keywords.get("values")
        unconstrained_values = transform_fn(inv_transforms,
                                            init_values,
                                            invert=True)
        init_strategy = _init_to_unconstrained_value(
            values=unconstrained_values)
    prototype_params = transform_fn(inv_transforms,
                                    constrained_values,
                                    invert=True)
    (init_params, pe, grad), is_valid = find_valid_initial_params(
        rng_key,
        model,
        init_strategy=init_strategy,
        model_args=model_args,
        model_kwargs=model_kwargs,
        prototype_params=prototype_params)

    if not_jax_tracer(is_valid):
        if device_get(~jnp.all(is_valid)):
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn,
                     postprocess_fn, model_trace)
Ejemplo n.º 21
0
 def prod(self, value, axis=None):
     if not isinstance(value, jnp.ndarray):
         value = jnp.array(value)
     if value.dtype == bool:
         return jnp.all(value, axis=axis)
     return jnp.prod(value, axis=axis)
Ejemplo n.º 22
0
    def optimize(self, p0, num_iters, metric, step_size, tolerance, verbose,
                 return_model):
        '''Workhorse of optimization.

        p0: dict
            A dictionary of the initial model parameters to be optimized.  

        num_iters: int
            Maximum number of iteration.

        metric: str
            Method of model evaluation. Can be
            `mse`, `corrcoeff`, `r2`


        step_size: float or jax scheduler
            Learning rate.
        
        tolerance: int
            Tolerance for early stop. If the training cost doesn't change more than 1e-5
            in the last (tolerance) steps, or the dev cost monotonically increase, stop.

        verbose: int
            Print progress. If verbose=0, no progress will be print.

        return_model: str
            Return the 'best' model on dev set metrics or the 'last' model.
        '''
        @jit
        def step(i, opt_state):
            p = get_params(opt_state)
            l, g = value_and_grad(self.cost)(p)
            return l, opt_update(i, g, opt_state)

        opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
        opt_state = opt_init(p0)

        cost_train = [0] * num_iters
        cost_dev = [0] * num_iters
        metric_train = [0] * num_iters
        metric_dev = [0] * num_iters
        params_list = [0] * num_iters
        if verbose:
            if 'dev' not in self.y:
                if metric is None:
                    print('{0}\t{1:>10}\t{2:>10}'.format(
                        'Iters', 'Time (s)', 'Cost (train)'))
                else:
                    print('{0}\t{1:>10}\t{2:>10}\t{3:>10}'.format(
                        'Iters', 'Time (s)', 'Cost (train)',
                        f'{metric} (train)'))
            else:
                if metric is None:
                    print('{0}\t{1:>10}\t{2:>10}\t{3:>10}'.format(
                        'Iters', 'Time (s)', 'Cost (train)', 'Cost (dev)'))
                else:
                    print('{0}\t{1:>10}\t{2:>10}\t{3:>10}\t{4:>10}\t{5:>10}'.
                          format('Iters', 'Time (s)', 'Cost (train)',
                                 'Cost (dev)', f'{metric} (train)',
                                 f'{metric} (dev)'))

        time_start = time.time()
        for i in range(num_iters):
            cost_train[i], opt_state = step(i, opt_state)
            params_list[i] = get_params(opt_state)
            y_pred_train = self.forwardpass(p=params_list[i], kind='train')
            metric_train[i] = self._score(self.y['train'], y_pred_train,
                                          metric)

            if 'dev' in self.y:
                y_pred_dev = self.forwardpass(p=params_list[i], kind='dev')
                cost_dev[i] = self.cost(p=params_list[i],
                                        kind='dev',
                                        precomputed=y_pred_dev,
                                        penalize=False)
                metric_dev[i] = self._score(self.y['dev'], y_pred_dev, metric)

            time_elapsed = time.time() - time_start
            if verbose:
                if i % int(verbose) == 0:
                    if 'dev' not in self.y:
                        if metric is None:
                            print('{0:>5}\t{1:>10.3f}\t{2:>10.3f}'.format(
                                i, time_elapsed, cost_train[i]))
                        else:
                            print('{0:>5}\t{1:>10.3f}\t{2:>10.3f}\t{3:>10.3f}'.
                                  format(i, time_elapsed, cost_train[i],
                                         metric_train[i]))

                    else:
                        if metric is None:
                            print('{0:>5}\t{1:>10.3f}\t{2:>10.3f}\t{3:>10.3f}'.
                                  format(i, time_elapsed, cost_train[i],
                                         cost_dev[i]))
                        else:
                            print(
                                '{0:>5}\t{1:>10.3f}\t{2:>10.3f}\t{3:>10.3f}\t{4:>10.3f}\t{5:>10.3f}'
                                .format(i, time_elapsed, cost_train[i],
                                        cost_dev[i], metric_train[i],
                                        metric_dev[i]))
            if tolerance and i > 300:  # tolerance = 0: no early stop.

                total_time_elapsed = time.time() - time_start
                cost_train_slice = np.array(cost_train[i - tolerance:i])
                cost_dev_slice = np.array(cost_dev[i - tolerance:i])

                if 'dev' in self.y and np.all(
                        cost_dev_slice[1:] - cost_dev_slice[:-1] > 0):
                    if verbose:
                        print(
                            'Stop at {0} steps: cost (dev) has been monotonically increasing for {1} steps.'
                            .format(i, tolerance))
                        print('Total time elapsed: {0:.3f}s.\n'.format(
                            total_time_elapsed))
                    stop = 'dev_stop'
                    break

                if np.all(cost_train_slice[:-1] - cost_train_slice[1:] < 1e-5):
                    if verbose:
                        print(
                            'Stop at {0} steps: cost (train) has been changing less than 1e-5 for {1} steps.'
                            .format(i, tolerance))
                        print('Total time elapsed: {0:.3f}s.\n'.format(
                            total_time_elapsed))
                    stop = 'train_stop'
                    break

        else:
            total_time_elapsed = time.time() - time_start
            stop = 'maxiter_stop'
            if verbose:
                print('Stop: reached {0} steps.'.format(num_iters))
                print('Total time elapsed: {0:.3f}s.\n'.format(
                    total_time_elapsed))

        if return_model == 'best_dev_cost':
            best = np.argmin(np.asarray(cost_dev[:i + 1]))

        elif return_model == 'best_train_cost':
            best = np.argmin(np.asarray(cost_train[:i + 1]))

        elif return_model == 'best_dev_metric':
            if metric in ['mse', 'gcv']:
                best = np.argmin(np.asarray(metric_dev[:i + 1]))
            else:
                best = np.argmax(np.asarray(metric_dev[:i + 1]))

        elif return_model == 'best_train_metric':
            if metric in ['mse', 'gcv']:
                best = np.argmin(np.asarray(metric_train[:i + 1]))
            else:
                best = np.argmax(np.asarray(metric_train[:i + 1]))

        elif return_model == 'last':
            if stop == 'dev_stop':
                best = i - tolerance
            else:
                best = i

        else:
            print(
                'Provided `return_model` is not supported. Fell back to `best_dev_cost`'
            )
            best = np.argmin(np.asarray(cost_dev[:i + 1]))

        params = params_list[best]
        metric_dev_opt = metric_dev[best]

        self.cost_train = np.hstack(cost_train[:i + 1])
        self.cost_dev = np.hstack(cost_dev[:i + 1])
        self.metric_train = np.hstack(metric_train[:i + 1])
        self.metric_dev = np.hstack(metric_dev[:i + 1])
        self.metric_dev_opt = metric_dev_opt
        self.total_time_elapsed = total_time_elapsed

        self.all_params = params_list[:i +
                                      1]  # not sure if this will occupy a lot of RAM.

        self.y_pred['opt'].update({'train': y_pred_train})
        if 'dev' in self.y:
            self.y_pred['opt'].update({'dev': y_pred_dev})

        return params
Ejemplo n.º 23
0
def _check_synced(pytree):
  mins = jax.lax.pmin(pytree, axis_name='batch')
  equals = jax.tree_multimap(jnp.array_equal, pytree, mins)
  return jnp.all(jnp.asarray(jax.tree_leaves(equals)))
Ejemplo n.º 24
0
 def cond_fn(*args):
     """ check if all are done or reached max number of iterations """
     i, _, done, _, _ = args[0]
     return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
Ejemplo n.º 25
0
 def body(carry, x):
     checkify.check(jnp.all(x > 0), "should be positive")
     return carry, x
Ejemplo n.º 26
0
 def evaluate_batch(params, data):
     logits = model.apply(params, data.predicates)
     correct = jnp.all(data.targets == jnp.argmax(logits, axis=-1), axis=-1)
     ce_loss = jnp.sum(cross_entropy(data.targets, logits), axis=-1)
     return correct, ce_loss
Ejemplo n.º 27
0
 def same(list1, list2):
     allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3)
     elements_close = list(map(allclose, list1, list2))
     return lnp.all(lnp.array(elements_close))
Ejemplo n.º 28
0
 def interaction(q, d_qtm_nk1, d_qtm_nk2):
     q_norm = jnp.linalg.norm(q @ bvec)
     flag = jnp.all(d_qtm_nk1==d_qtm_nk2)
     res = jnp.where(q_norm == 0., 0., jnp.where(flag, Uvalue * qM / jnp.sqrt(q_norm ** 2 + Kappa ** 2), Jvalue))
     return res/num_k1**2
Ejemplo n.º 29
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # XXX: we don't want to apply enum to draw latent samples
            model_ = model
            if enum:
                from numpyro.contrib.funsor import enum as enum_handler

                if isinstance(model, substitute) and isinstance(
                        model.fn, enum_handler):
                    model_ = substitute(model.fn.fn, data=model.data)
                elif isinstance(model, enum_handler):
                    model_ = model.fn

            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model_, subkey),
                                      substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(
                *model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if (v["type"] == "sample" and not v["is_observed"]
                        and not v["fn"].support.is_discrete):
                    constrained_values[k] = v["value"]
                    with helpful_support_errors(v):
                        inv_transforms[k] = biject_to(v["fn"].support)
            params = transform_fn(
                inv_transforms,
                {k: v
                 for k, v in constrained_values.items()},
                invert=True,
            )
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(subkey,
                                               jnp.shape(v),
                                               minval=-radius,
                                               maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy,
                               model,
                               model_args,
                               model_kwargs,
                               enum=enum)
        if validate_grad:
            if forward_mode_differentiation:
                pe = potential_fn(params)
                z_grad = jacfwd(potential_fn)(params)
            else:
                pe, z_grad = value_and_grad(potential_fn)(params)
            z_grad_flat = ravel_pytree(z_grad)[0]
            is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        else:
            pe = potential_fn(params)
            is_valid = jnp.isfinite(pe)
            z_grad = None

        return i + 1, key, (params, pe, z_grad), is_valid
Ejemplo n.º 30
0
                ReturnType = named_tuple_factory(name, get)
                if isinstance(out, types.GeneratorType):
                    return (ReturnType(*tuple(output[g] for g in get))
                            for output in out)
                else:
                    return ReturnType(*tuple(out[g] for g in get))

            return canonicalize_output(fn_out)

        return getter_fn

    return getter_decorator


@nt_tree_fn(nargs=2, reduce=lambda x: np.all(np.array(x)))
def x1_is_x2(x1: np.ndarray,
             x2: Optional[np.ndarray] = None,
             eps: float = 1e-12) -> Union[bool, np.ndarray]:
    if not isinstance(x1, (onp.ndarray, np.ndarray)):
        raise TypeError('`x1` must be an ndarray. A {} is found.'.format(
            type(x1)))

    if x2 is None:
        return True

    if x1 is x2:
        return True

    if x1.shape != x2.shape:
        return False