def testFoldIn(self, dtype): """Test for `fold_in`.""" orig_seed = constant_op.constant([1, 2], dtype='int32') seed = stateless.fold_in(orig_seed, constant_op.constant(3, dtype=dtype)) new_seeds = [] new_seeds.append(seed) seed = stateless.fold_in(seed, constant_op.constant(4, dtype=dtype)) new_seeds.append(seed) for s in new_seeds: self.assertEqual(s.shape, [2]) self.assertDTypeEqual(s.dtype, dtype) self.assertNoEqualPair([math_ops.cast(orig_seed, dtype)] + new_seeds)
def _read_keys(key, x1, x2): """Read dropout key. `key` might be a tuple of two rng keys or a single rng key or None. In either case, `key` will be mapped into two rng keys `key1` and `key2` to make sure `(x1==x2) == (key1==key2)`. """ if key is None or x2 is None: key1 = key2 = key elif isinstance(key, tuple) and len(key) == 2: key1, key2 = key new_key = np.where(utils.x1_is_x2(key1, key2), random.fold_in(key2, 1), key2) key2 = np.where(utils.x1_is_x2(x1, x2), key1, new_key) warnings.warn('The value of `key[1]` might be replaced by a new value if ' 'key[0] == key[1] and x1 != x2 or key[0] != key[1] and ' 'x1 == x2.') elif isinstance(key, np.ndarray): key1 = key key2 = np.where(utils.x1_is_x2(x1, x2), key1, random.fold_in(key, 1)) else: raise TypeError(type(key)) return key1, key2