def _mnist_tasks(head=None): """Creates MNIST training and evaluation tasks. Args: head: Adaptor layer to put before loss and accuracy layers in the tasks. Returns: A pair (train_task, eval_task) consisting of the MNIST training task and the MNIST evaluation task using cross-entropy as loss and accuracy as metric. """ loss = tl.CrossEntropyLoss() accuracy = tl.Accuracy() if head is not None: loss = tl.Serial(head, loss) accuracy = tl.Serial(head, accuracy) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), loss, adam.Adam(0.001), ) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [loss, accuracy], n_eval_batches=10, metric_names=['CrossEntropy', 'Accuracy'], ) return (task, eval_task)
def test_train_mnist_multitask(self, mock_stdout): """Train two-head MNIST model a bit, to compare to other implementations.""" mnist_model = _build_model(two_heads=True) # MNIST classification task. (cls_task, cls_eval_task) = _mnist_tasks(head=tl.Select([0], n_in=2)) # Auxiliary brightness prediction task. reg_task = training.TrainTask( itertools.cycle(_mnist_brightness_dataset().train_stream(1)), tl.Serial(tl.Select([1]), tl.L2Loss()), adam.Adam(0.001), ) reg_eval_task = training.EvalTask( itertools.cycle(_mnist_brightness_dataset().eval_stream(1)), [tl.Serial(tl.Select([1]), tl.L2Loss())], n_eval_batches=1, metric_names=['L2'], ) training_session = training.Loop( mnist_model, tasks=[cls_task, reg_task], eval_tasks=[cls_eval_task, reg_eval_task], eval_at=lambda step_n: step_n % 20 == 0, which_task=lambda step_n: step_n % 2, ) training_session.run(n_steps=100) self.assertEqual(training_session.step, 100) # Assert that we reach at least 80% eval accuracy on MNIST. self.assertGreater(_read_metric('Accuracy', mock_stdout), 0.8) # Assert that we get below 0.03 brightness prediction error. self.assertLess(_read_metric('L2', mock_stdout), 0.03)
def _mnist_tasks(): task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adam.Adam(0.001), ) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), (tl.CrossEntropyLoss(), tl.Accuracy()), n_eval_batches=10, metric_names=('CrossEntropy', 'Accuracy'), ) return (task, eval_task)