Ejemplo n.º 1
0
    def test_train_mnist(self):
        """Train MNIST model (almost) fully, to compare to other implementations.

    Evals for cross-entropy loss and accuracy are run every 50 steps;
    their values are visible in the test log.
    """
        gin.parse_config([
            'batch_fn.batch_size_per_device = 256',
            'batch_fn.eval_batch_size = 256',
        ])

        mnist_model = tl.Serial(
            tl.Flatten(),
            tl.Dense(512),
            tl.Relu(),
            tl.Dense(512),
            tl.Relu(),
            tl.Dense(10),
            tl.LogSoftmax(),
        )
        task = training.TrainTask(
            itertools.cycle(inputs.inputs('mnist').train_stream(1)),
            tl.CrossEntropyLoss(), adafactor.Adafactor(.02))
        evals = training.EvalTask(
            itertools.cycle(inputs.inputs('mnist').eval_stream(1)),
            [tl.CrossEntropyLoss(), tl.AccuracyScalar()],
            names=['CrossEntropyLoss', 'AccuracyScalar'],
            eval_at=lambda step_n: step_n % 50 == 0,
            eval_N=10)

        training_session = training.Loop(mnist_model, task, evals=evals)
        training_session.run(n_steps=1000)
        self.assertEqual(training_session.current_step, 1000)
Ejemplo n.º 2
0
    def test_c4(self):
        gin.bind_parameter('shuffle_and_batch_data.preprocess_fun',
                           inputs.c4_preprocess)
        gin.bind_parameter('c4_preprocess.max_target_length', 2048)

        gin.bind_parameter('batch_fn.batch_size_per_device', 8)
        gin.bind_parameter('batch_fn.eval_batch_size', 8)
        gin.bind_parameter('batch_fn.max_eval_length', 2048)
        gin.bind_parameter('batch_fn.buckets', ([2049], [8, 1]))

        # Just make sure this doesn't throw.
        _ = inputs.inputs('c4',
                          data_dir=_TESTDATA,
                          input_name='targets',
                          target_name='text')
Ejemplo n.º 3
0
    def test_c4(self):
        gin.bind_parameter('shuffle_and_batch_data.preprocess_fun',
                           inputs.c4_preprocess)
        gin.bind_parameter('c4_preprocess.max_target_length', 2048)
        gin.bind_parameter('c4_preprocess.tokenization', 'spc')
        gin.bind_parameter('c4_preprocess.spm_path',
                           os.path.join(_TESTDATA, 'sentencepiece.model'))

        gin.bind_parameter('batch_fn.batch_size_per_device', 8)
        gin.bind_parameter('batch_fn.eval_batch_size', 8)
        gin.bind_parameter('batch_fn.max_eval_length', 2048)
        gin.bind_parameter('batch_fn.buckets', ([2049], [8, 1]))

        # Just make sure this doesn't throw.
        _ = inputs.inputs('c4',
                          data_dir=_TESTDATA,
                          input_name='targets',
                          target_name='text')
Ejemplo n.º 4
0
def _mnist_dataset():
    """Loads (and caches) the standard MNIST data set."""
    return _add_weights(inputs.inputs('mnist'))