예제 #1
0
 def __init__(self, model, *args, **kwargs):
     super().__init__(model, *args, **kwargs)
     fastloss = objax.Jit(self.loss, model.vars())
     self.gradvals = objax.Jit(objax.GradValues(fastloss, model.vars()),
                               model.vars())
     self.model.predict = objax.Jit(
         objax.ForceArgs(model.__call__, training=False), model.vars())
예제 #2
0
 def _get_optimizer(self, model_vars: VarCollection, optimizer: str):
     if optimizer == 'momentum':
         opt = objax.Jit(objax.optimizer.Momentum(model_vars, momentum=0.9))
     elif optimizer == 'adam':
         opt = objax.Jit(objax.optimizer.Adam(model_vars))
     elif optimizer == 'sgd':
         opt = objax.Jit(objax.optimizer.SGD(model_vars))
     else:
         raise ValueError
     return opt
예제 #3
0
파일: jit.py 프로젝트: jmarrietar/objax
 def test_double_jit(self):
     k = objax.nn.Linear(3, 3)
     kj = objax.Jit(objax.Jit(k))
     x = objax.random.normal((64, 3))
     y1 = kj(x)
     k.w.assign(k.w.value + 1)
     y2 = kj(x)
     k.w.assign(k.w.value - 1)
     y3 = kj(x)
     self.assertAlmostEqual(((y1 - y3)**2).sum(), 0)
     self.assertNotEqual(((y1 - y2)**2).sum(), 0)
예제 #4
0
파일: jit.py 프로젝트: jmarrietar/objax
 def test_jit_kwargs(self):
     x = objax.random.normal((64, 3))
     kj = objax.Jit(LinearArgs(3, 3))
     y1 = kj(x, 1)
     y2 = kj(x, some_args=1)
     y3 = kj(x, some_args=2)
     self.assertEqual(y1.tolist(), y2.tolist())
     self.assertNotEqual(y1.tolist(), y3.tolist())
     kj = objax.Jit(LinearTrain(3, 3))
     with self.assertRaises(ConcretizationTypeError):
         kj(x, training=True)
예제 #5
0
파일: objax2tf.py 프로젝트: google/objax
 def disabled_test_savedmodel_wrn(self):
     model_dir = tempfile.mkdtemp()
     # Make a model and convert it to TF
     model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1)
     predict_op = objax.Jit(
         objax.nn.Sequential([
             objax.ForceArgs(model, training=False),
             objax.functional.softmax
         ]))
     predict_tf = objax.util.Objax2Tf(predict_op)
     # Save model
     input_shape = (BATCH_SIZE, NCHANNELS, IMAGE_SIZE, IMAGE_SIZE)
     tf.saved_model.save(
         predict_tf,
         model_dir,
         signatures=predict_tf.__call__.get_concrete_function(
             tf.TensorSpec(input_shape, tf.float32)))
     # Load model
     loaded_tf_model = tf.saved_model.load(model_dir)
     loaded_predict_tf_op = loaded_tf_model.signatures['serving_default']
     self.verify_converted_predict_op(
         predict_op,
         lambda x: loaded_predict_tf_op(x)['output_0'],
         shape=input_shape)
     self.verify_converted_predict_op(predict_op,
                                      lambda x: loaded_tf_model(x),
                                      shape=input_shape)
     # Cleanup
     shutil.rmtree(model_dir)
예제 #6
0
    def test_jit_parallel_bntrain_concat(self):
        """JIT parallel inference (concat reduction) with batch norm in train mode."""
        f = objax.nn.Sequential([objax.nn.Conv2D(3, 4, k=1), objax.nn.BatchNorm2D(4), objax.nn.Conv2D(4, 2, k=1)])
        x = objax.random.normal((64, 3, 16, 16))
        states, values = [], []
        tensors = f.vars().tensors()
        for it in range(2):
            values.append(f(x, training=True))
            states += [f[1].running_var.value, f[1].running_mean.value]
        f.vars().assign(tensors)

        self.assertEqual(((values[0] - values[1]) ** 2).sum(), 0)
        self.assertGreater(((states[0] - states[2]) ** 2).sum(), 0)
        self.assertGreater(((states[1] - states[3]) ** 2).sum(), 0)

        fp = objax.Jit(objax.Parallel(lambda x: f(x, training=True), vc=f.vars()))
        x8 = jn.broadcast_to(x, (8, 64, 3, 16, 16)).reshape((-1, 3, 16, 16))
        tensors = fp.vars().tensors()
        for it in range(2):
            with fp.vars().replicate():
                z = fp(x8).reshape((-1,) + values[it].shape)
            self.assertAlmostEqual(((f[1].running_var.value - states[2 * it]) ** 2).sum(), 0, delta=1e-12)
            self.assertAlmostEqual(((f[1].running_mean.value - states[2 * it + 1]) ** 2).sum(), 0, delta=1e-12)
            self.assertLess(((z - values[it][None]) ** 2).sum(), 1e-6)
        fp.vars().assign(tensors)
