def update(i, g, state):
   x, g_sq, m = state
   g_sq += np.square(g)
   g_sq_inv_sqrt = np.where(g_sq > 0, 1. / np.sqrt(g_sq), 0.0)
   m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m
   x = x - step_size(i) * m
   return x, g_sq, m
 def update(i, g, state):
   x, m, vs = state
   vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)]
   accum = functools.reduce(np.minimum, vs) + np.square(g)
   accum_inv_sqrt = np.where(accum > 0, 1. / np.sqrt(accum), 0)
   m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m
   x = x - step_size(i) * m
   vs = [accum.max(splice(range(x.ndim), j, [])) for j in range(x.ndim)]
   return x, m, vs
 def kernel_fn_sample_once(x1: np.ndarray, x2: Optional[np.ndarray],
                           key: PRNGKey, get: Get, **apply_fn_kwargs):
     splits = tf_split(key, 3)
     init_key = splits[0]
     dropout_key1 = splits[1]
     dropout_key2 = splits[2]
     keys = np.where(utils.x1_is_x2(x1, x2), dropout_key1,
                     np.stack([dropout_key1, dropout_key2]))
     _, params = init_fn(init_key, x1.shape)
     return kernel_fn(x1, x2, get, params, keys=keys, **apply_fn_kwargs)
示例#4
0
def get_masked_array(x: ArrayOrList,
                     mask_constant: float = None) -> MaskedArray:
  """Return `x` with entries equal to `mask_constant` zeroed-out, and the mask.

  The mask returned is a boolean `np.ndarray` with masked indices having `True`.

  Args:
    x: `np.ndarray` to mask. If `x` is a `MaskedInput`, treat it as
      `(masked_x, mask)` and pass it through.
    mask_constant: an optional `float`, the value in inputs to be considered as
      masked (e.g. padding in a batch of sentences). `None` means no masking.
      Can also be `np.nan`, `np.inf` etc.

  Returns:
    A `MaskedArray` of `(masked_x, boolean_mask)`.
  """
  if isinstance(x, list):
    x_array = []
    mask_array = []
    for x_ in x:
      masked_array = get_masked_array(x_, mask_constant)
      x_array.append(masked_array.masked_value)
      mask_array.append(masked_array.mask)
    # fields = zip(*(get_masked_array(_x, mask_constant).astuple() for _x in x))
    # return MaskedArray(*(list(f) for f in fields))
    return MaskedArray(x_array, mask_array)

  if x is None:
    mask = None

  if isinstance(x, MaskedArray):
    masked_value = x.masked_value
    mask = x.mask
    x = masked_value

  elif isinstance(x, np.ndarray) or isinstance(x, onp.ndarray):
    x = np.asarray(x)
    if mask_constant is None:
      mask = None
    else:
      choice_a = lambda: np.array(tf.math.is_nan(x))
      choice_b = lambda: x == mask_constant
      # mask = choice_a(x) if math.isnan(mask_constant) else choice_b(x)
      mask = tf.cond(tf.math.is_nan(mask_constant), choice_a, choice_b)
  else:
    raise TypeError(x, type(x))

  if mask is not None:
    x = np.where(mask, np.zeros((), x.dtype), x)

  return MaskedArray(x, mask)  # pytype: disable=wrong-arg-count
示例#5
0
def _read_keys(key, x1, x2):
  """Read dropout key.

     `key` might be a tuple of two rng keys or a single rng key or None. In
     either case, `key` will be mapped into two rng keys `key1` and `key2` to
     make sure `(x1==x2) == (key1==key2)`.
  """

  if key is None or x2 is None:
    key1 = key2 = key
  elif isinstance(key, tuple) and len(key) == 2:
    key1, key2 = key
    new_key = np.where(utils.x1_is_x2(key1, key2),
                       random.fold_in(key2, 1), key2)
    key2 = np.where(utils.x1_is_x2(x1, x2), key1, new_key)
    warnings.warn('The value of `key[1]` might be replaced by a new value if '
                  'key[0] == key[1] and x1 != x2 or key[0] != key[1] and '
                  'x1 == x2.')
  elif isinstance(key, np.ndarray):
    key1 = key
    key2 = np.where(utils.x1_is_x2(x1, x2), key1, random.fold_in(key, 1))
  else:
    raise TypeError(type(key))
  return key1, key2
