Ejemplo n.º 1
0
def _standard_gamma_grad(sample, alpha):
    samples = np.reshape(sample, -1)
    alphas = np.reshape(alpha, -1)
    grads = vmap(_standard_gamma_grad_one)(samples, alphas)
    return grads.reshape(alpha.shape)
Ejemplo n.º 2
0
def gradient_descent(g_dd, y_train, loss, g_td=None):
  """Predicts the outcome of function space gradient descent training on `loss`.

  Solves for continuous-time gradient descent using an ODE solver.

  Solves the function space ODE for continuous gradient descent with a given
  loss (detailed in [*]) given a Neural Tangent Kernel over the dataset. This
  function returns a function that predicts the time evolution for function
  space points at arbitrary times. Note that times are continuous and are
  measured in units of the learning rate so that t = learning_rate * steps.

  This function uses the scipy ode solver with the 'dopri5' algorithm.

  [*] https://arxiv.org/abs/1902.06720

  Example:
    ```python
    >>> from jax.experimental import stax
    >>> from neural_tangents import predict
    >>>
    >>> train_time = 1e-7
    >>> kernel_fn = empirical(f)
    >>> g_td = kernel_fn(x_test, x_train, params)
    >>>
    >>> from jax.experimental import stax
    >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
    >>> predict_fn = predict.gradient_descent(
    >>>     g_dd, y_train, cross_entropy, g_td)
    >>>
    >>> fx_train_initial = f(params, x_train)
    >>> fx_test_initial = f(params, x_test)
    >>>
    >>> fx_train_final, fx_test_final = predict_fn(
    >>>     fx_train_initial, fx_test_initial, train_time)
    ```
  Args:
    g_dd: A Kernel on the training data. The kernel should be an `np.ndarray` of
      shape [n_train * output_dim, n_train * output_dim] or [n_train, n_train].
      In the latter case it is assumed that the kernel is block diagonal over
      the logits.
    y_train: A `np.ndarray` of shape [n_train, output_dim] of labels for the
      training data.
    loss: A loss function whose signature is loss(fx, y_hat) where fx is an
      `np.ndarray` of function space output_dim of the network and y_hat are
      targets. Note: the loss function should treat the batch and output
        dimensions symmetrically.
    g_td: A Kernel relating training data with test data. The kernel should be
      an `np.ndarray` of shape [n_test * output_dim, n_train * output_dim] or
      [n_test, n_train]. Note: g_td should have been created in the convention
        kernel_fn(x_test, x_train, params).

  Returns:
    A function that predicts outputs after t = learning_rate * steps of
    training.

    If g_td is None:
      The function returned is predict(fx, t). Here fx is an `np.ndarray` of
      network outputs and has shape [n_train, output_dim], t is a floating point
      time. predict(fx, t) returns an `np.ndarray` of predictions of shape
      [n_train, output_dim].

    If g_td is not None:
      If a test set Kernel is specified then it returns a function,
      predict(fx_train, fx_test, t). Here fx_train and fx_test are ndarays of
      network outputs and have shape [n_train, output_dim] and
      [n_test, output_dim] respectively and t is a floating point time.
      predict(fx_train, fx_test, t) returns a tuple of predictions of shape
      [n_train, output_dim] and [n_test, output_dim] for train and test points
      respectively.
  """

  output_dimension = y_train.shape[-1]

  g_dd = empirical.flatten_features(g_dd)

  def fl(fx):
    """Flatten outputs."""
    return np.reshape(fx, (-1,))

  def ufl(fx):
    """Unflatten outputs."""
    return np.reshape(fx, (-1, output_dimension))

  # These functions are used inside the integrator only if the kernel is
  # diagonal over the logits.
  ifl = lambda x: x
  iufl = lambda x: x

  # Check to see whether the kernel has a logit dimension.
  if y_train.size > g_dd.shape[-1]:
    out_dim, ragged = divmod(y_train.size, g_dd.shape[-1])
    if ragged or out_dim != y_train.shape[-1]:
      raise ValueError()
    ifl = fl
    iufl = ufl

  y_train = np.reshape(y_train, (-1))
  grad_loss = grad(functools.partial(loss, y_hat=y_train))

  if g_td is None:
    dfx_dt = lambda unused_t, fx: -ifl(np.dot(g_dd, iufl(grad_loss(fx))))

    def predict(dt, fx=0.):
      r = ode(dfx_dt).set_integrator('dopri5')
      r.set_initial_value(fl(fx), 0)
      r.integrate(dt)

      return ufl(r.y)
  else:
    g_td = empirical.flatten_features(g_td)

    def dfx_dt(unused_t, fx, train_size):
      fx_train = fx[:train_size]
      dfx_train = -ifl(np.dot(g_dd, iufl(grad_loss(fx_train))))
      dfx_test = -ifl(np.dot(g_td, iufl(grad_loss(fx_train))))
      return np.concatenate((dfx_train, dfx_test), axis=0)

    def predict(dt, fx_train=0., fx_test=0.):
      r = ode(dfx_dt).set_integrator('dopri5')

      fx = fl(np.concatenate((fx_train, fx_test), axis=0))
      train_size, output_dim = fx_train.shape
      r.set_initial_value(fx, 0).set_f_params(train_size * output_dim)
      r.integrate(dt)
      fx = ufl(r.y)

      return fx[:train_size], fx[train_size:]

  return predict
