def update(i, g, state): x, g_sq, m = state g_sq += np.square(g) g_sq_inv_sqrt = np.where(g_sq > 0, 1. / np.sqrt(g_sq), 0.0) m = (1. - momentum) * (g * g_sq_inv_sqrt) + momentum * m x = x - step_size(i) * m return x, g_sq, m
def update(i, g, state): x, m, vs = state vs = [broadcast_into(g.ndim, v, i) for i, v in enumerate(vs)] accum = functools.reduce(np.minimum, vs) + np.square(g) accum_inv_sqrt = np.where(accum > 0, 1. / np.sqrt(accum), 0) m = (1. - momentum) * (g * accum_inv_sqrt) + momentum * m x = x - step_size(i) * m vs = [accum.max(splice(range(x.ndim), j, [])) for j in range(x.ndim)] return x, m, vs
def kernel_fn_sample_once(x1: np.ndarray, x2: Optional[np.ndarray], key: PRNGKey, get: Get, **apply_fn_kwargs): splits = tf_split(key, 3) init_key = splits[0] dropout_key1 = splits[1] dropout_key2 = splits[2] keys = np.where(utils.x1_is_x2(x1, x2), dropout_key1, np.stack([dropout_key1, dropout_key2])) _, params = init_fn(init_key, x1.shape) return kernel_fn(x1, x2, get, params, keys=keys, **apply_fn_kwargs)
def get_masked_array(x: ArrayOrList, mask_constant: float = None) -> MaskedArray: """Return `x` with entries equal to `mask_constant` zeroed-out, and the mask. The mask returned is a boolean `np.ndarray` with masked indices having `True`. Args: x: `np.ndarray` to mask. If `x` is a `MaskedInput`, treat it as `(masked_x, mask)` and pass it through. mask_constant: an optional `float`, the value in inputs to be considered as masked (e.g. padding in a batch of sentences). `None` means no masking. Can also be `np.nan`, `np.inf` etc. Returns: A `MaskedArray` of `(masked_x, boolean_mask)`. """ if isinstance(x, list): x_array = [] mask_array = [] for x_ in x: masked_array = get_masked_array(x_, mask_constant) x_array.append(masked_array.masked_value) mask_array.append(masked_array.mask) # fields = zip(*(get_masked_array(_x, mask_constant).astuple() for _x in x)) # return MaskedArray(*(list(f) for f in fields)) return MaskedArray(x_array, mask_array) if x is None: mask = None if isinstance(x, MaskedArray): masked_value = x.masked_value mask = x.mask x = masked_value elif isinstance(x, np.ndarray) or isinstance(x, onp.ndarray): x = np.asarray(x) if mask_constant is None: mask = None else: choice_a = lambda: np.array(tf.math.is_nan(x)) choice_b = lambda: x == mask_constant # mask = choice_a(x) if math.isnan(mask_constant) else choice_b(x) mask = tf.cond(tf.math.is_nan(mask_constant), choice_a, choice_b) else: raise TypeError(x, type(x)) if mask is not None: x = np.where(mask, np.zeros((), x.dtype), x) return MaskedArray(x, mask) # pytype: disable=wrong-arg-count
def _read_keys(key, x1, x2): """Read dropout key. `key` might be a tuple of two rng keys or a single rng key or None. In either case, `key` will be mapped into two rng keys `key1` and `key2` to make sure `(x1==x2) == (key1==key2)`. """ if key is None or x2 is None: key1 = key2 = key elif isinstance(key, tuple) and len(key) == 2: key1, key2 = key new_key = np.where(utils.x1_is_x2(key1, key2), random.fold_in(key2, 1), key2) key2 = np.where(utils.x1_is_x2(x1, x2), key1, new_key) warnings.warn('The value of `key[1]` might be replaced by a new value if ' 'key[0] == key[1] and x1 != x2 or key[0] != key[1] and ' 'x1 == x2.') elif isinstance(key, np.ndarray): key1 = key key2 = np.where(utils.x1_is_x2(x1, x2), key1, random.fold_in(key, 1)) else: raise TypeError(type(key)) return key1, key2
def apply_fun(params, inputs, **kwargs): rng = kwargs.get('rng', None) if rng is None: msg = ( "Dropout layer requires apply_fun to be called with a PRNG key " "argument. That is, instead of `apply_fun(params, inputs)`, call " "it like `apply_fun(params, inputs, rng)` where `rng` is a " "jax.random.PRNGKey value.") raise ValueError(msg) if mode == 'train': prob = tf.ones(inputs.shape) * rate keep = stateless_uniform( shape=inputs.shape, seed=rng, minval=0, maxval=1) < prob return tfnp.where(keep, inputs / rate, 0) else: return inputs
def f(x): # Note that shape of input to len is data dependent. return len(np.where(x)[0])
def mask(x: Optional[np.ndarray], mask_mat: Optional[np.ndarray]): if x is None or mask_mat is None: return x return np.where(mask_mat, np.zeros((), x.dtype), x)
def clip_grads(grad_tree, max_norm): """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`.""" norm = l2_norm(grad_tree) normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm)) return tree_map(normalize, grad_tree)
def predict_fn( t: ArrayOrScalar = None, fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0., fx_test_0: ArrayOrScalar = None, k_test_train: np.ndarray = None ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]: """Return output predictions on train [and test] set[s] at time[s] `t`. Args: t: a scalar or array of scalars of any shape in strictly increasing order. `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of training steps (but can be fractional). fx_train_or_state_0: either (a) output of the network at `t == 0` on the training set or (b) complete ODE state (`predict.ODEState`). Pass an ODE state if you want to operate on the full ODE state instead of output variables only (useful for inspecting auxiliary variables or resuming an optimizer with auxiliary variables from a specific state. Note that only `momentum != None` optimizer currently has auxiliary variables. To initialize an ODE state from scratch, call `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an ODE state is returned. `fx_train_0=None` means to not compute predictions on the training set. fx_test_0: output of the network at `t == 0` on the test set. `fx_test_0=None` means to not compute predictions on the test set. k_test_train: kernel relating test data with training data. Must have the shape of `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass `k_test_train=None` if you only need predictions on the training set. Returns: `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with potentially additional leading time dimensions matching `t.shape`. Alternatively can return an `ODEState` at time[s] `t`. Raises: ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`. """ _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train) t = np.array(t if t is not None else np.inf, dtype) * learning_rate t_shape = t.shape t = t.reshape((-1, )) # ODE solver requires `t[0]` to be the time where `fx_train_0` [and # `fx_test_0`] are evaluated, but also a strictly increasing sequence of # timesteps, so we always temporarily append an [almost] `0` at the start. identity = lambda x: x t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype), np.zeros((1, ), t.dtype)) t = np.concatenate([t0, t]) # Solve the ODE. fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes) state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape) state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t) # Remove the added `t0`. trim = lambda x: x[1:].reshape(t_shape + x.shape[1:]) trim_tree = lambda tree: tree_map(trim, tree) state_t = trim_tree(state_t) # `ODEState` -> `ODEState` if isinstance(fx_train_or_state_0, ODEState): return state_t # `np.ndarray` -> `np.ndarray` fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test if fx_train_or_state_0 is not None and fx_test_0 is None: return fx_train_t if fx_test_0 is not None and fx_train_or_state_0 is None: return fx_test_t return fx_train_t, fx_test_t