def test_tensors(self): vshared = objax.TrainVar(jn.ones(1)) vc = objax.VarCollection([('a', objax.TrainVar(jn.zeros(1))), ('b', vshared)]) vc += objax.VarCollection([('c', vshared)]) self.assertEqual(len(vc.tensors()), 2) self.assertEqual([x.sum() for x in vc.tensors()], [0, 1])
def test_gradvalues_constant(self): """Test if constants are preserved.""" # Set data ndim = 1 data = np.array([[1.0], [3.0], [5.0], [-10.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.ones(1)) m = objax.ModuleList([w, b]) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred)**2).mean() # We are supposed to see the gradient change after the value of b (the constant) changes. gv = objax.GradValues(loss, objax.VarCollection({'w': w})) g_old, v_old = gv(data, labels) b.assign(-b.value) g_new, v_new = gv(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0]) # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes. gv = objax.Jit(objax.GradValues(loss, objax.VarCollection({'w': w})), m.vars()) g_old, v_old = gv(data, labels) b.assign(-b.value) g_new, v_new = gv(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0])
def test_gradvalues_linear_and_inputs(self): """Test if gradient of inputs and variables has the correct values for linear regression.""" # Set data data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.array([2, 3], jn.float32)) b = objax.TrainVar(jn.array([1], jn.float32)) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred) ** 2).mean() expect_loss = loss(data, labels) expect_gw = [37.25, 69.0] expect_gb = [13.75] expect_gx = [[4.0, 6.0], [8.5, 12.75], [13.0, 19.5], [2.0, 3.0]] expect_gy = [-2.0, -4.25, -6.5, -1.0] gv0 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0,)) g, v = gv0(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gw) self.assertEqual(g[2].tolist(), expect_gb) gv1 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1,)) g, v = gv1(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gy) self.assertEqual(g[1].tolist(), expect_gw) self.assertEqual(g[2].tolist(), expect_gb) gv01 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0, 1)) g, v = gv01(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gy) self.assertEqual(g[2].tolist(), expect_gw) self.assertEqual(g[3].tolist(), expect_gb) gv10 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1, 0)) g, v = gv10(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gy) self.assertEqual(g[1].tolist(), expect_gx) self.assertEqual(g[2].tolist(), expect_gw) self.assertEqual(g[3].tolist(), expect_gb) gv10 = objax.GradValues(loss, None, input_argnums=(0, 1)) g, v = gv10(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gy)
def test_name_conflict(self): """Check name conflict raises a ValueError.""" vc1 = objax.VarCollection([('a', objax.TrainVar(jn.zeros(1)))]) vc2 = objax.VarCollection([('a', objax.TrainVar(jn.ones(1)))]) with self.assertRaises(ValueError): vc1 + vc2 with self.assertRaises(ValueError): vc1.update(vc2) with self.assertRaises(ValueError): vc1['a'] = objax.TrainVar(jn.ones(1))
def test_len_iter(self): """Verify length and iterator.""" v1 = objax.TrainVar(jn.zeros(1)) vshared = objax.TrainVar(jn.ones(1)) vc1 = objax.VarCollection([('a', v1), ('b', vshared)]) vc2 = objax.VarCollection([('c', vshared)]) vc = vc1 + vc2 self.assertEqual(len(vc), 3) self.assertEqual(len(vc.keys()), 3) self.assertEqual(len(vc.items()), 3) self.assertEqual(len(vc.values()), 3) self.assertEqual(len(list(vc)), 2) # Self iterator is unique.
def test_opt(self): self.assertEqual(repr(objax.optimizer.Adam(objax.VarCollection())), 'objax.optimizer.Adam(beta1=0.9, beta2=0.999, eps=1e-08)') self.assertEqual(repr(objax.optimizer.LARS(objax.VarCollection())), 'objax.optimizer.LARS(momentum=0.9, weight_decay=0.0001, tc=0.001, eps=1e-05)') self.assertEqual(repr(objax.optimizer.Momentum(objax.VarCollection())), 'objax.optimizer.Momentum(momentum=0.9, nesterov=False)') self.assertEqual(repr(objax.optimizer.SGD(objax.VarCollection())), 'objax.optimizer.SGD()') self.assertEqual(repr(objax.optimizer.ExponentialMovingAverage(objax.VarCollection())), 'objax.optimizer.ExponentialMovingAverage(momentum=0.999, debias=False, eps=1e-06)') self.assertEqual(repr(objax.optimizer.ExponentialMovingAverageModule(objax.Module())), 'objax.optimizer.ExponentialMovingAverageModule(momentum=0.999, debias=False, eps=1e-06)')
def test_init_list(self): """Initialize a VarCollection with a list.""" vc = objax.VarCollection([('a', objax.TrainVar(jn.zeros(1))), ('b', objax.TrainVar(jn.ones(1)))]) self.assertEqual(len(vc), 2) self.assertEqual(vc['a'].value.sum(), 0) self.assertEqual(vc['b'].value.sum(), 1)
def test_init_dict(self): """Initialize a VarCollection with a dict.""" vc = objax.VarCollection({'a': objax.TrainVar(jn.zeros(1)), 'b': objax.TrainVar(jn.ones(1))}) self.assertEqual(len(vc), 2) self.assertEqual(vc['a'].value.sum(), 0) self.assertEqual(vc['b'].value.sum(), 1)
def test_gradvalues_linear(self): """Test if gradient has the correct value for linear regression.""" # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.zeros(1)) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred)**2).mean() gv = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b})) g, v = gv(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) self.assertEqual(g[1].shape, tuple([1])) g_expect_w = -(data * np.tile(labels, (ndim, 1)).transpose()).mean(0) g_expect_b = np.array([-labels.mean()]) np.testing.assert_allclose(g[0], g_expect_w) np.testing.assert_allclose(g[1], g_expect_b) np.testing.assert_allclose(v[0], loss(data, labels))
def test_trainvar_jit_assign(self): # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.zeros(1)) def loss(x, y): pred = jn.dot(x, w.value) + b.value b.assign(b.value + 1) w.assign(w.value - 1) return 0.5 * ((y - pred)**2).mean() grad = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b})) def jloss(wb, x, y): w, b = wb pred = jn.dot(x, w) + b return 0.5 * ((y - pred)**2).mean() def jit_op(x, y): g = grad(x, y) b.assign(b.value * 2) w.assign(w.value * 3) return g jit_op = objax.Jit(jit_op, objax.VarCollection(dict(b=b, w=w))) jgrad = jax.grad(jloss) jg = jgrad([w.value, b.value], data, labels) g = jit_op(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) self.assertEqual(g[1].shape, tuple([1])) np.testing.assert_allclose(g[0], jg[0]) np.testing.assert_allclose(g[1], jg[1]) self.assertEqual(w.value.tolist(), [-3., -3.]) self.assertEqual(b.value.tolist(), [2.]) jg = jgrad([w.value, b.value], data, labels) g = jit_op(data, labels) np.testing.assert_allclose(g[0], jg[0]) np.testing.assert_allclose(g[1], jg[1]) self.assertEqual(w.value.tolist(), [-12., -12.]) self.assertEqual(b.value.tolist(), [6.])
def test_grad_signature(self): def f(x: JaxArray, y) -> Tuple[JaxArray, Dict[str, JaxArray]]: return (x + y).mean(), {'x': x, 'y': y} def df(x: JaxArray, y) -> List[JaxArray]: pass # Signature of the differential of f g = objax.Grad(f, objax.VarCollection()) self.assertEqual(inspect.signature(g), inspect.signature(df))
def test_parallel_concat_broadcast(self): """Parallel inference with broadcasted scalar input.""" f = lambda x, y: x + y x = objax.random.normal((96, 3)) d = jn.float32(0.5) y = f(x, d) fp = objax.Parallel(f, objax.VarCollection()) z = fp(x, d) self.assertTrue(jn.array_equal(y, z))
def test_constant_optimization(self): m = objax.nn.Linear(3, 4) jit_constant = objax.Jit(m, objax.VarCollection()) x = objax.random.normal((10, 3)) self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 0) # Modify m (which was supposed to be constant!) m.b.assign(m.b.value + 1) self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 40)
def get_g(microbatch, l2_norm_clip, batch_axis=(0, )): gv_priv = objax.privacy.dpsgd.PrivateGradValues( loss, objax.VarCollection({'w': w}), noise_multiplier, l2_norm_clip, microbatch, batch_axis=batch_axis) g_priv, v_priv = gv_priv(data) return g_priv
def test_transform(self): def myloss(x): return (x ** 2).mean() g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1., l2_norm_clip=0.5, microbatch=1) self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,' ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))') self.assertEqual(repr(objax.Jit(gv)), 'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)') self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())), 'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)') self.assertEqual(repr(objax.Parallel(gv)), "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,))," " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)") self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())), 'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))') self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')), "objax.ForceArgs(module=GradValues, training=True, word='hello')")
def __init__(self, methodname): """Initialize the test class.""" super().__init__(methodname) self.data = jn.array([1.0, 2.0, 3.0, 4.0]) self.W = objax.TrainVar( jn.array([[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 0., 1., 2.]])) self.b = objax.TrainVar(jn.array([-1., 0., 1.])) # f_lin(x) = W*x + b @objax.Function.with_vars( objax.VarCollection({ 'w': self.W, 'b': self.b })) def f_lin(x): return jn.dot(self.W.value, x) + self.b.value self.f_lin = f_lin
def test_grad_logistic(self): """Test if gradient has the correct value for logistic regression.""" # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, -1.0, 1.0, -1.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.ones(ndim)) def loss(x, y): xyw = jn.dot(x * np.tile(y, (ndim, 1)).transpose(), w.value) return jn.log(jn.exp(-xyw) + 1).mean(0) grad = objax.Grad(loss, objax.VarCollection({'w': w})) g = grad(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) xw = np.dot(data, w.value) g_expect_w = -(data * np.tile(labels / (1 + np.exp(labels * xw)), (ndim, 1)).transpose()).mean(0) np.testing.assert_allclose(g[0], g_expect_w, atol=1e-7)
def test_rename(self): vc = objax.VarCollection({ 'baab': objax.TrainVar(jn.zeros(()) + 1), 'baaab': objax.TrainVar(jn.zeros(()) + 2), 'baaaab': objax.TrainVar(jn.zeros(()) + 3), 'abba': objax.TrainVar(jn.zeros(()) + 4), 'acca': objax.TrainVar(jn.zeros(()) + 5)}) vcr = vc.rename(objax.util.Renamer({'aa': 'x', 'bb': 'y'})) self.assertEqual(vc['baab'], vcr['bxb']) self.assertEqual(vc['baaab'], vcr['bxab']) self.assertEqual(vc['baaaab'], vcr['bxxb']) self.assertEqual(vc['abba'], vcr['aya']) self.assertEqual(vc['acca'], vcr['acca']) def my_rename(x): return x.replace('aa', 'x').replace('bb', 'y') vcr = vc.rename(objax.util.Renamer(my_rename)) self.assertEqual(vc['baab'], vcr['bxb']) self.assertEqual(vc['baaab'], vcr['bxab']) self.assertEqual(vc['baaaab'], vcr['bxxb']) self.assertEqual(vc['abba'], vcr['aya']) self.assertEqual(vc['acca'], vcr['acca']) vcr = vc.rename(objax.util.Renamer([(re.compile('a{2}'), 'x'), (re.compile('bb'), 'y')])) self.assertEqual(vc['baab'], vcr['bxb']) self.assertEqual(vc['baaab'], vcr['bxab']) self.assertEqual(vc['baaaab'], vcr['bxxb']) self.assertEqual(vc['abba'], vcr['aya']) self.assertEqual(vc['acca'], vcr['acca']) vcr = vc.rename(objax.util.Renamer([(re.compile('a{2}'), 'x'), (re.compile('xa'), 'y')])) self.assertEqual(vc['baab'], vcr['bxb']) self.assertEqual(vc['baaab'], vcr['byb']) self.assertEqual(vc['baaaab'], vcr['bxxb']) self.assertEqual(vc['abba'], vcr['abba']) self.assertEqual(vc['acca'], vcr['acca'])
def test_assign(self): vc = objax.VarCollection({'a': objax.TrainVar(jn.zeros(1))}) vc['b'] = objax.TrainVar(jn.ones(1)) self.assertEqual(len(vc), 2) self.assertEqual(vc['a'].value.sum(), 0) self.assertEqual(vc['b'].value.sum(), 1)
def test_replicate_shape_assert(self): """Test replicating variable shapes does not assert""" vc = objax.VarCollection({'var': objax.TrainVar(jn.zeros(5))}) with vc.replicate(): self.assertEqual(len(vc['var'].value.shape), 2) self.assertEqual(vc['var'].value.shape[-1], 5)