예제 #1
0
    def test_compute_metrics_correct(self):
        """Tests output when logit outputs indicate correct classification."""
        logits, labels_correct = self._create_logits_labels(True)
        logits = training._shard_batch(logits)
        labels_correct = training._shard_batch(labels_correct)

        p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch')
        metrics = p_compute_metrics(logits, labels_correct)
        loss = metrics['loss']
        accuracy = metrics['accuracy']

        with self.subTest(name='loss_type'):
            self.assertIsInstance(loss, jnp.ndarray)

        with self.subTest(name='loss_len'):
            self.assertEqual(loss.size, 1)

        with self.subTest(name='loss_values'):
            self.assertGreaterEqual(loss.all(), 0)

        with self.subTest(name='accuracy_type'):
            self.assertIsInstance(accuracy, jnp.ndarray)

        with self.subTest(name='accuracy_Len'):
            self.assertEqual(accuracy.size, 1)

        with self.subTest(name='accuracy_values'):
            self.assertAlmostEqual(accuracy.all(), 1.0)
예제 #2
0
    def test_train_one_step(self):
        """Tests training loop over one step."""
        iterator = self._dataset.get_train()
        batch = next(iterator)

        state = jax_utils.replicate(self._state)
        optimizer = jax_utils.replicate(self._optimizer.create(self._model))

        self._rng, step_key = jax.random.split(self._rng)
        batch = training._shard_batch(batch)
        sharded_keys = common_utils.shard_prng_key(step_key)

        p_train_step = jax.pmap(functools.partial(
            training.train_step, learning_rate_fn=self._learning_rate_fn),
                                axis_name='batch')
        _, _, loss, gradient_norm = p_train_step(optimizer, batch,
                                                 sharded_keys, state)

        loss = jnp.mean(loss)
        gradient_norm = jax_utils.unreplicate(gradient_norm)

        with self.subTest(name='test_loss_range'):
            self.assertBetween(loss, self._min_loss, self._max_loss)

        with self.subTest(name='test_gradient_norm'):
            self.assertGreaterEqual(gradient_norm, 0)
예제 #3
0
    def test_eval_batch(self):
        """Tests model per-batch evaluation function."""
        state = jax_utils.replicate(self._state)
        optimizer = jax_utils.replicate(self._optimizer.create(self._model))

        iterator = self._dataset.get_test()
        batch = next(iterator)
        batch = training._shard_batch(batch)

        metrics = jax.pmap(training._eval_step,
                           axis_name='batch')(optimizer.target, state, batch)

        loss = jnp.mean(metrics['loss'])
        accuracy = jnp.mean(metrics['accuracy'])

        with self.subTest(name='test_eval_batch_loss'):
            self.assertBetween(loss, self._min_loss, self._max_loss)

        with self.subTest(name='test_eval_batch_accuracy'):
            self.assertBetween(accuracy, 0., 1.)