예제 #7
0
 def test_private_gradvalues_clipping(self):
     """Test if the gradient norm is within l2_norm_clip."""
     noise_multiplier = 0
     acceptable_float_error = 1e-8
     for use_norm_accumulation in [True, False]:
         for microbatch in [1, 10, self.ntrain]:
             for l2_norm_clip in [0, 1e-2, 1e-1, 1.0]:
                 gv_priv = objax.Jit(
                     objax.privacy.dpsgd.PrivateGradValues(
                         self.loss,
                         self.model_vars,
                         noise_multiplier,
                         l2_norm_clip,
                         microbatch,
                         batch_axis=(0, 0),
                         use_norm_accumulation=use_norm_accumulation))
                 g_priv, v_priv = gv_priv(self.data, self.labels)
                 # Get the actual squared norm of the gradient.
                 g_normsquared = sum([np.sum(g**2) for g in g_priv])
                 self.assertLessEqual(
                     g_normsquared,
                     l2_norm_clip**2 + acceptable_float_error)
                 np.testing.assert_allclose(v_priv[0],
                                            self.loss(
                                                self.data, self.labels)[0],
                                            atol=1e-7)
예제 #8
0
    def test_gradvalues_constant(self):
        """Test if constants are preserved."""
        # Set data
        ndim = 1
        data = np.array([[1.0], [3.0], [5.0], [-10.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.ones(1))
        m = objax.ModuleList([w, b])

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred)**2).mean()

        # We are supposed to see the gradient change after the value of b (the constant) changes.
        gv = objax.GradValues(loss, objax.VarCollection({'w': w}))
        g_old, v_old = gv(data, labels)
        b.assign(-b.value)
        g_new, v_new = gv(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])

        # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes.
        gv = objax.Jit(objax.GradValues(loss, objax.VarCollection({'w': w})),
                       m.vars())
        g_old, v_old = gv(data, labels)
        b.assign(-b.value)
        g_new, v_new = gv(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])
예제 #9
0
    def test_private_gradvalues_compare_nonpriv(self):
        """Test if PrivateGradValues without clipping / noise is the same as non-private GradValues."""
        l2_norm_clip = 1e10
        noise_multiplier = 0

        for use_norm_accumulation in [True, False]:
            for microbatch in [1, 10, self.ntrain]:
                gv_priv = objax.Jit(
                    objax.privacy.dpsgd.PrivateGradValues(
                        self.loss,
                        self.model_vars,
                        noise_multiplier,
                        l2_norm_clip,
                        microbatch,
                        batch_axis=(0, 0),
                        use_norm_accumulation=use_norm_accumulation))
                gv = objax.GradValues(self.loss, self.model_vars)
                g_priv, v_priv = gv_priv(self.data, self.labels)
                g, v = gv(self.data, self.labels)

                # Check the shape of the gradient.
                self.assertEqual(g_priv[0].shape, tuple([self.nclass]))
                self.assertEqual(g_priv[1].shape,
                                 tuple([self.ndim, self.nclass]))

                # Check if the private gradient is similar to the non-private gradient.
                np.testing.assert_allclose(g[0], g_priv[0], atol=1e-7)
                np.testing.assert_allclose(g[1], g_priv[1], atol=1e-7)
                np.testing.assert_allclose(v_priv[0],
                                           self.loss(self.data,
                                                     self.labels)[0],
                                           atol=1e-7)
