def test_get_rng_state(test_case): cpu_gen = flow.default_generator state = cpu_gen.get_state() rng_state = flow.get_rng_state() test_case.assertTrue(np.allclose(state.numpy(), rng_state.numpy())) flow.randn(100, 100, dtype=flow.float32, device="cpu", generator=cpu_gen) state = cpu_gen.get_state() rng_state = flow.get_rng_state() test_case.assertTrue(np.allclose(state.numpy(), rng_state.numpy()))
def test_set_rng_state(test_case): flow.randn(100, 100) state = flow.get_rng_state() flow.randn(100, 100) new_state = flow.get_rng_state() test_case.assertTrue(not np.allclose(new_state.numpy(), state.numpy())) flow.set_rng_state(state) new_state = flow.get_rng_state() test_case.assertTrue(np.allclose(new_state.numpy(), state.numpy()))
def test_set_get_rng_state(test_case): x = flow.ByteTensor(5000) flow.set_rng_state(x) y = flow.get_rng_state() test_case.assertTrue(np.allclose(x.numpy(), y.numpy()))