예제 #1
0
    def test_parallel_bntrain_concat(self):
        """Parallel inference (concat reduction) with batch norm in train mode."""
        f = objax.nn.Sequential([objax.nn.Conv2D(3, 4, k=1), objax.nn.BatchNorm2D(4), objax.nn.Conv2D(4, 2, k=1)])
        x = objax.random.normal((64, 3, 16, 16))
        states, values = [], []
        tensors = f.vars().tensors()
        for it in range(2):
            values.append(f(x, training=True))
            states += [f[1].running_var.value, f[1].running_mean.value]
        f.vars().assign(tensors)

        self.assertAlmostEqual(((values[0] - values[1]) ** 2).sum(), 0, places=8)
        self.assertGreater(((states[0] - states[2]) ** 2).sum(), 0)
        self.assertGreater(((states[1] - states[3]) ** 2).sum(), 0)

        fp8 = objax.Parallel(lambda x: f(x, training=True), vc=f.vars())
        x8 = jn.broadcast_to(x, (8, 64, 3, 16, 16)).reshape((-1, 3, 16, 16))
        tensors = fp8.vars().tensors()
        for it in range(2):
            with fp8.vars().replicate():
                z = fp8(x8).reshape((-1,) + values[it].shape)
            self.assertAlmostEqual(((f[1].running_var.value - states[2 * it]) ** 2).sum(), 0, delta=1e-12)
            self.assertAlmostEqual(((f[1].running_mean.value - states[2 * it + 1]) ** 2).sum(), 0, delta=1e-12)
            self.assertLess(((z - values[it][None]) ** 2).sum(), 5e-7)
        fp8.vars().assign(tensors)
예제 #2
0
    def test_parallel_ema_shared(self):
        """Parallel EMA with weight sharing."""
        f = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.nn.BatchNorm0D(4),
                                 objax.nn.Linear(4, 2), objax.nn.Dropout(1)])
        ema = objax.optimizer.ExponentialMovingAverage(f.vars().subset(objax.TrainVar))
        ema_f = ema.replace_vars(f)
        ema()
        all_vars = f.vars() + ema.vars()

        ema_fp = objax.Parallel(lambda x: ema_f(x, training=False), all_vars)
        ema_fps = objax.Parallel(lambda x: ema_f(x, training=False), all_vars + f.vars('shared'))
        x = objax.random.normal((96, 3))
        with all_vars.replicate():
            z = ema_fp(x)
            zs = ema_fps(x)
        self.assertLess(jn.abs(z - zs).mean(), 1e-6)
예제 #3
0
    def test_parallel_syncbntrain_concat(self):
        """Parallel inference (concat reduction) with synced batch norm in train mode."""
        f = objax.nn.Sequential([
            objax.nn.Conv2D(3, 4, k=1),
            objax.nn.BatchNorm2D(4),
            objax.nn.Conv2D(4, 2, k=1)
        ])
        fs = objax.nn.Sequential(f[:])
        fs[1] = objax.nn.SyncedBatchNorm2D(4)
        x = objax.random.normal((96, 3, 16, 16))
        states, values = [], []
        tensors = f.vars().tensors()
        for it in range(2):
            values.append(f(x, training=True))
            states += [f[1].running_var.value, f[1].running_mean.value]
        f.vars().assign(tensors)

        self.assertEqual(((values[0] - values[1])**2).sum(), 0)
        self.assertGreater(((states[0] - states[2])**2).sum(), 0)
        self.assertGreater(((states[1] - states[3])**2).sum(), 0)

        fp = objax.Parallel(lambda x: fs(x, training=True), vc=fs.vars())
        for it in range(2):
            with fp.vars().replicate():
                z = fp(x)
            self.assertAlmostEqual(
                ((fs[1].running_var.value - states[2 * it])**2).sum(),
                0,
                delta=1e-12)
            self.assertAlmostEqual(
                ((fs[1].running_mean.value - states[2 * it + 1])**2).sum(),
                0,
                delta=1e-12)
            self.assertLess(((z - values[it])**2).sum(), 1e-7)
예제 #4
0
    def test_parallel_ema(self):
        """Parallel EMA."""
        f = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.nn.BatchNorm0D(4),
                                 objax.nn.Linear(4, 2), objax.nn.Dropout(1)])
        ema = objax.optimizer.ExponentialMovingAverage(f.vars().subset(objax.TrainVar))
        ema_f = ema.replace_vars(f)
        ema()
        all_vars = f.vars() + ema.vars()

        fp = objax.Parallel(lambda x: f(x, training=False), f.vars())
        ema_fp = objax.Parallel(lambda x: ema_f(x, training=False), all_vars)
        x = objax.random.normal((96, 3))
        with all_vars.replicate():
            y = fp(x)
            z = ema_fp(x)
        self.assertGreater(jn.abs(y - z).mean(), 1e-3)