Ejemplo n.º 3
0
 def fl(fx):
   """Flatten outputs."""
   return np.reshape(fx, (-1,))
Ejemplo n.º 4
0
    def upper_bound_softmax_plus_affine(
        self,
        inner_dual_vars: Any,
        opt_instance: InnerVerifInstance,
        key: jnp.array,
        step: int,
    ) -> jnp.array:
        """Upper bound (softmax + affine)-type problem with cvxpy.

    Upper bound obj'*softmax(x) - lagrangian_form(x) subject to l<=x<=u

    Note that this function cannot be differentiated through; using it at
    training time will lead to an error.

    Args:
      inner_dual_vars: jax () scalar.
      opt_instance: Inner optimization instance.
      key: RNG key.
      step: outer optimization iteration number.

    Returns:
      Optimal value.
    Raises:
      ValueError if Lagrangian form is not supported or if the problem is not
        solved to optimality.
    """
        if not isinstance(opt_instance.lagrangian_form_pre,
                          lagrangian_form.Linear):
            raise ValueError('Unsupported Lagrangian form.')

        lower = opt_instance.bounds[0].lb_pre
        upper = opt_instance.bounds[0].ub_pre

        def lagr_form(x):
            val = opt_instance.lagrangian_form_pre.apply(
                x, opt_instance.lagrange_params_pre, step)
            return jnp.reshape(val, ())

        # extract coeff_linear via autodiff (including negative sign here)
        coeff_linear = -jax.grad(lagr_form)(jnp.zeros_like(lower))

        assert len(opt_instance.affine_fns) == 1
        # extract coeff_softmax via autodiff
        coeff_softmax_fn = lambda x: jnp.reshape(opt_instance.affine_fns[0](x),
                                                 ())
        coeff_softmax = jax.grad(coeff_softmax_fn)(jnp.zeros_like(lower))

        if opt_instance.spec_type == verify_utils.SpecType.ADVERSARIAL_SOFTMAX:
            upper_bounding_method = exact_opt_softmax.exact_opt_softmax_plus_affine
        else:
            upper_bounding_method = upper_bound_softmax_plus_affine_exact
        upper_bound, _ = upper_bounding_method(
            c_linear=np.array(coeff_linear).squeeze(0).astype(np.float64),
            c_softmax=np.array(coeff_softmax).squeeze(0).astype(np.float64),
            lb=np.array(lower).squeeze(0).astype(np.float64),
            ub=np.array(upper).squeeze(0).astype(np.float64),
        )

        constant = (coeff_softmax_fn(jnp.zeros_like(lower)) -
                    lagr_form(jnp.zeros_like(lower)))

        result = jnp.array(upper_bound) + constant

        return jnp.reshape(result, [lower.shape[0]])
Ejemplo n.º 5
0
    def solve_max_exp(
        self,
        inner_dual_vars: Any,
        opt_instance: InnerVerifInstance,
        key: jnp.array,
        step: int,
    ) -> jnp.array:
        """Solve inner max problem for final layer with uncertainty specification.

    Maximize obj'*softmax(x) - lagrangian_form(x) subject to l<=x<=u

    Args:
      inner_dual_vars: () jax scalar.
      opt_instance: Inner optimization instance.
      key: RNG key.
      step: outer optimization iteration number.

    Returns:
      opt: Optimal value.
    """
        assert opt_instance.is_last
        l = opt_instance.bounds[0].lb_pre
        u = opt_instance.bounds[0].ub_pre

        def lagr_form(x):
            val = opt_instance.lagrangian_form_pre.apply(
                x, opt_instance.lagrange_params_pre, step)
            return jnp.reshape(val, ())

        affine_obj = lambda x: jnp.reshape(opt_instance.affine_fns[0](x), ())
        assert len(opt_instance.affine_fns) == 1

        def max_objective_fn(anyx):
            return affine_obj(jax.nn.softmax(anyx)) - lagr_form(anyx)

        min_objective_fn = lambda x: -max_objective_fn(x)

        opt = optax.adam(self._learning_rate)
        grad_fn = jax.grad(min_objective_fn)

        def cond_fn(inputs):
            it, x, grad_x, _ = inputs
            not_converged = jnp.logical_not(has_converged(x, grad_x, l, u))
            return jnp.logical_and(it < self._n_iter, not_converged)

        def body_fn(inputs):
            it, x, _, opt_state = inputs
            grad_x = grad_fn(x)
            updates, opt_state = opt.update(grad_x, opt_state, x)
            x = optax.apply_updates(x, updates)
            x = jnp.clip(x, l, u)
            it = it + 1
            return it, x, grad_x, opt_state

        def find_max_from_init(x):
            opt_state = opt.init(x)

            # iteration, x, grad_x, opt_state
            init_val = (jnp.zeros(()), x, jnp.ones_like(x), opt_state)
            _, adv_x, _, _ = jax.lax.while_loop(cond_fn, body_fn, init_val)

            adv_x = jnp.clip(adv_x, l, u)

            return jnp.reshape(max_objective_fn(jax.lax.stop_gradient(adv_x)),
                               (1, ))

        # initialization heuristic 1: max when ignoring softmax
        mask_ignore_softmax = jax.grad(lagr_form)(jnp.ones_like(u)) < 0
        x = mask_ignore_softmax * u + (1 - mask_ignore_softmax) * l
        objective_1 = find_max_from_init(x)

        # initialization heuristic 2: max when ignoring affine
        mask_ignore_affine = jax.grad(affine_obj)(jnp.ones_like(u)) > 0
        x = mask_ignore_affine * u + (1 - mask_ignore_affine) * l
        objective_2 = find_max_from_init(x)

        # also try at boundaries
        objective_3 = find_max_from_init(l)
        objective_4 = find_max_from_init(u)

        # select best of runs
        objective = jnp.maximum(jnp.maximum(objective_1, objective_2),
                                jnp.maximum(objective_3, objective_4))

        return objective
