def test_parallel_bntrain_concat(self): """Parallel inference (concat reduction) with batch norm in train mode.""" f = objax.nn.Sequential([objax.nn.Conv2D(3, 4, k=1), objax.nn.BatchNorm2D(4), objax.nn.Conv2D(4, 2, k=1)]) x = objax.random.normal((64, 3, 16, 16)) states, values = [], [] tensors = f.vars().tensors() for it in range(2): values.append(f(x, training=True)) states += [f[1].running_var.value, f[1].running_mean.value] f.vars().assign(tensors) self.assertAlmostEqual(((values[0] - values[1]) ** 2).sum(), 0, places=8) self.assertGreater(((states[0] - states[2]) ** 2).sum(), 0) self.assertGreater(((states[1] - states[3]) ** 2).sum(), 0) fp8 = objax.Parallel(lambda x: f(x, training=True), vc=f.vars()) x8 = jn.broadcast_to(x, (8, 64, 3, 16, 16)).reshape((-1, 3, 16, 16)) tensors = fp8.vars().tensors() for it in range(2): with fp8.vars().replicate(): z = fp8(x8).reshape((-1,) + values[it].shape) self.assertAlmostEqual(((f[1].running_var.value - states[2 * it]) ** 2).sum(), 0, delta=1e-12) self.assertAlmostEqual(((f[1].running_mean.value - states[2 * it + 1]) ** 2).sum(), 0, delta=1e-12) self.assertLess(((z - values[it][None]) ** 2).sum(), 5e-7) fp8.vars().assign(tensors)
def test_parallel_ema_shared(self): """Parallel EMA with weight sharing.""" f = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.nn.BatchNorm0D(4), objax.nn.Linear(4, 2), objax.nn.Dropout(1)]) ema = objax.optimizer.ExponentialMovingAverage(f.vars().subset(objax.TrainVar)) ema_f = ema.replace_vars(f) ema() all_vars = f.vars() + ema.vars() ema_fp = objax.Parallel(lambda x: ema_f(x, training=False), all_vars) ema_fps = objax.Parallel(lambda x: ema_f(x, training=False), all_vars + f.vars('shared')) x = objax.random.normal((96, 3)) with all_vars.replicate(): z = ema_fp(x) zs = ema_fps(x) self.assertLess(jn.abs(z - zs).mean(), 1e-6)
def test_parallel_syncbntrain_concat(self): """Parallel inference (concat reduction) with synced batch norm in train mode.""" f = objax.nn.Sequential([ objax.nn.Conv2D(3, 4, k=1), objax.nn.BatchNorm2D(4), objax.nn.Conv2D(4, 2, k=1) ]) fs = objax.nn.Sequential(f[:]) fs[1] = objax.nn.SyncedBatchNorm2D(4) x = objax.random.normal((96, 3, 16, 16)) states, values = [], [] tensors = f.vars().tensors() for it in range(2): values.append(f(x, training=True)) states += [f[1].running_var.value, f[1].running_mean.value] f.vars().assign(tensors) self.assertEqual(((values[0] - values[1])**2).sum(), 0) self.assertGreater(((states[0] - states[2])**2).sum(), 0) self.assertGreater(((states[1] - states[3])**2).sum(), 0) fp = objax.Parallel(lambda x: fs(x, training=True), vc=fs.vars()) for it in range(2): with fp.vars().replicate(): z = fp(x) self.assertAlmostEqual( ((fs[1].running_var.value - states[2 * it])**2).sum(), 0, delta=1e-12) self.assertAlmostEqual( ((fs[1].running_mean.value - states[2 * it + 1])**2).sum(), 0, delta=1e-12) self.assertLess(((z - values[it])**2).sum(), 1e-7)
def test_parallel_ema(self): """Parallel EMA.""" f = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.nn.BatchNorm0D(4), objax.nn.Linear(4, 2), objax.nn.Dropout(1)]) ema = objax.optimizer.ExponentialMovingAverage(f.vars().subset(objax.TrainVar)) ema_f = ema.replace_vars(f) ema() all_vars = f.vars() + ema.vars() fp = objax.Parallel(lambda x: f(x, training=False), f.vars()) ema_fp = objax.Parallel(lambda x: ema_f(x, training=False), all_vars) x = objax.random.normal((96, 3)) with all_vars.replicate(): y = fp(x) z = ema_fp(x) self.assertGreater(jn.abs(y - z).mean(), 1e-3)
def test_parallel_concat(self): """Parallel inference (concat reduction).""" f = objax.nn.Linear(3, 4) x = objax.random.normal((96, 3)) y = f(x) fp = objax.Parallel(f) with fp.vars().replicate(): z = fp(x) self.assertTrue(jn.array_equal(y, z))
def test_parallel_concat_broadcast(self): """Parallel inference with broadcasted scalar input.""" f = lambda x, y: x + y x = objax.random.normal((96, 3)) d = jn.float32(0.5) y = f(x, d) fp = objax.Parallel(f, objax.VarCollection()) z = fp(x, d) self.assertTrue(jn.array_equal(y, z))
def __init__(self, nclass: int, model: Callable, **kwargs): super().__init__(nclass, **kwargs) self.model = model(nin=3, nclass=nclass, **kwargs) model_vars = self.model.vars() self.ema = objax.optimizer.ExponentialMovingAverage(model_vars.subset(objax.TrainVar), momentum=0.999, debias=False) self.opt = objax.optimizer.Momentum(model_vars, momentum=0.9, nesterov=True) def loss_function(x, u, y): xu = jn.concatenate([x, u.reshape((-1,) + u.shape[2:])], axis=0) logit = self.model(xu, training=True) logit_x = logit[:x.shape[0]] logit_weak = logit[x.shape[0]::2] logit_strong = logit[x.shape[0] + 1::2] xe = objax.functional.loss.cross_entropy_logits(logit_x, y).mean() pseudo_labels = objax.functional.stop_gradient(objax.functional.softmax(logit_weak)) pseudo_mask = (pseudo_labels.max(axis=1) >= self.params.confidence).astype(logit_weak.dtype) xeu = objax.functional.loss.cross_entropy_logits_sparse(logit_strong, pseudo_labels.argmax(axis=1)) xeu = (xeu * pseudo_mask).mean() wd = 0.5 * sum(objax.functional.loss.l2(v.value) for k, v in model_vars.items() if k.endswith('.w')) loss = xe + self.params.wu * xeu + self.params.wd * wd return loss, {'losses/xe': xe, 'losses/xeu': xeu, 'losses/wd': wd, 'monitors/mask': pseudo_mask.mean()} gv = objax.GradValues(loss_function, model_vars) def train_op(step, x, u, y): g, v = gv(x, u, y) fstep = step[0] / (FLAGS.train_kimg << 10) lr = self.params.lr * jn.cos(fstep * (7 * jn.pi) / (2 * 8)) self.opt(lr, objax.functional.parallel.pmean(g)) self.ema() return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]}) eval_op = self.ema.replace_vars(lambda x: objax.functional.softmax(self.model(x, training=False))) self.eval_op = objax.Parallel(eval_op, model_vars + self.ema.vars()) self.train_op = objax.Parallel(train_op, self.vars(), reduce=lambda x: x[0])
def test_parallel_concat_multi_output(self): """Parallel inference (concat reduction) for multiple outputs.""" f = objax.nn.Linear(3, 4) x = objax.random.normal((96, 3)) y = f(x) fp = objax.Parallel(lambda x: [f(x), f(-x)], vc=f.vars()) with fp.vars().replicate(): z1, z2 = fp(x) self.assertTrue(jn.array_equal(z1, y)) self.assertTrue(jn.array_equal(z2, -y))
def test_trainvar_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) def increase(): m[0].assign(m[0].value + 1) return m[0].value para_increase = objax.Parallel(increase, m.vars()) with m.vars().replicate(): para_increase() self.assertEqual(m[0].value.tolist(), [1., 1.])
def test_trainvar_assign_multivalue(self): m = objax.ModuleList([objax.TrainVar(jn.array((1., 2.)))]) def increase(x): m[0].assign(m[0].value + x) return m[0].value para_increase = objax.Parallel(increase, m.vars()) with m.vars().replicate(): para_increase(jn.arange(8)) self.assertEqual(m[0].value.tolist(), [4.5, 5.5])
def test_parallel_bneval_concat(self): """Parallel inference (concat reduction) with batch norm in eval mode.""" f = objax.nn.Sequential([objax.nn.Conv2D(3, 4, k=1), objax.nn.BatchNorm2D(4), objax.nn.Conv2D(4, 2, k=1)]) x = objax.random.normal((96, 3, 16, 16)) y = f(x, training=False) fp = objax.Parallel(lambda x: f(x, training=False), f.vars()) with fp.vars().replicate(): z = fp(x) self.assertTrue(jn.array_equal(z, y)) self.assertEqual(((f[1].running_var.value - 1) ** 2).sum(), 0) self.assertEqual((f[1].running_mean.value ** 2).sum(), 0)
def test_parallel_list(self): """Parallel inference (concat reduction) without batch splitting.""" f = objax.nn.Linear(3, 4) g = lambda x: f(x[0]) + x[1][:, jn.newaxis] x1 = objax.random.normal((96, 3)) x2 = objax.random.normal((96,)) y = g([x1, x2]) fp = objax.Parallel(g, vc=f.vars()) with fp.vars().replicate(): z = fp([x1, x2]) self.assertTrue(jn.array_equal(y, z))
def __init__(self, model: Callable, nclass: int, **kwargs): super().__init__(nclass, **kwargs) self.model = model(3, nclass) model_vars = self.model.vars() self.opt = objax.optimizer.Momentum(model_vars) self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True) print(model_vars) def loss(x, label): logit = self.model(x, training=True) loss_wd = 0.5 * sum( (v.value**2).sum() for k, v in model_vars.items() if k.endswith('.w')) loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean() return loss_xe + loss_wd * self.params.weight_decay, { 'losses/xe': loss_xe, 'losses/wd': loss_wd } gv = objax.GradValues(loss, model_vars) def train_op(progress, x, y): g, v = gv(x, y) lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8)) self.opt(lr, objax.functional.parallel.pmean(g)) self.ema() return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]}) def predict_op(x): return objax.functional.softmax(self.model(x, training=False)) self.predict = objax.Parallel(self.ema.replace_vars(predict_op), model_vars + self.ema.vars()) self.train_op = objax.Parallel(train_op, self.vars(), reduce=lambda x: x[0])
def test_trainvar_and_ref_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) m.append(objax.TrainRef(m[0])) def increase(): m[0].assign(m[0].value + 1) m[1].assign(m[1].value + 1) return m[0].value para_increase = objax.Parallel(increase, m.vars()) with m.vars().replicate(): para_increase() self.assertEqual(m[0].value.tolist(), [2., 2.])
def test_parallel_train_op(self): """Parallel train op.""" f = objax.nn.Sequential([ objax.nn.Linear(3, 4), objax.functional.relu, objax.nn.Linear(4, 2) ]) centers = objax.random.normal((1, 2, 3)) centers *= objax.functional.rsqrt((centers**2).sum(2, keepdims=True)) x = (objax.random.normal((256, 1, 3), stddev=0.1) + centers).reshape( (512, 3)) y = jn.concatenate([ jn.zeros((256, 1), dtype=jn.uint32), jn.ones((256, 1), dtype=jn.uint32) ], axis=1) y = y.reshape((512, )) opt = objax.optimizer.Momentum(f.vars()) all_vars = f.vars('f') + opt.vars('opt') def loss(x, y): xe = objax.functional.loss.cross_entropy_logits_sparse(f(x), y) return xe.mean() gv = objax.GradValues(loss, f.vars()) def train_op(x, y): g, v = gv(x, y) opt(0.05, g) return v tensors = all_vars.tensors() loss_value = np.array([train_op(x, y)[0] for _ in range(10)]) var_values = {k: v.value for k, v in all_vars.items()} all_vars.assign(tensors) self.assertGreater(loss_value.min(), 0) def train_op_para(x, y): g, v = gv(x, y) opt(0.05, objax.functional.parallel.pmean(g)) return objax.functional.parallel.pmean(v) fp = objax.Parallel(train_op_para, vc=all_vars, reduce=lambda x: x[0]) with all_vars.replicate(): loss_value_p = np.array([fp(x, y)[0] for _ in range(10)]) var_values_p = {k: v.value for k, v in all_vars.items()} self.assertLess(jn.abs(loss_value_p / loss_value - 1).max(), 1e-6) for k, v in var_values.items(): self.assertLess(((v - var_values_p[k])**2).sum(), 1e-12, msg=k)
def __init__(self, model: Callable, nclass: int, **kwargs): super().__init__(nclass, **kwargs) self.model = model(3, nclass) self.opt = objax.optimizer.Momentum(self.model.vars()) self.model_ema = objax.optimizer.ExponentialMovingAverageModule( self.model, momentum=0.999, debias=True) @objax.Function.with_vars(self.model.vars()) def loss(x, label): logit = self.model(x, training=True) loss_wd = 0.5 * sum( (v.value**2).sum() for k, v in self.model.vars().items() if k.endswith('.w')) loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean() return loss_xe + loss_wd * self.params.weight_decay, { 'losses/xe': loss_xe, 'losses/wd': loss_wd } gv = objax.GradValues(loss, self.model.vars()) @objax.Function.with_vars(self.vars()) def train_op(progress, x, y): g, v = gv(x, y) lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8)) self.opt(lr, objax.functional.parallel.pmean(g)) self.model_ema.update_ema() return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]}) self.predict = objax.Parallel( objax.nn.Sequential([ objax.ForceArgs(self.model_ema, training=False), objax.functional.softmax ])) self.train_op = objax.Parallel(train_op, reduce=lambda x: x[0])
def test_parallel_weight_decay(self): """Parallel weight decay.""" f = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.nn.Linear(4, 2)]) fvars = f.vars() def loss_fn(): return 0.5 * sum((v.value ** 2).sum() for k, v in fvars.items() if k.endswith('.w')) tensors = fvars.tensors() loss_value = loss_fn() fvars.assign(tensors) self.assertGreater(loss_value, 0) fp = objax.Parallel(loss_fn, vc=fvars, reduce=lambda x: x[0]) with fvars.replicate(): loss_value_p = fp() self.assertLess(abs(loss_value_p / loss_value - 1), 1e-6)
def __init__(self): fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6) self.model = fn(6,2) model_vars = self.model.vars() self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True) def predict_op(x,y): # The model takes the two images and checks if they correspond # to the same original image. xx = jn.concatenate([jn.abs(x), jn.abs(y)], axis=1) return self.model(xx, training=False) self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars()) self.predict_fast = objax.Parallel(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
def test_jit_parallel_reducers(self): """JIT parallel reductions.""" f = objax.nn.Linear(3, 4) x = objax.random.normal((96, 3)) y = f(x) zl = [] for reduce in (lambda x: x, lambda x: x[0], lambda x: x.mean(0), lambda x: x.sum(0)): fp = objax.Jit(objax.Parallel(f, reduce=reduce)) with fp.vars().replicate(): zl.append(fp(x)) znone, zfirst, zmean, zsum = zl self.assertAlmostEqual(jn.square(jn.array(y.split(8)) - znone).sum(), 0, places=8) self.assertAlmostEqual(jn.square(y.split(8)[0] - zfirst).sum(), 0, places=8) self.assertAlmostEqual(jn.square(np.mean(y.split(8), 0) - zmean).sum(), 0, places=8) self.assertAlmostEqual(jn.square(np.sum(y.split(8), 0) - zsum).sum(), 0, places=8)
def test_transform(self): def myloss(x): return (x ** 2).mean() g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1., l2_norm_clip=0.5, microbatch=1) self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,' ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))') self.assertEqual(repr(objax.Jit(gv)), 'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)') self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())), 'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)') self.assertEqual(repr(objax.Parallel(gv)), "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,))," " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)") self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())), 'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))') self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')), "objax.ForceArgs(module=GradValues, training=True, word='hello')")