def __init__(self, model, *args, **kwargs): super().__init__(model, *args, **kwargs) fastloss = objax.Jit(self.loss, model.vars()) self.gradvals = objax.Jit(objax.GradValues(fastloss, model.vars()), model.vars()) self.model.predict = objax.Jit( objax.ForceArgs(model.__call__, training=False), model.vars())
def _get_optimizer(self, model_vars: VarCollection, optimizer: str): if optimizer == 'momentum': opt = objax.Jit(objax.optimizer.Momentum(model_vars, momentum=0.9)) elif optimizer == 'adam': opt = objax.Jit(objax.optimizer.Adam(model_vars)) elif optimizer == 'sgd': opt = objax.Jit(objax.optimizer.SGD(model_vars)) else: raise ValueError return opt
def test_double_jit(self): k = objax.nn.Linear(3, 3) kj = objax.Jit(objax.Jit(k)) x = objax.random.normal((64, 3)) y1 = kj(x) k.w.assign(k.w.value + 1) y2 = kj(x) k.w.assign(k.w.value - 1) y3 = kj(x) self.assertAlmostEqual(((y1 - y3)**2).sum(), 0) self.assertNotEqual(((y1 - y2)**2).sum(), 0)
def test_jit_kwargs(self): x = objax.random.normal((64, 3)) kj = objax.Jit(LinearArgs(3, 3)) y1 = kj(x, 1) y2 = kj(x, some_args=1) y3 = kj(x, some_args=2) self.assertEqual(y1.tolist(), y2.tolist()) self.assertNotEqual(y1.tolist(), y3.tolist()) kj = objax.Jit(LinearTrain(3, 3)) with self.assertRaises(ConcretizationTypeError): kj(x, training=True)
def disabled_test_savedmodel_wrn(self): model_dir = tempfile.mkdtemp() # Make a model and convert it to TF model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1) predict_op = objax.Jit( objax.nn.Sequential([ objax.ForceArgs(model, training=False), objax.functional.softmax ])) predict_tf = objax.util.Objax2Tf(predict_op) # Save model input_shape = (BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE) tf.saved_model.save( predict_tf, model_dir, signatures=predict_tf.__call__.get_concrete_function( tf.TensorSpec(input_shape, tf.float32))) # Load model loaded_tf_model = tf.saved_model.load(model_dir) loaded_predict_tf_op = loaded_tf_model.signatures['serving_default'] self.verify_converted_predict_op( predict_op, lambda x: loaded_predict_tf_op(x)['output_0'], shape=input_shape) self.verify_converted_predict_op(predict_op, lambda x: loaded_tf_model(x), shape=input_shape) # Cleanup shutil.rmtree(model_dir)
def test_jit_parallel_bntrain_concat(self): """JIT parallel inference (concat reduction) with batch norm in train mode.""" f = objax.nn.Sequential([objax.nn.Conv2D(3, 4, k=1), objax.nn.BatchNorm2D(4), objax.nn.Conv2D(4, 2, k=1)]) x = objax.random.normal((64, 3, 16, 16)) states, values = [], [] tensors = f.vars().tensors() for it in range(2): values.append(f(x, training=True)) states += [f[1].running_var.value, f[1].running_mean.value] f.vars().assign(tensors) self.assertEqual(((values[0] - values[1]) ** 2).sum(), 0) self.assertGreater(((states[0] - states[2]) ** 2).sum(), 0) self.assertGreater(((states[1] - states[3]) ** 2).sum(), 0) fp = objax.Jit(objax.Parallel(lambda x: f(x, training=True), vc=f.vars())) x8 = jn.broadcast_to(x, (8, 64, 3, 16, 16)).reshape((-1, 3, 16, 16)) tensors = fp.vars().tensors() for it in range(2): with fp.vars().replicate(): z = fp(x8).reshape((-1,) + values[it].shape) self.assertAlmostEqual(((f[1].running_var.value - states[2 * it]) ** 2).sum(), 0, delta=1e-12) self.assertAlmostEqual(((f[1].running_mean.value - states[2 * it + 1]) ** 2).sum(), 0, delta=1e-12) self.assertLess(((z - values[it][None]) ** 2).sum(), 1e-6) fp.vars().assign(tensors)
def test_private_gradvalues_clipping(self): """Test if the gradient norm is within l2_norm_clip.""" noise_multiplier = 0 acceptable_float_error = 1e-8 for use_norm_accumulation in [True, False]: for microbatch in [1, 10, self.ntrain]: for l2_norm_clip in [0, 1e-2, 1e-1, 1.0]: gv_priv = objax.Jit( objax.privacy.dpsgd.PrivateGradValues( self.loss, self.model_vars, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0, 0), use_norm_accumulation=use_norm_accumulation)) g_priv, v_priv = gv_priv(self.data, self.labels) # Get the actual squared norm of the gradient. g_normsquared = sum([np.sum(g**2) for g in g_priv]) self.assertLessEqual( g_normsquared, l2_norm_clip**2 + acceptable_float_error) np.testing.assert_allclose(v_priv[0], self.loss( self.data, self.labels)[0], atol=1e-7)
def test_gradvalues_constant(self): """Test if constants are preserved.""" # Set data ndim = 1 data = np.array([[1.0], [3.0], [5.0], [-10.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.ones(1)) m = objax.ModuleList([w, b]) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred)**2).mean() # We are supposed to see the gradient change after the value of b (the constant) changes. gv = objax.GradValues(loss, objax.VarCollection({'w': w})) g_old, v_old = gv(data, labels) b.assign(-b.value) g_new, v_new = gv(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0]) # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes. gv = objax.Jit(objax.GradValues(loss, objax.VarCollection({'w': w})), m.vars()) g_old, v_old = gv(data, labels) b.assign(-b.value) g_new, v_new = gv(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0])
def test_private_gradvalues_compare_nonpriv(self): """Test if PrivateGradValues without clipping / noise is the same as non-private GradValues.""" l2_norm_clip = 1e10 noise_multiplier = 0 for use_norm_accumulation in [True, False]: for microbatch in [1, 10, self.ntrain]: gv_priv = objax.Jit( objax.privacy.dpsgd.PrivateGradValues( self.loss, self.model_vars, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0, 0), use_norm_accumulation=use_norm_accumulation)) gv = objax.GradValues(self.loss, self.model_vars) g_priv, v_priv = gv_priv(self.data, self.labels) g, v = gv(self.data, self.labels) # Check the shape of the gradient. self.assertEqual(g_priv[0].shape, tuple([self.nclass])) self.assertEqual(g_priv[1].shape, tuple([self.ndim, self.nclass])) # Check if the private gradient is similar to the non-private gradient. np.testing.assert_allclose(g[0], g_priv[0], atol=1e-7) np.testing.assert_allclose(g[1], g_priv[1], atol=1e-7) np.testing.assert_allclose(v_priv[0], self.loss(self.data, self.labels)[0], atol=1e-7)
def test_composite_shape_jitted(n_samples, n_features, n_components): x = objax.random.normal(( n_samples, n_features, ), generator=generator) # create layer transform = CompositeTransform( [MixtureGaussianCDF(n_features, n_components), Logit()]) # forward transformation jit_net = objax.Jit(transform, transform.vars()) z, log_abs_det = jit_net(x) # checks chex.assert_equal_shape([z, x]) chex.assert_shape(log_abs_det, (n_samples, )) # forward transformation z = transform.transform(x) # checks chex.assert_equal_shape([z, x]) # inverse transformation x_approx = transform.inverse(z) # checks chex.assert_equal_shape([x_approx, x])
def test_private_gradvalues_noise(self): """Test if the noise std is around expected.""" runs = 100 alpha = 0.0001 for use_norm_accumulation in [True, False]: for microbatch in [1, 10, self.ntrain]: for noise_multiplier in [0.1, 10.0]: for l2_norm_clip in [0.01, 0.1]: gv_priv = objax.Jit( objax.privacy.dpsgd.PrivateGradValues( self.loss, self.model_vars, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0, 0), use_norm_accumulation=use_norm_accumulation)) # Repeat the run and collect all gradients. g_privs = [] for i in range(runs): g_priv, v_priv = gv_priv(self.data, self.labels) g_privs.append( np.concatenate( [g_n.reshape(-1) for g_n in g_priv])) np.testing.assert_allclose(v_priv[0], self.loss( self.data, self.labels)[0], atol=1e-7) g_privs = np.array(g_privs) # Compute empirical std and expected std. std_empirical = np.std(g_privs, axis=0, ddof=1) std_theoretical = l2_norm_clip * noise_multiplier / ( self.ntrain // microbatch) # Conduct chi-square test for correct expected standard # deviation. chi2_value = ( runs - 1) * std_empirical**2 / std_theoretical**2 chi2_cdf = chi2.cdf(chi2_value, runs - 1) self.assertTrue( np.all(alpha <= chi2_cdf) and np.all(chi2_cdf <= 1.0 - alpha)) # Conduct chi-square test for incorrect expected standard # deviations: expect failure. chi2_value = (runs - 1) * std_empirical**2 / ( 1.25 * std_theoretical)**2 chi2_cdf = chi2.cdf(chi2_value, runs - 1) self.assertFalse( np.all(alpha <= chi2_cdf) and np.all(chi2_cdf <= 1.0 - alpha)) chi2_value = (runs - 1) * std_empirical**2 / ( 0.75 * std_theoretical)**2 chi2_cdf = chi2.cdf(chi2_value, runs - 1) self.assertFalse( np.all(alpha <= chi2_cdf) and np.all(chi2_cdf <= 1.0 - alpha))
def test_jacobian_linear_jit(self): """Test if Jacobian is corrrectly working with JIT.""" jac_lin_vars = objax.Jacobian(self.f_lin, self.f_lin.vars()) jac_lin_vars_jit = objax.Jit(jac_lin_vars) j = jac_lin_vars(self.data) j_jit = jac_lin_vars_jit(self.data) self.assertEqual(len(j_jit), 2) np.testing.assert_allclose(j_jit[0], j[0]) np.testing.assert_allclose(j_jit[1], j[1]) jac_lin_x = objax.Jacobian(self.f_lin, None, input_argnums=(0, )) jac_lin_x_jit = objax.Jit(jac_lin_x) j = jac_lin_x(self.data) j_jit = jac_lin_x_jit(self.data) self.assertEqual(len(j_jit), 1) np.testing.assert_allclose(j_jit[0], j[0])
def test_trainvar_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) def increase(): m[0].assign(m[0].value + 1) return m[0].value jit_increase = objax.Jit(increase, m.vars()) jit_increase() self.assertEqual(m[0].value.tolist(), [1., 1.])
def test_constant_optimization(self): m = objax.nn.Linear(3, 4) jit_constant = objax.Jit(m, objax.VarCollection()) x = objax.random.normal((10, 3)) self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 0) # Modify m (which was supposed to be constant!) m.b.assign(m.b.value + 1) self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 40)
def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs): """ Completely standard training. Nothing interesting to see here. """ super().__init__(nclass, **kwargs) self.model = model(1 if mnist else 3, nclass) self.opt = objax.optimizer.Momentum(self.model.vars()) self.model_ema = objax.optimizer.ExponentialMovingAverageModule( self.model, momentum=0.999, debias=True) @objax.Function.with_vars(self.model.vars()) def loss(x, label): logit = self.model(x, training=True) loss_wd = 0.5 * sum( (v.value**2).sum() for k, v in self.model.vars().items() if k.endswith('.w')) loss_xe = objax.functional.loss.cross_entropy_logits(logit, label).mean() return loss_xe + loss_wd * self.params.weight_decay, { 'losses/xe': loss_xe, 'losses/wd': loss_wd } gv = objax.GradValues(loss, self.model.vars()) self.gv = gv @objax.Function.with_vars(self.vars()) def train_op(progress, x, y): g, v = gv(x, y) lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8)) lr = lr * jn.clip(progress * 100, 0, 1) self.opt(lr, g) self.model_ema.update_ema() return {'monitors/lr': lr, **v[1]} self.predict = objax.Jit( objax.nn.Sequential( [objax.ForceArgs(self.model_ema, training=False)])) self.train_op = objax.Jit(train_op)
def validate(args): model = create_model(args.model, pretrained=True) print(f'Created {args.model} model. Validating...') eval_step = objax.Jit( lambda images, labels: eval_forward(model, images, labels), model.vars()) if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data) else: dataset = Dataset(args.data) data_config = resolve_data_config(vars(args), model=model) loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=8, crop_pct=data_config['crop_pct']) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, (images, labels) in enumerate(loader): images = images.numpy() labels = labels.numpy() top1_count, top5_count = eval_step(images, labels) correct_top1 += top1_count correct_top5 += top5_count total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{len(loader)}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print( f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=float(acc_1), top5=float(acc_5))
def test_trainvar_and_ref_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) m.append(objax.TrainRef(m[0])) def increase(): m[0].assign(m[0].value + 1) m[1].assign(m[1].value + 1) return m[0].value jit_increase = objax.Jit(increase, m.vars()) v = jit_increase() self.assertEqual(v.tolist(), [2., 2.]) self.assertEqual(m[0].value.tolist(), [2., 2.])
def test_transform(self): def myloss(x): return (x ** 2).mean() g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,)) gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1., l2_norm_clip=0.5, microbatch=1) self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))') self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,' ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))') self.assertEqual(repr(objax.Jit(gv)), 'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)') self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())), 'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)') self.assertEqual(repr(objax.Parallel(gv)), "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,))," " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)") self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())), 'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))') self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')), "objax.ForceArgs(module=GradValues, training=True, word='hello')")
def build_train_step_fn(model, loss_fn): gradient_loss = objax.GradValues(loss_fn, model.vars()) optimiser = objax.optimizer.Adam(model.vars()) def train_step(learning_rate, rgb_img_t1, true_dither_t0, true_dither_t1): grads, _loss = gradient_loss( rgb_img_t1, true_dither_t0, true_dither_t1) grads = u.clip_gradients(grads, theta=opts.gradient_clip) optimiser(learning_rate, grads) grad_norms = [jnp.linalg.norm(g) for g in grads] return grad_norms if JIT: train_step = objax.Jit( train_step, gradient_loss.vars() + optimiser.vars()) return train_step
def test_trainvar_jit_assign(self): # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.zeros(1)) def loss(x, y): pred = jn.dot(x, w.value) + b.value b.assign(b.value + 1) w.assign(w.value - 1) return 0.5 * ((y - pred)**2).mean() grad = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b})) def jloss(wb, x, y): w, b = wb pred = jn.dot(x, w) + b return 0.5 * ((y - pred)**2).mean() def jit_op(x, y): g = grad(x, y) b.assign(b.value * 2) w.assign(w.value * 3) return g jit_op = objax.Jit(jit_op, objax.VarCollection(dict(b=b, w=w))) jgrad = jax.grad(jloss) jg = jgrad([w.value, b.value], data, labels) g = jit_op(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) self.assertEqual(g[1].shape, tuple([1])) np.testing.assert_allclose(g[0], jg[0]) np.testing.assert_allclose(g[1], jg[1]) self.assertEqual(w.value.tolist(), [-3., -3.]) self.assertEqual(b.value.tolist(), [2.]) jg = jgrad([w.value, b.value], data, labels) g = jit_op(data, labels) np.testing.assert_allclose(g[0], jg[0]) np.testing.assert_allclose(g[1], jg[1]) self.assertEqual(w.value.tolist(), [-12., -12.]) self.assertEqual(b.value.tolist(), [6.])
def validate(args): model = create_model(args.model, pretrained=True) print(f'Created {args.model} model. Validating...') eval_step = objax.Jit( lambda images, labels: eval_forward(model, images, labels), model.vars()) """Runs evaluation and returns top-1 accuracy.""" image_size = model.default_cfg['input_size'][-1] test_ds, num_batches = imagenet_data.load( imagenet_data.Split.TEST, is_training=False, image_size=image_size, batch_dims=[args.batch_size], chw=True, mean=tuple([x * 255 for x in model.default_cfg['mean']]), std=tuple([x * 255 for x in model.default_cfg['std']]), tfds_data_dir=args.data) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, batch in enumerate(test_ds): images, labels = batch['images'], batch['labels'] top1_count, top5_count = eval_step(images, labels) correct_top1 += int(top1_count) correct_top5 += int(top5_count) total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{num_batches}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=float(acc_1), top5=float(acc_5))
def __init__(self): fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6) self.model = fn(3*4,2) model_vars = self.model.vars() self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True) def predict_op(x,y): # The model takes SEVERAL images and checks if they all correspond # to the same original image. # Guaranteed that the first N-1 all do, the test is if the last does. xx = jn.concatenate([jn.abs(x), jn.abs(y)], axis=1) return self.model(xx, training=False) self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
def test_jit_parallel_reducers(self): """JIT parallel reductions.""" f = objax.nn.Linear(3, 4) x = objax.random.normal((96, 3)) y = f(x) zl = [] for reduce in (lambda x: x, lambda x: x[0], lambda x: x.mean(0), lambda x: x.sum(0)): fp = objax.Jit(objax.Parallel(f, reduce=reduce)) with fp.vars().replicate(): zl.append(fp(x)) znone, zfirst, zmean, zsum = zl self.assertAlmostEqual(jn.square(jn.array(y.split(8)) - znone).sum(), 0, places=8) self.assertAlmostEqual(jn.square(y.split(8)[0] - zfirst).sum(), 0, places=8) self.assertAlmostEqual(jn.square(np.mean(y.split(8), 0) - zmean).sum(), 0, places=8) self.assertAlmostEqual(jn.square(np.sum(y.split(8), 0) - zsum).sum(), 0, places=8)
def disabled_test_convert_wrn(self): # Make a model model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1) # Prediction op without JIT predict_op = objax.nn.Sequential( [objax.ForceArgs(model, training=False), objax.functional.softmax]) predict_tf = objax.util.Objax2Tf(predict_op) # Compare results self.verify_converted_predict_op(predict_op, predict_tf, shape=(BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE)) # Predict op with JIT predict_op_jit = objax.Jit(predict_op) predict_tf_jit = objax.util.Objax2Tf(predict_op_jit) # Compare results self.verify_converted_predict_op(predict_op_jit, predict_tf_jit, shape=(BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE))
logit = model(x, training=True) return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean() gv = objax.GradValues(loss, model.vars()) @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars()) def train_op(x, y, lr): g, v = gv(x, y) opt(lr=lr, grads=g) return v train_op = objax.Jit(train_op) predict = objax.Jit( objax.nn.Sequential( [objax.ForceArgs(model, training=False), objax.functional.softmax])) def augment(x): if random.random() < .5: x = x[:, :, :, ::-1] # Flip the batch images about the horizontal axis # Pixel-shift all images in the batch by up to 4 pixels in any direction. x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect') rx, ry = np.random.randint(0, 8), np.random.randint(0, 8) x = x_pad[:, :, rx:rx + 32, ry:ry + 32] return x
trainable_vars = model.vars() + inf.vars() energy = objax.GradValues(inf.energy, trainable_vars) lr_adam = 0.2 lr_newton = 0.2 iters = 100 opt = objax.optimizer.Adam(trainable_vars) def train_op(): inf(model, lr=lr_newton) # perform inference and update variational params dE, E = energy(model) # compute energy and its gradients w.r.t. hypers return dE, E train_op = objax.Jit(train_op, trainable_vars) t0 = time.time() for i in range(1, iters + 1): grad, loss = train_op() opt(lr_adam, grad) print('iter %2d: energy: %1.4f' % (i, loss[0])) t1 = time.time() print('optimisation time: %2.2f secs' % (t1 - t0)) # calculate posterior predictive distribution via filtering and smoothing at train & test locations: print('calculating the posterior predictive distribution ...') t0 = time.time() nlpd = model.negative_log_predictive_density(X=t_test, R=r_test, Y=Y_test) t1 = time.time() print('prediction time: %2.2f secs' % (t1 - t0))
def loss(x, label): return objax.functional.loss.sigmoid_cross_entropy_logits(model(x)[:, 0], label).mean() gv = objax.GradValues(loss, model.vars()) @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars()) def train_op(x, label): g, v = gv(x, label) # returns gradients, loss opt(lr, g) return v # This line is optional: it is compiling the code to make it faster. train_op = objax.Jit(train_op) # Training for epoch in range(epochs): # Train avg_loss = 0 for it in range(0, train.image.shape[0], batch): sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0]) avg_loss += float(train_op(train.image[sel], train.label[sel])[0]) * batch avg_loss /= it + batch # Eval accuracy = 0 for it in range(0, test.image.shape[0], batch): x, y = test.image[it: it + batch], test.label[it: it + batch] accuracy += (np.round(objax.functional.sigmoid(model(x)))[:, 0] == y).sum()
def __init__(self, model, *args, **kwargs): super().__init__(model, *args, **kwargs) self.loss = objax.Jit(self.loss, model.vars()) #self.model = objax.Jit(self.model) self.gradvals = objax.Jit(objax.GradValues(self.loss, model.vars( ))) #objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars())
def loss(x, y): return ((y - net(x))**2).mean() gv = objax.GradValues(loss, net.vars()) def train_op(): g, v = gv(source, target) opt(0.01, g) return v train_op = objax.Jit(train_op, gv.vars() + opt.vars()) for i in range(100): train_op() plt.plot(source, net(source), label='prediction') plt.plot(source, (target - net(source))**2, label='loss') plt.plot(source, target, label='target') plt.legend() plt.show() print('MAML training') net = make_net() opt = objax.optimizer.Adam(net.vars())
return objax.functional.loss.cross_entropy_logits(logit, label).mean() gv = objax.GradValues(loss, model.vars()) @objax.Function.with_vars(model.vars() + gv.vars() + opt.vars() + model_ema.vars()) def train_op(x, xl): g, v = gv(x, xl) # returns gradients, loss opt(lr, g) model_ema.update_ema() return v train_op = objax.Jit(train_op) # Compile train_op to make it run faster. predict = objax.Jit(model_ema) # Training print(model.vars()) print(f'Visualize results with: tensorboard --logdir "{logdir}"') print( "Disclaimer: This code demonstrates the DNNet class. For SOTA accuracy use a CNN instead." ) with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(num_train_epochs): # Train one epoch summary = Summary() loop = trange(0, train_size, batch,