Ejemplo n.º 1
0
    def testFromDatasetId(self):
        """Tests cross-trainer cache with `register_dataset`/`from_dataset_id`."""
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset_id1 = data_service_ops.register_dataset(
            cluster.dispatcher.target, dataset, dataset_id="dataset_id")
        dataset1 = data_service_ops.from_dataset_id(
            processing_mode=data_service_ops.ShardingPolicy.OFF,
            service=cluster.dispatcher.target,
            dataset_id=dataset_id1,
            element_spec=dataset.element_spec,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        self.assertDatasetProduces(dataset1.take(10), list(range(10)))

        dataset_id2 = data_service_ops.register_dataset(
            cluster.dispatcher.target, dataset, dataset_id="dataset_id")
        dataset2 = data_service_ops.from_dataset_id(
            processing_mode=data_service_ops.ShardingPolicy.OFF,
            service=cluster.dispatcher.target,
            dataset_id=dataset_id2,
            element_spec=dataset.element_spec,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 2"))
        self.assertDatasetProduces(dataset2.take(10), list(range(10)))
Ejemplo n.º 2
0
    def testMultipleIterationsForOneDatasetGraphMode(self):
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        # These clients are assumed to be from the same training cluster. Thus, they
        # do not reuse data from the cross-trainer cache.
        output1 = self.getDatasetOutput(dataset1.take(10))
        output1 += self.getDatasetOutput(dataset1.take(10))
        output1 += self.getDatasetOutput(dataset1.take(10))
        self.assertLen(set(output1), 30)

        dataset2 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 2"))
        # These clients reuse some data from the previous clients (not exactly the
        # same data due to client-side buffering).
        output2 = self.getDatasetOutput(dataset2.take(10))
        output2 += self.getDatasetOutput(dataset2.take(10))
        output2 += self.getDatasetOutput(dataset2.take(10))
        self.assertTrue(set(output1) & set(output2))
Ejemplo n.º 3
0
    def testSlowClientSkipsData(self):
        cluster = self._create_cluster(num_workers=1,
                                       cross_trainer_cache_size_bytes=500)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        self.assertDatasetProduces(dataset1.take(200), list(range(200)))

        dataset2 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 2"))
        dataset2 = dataset2.take(200)
        output = self.getDatasetOutput(dataset2)
        # When the cache is small, the second trainer couldn't read the beginning of
        # the dataset. It can still read 100 elements from the dataset, because the
        # dataset is infinite.
        self.assertGreater(output[0], 0)
        self.assertEqual(self.evaluate(dataset2.cardinality()), 200)
Ejemplo n.º 4
0
    def testConcurrentReaders(self):
        # Fetching an element from the dataset will trigger prefetches of more
        # elements, one per CPU core which will be placed in the cache.
        # However if the number of prefetches exceeds the space available in
        # the cache then the sliding window will be moved forward away from
        # the element just read thus negating the use of the cache as other
        # trainers will not get the correct element.
        # Hence the need to calculate the size of the cache based on the
        # number of CPU cores and the element size of 363. The extra 8
        # entries are simply a bit of margin.
        num_cpus = multiprocessing.cpu_count()
        cluster = self._create_cluster(
            num_workers=1, cross_trainer_cache_size_bytes=(num_cpus + 8) * 363)
        num_readers = 20
        num_elements = 50
        dataset = dataset_ops.Dataset.range(10000000).repeat()

        datasets = []
        iterators = []
        for i in range(num_readers):
            distributed_dataset = self.make_distributed_dataset(
                dataset,
                cluster,
                job_name="job",
                cross_trainer_cache=data_service_ops.CrossTrainerCache(
                    trainer_id=f"Trainer {i}"),
                max_outstanding_requests=1)
            iterator = self.getNext(distributed_dataset)
            datasets.append(distributed_dataset)
            iterators.append(iterator)

        for i in range(num_elements):
            # All the readers read the same element in one step.
            for j in range(num_readers):
                self.assertEqual(self.evaluate(iterators[j]()), i)
Ejemplo n.º 5
0
    def testConcurrentReaders(self):
        cluster = self._create_cluster(num_workers=1,
                                       cross_trainer_cache_size_bytes=18000)
        num_readers = 20
        num_elements = 50
        dataset = dataset_ops.Dataset.range(10000000).repeat()

        datasets = []
        iterators = []
        for i in range(num_readers):
            distributed_dataset = self.make_distributed_dataset(
                dataset,
                cluster,
                job_name="job",
                cross_trainer_cache=data_service_ops.CrossTrainerCache(
                    trainer_id=f"Trainer {i}"),
                max_outstanding_requests=1)
            iterator = self.getNext(distributed_dataset)
            datasets.append(distributed_dataset)
            iterators.append(iterator)

        for i in range(num_elements):
            # All the readers read the same element in one step.
            for j in range(num_readers):
                self.assertEqual(self.evaluate(iterators[j]()), i)
Ejemplo n.º 6
0
    def testDifferentJobNames(self):
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job1",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        self.assertDatasetProduces(dataset1.take(10), list(range(10)))

        dataset2 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job2",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 2"))
        self.assertDatasetProduces(dataset2.take(10), list(range(10)))
