Example #1
0
    def test_gradvalues_constant(self):
        """Test if constants are preserved."""
        # Set data
        ndim = 1
        data = np.array([[1.0], [3.0], [5.0], [-10.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.ones(1))
        m = objax.ModuleList([w, b])

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred)**2).mean()

        # We are supposed to see the gradient change after the value of b (the constant) changes.
        gv = objax.GradValues(loss, objax.VarCollection({'w': w}))
        g_old, v_old = gv(data, labels)
        b.assign(-b.value)
        g_new, v_new = gv(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])

        # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes.
        gv = objax.Jit(objax.GradValues(loss, objax.VarCollection({'w': w})),
                       m.vars())
        g_old, v_old = gv(data, labels)
        b.assign(-b.value)
        g_new, v_new = gv(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])
Example #2
0
    def test_gradvalues_linear_and_inputs(self):
        """Test if gradient of inputs and variables has the correct values for linear regression."""
        # Set data
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.array([2, 3], jn.float32))
        b = objax.TrainVar(jn.array([1], jn.float32))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred) ** 2).mean()

        expect_loss = loss(data, labels)
        expect_gw = [37.25, 69.0]
        expect_gb = [13.75]
        expect_gx = [[4.0, 6.0], [8.5, 12.75], [13.0, 19.5], [2.0, 3.0]]
        expect_gy = [-2.0, -4.25, -6.5, -1.0]

        gv0 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0,))
        g, v = gv0(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gw)
        self.assertEqual(g[2].tolist(), expect_gb)

        gv1 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1,))
        g, v = gv1(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gy)
        self.assertEqual(g[1].tolist(), expect_gw)
        self.assertEqual(g[2].tolist(), expect_gb)

        gv01 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0, 1))
        g, v = gv01(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gy)
        self.assertEqual(g[2].tolist(), expect_gw)
        self.assertEqual(g[3].tolist(), expect_gb)

        gv10 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1, 0))
        g, v = gv10(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gy)
        self.assertEqual(g[1].tolist(), expect_gx)
        self.assertEqual(g[2].tolist(), expect_gw)
        self.assertEqual(g[3].tolist(), expect_gb)

        gv10 = objax.GradValues(loss, None, input_argnums=(0, 1))
        g, v = gv10(data, labels)
        self.assertEqual(v[0], expect_loss)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gy)
def test_gradient_step(var_f, len_f, var_y, N):
    """
    test whether VI with newt's GP and MarkovGP provide the same initial gradient step in the hyperparameters
    """

    x, Y, t, R, y = build_data(N)

    gp_model = initialise_gp_model(var_f, len_f, var_y, x, y, R[0])
    markovgp_model = initialise_markovgp_model(var_f, len_f, var_y, x, y, R[0])

    gv = objax.GradValues(inf, gp_model.vars())
    gv_markov = objax.GradValues(inf, markovgp_model.vars())

    lr_adam = 0.1
    lr_newton = 1.
    opt = objax.optimizer.Adam(gp_model.vars())
    opt_markov = objax.optimizer.Adam(markovgp_model.vars())

    gp_model.update_posterior()
    gp_grads, gp_value = gv(gp_model, lr=lr_newton)
    gp_loss_ = gp_value[0]
    opt(lr_adam, gp_grads)
    gp_hypers = np.array([
        gp_model.kernel.temporal_kernel.lengthscale,
        gp_model.kernel.temporal_kernel.variance,
        gp_model.kernel.spatial_kernel.lengthscale,
        gp_model.likelihood.variance
    ])
    print(gp_hypers)
    print(gp_grads)

    markovgp_model.update_posterior()
    markovgp_grads, markovgp_value = gv_markov(markovgp_model, lr=lr_newton)
    markovgp_loss_ = markovgp_value[0]
    opt_markov(lr_adam, markovgp_grads)
    markovgp_hypers = np.array([
        markovgp_model.kernel.temporal_kernel.lengthscale,
        markovgp_model.kernel.temporal_kernel.variance,
        markovgp_model.kernel.spatial_kernel.lengthscale,
        markovgp_model.likelihood.variance
    ])

    print(markovgp_hypers)
    print(markovgp_grads)

    np.testing.assert_allclose(gp_grads[0], markovgp_grads[0], rtol=1e-4)
    np.testing.assert_allclose(gp_grads[1], markovgp_grads[1], rtol=1e-4)
    np.testing.assert_allclose(gp_grads[2], markovgp_grads[2], rtol=1e-4)
