def test_compute_metrics_correct(self): """Tests output when logit outputs indicate correct classification.""" logits, labels_correct = self._create_logits_labels(True) logits = training._shard_batch(logits) labels_correct = training._shard_batch(labels_correct) p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch') metrics = p_compute_metrics(logits, labels_correct) loss = metrics['loss'] accuracy = metrics['accuracy'] with self.subTest(name='loss_type'): self.assertIsInstance(loss, jnp.ndarray) with self.subTest(name='loss_len'): self.assertEqual(loss.size, 1) with self.subTest(name='loss_values'): self.assertGreaterEqual(loss.all(), 0) with self.subTest(name='accuracy_type'): self.assertIsInstance(accuracy, jnp.ndarray) with self.subTest(name='accuracy_Len'): self.assertEqual(accuracy.size, 1) with self.subTest(name='accuracy_values'): self.assertAlmostEqual(accuracy.all(), 1.0)
def test_train_one_step(self): """Tests training loop over one step.""" iterator = self._dataset.get_train() batch = next(iterator) state = jax_utils.replicate(self._state) optimizer = jax_utils.replicate(self._optimizer.create(self._model)) self._rng, step_key = jax.random.split(self._rng) batch = training._shard_batch(batch) sharded_keys = common_utils.shard_prng_key(step_key) p_train_step = jax.pmap(functools.partial( training.train_step, learning_rate_fn=self._learning_rate_fn), axis_name='batch') _, _, loss, gradient_norm = p_train_step(optimizer, batch, sharded_keys, state) loss = jnp.mean(loss) gradient_norm = jax_utils.unreplicate(gradient_norm) with self.subTest(name='test_loss_range'): self.assertBetween(loss, self._min_loss, self._max_loss) with self.subTest(name='test_gradient_norm'): self.assertGreaterEqual(gradient_norm, 0)
def test_eval_batch(self): """Tests model per-batch evaluation function.""" state = jax_utils.replicate(self._state) optimizer = jax_utils.replicate(self._optimizer.create(self._model)) iterator = self._dataset.get_test() batch = next(iterator) batch = training._shard_batch(batch) metrics = jax.pmap(training._eval_step, axis_name='batch')(optimizer.target, state, batch) loss = jnp.mean(metrics['loss']) accuracy = jnp.mean(metrics['accuracy']) with self.subTest(name='test_eval_batch_loss'): self.assertBetween(loss, self._min_loss, self._max_loss) with self.subTest(name='test_eval_batch_accuracy'): self.assertBetween(accuracy, 0., 1.)