예제 #10
0
def test_composite_shape_jitted(n_samples, n_features, n_components):

    x = objax.random.normal((
        n_samples,
        n_features,
    ), generator=generator)

    # create layer
    transform = CompositeTransform(
        [MixtureGaussianCDF(n_features, n_components),
         Logit()])

    # forward transformation
    jit_net = objax.Jit(transform, transform.vars())
    z, log_abs_det = jit_net(x)

    # checks
    chex.assert_equal_shape([z, x])
    chex.assert_shape(log_abs_det, (n_samples, ))

    # forward transformation
    z = transform.transform(x)

    # checks
    chex.assert_equal_shape([z, x])

    # inverse transformation
    x_approx = transform.inverse(z)

    # checks
    chex.assert_equal_shape([x_approx, x])
예제 #11
0
    def test_private_gradvalues_noise(self):
        """Test if the noise std is around expected."""
        runs = 100
        alpha = 0.0001
        for use_norm_accumulation in [True, False]:
            for microbatch in [1, 10, self.ntrain]:
                for noise_multiplier in [0.1, 10.0]:
                    for l2_norm_clip in [0.01, 0.1]:
                        gv_priv = objax.Jit(
                            objax.privacy.dpsgd.PrivateGradValues(
                                self.loss,
                                self.model_vars,
                                noise_multiplier,
                                l2_norm_clip,
                                microbatch,
                                batch_axis=(0, 0),
                                use_norm_accumulation=use_norm_accumulation))
                        # Repeat the run and collect all gradients.
                        g_privs = []
                        for i in range(runs):
                            g_priv, v_priv = gv_priv(self.data, self.labels)
                            g_privs.append(
                                np.concatenate(
                                    [g_n.reshape(-1) for g_n in g_priv]))
                            np.testing.assert_allclose(v_priv[0],
                                                       self.loss(
                                                           self.data,
                                                           self.labels)[0],
                                                       atol=1e-7)
                        g_privs = np.array(g_privs)

                        # Compute empirical std and expected std.
                        std_empirical = np.std(g_privs, axis=0, ddof=1)
                        std_theoretical = l2_norm_clip * noise_multiplier / (
                            self.ntrain // microbatch)

                        # Conduct chi-square test for correct expected standard
                        # deviation.
                        chi2_value = (
                            runs - 1) * std_empirical**2 / std_theoretical**2
                        chi2_cdf = chi2.cdf(chi2_value, runs - 1)
                        self.assertTrue(
                            np.all(alpha <= chi2_cdf)
                            and np.all(chi2_cdf <= 1.0 - alpha))

                        # Conduct chi-square test for incorrect expected standard
                        # deviations: expect failure.
                        chi2_value = (runs - 1) * std_empirical**2 / (
                            1.25 * std_theoretical)**2
                        chi2_cdf = chi2.cdf(chi2_value, runs - 1)
                        self.assertFalse(
                            np.all(alpha <= chi2_cdf)
                            and np.all(chi2_cdf <= 1.0 - alpha))

                        chi2_value = (runs - 1) * std_empirical**2 / (
                            0.75 * std_theoretical)**2
                        chi2_cdf = chi2.cdf(chi2_value, runs - 1)
                        self.assertFalse(
                            np.all(alpha <= chi2_cdf)
                            and np.all(chi2_cdf <= 1.0 - alpha))
예제 #12
0
    def test_jacobian_linear_jit(self):
        """Test if Jacobian is corrrectly working with JIT."""
        jac_lin_vars = objax.Jacobian(self.f_lin, self.f_lin.vars())
        jac_lin_vars_jit = objax.Jit(jac_lin_vars)
        j = jac_lin_vars(self.data)
        j_jit = jac_lin_vars_jit(self.data)

        self.assertEqual(len(j_jit), 2)
        np.testing.assert_allclose(j_jit[0], j[0])
        np.testing.assert_allclose(j_jit[1], j[1])

        jac_lin_x = objax.Jacobian(self.f_lin, None, input_argnums=(0, ))
        jac_lin_x_jit = objax.Jit(jac_lin_x)
        j = jac_lin_x(self.data)
        j_jit = jac_lin_x_jit(self.data)

        self.assertEqual(len(j_jit), 1)
        np.testing.assert_allclose(j_jit[0], j[0])
예제 #13
0
파일: jit.py 프로젝트: qingliaowu/objax
    def test_trainvar_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])

        def increase():
            m[0].assign(m[0].value + 1)
            return m[0].value

        jit_increase = objax.Jit(increase, m.vars())
        jit_increase()
        self.assertEqual(m[0].value.tolist(), [1., 1.])