def test_gradient_step(var_f, len_f, var_y, N):
    """
    test whether MarkovGP and SparseMarkovGP provide the same initial gradient step in the hyperparameters (Z=X)
    """

    x, y = build_data(N)

    z = x + np.random.normal(0, .05, x.shape)

    gp_model = initialise_gp_model(var_f, len_f, var_y, x, y)
    sparsemarkovgp_model = initialise_sparsemarkovgp_model(
        var_f, len_f, var_y, x, y, z)

    gv = objax.GradValues(inf, gp_model.vars())
    gv_markov = objax.GradValues(inf, sparsemarkovgp_model.vars())

    lr_adam = 0.1
    lr_newton = 1.
    opt = objax.optimizer.Adam(gp_model.vars())
    opt_markov = objax.optimizer.Adam(sparsemarkovgp_model.vars())

    gp_model.update_posterior()
    gp_grads, gp_value = gv(gp_model, lr=lr_newton)
    gp_loss_ = gp_value[0]
    opt(lr_adam, gp_grads)
    gp_hypers = np.array([
        gp_model.kernel.lengthscale, gp_model.kernel.variance,
        gp_model.likelihood.variance
    ])
    print(gp_hypers)
    print(gp_grads)

    sparsemarkovgp_model.update_posterior()
    markovgp_grads, markovgp_value = gv_markov(sparsemarkovgp_model,
                                               lr=lr_newton)
    markovgp_loss_ = markovgp_value[0]
    opt_markov(lr_adam, markovgp_grads)
    markovgp_hypers = np.array([
        sparsemarkovgp_model.kernel.lengthscale,
        sparsemarkovgp_model.kernel.variance,
        sparsemarkovgp_model.likelihood.variance
    ])
    print(markovgp_hypers)
    print(markovgp_grads)

    np.testing.assert_allclose(gp_grads[0], markovgp_grads[0], atol=3e-1)
    np.testing.assert_allclose(gp_grads[1], markovgp_grads[1], rtol=5e-2)
    np.testing.assert_allclose(gp_grads[2], markovgp_grads[2], rtol=5e-2)
Example #5
0
    def test_gradvalues_linear(self):
        """Test if gradient has the correct value for linear regression."""
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.zeros(1))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred)**2).mean()

        gv = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}))
        g, v = gv(data, labels)

        self.assertEqual(g[0].shape, tuple([ndim]))
        self.assertEqual(g[1].shape, tuple([1]))

        g_expect_w = -(data * np.tile(labels, (ndim, 1)).transpose()).mean(0)
        g_expect_b = np.array([-labels.mean()])
        np.testing.assert_allclose(g[0], g_expect_w)
        np.testing.assert_allclose(g[1], g_expect_b)
        np.testing.assert_allclose(v[0], loss(data, labels))
Example #6
0
 def __init__(self):
     # Some constants
     total_batch_size = FLAGS.train_device_batch_size * jax.device_count()
     self.base_learning_rate = FLAGS.base_learning_rate * total_batch_size / 256
     # Create model
     bn_cls = objax.nn.SyncedBatchNorm2D if FLAGS.use_sync_bn else objax.nn.BatchNorm2D
     self.model = ResNet50(in_channels=3,
                           num_classes=NUM_CLASSES,
                           normalization_fn=bn_cls)
     self.model_vars = self.model.vars()
     print(self.model_vars)
     # Create parallel eval op
     self.evaluate_batch_parallel = objax.Parallel(
         self.evaluate_batch, self.model_vars, reduce=lambda x: x.sum(0))
     # Create parallel training op
     self.optimizer = objax.optimizer.Momentum(self.model_vars,
                                               momentum=0.9,
                                               nesterov=True)
     self.compute_grads_loss = objax.GradValues(self.loss_fn,
                                                self.model_vars)
     self.all_vars = self.model_vars + self.optimizer.vars()
     self.train_op_parallel = objax.Parallel(self.train_op,
                                             self.all_vars,
                                             reduce=lambda x: x[0])
     # Summary writer
     self.summary_writer = objax.jaxboard.SummaryWriter(
         os.path.join(FLAGS.model_dir, 'tb'))
Example #7
0
    def __init__(self,
                 model,
                 dataloaders,
                 optim=objax.optimizer.Adam,
                 lr_sched=lambda e: 1,
                 log_dir=None,
                 log_suffix='',
                 log_args={},
                 early_stop_metric=None):
        # Setup model, optimizer, and dataloaders
        self.model = model  #
        #self.model= objax.Jit(objax.ForceArgs(model,training=True)) #TODO: figure out static nums
        #self.model.predict = objax.Jit(objax.ForceArgs(model.__call__,training=False),model.vars())
        #self.model.predict = objax.ForceArgs(model.__call__,training=False)
        #self._model = model
        #self.model = objax.ForceArgs(model,training=True)
        #self.model.predict = objax.ForceArgs(model.__call__,training=False)
        #self.model = objax.Jit(lambda x, training: model(x,training=training),model.vars(),static_argnums=(1,))
        #self.model = objax.Jit(model,static_argnums=(1,))

        self.optimizer = optim(model.vars())
        self.lr_sched = lr_sched
        self.dataloaders = dataloaders  # A dictionary of dataloaders
        self.epoch = 0

        self.logger = LazyLogger(log_dir, log_suffix, **log_args)
        #self.logger.add_text('ModelSpec','model: {}'.format(model))
        self.hypers = {}
        self.ckpt = None  # copy.deepcopy(self.state_dict()) #TODO fix model saving
        self.early_stop_metric = early_stop_metric
        #fastloss = objax.Jit(self.loss,model.vars())

        self.gradvals = objax.GradValues(self.loss, self.model.vars())