Ejemplo n.º 7
0
    def testDifferentJobNames(self):
        # TODO(b/221104308): Disallow this use case because it increases RAM usage.
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job1",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        self.assertDatasetProduces(dataset1.take(10), list(range(10)))

        dataset2 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job2",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 2"))
        self.assertDatasetProduces(dataset2.take(10), list(range(10)))
Ejemplo n.º 8
0
    def testEnableCrossTrainerCache(self):
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        self.assertDatasetProduces(dataset1.take(10), list(range(10)))

        # The second client reads the same data from the cross-trainer cache.
        dataset2 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 2"))
        self.assertDatasetProduces(dataset2.take(10), list(range(10)))
Ejemplo n.º 9
0
    def testSameTrainerID(self):
        # Jobs from the same training cluster do not reuse data from the cache.
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer ID"))
        self.assertDatasetProduces(dataset1.take(10), list(range(10)))

        dataset2 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer ID"))
        output = self.getDatasetOutput(dataset2.take(10))
        self.assertGreaterEqual(output[0], 10)
Ejemplo n.º 10
0
    def testShuffleDataset(self):
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat().shuffle(
            buffer_size=100)
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        output1 = self.getDatasetOutput(dataset1.take(10))

        dataset2 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 2"))
        output2 = self.getDatasetOutput(dataset2.take(10))
        self.assertEqual(output1, output2)
Ejemplo n.º 11
0
    def testCompressionMismatch(self):
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        self.assertDatasetProduces(dataset1.take(10), list(range(10)))

        with self.assertRaisesRegex(errors.InvalidArgumentError,
                                    "Data type mismatch"):
            dataset2 = self.make_distributed_dataset(
                dataset,
                cluster,
                job_name="job",
                compression=None,
                cross_trainer_cache=data_service_ops.CrossTrainerCache(
                    trainer_id="Trainer 1"))
            self.getDatasetOutput(dataset2)
Ejemplo n.º 12
0
 def testDisallowFiniteDataset(self):
     cluster = self._create_cluster(num_workers=1)
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Cross-trainer caching requires the input dataset to be infinite."
     ):
         dataset = self.make_distributed_range_dataset(
             10,
             cluster,
             job_name="job",
             cross_trainer_cache=data_service_ops.CrossTrainerCache(
                 trainer_id="Trainer 1"))
         self.getDatasetOutput(dataset)
Ejemplo n.º 13
0
 def testRequiresJobName(self):
     cluster = self._create_cluster(num_workers=1)
     dataset = dataset_ops.Dataset.range(10000000).repeat()
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Cross-trainer caching requires named jobs. Got empty `job_name`."
     ):
         dataset = self.make_distributed_dataset(
             dataset,
             cluster,
             job_name=None,
             cross_trainer_cache=data_service_ops.CrossTrainerCache(
                 trainer_id="Trainer 1"))
         self.getDatasetOutput(dataset)
Ejemplo n.º 14
0
    def testDynamicSharding(self):
        cluster = self._create_cluster(num_workers=2)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            processing_mode=data_service_ops.ShardingPolicy.DYNAMIC,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        output1 = self.getDatasetOutput(dataset1.take(100))

        # The second client reads the same data from the cross-trainer cache.
        dataset2 = self.make_distributed_dataset(
            dataset,
            cluster,
            processing_mode=data_service_ops.ShardingPolicy.DYNAMIC,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 2"))
        output2 = self.getDatasetOutput(dataset2.take(100))
        # Verifies the intersection is non-empty.
        self.assertTrue(set(output1) & set(output2))
Ejemplo n.º 15
0
    def testRequiresNonEmptyTrainerID(self):
        cluster = self._create_cluster(num_workers=2)
        dataset = dataset_ops.Dataset.range(10000000).repeat()

        with self.assertRaisesRegex(
                ValueError,
                "tf.data service cross-trainer cache requires a non-empty trainer ID."
        ):
            self.make_distributed_dataset(
                dataset,
                cluster,
                job_name="job",
                cross_trainer_cache=data_service_ops.CrossTrainerCache(
                    trainer_id=None))
Ejemplo n.º 16
0
 def testDisallowCoordinatedRead(self):
     cluster = self._create_cluster(num_workers=1)
     dataset = dataset_ops.Dataset.range(10000000).repeat()
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Cross-trainer caching does not support coordinated reads."):
         dataset = self.make_distributed_dataset(
             dataset,
             cluster,
             job_name="job",
             num_consumers=1,
             consumer_index=0,
             cross_trainer_cache=data_service_ops.CrossTrainerCache(
                 trainer_id="Trainer 1"))
         self.getDatasetOutput(dataset)
Ejemplo n.º 17
0
 def testRequiresInfiniteDataset(self, range_):
     cluster = self._create_cluster(num_workers=1)
     dataset = dataset_ops.Dataset.range(range_).map(lambda x: x + 1)
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Cross-trainer caching requires the input dataset to be infinite."
     ):
         dataset = dataset.apply(
             data_service_ops.distribute(
                 processing_mode=data_service_ops.ShardingPolicy.OFF,
                 service=cluster.dispatcher.target,
                 job_name="job_name",
                 cross_trainer_cache=data_service_ops.CrossTrainerCache(
                     trainer_id="Trainer ID")))
         self.getDatasetOutput(dataset)
