Ejemplo n.º 1
0
 def test_tensors(self):
     vshared = objax.TrainVar(jn.ones(1))
     vc = objax.VarCollection([('a', objax.TrainVar(jn.zeros(1))),
                               ('b', vshared)])
     vc += objax.VarCollection([('c', vshared)])
     self.assertEqual(len(vc.tensors()), 2)
     self.assertEqual([x.sum() for x in vc.tensors()], [0, 1])
Ejemplo n.º 2
0
    def test_gradvalues_constant(self):
        """Test if constants are preserved."""
        # Set data
        ndim = 1
        data = np.array([[1.0], [3.0], [5.0], [-10.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.ones(1))
        m = objax.ModuleList([w, b])

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred)**2).mean()

        # We are supposed to see the gradient change after the value of b (the constant) changes.
        gv = objax.GradValues(loss, objax.VarCollection({'w': w}))
        g_old, v_old = gv(data, labels)
        b.assign(-b.value)
        g_new, v_new = gv(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])

        # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes.
        gv = objax.Jit(objax.GradValues(loss, objax.VarCollection({'w': w})),
                       m.vars())
        g_old, v_old = gv(data, labels)
        b.assign(-b.value)
        g_new, v_new = gv(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])
Ejemplo n.º 3
0
    def test_gradvalues_linear_and_inputs(self):
        """Test if gradient of inputs and variables has the correct values for linear regression."""
        # Set data
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.array([2, 3], jn.float32))
        b = objax.TrainVar(jn.array([1], jn.float32))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred) ** 2).mean()

        expect_loss = loss(data, labels)
        expect_gw = [37.25, 69.0]
        expect_gb = [13.75]
        expect_gx = [[4.0, 6.0], [8.5, 12.75], [13.0, 19.5], [2.0, 3.0]]
        expect_gy = [-2.0, -4.25, -6.5, -1.0]

        gv0 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0,))
        g, v = gv0(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gw)
        self.assertEqual(g[2].tolist(), expect_gb)

        gv1 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1,))
        g, v = gv1(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gy)
        self.assertEqual(g[1].tolist(), expect_gw)
        self.assertEqual(g[2].tolist(), expect_gb)

        gv01 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0, 1))
        g, v = gv01(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gy)
        self.assertEqual(g[2].tolist(), expect_gw)
        self.assertEqual(g[3].tolist(), expect_gb)

        gv10 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1, 0))
        g, v = gv10(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gy)
        self.assertEqual(g[1].tolist(), expect_gx)
        self.assertEqual(g[2].tolist(), expect_gw)
        self.assertEqual(g[3].tolist(), expect_gb)

        gv10 = objax.GradValues(loss, None, input_argnums=(0, 1))
        g, v = gv10(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gy)
Ejemplo n.º 4
0
    def test_name_conflict(self):
        """Check name conflict raises a ValueError."""
        vc1 = objax.VarCollection([('a', objax.TrainVar(jn.zeros(1)))])
        vc2 = objax.VarCollection([('a', objax.TrainVar(jn.ones(1)))])
        with self.assertRaises(ValueError):
            vc1 + vc2

        with self.assertRaises(ValueError):
            vc1.update(vc2)

        with self.assertRaises(ValueError):
            vc1['a'] = objax.TrainVar(jn.ones(1))
Ejemplo n.º 5
0
 def test_len_iter(self):
     """Verify length and iterator."""
     v1 = objax.TrainVar(jn.zeros(1))
     vshared = objax.TrainVar(jn.ones(1))
     vc1 = objax.VarCollection([('a', v1), ('b', vshared)])
     vc2 = objax.VarCollection([('c', vshared)])
     vc = vc1 + vc2
     self.assertEqual(len(vc), 3)
     self.assertEqual(len(vc.keys()), 3)
     self.assertEqual(len(vc.items()), 3)
     self.assertEqual(len(vc.values()), 3)
     self.assertEqual(len(list(vc)), 2)  # Self iterator is unique.
Ejemplo n.º 6
0
 def test_opt(self):
     self.assertEqual(repr(objax.optimizer.Adam(objax.VarCollection())),
                      'objax.optimizer.Adam(beta1=0.9, beta2=0.999, eps=1e-08)')
     self.assertEqual(repr(objax.optimizer.LARS(objax.VarCollection())),
                      'objax.optimizer.LARS(momentum=0.9, weight_decay=0.0001, tc=0.001, eps=1e-05)')
     self.assertEqual(repr(objax.optimizer.Momentum(objax.VarCollection())),
                      'objax.optimizer.Momentum(momentum=0.9, nesterov=False)')
     self.assertEqual(repr(objax.optimizer.SGD(objax.VarCollection())),
                      'objax.optimizer.SGD()')
     self.assertEqual(repr(objax.optimizer.ExponentialMovingAverage(objax.VarCollection())),
                      'objax.optimizer.ExponentialMovingAverage(momentum=0.999, debias=False, eps=1e-06)')
     self.assertEqual(repr(objax.optimizer.ExponentialMovingAverageModule(objax.Module())),
                      'objax.optimizer.ExponentialMovingAverageModule(momentum=0.999, debias=False, eps=1e-06)')
