def test_grad_constant(self): """Test if constants are preserved.""" # Set data ndim = 1 data = np.array([[1.0], [3.0], [5.0], [-10.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.ones(1)) m = objax.ModuleList([w, b]) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred)**2).mean() # We are supposed to see the gradient change after the value of b (the constant) changes. grad = objax.Grad(loss, objax.VarCollection({'w': w})) g_old = grad(data, labels) b.assign(-b.value) g_new = grad(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0]) # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes. grad = objax.Jit(objax.Grad(loss, objax.VarCollection({'w': w})), m.vars()) g_old = grad(data, labels) b.assign(-b.value) g_new = grad(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0])
def test_grad_linear_and_inputs(self): """Test if gradient of inputs and variables has the correct values for linear regression.""" # Set data data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.array([2, 3], jn.float32)) b = objax.TrainVar(jn.array([1], jn.float32)) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred) ** 2).mean() expect_gw = [37.25, 69.0] expect_gb = [13.75] expect_gx = [[4.0, 6.0], [8.5, 12.75], [13.0, 19.5], [2.0, 3.0]] expect_gy = [-2.0, -4.25, -6.5, -1.0] grad0 = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0,)) g = grad0(data, labels) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gw) self.assertEqual(g[2].tolist(), expect_gb) grad1 = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1,)) g = grad1(data, labels) self.assertEqual(g[0].tolist(), expect_gy) self.assertEqual(g[1].tolist(), expect_gw) self.assertEqual(g[2].tolist(), expect_gb) grad01 = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0, 1)) g = grad01(data, labels) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gy) self.assertEqual(g[2].tolist(), expect_gw) self.assertEqual(g[3].tolist(), expect_gb) grad10 = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1, 0)) g = grad10(data, labels) self.assertEqual(g[0].tolist(), expect_gy) self.assertEqual(g[1].tolist(), expect_gx) self.assertEqual(g[2].tolist(), expect_gw) self.assertEqual(g[3].tolist(), expect_gb) grad10 = objax.Grad(loss, None, input_argnums=(0, 1)) g = grad10(data, labels) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gy)
def test_grad_linear(self): """Test if gradient has the correct value for linear regression.""" # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.zeros(1)) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred)**2).mean() grad = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b})) g = grad(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) self.assertEqual(g[1].shape, tuple([1])) g_expect_w = -(data * np.tile(labels, (ndim, 1)).transpose()).mean(0) g_expect_b = np.array([-labels.mean()]) np.testing.assert_allclose(g[0], g_expect_w) np.testing.assert_allclose(g[1], g_expect_b)
def test_grad_signature(self): def f(x: JaxArray, y) -> Tuple[JaxArray, Dict[str, JaxArray]]: return (x + y).mean(), {'x': x, 'y': y} def df(x: JaxArray, y) -> List[JaxArray]: pass # Signature of the differential of f g = objax.Grad(f, objax.VarCollection()) self.assertEqual(inspect.signature(g), inspect.signature(df))
def test_trainvar_jit_assign(self): # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.zeros(1)) def loss(x, y): pred = jn.dot(x, w.value) + b.value b.assign(b.value + 1) w.assign(w.value - 1) return 0.5 * ((y - pred)**2).mean() grad = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b})) def jloss(wb, x, y): w, b = wb pred = jn.dot(x, w) + b return 0.5 * ((y - pred)**2).mean() def jit_op(x, y): g = grad(x, y) b.assign(b.value * 2) w.assign(w.value * 3) return g jit_op = objax.Jit(jit_op, objax.VarCollection(dict(b=b, w=w))) jgrad = jax.grad(jloss) jg = jgrad([w.value, b.value], data, labels) g = jit_op(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) self.assertEqual(g[1].shape, tuple([1])) np.testing.assert_allclose(g[0], jg[0]) np.testing.assert_allclose(g[1], jg[1]) self.assertEqual(w.value.tolist(), [-3., -3.]) self.assertEqual(b.value.tolist(), [2.]) jg = jgrad([w.value, b.value], data, labels) g = jit_op(data, labels) np.testing.assert_allclose(g[0], jg[0]) np.testing.assert_allclose(g[1], jg[1]) self.assertEqual(w.value.tolist(), [-12., -12.]) self.assertEqual(b.value.tolist(), [6.])
def test_grad_logistic(self): """Test if gradient has the correct value for logistic regression.""" # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, -1.0, 1.0, -1.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.ones(ndim)) def loss(x, y): xyw = jn.dot(x * np.tile(y, (ndim, 1)).transpose(), w.value) return jn.log(jn.exp(-xyw) + 1).mean(0) grad = objax.Grad(loss, objax.VarCollection({'w': w})) g = grad(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) xw = np.dot(data, w.value) g_expect_w = -(data * np.tile(labels / (1 + np.exp(labels * xw)), (ndim, 1)).transpose()).mean(0) np.testing.assert_allclose(g[0], g_expect_w, atol=1e-7)
def test_transform(self): def myloss(x): return (x ** 2).mean() g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1., l2_norm_clip=0.5, microbatch=1) self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,' ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))') self.assertEqual(repr(objax.Jit(gv)), 'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)') self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())), 'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)') self.assertEqual(repr(objax.Parallel(gv)), "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,))," " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)") self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())), 'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))') self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')), "objax.ForceArgs(module=GradValues, training=True, word='hello')")