Ejemplo n.º 18
0
 def testMultipleIterationsForOneDatasetEagerMode(self):
     cluster = self._create_cluster(num_workers=1)
     dataset = dataset_ops.Dataset.range(10000000).repeat()
     dataset1 = self.make_distributed_dataset(
         dataset,
         cluster,
         job_name="job",
         cross_trainer_cache=data_service_ops.CrossTrainerCache(
             trainer_id="Trainer 1"))
     # In the eager mode, each iteration creates a new data service job and does
     # not reuse cached data. We disallow this use case.
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Cross-trainer caching requires infinite datasets and disallows "
             "multiple iterations of the same dataset."):
         self.getDatasetOutput(dataset1.take(10))
         self.getDatasetOutput(dataset1.take(10))
         self.getDatasetOutput(dataset1.take(10))
Ejemplo n.º 19
0
    def testSmallCache(self):
        cluster = self._create_cluster(num_workers=1,
                                       cross_trainer_cache_size_bytes=500)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        num_readers = 20

        for i in range(num_readers):
            # Even if the cache is small and may discard old data, each trainer can
            # still read the required number of elements because the input dataset is
            # infinite.
            distributed_dataset = self.make_distributed_dataset(
                dataset,
                cluster,
                job_name="job",
                cross_trainer_cache=data_service_ops.CrossTrainerCache(
                    trainer_id=f"Trainer {i}"))
            output = self.getDatasetOutput(distributed_dataset.take(200))
            self.assertLen(output, 200)
Ejemplo n.º 20
0
    def testNamedJobMismatch(self):
        cluster = self._create_cluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(10000000).repeat()
        dataset1 = self.make_distributed_dataset(
            dataset,
            cluster,
            job_name="job",
            cross_trainer_cache=data_service_ops.CrossTrainerCache(
                trainer_id="Trainer 1"))
        self.assertDatasetProduces(dataset1.take(10), list(range(10)))

        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                "Existing cross-trainer cache: <enabled>; got <disabled>"):
            dataset2 = self.make_distributed_dataset(dataset,
                                                     cluster,
                                                     job_name="job",
                                                     cross_trainer_cache=None)
            self.getDatasetOutput(dataset2)
  def testDispatcherRestart(self):
    cluster = self._create_cluster(num_workers=1)
    dataset = dataset_ops.Dataset.range(10000000).repeat()
    distributed_dataset = self.make_distributed_dataset(
        dataset,
        cluster,
        job_name="job",
        cross_trainer_cache=data_service_ops.CrossTrainerCache(
            trainer_id="Trainer 1"))

    get_next = self.getNext(distributed_dataset)
    elements = self._get_next(get_next, 100)
    self.assertEqual(elements, list(range(100)))

    cluster.restart_dispatcher()

    # Dispatcher restart should not affect the workers.
    elements = self._get_next(get_next, 100)
    self.assertEqual(elements, list(range(100, 200)))
  def testWorkerRestart(self):
    cluster = self._create_cluster(num_workers=1)
    dataset = dataset_ops.Dataset.range(10000000).repeat()
    distributed_dataset = self.make_distributed_dataset(
        dataset,
        cluster,
        job_name="job",
        cross_trainer_cache=data_service_ops.CrossTrainerCache(
            trainer_id="Trainer 1"))

    get_next = self.getNext(distributed_dataset)
    elements = self._get_next(get_next, 100)
    self.assertEqual(elements, list(range(100)))

    cluster.workers[0].restart()

    # Read until we get results from the restarted worker, then read some more.
    while self.evaluate(get_next()) != 0:
      pass

    elements = self._get_next(get_next, 100)
    self.assertEqual(elements, list(range(1, 101)))