예제 #14
0
파일: jit.py 프로젝트: srxzr/objax
    def test_constant_optimization(self):
        m = objax.nn.Linear(3, 4)
        jit_constant = objax.Jit(m, objax.VarCollection())

        x = objax.random.normal((10, 3))
        self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 0)

        # Modify m (which was supposed to be constant!)
        m.b.assign(m.b.value + 1)
        self.assertEqual(((m(x) - jit_constant(x)) ** 2).sum(), 40)
예제 #15
0
    def __init__(self, model: Callable, nclass: int, mnist=False, **kwargs):
        """
        Completely standard training. Nothing interesting to see here.
        """
        super().__init__(nclass, **kwargs)
        self.model = model(1 if mnist else 3, nclass)
        self.opt = objax.optimizer.Momentum(self.model.vars())
        self.model_ema = objax.optimizer.ExponentialMovingAverageModule(
            self.model, momentum=0.999, debias=True)

        @objax.Function.with_vars(self.model.vars())
        def loss(x, label):
            logit = self.model(x, training=True)
            loss_wd = 0.5 * sum(
                (v.value**2).sum()
                for k, v in self.model.vars().items() if k.endswith('.w'))
            loss_xe = objax.functional.loss.cross_entropy_logits(logit,
                                                                 label).mean()
            return loss_xe + loss_wd * self.params.weight_decay, {
                'losses/xe': loss_xe,
                'losses/wd': loss_wd
            }

        gv = objax.GradValues(loss, self.model.vars())
        self.gv = gv

        @objax.Function.with_vars(self.vars())
        def train_op(progress, x, y):
            g, v = gv(x, y)
            lr = self.params.lr * jn.cos(progress * (7 * jn.pi) / (2 * 8))
            lr = lr * jn.clip(progress * 100, 0, 1)
            self.opt(lr, g)
            self.model_ema.update_ema()
            return {'monitors/lr': lr, **v[1]}

        self.predict = objax.Jit(
            objax.nn.Sequential(
                [objax.ForceArgs(self.model_ema, training=False)]))

        self.train_op = objax.Jit(train_op)
예제 #16
0
def validate(args):
    model = create_model(args.model, pretrained=True)
    print(f'Created {args.model} model. Validating...')

    eval_step = objax.Jit(
        lambda images, labels: eval_forward(model, images, labels),
        model.vars())

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data)
    else:
        dataset = Dataset(args.data)

    data_config = resolve_data_config(vars(args), model=model)
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=False,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=8,
                           crop_pct=data_config['crop_pct'])

    batch_time = AverageMeter()
    correct_top1, correct_top5 = 0, 0
    total_examples = 0
    start_time = prev_time = time.time()
    for batch_index, (images, labels) in enumerate(loader):
        images = images.numpy()
        labels = labels.numpy()

        top1_count, top5_count = eval_step(images, labels)
        correct_top1 += top1_count
        correct_top5 += top5_count
        total_examples += images.shape[0]

        batch_time.update(time.time() - prev_time)
        if batch_index % 20 == 0 and batch_index > 0:
            print(
                f'Test: [{batch_index:>4d}/{len(loader)}]  '
                f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) '
                f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} '
                f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}')
        prev_time = time.time()

    acc_1 = 100 * correct_top1 / total_examples
    acc_5 = 100 * correct_top5 / total_examples
    print(
        f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. '
        f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}')
    return dict(top1=float(acc_1), top5=float(acc_5))
예제 #17
0
파일: jit.py 프로젝트: qingliaowu/objax
    def test_trainvar_and_ref_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])
        m.append(objax.TrainRef(m[0]))

        def increase():
            m[0].assign(m[0].value + 1)
            m[1].assign(m[1].value + 1)
            return m[0].value

        jit_increase = objax.Jit(increase, m.vars())
        v = jit_increase()
        self.assertEqual(v.tolist(), [2., 2.])
        self.assertEqual(m[0].value.tolist(), [2., 2.])
