def test_vectorize_module_one_arg_one_static_missing_batched(self): """Vectorize module with a broadcast argument and a static one with incomplete batch argument.""" f = objax.nn.Linear(3, 4) with self.assertRaises(AssertionError): _ = objax.Vectorize(lambda x, y: f(x) + y, f.vars(), batch_axis=(0, ))
def test_vectorize_module_one_arg(self): """Vectorize module with a single argument.""" f = objax.nn.Linear(3, 4) fv = objax.Vectorize(f) x = objax.random.normal((96, 3)) y = f(x) yv = fv(x) self.assertTrue(jn.array_equal(y, yv))
def test_vectorize_module_one_arg_transposed_axis(self): """Vectorize module with a single argument transposed axis.""" f = objax.nn.Linear(3, 4) fv = objax.Vectorize(f, batch_axis=(1, )) x = objax.random.normal((96, 3)) y = f(x) yv = fv(x.T) self.assertTrue(jn.array_equal(y, yv))
def test_vectorize_module_one_arg_one_static_positional_syntax(self): """Vectorize module with a broadcast argument and a static one using positional syntax.""" f = objax.nn.Linear(3, 4) c = objax.random.normal([4]) fv = objax.Vectorize(lambda *a: f(a[0]) + a[1], f.vars(), batch_axis=(0, None)) x = objax.random.normal((96, 3)) y = f(x) + c yv = fv(x, c) self.assertTrue(jn.array_equal(y, yv))
def test_vectorize_module_one_arg_one_static(self): """Vectorize module with a broadcast argument and a static one.""" f = objax.nn.Linear(3, 4) c = objax.random.normal([4]) fv = objax.Vectorize(lambda x, y: f(x) + y, f.vars(), batch_axis=(0, None)) x = objax.random.normal((96, 3)) y = f(x) + c yv = fv(x, c) self.assertTrue(jn.array_equal(y, yv))
def test_vectorize_module_one_arg_one_static_missing_batched_call(self): """Vectorize module with a broadcast argument and a static one with incomplete batch argument. Catch exception during call for variadic functions.""" f = objax.nn.Linear(3, 4) c = objax.random.normal([4]) fv = objax.Vectorize(lambda *a: f(a[0]) + a[1], f.vars(), batch_axis=(0, )) x = objax.random.normal((96, 3)) with self.assertRaises(AssertionError): _ = fv(x, c)
def test_trainvar_assign_multivalue(self): m = objax.ModuleList([objax.TrainVar(jn.array((1., 2.)))]) def increase(x): m[0].assign(m[0].value + x) return x * 2 x = np.arange(10)[:, None] vec_increase = objax.Vectorize(increase, m.vars()) y = vec_increase(x) self.assertEqual(y.tolist(), (2 * np.arange(10))[:, None].tolist()) self.assertEqual(m[0].value.tolist(), [5.5, 6.5])
def test_trainvar_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) def increase(x): m[0].assign(m[0].value + 1) return x + 1 x = np.arange(10)[:, None] vec_increase = objax.Vectorize(increase, m.vars()) y = vec_increase(x) self.assertEqual(y.tolist(), np.arange(1, 11)[:, None].tolist()) self.assertEqual(m[0].value.tolist(), [1., 1.])
def test_vectorize_module_two_args_mixed_axis(self): """Vectorize module with a two arguments with mixed batch axis.""" f1 = objax.nn.Linear(3, 4) f2 = objax.nn.Linear(5, 3) f = lambda x, y: jn.concatenate([f1(x), f2(y)], axis=1) fv = objax.Vectorize( lambda x, y: jn.concatenate([f1(x), f2(y)], axis=0), f1.vars('f1') + f2.vars('f2'), batch_axis=(0, 1)) x1 = objax.random.normal((96, 3)) x2 = objax.random.normal((96, 5)) y = f(x1, x2) yv = fv(x1, x2.T) self.assertTrue(jn.array_equal(y, yv))
def test_vectorize_random_function_reseed(self): class RandomReverse(objax.Module): def __init__(self): self.keygen = objax.random.Generator(1337) def __call__(self, x): r = objax.random.randint([], 0, 2, generator=self.keygen) return x + r * (x[::-1] - x), r random_reverse = RandomReverse() vector_reverse = objax.Vectorize(random_reverse) x = jn.arange(20).reshape((10, 2)) random_reverse.keygen.seed(0) y, r = vector_reverse(x) self.assertEqual(r.tolist(), [1, 1, 1, 1, 1, 0, 0, 1, 0, 0]) self.assertEqual(y.tolist(), [[1, 0], [3, 2], [5, 4], [7, 6], [9, 8], [10, 11], [12, 13], [15, 14], [16, 17], [18, 19]])
def test_transform(self): def myloss(x): return (x ** 2).mean() g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1., l2_norm_clip=0.5, microbatch=1) self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,' ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))') self.assertEqual(repr(objax.Jit(gv)), 'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)') self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())), 'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)') self.assertEqual(repr(objax.Parallel(gv)), "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,))," " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)") self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())), 'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))') self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')), "objax.ForceArgs(module=GradValues, training=True, word='hello')")
gv = objax.GradValues(loss, net.vars()) def maml_loss(x1, y1, x2, y2, alpha=0.1): net_vars = net.vars() original_weights = net_vars.tensors() # Save original weights g_x1y1 = gv(x1, y1)[0] # Compute gradient at (x1, y1) # Apply gradient update using SGD net_vars.assign([v - alpha * g for v, g in zip(original_weights, g_x1y1)]) loss_x2y2 = loss(x2, y2) net_vars.assign(original_weights) # Restore original weights return loss_x2y2 vec_maml_loss = objax.Vectorize(maml_loss, gv.vars(), batch_axis=(0, 0, 0, 0, None)) def batch_maml_loss(x1, y1, x2, y2, alpha=0.1): return vec_maml_loss(x1, y1, x2, y2, alpha).mean() maml_gv = objax.GradValues(batch_maml_loss, vec_maml_loss.vars()) def train_op(x1, y1, x2, y2): g, v = maml_gv(x1, y1, x2, y2) opt(0.001, g) return v