예제 #1
0
    def testZeroTimeAgreement(self, train_shape, test_shape, network,
                              out_logits):
        """Test that the NTK and NNGP agree at t=0."""
        _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                       train_shape)
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        predictor = predict.gradient_descent_mse_ensemble(ker_fun,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=reg)

        for x in (None, 'x_test'):
            with self.subTest(x=x):
                x = x if x is None else x_test
                zero = predictor(t=0.0,
                                 x_test=x,
                                 get=('NTK', 'NNGP'),
                                 compute_cov=True)
                if x is None:
                    k = ker_fun(x_train, None, get='nngp')
                    ref = (np.zeros_like(y_train, k.dtype), k)
                else:
                    ref = (np.zeros((test_shape[0], out_logits)),
                           ker_fun(x_test, None, get='nngp'))

                self.assertAllClose((ref, ) * 2, zero, check_dtypes=False)
                if x is None:
                    zero_x = predictor(t=0.0,
                                       x_test=x_train,
                                       get=('NTK', 'NNGP'),
                                       compute_cov=True)
                    self.assertAllClose((ref, ) * 2, zero_x)
예제 #2
0
    def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape):
        if isinstance(fx_train_or_state_0, ODEState):
            fx_train_0 = fx_train_or_state_0.fx_train
            fx_test_0 = fx_train_or_state_0.fx_test
            qx_train_0 = fx_train_or_state_0.qx_train
            qx_test_0 = fx_train_or_state_0.qx_test
        else:
            fx_train_0 = fx_train_or_state_0
            qx_train_0 = qx_test_0 = None

        if fx_train_0 is None:
            fx_train_0 = np.zeros_like(y_train, dtype)
        else:
            fx_train_0 = np.broadcast_to(fx_train_0, y_train.shape)

        if fx_test_0 is not None:
            fx_test_0 = np.broadcast_to(fx_test_0, fx_test_shape)

        if momentum is None:
            if qx_train_0 is not None or qx_test_0 is not None:
                raise ValueError('Got passed momentum state variables, while '
                                 '`momentum is None`.')
        else:
            qx_train_0 = (np.zeros_like(y_train, dtype) if qx_train_0 is None
                          else np.broadcast_to(qx_train_0, y_train.shape))
            qx_test_0 = (None if fx_test_0 is None else
                         (np.zeros(fx_test_shape, dtype) if qx_test_0 is None
                          else np.broadcast_to(qx_test_0, fx_test_shape)))

        return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0)  # pytype: disable=wrong-arg-count
예제 #3
0
 def init_fun(rng, input_shape):
   # Move the batch and channel dimension of the input shape such
   # that it is of data format "NHWC"
   shape = [input_shape[batch_dim]]
   for i in range(len(input_shape)):
     if i not in [batch_dim, channel_dim]:
       shape.append(input_shape[i])
   shape.append(input_shape[channel_dim])
   out_shape = reduce_window_shape_tuple(shape, window_shape,
                                             strides, padding)
   return tfnp.zeros(out_shape), ()
예제 #4
0
 def init_fun(rng, input_shape):
   output_shape = input_shape[:-1] + (out_dim,)
   keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2)
   k1 = keys[0]
   k2 = keys[1]
   # convert the two keys from shape (2,) into a scalar
   k1 = stateless_uniform(shape=[], seed=k1, minval=None, maxval=None, dtype=tf.int32)
   k2 = stateless_uniform(shape=[], seed=k2, minval=None, maxval=None, dtype=tf.int32)
   W = W_init(seed=k1, shape=(input_shape[-1], out_dim))
   b = b_init(seed=k2, shape=(out_dim,))
   return tfnp.zeros(output_shape), (W.numpy(), b.numpy())