예제 #18
0
파일: repr.py 프로젝트: utkarshgiri/objax
    def test_transform(self):
        def myloss(x):
            return (x ** 2).mean()

        g = objax.Grad(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gv = objax.GradValues(myloss, variables=objax.VarCollection(), input_argnums=(0,))
        gvp = objax.privacy.dpsgd.PrivateGradValues(myloss, objax.VarCollection(), noise_multiplier=1.,
                                                    l2_norm_clip=0.5, microbatch=1)
        self.assertEqual(repr(g), 'objax.Grad(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gv), 'objax.GradValues(f=myloss, input_argnums=(0,))')
        self.assertEqual(repr(gvp), 'objax.privacy.dpsgd.gradient.PrivateGradValues(f=myloss, noise_multiplier=1.0,'
                                    ' l2_norm_clip=0.5, microbatch=1, batch_axis=(0,))')
        self.assertEqual(repr(objax.Jit(gv)),
                         'objax.Jit(f=objax.GradValues(f=myloss, input_argnums=(0,)), static_argnums=None)')
        self.assertEqual(repr(objax.Jit(myloss, vc=objax.VarCollection())),
                         'objax.Jit(f=objax.Function(f=myloss), static_argnums=None)')
        self.assertEqual(repr(objax.Parallel(gv)),
                         "objax.Parallel(f=objax.GradValues(f=myloss, input_argnums=(0,)),"
                         " reduce=concatenate(*, axis=0), axis_name='device', static_argnums=None)")
        self.assertEqual(repr(objax.Vectorize(myloss, vc=objax.VarCollection())),
                         'objax.Vectorize(f=objax.Function(f=myloss), batch_axis=(0,))')
        self.assertEqual(repr(objax.ForceArgs(gv, training=True, word='hello')),
                         "objax.ForceArgs(module=GradValues, training=True, word='hello')")
예제 #19
0
def build_train_step_fn(model, loss_fn):
    gradient_loss = objax.GradValues(loss_fn, model.vars())
    optimiser = objax.optimizer.Adam(model.vars())

    def train_step(learning_rate, rgb_img_t1, true_dither_t0, true_dither_t1):
        grads, _loss = gradient_loss(
            rgb_img_t1, true_dither_t0, true_dither_t1)
        grads = u.clip_gradients(grads, theta=opts.gradient_clip)
        optimiser(learning_rate, grads)
        grad_norms = [jnp.linalg.norm(g) for g in grads]
        return grad_norms

    if JIT:
        train_step = objax.Jit(
            train_step, gradient_loss.vars() + optimiser.vars())
    return train_step
예제 #20
0
    def test_trainvar_jit_assign(self):
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.zeros(1))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            b.assign(b.value + 1)
            w.assign(w.value - 1)
            return 0.5 * ((y - pred)**2).mean()

        grad = objax.Grad(loss, objax.VarCollection({'w': w, 'b': b}))

        def jloss(wb, x, y):
            w, b = wb
            pred = jn.dot(x, w) + b
            return 0.5 * ((y - pred)**2).mean()

        def jit_op(x, y):
            g = grad(x, y)
            b.assign(b.value * 2)
            w.assign(w.value * 3)
            return g

        jit_op = objax.Jit(jit_op, objax.VarCollection(dict(b=b, w=w)))
        jgrad = jax.grad(jloss)

        jg = jgrad([w.value, b.value], data, labels)
        g = jit_op(data, labels)
        self.assertEqual(g[0].shape, tuple([ndim]))
        self.assertEqual(g[1].shape, tuple([1]))
        np.testing.assert_allclose(g[0], jg[0])
        np.testing.assert_allclose(g[1], jg[1])
        self.assertEqual(w.value.tolist(), [-3., -3.])
        self.assertEqual(b.value.tolist(), [2.])

        jg = jgrad([w.value, b.value], data, labels)
        g = jit_op(data, labels)
        np.testing.assert_allclose(g[0], jg[0])
        np.testing.assert_allclose(g[1], jg[1])
        self.assertEqual(w.value.tolist(), [-12., -12.])
        self.assertEqual(b.value.tolist(), [6.])
def validate(args):
    model = create_model(args.model, pretrained=True)
    print(f'Created {args.model} model. Validating...')

    eval_step = objax.Jit(
        lambda images, labels: eval_forward(model, images, labels),
        model.vars())

    """Runs evaluation and returns top-1 accuracy."""
    image_size = model.default_cfg['input_size'][-1]
    test_ds, num_batches = imagenet_data.load(
        imagenet_data.Split.TEST,
        is_training=False,
        image_size=image_size,
        batch_dims=[args.batch_size],
        chw=True,
        mean=tuple([x * 255 for x in model.default_cfg['mean']]),
        std=tuple([x * 255 for x in model.default_cfg['std']]),
        tfds_data_dir=args.data)

    batch_time = AverageMeter()
    correct_top1, correct_top5 = 0, 0
    total_examples = 0
    start_time = prev_time = time.time()
    for batch_index, batch in enumerate(test_ds):
        images, labels = batch['images'], batch['labels']
        top1_count, top5_count = eval_step(images, labels)
        correct_top1 += int(top1_count)
        correct_top5 += int(top5_count)
        total_examples += images.shape[0]

        batch_time.update(time.time() - prev_time)
        if batch_index % 20 == 0 and batch_index > 0:
            print(
                f'Test: [{batch_index:>4d}/{num_batches}]  '
                f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) '
                f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} '
                f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}')
        prev_time = time.time()

    acc_1 = 100 * correct_top1 / total_examples
    acc_5 = 100 * correct_top5 / total_examples
    print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. '
          f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}')
    return dict(top1=float(acc_1), top5=float(acc_5))