예제 #5
0
 def test_parallel_concat(self):
     """Parallel inference (concat reduction)."""
     f = objax.nn.Linear(3, 4)
     x = objax.random.normal((96, 3))
     y = f(x)
     fp = objax.Parallel(f)
     with fp.vars().replicate():
         z = fp(x)
     self.assertTrue(jn.array_equal(y, z))
예제 #6
0
 def test_parallel_concat_broadcast(self):
     """Parallel inference with broadcasted scalar input."""
     f = lambda x, y: x + y
     x = objax.random.normal((96, 3))
     d = jn.float32(0.5)
     y = f(x, d)
     fp = objax.Parallel(f, objax.VarCollection())
     z = fp(x, d)
     self.assertTrue(jn.array_equal(y, z))
예제 #7
0
    def __init__(self, nclass: int, model: Callable, **kwargs):
        super().__init__(nclass, **kwargs)
        self.model = model(nin=3, nclass=nclass, **kwargs)
        model_vars = self.model.vars()
        self.ema = objax.optimizer.ExponentialMovingAverage(model_vars.subset(objax.TrainVar),
                                                            momentum=0.999, debias=False)
        self.opt = objax.optimizer.Momentum(model_vars, momentum=0.9, nesterov=True)

        def loss_function(x, u, y):
            xu = jn.concatenate([x, u.reshape((-1,) + u.shape[2:])], axis=0)
            logit = self.model(xu, training=True)
            logit_x = logit[:x.shape[0]]
            logit_weak = logit[x.shape[0]::2]
            logit_strong = logit[x.shape[0] + 1::2]

            xe = objax.functional.loss.cross_entropy_logits(logit_x, y).mean()

            pseudo_labels = objax.functional.stop_gradient(objax.functional.softmax(logit_weak))
            pseudo_mask = (pseudo_labels.max(axis=1) >= self.params.confidence).astype(logit_weak.dtype)
            xeu = objax.functional.loss.cross_entropy_logits_sparse(logit_strong, pseudo_labels.argmax(axis=1))
            xeu = (xeu * pseudo_mask).mean()

            wd = 0.5 * sum(objax.functional.loss.l2(v.value) for k, v in model_vars.items() if k.endswith('.w'))

            loss = xe + self.params.wu * xeu + self.params.wd * wd
            return loss, {'losses/xe': xe,
                          'losses/xeu': xeu,
                          'losses/wd': wd,
                          'monitors/mask': pseudo_mask.mean()}

        gv = objax.GradValues(loss_function, model_vars)

        def train_op(step, x, u, y):
            g, v = gv(x, u, y)
            fstep = step[0] / (FLAGS.train_kimg << 10)
            lr = self.params.lr * jn.cos(fstep * (7 * jn.pi) / (2 * 8))
            self.opt(lr, objax.functional.parallel.pmean(g))
            self.ema()
            return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]})

        eval_op = self.ema.replace_vars(lambda x: objax.functional.softmax(self.model(x, training=False)))
        self.eval_op = objax.Parallel(eval_op, model_vars + self.ema.vars())
        self.train_op = objax.Parallel(train_op, self.vars(), reduce=lambda x: x[0])
예제 #8
0
 def test_parallel_concat_multi_output(self):
     """Parallel inference (concat reduction) for multiple outputs."""
     f = objax.nn.Linear(3, 4)
     x = objax.random.normal((96, 3))
     y = f(x)
     fp = objax.Parallel(lambda x: [f(x), f(-x)], vc=f.vars())
     with fp.vars().replicate():
         z1, z2 = fp(x)
     self.assertTrue(jn.array_equal(z1, y))
     self.assertTrue(jn.array_equal(z2, -y))
예제 #9
0
    def test_trainvar_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])

        def increase():
            m[0].assign(m[0].value + 1)
            return m[0].value

        para_increase = objax.Parallel(increase, m.vars())
        with m.vars().replicate():
            para_increase()
        self.assertEqual(m[0].value.tolist(), [1., 1.])
예제 #10
0
    def test_trainvar_assign_multivalue(self):
        m = objax.ModuleList([objax.TrainVar(jn.array((1., 2.)))])

        def increase(x):
            m[0].assign(m[0].value + x)
            return m[0].value

        para_increase = objax.Parallel(increase, m.vars())
        with m.vars().replicate():
            para_increase(jn.arange(8))
        self.assertEqual(m[0].value.tolist(), [4.5, 5.5])
