def piecewise_constant(boundaries, values):
  boundaries = np.array(boundaries)
  values = np.array(values)
  if not boundaries.ndim == values.ndim == 1:
    raise ValueError("boundaries and values must be sequences")
  if not boundaries.shape[0] == values.shape[0] - 1:
    raise ValueError("boundaries length must be one shorter than values length")

  def schedule(i):
    return values[np.sum(i > boundaries)]
  return schedule
    def testSize(self):
        def run_test(arr):
            onp_arr = arr.numpy() if isinstance(arr, tf.Tensor) else arr
            print(onp_arr)
            self.assertEqual(np_size(arr), onp.size(onp_arr))

        run_test(np.array([1]))
        run_test(np.array([1, 2, 3, 4, 5]))
        run_test(np.ones((2, 3, 2)))
        run_test(np.ones((3, 2)))
        run_test(np.zeros((5, 6, 7)))
        run_test(1)
        run_test(onp.ones((3, 2, 1)))
        run_test(tf.constant(5))
        run_test(tf.constant([1, 1, 1]))
Example #3
0
        def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train):
            t = np.array(t) * learning_rate
            t_shape, t_ndim = t.shape, t.ndim
            t = t.reshape((-1, 1))

            rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train
            rhs = np.moveaxis(rhs, trace_axes,
                              last_t_axes).reshape((-1, ) + rhs_shape)
            shape = t_shape + k_train_train.shape[1::2] + rhs_shape

            if fx_train_0 is not None:
                dfx_train = expm1_fn(rhs, t).reshape(shape)
                dfx_train = np.moveaxis(dfx_train, last_t_axes, trace_axes)
                fx_train_t = fx_train_0 + dfx_train

            if fx_test_0 is not None:
                dfx_test = inv_expm1_fn(rhs, t).reshape(shape)
                dfx_test = np.tensordot(k_test_train, dfx_test,
                                        (odd, non_t_axes))
                dfx_test = np.moveaxis(
                    dfx_test,
                    tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) +
                    last_t_axes,
                    tuple(range(t_ndim)) + trace_axes)
                fx_test_t = fx_test_0 + dfx_test

            if fx_train_0 is not None and fx_test_0 is not None:
                return fx_train_t, fx_test_t
            if fx_test_0 is None:
                return fx_train_t
            return fx_test_t
Example #4
0
    def testGradientDescentMseEnsembleGet(self, train_shape, test_shape,
                                          network, out_logits):
        _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                       train_shape)
        _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)

        predictor = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=0.)
        for x in [None, 'x_test']:
            with self.subTest(x=x):
                x = x if x is None else x_test
                out = predictor(None, x, 'ntk', compute_cov=True)
                assert isinstance(out, predict.Gaussian)

                out = predictor(1., x, 'nngp', compute_cov=True)
                assert isinstance(out, predict.Gaussian)

                out = predictor(np.array([0., 1.]),
                                x, ('ntk', ),
                                compute_cov=True)
                assert len(out) == 1 and isinstance(out[0], predict.Gaussian)

                out = predictor(2., x, ('ntk', 'nngp'), compute_cov=True)
                assert (len(out) == 2 and isinstance(out[0], predict.Gaussian)
                        and isinstance(out[1], predict.Gaussian))

                out2 = predictor(2., x, ('nngp', 'ntk'), compute_cov=True)
                self.assertAllClose(out[0], out2[1])
                self.assertAllClose(out[1], out2[0])