Ejemplo n.º 6
0
 def decollapse_and_split(self, x):        
     # Decollapse batches and split alpha from color channels
     x = jnp.reshape(x, (x.shape[0]//self.num_slots, self.num_slots, *x.shape[1:])) # Decollapse batches from slots
     x, alphas = jnp.array_split(x, [x.shape[-1]-1], -1)
     return x, alphas
Ejemplo n.º 7
0
 def _flatten(params):
     """Flattens and concatenates all tensors in params to a single vector."""
     params, _ = tree_flatten(params)
     return jnp.concatenate([jnp.reshape(param, [-1]) for param in params])
Ejemplo n.º 8
0
 def candidate_fn(R, **kwargs):
     return np.broadcast_to(
         np.reshape(np.arange(R.shape[0]), (1, R.shape[0])),
         (R.shape[0], R.shape[0]))
Ejemplo n.º 9
0
 def copy_values_from_cell(value, cell_value, cell_id):
     scatter_indices = np.reshape(cell_id, (-1, ))
     cell_value = np.reshape(cell_value, (-1, ) + cell_value.shape[-2:])
     return ops.index_update(value, scatter_indices, cell_value)
Ejemplo n.º 10
0
def cell_list(
        box_size: Box,
        minimum_cell_size: float,
        cell_capacity_or_example_R: Union[int, Array],
        buffer_size_multiplier: float = 1.1) -> Callable[[Array], CellList]:
    r"""Returns a function that partitions point data spatially. 

  Given a set of points {x_i \in R^d} with associated data {k_i \in R^m} it is
  often useful to partition the points / data spatially. A simple partitioning
  that can be implemented efficiently within XLA is a dense partition into a
  uniform grid called a cell list.

  Since XLA requires that shapes be statically specified, we allocate fixed
  sized buffers for each cell. The size of this buffer can either be specified
  manually or it can be estimated automatically from a set of positions. Note,
  if the distribution of points changes significantly it is likely the buffer
  the buffer sizes will have to be adjusted.

  This partitioning will likely form the groundwork for parallelizing
  simulations over different accelerators.

  Args:
    box_size: A float or an ndarray of shape [spatial_dimension] specifying the
      size of the system. Note, this code is written for the case where the
      boundaries are periodic. If this is not the case, then the current code
      will be slightly less efficient.
    minimum_cell_size: A float specifying the minimum side length of each cell.
      Cells are enlarged so that they exactly fill the box.
    cell_capacity_or_example_R: Either an integer specifying the size
      number of particles that can be stored in each cell or an ndarray of
      positions of shape [particle_count, spatial_dimension] that is used to
      estimate the cell_capacity.
    buffer_size_multiplier: A floating point multiplier that multiplies the
      estimated cell capacity to allow for fluctuations in the maximum cell
      occupancy.
  Returns:
    A function `cell_list_fn(R, **kwargs)` that partitions positions, `R`, and
    side data specified by kwargs into a cell list. Returns a CellList
    containing the partition.
  """

    if isinstance(box_size, np.ndarray):
        box_size = onp.array(box_size)
        if len(box_size.shape) == 1:
            box_size = np.reshape(box_size, (1, -1))

    if isinstance(minimum_cell_size, np.ndarray):
        minimum_cell_size = onp.array(minimum_cell_size)

    cell_capacity = cell_capacity_or_example_R
    if _is_variable_compatible_with_positions(cell_capacity):
        cell_capacity = _estimate_cell_capacity(cell_capacity, box_size,
                                                minimum_cell_size,
                                                buffer_size_multiplier)
    elif not isinstance(cell_capacity, int):
        msg = (
            'cell_capacity_or_example_positions must either be an integer '
            'specifying the cell capacity or a set of positions that will be used '
            'to estimate a cell capacity. Found {}.'.format(
                type(cell_capacity)))
        raise ValueError(msg)

    def build_cells(R, **kwargs):
        N = R.shape[0]
        dim = R.shape[1]

        if dim != 2 and dim != 3:
            # NOTE(schsam): Do we want to check this in compute_fn as well?
            raise ValueError(
                'Cell list spatial dimension must be 2 or 3. Found {}'.format(
                    dim))

        neighborhood_tile_count = 3**dim

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create cell list data.
        particle_id = lax.iota(np.int64, N)
        # NOTE(schsam): We use the convention that particles that are successfully,
        # copied have their true id whereas particles empty slots have id = N.
        # Then when we copy data back from the grid, copy it to an array of shape
        # [N + 1, output_dimension] and then truncate it to an array of shape
        # [N, output_dimension] which ignores the empty slots.
        mask_id = np.ones((N, ), np.int64) * N
        cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype)
        cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32)

        # It might be worth adding an occupied mask. However, that will involve
        # more compute since often we will do a mask for species that will include
        # an occupancy test. It seems easier to design around this empty_data_value
        # for now and revisit the issue if it comes up later.
        empty_kwarg_value = 10**5
        cell_kwargs = {}
        for k, v in kwargs.items():
            if not isinstance(v, np.ndarray):
                raise ValueError(
                    ('Data must be specified as an ndarry. Found "{}" with '
                     'type {}'.format(k, type(v))))
            if v.shape[0] != R.shape[0]:
                raise ValueError((
                    'Data must be specified per-particle (an ndarray with shape '
                    '(R.shape[0], ...)). Found "{}" with shape {}'.format(
                        k, v.shape)))
            kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, )
            cell_kwargs[k] = empty_kwarg_value * np.ones(
                (cell_count * cell_capacity, ) + kwarg_shape, v.dtype)

        indices = np.array(R / cell_size, dtype=i32)
        hashes = np.sum(indices * hash_multipliers, axis=1)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
        # is a flat list that repeats 0, .., cell_capacity. So long as there are
        # fewer than cell_capacity particles per cell, each particle is guarenteed
        # to get a cell id that is unique.
        sort_map = np.argsort(hashes)
        sorted_R = R[sort_map]
        sorted_hash = hashes[sort_map]
        sorted_id = particle_id[sort_map]

        sorted_kwargs = {}
        for k, v in kwargs.items():
            sorted_kwargs[k] = v[sort_map]

        sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
        sorted_id = np.reshape(sorted_id, (N, 1))
        cell_id = ops.index_update(cell_id, sorted_cell_id, sorted_id)
        cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
        cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

        for k, v in sorted_kwargs.items():
            if v.ndim == 1:
                v = np.reshape(v, v.shape + (1, ))
            cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id,
                                              v)
            cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k],
                                                    cells_per_side, dim)

        return CellList(cell_R, cell_id, cell_kwargs)  # pytype: disable=wrong-arg-count

    return build_cells
