def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.Accuracy()], n_eval_batches=10) training_session = training.Loop( mnist_model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 50 == 0) training_session.run(n_steps=1000) self.assertEqual(training_session.step, 1000)
def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ gin.parse_config([ 'batch_fn.batch_size_per_device = 256', 'batch_fn.eval_batch_size = 256', ]) mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.AccuracyScalar()], names=['CrossEntropyLoss', 'AccuracyScalar'], eval_at=lambda step_n: step_n % 50 == 0, eval_N=10) training_session = training.Loop(mnist_model, task, eval_task=eval_task) training_session.run(n_steps=1000) self.assertEqual(training_session.current_step(), 1000)
def test_mnist(self) -> None: trainer = TraxTrainer() trainer.load_data('mnist', tfds_dir=TestMnist.tfds_dir) trainer.load_model(get_model, False, num_classes=10) training_session = trainer.train( epochs=self.epochs, model_dir=TestMnist.model_dir, metric_emit_freq=lambda step_n: step_n % 50 == 0, metrics=[tl.CrossEntropyLoss(), tl.Accuracy()], loss=tl.CrossEntropyLoss(), optimizer=adafactor.Adafactor(.02), callbacks=None, save_directory=None) self.assertEqual(training_session.current_step, self.epochs)