예제 #22
0
     def __init__(self):
         fn = functools.partial(wide_resnet.WideResNet, depth=28, width=6)
         self.model = fn(3*4,2)
         
         model_vars = self.model.vars()
         self.ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999, debias=True)
 
         
         def predict_op(x,y):
             # The model takes SEVERAL images and checks if they all correspond
             # to the same original image.
             # Guaranteed that the first N-1 all do, the test is if the last does.
             xx = jn.concatenate([jn.abs(x),
                                  jn.abs(y)],
                                 axis=1)
             return self.model(xx, training=False)
         
         self.predict = objax.Jit(self.ema.replace_vars(predict_op), model_vars + self.ema.vars())
예제 #23
0
 def test_jit_parallel_reducers(self):
     """JIT parallel reductions."""
     f = objax.nn.Linear(3, 4)
     x = objax.random.normal((96, 3))
     y = f(x)
     zl = []
     for reduce in (lambda x: x,
                    lambda x: x[0],
                    lambda x: x.mean(0),
                    lambda x: x.sum(0)):
         fp = objax.Jit(objax.Parallel(f, reduce=reduce))
         with fp.vars().replicate():
             zl.append(fp(x))
     znone, zfirst, zmean, zsum = zl
     self.assertAlmostEqual(jn.square(jn.array(y.split(8)) - znone).sum(), 0, places=8)
     self.assertAlmostEqual(jn.square(y.split(8)[0] - zfirst).sum(), 0, places=8)
     self.assertAlmostEqual(jn.square(np.mean(y.split(8), 0) - zmean).sum(), 0, places=8)
     self.assertAlmostEqual(jn.square(np.sum(y.split(8), 0) - zsum).sum(), 0, places=8)
예제 #24
0
파일: objax2tf.py 프로젝트: google/objax
 def disabled_test_convert_wrn(self):
     # Make a model
     model = WideResNet(NCHANNELS, NCLASSES, depth=4, width=1)
     # Prediction op without JIT
     predict_op = objax.nn.Sequential(
         [objax.ForceArgs(model, training=False), objax.functional.softmax])
     predict_tf = objax.util.Objax2Tf(predict_op)
     # Compare results
     self.verify_converted_predict_op(predict_op,
                                      predict_tf,
                                      shape=(BATCH_SIZE, NCHANNELS,
                                             IMAGE_SIZE, IMAGE_SIZE))
     # Predict op with JIT
     predict_op_jit = objax.Jit(predict_op)
     predict_tf_jit = objax.util.Objax2Tf(predict_op_jit)
     # Compare results
     self.verify_converted_predict_op(predict_op_jit,
                                      predict_tf_jit,
                                      shape=(BATCH_SIZE, NCHANNELS,
                                             IMAGE_SIZE, IMAGE_SIZE))
예제 #25
0
    logit = model(x, training=True)
    return objax.functional.loss.cross_entropy_logits_sparse(logit,
                                                             label).mean()


gv = objax.GradValues(loss, model.vars())


@objax.Function.with_vars(model.vars() + gv.vars() + opt.vars())
def train_op(x, y, lr):
    g, v = gv(x, y)
    opt(lr=lr, grads=g)
    return v


train_op = objax.Jit(train_op)
predict = objax.Jit(
    objax.nn.Sequential(
        [objax.ForceArgs(model, training=False), objax.functional.softmax]))