Ejemplo n.º 11
0
    def build_cells(R, **kwargs):
        N = R.shape[0]
        dim = R.shape[1]

        if dim != 2 and dim != 3:
            # NOTE(schsam): Do we want to check this in compute_fn as well?
            raise ValueError(
                'Cell list spatial dimension must be 2 or 3. Found {}'.format(
                    dim))

        neighborhood_tile_count = 3**dim

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create cell list data.
        particle_id = lax.iota(np.int64, N)
        # NOTE(schsam): We use the convention that particles that are successfully,
        # copied have their true id whereas particles empty slots have id = N.
        # Then when we copy data back from the grid, copy it to an array of shape
        # [N + 1, output_dimension] and then truncate it to an array of shape
        # [N, output_dimension] which ignores the empty slots.
        mask_id = np.ones((N, ), np.int64) * N
        cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype)
        cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32)

        # It might be worth adding an occupied mask. However, that will involve
        # more compute since often we will do a mask for species that will include
        # an occupancy test. It seems easier to design around this empty_data_value
        # for now and revisit the issue if it comes up later.
        empty_kwarg_value = 10**5
        cell_kwargs = {}
        for k, v in kwargs.items():
            if not isinstance(v, np.ndarray):
                raise ValueError(
                    ('Data must be specified as an ndarry. Found "{}" with '
                     'type {}'.format(k, type(v))))
            if v.shape[0] != R.shape[0]:
                raise ValueError((
                    'Data must be specified per-particle (an ndarray with shape '
                    '(R.shape[0], ...)). Found "{}" with shape {}'.format(
                        k, v.shape)))
            kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, )
            cell_kwargs[k] = empty_kwarg_value * np.ones(
                (cell_count * cell_capacity, ) + kwarg_shape, v.dtype)

        indices = np.array(R / cell_size, dtype=i32)
        hashes = np.sum(indices * hash_multipliers, axis=1)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
        # is a flat list that repeats 0, .., cell_capacity. So long as there are
        # fewer than cell_capacity particles per cell, each particle is guarenteed
        # to get a cell id that is unique.
        sort_map = np.argsort(hashes)
        sorted_R = R[sort_map]
        sorted_hash = hashes[sort_map]
        sorted_id = particle_id[sort_map]

        sorted_kwargs = {}
        for k, v in kwargs.items():
            sorted_kwargs[k] = v[sort_map]

        sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
        sorted_id = np.reshape(sorted_id, (N, 1))
        cell_id = ops.index_update(cell_id, sorted_cell_id, sorted_id)
        cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
        cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

        for k, v in sorted_kwargs.items():
            if v.ndim == 1:
                v = np.reshape(v, v.shape + (1, ))
            cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id,
                                              v)
            cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k],
                                                    cells_per_side, dim)

        return CellList(cell_R, cell_id, cell_kwargs)  # pytype: disable=wrong-arg-count
Ejemplo n.º 12
0
 def tec_to_dtec(tec):
     tec = tec.reshape((nant, ndir, ntime))
     dtec = jnp.reshape(tec - tec[0, :, :], (-1, ))
     return dtec
Ejemplo n.º 13
0
def reshape_to_broadcast(array: jnp.array, shape: tuple, axis: int):
    """ reshapes the `array` to be broadcastable to `shape`"""
    new_shape = [1 for _ in shape]
    new_shape[axis] = shape[axis]
    return jnp.reshape(array, new_shape)