Ejemplo n.º 7
0
 def test_init_list(self):
     """Initialize a VarCollection with a list."""
     vc = objax.VarCollection([('a', objax.TrainVar(jn.zeros(1))),
                               ('b', objax.TrainVar(jn.ones(1)))])
     self.assertEqual(len(vc), 2)
     self.assertEqual(vc['a'].value.sum(), 0)
     self.assertEqual(vc['b'].value.sum(), 1)
Ejemplo n.º 8
0
 def test_init_dict(self):
     """Initialize a VarCollection with a dict."""
     vc = objax.VarCollection({'a': objax.TrainVar(jn.zeros(1)),
                               'b': objax.TrainVar(jn.ones(1))})
     self.assertEqual(len(vc), 2)
     self.assertEqual(vc['a'].value.sum(), 0)
     self.assertEqual(vc['b'].value.sum(), 1)
Ejemplo n.º 9
0
    def test_gradvalues_linear(self):
        """Test if gradient has the correct value for linear regression."""
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.zeros(1))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred)**2).mean()

        gv = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}))
        g, v = gv(data, labels)

        self.assertEqual(g[0].shape, tuple([ndim]))
        self.assertEqual(g[1].shape, tuple([1]))

        g_expect_w = -(data * np.tile(labels, (ndim, 1)).transpose()).mean(0)
        g_expect_b = np.array([-labels.mean()])
        np.testing.assert_allclose(g[0], g_expect_w)
        np.testing.assert_allclose(g[1], g_expect_b)
        np.testing.assert_allclose(v[0], loss(data, labels))
Ejemplo n.º 10
0
    def test_trainvar_jit_assign(self):
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.zeros(1))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            b.assign(b.value + 1)
            w.assign(w.value - 1)
            return 0.5 * ((y - pred)**2).mean()

        grad = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}))

        def jloss(wb, x, y):
            w, b = wb
            pred = jn.dot(x, w) + b
            return 0.5 * ((y - pred)**2).mean()

        def jit_op(x, y):
            g = grad(x, y)
            b.assign(b.value * 2)
            w.assign(w.value * 3)
            return g

        jit_op = objax.Jit(jit_op, objax.VarCollection(dict(b=b, w=w)))
        jgrad = jax.grad(jloss)

        jg = jgrad([w.value, b.value], data, labels)
        g = jit_op(data, labels)
        self.assertEqual(g[0].shape, tuple([ndim]))
        self.assertEqual(g[1].shape, tuple([1]))
        np.testing.assert_allclose(g[0], jg[0])
        np.testing.assert_allclose(g[1], jg[1])
        self.assertEqual(w.value.tolist(), [-3., -3.])
        self.assertEqual(b.value.tolist(), [2.])

        jg = jgrad([w.value, b.value], data, labels)
        g = jit_op(data, labels)
        np.testing.assert_allclose(g[0], jg[0])
        np.testing.assert_allclose(g[1], jg[1])
        self.assertEqual(w.value.tolist(), [-12., -12.])
        self.assertEqual(b.value.tolist(), [6.])
Ejemplo n.º 11
0
    def test_grad_signature(self):
        def f(x: JaxArray, y) -> Tuple[JaxArray, Dict[str, JaxArray]]:
            return (x + y).mean(), {'x': x, 'y': y}

        def df(x: JaxArray, y) -> List[JaxArray]:
            pass  # Signature of the differential of f

        g = objax.Grad(f, objax.VarCollection())
        self.assertEqual(inspect.signature(g), inspect.signature(df))
Ejemplo n.º 12
0
 def test_parallel_concat_broadcast(self):
     """Parallel inference with broadcasted scalar input."""
     f = lambda x, y: x + y
     x = objax.random.normal((96, 3))
     d = jn.float32(0.5)
     y = f(x, d)
     fp = objax.Parallel(f, objax.VarCollection())
     z = fp(x, d)
     self.assertTrue(jn.array_equal(y, z))
Ejemplo n.º 13
0
Archivo: jit.py Proyecto: srxzr/objax
    def test_constant_optimization(self):
        m = objax.nn.Linear(3, 4)
        jit_constant = objax.Jit(m, objax.VarCollection())

        x = objax.random.normal((10, 3))
        self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 0)

        # Modify m (which was supposed to be constant!)
        m.b.assign(m.b.value + 1)
        self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 40)
Ejemplo n.º 14
0
 def get_g(microbatch, l2_norm_clip, batch_axis=(0, )):
     gv_priv = objax.privacy.dpsgd.PrivateGradValues(
         loss,
         objax.VarCollection({'w': w}),
         noise_multiplier,
         l2_norm_clip,
         microbatch,
         batch_axis=batch_axis)
     g_priv, v_priv = gv_priv(data)
     return g_priv
