Esempio n. 1
0
    def test_trainvar_and_ref_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])
        m.append(objax.TrainRef(m[0]))

        def increase(x):
            m[0].assign(m[0].value + 1)
            m[1].assign(m[1].value + 1)
            return x + 1

        x = np.arange(10)[:, None]
        vec_increase = objax.Vectorize(increase, m.vars())
        y = vec_increase(x)
        self.assertEqual(y.tolist(), np.arange(1, 11)[:, None].tolist())
        self.assertEqual(m[0].value.tolist(), [2., 2.])
Esempio n. 2
0
 def test_vars(self):
     t = objax.TrainVar(jn.zeros([1, 2, 3, 2, 1]))
     tv = '\n'.join(['objax.TrainVar(DeviceArray([[[[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]],',
                     '              [[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]]]], dtype=float32), reduce=reduce_mean)'])
     self.assertEqual(repr(t), tv)
     r = objax.TrainRef(t)
     rv = '\n'.join(['objax.TrainRef(ref=objax.TrainVar(DeviceArray([[[[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]],',
                     '              [[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]]]], dtype=float32), reduce=reduce_mean))'])
     self.assertEqual(repr(r), rv)
     t = objax.StateVar(jn.zeros([1, 2, 3, 2, 1]))
     tv = '\n'.join(['objax.StateVar(DeviceArray([[[[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]],',
                     '              [[[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]],',
                     '               [[0.],',
                     '                [0.]]]]], dtype=float32), reduce=reduce_mean)'])
     self.assertEqual(repr(t), tv)
     self.assertEqual(repr(objax.random.Generator().key), 'objax.RandomState(DeviceArray([0, 0], dtype=uint32))')