Ejemplo n.º 14
0
def unpack(v, sh):
    sz = np0.cumsum([prod(s) for s in sh])
    return [np.reshape(x, s) for x, s in zip(np.split(v, sz), sh)]
 def compute():
     sample = self.sample(shape=(sample_size, ))
     sample = np.reshape(sample, newshape=(sample_size, self.d))
     return self.compute_metrics(sample)
Ejemplo n.º 16
0
 def mask_self_fn(idx):
     self_mask = idx == np.reshape(np.arange(idx.shape[0]),
                                   (idx.shape[0], 1))
     return np.where(self_mask, idx.shape[0], idx)
Ejemplo n.º 17
0
    def _sp_mcmc(self, rng_key, unconstr_params, *args, **kwargs):
        # 0. Separate classical and stein parameters
        classic_uparams = {
            p: v
            for p, v in unconstr_params.items() if
            p not in self.guide_param_names or self.classic_guide_params_fn(p)
        }
        stein_uparams = {
            p: v
            for p, v in unconstr_params.items() if p not in classic_uparams
        }

        # Fix classical parameters for MCMC run
        self.mcmc.sampler._model = handlers.substitute(
            self.mcmc.sampler._model, self.constrain_fn(classic_uparams))

        # 1. Run warmup on a subset of particles to tune the MCMC state
        warmup_key, mcmc_key = jax.random.split(rng_key)
        if self.mcmc._warmup_state is None:
            stein_subset_uparams = {
                p: v[0:self.num_mcmc_particles]
                for p, v in stein_uparams.items()
            }
            self.mcmc.warmup(
                warmup_key,
                *args,
                init_params=self.constrain_fn(stein_subset_uparams),
                **kwargs)

        # 2. Choose MCMC particles
        mcmc_key, choice_key = jax.random.split(mcmc_key)
        if self.num_stein_particles == self.num_stein_particles:
            idxs = np.arange(self.num_stein_particles)
        else:
            if self.sp_mcmc_crit == 'rand':
                idxs = jax.random.shuffle(
                    choice_key, np.arange(
                        self.num_stein_particles))[:self.num_mcmc_particles]
            elif self.sp_mcmc_crit == 'infl':
                _, grads = self._svgd_loss_and_grads(choice_key,
                                                     unconstr_params, *args,
                                                     **kwargs)
                ksd = np.linalg.norm(np.concatenate([
                    np.reshape(grads[p], (self.num_stein_particles, -1))
                    for p in stein_uparams.keys()
                ],
                                                    axis=-1),
                                     ord=2,
                                     axis=-1)
                idxs = np.argsort(ksd)[:self.num_mcmc_particles]
            else:
                assert False, "Unsupported SP MCMC criterion: {}".format(
                    self.sp_mcmc_crit)

        # 3. Run MCMC on chosen particles
        stein_subset_uparams = {p: v[idxs] for p, v in stein_uparams.items()}
        self.mcmc.run(mcmc_key,
                      *args,
                      init_params=self.constrain_fn(stein_subset_uparams),
                      **kwargs)
        samples_subset_stein_params = self.mcmc.get_samples(
            group_by_chain=True)
        sss_uparams = self.uconstrain_fn(samples_subset_stein_params)

        # 4. Select best MCMC iteration to update particles
        scores = jax.vmap(lambda sss_uparam_i: self._score_sp_mcmc(
            mcmc_key, idxs, stein_uparams, sss_uparam_i, classic_uparams, *
            args, **kwargs),
                          in_axes=1)(sss_uparams)
        mcmc_idx = np.argmax(scores)
        stein_uparams = {
            p: ops.index_update(v, idxs, sss_uparams[:, mcmc_idx])
            for p, v in stein_uparams.items()
        }
        return {**stein_uparams, **classic_uparams}
Ejemplo n.º 18
0
 def _validate_sample(self, value):
     mask = super(ImproperUniform, self)._validate_sample(value)
     batch_dim = jnp.ndim(value) - len(self.event_shape)
     if batch_dim < jnp.ndim(mask):
         mask = jnp.all(jnp.reshape(mask, jnp.shape(mask)[:batch_dim] + (-1,)), -1)
     return mask
Ejemplo n.º 19
0
 def tile_grid(self, x):
     # takes slots (batch, k, d) and returns (batch*k, w, h, d)
     # i.e. collapse batches (for computation/layer applicability?) and copy slot information wxh times, wtf?
     # maybe this general representational mapping format is sensible - grid cells and conceptual spaces eichenbaum hmmm
     x = jnp.reshape(x, (x.shape[0]*x.shape[1], 1, 1, x.shape[-1]))
     return jnp.tile(x, [1,*self.C['spatial_broadcast_dims'],1])
Ejemplo n.º 20
0
 def FlattenControlsIntoTime(x, **unused_kwargs):  # pylint: disable=invalid-name
     """Splits logits for actions in different controls and flattens controls."""
     return np.reshape(x, (x.shape[0], -1, n_actions))
