def test_serial_composes(self): """Check that data.Serial works inside another data.Serial.""" dataset = lambda _: ((i, i + 1) for i in range(10)) serial1 = data.Serial(dataset, data.Shuffle(3)) batches = data.Serial(serial1, data.Batch(10)) batch = next(batches()) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (10, ))
def test_train_save_restore_sharded(self): """Saves and restores a sharded checkpoint to check for equivalence.""" if fastmath.local_device_count() < 2: return # multi-accelerator only base.N_WEIGHTS_SHARDS = fastmath.local_device_count() train_data = data.Serial(lambda _: _very_simple_data(2, 2), data.CountAndSkip('simple_data')) task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(.0001)) eval_task = training.EvalTask( _very_simple_data(2, 2), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = tl.Serial(tl.Dense(2)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts _, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint('model') _, training_session2 = _make_model_and_session() training_session2.run(n_steps=1) base.N_WEIGHTS_SHARDS = 1
def test_serial_with_python(self): dataset = lambda _: ((i, i + 1) for i in range(10)) batches = data.Serial(dataset, lambda g: map(lambda x: (x[0], x[1] + 1), g), lambda g: filter(lambda x: x[0] % 2 == 1, g), data.Batch(2)) batch = next(batches()) self.assertLen(batch, 2) (xs, ys) = batch # First tuple after filtering is (1, 3) = (1, 2+1). self.assertEqual(xs[0], 1) self.assertEqual(ys[0], 3) # Second tuple after filtering is (3, 5). self.assertEqual(xs[1], 3) self.assertEqual(ys[1], 5)
def test_count_and_skip(self): dataset = lambda _: ((i, i + 1) for i in range(10)) examples = data.Serial(dataset, data.CountAndSkip('toy_data')) ex_generator = examples() ex1 = next(ex_generator) self.assertEqual(ex1, (0, 1)) self.assertEqual(data.inputs.data_counters['toy_data'], 1) ex2 = next(ex_generator) self.assertEqual(ex2, (1, 2)) self.assertEqual(data.inputs.data_counters['toy_data'], 2) ex3 = next(examples()) # new generator, will skip self.assertEqual(ex3, (2, 3)) self.assertEqual(data.inputs.data_counters['toy_data'], 3) data.inputs.data_counters['toy_data'] = 0 # reset ex4 = next(examples()) # new generator, was reset self.assertEqual(ex4, (0, 1)) self.assertEqual(data.inputs.data_counters['toy_data'], 1)
def test_train_save_restore_dense(self): """Saves and restores a checkpoint to check for equivalence.""" train_data = data.Serial(lambda _: _very_simple_data(), data.CountAndSkip('simple_data')) task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(.0001)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = tl.Serial(tl.Dense(1)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts model, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint() self.assertEqual(data.inputs.data_counters['simple_data'], 2) data.inputs.data_counters['simple_data'] = 0 # reset manually self.assertEqual(data.inputs.data_counters['simple_data'], 0) # check model2, training_session2 = _make_model_and_session() self.assertEqual(data.inputs.data_counters['simple_data'], 2) # restored x = np.ones((8, 1)) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertEqual(str(y1), str(y2)) training_session2.run(n_steps=1) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertNotEqual(str(y1), str(y2)) slots1 = training_session._trainer_per_task[0].slots slots2 = training_session2._trainer_per_task[0].slots np.testing.assert_array_equal(slots1, slots2)
def process_c4_with_span_corruption(spm_path=None, extra_ids=0, train=False, max_length=100, noise_density=0.15, mean_noise_span_length=3.0, seed1=None, seed2=None): return data.Serial( data.TFDS('c4/en:2.3.0', data_dir=_TESTDATA, keys=('text', ), train=train), data.SentencePieceTokenize(spm_path=spm_path, extra_ids=extra_ids), data.generate_sequential_chunks(max_length=max_length), data.generate_random_noise_mask( noise_density=noise_density, mean_noise_span_length=mean_noise_span_length, seed1=seed1, seed2=seed2), data.consume_noise_mask(vocab_size=32000 + extra_ids), data.FilterEmptyExamples(), data.AppendValue(val={ 0: [1], 1: [1] }), data.PadToLength(len_map={ 0: 100, 1: 30 }, pad_value={ 0: 0, 1: 0 }), data.AddLossWeights(id_to_mask=0), data.Batch(batch_size=2))
def test_serial(self): dataset = lambda _: ((i, i + 1) for i in range(10)) batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10)) batch = next(batches()) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (10, ))