Exemple #1
0
 def test_vectorize_module_one_arg_one_static_missing_batched(self):
     """Vectorize module with a broadcast argument and a static one with incomplete batch argument."""
     f = objax.nn.Linear(3, 4)
     with self.assertRaises(AssertionError):
         _ = objax.Vectorize(lambda x, y: f(x) + y,
                             f.vars(),
                             batch_axis=(0, ))
Exemple #2
0
 def test_vectorize_module_one_arg(self):
     """Vectorize module with a single argument."""
     f = objax.nn.Linear(3, 4)
     fv = objax.Vectorize(f)
     x = objax.random.normal((96, 3))
     y = f(x)
     yv = fv(x)
     self.assertTrue(jn.array_equal(y, yv))
Exemple #3
0
 def test_vectorize_module_one_arg_transposed_axis(self):
     """Vectorize module with a single argument transposed axis."""
     f = objax.nn.Linear(3, 4)
     fv = objax.Vectorize(f, batch_axis=(1, ))
     x = objax.random.normal((96, 3))
     y = f(x)
     yv = fv(x.T)
     self.assertTrue(jn.array_equal(y, yv))
Exemple #4
0
 def test_vectorize_module_one_arg_one_static_positional_syntax(self):
     """Vectorize module with a broadcast argument and a static one using positional syntax."""
     f = objax.nn.Linear(3, 4)
     c = objax.random.normal([4])
     fv = objax.Vectorize(lambda *a: f(a[0]) + a[1],
                          f.vars(),
                          batch_axis=(0, None))
     x = objax.random.normal((96, 3))
     y = f(x) + c
     yv = fv(x, c)
     self.assertTrue(jn.array_equal(y, yv))
Exemple #5
0
 def test_vectorize_module_one_arg_one_static(self):
     """Vectorize module with a broadcast argument and a static one."""
     f = objax.nn.Linear(3, 4)
     c = objax.random.normal([4])
     fv = objax.Vectorize(lambda x, y: f(x) + y,
                          f.vars(),
                          batch_axis=(0, None))
     x = objax.random.normal((96, 3))
     y = f(x) + c
     yv = fv(x, c)
     self.assertTrue(jn.array_equal(y, yv))
Exemple #6
0
 def test_vectorize_module_one_arg_one_static_missing_batched_call(self):
     """Vectorize module with a broadcast argument and a static one with incomplete batch argument.
     Catch exception during call for variadic functions."""
     f = objax.nn.Linear(3, 4)
     c = objax.random.normal([4])
     fv = objax.Vectorize(lambda *a: f(a[0]) + a[1],
                          f.vars(),
                          batch_axis=(0, ))
     x = objax.random.normal((96, 3))
     with self.assertRaises(AssertionError):
         _ = fv(x, c)
Exemple #7
0
    def test_trainvar_assign_multivalue(self):
        m = objax.ModuleList([objax.TrainVar(jn.array((1., 2.)))])

        def increase(x):
            m[0].assign(m[0].value + x)
            return x * 2

        x = np.arange(10)[:, None]
        vec_increase = objax.Vectorize(increase, m.vars())
        y = vec_increase(x)
        self.assertEqual(y.tolist(), (2 * np.arange(10))[:, None].tolist())
        self.assertEqual(m[0].value.tolist(), [5.5, 6.5])
Exemple #8
0
    def test_trainvar_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])

        def increase(x):
            m[0].assign(m[0].value + 1)
            return x + 1

        x = np.arange(10)[:, None]
        vec_increase = objax.Vectorize(increase, m.vars())
        y = vec_increase(x)
        self.assertEqual(y.tolist(), np.arange(1, 11)[:, None].tolist())
        self.assertEqual(m[0].value.tolist(), [1., 1.])
Exemple #9
0
 def test_vectorize_module_two_args_mixed_axis(self):
     """Vectorize module with a two arguments with mixed batch axis."""
     f1 = objax.nn.Linear(3, 4)
     f2 = objax.nn.Linear(5, 3)
     f = lambda x, y: jn.concatenate([f1(x), f2(y)], axis=1)
     fv = objax.Vectorize(
         lambda x, y: jn.concatenate([f1(x), f2(y)], axis=0),
         f1.vars('f1') + f2.vars('f2'),
         batch_axis=(0, 1))
     x1 = objax.random.normal((96, 3))
     x2 = objax.random.normal((96, 5))
     y = f(x1, x2)
     yv = fv(x1, x2.T)
     self.assertTrue(jn.array_equal(y, yv))
Exemple #10
0
    def test_vectorize_random_function_reseed(self):
        class RandomReverse(objax.Module):
            def __init__(self):
                self.keygen = objax.random.Generator(1337)

            def __call__(self, x):
                r = objax.random.randint([], 0, 2, generator=self.keygen)
                return x + r * (x[::-1] - x), r

        random_reverse = RandomReverse()
        vector_reverse = objax.Vectorize(random_reverse)
        x = jn.arange(20).reshape((10, 2))
        random_reverse.keygen.seed(0)
        y, r = vector_reverse(x)
        self.assertEqual(r.tolist(), [1, 1, 1, 1, 1, 0, 0, 1, 0, 0])
        self.assertEqual(y.tolist(),
                         [[1, 0], [3, 2], [5, 4], [7, 6], [9, 8], [10, 11],
                          [12, 13], [15, 14], [16, 17], [18, 19]])
Exemple #11
0
    def test_transform(self):
        def myloss(x):
            return (x ** 2).mean()

        g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1.,
                                                    l2_norm_clip=0.5, microbatch=1)
        self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,'
                                    ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))')
        self.assertEqual(repr(objax.Jit(gv)),
                         'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)')
        self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())),
                         'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)')
        self.assertEqual(repr(objax.Parallel(gv)),
                         "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,)),"
                         " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)")
        self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())),
                         'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))')
        self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')),
                         "objax.ForceArgs(module=GradValues, training=True, word='hello')")
Exemple #12
0
gv = objax.GradValues(loss, net.vars())


def maml_loss(x1, y1, x2, y2, alpha=0.1):
    net_vars = net.vars()
    original_weights = net_vars.tensors()  # Save original weights
    g_x1y1 = gv(x1, y1)[0]  # Compute gradient at (x1, y1)
    # Apply gradient update using SGD
    net_vars.assign([v - alpha * g for v, g in zip(original_weights, g_x1y1)])
    loss_x2y2 = loss(x2, y2)
    net_vars.assign(original_weights)  # Restore original weights
    return loss_x2y2


vec_maml_loss = objax.Vectorize(maml_loss,
                                gv.vars(),
                                batch_axis=(0, 0, 0, 0, None))


def batch_maml_loss(x1, y1, x2, y2, alpha=0.1):
    return vec_maml_loss(x1, y1, x2, y2, alpha).mean()


maml_gv = objax.GradValues(batch_maml_loss, vec_maml_loss.vars())


def train_op(x1, y1, x2, y2):
    g, v = maml_gv(x1, y1, x2, y2)
    opt(0.001, g)
    return v