예제 #5
0
def get_masked_array(x: ArrayOrList,
                     mask_constant: float = None) -> MaskedArray:
  """Return `x` with entries equal to `mask_constant` zeroed-out, and the mask.

  The mask returned is a boolean `np.ndarray` with masked indices having `True`.

  Args:
    x: `np.ndarray` to mask. If `x` is a `MaskedInput`, treat it as
      `(masked_x, mask)` and pass it through.
    mask_constant: an optional `float`, the value in inputs to be considered as
      masked (e.g. padding in a batch of sentences). `None` means no masking.
      Can also be `np.nan`, `np.inf` etc.

  Returns:
    A `MaskedArray` of `(masked_x, boolean_mask)`.
  """
  if isinstance(x, list):
    x_array = []
    mask_array = []
    for x_ in x:
      masked_array = get_masked_array(x_, mask_constant)
      x_array.append(masked_array.masked_value)
      mask_array.append(masked_array.mask)
    # fields = zip(*(get_masked_array(_x, mask_constant).astuple() for _x in x))
    # return MaskedArray(*(list(f) for f in fields))
    return MaskedArray(x_array, mask_array)

  if x is None:
    mask = None

  if isinstance(x, MaskedArray):
    masked_value = x.masked_value
    mask = x.mask
    x = masked_value

  elif isinstance(x, np.ndarray) or isinstance(x, onp.ndarray):
    x = np.asarray(x)
    if mask_constant is None:
      mask = None
    else:
      choice_a = lambda: np.array(tf.math.is_nan(x))
      choice_b = lambda: x == mask_constant
      # mask = choice_a(x) if math.isnan(mask_constant) else choice_b(x)
      mask = tf.cond(tf.math.is_nan(mask_constant), choice_a, choice_b)
  else:
    raise TypeError(x, type(x))

  if mask is not None:
    x = np.where(mask, np.zeros((), x.dtype), x)

  return MaskedArray(x, mask)  # pytype: disable=wrong-arg-count
    def testSize(self):
        def run_test(arr):
            onp_arr = arr.numpy() if isinstance(arr, tf.Tensor) else arr
            print(onp_arr)
            self.assertEqual(np_size(arr), onp.size(onp_arr))

        run_test(np.array([1]))
        run_test(np.array([1, 2, 3, 4, 5]))
        run_test(np.ones((2, 3, 2)))
        run_test(np.ones((3, 2)))
        run_test(np.zeros((5, 6, 7)))
        run_test(1)
        run_test(onp.ones((3, 2, 1)))
        run_test(tf.constant(5))
        run_test(tf.constant([1, 1, 1]))
예제 #7
0
 def init_fun(rng, input_shape):
   input_shape = shape_conversion(input_shape)
   filter_shape_iter = iter(filter_shape)
   kernel_shape = [out_chan if c == 'O' else
                   input_shape[lhs_spec.index('C')] if c == 'I' else
                   next(filter_shape_iter) for c in rhs_spec]
   output_shape = conv_general_shape_tuple(
       input_shape, kernel_shape, strides, padding, dimension_numbers)
   bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
   bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
   keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2)
   k1 = keys[0]
   k2 = keys[1]
   W = W_init(seed=k1, shape=kernel_shape)
   b = b_init(stddev=1e-6, seed=k2, shape=bias_shape)
   return tfnp.zeros(output_shape), (W, b)
예제 #8
0
        def run_test(*args):
            num_samples = 1000
            tol = 0.1  # High tolerance to keep the # of samples low else the test
            # takes a long time to run.
            np_random.seed(10)
            outputs = [np_random.randn(*args) for _ in range(num_samples)]

            # Test output shape.
            for output in outputs:
                self.assertEqual(output.shape, tuple(args))
                default_dtype = (np.float64 if np_dtypes.is_allow_float64()
                                 else np.float32)
                self.assertEqual(output.dtype.as_numpy_dtype, default_dtype)

            if np.prod(args):  # Don't bother with empty arrays.
                outputs = [output.tolist() for output in outputs]

                # Test that the properties of normal distribution are satisfied.
                mean = np.mean(outputs, axis=0)
                stddev = np.std(outputs, axis=0)
                self.assertAllClose(mean, np.zeros(args), atol=tol)
                self.assertAllClose(stddev, np.ones(args), atol=tol)

                # Test that outputs are different with different seeds.
                np_random.seed(20)
                diff_seed_outputs = [
                    np_random.randn(*args).tolist() for _ in range(num_samples)
                ]
                self.assertNotAllClose(outputs, diff_seed_outputs)

                # Test that outputs are the same with the same seed.
                np_random.seed(10)
                same_seed_outputs = [
                    np_random.randn(*args).tolist() for _ in range(num_samples)
                ]
                self.assertAllClose(outputs, same_seed_outputs)
