Exemple #1
0
    def test_grad_constant(self):
        """Test if constants are preserved."""
        # Set data
        ndim = 1
        data = np.array([[1.0], [3.0], [5.0], [-10.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.ones(1))
        m = objax.ModuleList([w, b])

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred)**2).mean()

        # We are supposed to see the gradient change after the value of b (the constant) changes.
        grad = objax.Grad(loss, objax.VarCollection({'w': w}))
        g_old = grad(data, labels)
        b.assign(-b.value)
        g_new = grad(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])

        # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes.
        grad = objax.Jit(objax.Grad(loss, objax.VarCollection({'w': w})),
                         m.vars())
        g_old = grad(data, labels)
        b.assign(-b.value)
        g_new = grad(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])
Exemple #2
0
    def test_grad_linear_and_inputs(self):
        """Test if gradient of inputs and variables has the correct values for linear regression."""
        # Set data
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.array([2, 3], jn.float32))
        b = objax.TrainVar(jn.array([1], jn.float32))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred) ** 2).mean()

        expect_gw = [37.25, 69.0]
        expect_gb = [13.75]
        expect_gx = [[4.0, 6.0], [8.5, 12.75], [13.0, 19.5], [2.0, 3.0]]
        expect_gy = [-2.0, -4.25, -6.5, -1.0]

        grad0 = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0,))
        g = grad0(data, labels)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gw)
        self.assertEqual(g[2].tolist(), expect_gb)

        grad1 = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1,))
        g = grad1(data, labels)
        self.assertEqual(g[0].tolist(), expect_gy)
        self.assertEqual(g[1].tolist(), expect_gw)
        self.assertEqual(g[2].tolist(), expect_gb)

        grad01 = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0, 1))
        g = grad01(data, labels)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gy)
        self.assertEqual(g[2].tolist(), expect_gw)
        self.assertEqual(g[3].tolist(), expect_gb)

        grad10 = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1, 0))
        g = grad10(data, labels)
        self.assertEqual(g[0].tolist(), expect_gy)
        self.assertEqual(g[1].tolist(), expect_gx)
        self.assertEqual(g[2].tolist(), expect_gw)
        self.assertEqual(g[3].tolist(), expect_gb)

        grad10 = objax.Grad(loss, None, input_argnums=(0, 1))
        g = grad10(data, labels)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gy)
Exemple #3
0
    def test_grad_linear(self):
        """Test if gradient has the correct value for linear regression."""
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.zeros(1))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred)**2).mean()

        grad = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}))
        g = grad(data, labels)

        self.assertEqual(g[0].shape, tuple([ndim]))
        self.assertEqual(g[1].shape, tuple([1]))

        g_expect_w = -(data * np.tile(labels, (ndim, 1)).transpose()).mean(0)
        g_expect_b = np.array([-labels.mean()])
        np.testing.assert_allclose(g[0], g_expect_w)
        np.testing.assert_allclose(g[1], g_expect_b)
Exemple #4
0
    def test_grad_signature(self):
        def f(x: JaxArray, y) -> Tuple[JaxArray, Dict[str, JaxArray]]:
            return (x + y).mean(), {'x': x, 'y': y}

        def df(x: JaxArray, y) -> List[JaxArray]:
            pass  # Signature of the differential of f

        g = objax.Grad(f, objax.VarCollection())
        self.assertEqual(inspect.signature(g), inspect.signature(df))
Exemple #5
0
    def test_trainvar_jit_assign(self):
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.zeros(1))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            b.assign(b.value + 1)
            w.assign(w.value - 1)
            return 0.5 * ((y - pred)**2).mean()

        grad = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}))

        def jloss(wb, x, y):
            w, b = wb
            pred = jn.dot(x, w) + b
            return 0.5 * ((y - pred)**2).mean()

        def jit_op(x, y):
            g = grad(x, y)
            b.assign(b.value * 2)
            w.assign(w.value * 3)
            return g

        jit_op = objax.Jit(jit_op, objax.VarCollection(dict(b=b, w=w)))
        jgrad = jax.grad(jloss)

        jg = jgrad([w.value, b.value], data, labels)
        g = jit_op(data, labels)
        self.assertEqual(g[0].shape, tuple([ndim]))
        self.assertEqual(g[1].shape, tuple([1]))
        np.testing.assert_allclose(g[0], jg[0])
        np.testing.assert_allclose(g[1], jg[1])
        self.assertEqual(w.value.tolist(), [-3., -3.])
        self.assertEqual(b.value.tolist(), [2.])

        jg = jgrad([w.value, b.value], data, labels)
        g = jit_op(data, labels)
        np.testing.assert_allclose(g[0], jg[0])
        np.testing.assert_allclose(g[1], jg[1])
        self.assertEqual(w.value.tolist(), [-12., -12.])
        self.assertEqual(b.value.tolist(), [6.])
Exemple #6
0
    def test_grad_logistic(self):
        """Test if gradient has the correct value for logistic regression."""
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, -1.0, 1.0, -1.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.ones(ndim))

        def loss(x, y):
            xyw = jn.dot(x * np.tile(y, (ndim, 1)).transpose(), w.value)
            return jn.log(jn.exp(-xyw) + 1).mean(0)

        grad = objax.Grad(loss, objax.VarCollection({'w': w}))
        g = grad(data, labels)

        self.assertEqual(g[0].shape, tuple([ndim]))

        xw = np.dot(data, w.value)
        g_expect_w = -(data * np.tile(labels / (1 + np.exp(labels * xw)), (ndim, 1)).transpose()).mean(0)
        np.testing.assert_allclose(g[0], g_expect_w, atol=1e-7)
Exemple #7
0
    def test_transform(self):
        def myloss(x):
            return (x ** 2).mean()

        g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1.,
                                                    l2_norm_clip=0.5, microbatch=1)
        self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,'
                                    ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))')
        self.assertEqual(repr(objax.Jit(gv)),
                         'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)')
        self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())),
                         'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)')
        self.assertEqual(repr(objax.Parallel(gv)),
                         "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,)),"
                         " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)")
        self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())),
                         'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))')
        self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')),
                         "objax.ForceArgs(module=GradValues, training=True, word='hello')")