Example #5
0
    def testNTKMeanCovPrediction(self, train_shape, test_shape, network,
                                 out_logits):
        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)
        init_fn, f, kernel_fn = stax.serial(
            stax.Dense(512, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        reg = 1e-6
        predictor = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=reg)
        ts = np.array([1., 5., 10.])

        fx_test_inf, cov_test_inf = predictor(ts, x_test, 'ntk', True)
        self.assertEqual(cov_test_inf.shape[1], x_test.shape[0])
        self.assertGreater(np.min(np.linalg.eigh(cov_test_inf)[0]), -1e-8)

        fx_train_inf, cov_train_inf = predictor(ts, None, 'ntk', True)
        self.assertEqual(cov_train_inf.shape[1], x_train.shape[0])
        self.assertGreater(np.min(np.linalg.eigh(cov_train_inf)[0]), -1e-8)

        _kernel_fn = empirical.empirical_kernel_fn(f)
        kernel_fn = jit(
            lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params))

        def predict_empirical(key):
            _, params = init_fn(key, train_shape)
            g_dd = kernel_fn(x_train, None, params)
            g_td = kernel_fn(x_test, x_train, params)
            predict_fn = predict.gradient_descent_mse(g_dd,
                                                      y_train,
                                                      diag_reg=reg)
            fx_train_0 = f(params, x_train)
            fx_test_0 = f(params, x_test)
            return predict_fn(ts, fx_train_0, fx_test_0, g_td)

        def predict_mc(count, key):
            key = tf_random_split(key, count)
            fx_train, fx_test = vmap(predict_empirical)(key)
            fx_train_mean = np.mean(fx_train, axis=0)
            fx_test_mean = np.mean(fx_test, axis=0)

            fx_train_centered = fx_train - fx_train_mean
            fx_test_centered = fx_test - fx_test_mean

            cov_train = PredictTest._cov_empirical(fx_train_centered)
            cov_test = PredictTest._cov_empirical(fx_test_centered)

            return fx_train_mean, fx_test_mean, cov_train, cov_test

        fx_train_mc, fx_test_mc, cov_train_mc, cov_test_mc = predict_mc(
            4096, key)
        rtol = 0.05
        self._assertAllClose(fx_train_mc, fx_train_inf, rtol)
        self._assertAllClose(cov_train_mc, cov_train_inf, rtol)
        self._assertAllClose(cov_test_mc, cov_test_inf, rtol)
        self._assertAllClose(fx_test_mc, fx_test_inf, rtol)
 def apply_fun(params, inputs, **kwargs):
   inputs = onp.moveaxis(inputs, (batch_dim, channel_dim), \
                       (0, dim + 1))
   output = reduce_window(inputs, init_val, reducer, window_shape,
                           strides, padding)
   return rescale(out, inputs, spec) if rescale else out
   # return output
   return tfnp.array(output)
Example #7
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             keys: Union[PRNGKey,
                         Tuple[PRNGKey, PRNGKey],
                         np.ndarray] = None,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (implicit differentiation).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      keys:
        `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of
        dtype `uint32`. If `key=None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then `x1`
        and `x2` must be the same and share the same PRNG key; else `x1` and
        `x2` use two different PRNG keys.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    key1, key2 = _read_keys(keys)
    # TODO(xlc): find a good way to check utils.x1_is_x2(x1, x2) == (key1==key2)

    f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs)
    f2 = f1 if x2 is None else _get_f_params(f, x2, key2, **apply_fn_kwargs)

    def delta_vjp_jvp(delta):
      def delta_vjp(delta):
        return vjp(f2, params)[1](delta)
      return _jvp(f1, _tf_to_np((params,)), delta_vjp(delta))[1]

    # Since we are taking the Jacobian of a linear function (which does not
    # depend on its coefficients), it is more efficient to substitute fx_dummy
    # for the outputs of the network. fx_dummy has the same shape as the output
    # of the network on a single piece of input data.
    fx2_struct = eval_on_shapes(f2)(params)
    fx_dummy = np.ones(fx2_struct.shape, dtype=tf.float32)

    # ntk = jacobian(delta_vjp_jvp)(fx_dummy)
    with tf.GradientTape() as tape:
      tape.watch(fx_dummy.data)
      y = delta_vjp_jvp(fx_dummy.data)
    ntk = np.array(tape.jacobian(y, fx_dummy.data))
    return _index_and_contract(ntk, trace_axes, diagonal_axes)