Example #8
0
    def test_private_gradvalues_compare_nonpriv(self):
        """Test if PrivateGradValues without clipping / noise is the same as non-private GradValues."""
        l2_norm_clip = 1e10
        noise_multiplier = 0

        for use_norm_accumulation in [True, False]:
            for microbatch in [1, 10, self.ntrain]:
                gv_priv = objax.Jit(
                    objax.privacy.dpsgd.PrivateGradValues(
                        self.loss,
                        self.model_vars,
                        noise_multiplier,
                        l2_norm_clip,
                        microbatch,
                        batch_axis=(0, 0),
                        use_norm_accumulation=use_norm_accumulation))
                gv = objax.GradValues(self.loss, self.model_vars)
                g_priv, v_priv = gv_priv(self.data, self.labels)
                g, v = gv(self.data, self.labels)

                # Check the shape of the gradient.
                self.assertEqual(g_priv[0].shape, tuple([self.nclass]))
                self.assertEqual(g_priv[1].shape,
                                 tuple([self.ndim, self.nclass]))

                # Check if the private gradient is similar to the non-private gradient.
                np.testing.assert_allclose(g[0], g_priv[0], atol=1e-7)
                np.testing.assert_allclose(g[1], g_priv[1], atol=1e-7)
                np.testing.assert_allclose(v_priv[0],
                                           self.loss(self.data,
                                                     self.labels)[0],
                                           atol=1e-7)
Example #9
0
    def test_typical_training_loop(self):
        # Define model and optimizer
        model = DNNet((32, 10), objax.functional.leaky_relu)
        opt = objax.optimizer.Momentum(model.vars(), nesterov=True)

        # Predict op
        predict_op = lambda x: objax.functional.softmax(model(x, training=False))

        self.assertDictEqual(objax.util.find_used_variables(predict_op),
                             model.vars(scope='model.'))

        # Loss function
        def loss(x, label):
            logit = model(x, training=True)
            xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
            return xe_loss

        self.assertDictEqual(objax.util.find_used_variables(loss),
                             model.vars(scope='model.'))

        # Gradients and loss function
        loss_gv = objax.GradValues(loss, objax.util.find_used_variables(loss))

        def train_op(x, y, learning_rate):
            grads, loss = loss_gv(x, y)
            opt(learning_rate, grads)
            return loss

        self.assertDictEqual(objax.util.find_used_variables(train_op),
                             {**model.vars(scope='loss_gv.model.'), **opt.vars(scope='opt.')})
Example #10
0
 def __init__(self, model, *args, **kwargs):
     super().__init__(model, *args, **kwargs)
     fastloss = objax.Jit(self.loss, model.vars())
     self.gradvals = objax.Jit(objax.GradValues(fastloss, model.vars()),
                               model.vars())
     self.model.predict = objax.Jit(
         objax.ForceArgs(model.__call__, training=False), model.vars())
Example #11
0
    def test_gradvalues_signature(self):
        def f(x: JaxArray, y) -> Tuple[JaxArray, Dict[str, JaxArray]]:
            return (x + y).mean(), {'x': x, 'y': y}

        def df(x: JaxArray, y) -> Tuple[List[JaxArray], Tuple[JaxArray, Dict[str, JaxArray]]]:
            pass  # Signature of the (differential of f, f)

        g = objax.GradValues(f, objax.VarCollection())
        self.assertEqual(inspect.signature(g), inspect.signature(df))
Example #12
0
 def _test_loss_opt(self, loss_name: str, opt_name: str):
     """Given loss and optimizer name, get definitions and run test."""
     model_vars, loss = self._get_loss(loss_name)
     gv = objax.GradValues(loss, model_vars)
     opt = self._get_optimizer(model_vars, opt_name)
     lr = self.lrs['{}_{}'.format(loss_name, opt_name)]
     tolerance = self.tolerances['{}_{}'.format(loss_name, opt_name)]
     self._check_run(gv, opt, loss, lr, self.num_steps, tolerance)
     return model_vars, loss