Ejemplo n.º 21
0
  def __call__(self, x: Array, *, train: bool, debug: bool = False):
    """Applies the module."""
    input_shape = x.shape
    b, h, w, _ = input_shape

    fh, fw = self.patches.size
    gh, gw = h // fh, w // fw

    if self.backbone_configs.type == 'vit' and self.decoder_configs.type == 'linear':
      assert self.backbone_configs.ens_size == 1

    if self.backbone_configs.type == 'vit' and self.decoder_configs.type == 'linear_be':
      raise NotImplementedError(
          'Configuration with encoder {} and decoder {} is not implemented'
          .format(
              self.backbone_configs.type,
              self.decoder_configs.type,
          ))

    if self.backbone_configs.type == 'vit':
      x, out = segmenter.ViTBackbone(
          mlp_dim=self.backbone_configs.mlp_dim,
          num_layers=self.backbone_configs.num_layers,
          num_heads=self.backbone_configs.num_heads,
          patches=self.patches,
          hidden_size=self.backbone_configs.hidden_size,
          dropout_rate=self.backbone_configs.dropout_rate,
          attention_dropout_rate=self.backbone_configs.attention_dropout_rate,
          classifier=self.backbone_configs.classifier,
          name='backbone')(
              x, train=train)
    elif self.backbone_configs.type == 'vit_be':
      x, out = ViTBackboneBE(
          mlp_dim=self.backbone_configs.mlp_dim,
          num_layers=self.backbone_configs.num_layers,
          num_heads=self.backbone_configs.num_heads,
          patches=self.patches,
          hidden_size=self.backbone_configs.hidden_size,
          dropout_rate=self.backbone_configs.dropout_rate,
          attention_dropout_rate=self.backbone_configs.attention_dropout_rate,
          classifier=self.backbone_configs.classifier,
          ens_size=self.backbone_configs.ens_size,
          random_sign_init=self.backbone_configs.random_sign_init,
          be_layers=self.backbone_configs.be_layers,
          name='backbone')(
              x, train=train)
    else:
      raise ValueError(f'Unknown backbone: {self.backbone_configs.type}.')

    if self.decoder_configs.type == 'linear':
      output_projection = nn.Dense(
          self.num_classes,
          kernel_init=self.head_kernel_init,
          name='output_projection')
    elif self.decoder_configs.type == 'linear_be':
      output_projection = ed.nn.DenseBatchEnsemble(
          self.num_classes,
          self.backbone_configs.ens_size,
          activation=None,
          alpha_init=ed.nn.utils.make_sign_initializer(
              self.backbone_configs.get('random_sign_init')),
          gamma_init=ed.nn.utils.make_sign_initializer(
              self.backbone_configs.get('random_sign_init')),
          kernel_init=self.head_kernel_init,
          name='output_projection_be')
    else:
      raise ValueError(
          f'Decoder type {self.decoder_configs.type} is not defined.')

    ens_size = self.backbone_configs.get('ens_size')

    # Linear head only, like Segmenter baseline:
    # https://arxiv.org/abs/2105.05633
    x = jnp.reshape(x, [b * ens_size, gh, gw, -1])
    x = output_projection(x)

    # Resize bilinearly:
    x = jax.image.resize(x, [b * ens_size, h, w, x.shape[-1]], 'linear')
    out['logits'] = x

    new_input_shape = tuple([
        input_shape[0] * ens_size,
    ] + list(input_shape[1:-1]))
    assert new_input_shape == x.shape[:-1], (
        'BE Input and output shapes do not match: %d vs. %d.', new_input_shape,
        x.shape[:-1])

    return x, out
Ejemplo n.º 22
0
 def _may_repeat(self, z):
     """Enforces rank 2 on z, repeating itself if needed to match the batch."""
     z = np.array(z)
     if len(z.shape) < len(self.x.shape):
         z = np.reshape(np.tile(z, [self._batch]), (self._batch, -1))
     return z
Ejemplo n.º 23
0
 def lagr_form(x):
     val = opt_instance.lagrangian_form_pre.apply(
         x, opt_instance.lagrange_params_pre, step)
     return jnp.reshape(val, ())
Ejemplo n.º 24
0
 def _squash_sample_dims(v: jnp.array) -> jnp.array:
     old_shape = jnp.shape(v)
     assert len(old_shape) >= 2
     new_shape = (old_shape[0] * old_shape[1], *old_shape[2:])
     reshaped_v = jnp.reshape(v, new_shape)
     return reshaped_v
Ejemplo n.º 25
0
 def sample(self, key, sample_shape=()):
     return jnp.reshape(
         random.split(key,
                      np.prod(sample_shape).astype(np.int32)),
         sample_shape + self.event_shape)
Ejemplo n.º 26
0
def flatten(x):
    return np.reshape(x, (x.shape[0], -1))