Example #8
0
def get_masked_array(x: ArrayOrList,
                     mask_constant: float = None) -> MaskedArray:
  """Return `x` with entries equal to `mask_constant` zeroed-out, and the mask.

  The mask returned is a boolean `np.ndarray` with masked indices having `True`.

  Args:
    x: `np.ndarray` to mask. If `x` is a `MaskedInput`, treat it as
      `(masked_x, mask)` and pass it through.
    mask_constant: an optional `float`, the value in inputs to be considered as
      masked (e.g. padding in a batch of sentences). `None` means no masking.
      Can also be `np.nan`, `np.inf` etc.

  Returns:
    A `MaskedArray` of `(masked_x, boolean_mask)`.
  """
  if isinstance(x, list):
    x_array = []
    mask_array = []
    for x_ in x:
      masked_array = get_masked_array(x_, mask_constant)
      x_array.append(masked_array.masked_value)
      mask_array.append(masked_array.mask)
    # fields = zip(*(get_masked_array(_x, mask_constant).astuple() for _x in x))
    # return MaskedArray(*(list(f) for f in fields))
    return MaskedArray(x_array, mask_array)

  if x is None:
    mask = None

  if isinstance(x, MaskedArray):
    masked_value = x.masked_value
    mask = x.mask
    x = masked_value

  elif isinstance(x, np.ndarray) or isinstance(x, onp.ndarray):
    x = np.asarray(x)
    if mask_constant is None:
      mask = None
    else:
      choice_a = lambda: np.array(tf.math.is_nan(x))
      choice_b = lambda: x == mask_constant
      # mask = choice_a(x) if math.isnan(mask_constant) else choice_b(x)
      mask = tf.cond(tf.math.is_nan(mask_constant), choice_a, choice_b)
  else:
    raise TypeError(x, type(x))

  if mask is not None:
    x = np.where(mask, np.zeros((), x.dtype), x)

  return MaskedArray(x, mask)  # pytype: disable=wrong-arg-count
Example #9
0
    def evaluate(self, x, y):
        """Returns the number of correct predictions.

        Args:
            x: 2-d array of size batch_size x image_size.
            y: 2-d array of size batch_size x num_classes.

        Returns:
            A scalar, the number of correct predictions.
        """
        y_actual = np.argmax(y,  axis=1)
        y_predicted = np.argmax(self.forward(x), axis=1)
        correct = int(np.sum(np.array(y_actual == y_predicted)))
        return correct
Example #10
0
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
  """Compute the shape tuple of a conv given input shapes in canonical order."""
  if isinstance(pads, str):
    pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
  if len(pads) != len(lhs_shape) - 2:
    msg = 'Wrong number of explicit pads for convolution: expected {}, got {}.'
    raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))

  lhs_padded = onp.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2),
                                              axis=1))
  out_space = np.floor_divide(
    np.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
  out_space = np.maximum(0, out_space)
  assert lhs_shape[0] % batch_group_count == 0
  out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0])
  return tuple(out_shape + tuple(out_space))
Example #11
0
    def testGradientDescentMseEnsembleTrain(self):
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        x = np.asarray(normal((8, 4, 6, 3), seed=key))
        _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(),
                                      stax.Conv(1, (2, 1)))
        y = np.asarray(normal((8, 2, 5, 1), seed=key))
        predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y)

        for t in [None, np.array([0., 1., 10.])]:
            with self.subTest(t=t):
                y_none = predictor(t, None, None, compute_cov=True)
                y_x = predictor(t, x, None, compute_cov=True)
                self._assertAllClose(y_none, y_x, 0.04)
Example #12
0
 def test_tf_dot_general(self, lhs_np, rhs_np, dims):
     ans = jax.lax.dot_general(lhs_np, rhs_np, dims)
     result = lax.dot_general(lhs_np, rhs_np, dims)
     self.assertAllClose(result, tfnp.array(ans))
