示例#1
0
 def __init__(self, model, *args, **kwargs):
     super().__init__(model, *args, **kwargs)
     fastloss = objax.Jit(self.loss, model.vars())
     self.gradvals = objax.Jit(objax.GradValues(fastloss, model.vars()),
                               model.vars())
     self.model.predict = objax.Jit(
         objax.ForceArgs(model.__call__, training=False), model.vars())
示例#2
0
文件: objax2tf.py 项目: google/objax
 def disabled_test_savedmodel_wrn(self):
     model_dir = tempfile.mkdtemp()
     # Make a model and convert it to TF
     model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1)
     predict_op = objax.Jit(
         objax.nn.Sequential([
             objax.ForceArgs(model, training=False),
             objax.functional.softmax
         ]))
     predict_tf = objax.util.Objax2Tf(predict_op)
     # Save model
     input_shape = (BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE)
     tf.saved_model.save(
         predict_tf,
         model_dir,
         signatures=predict_tf.__call__.get_concrete_function(
             tf.TensorSpec(input_shape, tf.float32)))
     # Load model
     loaded_tf_model = tf.saved_model.load(model_dir)
     loaded_predict_tf_op = loaded_tf_model.signatures['serving_default']
     self.verify_converted_predict_op(
         predict_op,
         lambda x: loaded_predict_tf_op(x)['output_0'],
         shape=input_shape)
     self.verify_converted_predict_op(predict_op,
                                      lambda x: loaded_tf_model(x),
                                      shape=input_shape)
     # Cleanup
     shutil.rmtree(model_dir)
示例#3
0
文件: objax2tf.py 项目: google/objax
 def disabled_test_convert_wrn(self):
     # Make a model
     model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1)
     # Prediction op without JIT
     predict_op = objax.nn.Sequential(
         [objax.ForceArgs(model, training=False), objax.functional.softmax])
     predict_tf = objax.util.Objax2Tf(predict_op)
     # Compare results
     self.verify_converted_predict_op(predict_op,
                                      predict_tf,
                                      shape=(BATCH_SIZE, NCHANNELS,
                                             IMAGE_SIZE, IMAGE_SIZE))
     # Predict op with JIT
     predict_op_jit = objax.Jit(predict_op)
     predict_tf_jit = objax.util.Objax2Tf(predict_op_jit)
     # Compare results
     self.verify_converted_predict_op(predict_op_jit,
                                      predict_tf_jit,
                                      shape=(BATCH_SIZE, NCHANNELS,
                                             IMAGE_SIZE, IMAGE_SIZE))
示例#4
0
    def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs):
        """
        Completely standard training. Nothing interesting to see here.
        """
        super().__init__(nclass, **kwargs)
        self.model = model(1 if mnist else 3, nclass)
        self.opt = objax.optimizer.Momentum(self.model.vars())
        self.model_ema = objax.optimizer.ExponentialMovingAverageModule(
            self.model, momentum=0.999, debias=True)

        @objax.Function.with_vars(self.model.vars())
        def loss(x, label):
            logit = self.model(x, training=True)
            loss_wd = 0.5 * sum(
                (v.value**2).sum()
                for k, v in self.model.vars().items() if k.endswith('.w'))
            loss_xe = objax.functional.loss.cross_entropy_logits(logit,
                                                                 label).mean()
            return loss_xe + loss_wd * self.params.weight_decay, {
                'losses/xe': loss_xe,
                'losses/wd': loss_wd
            }

        gv = objax.GradValues(loss, self.model.vars())
        self.gv = gv

        @objax.Function.with_vars(self.vars())
        def train_op(progress, x, y):
            g, v = gv(x, y)
            lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
            lr = lr * jn.clip(progress * 100, 0, 1)
            self.opt(lr, g)
            self.model_ema.update_ema()
            return {'monitors/lr': lr, **v[1]}

        self.predict = objax.Jit(
            objax.nn.Sequential(
                [objax.ForceArgs(self.model_ema, training=False)]))

        self.train_op = objax.Jit(train_op)
示例#5
0
文件: repr.py 项目: utkarshgiri/objax
    def test_transform(self):
        def myloss(x):
            return (x ** 2).mean()

        g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1.,
                                                    l2_norm_clip=0.5, microbatch=1)
        self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,'
                                    ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))')
        self.assertEqual(repr(objax.Jit(gv)),
                         'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)')
        self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())),
                         'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)')
        self.assertEqual(repr(objax.Parallel(gv)),
                         "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,)),"
                         " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)")
        self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())),
                         'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))')
        self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')),
                         "objax.ForceArgs(module=GradValues, training=True, word='hello')")
示例#6
0
    def __init__(self, model: Callable, nclass: int, **kwargs):
        super().__init__(nclass, **kwargs)
        self.model = model(3, nclass)
        self.opt = objax.optimizer.Momentum(self.model.vars())
        self.model_ema = objax.optimizer.ExponentialMovingAverageModule(
            self.model, momentum=0.999, debias=True)

        @objax.Function.with_vars(self.model.vars())
        def loss(x, label):
            logit = self.model(x, training=True)
            loss_wd = 0.5 * sum(
                (v.value**2).sum()
                for k, v in self.model.vars().items() if k.endswith('.w'))
            loss_xe = objax.functional.loss.cross_entropy_logits(logit,
                                                                 label).mean()
            return loss_xe + loss_wd * self.params.weight_decay, {
                'losses/xe': loss_xe,
                'losses/wd': loss_wd
            }

        gv = objax.GradValues(loss, self.model.vars())

        @objax.Function.with_vars(self.vars())
        def train_op(progress, x, y):
            g, v = gv(x, y)
            lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
            self.opt(lr, objax.functional.parallel.pmean(g))
            self.model_ema.update_ema()
            return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]})

        self.predict = objax.Parallel(
            objax.nn.Sequential([
                objax.ForceArgs(self.model_ema, training=False),
                objax.functional.softmax
            ]))
        self.train_op = objax.Parallel(train_op, reduce=lambda x: x[0])
