Exemple #1
0
def _read_keys(key, x1, x2):
    """Read dropout key.

     `key` might be a tuple of two rng keys or a single rng key or None. In
     either case, `key` will be mapped into two rng keys `key1` and `key2` to
     make sure `(x1==x2) == (key1==key2)`.
  """

    if key is None or x2 is None:
        key1 = key2 = key
    elif isinstance(key, tuple) and len(key) == 2:
        key1, key2 = key
        new_key = np.where(utils.x1_is_x2(key1, key2), random.fold_in(key2, 1),
                           key2)
        key2 = np.where(utils.x1_is_x2(x1, x2), key1, new_key)
        warnings.warn(
            'The value of `key[1]` might be replaced by a new value if '
            'key[0] == key[1] and x1 != x2 or key[0] != key[1] and '
            'x1 == x2.')
    elif isinstance(key, np.ndarray):
        key1 = key
        key2 = np.where(utils.x1_is_x2(x1, x2), key1, random.fold_in(key, 1))
    else:
        raise TypeError(type(key))
    return key1, key2
Exemple #2
0
 def kernel_fn_sample_once(x1: np.ndarray, x2: Optional[np.ndarray],
                           key: PRNGKey, get: Get, **apply_fn_kwargs):
     init_key, dropout_key1, dropout_key2 = random.split(key, 3)
     keys = np.where(utils.x1_is_x2(x1, x2), dropout_key1,
                     np.stack([dropout_key1, dropout_key2]))
     _, params = init_fn(init_key, x1.shape)
     return kernel_fn(x1, x2, get, params, keys=keys, **apply_fn_kwargs)
 def kernel_fn_sample_once(x1, x2, key, get):
     init_key, dropout_key1, dropout_key2 = random.split(key, 3)
     keys = np.where(utils.x1_is_x2(x1, x2), dropout_key1,
                     (dropout_key1, dropout_key2))
     _, params = init_fn(init_key, x1.shape)
     return kernel_fn(x1, x2, params, get, keys=keys)