Example #13
0
    def predict_fn(
        t: ArrayOrScalar = None,
        fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
        fx_test_0: ArrayOrScalar = None,
        k_test_train: np.ndarray = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]:
        """Return output predictions on train [and test] set[s] at time[s] `t`.

    Args:
      t:
        a scalar or array of scalars of any shape in strictly increasing order.
        `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of
        training steps (but can be fractional).
      fx_train_or_state_0:
        either (a) output of the network at `t == 0` on the training set or (b)
        complete ODE state (`predict.ODEState`). Pass an ODE state if you want
        to operate on the full ODE state instead of output variables only
        (useful for inspecting auxiliary variables or resuming an optimizer with
        auxiliary variables from a specific state. Note that only
        `momentum != None` optimizer currently has auxiliary variables. To
        initialize an ODE state from scratch, call
        `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an
        ODE state is returned. `fx_train_0=None` means to not compute
        predictions on the training set.
      fx_test_0:
        output of the network at `t == 0` on the test set. `fx_test_0=None`
        means to not compute predictions on the test set.
      k_test_train:
        kernel relating test data with training data. Must have the shape of
        `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass
        `k_test_train=None` if you only need predictions on the training set.

    Returns:
      `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with
      potentially additional leading time dimensions matching `t.shape`.
      Alternatively can return an `ODEState` at time[s] `t`.

    Raises:
      ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`.
    """
        _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train)

        t = np.array(t if t is not None else np.inf, dtype) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, ))

        # ODE solver requires `t[0]` to be the time where `fx_train_0` [and
        # `fx_test_0`] are evaluated, but also a strictly increasing sequence of
        # timesteps, so we always temporarily append an [almost] `0` at the start.
        identity = lambda x: x
        t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype),
                      np.zeros((1, ), t.dtype))
        t = np.concatenate([t0, t])

        # Solve the ODE.
        fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes)
        state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape)
        state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t)

        # Remove the added `t0`.
        trim = lambda x: x[1:].reshape(t_shape + x.shape[1:])
        trim_tree = lambda tree: tree_map(trim, tree)
        state_t = trim_tree(state_t)

        # `ODEState` -> `ODEState`
        if isinstance(fx_train_or_state_0, ODEState):
            return state_t

        # `np.ndarray` -> `np.ndarray`
        fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test

        if fx_train_or_state_0 is not None and fx_test_0 is None:
            return fx_train_t
        if fx_test_0 is not None and fx_train_or_state_0 is None:
            return fx_test_t
        return fx_train_t, fx_test_t
