def test_random_state(self): """Test RandomState behavior.""" v1 = objax.RandomState(0) v2 = objax.RandomState(1) v3 = objax.RandomState(0) self.assertEqual(v1.value.tolist(), [0, 0]) self.assertEqual(v2.value.tolist(), [0, 1]) s1 = v1.split(1)[0] self.assertEqual(s1.tolist(), [2718843009, 1272950319]) self.assertEqual(v1.value.tolist(), [4146024105, 967050713]) s2 = v1.split(1)[0] self.assertEqual(s2.tolist(), [1278412471, 2182328957]) self.assertEqual(v1.value.tolist(), [2384771982, 3928867769]) s1, s2 = v3.split(2) self.assertEqual(s1.tolist(), [3186719485, 3840466878]) self.assertEqual(s2.tolist(), [2562233961, 1946702221]) self.assertEqual(v3.value.tolist(), [2467461003, 428148500])
def test_var_hierarchy(self): """Test variable hierarchy.""" t = objax.TrainVar(jn.zeros(2)) s = objax.StateVar(jn.zeros(2)) r = objax.TrainRef(t) x = objax.RandomState(0) self.assertIsInstance(t, objax.TrainVar) self.assertIsInstance(t, objax.BaseVar) self.assertNotIsInstance(t, objax.BaseState) self.assertIsInstance(s, objax.BaseVar) self.assertIsInstance(s, objax.BaseState) self.assertNotIsInstance(s, objax.TrainVar) self.assertIsInstance(r, objax.BaseVar) self.assertIsInstance(r, objax.BaseState) self.assertNotIsInstance(r, objax.TrainVar) self.assertIsInstance(x, objax.BaseVar) self.assertIsInstance(x, objax.BaseState) self.assertIsInstance(x, objax.StateVar) self.assertNotIsInstance(x, objax.TrainVar)