Example #13
0
    def __init__(self, nclass: int, model: Callable, **kwargs):
        super().__init__(nclass, **kwargs)
        self.model = model(nin=3, nclass=nclass, **kwargs)
        model_vars = self.model.vars()
        self.ema = objax.optimizer.ExponentialMovingAverage(model_vars.subset(
            objax.TrainVar),
                                                            momentum=0.999,
                                                            debias=False)
        self.opt = objax.optimizer.Momentum(model_vars,
                                            momentum=0.9,
                                            nesterov=True)

        def loss_function(x, u, y):
            xu = jn.concatenate([x, u.reshape((-1, ) + u.shape[2:])], axis=0)
            logit = self.model(xu, training=True)
            logit_x = logit[:x.shape[0]]
            logit_weak = logit[x.shape[0]::2]
            logit_strong = logit[x.shape[0] + 1::2]

            xe = objax.functional.loss.cross_entropy_logits(logit_x, y).mean()

            pseudo_labels = objax.functional.stop_gradient(
                objax.functional.softmax(logit_weak))
            pseudo_mask = (pseudo_labels.max(axis=1) >=
                           self.params.confidence).astype(logit_weak.dtype)
            xeu = objax.functional.loss.cross_entropy_logits_sparse(
                logit_strong, pseudo_labels.argmax(axis=1))
            xeu = (xeu * pseudo_mask).mean()

            wd = 0.5 * sum(
                objax.functional.loss.l2(v.value)
                for k, v in model_vars.items() if k.endswith('.w'))

            loss = xe + self.params.wu * xeu + self.params.wd * wd
            return loss, {
                'losses/xe': xe,
                'losses/xeu': xeu,
                'losses/wd': wd,
                'monitors/mask': pseudo_mask.mean()
            }

        gv = objax.GradValues(loss_function, model_vars)

        def train_op(step, x, u, y):
            g, v = gv(x, u, y)
            fstep = step[0] / (FLAGS.train_kimg << 10)
            lr = self.params.lr * jn.cos(fstep * (7 * jn.pi) / (2 * 8))
            self.opt(lr, objax.functional.parallel.pmean(g))
            self.ema()
            return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]})

        eval_op = self.ema.replace_vars(
            lambda x: objax.functional.softmax(self.model(x, training=False)))
        self.eval_op = objax.Parallel(eval_op, model_vars + self.ema.vars())
        self.train_op = objax.Parallel(train_op,
                                       self.vars(),
                                       reduce=lambda x: x[0])
Example #14
0
 def _test_loss_opt(self, loss_name: str, opt_name: str, override: bool = False):
     """Given loss and optimizer name, get definitions and run test."""
     model_vars, loss = self._get_loss(loss_name)
     gv = objax.GradValues(loss, model_vars)
     opt = self._get_optimizer(model_vars, opt_name)
     test_name = '{}_{}'.format(loss_name, opt_name)
     test_name = test_name + '_override' if override else test_name
     lr = self.lrs[test_name]
     tolerance = self.tolerances[test_name]
     options = self.override_options[test_name] if override and test_name in self.override_options else None
     self._check_run(gv, opt, loss, lr, self.num_steps, tolerance, options)
     return model_vars, loss
Example #15
0
    def test_parallel_train_op(self):
        """Parallel train op."""
        f = objax.nn.Sequential([
            objax.nn.Linear(3, 4), objax.functional.relu,
            objax.nn.Linear(4, 2)
        ])
        centers = objax.random.normal((1, 2, 3))
        centers *= objax.functional.rsqrt((centers**2).sum(2, keepdims=True))
        x = (objax.random.normal((256, 1, 3), stddev=0.1) + centers).reshape(
            (512, 3))
        y = jn.concatenate([
            jn.zeros((256, 1), dtype=jn.uint32),
            jn.ones((256, 1), dtype=jn.uint32)
        ],
                           axis=1)
        y = y.reshape((512, ))
        opt = objax.optimizer.Momentum(f.vars())
        all_vars = f.vars('f') + opt.vars('opt')

        def loss(x, y):
            xe = objax.functional.loss.cross_entropy_logits_sparse(f(x), y)
            return xe.mean()

        gv = objax.GradValues(loss, f.vars())

        def train_op(x, y):
            g, v = gv(x, y)
            opt(0.05, g)
            return v

        tensors = all_vars.tensors()
        loss_value = np.array([train_op(x, y)[0] for _ in range(10)])
        var_values = {k: v.value for k, v in all_vars.items()}
        all_vars.assign(tensors)

        self.assertGreater(loss_value.min(), 0)

        def train_op_para(x, y):
            g, v = gv(x, y)
            opt(0.05, objax.functional.parallel.pmean(g))
            return objax.functional.parallel.pmean(v)

        fp = objax.Parallel(train_op_para, vc=all_vars, reduce=lambda x: x[0])
        with all_vars.replicate():
            loss_value_p = np.array([fp(x, y)[0] for _ in range(10)])
        var_values_p = {k: v.value for k, v in all_vars.items()}

        self.assertLess(jn.abs(loss_value_p / loss_value - 1).max(), 1e-6)
        for k, v in var_values.items():
            self.assertLess(((v - var_values_p[k])**2).sum(), 1e-12, msg=k)
