def _read_keys(key, x1, x2): """Read dropout key. `key` might be a tuple of two rng keys or a single rng key or None. In either case, `key` will be mapped into two rng keys `key1` and `key2` to make sure `(x1==x2) == (key1==key2)`. """ if key is None or x2 is None: key1 = key2 = key elif isinstance(key, tuple) and len(key) == 2: key1, key2 = key new_key = np.where(utils.x1_is_x2(key1, key2), random.fold_in(key2, 1), key2) key2 = np.where(utils.x1_is_x2(x1, x2), key1, new_key) warnings.warn( 'The value of `key[1]` might be replaced by a new value if ' 'key[0] == key[1] and x1 != x2 or key[0] != key[1] and ' 'x1 == x2.') elif isinstance(key, np.ndarray): key1 = key key2 = np.where(utils.x1_is_x2(x1, x2), key1, random.fold_in(key, 1)) else: raise TypeError(type(key)) return key1, key2
def kernel_fn_sample_once(x1: np.ndarray, x2: Optional[np.ndarray], key: PRNGKey, get: Get, **apply_fn_kwargs): init_key, dropout_key1, dropout_key2 = random.split(key, 3) keys = np.where(utils.x1_is_x2(x1, x2), dropout_key1, np.stack([dropout_key1, dropout_key2])) _, params = init_fn(init_key, x1.shape) return kernel_fn(x1, x2, get, params, keys=keys, **apply_fn_kwargs)
def kernel_fn_sample_once(x1, x2, key, get): init_key, dropout_key1, dropout_key2 = random.split(key, 3) keys = np.where(utils.x1_is_x2(x1, x2), dropout_key1, (dropout_key1, dropout_key2)) _, params = init_fn(init_key, x1.shape) return kernel_fn(x1, x2, params, get, keys=keys)