Esempio n. 1
0
def _mnist_tasks(head=None):
    """Creates MNIST training and evaluation tasks.

  Args:
    head: Adaptor layer to put before loss and accuracy layers in the tasks.

  Returns:
    A pair (train_task, eval_task) consisting of the MNIST training task and the
    MNIST evaluation task using cross-entropy as loss and accuracy as metric.
  """
    loss = tl.CrossEntropyLoss()
    accuracy = tl.Accuracy()
    if head is not None:
        loss = tl.Serial(head, loss)
        accuracy = tl.Serial(head, accuracy)
    task = training.TrainTask(
        itertools.cycle(_mnist_dataset().train_stream(1)),
        loss,
        adam.Adam(0.001),
    )
    eval_task = training.EvalTask(
        itertools.cycle(_mnist_dataset().eval_stream(1)),
        [loss, accuracy],
        n_eval_batches=10,
        metric_names=['CrossEntropy', 'Accuracy'],
    )
    return (task, eval_task)
Esempio n. 2
0
    def test_train_mnist_multitask(self, mock_stdout):
        """Train two-head MNIST model a bit, to compare to other implementations."""
        mnist_model = _build_model(two_heads=True)
        # MNIST classification task.
        (cls_task, cls_eval_task) = _mnist_tasks(head=tl.Select([0], n_in=2))
        # Auxiliary brightness prediction task.
        reg_task = training.TrainTask(
            itertools.cycle(_mnist_brightness_dataset().train_stream(1)),
            tl.Serial(tl.Select([1]), tl.L2Loss()),
            adam.Adam(0.001),
        )
        reg_eval_task = training.EvalTask(
            itertools.cycle(_mnist_brightness_dataset().eval_stream(1)),
            [tl.Serial(tl.Select([1]), tl.L2Loss())],
            n_eval_batches=1,
            metric_names=['L2'],
        )
        training_session = training.Loop(
            mnist_model,
            tasks=[cls_task, reg_task],
            eval_tasks=[cls_eval_task, reg_eval_task],
            eval_at=lambda step_n: step_n % 20 == 0,
            which_task=lambda step_n: step_n % 2,
        )

        training_session.run(n_steps=100)
        self.assertEqual(training_session.step, 100)

        # Assert that we reach at least 80% eval accuracy on MNIST.
        self.assertGreater(_read_metric('Accuracy', mock_stdout), 0.8)
        # Assert that we get below 0.03 brightness prediction error.
        self.assertLess(_read_metric('L2', mock_stdout), 0.03)
Esempio n. 3
0
def _mnist_tasks():
    task = training.TrainTask(
        itertools.cycle(_mnist_dataset().train_stream(1)),
        tl.CrossEntropyLoss(),
        adam.Adam(0.001),
    )
    eval_task = training.EvalTask(
        itertools.cycle(_mnist_dataset().eval_stream(1)),
        (tl.CrossEntropyLoss(), tl.Accuracy()),
        n_eval_batches=10,
        metric_names=('CrossEntropy', 'Accuracy'),
    )
    return (task, eval_task)