Ejemplo n.º 27
0
def momentum(g_dd, y_train, loss, learning_rate, g_td=None, momentum=0.9):
  r"""Predicts the outcome of function space training using momentum descent.

  Solves a continuous-time version of standard momentum instead of
  Nesterov momentum using an ODE solver.

  Solves the function space ODE for momentum with a given loss (detailed
  in [*]) given a Neural Tangent Kernel over the dataset. This function returns
  a triplet of functions that initialize state variables, predicts the time
  evolution for function space points at arbitrary times and retrieves the
  function-space outputs from the state. Note that times are continuous and are
  measured in units of the learning rate so that
  t = \sqrt(learning_rate) * steps.

  This function uses the scipy ode solver with the 'dopri5' algorithm.

  [*] https://arxiv.org/abs/1902.06720

  Example:
    ```python
    >>> train_time = 1e-7
    >>> learning_rate = 1e-2
    >>>
    >>> kernel_fn = empirical(f)
    >>> g_td = kernel_fn(x_test, x_train, params)
    >>>
    >>> from jax.experimental import stax
    >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
    >>> init_fn, predict_fn, get_fn = predict.momentum(
    >>>                   g_dd, y_train, cross_entropy, learning_rate, g_td)
    >>>
    >>> fx_train_initial = f(params, x_train)
    >>> fx_test_initial = f(params, x_test)
    >>>
    >>> lin_state = init_fn(fx_train_initial, fx_test_initial)
    >>> lin_state = predict_fn(lin_state, train_time)
    >>> fx_train_final, fx_test_final = get_fn(lin_state)
    ```python

  Args:
    g_dd: Kernel on the training data. The kernel should be an `np.ndarray` of
      shape [n_train * output_dim, n_train * output_dim].
    y_train: A `np.ndarray` of shape [n_train, output_dim] of labels for the
      training data.
    loss: A loss function whose signature is loss(fx, y_hat) where fx an
      `np.ndarray` of function space outputs of the network and y_hat are
      labels. Note: the loss function should treat the batch and output
        dimensions symmetrically.
    learning_rate:  A float specifying the learning rate.
    g_td: Kernel relating training data with test data. Should be an
      `np.ndarray` of shape [n_test * output_dim, n_train * output_dim]. Note:
        g_td should have been created in the convention g_td = kernel_fn(x_test,
        x_train, params).
    momentum: float specifying the momentum.

  Returns:
    Functions to predicts outputs after t = \sqrt(learning_rate) * steps of
    training. Generically three functions are returned, an init_fn that creates
    auxiliary velocity variables needed for optimization and packs them into
    a state variable, a predict_fn that computes the time-evolution of the state
    for some dt, and a get_fn that extracts the predictions from the state.

    If g_td is None:
      init_fn(fx_train): Takes a single `np.ndarray` of shape
        [n_train, output_dim] and returns a tuple containing the output_dim as
        an int and an `np.ndarray` of shape [2 * n_train * output_dim].

      predict_fn(state, dt): Takes a state described above and a floating point
        time. Returns a new state with the same type and shape.

      get_fn(state): Takes a state and returns an `np.ndarray` of shape
        [n_train, output_dim].

    If g_td is not None:
      init_fn(fx_train, fx_test): Takes two `np.ndarray`s of shape
        [n_train, output_dim] and [n_test, output_dim] respectively. Returns a
        tuple with an int giving 2 * n_train * output_dim, an int containing the
        output_dim, and an `np.ndarray` of shape
        [2 * (n_train + n_test) * output_dim].

      predict_fn(state, dt): Takes a state described above and a floating point
        time. Returns a new state with the same type and shape.

      get_fn(state): Takes a state and returns two `np.ndarray` of shape
        [n_train, output_dim] and [n_test, output_dim] respectively.
  """
  output_dimension = y_train.shape[-1]

  g_dd = empirical.flatten_features(g_dd)

  momentum = (momentum - 1.0) / np.sqrt(learning_rate)

  def fl(fx):
    """Flatten outputs."""
    return np.reshape(fx, (-1,))

  def ufl(fx):
    """Unflatten outputs."""
    return np.reshape(fx, (-1, output_dimension))

  # These functions are used inside the integrator only if the kernel is
  # diagonal over the logits.
  ifl = lambda x: x
  iufl = lambda x: x

  # Check to see whether the kernel has a logit dimension.
  if y_train.size > g_dd.shape[-1]:
    out_dim, ragged = divmod(y_train.size, g_dd.shape[-1])
    if ragged or out_dim != y_train.shape[-1]:
      raise ValueError()
    ifl = fl
    iufl = ufl

  y_train = np.reshape(y_train, (-1))
  grad_loss = grad(functools.partial(loss, y_hat=y_train))

  if g_td is None:

    def dr_dt(unused_t, r):
      fx, qx = np.split(r, 2)
      dfx = qx
      dqx = momentum * qx - ifl(np.dot(g_dd, iufl(grad_loss(fx))))
      return np.concatenate((dfx, dqx), axis=0)

    def init_fn(fx_train=0.):
      fx_train = fl(fx_train)
      qx_train = np.zeros_like(fx_train)
      return np.concatenate((fx_train, qx_train), axis=0)

    def predict_fn(state, dt):
      state = state

      solver = ode(dr_dt).set_integrator('dopri5')
      solver.set_initial_value(state, 0)
      solver.integrate(dt)

      return solver.y

    def get_fn(state):
      return ufl(np.split(state, 2)[0])

  else:
    g_td = empirical.flatten_features(g_td)

    def dr_dt(unused_t, r, train_size):
      train, test = r[:train_size], r[train_size:]
      fx_train, qx_train = np.split(train, 2)
      _, qx_test = np.split(test, 2)
      dfx_train = qx_train
      dqx_train = \
          momentum * qx_train - ifl(np.dot(g_dd, iufl(grad_loss(fx_train))))
      dfx_test = qx_test
      dqx_test = \
          momentum * qx_test - ifl(np.dot(g_td, iufl(grad_loss(fx_train))))
      return np.concatenate((dfx_train, dqx_train, dfx_test, dqx_test), axis=0)

    def init_fn(fx_train=0., fx_test=0.):
      train_size = fx_train.shape[0]
      fx_train, fx_test = fl(fx_train), fl(fx_test)
      qx_train = np.zeros_like(fx_train)
      qx_test = np.zeros_like(fx_test)
      return (2 * train_size * output_dimension,
              np.concatenate((fx_train, qx_train, fx_test, qx_test), axis=0))

    def predict_fn(state, dt):
      train_size, state = state
      solver = ode(dr_dt).set_integrator('dopri5')
      solver.set_initial_value(state, 0).set_f_params(train_size)
      solver.integrate(dt)

      return train_size, solver.y

    def get_fn(state):
      train_size, state = state
      train, test = state[:train_size], state[train_size:]
      return ufl(np.split(train, 2)[0]), ufl(np.split(test, 2)[0])

  return init_fn, predict_fn, get_fn