예제 #11
0
 def test_parallel_bneval_concat(self):
     """Parallel inference (concat reduction) with batch norm in eval mode."""
     f = objax.nn.Sequential([objax.nn.Conv2D(3, 4, k=1), objax.nn.BatchNorm2D(4), objax.nn.Conv2D(4, 2, k=1)])
     x = objax.random.normal((96, 3, 16, 16))
     y = f(x, training=False)
     fp = objax.Parallel(lambda x: f(x, training=False), f.vars())
     with fp.vars().replicate():
         z = fp(x)
     self.assertTrue(jn.array_equal(z, y))
     self.assertEqual(((f[1].running_var.value - 1) ** 2).sum(), 0)
     self.assertEqual((f[1].running_mean.value ** 2).sum(), 0)
예제 #12
0
 def test_parallel_list(self):
     """Parallel inference (concat reduction) without batch splitting."""
     f = objax.nn.Linear(3, 4)
     g = lambda x: f(x[0]) + x[1][:, jn.newaxis]
     x1 = objax.random.normal((96, 3))
     x2 = objax.random.normal((96,))
     y = g([x1, x2])
     fp = objax.Parallel(g, vc=f.vars())
     with fp.vars().replicate():
         z = fp([x1, x2])
     self.assertTrue(jn.array_equal(y, z))
예제 #13
0
    def __init__(self, model: Callable, nclass: int, **kwargs):
        super().__init__(nclass, **kwargs)
        self.model = model(3, nclass)
        model_vars = self.model.vars()
        self.opt = objax.optimizer.Momentum(model_vars)
        self.ema = objax.optimizer.ExponentialMovingAverage(model_vars,
                                                            momentum=0.999,
                                                            debias=True)
        print(model_vars)

        def loss(x, label):
            logit = self.model(x, training=True)
            loss_wd = 0.5 * sum(
                (v.value**2).sum()
                for k, v in model_vars.items() if k.endswith('.w'))
            loss_xe = objax.functional.loss.cross_entropy_logits(logit,
                                                                 label).mean()
            return loss_xe + loss_wd * self.params.weight_decay, {
                'losses/xe': loss_xe,
                'losses/wd': loss_wd
            }

        gv = objax.GradValues(loss, model_vars)

        def train_op(progress, x, y):
            g, v = gv(x, y)
            lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
            self.opt(lr, objax.functional.parallel.pmean(g))
            self.ema()
            return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]})

        def predict_op(x):
            return objax.functional.softmax(self.model(x, training=False))

        self.predict = objax.Parallel(self.ema.replace_vars(predict_op),
                                      model_vars + self.ema.vars())
        self.train_op = objax.Parallel(train_op,
                                       self.vars(),
                                       reduce=lambda x: x[0])
예제 #14
0
    def test_trainvar_and_ref_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])
        m.append(objax.TrainRef(m[0]))

        def increase():
            m[0].assign(m[0].value + 1)
            m[1].assign(m[1].value + 1)
            return m[0].value

        para_increase = objax.Parallel(increase, m.vars())
        with m.vars().replicate():
            para_increase()
        self.assertEqual(m[0].value.tolist(), [2., 2.])
예제 #15
0
    def test_parallel_train_op(self):
        """Parallel train op."""
        f = objax.nn.Sequential([
            objax.nn.Linear(3, 4), objax.functional.relu,
            objax.nn.Linear(4, 2)
        ])
        centers = objax.random.normal((1, 2, 3))
        centers *= objax.functional.rsqrt((centers**2).sum(2, keepdims=True))
        x = (objax.random.normal((256, 1, 3), stddev=0.1) + centers).reshape(
            (512, 3))
        y = jn.concatenate([
            jn.zeros((256, 1), dtype=jn.uint32),
            jn.ones((256, 1), dtype=jn.uint32)
        ],
                           axis=1)
        y = y.reshape((512, ))
        opt = objax.optimizer.Momentum(f.vars())
        all_vars = f.vars('f') + opt.vars('opt')

        def loss(x, y):
            xe = objax.functional.loss.cross_entropy_logits_sparse(f(x), y)
            return xe.mean()

        gv = objax.GradValues(loss, f.vars())

        def train_op(x, y):
            g, v = gv(x, y)
            opt(0.05, g)
            return v

        tensors = all_vars.tensors()
        loss_value = np.array([train_op(x, y)[0] for _ in range(10)])
        var_values = {k: v.value for k, v in all_vars.items()}
        all_vars.assign(tensors)

        self.assertGreater(loss_value.min(), 0)

        def train_op_para(x, y):
            g, v = gv(x, y)
            opt(0.05, objax.functional.parallel.pmean(g))
            return objax.functional.parallel.pmean(v)

        fp = objax.Parallel(train_op_para, vc=all_vars, reduce=lambda x: x[0])
        with all_vars.replicate():
            loss_value_p = np.array([fp(x, y)[0] for _ in range(10)])
        var_values_p = {k: v.value for k, v in all_vars.items()}

        self.assertLess(jn.abs(loss_value_p / loss_value - 1).max(), 1e-6)
        for k, v in var_values.items():
            self.assertLess(((v - var_values_p[k])**2).sum(), 1e-12, msg=k)
