def test_trainvar_and_ref_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) m.append(objax.TrainRef(m[0])) def increase(x): m[0].assign(m[0].value + 1) m[1].assign(m[1].value + 1) return x + 1 x = np.arange(10)[:, None] vec_increase = objax.Vectorize(increase, m.vars()) y = vec_increase(x) self.assertEqual(y.tolist(), np.arange(1, 11)[:, None].tolist()) self.assertEqual(m[0].value.tolist(), [2., 2.])
def test_vars(self): t = objax.TrainVar(jn.zeros([1, 2, 3, 2, 1])) tv = '\n'.join(['objax.TrainVar(DeviceArray([[[[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]],', ' [[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]]]], dtype=float32), reduce=reduce_mean)']) self.assertEqual(repr(t), tv) r = objax.TrainRef(t) rv = '\n'.join(['objax.TrainRef(ref=objax.TrainVar(DeviceArray([[[[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]],', ' [[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]]]], dtype=float32), reduce=reduce_mean))']) self.assertEqual(repr(r), rv) t = objax.StateVar(jn.zeros([1, 2, 3, 2, 1])) tv = '\n'.join(['objax.StateVar(DeviceArray([[[[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]],', ' [[[0.],', ' [0.]],', ' [[0.],', ' [0.]],', ' [[0.],', ' [0.]]]]], dtype=float32), reduce=reduce_mean)']) self.assertEqual(repr(t), tv) self.assertEqual(repr(objax.random.Generator().key), 'objax.RandomState(DeviceArray([0, 0], dtype=uint32))')