Example #14
0
    def predict_fn(t: ArrayOrScalar = None,
                   x_test: np.ndarray = None,
                   get: Get = None,
                   compute_cov: bool = False) -> Dict[str, Gaussian]:
        """Return output mean and covariance on the test set at time[s] `t`.

    Args:
      t:
        a scalar of array of scalars of any shape. `t=None` is treated as
        infinity and returns the same result as `t=np.inf`, but is computed
        using linear solve for test predictions instead of eigendecomposition,
        saving time and precision.
      x_test:
        test inputs. `None` means to return non-regularized (`diag_reg=0`)
        predictions on the train-set inputs. For regularized predictions, pass
        `x_test=x_train`.
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple. `get=None` is equivalent to `get=("nngp", "ntk")`.
      compute_cov:
        if `True` computing both `mean` and `variance` and only `mean`
        otherwise.

    Returns:
      `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if
      `compute_cov == True` with potentially additional leading time dimensions.
    """
        if get is None:
            get = ('nngp', 'ntk')

        # train-train, test-train, test-test.
        k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov)

        # Infinite time.
        if t is None:
            return predict_inf(get)(get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)

        # Finite time.
        t = np.array(t) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, 1))

        def reshape_mean(mean):
            k = _get_first(k_dd if k_td is None else k_td)
            mean = mean.reshape(t_shape + k.shape[::2] + trace_shape)
            mean = np.moveaxis(mean, last_t_axes, trace_axes)
            return mean

        def reshape_cov(cov):
            k = _get_first(k_dd if k_td is None else k_td)
            cov_shape_t = t_shape + k.shape[::2] * 2
            return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape))

        out = {}

        for g in get:
            evals, evecs = eigenspace(g)

            # Training set.
            if k_td is None:
                mean = tf.einsum('ji,ti,ki,k...->tj...',
                                 evecs,
                                 -expm1(evals, t),
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            # Test set.
            else:
                neg_inv_expm1 = -inv_expm1(evals, t)
                ktd_g = utils.make_2d(getattr(k_td, g))
                mean = tf.einsum('lj,ji,ti,ki,k...->tl...',
                                 ktd_g,
                                 evecs,
                                 neg_inv_expm1,
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            mean = reshape_mean(mean)

            if nngp_tt is not None:
                nngp_dd = utils.make_2d(k_dd.nngp)

                # Training set.
                if k_td is None:
                    if g == 'nngp':
                        cov = np.einsum('ji,ti,ki->tjk',
                                        evecs,
                                        (np.maximum(evals, 0.) *
                                         np.exp(-2 * np.maximum(evals, 0.) *
                                                t / y_train.size)),
                                        evecs,
                                        optimize=True)

                    elif g == 'ntk':
                        exp = np.einsum('mi,ti,ki->tmk',
                                        evecs,
                                        np.exp(-np.maximum(evals, 0.) * t /
                                               y_train.size),
                                        evecs,
                                        optimize=True)
                        cov = np.einsum('tmk,kl,tnl->tmn',
                                        exp,
                                        nngp_dd,
                                        exp,
                                        optimize=True)

                    else:
                        raise ValueError(g)

                # Test set.
                else:
                    _nngp_tt = utils.make_2d(nngp_tt)

                    if g == 'nngp':
                        cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml',
                                                   ktd_g,
                                                   evecs,
                                                   -inv_expm1(evals, 2 * t),
                                                   evecs,
                                                   ktd_g,
                                                   optimize=True)

                    elif g == 'ntk':
                        term_1 = np.einsum('mi,ti,ki,lk->tml',
                                           evecs,
                                           neg_inv_expm1,
                                           evecs,
                                           ktd_g,
                                           optimize=True)
                        term_2 = np.einsum(
                            'mj,ji,ti,ki,lk->tml',
                            ktd_g,
                            evecs,
                            neg_inv_expm1,
                            evecs,
                            utils.make_2d(k_td.nngp),  # pytype:disable=attribute-error
                            optimize=True)
                        term_2 += np.moveaxis(term_2, 1, 2)
                        cov = np.einsum('tji,jk,tkl->til',
                                        term_1,
                                        nngp_dd,
                                        term_1,
                                        optimize=True)
                        cov += -term_2 + _nngp_tt

                    else:
                        raise ValueError(g)

                out[g] = Gaussian(mean, reshape_cov(cov))

            else:
                out[g] = mean

        return out
    def test_jit_or_pmap_broadcast(self):
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= stateless_uniform(shape=[], seed=keys) * p
            return [res, params]

        params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
        x2 = np.arange(0, 10).reshape((10, ))
        keys = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=0)
        x1 = np.arange(0, 10).reshape((1, 10))
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=0):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=True,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=True)
                    self.assertAllClose(res_1, res_2)

        test_utils.stub_out_pmap(batch, 1)
        x1 = np.arange(0, 10).reshape((1, 10))
        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=1)
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=1):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=False,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None)
                    self.assertAllClose(res_1[0], res_2[0])
                    self.assertAllClose(
                        tree_map(partial(np.expand_dims, axis=0), res_1[1]),
                        res_2[1])

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=2)
        x1 = np.arange(0, 20).reshape((2, 10))
        test_utils.stub_out_pmap(batch, 2)

        def broadcast(arg):
            return np.broadcast_to(arg, (2, ) + arg.shape)

        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=2):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      p=0.2)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None,
                                              p=0.2)
                    self.assertAllClose(res_1[0][0], res_2[0][0])
                    self.assertAllClose(res_1[0][1], res_2[0][1])
                    self.assertAllClose(tree_map(broadcast, res_1[1]),
                                        res_2[1])