예제 #16
0
    def __init__(self, model: Callable, nclass: int, **kwargs):
        super().__init__(nclass, **kwargs)
        self.model = model(3, nclass)
        self.opt = objax.optimizer.Momentum(self.model.vars())
        self.model_ema = objax.optimizer.ExponentialMovingAverageModule(
            self.model, momentum=0.999, debias=True)

        @objax.Function.with_vars(self.model.vars())
        def loss(x, label):
            logit = self.model(x, training=True)
            loss_wd = 0.5 * sum(
                (v.value**2).sum()
                for k, v in self.model.vars().items() if k.endswith('.w'))
            loss_xe = objax.functional.loss.cross_entropy_logits(logit,
                                                                 label).mean()
            return loss_xe + loss_wd * self.params.weight_decay, {
                'losses/xe': loss_xe,
                'losses/wd': loss_wd
            }

        gv = objax.GradValues(loss, self.model.vars())

        @objax.Function.with_vars(self.vars())
        def train_op(progress, x, y):
            g, v = gv(x, y)
            lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
            self.opt(lr, objax.functional.parallel.pmean(g))
            self.model_ema.update_ema()
            return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]})

        self.predict = objax.Parallel(
            objax.nn.Sequential([
                objax.ForceArgs(self.model_ema, training=False),
                objax.functional.softmax
            ]))
        self.train_op = objax.Parallel(train_op, reduce=lambda x: x[0])
예제 #17
0
    def test_parallel_weight_decay(self):
        """Parallel weight decay."""
        f = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.nn.Linear(4, 2)])
        fvars = f.vars()

        def loss_fn():
            return 0.5 * sum((v.value ** 2).sum() for k, v in fvars.items() if k.endswith('.w'))

        tensors = fvars.tensors()
        loss_value = loss_fn()
        fvars.assign(tensors)
        self.assertGreater(loss_value, 0)

        fp = objax.Parallel(loss_fn, vc=fvars, reduce=lambda x: x[0])
        with fvars.replicate():
            loss_value_p = fp()
        self.assertLess(abs(loss_value_p / loss_value - 1), 1e-6)
예제 #18
0
        def __init__(self):
            fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6)
            self.model = fn(6,2)
            
            model_vars = self.model.vars()
            self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True)
    

            def predict_op(x,y):
                # The model takes the two images and checks if they correspond
                # to the same original image.
                xx = jn.concatenate([jn.abs(x),
                                     jn.abs(y)],
                                    axis=1)
                return self.model(xx, training=False)
            
            self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
            self.predict_fast = objax.Parallel(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
예제 #19
0
 def test_jit_parallel_reducers(self):
     """JIT parallel reductions."""
     f = objax.nn.Linear(3, 4)
     x = objax.random.normal((96, 3))
     y = f(x)
     zl = []
     for reduce in (lambda x: x,
                    lambda x: x[0],
                    lambda x: x.mean(0),
                    lambda x: x.sum(0)):
         fp = objax.Jit(objax.Parallel(f, reduce=reduce))
         with fp.vars().replicate():
             zl.append(fp(x))
     znone, zfirst, zmean, zsum = zl
     self.assertAlmostEqual(jn.square(jn.array(y.split(8)) - znone).sum(), 0, places=8)
     self.assertAlmostEqual(jn.square(y.split(8)[0] - zfirst).sum(), 0, places=8)
     self.assertAlmostEqual(jn.square(np.mean(y.split(8), 0) - zmean).sum(), 0, places=8)
     self.assertAlmostEqual(jn.square(np.sum(y.split(8), 0) - zsum).sum(), 0, places=8)
예제 #20
0
파일: repr.py 프로젝트: utkarshgiri/objax
    def test_transform(self):
        def myloss(x):
            return (x ** 2).mean()

        g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1.,
                                                    l2_norm_clip=0.5, microbatch=1)
        self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,'
                                    ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))')
        self.assertEqual(repr(objax.Jit(gv)),
                         'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)')
        self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())),
                         'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)')
        self.assertEqual(repr(objax.Parallel(gv)),
                         "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,)),"
                         " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)")
        self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())),
                         'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))')
        self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')),
                         "objax.ForceArgs(module=GradValues, training=True, word='hello')")