예제 #9
0
 def test_broadcast(self, low_shape, high_shape, size):
     low = np.zeros(low_shape).astype(np.float64)
     high = np.ones(high_shape).astype(np.float64)
     self._test(low=low, high=high, size=size)
예제 #10
0
def mask(x: Optional[np.ndarray], mask_mat: Optional[np.ndarray]):
  if x is None or mask_mat is None:
    return x
  return np.where(mask_mat, np.zeros((), x.dtype), x)
예제 #11
0
 def init(x0):
   vs = [np.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
   return x0, np.zeros_like(x0), vs
예제 #12
0
 def init_fun(rng, input_shape):
   output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
   return tfnp.zeros(output_shape), ()
예제 #13
0
    def predict_fn(
        t: ArrayOrScalar = None,
        fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
        fx_test_0: ArrayOrScalar = None,
        k_test_train: np.ndarray = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]:
        """Return output predictions on train [and test] set[s] at time[s] `t`.

    Args:
      t:
        a scalar or array of scalars of any shape in strictly increasing order.
        `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of
        training steps (but can be fractional).
      fx_train_or_state_0:
        either (a) output of the network at `t == 0` on the training set or (b)
        complete ODE state (`predict.ODEState`). Pass an ODE state if you want
        to operate on the full ODE state instead of output variables only
        (useful for inspecting auxiliary variables or resuming an optimizer with
        auxiliary variables from a specific state. Note that only
        `momentum != None` optimizer currently has auxiliary variables. To
        initialize an ODE state from scratch, call
        `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an
        ODE state is returned. `fx_train_0=None` means to not compute
        predictions on the training set.
      fx_test_0:
        output of the network at `t == 0` on the test set. `fx_test_0=None`
        means to not compute predictions on the test set.
      k_test_train:
        kernel relating test data with training data. Must have the shape of
        `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass
        `k_test_train=None` if you only need predictions on the training set.

    Returns:
      `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with
      potentially additional leading time dimensions matching `t.shape`.
      Alternatively can return an `ODEState` at time[s] `t`.

    Raises:
      ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`.
    """
        _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train)

        t = np.array(t if t is not None else np.inf, dtype) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, ))

        # ODE solver requires `t[0]` to be the time where `fx_train_0` [and
        # `fx_test_0`] are evaluated, but also a strictly increasing sequence of
        # timesteps, so we always temporarily append an [almost] `0` at the start.
        identity = lambda x: x
        t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype),
                      np.zeros((1, ), t.dtype))
        t = np.concatenate([t0, t])

        # Solve the ODE.
        fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes)
        state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape)
        state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t)

        # Remove the added `t0`.
        trim = lambda x: x[1:].reshape(t_shape + x.shape[1:])
        trim_tree = lambda tree: tree_map(trim, tree)
        state_t = trim_tree(state_t)

        # `ODEState` -> `ODEState`
        if isinstance(fx_train_or_state_0, ODEState):
            return state_t

        # `np.ndarray` -> `np.ndarray`
        fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test

        if fx_train_or_state_0 is not None and fx_test_0 is None:
            return fx_train_t
        if fx_test_0 is not None and fx_train_or_state_0 is None:
            return fx_test_t
        return fx_train_t, fx_test_t
예제 #14
0
 def init_fun(rng, input_shape):
   return tfnp.zeros(input_shape), ()
예제 #15
0
 def init_fun(rng, input_shape):
   ax = axis % len(input_shape[0])
   concat_size = sum(shape[ax] for shape in input_shape)
   out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
   return tfnp.zeros(out_shape), ()
예제 #16
0
def FanInSum():
  """Layer construction function for a fan-in sum layer."""
  init_fun = lambda rng, input_shape: (tfnp.zeros(input_shape[0]), ())
  apply_fun = lambda params, inputs, **kwargs: sum(inputs)
  return init_fun, apply_fun
예제 #17
0
 def init_fun(rng, input_shape):
   return ([tfnp.zeros(input_shape)] * num, ())
예제 #18
0
def Identity():
  """Layer construction function for an identity layer."""
  init_fun = lambda rng, input_shape: (tfnp.zeros(input_shape), ())
  apply_fun = lambda params, inputs, **kwargs: inputs
  return init_fun, apply_fun