def __init__(self, model, *args, **kwargs): super().__init__(model, *args, **kwargs) fastloss = objax.Jit(self.loss, model.vars()) self.gradvals = objax.Jit(objax.GradValues(fastloss, model.vars()), model.vars()) self.model.predict = objax.Jit( objax.ForceArgs(model.__call__, training=False), model.vars())
def disabled_test_savedmodel_wrn(self): model_dir = tempfile.mkdtemp() # Make a model and convert it to TF model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1) predict_op = objax.Jit( objax.nn.Sequential([ objax.ForceArgs(model, training=False), objax.functional.softmax ])) predict_tf = objax.util.Objax2Tf(predict_op) # Save model input_shape = (BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE) tf.saved_model.save( predict_tf, model_dir, signatures=predict_tf.__call__.get_concrete_function( tf.TensorSpec(input_shape, tf.float32))) # Load model loaded_tf_model = tf.saved_model.load(model_dir) loaded_predict_tf_op = loaded_tf_model.signatures['serving_default'] self.verify_converted_predict_op( predict_op, lambda x: loaded_predict_tf_op(x)['output_0'], shape=input_shape) self.verify_converted_predict_op(predict_op, lambda x: loaded_tf_model(x), shape=input_shape) # Cleanup shutil.rmtree(model_dir)
def disabled_test_convert_wrn(self): # Make a model model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1) # Prediction op without JIT predict_op = objax.nn.Sequential( [objax.ForceArgs(model, training=False), objax.functional.softmax]) predict_tf = objax.util.Objax2Tf(predict_op) # Compare results self.verify_converted_predict_op(predict_op, predict_tf, shape=(BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE)) # Predict op with JIT predict_op_jit = objax.Jit(predict_op) predict_tf_jit = objax.util.Objax2Tf(predict_op_jit) # Compare results self.verify_converted_predict_op(predict_op_jit, predict_tf_jit, shape=(BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE))
def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs): """ Completely standard training. Nothing interesting to see here. """ super().__init__(nclass, **kwargs) self.model = model(1 if mnist else 3, nclass) self.opt = objax.optimizer.Momentum(self.model.vars()) self.model_ema = objax.optimizer.ExponentialMovingAverageModule( self.model, momentum=0.999, debias=True) @objax.Function.with_vars(self.model.vars()) def loss(x, label): logit = self.model(x, training=True) loss_wd = 0.5 * sum( (v.value**2).sum() for k, v in self.model.vars().items() if k.endswith('.w')) loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean() return loss_xe + loss_wd * self.params.weight_decay, { 'losses/xe': loss_xe, 'losses/wd': loss_wd } gv = objax.GradValues(loss, self.model.vars()) self.gv = gv @objax.Function.with_vars(self.vars()) def train_op(progress, x, y): g, v = gv(x, y) lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8)) lr = lr * jn.clip(progress * 100, 0, 1) self.opt(lr, g) self.model_ema.update_ema() return {'monitors/lr': lr, **v[1]} self.predict = objax.Jit( objax.nn.Sequential( [objax.ForceArgs(self.model_ema, training=False)])) self.train_op = objax.Jit(train_op)
def test_transform(self): def myloss(x): return (x ** 2).mean() g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1., l2_norm_clip=0.5, microbatch=1) self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,' ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))') self.assertEqual(repr(objax.Jit(gv)), 'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)') self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())), 'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)') self.assertEqual(repr(objax.Parallel(gv)), "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,))," " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)") self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())), 'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))') self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')), "objax.ForceArgs(module=GradValues, training=True, word='hello')")
def __init__(self, model: Callable, nclass: int, **kwargs): super().__init__(nclass, **kwargs) self.model = model(3, nclass) self.opt = objax.optimizer.Momentum(self.model.vars()) self.model_ema = objax.optimizer.ExponentialMovingAverageModule( self.model, momentum=0.999, debias=True) @objax.Function.with_vars(self.model.vars()) def loss(x, label): logit = self.model(x, training=True) loss_wd = 0.5 * sum( (v.value**2).sum() for k, v in self.model.vars().items() if k.endswith('.w')) loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean() return loss_xe + loss_wd * self.params.weight_decay, { 'losses/xe': loss_xe, 'losses/wd': loss_wd } gv = objax.GradValues(loss, self.model.vars()) @objax.Function.with_vars(self.vars()) def train_op(progress, x, y): g, v = gv(x, y) lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8)) self.opt(lr, objax.functional.parallel.pmean(g)) self.model_ema.update_ema() return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]}) self.predict = objax.Parallel( objax.nn.Sequential([ objax.ForceArgs(self.model_ema, training=False), objax.functional.softmax ])) self.train_op = objax.Parallel(train_op, reduce=lambda x: x[0])
gv = objax.GradValues(loss, model.vars()) @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars()) def train_op(x, y, lr): g, v = gv(x, y) opt(lr=lr, grads=g) return v train_op = objax.Jit(train_op) predict = objax.Jit( objax.nn.Sequential( [objax.ForceArgs(model, training=False), objax.functional.softmax])) def augment(x): if random.random() < .5: x = x[:, :, :, ::-1] # Flip the batch images about the horizontal axis # Pixel-shift all images in the batch by up to 4 pixels in any direction. x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect') rx, ry = np.random.randint(0, 8), np.random.randint(0, 8) x = x_pad[:, :, rx:rx + 32, ry:ry + 32] return x # Training print(model.vars()) for epoch in range(30):
def test_force_args(self): # def __call__(self, x, some_arg1, some_arg2): # # without forces args equivalent to (x + some_arg1 * 5 + some_arg2 * 2) x = jn.array([1., 2.]) model = ModelWithArg() np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [1., 2.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0), [6., 7.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0), [8., 9.]) # Ensure that can not use invalid argument in original module with self.assertRaises(TypeError): model(x, some_arg1=0.0, some_arg2=0.0, wrong_arg_name=0.0) with self.assertRaises(TypeError): model(x, wrong_arg_name1=0.0, wrong_arg_name2=0.0) with self.assertRaises(TypeError): model.block1.op1(x, wrong_arg_name=0.0) # Set forced args. original_signature = inspect.signature(model.block1.op1) model.block1.op1 = objax.ForceArgs(model.block1.op1, some_arg=-1.0) # At this point following arguments are forced: # model.block1.op1(..., some_arg=-1.0) self.assertEqual(original_signature, inspect.signature(model.block1.op1)) np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [0., 1.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0), [5., 6.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0), [6., 7.]) # ForceArgs does not allow to pass invalid args with self.assertRaises(TypeError): model.block1.op1(x, wrong_arg_name=1.0) # Set forced args in a list model.block1.seq[0] = objax.ForceArgs(model.block1.seq[0], some_arg=-1.0) # At this point following arguments are forced: # model.block1.op1(..., some_arg=-1.0) # model.block1.seq[0](..., some_arg=-1.0) np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [-1., 0.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0), [3., 4.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0), [4., 5.]) # Set invalid arg in forced args model.block2.op1 = objax.ForceArgs(model.block2.op1, wrong_arg_name=1.0) with self.assertRaises(TypeError): model(x, some_arg=0.0) # Remove force args with invalid name objax.ForceArgs.undo(model, wrong_arg_name=objax.ForceArgs.ANY) np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [-1., 0.]) # Set few other forces args with nested ForceArgs model.block1 = objax.ForceArgs(model.block1, some_arg1=10.0) model.block1 = objax.ForceArgs(model.block1, some_arg2=20.0) # At this point following arguments are forced: # model.block1(..., some_arg1=10.0, some_arg2=20.0) # model.block1.op1(..., some_arg=-1.0) # model.block1.seq[0](..., some_arg=-1.0) np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [9., 10.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0), [12., 13.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0), [13., 14.]) # Resetting some of the forced args objax.ForceArgs.undo( model, some_arg1=30) # noop because some_arg1=30 is not used np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [9., 10.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0), [12., 13.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0), [13., 14.]) objax.ForceArgs.undo(model, some_arg1=10) # undo some_arg1=10 np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [-1., 0.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0), [3., 4.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0), [4., 5.]) objax.ForceArgs.undo(model) # undo all forced args np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [1., 2.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0), [6., 7.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0), [8., 9.]) # Try forced args on the root model = objax.ForceArgs(model, some_arg1=-1.0) np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [-4., -3.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0), [-4., -3.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0), [-2., -1.]) # Reset forced args on the model root objax.ForceArgs.undo(model) np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0), [1., 2.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0), [6., 7.]) np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0), [8., 9.])