Ejemplo n.º 28
0
def gradient_descent_predictor(g_dd, y_train, loss, g_td=None):
    """Predicts the outcome of function space training using gradient descent.

  Solves the function space ODE for gradient descent with a given loss (detailed
  in [*]) given a Neural Tangent Kernel over the dataset. This function returns
  a function that predicts the time evolution for function space points at
  arbitrary times. Note that times are continuous and are measured in units of
  the learning rate so that t = learning_rate * steps.

  This function uses the scipy ode solver with the 'dopri5' algorithm.

  [*] https://arxiv.org/abs/1806.07572

  Example:
    >>> train_time = 1e-7
    >>> kernel_fn = ntk(f)
    >>> g_dd = compute_spectrum(kernel_fn(params, x_train))
    >>> g_td = kernel_fn(params, x_test, x_train)
    >>>
    >>> from jax.experimental import stax
    >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
    >>> predict_fn = gradient_descent_predictor(
    >>>                   g_dd, train_y, cross_entropy, g_td)
    >>>
    >>> fx_train_initial = f(params, x_train)
    >>> fx_test_initial = f(params, x_test)
    >>>
    >>> fx_train_final, fx_test_final = predict_fn(
    >>>          fx_train_initial, fx_test_initial, train_time)

  Args:
    g_dd: A Kernel on the training data. The kernel should be an ndarray of
      shape [n_train * output_dim, n_train * output_dim].
    y_train: An ndarray of shape [n_train, output_dim] of labels for the
      training data.
    loss: A loss function whose signature is loss(fx, y_hat) where fx is an
      ndarray of function space output_dim of the network and y_hat are
      targets.

      Note: the loss function should treat the batch and output dimensions
      symmetrically.
    g_td: A Kernel relating training data with test data. The kernel should be
      an ndarray of shape [n_test * output_dim, n_train * output_dim].

      Note: g_td should have been created in the convention
      kernel_fn(params, x_test, x_train).

  Returns:
    A function that predicts outputs after t = learning_rate * steps of
    training.

    If g_td is None:
      The function returned is predict(fx, t). Here fx is an ndarray of network
      outputs and has shape [n_train, output_dim], t is a floating point time.
      predict(fx, t) returns an ndarray of predictions of shape
      [n_train, output_dim].

    If g_td is not None:
      If a test set Kernel is specified then it returns a function,
      predict(fx_train, fx_test, t). Here fx_train and fx_test are ndarays of
      network outputs and have shape [n_train, output_dim] and
      [n_test, output_dim] respectively and t is a floating point time.
      predict(fx_train, fx_test, t) returns a tuple of predictions of shape
      [n_train, output_dim] and [n_test, output_dim] for train and test points
      respectively.
  """
    y_train = np.reshape(y_train, (-1))
    grad_loss = grad(functools.partial(loss, y_hat=y_train))

    def fl(fx):
        """Flatten outputs."""
        return np.reshape(fx, (-1, ))

    def ufl(fx, output_dim):
        """Unflatten outputs."""
        return np.reshape(fx, (-1, output_dim))

    if g_td is None:
        dfx_dt = lambda unused_t, fx: -np.dot(g_dd, grad_loss(fx))

        def predict(fx, dt):
            r = ode(dfx_dt).set_integrator('dopri5')
            r.set_initial_value(fl(fx), 0)
            r.integrate(dt)

            return ufl(r.y, fx.shape[-1])
    else:

        def dfx_dt(unused_t, fx, train_size):
            fx_train = fx[:train_size]
            dfx_train = -np.dot(g_dd, grad_loss(fx_train))
            dfx_test = -np.dot(g_td, grad_loss(fx_train))
            return np.concatenate((dfx_train, dfx_test), axis=0)

        def predict(fx_train, fx_test, dt):
            r = ode(dfx_dt).set_integrator('dopri5')

            fx = fl(np.concatenate((fx_train, fx_test), axis=0))
            train_size, output_dim = fx_train.shape
            r.set_initial_value(fx, 0).set_f_params(train_size * output_dim)
            r.integrate(dt)
            fx = ufl(r.y, output_dim)

            return fx[:train_size], fx[train_size:]

    return predict
Ejemplo n.º 29
0
 def ufl(fx):
   """Unflatten outputs."""
   return np.reshape(fx, (-1, output_dimension))
Ejemplo n.º 30
0
def shape_as_image(images, labels, dummy_dim=False):
  target_shape = (-1, 1, 28, 28, 1) if dummy_dim else (-1, 28, 28, 1)
  return np.reshape(images, target_shape), labels