def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ gin.parse_config([ 'batch_fn.batch_size_per_device = 256', 'batch_fn.eval_batch_size = 256', ]) mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(inputs.inputs('mnist').train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) evals = training.EvalTask( itertools.cycle(inputs.inputs('mnist').eval_stream(1)), [tl.CrossEntropyLoss(), tl.AccuracyScalar()], names=['CrossEntropyLoss', 'AccuracyScalar'], eval_at=lambda step_n: step_n % 50 == 0, eval_N=10) training_session = training.Loop(mnist_model, task, evals=evals) training_session.run(n_steps=1000) self.assertEqual(training_session.current_step, 1000)
def test_c4(self): gin.bind_parameter('shuffle_and_batch_data.preprocess_fun', inputs.c4_preprocess) gin.bind_parameter('c4_preprocess.max_target_length', 2048) gin.bind_parameter('batch_fn.batch_size_per_device', 8) gin.bind_parameter('batch_fn.eval_batch_size', 8) gin.bind_parameter('batch_fn.max_eval_length', 2048) gin.bind_parameter('batch_fn.buckets', ([2049], [8, 1])) # Just make sure this doesn't throw. _ = inputs.inputs('c4', data_dir=_TESTDATA, input_name='targets', target_name='text')
def test_c4(self): gin.bind_parameter('shuffle_and_batch_data.preprocess_fun', inputs.c4_preprocess) gin.bind_parameter('c4_preprocess.max_target_length', 2048) gin.bind_parameter('c4_preprocess.tokenization', 'spc') gin.bind_parameter('c4_preprocess.spm_path', os.path.join(_TESTDATA, 'sentencepiece.model')) gin.bind_parameter('batch_fn.batch_size_per_device', 8) gin.bind_parameter('batch_fn.eval_batch_size', 8) gin.bind_parameter('batch_fn.max_eval_length', 2048) gin.bind_parameter('batch_fn.buckets', ([2049], [8, 1])) # Just make sure this doesn't throw. _ = inputs.inputs('c4', data_dir=_TESTDATA, input_name='targets', target_name='text')
def _mnist_dataset(): """Loads (and caches) the standard MNIST data set.""" return _add_weights(inputs.inputs('mnist'))