예제 #1
0
 def test_serial_composes(self):
     """Check that data.Serial works inside another data.Serial."""
     dataset = lambda _: ((i, i + 1) for i in range(10))
     serial1 = data.Serial(dataset, data.Shuffle(3))
     batches = data.Serial(serial1, data.Batch(10))
     batch = next(batches())
     self.assertLen(batch, 2)
     self.assertEqual(batch[0].shape, (10, ))
예제 #2
0
    def test_train_save_restore_sharded(self):
        """Saves and restores a sharded checkpoint to check for equivalence."""
        if fastmath.local_device_count() < 2:
            return  # multi-accelerator only
        base.N_WEIGHTS_SHARDS = fastmath.local_device_count()
        train_data = data.Serial(lambda _: _very_simple_data(2, 2),
                                 data.CountAndSkip('simple_data'))
        task = training.TrainTask(train_data(), tl.L2Loss(),
                                  optimizers.Adam(.0001))
        eval_task = training.EvalTask(
            _very_simple_data(2, 2),  # deliberately re-using training data
            [tl.L2Loss()],
            metric_names=['SGD.L2Loss'])
        tmp_dir = self.create_tempdir().full_path

        def _make_model_and_session():
            m = tl.Serial(tl.Dense(2))
            ts = training.Loop(m, [task],
                               eval_tasks=[eval_task],
                               eval_at=lambda step_n: step_n % 2 == 0,
                               output_dir=tmp_dir)
            return m, ts

        _, training_session = _make_model_and_session()
        self.assertEqual(0, training_session.step)
        training_session.run(n_steps=1)
        training_session.save_checkpoint('model')
        _, training_session2 = _make_model_and_session()
        training_session2.run(n_steps=1)
        base.N_WEIGHTS_SHARDS = 1
예제 #3
0
 def test_serial_with_python(self):
     dataset = lambda _: ((i, i + 1) for i in range(10))
     batches = data.Serial(dataset,
                           lambda g: map(lambda x: (x[0], x[1] + 1), g),
                           lambda g: filter(lambda x: x[0] % 2 == 1, g),
                           data.Batch(2))
     batch = next(batches())
     self.assertLen(batch, 2)
     (xs, ys) = batch
     # First tuple after filtering is (1, 3) = (1, 2+1).
     self.assertEqual(xs[0], 1)
     self.assertEqual(ys[0], 3)
     # Second tuple after filtering is (3, 5).
     self.assertEqual(xs[1], 3)
     self.assertEqual(ys[1], 5)
예제 #4
0
 def test_count_and_skip(self):
     dataset = lambda _: ((i, i + 1) for i in range(10))
     examples = data.Serial(dataset, data.CountAndSkip('toy_data'))
     ex_generator = examples()
     ex1 = next(ex_generator)
     self.assertEqual(ex1, (0, 1))
     self.assertEqual(data.inputs.data_counters['toy_data'], 1)
     ex2 = next(ex_generator)
     self.assertEqual(ex2, (1, 2))
     self.assertEqual(data.inputs.data_counters['toy_data'], 2)
     ex3 = next(examples())  # new generator, will skip
     self.assertEqual(ex3, (2, 3))
     self.assertEqual(data.inputs.data_counters['toy_data'], 3)
     data.inputs.data_counters['toy_data'] = 0  # reset
     ex4 = next(examples())  # new generator, was reset
     self.assertEqual(ex4, (0, 1))
     self.assertEqual(data.inputs.data_counters['toy_data'], 1)
예제 #5
0
    def test_train_save_restore_dense(self):
        """Saves and restores a checkpoint to check for equivalence."""
        train_data = data.Serial(lambda _: _very_simple_data(),
                                 data.CountAndSkip('simple_data'))
        task = training.TrainTask(train_data(), tl.L2Loss(),
                                  optimizers.Adam(.0001))
        eval_task = training.EvalTask(
            _very_simple_data(),  # deliberately re-using training data
            [tl.L2Loss()],
            metric_names=['SGD.L2Loss'])
        tmp_dir = self.create_tempdir().full_path

        def _make_model_and_session():
            m = tl.Serial(tl.Dense(1))
            ts = training.Loop(m, [task],
                               eval_tasks=[eval_task],
                               eval_at=lambda step_n: step_n % 2 == 0,
                               output_dir=tmp_dir)
            return m, ts

        model, training_session = _make_model_and_session()
        self.assertEqual(0, training_session.step)
        training_session.run(n_steps=1)
        training_session.save_checkpoint()
        self.assertEqual(data.inputs.data_counters['simple_data'], 2)
        data.inputs.data_counters['simple_data'] = 0  # reset manually
        self.assertEqual(data.inputs.data_counters['simple_data'], 0)  # check
        model2, training_session2 = _make_model_and_session()
        self.assertEqual(data.inputs.data_counters['simple_data'],
                         2)  # restored

        x = np.ones((8, 1))
        y1 = model(x, rng=fastmath.random.get_prng(0))
        y2 = model2(x, rng=fastmath.random.get_prng(0))
        self.assertEqual(str(y1), str(y2))

        training_session2.run(n_steps=1)
        y1 = model(x, rng=fastmath.random.get_prng(0))
        y2 = model2(x, rng=fastmath.random.get_prng(0))
        self.assertNotEqual(str(y1), str(y2))

        slots1 = training_session._trainer_per_task[0].slots
        slots2 = training_session2._trainer_per_task[0].slots
        np.testing.assert_array_equal(slots1, slots2)
예제 #6
0
 def process_c4_with_span_corruption(spm_path=None,
                                     extra_ids=0,
                                     train=False,
                                     max_length=100,
                                     noise_density=0.15,
                                     mean_noise_span_length=3.0,
                                     seed1=None,
                                     seed2=None):
     return data.Serial(
         data.TFDS('c4/en:2.3.0',
                   data_dir=_TESTDATA,
                   keys=('text', ),
                   train=train),
         data.SentencePieceTokenize(spm_path=spm_path,
                                    extra_ids=extra_ids),
         data.generate_sequential_chunks(max_length=max_length),
         data.generate_random_noise_mask(
             noise_density=noise_density,
             mean_noise_span_length=mean_noise_span_length,
             seed1=seed1,
             seed2=seed2),
         data.consume_noise_mask(vocab_size=32000 + extra_ids),
         data.FilterEmptyExamples(),
         data.AppendValue(val={
             0: [1],
             1: [1]
         }),
         data.PadToLength(len_map={
             0: 100,
             1: 30
         },
                          pad_value={
                              0: 0,
                              1: 0
                          }), data.AddLossWeights(id_to_mask=0),
         data.Batch(batch_size=2))
예제 #7
0
 def test_serial(self):
     dataset = lambda _: ((i, i + 1) for i in range(10))
     batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10))
     batch = next(batches())
     self.assertLen(batch, 2)
     self.assertEqual(batch[0].shape, (10, ))