コード例 #1
0
 def testFoldIn(self, dtype):
   """Test for `fold_in`."""
   orig_seed = constant_op.constant([1, 2], dtype='int32')
   seed = stateless.fold_in(orig_seed, constant_op.constant(3, dtype=dtype))
   new_seeds = []
   new_seeds.append(seed)
   seed = stateless.fold_in(seed, constant_op.constant(4, dtype=dtype))
   new_seeds.append(seed)
   for s in new_seeds:
     self.assertEqual(s.shape, [2])
     self.assertDTypeEqual(s.dtype, dtype)
   self.assertNoEqualPair([math_ops.cast(orig_seed, dtype)] + new_seeds)
コード例 #2
0
def _read_keys(key, x1, x2):
  """Read dropout key.

     `key` might be a tuple of two rng keys or a single rng key or None. In
     either case, `key` will be mapped into two rng keys `key1` and `key2` to
     make sure `(x1==x2) == (key1==key2)`.
  """

  if key is None or x2 is None:
    key1 = key2 = key
  elif isinstance(key, tuple) and len(key) == 2:
    key1, key2 = key
    new_key = np.where(utils.x1_is_x2(key1, key2),
                       random.fold_in(key2, 1), key2)
    key2 = np.where(utils.x1_is_x2(x1, x2), key1, new_key)
    warnings.warn('The value of `key[1]` might be replaced by a new value if '
                  'key[0] == key[1] and x1 != x2 or key[0] != key[1] and '
                  'x1 == x2.')
  elif isinstance(key, np.ndarray):
    key1 = key
    key2 = np.where(utils.x1_is_x2(x1, x2), key1, random.fold_in(key, 1))
  else:
    raise TypeError(type(key))
  return key1, key2