def augment(x):
    if random.random() < .5:
        x = x[:, :, :, ::-1]  # Flip the batch images about the horizontal axis
    # Pixel-shift all images in the batch by up to 4 pixels in any direction.
    x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect')
    rx, ry = np.random.randint(0, 8), np.random.randint(0, 8)
    x = x_pad[:, :, rx:rx + 32, ry:ry + 32]
    return x

예제 #26
0
trainable_vars = model.vars() + inf.vars()
energy = objax.GradValues(inf.energy, trainable_vars)

lr_adam = 0.2
lr_newton = 0.2
iters = 100
opt = objax.optimizer.Adam(trainable_vars)


def train_op():
    inf(model, lr=lr_newton)  # perform inference and update variational params
    dE, E = energy(model)  # compute energy and its gradients w.r.t. hypers
    return dE, E


train_op = objax.Jit(train_op, trainable_vars)

t0 = time.time()
for i in range(1, iters + 1):
    grad, loss = train_op()
    opt(lr_adam, grad)
    print('iter %2d: energy: %1.4f' % (i, loss[0]))
t1 = time.time()
print('optimisation time: %2.2f secs' % (t1 - t0))

# calculate posterior predictive distribution via filtering and smoothing at train & test locations:
print('calculating the posterior predictive distribution ...')
t0 = time.time()
nlpd = model.negative_log_predictive_density(X=t_test, R=r_test, Y=Y_test)
t1 = time.time()
print('prediction time: %2.2f secs' % (t1 - t0))
예제 #27
0
def loss(x, label):
    return objax.functional.loss.sigmoid_cross_entropy_logits(model(x)[:, 0], label).mean()


gv = objax.GradValues(loss, model.vars())


@objax.Function.with_vars(model.vars() + gv.vars() + opt.vars())
def train_op(x, label):
    g, v = gv(x, label)  # returns gradients, loss
    opt(lr, g)
    return v


# This line is optional: it is compiling the code to make it faster.
train_op = objax.Jit(train_op)

# Training
for epoch in range(epochs):
    # Train
    avg_loss = 0
    for it in range(0, train.image.shape[0], batch):
        sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0])
        avg_loss += float(train_op(train.image[sel], train.label[sel])[0]) * batch
    avg_loss /= it + batch

    # Eval
    accuracy = 0
    for it in range(0, test.image.shape[0], batch):
        x, y = test.image[it: it + batch], test.label[it: it + batch]
        accuracy += (np.round(objax.functional.sigmoid(model(x)))[:, 0] == y).sum()
 def __init__(self, model, *args, **kwargs):
     super().__init__(model, *args, **kwargs)
     self.loss = objax.Jit(self.loss, model.vars())
     #self.model = objax.Jit(self.model)
     self.gradvals = objax.Jit(objax.GradValues(self.loss, model.vars(
     )))  #objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars())
예제 #29
0

def loss(x, y):
    return ((y - net(x))**2).mean()


gv = objax.GradValues(loss, net.vars())


def train_op():
    g, v = gv(source, target)
    opt(0.01, g)
    return v


train_op = objax.Jit(train_op, gv.vars() + opt.vars())

for i in range(100):
    train_op()

plt.plot(source, net(source), label='prediction')
plt.plot(source, (target - net(source))**2, label='loss')
plt.plot(source, target, label='target')
plt.legend()
plt.show()

print('MAML training')
net = make_net()
opt = objax.optimizer.Adam(net.vars())

예제 #30
0
    return objax.functional.loss.cross_entropy_logits(logit, label).mean()


gv = objax.GradValues(loss, model.vars())


@objax.Function.with_vars(model.vars() + gv.vars() + opt.vars() +
                          model_ema.vars())
def train_op(x, xl):
    g, v = gv(x, xl)  # returns gradients, loss
    opt(lr, g)
    model_ema.update_ema()
    return v


train_op = objax.Jit(train_op)  # Compile train_op to make it run faster.
predict = objax.Jit(model_ema)

# Training
print(model.vars())
print(f'Visualize results with: tensorboard --logdir "{logdir}"')
print(
    "Disclaimer: This code demonstrates the DNNet class. For SOTA accuracy use a CNN instead."
)
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
    for epoch in range(num_train_epochs):
        # Train one epoch
        summary = Summary()
        loop = trange(0,
                      train_size,
                      batch,