示例#7
0

gv = objax.GradValues(loss, model.vars())


@objax.Function.with_vars(model.vars() + gv.vars() + opt.vars())
def train_op(x, y, lr):
    g, v = gv(x, y)
    opt(lr=lr, grads=g)
    return v


train_op = objax.Jit(train_op)
predict = objax.Jit(
    objax.nn.Sequential(
        [objax.ForceArgs(model, training=False), objax.functional.softmax]))


def augment(x):
    if random.random() < .5:
        x = x[:, :, :, ::-1]  # Flip the batch images about the horizontal axis
    # Pixel-shift all images in the batch by up to 4 pixels in any direction.
    x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect')
    rx, ry = np.random.randint(0, 8), np.random.randint(0, 8)
    x = x_pad[:, :, rx:rx + 32, ry:ry + 32]
    return x


# Training
print(model.vars())
for epoch in range(30):
示例#8
0
 def test_force_args(self):
     # def __call__(self, x, some_arg1, some_arg2):
     #     # without forces args equivalent to (x + some_arg1 * 5 + some_arg2 * 2)
     x = jn.array([1., 2.])
     model = ModelWithArg()
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [1., 2.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0),
                                    [6., 7.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0),
                                    [8., 9.])
     # Ensure that can not use invalid argument in original module
     with self.assertRaises(TypeError):
         model(x, some_arg1=0.0, some_arg2=0.0, wrong_arg_name=0.0)
     with self.assertRaises(TypeError):
         model(x, wrong_arg_name1=0.0, wrong_arg_name2=0.0)
     with self.assertRaises(TypeError):
         model.block1.op1(x, wrong_arg_name=0.0)
     # Set forced args.
     original_signature = inspect.signature(model.block1.op1)
     model.block1.op1 = objax.ForceArgs(model.block1.op1, some_arg=-1.0)
     # At this point following arguments are forced:
     #   model.block1.op1(..., some_arg=-1.0)
     self.assertEqual(original_signature,
                      inspect.signature(model.block1.op1))
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [0., 1.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0),
                                    [5., 6.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0),
                                    [6., 7.])
     # ForceArgs does not allow to pass invalid args
     with self.assertRaises(TypeError):
         model.block1.op1(x, wrong_arg_name=1.0)
     # Set forced args in a list
     model.block1.seq[0] = objax.ForceArgs(model.block1.seq[0],
                                           some_arg=-1.0)
     # At this point following arguments are forced:
     #   model.block1.op1(..., some_arg=-1.0)
     #   model.block1.seq[0](..., some_arg=-1.0)
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [-1., 0.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0),
                                    [3., 4.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0),
                                    [4., 5.])
     # Set invalid arg in forced args
     model.block2.op1 = objax.ForceArgs(model.block2.op1,
                                        wrong_arg_name=1.0)
     with self.assertRaises(TypeError):
         model(x, some_arg=0.0)
     # Remove force args with invalid name
     objax.ForceArgs.undo(model, wrong_arg_name=objax.ForceArgs.ANY)
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [-1., 0.])
     # Set few other forces args with nested ForceArgs
     model.block1 = objax.ForceArgs(model.block1, some_arg1=10.0)
     model.block1 = objax.ForceArgs(model.block1, some_arg2=20.0)
     # At this point following arguments are forced:
     #   model.block1(..., some_arg1=10.0, some_arg2=20.0)
     #   model.block1.op1(..., some_arg=-1.0)
     #   model.block1.seq[0](..., some_arg=-1.0)
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [9., 10.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0),
                                    [12., 13.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0),
                                    [13., 14.])
     # Resetting some of the forced args
     objax.ForceArgs.undo(
         model, some_arg1=30)  # noop because some_arg1=30 is not used
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [9., 10.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0),
                                    [12., 13.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0),
                                    [13., 14.])
     objax.ForceArgs.undo(model, some_arg1=10)  # undo some_arg1=10
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [-1., 0.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0),
                                    [3., 4.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0),
                                    [4., 5.])
     objax.ForceArgs.undo(model)  # undo all forced args
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [1., 2.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0),
                                    [6., 7.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0),
                                    [8., 9.])
     # Try forced args on the root
     model = objax.ForceArgs(model, some_arg1=-1.0)
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [-4., -3.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0),
                                    [-4., -3.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0),
                                    [-2., -1.])
     # Reset forced args on the model root
     objax.ForceArgs.undo(model)
     np.testing.assert_almost_equal(model(x, some_arg1=0.0, some_arg2=0.0),
                                    [1., 2.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=0.0),
                                    [6., 7.])
     np.testing.assert_almost_equal(model(x, some_arg1=1.0, some_arg2=1.0),
                                    [8., 9.])