示例#6
0
 def apply_fun(params, inputs, **kwargs):
     rng = kwargs.get('rng', None)
     if rng is None:
         msg = (
             "Dropout layer requires apply_fun to be called with a PRNG key "
             "argument. That is, instead of `apply_fun(params, inputs)`, call "
             "it like `apply_fun(params, inputs, rng)` where `rng` is a "
             "jax.random.PRNGKey value.")
         raise ValueError(msg)
     if mode == 'train':
         prob = tf.ones(inputs.shape) * rate
         keep = stateless_uniform(
             shape=inputs.shape, seed=rng, minval=0, maxval=1) < prob
         return tfnp.where(keep, inputs / rate, 0)
     else:
         return inputs
示例#7
0
 def f(x):
     # Note that shape of input to len is data dependent.
     return len(np.where(x)[0])
示例#8
0
def mask(x: Optional[np.ndarray], mask_mat: Optional[np.ndarray]):
  if x is None or mask_mat is None:
    return x
  return np.where(mask_mat, np.zeros((), x.dtype), x)
def clip_grads(grad_tree, max_norm):
  """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
  norm = l2_norm(grad_tree)
  normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm))
  return tree_map(normalize, grad_tree)
示例#10
0
    def predict_fn(
        t: ArrayOrScalar = None,
        fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
        fx_test_0: ArrayOrScalar = None,
        k_test_train: np.ndarray = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]:
        """Return output predictions on train [and test] set[s] at time[s] `t`.

    Args:
      t:
        a scalar or array of scalars of any shape in strictly increasing order.
        `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of
        training steps (but can be fractional).
      fx_train_or_state_0:
        either (a) output of the network at `t == 0` on the training set or (b)
        complete ODE state (`predict.ODEState`). Pass an ODE state if you want
        to operate on the full ODE state instead of output variables only
        (useful for inspecting auxiliary variables or resuming an optimizer with
        auxiliary variables from a specific state. Note that only
        `momentum != None` optimizer currently has auxiliary variables. To
        initialize an ODE state from scratch, call
        `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an
        ODE state is returned. `fx_train_0=None` means to not compute
        predictions on the training set.
      fx_test_0:
        output of the network at `t == 0` on the test set. `fx_test_0=None`
        means to not compute predictions on the test set.
      k_test_train:
        kernel relating test data with training data. Must have the shape of
        `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass
        `k_test_train=None` if you only need predictions on the training set.

    Returns:
      `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with
      potentially additional leading time dimensions matching `t.shape`.
      Alternatively can return an `ODEState` at time[s] `t`.

    Raises:
      ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`.
    """
        _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train)

        t = np.array(t if t is not None else np.inf, dtype) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, ))

        # ODE solver requires `t[0]` to be the time where `fx_train_0` [and
        # `fx_test_0`] are evaluated, but also a strictly increasing sequence of
        # timesteps, so we always temporarily append an [almost] `0` at the start.
        identity = lambda x: x
        t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype),
                      np.zeros((1, ), t.dtype))
        t = np.concatenate([t0, t])

        # Solve the ODE.
        fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes)
        state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape)
        state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t)

        # Remove the added `t0`.
        trim = lambda x: x[1:].reshape(t_shape + x.shape[1:])
        trim_tree = lambda tree: tree_map(trim, tree)
        state_t = trim_tree(state_t)

        # `ODEState` -> `ODEState`
        if isinstance(fx_train_or_state_0, ODEState):
            return state_t

        # `np.ndarray` -> `np.ndarray`
        fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test

        if fx_train_or_state_0 is not None and fx_test_0 is None:
            return fx_train_t
        if fx_test_0 is not None and fx_train_or_state_0 is None:
            return fx_test_t
        return fx_train_t, fx_test_t