Ejemplo n.º 15
0
    def test_transform(self):
        def myloss(x):
            return (x ** 2).mean()

        g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1.,
                                                    l2_norm_clip=0.5, microbatch=1)
        self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,'
                                    ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))')
        self.assertEqual(repr(objax.Jit(gv)),
                         'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)')
        self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())),
                         'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)')
        self.assertEqual(repr(objax.Parallel(gv)),
                         "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,)),"
                         " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)")
        self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())),
                         'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))')
        self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')),
                         "objax.ForceArgs(module=GradValues, training=True, word='hello')")
Ejemplo n.º 16
0
    def __init__(self, methodname):
        """Initialize the test class."""
        super().__init__(methodname)

        self.data = jn.array([1.0, 2.0, 3.0, 4.0])

        self.W = objax.TrainVar(
            jn.array([[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 0., 1., 2.]]))
        self.b = objax.TrainVar(jn.array([-1., 0., 1.]))

        # f_lin(x) = W*x + b
        @objax.Function.with_vars(
            objax.VarCollection({
                'w': self.W,
                'b': self.b
            }))
        def f_lin(x):
            return jn.dot(self.W.value, x) + self.b.value

        self.f_lin = f_lin
Ejemplo n.º 17
0
    def test_grad_logistic(self):
        """Test if gradient has the correct value for logistic regression."""
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, -1.0, 1.0, -1.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.ones(ndim))

        def loss(x, y):
            xyw = jn.dot(x * np.tile(y, (ndim, 1)).transpose(), w.value)
            return jn.log(jn.exp(-xyw) + 1).mean(0)

        grad = objax.Grad(loss, objax.VarCollection({'w': w}))
        g = grad(data, labels)

        self.assertEqual(g[0].shape, tuple([ndim]))

        xw = np.dot(data, w.value)
        g_expect_w = -(data * np.tile(labels / (1 + np.exp(labels * xw)), (ndim, 1)).transpose()).mean(0)
        np.testing.assert_allclose(g[0], g_expect_w, atol=1e-7)
Ejemplo n.º 18
0
    def test_rename(self):
        vc = objax.VarCollection({
            'baab': objax.TrainVar(jn.zeros(()) + 1),
            'baaab': objax.TrainVar(jn.zeros(()) + 2),
            'baaaab': objax.TrainVar(jn.zeros(()) + 3),
            'abba': objax.TrainVar(jn.zeros(()) + 4),
            'acca': objax.TrainVar(jn.zeros(()) + 5)})
        vcr = vc.rename(objax.util.Renamer({'aa': 'x', 'bb': 'y'}))
        self.assertEqual(vc['baab'], vcr['bxb'])
        self.assertEqual(vc['baaab'], vcr['bxab'])
        self.assertEqual(vc['baaaab'], vcr['bxxb'])
        self.assertEqual(vc['abba'], vcr['aya'])
        self.assertEqual(vc['acca'], vcr['acca'])

        def my_rename(x):
            return x.replace('aa', 'x').replace('bb', 'y')

        vcr = vc.rename(objax.util.Renamer(my_rename))
        self.assertEqual(vc['baab'], vcr['bxb'])
        self.assertEqual(vc['baaab'], vcr['bxab'])
        self.assertEqual(vc['baaaab'], vcr['bxxb'])
        self.assertEqual(vc['abba'], vcr['aya'])
        self.assertEqual(vc['acca'], vcr['acca'])

        vcr = vc.rename(objax.util.Renamer([(re.compile('a{2}'), 'x'), (re.compile('bb'), 'y')]))
        self.assertEqual(vc['baab'], vcr['bxb'])
        self.assertEqual(vc['baaab'], vcr['bxab'])
        self.assertEqual(vc['baaaab'], vcr['bxxb'])
        self.assertEqual(vc['abba'], vcr['aya'])
        self.assertEqual(vc['acca'], vcr['acca'])

        vcr = vc.rename(objax.util.Renamer([(re.compile('a{2}'), 'x'), (re.compile('xa'), 'y')]))
        self.assertEqual(vc['baab'], vcr['bxb'])
        self.assertEqual(vc['baaab'], vcr['byb'])
        self.assertEqual(vc['baaaab'], vcr['bxxb'])
        self.assertEqual(vc['abba'], vcr['abba'])
        self.assertEqual(vc['acca'], vcr['acca'])
Ejemplo n.º 19
0
 def test_assign(self):
     vc = objax.VarCollection({'a': objax.TrainVar(jn.zeros(1))})
     vc['b'] = objax.TrainVar(jn.ones(1))
     self.assertEqual(len(vc), 2)
     self.assertEqual(vc['a'].value.sum(), 0)
     self.assertEqual(vc['b'].value.sum(), 1)
Ejemplo n.º 20
0
 def test_replicate_shape_assert(self):
     """Test replicating variable shapes does not assert"""
     vc = objax.VarCollection({'var': objax.TrainVar(jn.zeros(5))})
     with vc.replicate():
         self.assertEqual(len(vc['var'].value.shape), 2)
         self.assertEqual(vc['var'].value.shape[-1], 5)