Example #16
0
def test_gradient_step(var_f, len_f, var_y, N):
    """
    test whether newt's VI and gpflow's SVGP (Z=X) provide the same initial gradient step in the hyperparameters
    """

    x, y = build_data(N)

    newt_model = initialise_newt_model(var_f, len_f, var_y, x, y)
    gpflow_model = initialise_gpflow_model(var_f, len_f, var_y, x, y)

    gv = objax.GradValues(inf.energy, newt_model.vars())

    lr_adam = 0.1
    lr_newton = 1.
    opt = objax.optimizer.Adam(newt_model.vars())

    newt_model.update_posterior()
    newt_grads, value = gv(newt_model)  # , lr=lr_newton)
    loss_ = value[0]
    opt(lr_adam, newt_grads)
    newt_hypers = np.array([
        newt_model.kernel.lengthscale, newt_model.kernel.variance,
        newt_model.likelihood.variance
    ])
    print(newt_hypers)
    print(newt_grads)

    adam_opt = tf.optimizers.Adam(lr_adam)
    data = (x, y[:, None])
    with tf.GradientTape() as tape:
        loss = -gpflow_model.elbo(data)
    _vars = gpflow_model.trainable_variables
    gpflow_grads = tape.gradient(loss, _vars)

    loss_fn = gpflow_model.training_loss_closure(data)
    adam_vars = gpflow_model.trainable_variables
    adam_opt.minimize(loss_fn, adam_vars)
    gpflow_hypers = np.array([
        gpflow_model.kernel.lengthscales.numpy()[0],
        gpflow_model.kernel.variance.numpy(),
        gpflow_model.likelihood.variance.numpy()
    ])
    print(gpflow_hypers)
    print(gpflow_grads)

    np.testing.assert_allclose(newt_grads[0], gpflow_grads[0],
                               atol=1e-2)  # use atol since values are so small
    np.testing.assert_allclose(newt_grads[1], gpflow_grads[1], rtol=1e-2)
    np.testing.assert_allclose(newt_grads[2], gpflow_grads[2], rtol=1e-2)
Example #17
0
def build_train_step_fn(model, loss_fn):
    gradient_loss = objax.GradValues(loss_fn, model.vars())
    optimiser = objax.optimizer.Adam(model.vars())

    def train_step(learning_rate, rgb_img_t1, true_dither_t0, true_dither_t1):
        grads, _loss = gradient_loss(
            rgb_img_t1, true_dither_t0, true_dither_t1)
        grads = u.clip_gradients(grads, theta=opts.gradient_clip)
        optimiser(learning_rate, grads)
        grad_norms = [jnp.linalg.norm(g) for g in grads]
        return grad_norms

    if JIT:
        train_step = objax.Jit(
            train_step, gradient_loss.vars() + optimiser.vars())
    return train_step
Example #18
0
    def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs):
        """
        Completely standard training. Nothing interesting to see here.
        """
        super().__init__(nclass, **kwargs)
        self.model = model(1 if mnist else 3, nclass)
        self.opt = objax.optimizer.Momentum(self.model.vars())
        self.model_ema = objax.optimizer.ExponentialMovingAverageModule(
            self.model, momentum=0.999, debias=True)

        @objax.Function.with_vars(self.model.vars())
        def loss(x, label):
            logit = self.model(x, training=True)
            loss_wd = 0.5 * sum(
                (v.value**2).sum()
                for k, v in self.model.vars().items() if k.endswith('.w'))
            loss_xe = objax.functional.loss.cross_entropy_logits(logit,
                                                                 label).mean()
            return loss_xe + loss_wd * self.params.weight_decay, {
                'losses/xe': loss_xe,
                'losses/wd': loss_wd
            }

        gv = objax.GradValues(loss, self.model.vars())
        self.gv = gv

        @objax.Function.with_vars(self.vars())
        def train_op(progress, x, y):
            g, v = gv(x, y)
            lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
            lr = lr * jn.clip(progress * 100, 0, 1)
            self.opt(lr, g)
            self.model_ema.update_ema()
            return {'monitors/lr': lr, **v[1]}

        self.predict = objax.Jit(
            objax.nn.Sequential(
                [objax.ForceArgs(self.model_ema, training=False)]))

        self.train_op = objax.Jit(train_op)
Example #19
0
    def __init__(self, model: Callable, nclass: int, **kwargs):
        super().__init__(nclass, **kwargs)
        self.model = model(3, nclass)
        model_vars = self.model.vars()
        self.opt = objax.optimizer.Momentum(model_vars)
        self.ema = objax.optimizer.ExponentialMovingAverage(model_vars,
                                                            momentum=0.999,
                                                            debias=True)
        print(model_vars)

        def loss(x, label):
            logit = self.model(x, training=True)
            loss_wd = 0.5 * sum(
                (v.value**2).sum()
                for k, v in model_vars.items() if k.endswith('.w'))
            loss_xe = objax.functional.loss.cross_entropy_logits(logit,
                                                                 label).mean()
            return loss_xe + loss_wd * self.params.weight_decay, {
                'losses/xe': loss_xe,
                'losses/wd': loss_wd
            }

        gv = objax.GradValues(loss, model_vars)

        def train_op(progress, x, y):
            g, v = gv(x, y)
            lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
            self.opt(lr, objax.functional.parallel.pmean(g))
            self.ema()
            return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]})

        def predict_op(x):
            return objax.functional.softmax(self.model(x, training=False))

        self.predict = objax.Parallel(self.ema.replace_vars(predict_op),
                                      model_vars + self.ema.vars())
        self.train_op = objax.Parallel(train_op,
                                       self.vars(),
                                       reduce=lambda x: x[0])
Example #20
0
    def test_gradvalues_logistic(self):
        """Test if gradient has the correct value for logistic regression."""
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, -1.0, 1.0, -1.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.ones(ndim))

        def loss(x, y):
            xyw = jn.dot(x * np.tile(y, (ndim, 1)).transpose(), w.value)
            return jn.log(jn.exp(-xyw) + 1).mean(0)

        gv = objax.GradValues(loss, objax.VarCollection({'w': w}))
        g, v = gv(data, labels)

        self.assertEqual(g[0].shape, tuple([ndim]))

        xw = np.dot(data, w.value)
        g_expect_w = -(data * np.tile(labels / (1 + np.exp(labels * xw)), (ndim, 1)).transpose()).mean(0)
        np.testing.assert_allclose(g[0], g_expect_w, atol=1e-7)
        np.testing.assert_allclose(v[0], loss(data, labels))
Example #21
0
    def test_transform(self):
        def myloss(x):
            return (x ** 2).mean()

        g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1.,
                                                    l2_norm_clip=0.5, microbatch=1)
        self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,'
                                    ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))')
        self.assertEqual(repr(objax.Jit(gv)),
                         'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)')
        self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())),
                         'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)')
        self.assertEqual(repr(objax.Parallel(gv)),
                         "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,)),"
                         " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)")
        self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())),
                         'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))')
        self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')),
                         "objax.ForceArgs(module=GradValues, training=True, word='hello')")
Example #22
0
    def __init__(self, model: Callable, nclass: int, **kwargs):
        super().__init__(nclass, **kwargs)
        self.model = model(3, nclass)
        self.opt = objax.optimizer.Momentum(self.model.vars())
        self.model_ema = objax.optimizer.ExponentialMovingAverageModule(
            self.model, momentum=0.999, debias=True)

        @objax.Function.with_vars(self.model.vars())
        def loss(x, label):
            logit = self.model(x, training=True)
            loss_wd = 0.5 * sum(
                (v.value**2).sum()
                for k, v in self.model.vars().items() if k.endswith('.w'))
            loss_xe = objax.functional.loss.cross_entropy_logits(logit,
                                                                 label).mean()
            return loss_xe + loss_wd * self.params.weight_decay, {
                'losses/xe': loss_xe,
                'losses/wd': loss_wd
            }

        gv = objax.GradValues(loss, self.model.vars())

        @objax.Function.with_vars(self.vars())
        def train_op(progress, x, y):
            g, v = gv(x, y)
            lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
            self.opt(lr, objax.functional.parallel.pmean(g))
            self.model_ema.update_ema()
            return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]})

        self.predict = objax.Parallel(
            objax.nn.Sequential([
                objax.ForceArgs(self.model_ema, training=False),
                objax.functional.softmax
            ]))
        self.train_op = objax.Parallel(train_op, reduce=lambda x: x[0])
Example #23
0
    ])


source = jn.linspace(-5, 5, 100).reshape((100, 1))  # (k, 1)
target = jn.sin(source)

print('Standard training.')
net = make_net()
opt = objax.optimizer.Adam(net.vars())


def loss(x, y):
    return ((y - net(x))**2).mean()


gv = objax.GradValues(loss, net.vars())


def train_op():
    g, v = gv(source, target)
    opt(0.01, g)
    return v


train_op = objax.Jit(train_op, gv.vars() + opt.vars())

for i in range(100):
    train_op()

plt.plot(source, net(source), label='prediction')
plt.plot(source, (target - net(source))**2, label='loss')
Example #24
0
                                          R=r,
                                          Y=Y)
elif model_type == 2:
    model = newt.models.InfiniteHorizonGP(kernel=kern,
                                          likelihood=lik,
                                          X=t,
                                          R=r,
                                          Y=Y)

print('num spatial pts:', nr)
print(model)

inf = newt.inference.VariationalInference(cubature=newt.cubature.Unscented())

trainable_vars = model.vars() + inf.vars()
energy = objax.GradValues(inf.energy, trainable_vars)

lr_adam = 0.2
lr_newton = 0.2
iters = 100
opt = objax.optimizer.Adam(trainable_vars)


def train_op():
    inf(model, lr=lr_newton)  # perform inference and update variational params
    dE, E = energy(model)  # compute energy and its gradients w.r.t. hypers
    return dE, E


train_op = objax.Jit(train_op, trainable_vars)
Example #25
0
def train_model():
    """
    Train the patch similarity function
    """
    global ema, model

    model = Model()

    def loss(x, y):
        """
        K-way contrastive loss as in SimCLR et al.
        The idea is that we should embed x and y so that they are similar
        to each other, and dis-similar from others. To do this we have a
        softmx loss over one dimension to make the values large on the diagonal
        and small off-diagonal.
        """
        a = model.encode(x)
        b = model.decode(y)

        mat = a @ b.T
        return objax.functional.loss.cross_entropy_logits_sparse(
            logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) * mat,
            labels=np.arange(a.shape[0])).mean()

    ema = objax.optimizer.ExponentialMovingAverage(model.vars(),
                                                   momentum=0.999)
    gv = objax.GradValues(loss, model.vars())

    encode_ema = ema.replace_vars(lambda x: model.encode(x))
    decode_ema = ema.replace_vars(lambda y: model.decode(y))

    def train_op(x, y):
        """
        No one was ever fired for using Adam with 1e-4.
        """
        g, v = gv(x, y)
        opt(1e-4, g)
        ema()
        return v

    opt = objax.optimizer.Adam(model.vars())
    train_op = objax.Jit(train_op, gv.vars() + opt.vars() + ema.vars())

    ys_ = ys_train

    print(ys_.shape)

    xs_ = xs_train.reshape((-1, xs_train.shape[-1]))
    ys_ = ys_.reshape((-1, ys_train.shape[-1]))

    # The model scale trick here is taken from CLIP.
    # Let the model decide how confident to make its own predictions.
    model.scale.w.assign(jn.zeros((1, 1)))

    valid_size = 1000

    print(xs_train.shape)
    # SimCLR likes big batches
    B = 4096
    for it in range(80):
        print()
        ms = []
        for i in range(1000):
            # First batch is smaller, to make training more stable
            bs = [B // 64, B][it > 0]
            batch = np.random.randint(0, len(xs_) - valid_size, size=bs)
            r = train_op(xs_[batch], ys_[batch])

            # This shouldn't happen, but if it does, better to bort early
            if np.isnan(r):
                print("Die on nan")
                print(ms[-100:])
                return
            ms.append(r)

        print('mean', np.mean(ms), 'scale', model.scale.w.value)
        print('loss', loss(xs_[-100:], ys_[-100:]))

        a = encode_ema(xs_[-valid_size:])
        b = decode_ema(ys_[-valid_size:])

        br = b[np.random.permutation(len(b))]

        print('score',
              np.mean(np.sum(a * b, axis=(1)) - np.sum(a * br, axis=(1))),
              np.mean(np.sum(a * b, axis=(1)) > np.sum(a * br, axis=(1))))
    ckpt = objax.io.Checkpoint("saved", keep_ckpts=0)
    ema.replace_vars(lambda: ckpt.save(model.vars(), 0))()
Example #26
0
def train(opts):
    if opts.ortho_init in ['True', 'False']:
        # TODO: move this to argparses problem
        opts.ortho_init = opts.ortho_init == 'True'
    if opts.ortho_init not in [True, False]:
        raise Exception("unknown --ortho-init value [%s]" % opts.ortho_init)
    ortho_init = opts.ortho_init == 'True'

    train_frame_pairs = img_utils.parse_frame_pairs(opts.train_tsv)
    test_frame_pairs = img_utils.parse_frame_pairs(opts.test_tsv)

    # init w & b
    wandb_enabled = opts.group is not None
    if wandb_enabled:
        if opts.run is None:
            run = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
        else:
            run = opts.run
        wandb.init(project='embedding_the_chickens',
                   group=opts.group, name=run,
                   reinit=True)
        wandb.config.num_models = opts.num_models
        wandb.config.dense_kernel_size = opts.dense_kernel_size
        wandb.config.embedding_dim = opts.embedding_dim
        wandb.config.learning_rate = opts.learning_rate
        wandb.config.ortho_init = opts.ortho_init
        wandb.config.logit_temp = opts.logit_temp

    else:
        print("not using wandb", file=sys.stderr)

    # init model
    embed_net = ensemble_net.EnsembleNet(num_models=opts.num_models,
                                         dense_kernel_size=opts.dense_kernel_size,
                                         embedding_dim=opts.embedding_dim,
                                         seed=opts.seed,
                                         orthogonal_init=ortho_init)

    # make some jitted version of embed_net methods
    # TODO: objax decorator to do this?
    j_embed_net_calc_sims = objax.Jit(embed_net.calc_sims, embed_net.vars())
    j_embed_net_loss = objax.Jit(embed_net.loss, embed_net.vars())

    # run pair of crops through model and calculate optimal pairing
    def labels_for_crops(crops_t0, crops_t1):
        sims = j_embed_net_calc_sims(crops_t0, crops_t1)
        pairing = optimal_pairing.calculate(sims)
        labels = optimal_pairing.to_one_hot_labels(sims, pairing)
        return labels

    # init optimiser
    gradient_loss = objax.GradValues(embed_net.loss, embed_net.vars())
    optimiser = objax.optimizer.Adam(embed_net.vars())
    lr = 1e-3

    # create a jitted training step
    def train_step(crops_t0, crops_t1, labels):
        grads, loss = gradient_loss(crops_t0, crops_t1, labels)
        optimiser(lr, grads)
        return loss

    train_step = objax.Jit(train_step, gradient_loss.vars() + optimiser.vars())

    # run training loop
    for e in range(opts.epochs):
        # shuffle training examples
        orandom.seed(e)
        orandom.shuffle(train_frame_pairs)

        # make pass through training examples
        train_losses = []
        for f0, f1 in tqdm(train_frame_pairs):
            try:
                # load
                crops_t0 = img_utils.load_crops_as_floats(f"{f0}/crops.npy")
                crops_t1 = img_utils.load_crops_as_floats(f"{f1}/crops.npy")
                # calc labels for crops
                labels = labels_for_crops(crops_t0, crops_t1)
                # take update step
                loss = train_step(crops_t0, crops_t1, labels)
                train_losses.append(loss)
            except Exception:
                print("train exception", e, f0, f1, file=sys.stderr)
                traceback.print_exc(file=sys.stderr)

        # eval mean test loss
        test_losses = []
        for f0, f1 in test_frame_pairs:
            try:
                # load
                crops_t0 = img_utils.load_crops_as_floats(f"{f0}/crops.npy")
                crops_t1 = img_utils.load_crops_as_floats(f"{f1}/crops.npy")
                # calc labels for crops
                labels = labels_for_crops(crops_t0, crops_t1)
                # collect loss
                loss = j_embed_net_loss(crops_t0, crops_t1, labels)
                test_losses.append(loss)
            except Exception:
                print("test exception", e, f0, f1, file=sys.stderr)
                traceback.print_exc(file=sys.stderr)

        # log stats
        mean_train_loss = np.mean(train_losses)
        mean_test_loss = np.mean(test_losses)
        print(e, mean_train_loss, mean_test_loss)

        # TODO: or only if train_loss is Nan?
        nan_loss = np.isnan(mean_train_loss) or np.isnan(mean_test_loss)
        if wandb_enabled and not nan_loss:
            wandb.log({'train_loss': np.mean(train_losses)})
            wandb.log({'test_loss': np.mean(test_losses)})

    # close out wandb run
    wandb.join()

    # note: use None value to indicate run failed
    if nan_loss:
        return None
    else:
        return mean_test_loss
Example #27
0
# ==========================
#  Loss Function
# ==========================
@objax.Function.with_vars(gf_model.vars())
def nll_loss(x):
    return gf_model.score(x)


# =========================
# Optimizer
# =========================
# define the optimizer
opt = objax.optimizer.Adam(gf_model.vars())

# get grad values
gv = objax.GradValues(nll_loss, gf_model.vars())
lr = wandb_logger.config.learning_rate
epochs = wandb_logger.config.epochs
batchsize = wandb_logger.config.batchsize


# define the training operation
@objax.Function.with_vars(gf_model.vars() + opt.vars())
def train_op(x):
    g, v = gv(x)  # returns gradients, loss
    opt(lr, g)
    return v


# This line is optional: it is compiling the code to make it faster.
train_op = objax.Jit(train_op)
 def __init__(self, model, *args, **kwargs):
     super().__init__(model, *args, **kwargs)
     self.loss = objax.Jit(self.loss, model.vars())
     #self.model = objax.Jit(self.model)
     self.gradvals = objax.Jit(objax.GradValues(self.loss, model.vars(
     )))  #objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars())
Example #29
0
lr = 0.0001  # learning rate
batch = 256
epochs = 20

# Model
model = objax.nn.Linear(ndim, 1)
opt = objax.optimizer.SGD(model.vars())
print(model.vars())


# Cross Entropy Loss
def loss(x, label):
    return objax.functional.loss.sigmoid_cross_entropy_logits(model(x)[:, 0], label).mean()


gv = objax.GradValues(loss, model.vars())


def train_op(x, label):
    g, v = gv(x, label)  # returns gradients, loss
    opt(lr, g)
    return v


# This line is optional: it is compiling the code to make it faster.
# gv.vars() contains the model variables.
train_op = objax.Jit(train_op, gv.vars() + opt.vars())

# Training
for epoch in range(epochs):
    # Train
                                         spatial_kernel=kern_space,
                                         z=R[0],
                                         sparse=True,
                                         opt_z=False,
                                         conditional='Full')
inf = newt.inference.VariationalInference()

markov = True

if markov:
    model = newt.models.MarkovGP(kernel=kern, likelihood=lik, X=t, R=R, Y=Y)
    # model = newt.models.MarkovGP(kernel=kern, likelihood=lik, X=X, Y=y)
else:
    model = newt.models.GP(kernel=kern, likelihood=lik, X=X, Y=y)

compute_energy_and_update = objax.GradValues(inf, model.vars())

lr_adam = 0.
lr_newton = 1.
epochs = 2
opt = objax.optimizer.Adam(model.vars())


def train_op():
    model.update_posterior()
    grads, loss_ = compute_energy_and_update(model, lr=lr_newton)
    # print(grads)
    for g, var_name in zip(grads, model.vars().keys()):  # TODO: this gives wrong label to likelihood variance
        print(g, ' w.r.t. ', var_name)
    # print(model.kernel.temporal_kernel.variance)
    opt(lr_adam, grads)