def test_gradvalues_constant(self): """Test if constants are preserved.""" # Set data ndim = 1 data = np.array([[1.0], [3.0], [5.0], [-10.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.ones(1)) m = objax.ModuleList([w, b]) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred)**2).mean() # We are supposed to see the gradient change after the value of b (the constant) changes. gv = objax.GradValues(loss, objax.VarCollection({'w': w})) g_old, v_old = gv(data, labels) b.assign(-b.value) g_new, v_new = gv(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0]) # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes. gv = objax.Jit(objax.GradValues(loss, objax.VarCollection({'w': w})), m.vars()) g_old, v_old = gv(data, labels) b.assign(-b.value) g_new, v_new = gv(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0])
def test_gradvalues_linear_and_inputs(self): """Test if gradient of inputs and variables has the correct values for linear regression.""" # Set data data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.array([2, 3], jn.float32)) b = objax.TrainVar(jn.array([1], jn.float32)) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred) ** 2).mean() expect_loss = loss(data, labels) expect_gw = [37.25, 69.0] expect_gb = [13.75] expect_gx = [[4.0, 6.0], [8.5, 12.75], [13.0, 19.5], [2.0, 3.0]] expect_gy = [-2.0, -4.25, -6.5, -1.0] gv0 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0,)) g, v = gv0(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gw) self.assertEqual(g[2].tolist(), expect_gb) gv1 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1,)) g, v = gv1(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gy) self.assertEqual(g[1].tolist(), expect_gw) self.assertEqual(g[2].tolist(), expect_gb) gv01 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(0, 1)) g, v = gv01(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gy) self.assertEqual(g[2].tolist(), expect_gw) self.assertEqual(g[3].tolist(), expect_gb) gv10 = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b}), input_argnums=(1, 0)) g, v = gv10(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gy) self.assertEqual(g[1].tolist(), expect_gx) self.assertEqual(g[2].tolist(), expect_gw) self.assertEqual(g[3].tolist(), expect_gb) gv10 = objax.GradValues(loss, None, input_argnums=(0, 1)) g, v = gv10(data, labels) self.assertEqual(v[0], expect_loss) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gy)
def test_gradient_step(var_f, len_f, var_y, N): """ test whether VI with newt's GP and MarkovGP provide the same initial gradient step in the hyperparameters """ x, Y, t, R, y = build_data(N) gp_model = initialise_gp_model(var_f, len_f, var_y, x, y, R[0]) markovgp_model = initialise_markovgp_model(var_f, len_f, var_y, x, y, R[0]) gv = objax.GradValues(inf, gp_model.vars()) gv_markov = objax.GradValues(inf, markovgp_model.vars()) lr_adam = 0.1 lr_newton = 1. opt = objax.optimizer.Adam(gp_model.vars()) opt_markov = objax.optimizer.Adam(markovgp_model.vars()) gp_model.update_posterior() gp_grads, gp_value = gv(gp_model, lr=lr_newton) gp_loss_ = gp_value[0] opt(lr_adam, gp_grads) gp_hypers = np.array([ gp_model.kernel.temporal_kernel.lengthscale, gp_model.kernel.temporal_kernel.variance, gp_model.kernel.spatial_kernel.lengthscale, gp_model.likelihood.variance ]) print(gp_hypers) print(gp_grads) markovgp_model.update_posterior() markovgp_grads, markovgp_value = gv_markov(markovgp_model, lr=lr_newton) markovgp_loss_ = markovgp_value[0] opt_markov(lr_adam, markovgp_grads) markovgp_hypers = np.array([ markovgp_model.kernel.temporal_kernel.lengthscale, markovgp_model.kernel.temporal_kernel.variance, markovgp_model.kernel.spatial_kernel.lengthscale, markovgp_model.likelihood.variance ]) print(markovgp_hypers) print(markovgp_grads) np.testing.assert_allclose(gp_grads[0], markovgp_grads[0], rtol=1e-4) np.testing.assert_allclose(gp_grads[1], markovgp_grads[1], rtol=1e-4) np.testing.assert_allclose(gp_grads[2], markovgp_grads[2], rtol=1e-4)
def test_gradient_step(var_f, len_f, var_y, N): """ test whether MarkovGP and SparseMarkovGP provide the same initial gradient step in the hyperparameters (Z=X) """ x, y = build_data(N) z = x + np.random.normal(0, .05, x.shape) gp_model = initialise_gp_model(var_f, len_f, var_y, x, y) sparsemarkovgp_model = initialise_sparsemarkovgp_model( var_f, len_f, var_y, x, y, z) gv = objax.GradValues(inf, gp_model.vars()) gv_markov = objax.GradValues(inf, sparsemarkovgp_model.vars()) lr_adam = 0.1 lr_newton = 1. opt = objax.optimizer.Adam(gp_model.vars()) opt_markov = objax.optimizer.Adam(sparsemarkovgp_model.vars()) gp_model.update_posterior() gp_grads, gp_value = gv(gp_model, lr=lr_newton) gp_loss_ = gp_value[0] opt(lr_adam, gp_grads) gp_hypers = np.array([ gp_model.kernel.lengthscale, gp_model.kernel.variance, gp_model.likelihood.variance ]) print(gp_hypers) print(gp_grads) sparsemarkovgp_model.update_posterior() markovgp_grads, markovgp_value = gv_markov(sparsemarkovgp_model, lr=lr_newton) markovgp_loss_ = markovgp_value[0] opt_markov(lr_adam, markovgp_grads) markovgp_hypers = np.array([ sparsemarkovgp_model.kernel.lengthscale, sparsemarkovgp_model.kernel.variance, sparsemarkovgp_model.likelihood.variance ]) print(markovgp_hypers) print(markovgp_grads) np.testing.assert_allclose(gp_grads[0], markovgp_grads[0], atol=3e-1) np.testing.assert_allclose(gp_grads[1], markovgp_grads[1], rtol=5e-2) np.testing.assert_allclose(gp_grads[2], markovgp_grads[2], rtol=5e-2)
def test_gradvalues_linear(self): """Test if gradient has the correct value for linear regression.""" # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.zeros(1)) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred)**2).mean() gv = objax.GradValues(loss, objax.VarCollection({'w': w, 'b': b})) g, v = gv(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) self.assertEqual(g[1].shape, tuple([1])) g_expect_w = -(data * np.tile(labels, (ndim, 1)).transpose()).mean(0) g_expect_b = np.array([-labels.mean()]) np.testing.assert_allclose(g[0], g_expect_w) np.testing.assert_allclose(g[1], g_expect_b) np.testing.assert_allclose(v[0], loss(data, labels))
def __init__(self): # Some constants total_batch_size = FLAGS.train_device_batch_size * jax.device_count() self.base_learning_rate = FLAGS.base_learning_rate * total_batch_size / 256 # Create model bn_cls = objax.nn.SyncedBatchNorm2D if FLAGS.use_sync_bn else objax.nn.BatchNorm2D self.model = ResNet50(in_channels=3, num_classes=NUM_CLASSES, normalization_fn=bn_cls) self.model_vars = self.model.vars() print(self.model_vars) # Create parallel eval op self.evaluate_batch_parallel = objax.Parallel( self.evaluate_batch, self.model_vars, reduce=lambda x: x.sum(0)) # Create parallel training op self.optimizer = objax.optimizer.Momentum(self.model_vars, momentum=0.9, nesterov=True) self.compute_grads_loss = objax.GradValues(self.loss_fn, self.model_vars) self.all_vars = self.model_vars + self.optimizer.vars() self.train_op_parallel = objax.Parallel(self.train_op, self.all_vars, reduce=lambda x: x[0]) # Summary writer self.summary_writer = objax.jaxboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'tb'))
def __init__(self, model, dataloaders, optim=objax.optimizer.Adam, lr_sched=lambda e: 1, log_dir=None, log_suffix='', log_args={}, early_stop_metric=None): # Setup model, optimizer, and dataloaders self.model = model # #self.model= objax.Jit(objax.ForceArgs(model,training=True)) #TODO: figure out static nums #self.model.predict = objax.Jit(objax.ForceArgs(model.__call__,training=False),model.vars()) #self.model.predict = objax.ForceArgs(model.__call__,training=False) #self._model = model #self.model = objax.ForceArgs(model,training=True) #self.model.predict = objax.ForceArgs(model.__call__,training=False) #self.model = objax.Jit(lambda x, training: model(x,training=training),model.vars(),static_argnums=(1,)) #self.model = objax.Jit(model,static_argnums=(1,)) self.optimizer = optim(model.vars()) self.lr_sched = lr_sched self.dataloaders = dataloaders # A dictionary of dataloaders self.epoch = 0 self.logger = LazyLogger(log_dir, log_suffix, **log_args) #self.logger.add_text('ModelSpec','model: {}'.format(model)) self.hypers = {} self.ckpt = None # copy.deepcopy(self.state_dict()) #TODO fix model saving self.early_stop_metric = early_stop_metric #fastloss = objax.Jit(self.loss,model.vars()) self.gradvals = objax.GradValues(self.loss, self.model.vars())
def test_private_gradvalues_compare_nonpriv(self): """Test if PrivateGradValues without clipping / noise is the same as non-private GradValues.""" l2_norm_clip = 1e10 noise_multiplier = 0 for use_norm_accumulation in [True, False]: for microbatch in [1, 10, self.ntrain]: gv_priv = objax.Jit( objax.privacy.dpsgd.PrivateGradValues( self.loss, self.model_vars, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0, 0), use_norm_accumulation=use_norm_accumulation)) gv = objax.GradValues(self.loss, self.model_vars) g_priv, v_priv = gv_priv(self.data, self.labels) g, v = gv(self.data, self.labels) # Check the shape of the gradient. self.assertEqual(g_priv[0].shape, tuple([self.nclass])) self.assertEqual(g_priv[1].shape, tuple([self.ndim, self.nclass])) # Check if the private gradient is similar to the non-private gradient. np.testing.assert_allclose(g[0], g_priv[0], atol=1e-7) np.testing.assert_allclose(g[1], g_priv[1], atol=1e-7) np.testing.assert_allclose(v_priv[0], self.loss(self.data, self.labels)[0], atol=1e-7)
def test_typical_training_loop(self): # Define model and optimizer model = DNNet((32, 10), objax.functional.leaky_relu) opt = objax.optimizer.Momentum(model.vars(), nesterov=True) # Predict op predict_op = lambda x: objax.functional.softmax(model(x, training=False)) self.assertDictEqual(objax.util.find_used_variables(predict_op), model.vars(scope='model.')) # Loss function def loss(x, label): logit = model(x, training=True) xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean() return xe_loss self.assertDictEqual(objax.util.find_used_variables(loss), model.vars(scope='model.')) # Gradients and loss function loss_gv = objax.GradValues(loss, objax.util.find_used_variables(loss)) def train_op(x, y, learning_rate): grads, loss = loss_gv(x, y) opt(learning_rate, grads) return loss self.assertDictEqual(objax.util.find_used_variables(train_op), {**model.vars(scope='loss_gv.model.'), **opt.vars(scope='opt.')})
def __init__(self, model, *args, **kwargs): super().__init__(model, *args, **kwargs) fastloss = objax.Jit(self.loss, model.vars()) self.gradvals = objax.Jit(objax.GradValues(fastloss, model.vars()), model.vars()) self.model.predict = objax.Jit( objax.ForceArgs(model.__call__, training=False), model.vars())
def test_gradvalues_signature(self): def f(x: JaxArray, y) -> Tuple[JaxArray, Dict[str, JaxArray]]: return (x + y).mean(), {'x': x, 'y': y} def df(x: JaxArray, y) -> Tuple[List[JaxArray], Tuple[JaxArray, Dict[str, JaxArray]]]: pass # Signature of the (differential of f, f) g = objax.GradValues(f, objax.VarCollection()) self.assertEqual(inspect.signature(g), inspect.signature(df))
def _test_loss_opt(self, loss_name: str, opt_name: str): """Given loss and optimizer name, get definitions and run test.""" model_vars, loss = self._get_loss(loss_name) gv = objax.GradValues(loss, model_vars) opt = self._get_optimizer(model_vars, opt_name) lr = self.lrs['{}_{}'.format(loss_name, opt_name)] tolerance = self.tolerances['{}_{}'.format(loss_name, opt_name)] self._check_run(gv, opt, loss, lr, self.num_steps, tolerance) return model_vars, loss
def __init__(self, nclass: int, model: Callable, **kwargs): super().__init__(nclass, **kwargs) self.model = model(nin=3, nclass=nclass, **kwargs) model_vars = self.model.vars() self.ema = objax.optimizer.ExponentialMovingAverage(model_vars.subset( objax.TrainVar), momentum=0.999, debias=False) self.opt = objax.optimizer.Momentum(model_vars, momentum=0.9, nesterov=True) def loss_function(x, u, y): xu = jn.concatenate([x, u.reshape((-1, ) + u.shape[2:])], axis=0) logit = self.model(xu, training=True) logit_x = logit[:x.shape[0]] logit_weak = logit[x.shape[0]::2] logit_strong = logit[x.shape[0] + 1::2] xe = objax.functional.loss.cross_entropy_logits(logit_x, y).mean() pseudo_labels = objax.functional.stop_gradient( objax.functional.softmax(logit_weak)) pseudo_mask = (pseudo_labels.max(axis=1) >= self.params.confidence).astype(logit_weak.dtype) xeu = objax.functional.loss.cross_entropy_logits_sparse( logit_strong, pseudo_labels.argmax(axis=1)) xeu = (xeu * pseudo_mask).mean() wd = 0.5 * sum( objax.functional.loss.l2(v.value) for k, v in model_vars.items() if k.endswith('.w')) loss = xe + self.params.wu * xeu + self.params.wd * wd return loss, { 'losses/xe': xe, 'losses/xeu': xeu, 'losses/wd': wd, 'monitors/mask': pseudo_mask.mean() } gv = objax.GradValues(loss_function, model_vars) def train_op(step, x, u, y): g, v = gv(x, u, y) fstep = step[0] / (FLAGS.train_kimg << 10) lr = self.params.lr * jn.cos(fstep * (7 * jn.pi) / (2 * 8)) self.opt(lr, objax.functional.parallel.pmean(g)) self.ema() return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]}) eval_op = self.ema.replace_vars( lambda x: objax.functional.softmax(self.model(x, training=False))) self.eval_op = objax.Parallel(eval_op, model_vars + self.ema.vars()) self.train_op = objax.Parallel(train_op, self.vars(), reduce=lambda x: x[0])
def _test_loss_opt(self, loss_name: str, opt_name: str, override: bool = False): """Given loss and optimizer name, get definitions and run test.""" model_vars, loss = self._get_loss(loss_name) gv = objax.GradValues(loss, model_vars) opt = self._get_optimizer(model_vars, opt_name) test_name = '{}_{}'.format(loss_name, opt_name) test_name = test_name + '_override' if override else test_name lr = self.lrs[test_name] tolerance = self.tolerances[test_name] options = self.override_options[test_name] if override and test_name in self.override_options else None self._check_run(gv, opt, loss, lr, self.num_steps, tolerance, options) return model_vars, loss
def test_parallel_train_op(self): """Parallel train op.""" f = objax.nn.Sequential([ objax.nn.Linear(3, 4), objax.functional.relu, objax.nn.Linear(4, 2) ]) centers = objax.random.normal((1, 2, 3)) centers *= objax.functional.rsqrt((centers**2).sum(2, keepdims=True)) x = (objax.random.normal((256, 1, 3), stddev=0.1) + centers).reshape( (512, 3)) y = jn.concatenate([ jn.zeros((256, 1), dtype=jn.uint32), jn.ones((256, 1), dtype=jn.uint32) ], axis=1) y = y.reshape((512, )) opt = objax.optimizer.Momentum(f.vars()) all_vars = f.vars('f') + opt.vars('opt') def loss(x, y): xe = objax.functional.loss.cross_entropy_logits_sparse(f(x), y) return xe.mean() gv = objax.GradValues(loss, f.vars()) def train_op(x, y): g, v = gv(x, y) opt(0.05, g) return v tensors = all_vars.tensors() loss_value = np.array([train_op(x, y)[0] for _ in range(10)]) var_values = {k: v.value for k, v in all_vars.items()} all_vars.assign(tensors) self.assertGreater(loss_value.min(), 0) def train_op_para(x, y): g, v = gv(x, y) opt(0.05, objax.functional.parallel.pmean(g)) return objax.functional.parallel.pmean(v) fp = objax.Parallel(train_op_para, vc=all_vars, reduce=lambda x: x[0]) with all_vars.replicate(): loss_value_p = np.array([fp(x, y)[0] for _ in range(10)]) var_values_p = {k: v.value for k, v in all_vars.items()} self.assertLess(jn.abs(loss_value_p / loss_value - 1).max(), 1e-6) for k, v in var_values.items(): self.assertLess(((v - var_values_p[k])**2).sum(), 1e-12, msg=k)
def test_gradient_step(var_f, len_f, var_y, N): """ test whether newt's VI and gpflow's SVGP (Z=X) provide the same initial gradient step in the hyperparameters """ x, y = build_data(N) newt_model = initialise_newt_model(var_f, len_f, var_y, x, y) gpflow_model = initialise_gpflow_model(var_f, len_f, var_y, x, y) gv = objax.GradValues(inf.energy, newt_model.vars()) lr_adam = 0.1 lr_newton = 1. opt = objax.optimizer.Adam(newt_model.vars()) newt_model.update_posterior() newt_grads, value = gv(newt_model) # , lr=lr_newton) loss_ = value[0] opt(lr_adam, newt_grads) newt_hypers = np.array([ newt_model.kernel.lengthscale, newt_model.kernel.variance, newt_model.likelihood.variance ]) print(newt_hypers) print(newt_grads) adam_opt = tf.optimizers.Adam(lr_adam) data = (x, y[:, None]) with tf.GradientTape() as tape: loss = -gpflow_model.elbo(data) _vars = gpflow_model.trainable_variables gpflow_grads = tape.gradient(loss, _vars) loss_fn = gpflow_model.training_loss_closure(data) adam_vars = gpflow_model.trainable_variables adam_opt.minimize(loss_fn, adam_vars) gpflow_hypers = np.array([ gpflow_model.kernel.lengthscales.numpy()[0], gpflow_model.kernel.variance.numpy(), gpflow_model.likelihood.variance.numpy() ]) print(gpflow_hypers) print(gpflow_grads) np.testing.assert_allclose(newt_grads[0], gpflow_grads[0], atol=1e-2) # use atol since values are so small np.testing.assert_allclose(newt_grads[1], gpflow_grads[1], rtol=1e-2) np.testing.assert_allclose(newt_grads[2], gpflow_grads[2], rtol=1e-2)
def build_train_step_fn(model, loss_fn): gradient_loss = objax.GradValues(loss_fn, model.vars()) optimiser = objax.optimizer.Adam(model.vars()) def train_step(learning_rate, rgb_img_t1, true_dither_t0, true_dither_t1): grads, _loss = gradient_loss( rgb_img_t1, true_dither_t0, true_dither_t1) grads = u.clip_gradients(grads, theta=opts.gradient_clip) optimiser(learning_rate, grads) grad_norms = [jnp.linalg.norm(g) for g in grads] return grad_norms if JIT: train_step = objax.Jit( train_step, gradient_loss.vars() + optimiser.vars()) return train_step
def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs): """ Completely standard training. Nothing interesting to see here. """ super().__init__(nclass, **kwargs) self.model = model(1 if mnist else 3, nclass) self.opt = objax.optimizer.Momentum(self.model.vars()) self.model_ema = objax.optimizer.ExponentialMovingAverageModule( self.model, momentum=0.999, debias=True) @objax.Function.with_vars(self.model.vars()) def loss(x, label): logit = self.model(x, training=True) loss_wd = 0.5 * sum( (v.value**2).sum() for k, v in self.model.vars().items() if k.endswith('.w')) loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean() return loss_xe + loss_wd * self.params.weight_decay, { 'losses/xe': loss_xe, 'losses/wd': loss_wd } gv = objax.GradValues(loss, self.model.vars()) self.gv = gv @objax.Function.with_vars(self.vars()) def train_op(progress, x, y): g, v = gv(x, y) lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8)) lr = lr * jn.clip(progress * 100, 0, 1) self.opt(lr, g) self.model_ema.update_ema() return {'monitors/lr': lr, **v[1]} self.predict = objax.Jit( objax.nn.Sequential( [objax.ForceArgs(self.model_ema, training=False)])) self.train_op = objax.Jit(train_op)
def __init__(self, model: Callable, nclass: int, **kwargs): super().__init__(nclass, **kwargs) self.model = model(3, nclass) model_vars = self.model.vars() self.opt = objax.optimizer.Momentum(model_vars) self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True) print(model_vars) def loss(x, label): logit = self.model(x, training=True) loss_wd = 0.5 * sum( (v.value**2).sum() for k, v in model_vars.items() if k.endswith('.w')) loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean() return loss_xe + loss_wd * self.params.weight_decay, { 'losses/xe': loss_xe, 'losses/wd': loss_wd } gv = objax.GradValues(loss, model_vars) def train_op(progress, x, y): g, v = gv(x, y) lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8)) self.opt(lr, objax.functional.parallel.pmean(g)) self.ema() return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]}) def predict_op(x): return objax.functional.softmax(self.model(x, training=False)) self.predict = objax.Parallel(self.ema.replace_vars(predict_op), model_vars + self.ema.vars()) self.train_op = objax.Parallel(train_op, self.vars(), reduce=lambda x: x[0])
def test_gradvalues_logistic(self): """Test if gradient has the correct value for logistic regression.""" # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, -1.0, 1.0, -1.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.ones(ndim)) def loss(x, y): xyw = jn.dot(x * np.tile(y, (ndim, 1)).transpose(), w.value) return jn.log(jn.exp(-xyw) + 1).mean(0) gv = objax.GradValues(loss, objax.VarCollection({'w': w})) g, v = gv(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) xw = np.dot(data, w.value) g_expect_w = -(data * np.tile(labels / (1 + np.exp(labels * xw)), (ndim, 1)).transpose()).mean(0) np.testing.assert_allclose(g[0], g_expect_w, atol=1e-7) np.testing.assert_allclose(v[0], loss(data, labels))
def test_transform(self): def myloss(x): return (x ** 2).mean() g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1., l2_norm_clip=0.5, microbatch=1) self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,' ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))') self.assertEqual(repr(objax.Jit(gv)), 'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)') self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())), 'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)') self.assertEqual(repr(objax.Parallel(gv)), "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,))," " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)") self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())), 'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))') self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')), "objax.ForceArgs(module=GradValues, training=True, word='hello')")
def __init__(self, model: Callable, nclass: int, **kwargs): super().__init__(nclass, **kwargs) self.model = model(3, nclass) self.opt = objax.optimizer.Momentum(self.model.vars()) self.model_ema = objax.optimizer.ExponentialMovingAverageModule( self.model, momentum=0.999, debias=True) @objax.Function.with_vars(self.model.vars()) def loss(x, label): logit = self.model(x, training=True) loss_wd = 0.5 * sum( (v.value**2).sum() for k, v in self.model.vars().items() if k.endswith('.w')) loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean() return loss_xe + loss_wd * self.params.weight_decay, { 'losses/xe': loss_xe, 'losses/wd': loss_wd } gv = objax.GradValues(loss, self.model.vars()) @objax.Function.with_vars(self.vars()) def train_op(progress, x, y): g, v = gv(x, y) lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8)) self.opt(lr, objax.functional.parallel.pmean(g)) self.model_ema.update_ema() return objax.functional.parallel.pmean({'monitors/lr': lr, **v[1]}) self.predict = objax.Parallel( objax.nn.Sequential([ objax.ForceArgs(self.model_ema, training=False), objax.functional.softmax ])) self.train_op = objax.Parallel(train_op, reduce=lambda x: x[0])
]) source = jn.linspace(-5, 5, 100).reshape((100, 1)) # (k, 1) target = jn.sin(source) print('Standard training.') net = make_net() opt = objax.optimizer.Adam(net.vars()) def loss(x, y): return ((y - net(x))**2).mean() gv = objax.GradValues(loss, net.vars()) def train_op(): g, v = gv(source, target) opt(0.01, g) return v train_op = objax.Jit(train_op, gv.vars() + opt.vars()) for i in range(100): train_op() plt.plot(source, net(source), label='prediction') plt.plot(source, (target - net(source))**2, label='loss')
R=r, Y=Y) elif model_type == 2: model = newt.models.InfiniteHorizonGP(kernel=kern, likelihood=lik, X=t, R=r, Y=Y) print('num spatial pts:', nr) print(model) inf = newt.inference.VariationalInference(cubature=newt.cubature.Unscented()) trainable_vars = model.vars() + inf.vars() energy = objax.GradValues(inf.energy, trainable_vars) lr_adam = 0.2 lr_newton = 0.2 iters = 100 opt = objax.optimizer.Adam(trainable_vars) def train_op(): inf(model, lr=lr_newton) # perform inference and update variational params dE, E = energy(model) # compute energy and its gradients w.r.t. hypers return dE, E train_op = objax.Jit(train_op, trainable_vars)
def train_model(): """ Train the patch similarity function """ global ema, model model = Model() def loss(x, y): """ K-way contrastive loss as in SimCLR et al. The idea is that we should embed x and y so that they are similar to each other, and dis-similar from others. To do this we have a softmx loss over one dimension to make the values large on the diagonal and small off-diagonal. """ a = model.encode(x) b = model.decode(y) mat = a @ b.T return objax.functional.loss.cross_entropy_logits_sparse( logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) * mat, labels=np.arange(a.shape[0])).mean() ema = objax.optimizer.ExponentialMovingAverage(model.vars(), momentum=0.999) gv = objax.GradValues(loss, model.vars()) encode_ema = ema.replace_vars(lambda x: model.encode(x)) decode_ema = ema.replace_vars(lambda y: model.decode(y)) def train_op(x, y): """ No one was ever fired for using Adam with 1e-4. """ g, v = gv(x, y) opt(1e-4, g) ema() return v opt = objax.optimizer.Adam(model.vars()) train_op = objax.Jit(train_op, gv.vars() + opt.vars() + ema.vars()) ys_ = ys_train print(ys_.shape) xs_ = xs_train.reshape((-1, xs_train.shape[-1])) ys_ = ys_.reshape((-1, ys_train.shape[-1])) # The model scale trick here is taken from CLIP. # Let the model decide how confident to make its own predictions. model.scale.w.assign(jn.zeros((1, 1))) valid_size = 1000 print(xs_train.shape) # SimCLR likes big batches B = 4096 for it in range(80): print() ms = [] for i in range(1000): # First batch is smaller, to make training more stable bs = [B // 64, B][it > 0] batch = np.random.randint(0, len(xs_) - valid_size, size=bs) r = train_op(xs_[batch], ys_[batch]) # This shouldn't happen, but if it does, better to bort early if np.isnan(r): print("Die on nan") print(ms[-100:]) return ms.append(r) print('mean', np.mean(ms), 'scale', model.scale.w.value) print('loss', loss(xs_[-100:], ys_[-100:])) a = encode_ema(xs_[-valid_size:]) b = decode_ema(ys_[-valid_size:]) br = b[np.random.permutation(len(b))] print('score', np.mean(np.sum(a * b, axis=(1)) - np.sum(a * br, axis=(1))), np.mean(np.sum(a * b, axis=(1)) > np.sum(a * br, axis=(1)))) ckpt = objax.io.Checkpoint("saved", keep_ckpts=0) ema.replace_vars(lambda: ckpt.save(model.vars(), 0))()
def train(opts): if opts.ortho_init in ['True', 'False']: # TODO: move this to argparses problem opts.ortho_init = opts.ortho_init == 'True' if opts.ortho_init not in [True, False]: raise Exception("unknown --ortho-init value [%s]" % opts.ortho_init) ortho_init = opts.ortho_init == 'True' train_frame_pairs = img_utils.parse_frame_pairs(opts.train_tsv) test_frame_pairs = img_utils.parse_frame_pairs(opts.test_tsv) # init w & b wandb_enabled = opts.group is not None if wandb_enabled: if opts.run is None: run = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') else: run = opts.run wandb.init(project='embedding_the_chickens', group=opts.group, name=run, reinit=True) wandb.config.num_models = opts.num_models wandb.config.dense_kernel_size = opts.dense_kernel_size wandb.config.embedding_dim = opts.embedding_dim wandb.config.learning_rate = opts.learning_rate wandb.config.ortho_init = opts.ortho_init wandb.config.logit_temp = opts.logit_temp else: print("not using wandb", file=sys.stderr) # init model embed_net = ensemble_net.EnsembleNet(num_models=opts.num_models, dense_kernel_size=opts.dense_kernel_size, embedding_dim=opts.embedding_dim, seed=opts.seed, orthogonal_init=ortho_init) # make some jitted version of embed_net methods # TODO: objax decorator to do this? j_embed_net_calc_sims = objax.Jit(embed_net.calc_sims, embed_net.vars()) j_embed_net_loss = objax.Jit(embed_net.loss, embed_net.vars()) # run pair of crops through model and calculate optimal pairing def labels_for_crops(crops_t0, crops_t1): sims = j_embed_net_calc_sims(crops_t0, crops_t1) pairing = optimal_pairing.calculate(sims) labels = optimal_pairing.to_one_hot_labels(sims, pairing) return labels # init optimiser gradient_loss = objax.GradValues(embed_net.loss, embed_net.vars()) optimiser = objax.optimizer.Adam(embed_net.vars()) lr = 1e-3 # create a jitted training step def train_step(crops_t0, crops_t1, labels): grads, loss = gradient_loss(crops_t0, crops_t1, labels) optimiser(lr, grads) return loss train_step = objax.Jit(train_step, gradient_loss.vars() + optimiser.vars()) # run training loop for e in range(opts.epochs): # shuffle training examples orandom.seed(e) orandom.shuffle(train_frame_pairs) # make pass through training examples train_losses = [] for f0, f1 in tqdm(train_frame_pairs): try: # load crops_t0 = img_utils.load_crops_as_floats(f"{f0}/crops.npy") crops_t1 = img_utils.load_crops_as_floats(f"{f1}/crops.npy") # calc labels for crops labels = labels_for_crops(crops_t0, crops_t1) # take update step loss = train_step(crops_t0, crops_t1, labels) train_losses.append(loss) except Exception: print("train exception", e, f0, f1, file=sys.stderr) traceback.print_exc(file=sys.stderr) # eval mean test loss test_losses = [] for f0, f1 in test_frame_pairs: try: # load crops_t0 = img_utils.load_crops_as_floats(f"{f0}/crops.npy") crops_t1 = img_utils.load_crops_as_floats(f"{f1}/crops.npy") # calc labels for crops labels = labels_for_crops(crops_t0, crops_t1) # collect loss loss = j_embed_net_loss(crops_t0, crops_t1, labels) test_losses.append(loss) except Exception: print("test exception", e, f0, f1, file=sys.stderr) traceback.print_exc(file=sys.stderr) # log stats mean_train_loss = np.mean(train_losses) mean_test_loss = np.mean(test_losses) print(e, mean_train_loss, mean_test_loss) # TODO: or only if train_loss is Nan? nan_loss = np.isnan(mean_train_loss) or np.isnan(mean_test_loss) if wandb_enabled and not nan_loss: wandb.log({'train_loss': np.mean(train_losses)}) wandb.log({'test_loss': np.mean(test_losses)}) # close out wandb run wandb.join() # note: use None value to indicate run failed if nan_loss: return None else: return mean_test_loss
# ========================== # Loss Function # ========================== @objax.Function.with_vars(gf_model.vars()) def nll_loss(x): return gf_model.score(x) # ========================= # Optimizer # ========================= # define the optimizer opt = objax.optimizer.Adam(gf_model.vars()) # get grad values gv = objax.GradValues(nll_loss, gf_model.vars()) lr = wandb_logger.config.learning_rate epochs = wandb_logger.config.epochs batchsize = wandb_logger.config.batchsize # define the training operation @objax.Function.with_vars(gf_model.vars() + opt.vars()) def train_op(x): g, v = gv(x) # returns gradients, loss opt(lr, g) return v # This line is optional: it is compiling the code to make it faster. train_op = objax.Jit(train_op)
def __init__(self, model, *args, **kwargs): super().__init__(model, *args, **kwargs) self.loss = objax.Jit(self.loss, model.vars()) #self.model = objax.Jit(self.model) self.gradvals = objax.Jit(objax.GradValues(self.loss, model.vars( ))) #objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars())
lr = 0.0001 # learning rate batch = 256 epochs = 20 # Model model = objax.nn.Linear(ndim, 1) opt = objax.optimizer.SGD(model.vars()) print(model.vars()) # Cross Entropy Loss def loss(x, label): return objax.functional.loss.sigmoid_cross_entropy_logits(model(x)[:, 0], label).mean() gv = objax.GradValues(loss, model.vars()) def train_op(x, label): g, v = gv(x, label) # returns gradients, loss opt(lr, g) return v # This line is optional: it is compiling the code to make it faster. # gv.vars() contains the model variables. train_op = objax.Jit(train_op, gv.vars() + opt.vars()) # Training for epoch in range(epochs): # Train
spatial_kernel=kern_space, z=R[0], sparse=True, opt_z=False, conditional='Full') inf = newt.inference.VariationalInference() markov = True if markov: model = newt.models.MarkovGP(kernel=kern, likelihood=lik, X=t, R=R, Y=Y) # model = newt.models.MarkovGP(kernel=kern, likelihood=lik, X=X, Y=y) else: model = newt.models.GP(kernel=kern, likelihood=lik, X=X, Y=y) compute_energy_and_update = objax.GradValues(inf, model.vars()) lr_adam = 0. lr_newton = 1. epochs = 2 opt = objax.optimizer.Adam(model.vars()) def train_op(): model.update_posterior() grads, loss_ = compute_energy_and_update(model, lr=lr_newton) # print(grads) for g, var_name in zip(grads, model.vars().keys()): # TODO: this gives wrong label to likelihood variance print(g, ' w.r.t. ', var_name) # print(model.kernel.temporal_kernel.variance) opt(lr_adam, grads)