Example #16
0
    def testNTK_NTKNNGPAgreement(self, train_shape, test_shape, network,
                                 out_logits):
        _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                       train_shape)
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        predictor = predict.gradient_descent_mse_ensemble(ker_fun,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=reg)

        ts = np.logspace(-2, 8, 10).reshape((5, 2))

        for t in (None, 'ts'):
            for x in (None, 'x_test'):
                with self.subTest(t=t, x=x):
                    x = x if x is None else x_test
                    t = t if t is None else ts

                    ntk = predictor(t=t, get='ntk', x_test=x)

                    # Test time broadcasting
                    if t is not None:
                        ntk_ind = np.array([
                            predictor(t=t, get='ntk', x_test=x)
                            for t in t.ravel()
                        ]).reshape(t.shape + ntk.shape[2:])
                        self.assertAllClose(ntk_ind, ntk)

                    # Create a hacked kernel function that always returns the ntk kernel
                    def always_ntk(x1, x2, get=('nngp', 'ntk')):
                        out = ker_fun(x1, x2, get=('nngp', 'ntk'))
                        if get == 'nngp' or get == 'ntk':
                            return out.ntk
                        else:
                            return out._replace(nngp=out.ntk)

                    predictor_ntk = predict.gradient_descent_mse_ensemble(
                        always_ntk, x_train, y_train, diag_reg=reg)

                    ntk_nngp = predictor_ntk(t=t, get='nngp', x_test=x)

                    # Test if you use nngp equations with ntk, you get the same mean
                    self.assertAllClose(ntk, ntk_nngp)

                    # Next test that if you go through the NTK code path, but with only
                    # the NNGP kernel, we recreate the NNGP dynamics.
                    # Create a hacked kernel function that always returns the nngp kernel
                    def always_nngp(x1, x2, get=('nngp', 'ntk')):
                        out = ker_fun(x1, x2, get=('nngp', 'ntk'))
                        if get == 'nngp' or get == 'ntk':
                            return out.nngp
                        else:
                            return out._replace(ntk=out.nngp)

                    predictor_nngp = predict.gradient_descent_mse_ensemble(
                        always_nngp, x_train, y_train, diag_reg=reg)

                    nngp_cov = predictor(t=t,
                                         get='nngp',
                                         x_test=x,
                                         compute_cov=True).covariance

                    # test time broadcasting for covariance
                    nngp_ntk_cov = predictor_nngp(t=t,
                                                  get='ntk',
                                                  x_test=x,
                                                  compute_cov=True).covariance
                    if t is not None:
                        nngp_ntk_cov_ind = np.array([
                            predictor_nngp(t=t,
                                           get='ntk',
                                           x_test=x,
                                           compute_cov=True).covariance
                            for t in t.ravel()
                        ]).reshape(t.shape + nngp_cov.shape[2:])
                        self.assertAllClose(nngp_ntk_cov_ind, nngp_ntk_cov)

                    # Test if you use ntk equations with nngp, you get the same cov
                    # Although, due to accumulation of numerical errors, only roughly.
                    self.assertAllClose(nngp_cov, nngp_ntk_cov)
Example #17
0
def conv_transpose(lhs, rhs, strides, padding,
                   rhs_dilation=None, dimension_numbers=None,
                   transpose_kernel=False, precision=None):
  """Convenience wrapper for calculating the N-d convolution "transpose".
  This function directly calculates a fractionally strided conv rather than
  indirectly calculating the gradient (transpose) of a forward convolution.

  Args:
    lhs: a rank `n+2` dimensional input array.
    rhs: a rank `n+2` dimensional array of kernel weights.
    strides: sequence of `n` integers, sets fractional stride.
    padding: 'SAME', 'VALID' will set as transpose of corresponding forward
      conv, or a sequence of `n` integer 2-tuples describing before-and-after
      padding for each `n` spatial dimension.
    rhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
      is also known as atrous convolution.
    dimension_numbers: tuple of dimension descriptors as in
      lax.conv_general_dilated. Defaults to tensorflow convention.
    transpose_kernel: if True flips spatial axes and swaps the input/output
      channel axes of the kernel. This makes the output of this function identical
      to the gradient-derived functions like keras.layers.Conv2DTranspose
      applied to the same kernel. For typical use in neural nets this is completely
      pointless and just makes input/output channel specification confusing.
    precision: Optional. Either `None`, which means the default precision for
      the backend, or a `Precision` enum value.

  Returns:
    Transposed N-d convolution, with output padding following the conventions of
    keras.layers.Conv2DTranspose.
  """
  assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) > 2
  ndims = len(lhs.shape)
  one = (1,) * (ndims - 2)
  # Set dimensional layout defaults if not specified.
  if dimension_numbers is None:
    if ndims == 3:
      dimension_numbers = ('NHC', 'HIO', 'NHC')
    elif ndims == 4:
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
    elif ndims == 5:
      dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
    else:
      raise ValueError('No 4+ dimensional dimension_number defaults.')
  dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
  k_shape = np.take(rhs.shape, dn.rhs_spec)
  k_sdims = k_shape[2:]
  # Calculate correct output shape given padding and strides.
  pads: Union[str, Sequence[Tuple[int, int]]]
  if padding in {'SAME', 'VALID'}:
    if rhs_dilation is None:
      rhs_dilation = (1,) * (rhs.ndim - 2)
    effective_k_size = map(lambda k, r: (k-1) * r + 1, k_sdims, rhs_dilation)
    pads = [_conv_transpose_padding(k, s, padding)
            for k,s in zip(effective_k_size, strides)]
  else:
    pads = padding
  if transpose_kernel:
    # flip spatial dims and swap input / output channel axes
    rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
    rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
  return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn)