Example #1
0
 def testUnsupportedTransformError(self):
   dataset = dataset_ops.Dataset.range(1024).batch(32).apply(sleep.sleep(10))
   with self.assertRaises(errors.InvalidArgumentError):
     rebatched_dataset = distribute._RebatchDataset(
         dataset, num_replicas=4, use_fallback=False)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
 def testUnsupportedTransformError(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(1024).batch(
         32, drop_remainder=drop_remainder).apply(sleep.sleep(10))
     with self.assertRaises(errors.InvalidArgumentError):
         rebatched_dataset = distribute._RebatchDataset(dataset,
                                                        num_workers=4)
         next_element = self.getNext(rebatched_dataset)
         self.evaluate(next_element())
    def testWithUnknownBatchDim(self):
        dataset = dataset_ops.Dataset.range(1024).batch(
            32, drop_remainder=False).apply(sleep.sleep(10))
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
        self.assertDatasetProduces(rebatched_dataset, expected_output)
Example #4
0
 def testUnsupportedTransformInFlatMapError(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(2).flat_map(
       lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
           32, drop_remainder=drop_remainder).apply(sleep.sleep(10)))
   with self.assertRaises(errors.InvalidArgumentError):
     rebatched_dataset = distribute._RebatchDataset(
         dataset, num_workers=4, use_fallback=False)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
Example #5
0
  def testWithUnknownBatchDim(self):
    dataset = dataset_ops.Dataset.range(1024).batch(
        32, drop_remainder=False).apply(sleep.sleep(10))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "Cannot use rebatching fallback"):
      rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
      next_element = self.getNext(rebatched_dataset)
      self.evaluate(next_element())
Example #6
0
  def testWithUnhandledTransformation(self):
    dataset = dataset_ops.Dataset.range(1024).batch(
        32, drop_remainder=True).apply(sleep.sleep(10))
    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
    self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
    self.assertEqual([[8]],
                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(rebatched_dataset, expected_output)
Example #7
0
  def testBatchSizeIndivisibleByNumWorkers(self):
    # This doesn't work; reshape requires tensor shape to be exactly divisible
    # by the second dim.
    dataset = dataset_ops.Dataset.range(64).batch(
        32, drop_remainder=True).apply(sleep.sleep(10))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "Cannot use rebatching fallback"):
      rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
      next_element = self.getNext(rebatched_dataset)
      self.evaluate(next_element())
Example #8
0
 def testSleep(self):
   sleep_microseconds = 100
   dataset = dataset_ops.Dataset.range(10).apply(
       sleep.sleep(sleep_microseconds))
   next_element = self.getNext(dataset)
   start_time = time.time()
   for i in range(10):
     self.assertEqual(i, self.evaluate(next_element()))
   end_time = time.time()
   self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
   with self.assertRaises(errors.OutOfRangeError):
     self.evaluate(next_element())
Example #9
0
 def testSleep(self):
     sleep_microseconds = 100
     dataset = dataset_ops.Dataset.range(10).apply(
         sleep.sleep(sleep_microseconds))
     next_element = self.getNext(dataset)
     start_time = time.time()
     for i in range(10):
         self.assertEqual(i, self.evaluate(next_element()))
     end_time = time.time()
     self.assertGreater(end_time - start_time,
                        (10 * sleep_microseconds) / 1e6)
     with self.assertRaises(errors.OutOfRangeError):
         self.evaluate(next_element())
Example #10
0
    def testSleepCancellation(self):
        sleep_microseconds = int(1e6) * 1000
        ds = dataset_ops.Dataset.range(1)
        ds = ds.apply(sleep.sleep(sleep_microseconds))
        ds = ds.prefetch(1)
        get_next = self.getNext(ds, requires_initialization=True)

        with self.cached_session() as sess:
            thread = self.checkedThread(self.assert_op_cancelled,
                                        args=(get_next(), ))
            thread.start()
            time.sleep(0.2)
            sess.close()
            thread.join()
Example #11
0
    def testSleepBackgroundCancellation(self):
        ds = dataset_ops.Dataset.range(1)

        sleep_microseconds = int(1e6) * 1000
        ds_sleep = dataset_ops.Dataset.range(1)
        ds_sleep = ds.apply(sleep.sleep(sleep_microseconds))

        ds = ds.concatenate(ds_sleep)
        ds = ds.prefetch(1)

        get_next = self.getNext(ds, requires_initialization=True)

        with self.cached_session():
            self.assertEqual(self.evaluate(get_next()), 0)
Example #12
0
  def testWithUnhandledTransformationInFlatMap(self):
    dataset = dataset_ops.Dataset.range(2).flat_map(
        lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
            32, drop_remainder=True).apply(sleep.sleep(10)))
    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)

    self.assertEqual([[8]],
                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

    # Two elements where each element is a list of 4 elements where each element
    # is a list of 8.
    expected_output = [
        [k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
        for _ in range(2) for i in range(0, 32, 8)]  # generates 4 elements
    self.assertDatasetProduces(rebatched_dataset, expected_output)
Example #13
0
    def __init__(self, metric, endpoint, interval=None, internal=True):
        """PrometheusScrapeStreamIODataset."""
        with tf.name_scope("PrometheusScrapeStreamIODataset"):
            assert internal

            interval = 1000000 if interval is None else interval

            dataset = tf.data.Dataset.range(0, 10, 1)
            dataset = dataset.map(
                lambda i: golang_ops.io_prometheus_scrape(metric, endpoint, i))
            dataset = dataset.apply(sleep(interval))

            self._dataset = dataset
            super(PrometheusScrapeStreamIODataset,
                  self).__init__(self._dataset._variant_tensor)  # pylint: disable=protected-access
Example #14
0
  def testSleep(self):
    sleep_microseconds = 100
    dataset = dataset_ops.Dataset.range(10).apply(
        sleep.sleep(sleep_microseconds))
    iterator = dataset_ops.make_initializable_iterator(dataset)
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      self.evaluate(iterator.initializer)
      start_time = time.time()
      for i in range(10):
        self.assertEqual(i, self.evaluate(next_element))
      end_time = time.time()
      self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
Example #15
0
  def testSleep(self):
    sleep_microseconds = 100
    dataset = dataset_ops.Dataset.range(10).apply(
        sleep.sleep(sleep_microseconds))
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      sess.run(iterator.initializer)
      start_time = time.time()
      for i in range(10):
        self.assertEqual(i, sess.run(next_element))
      end_time = time.time()
      self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)
    def testBatchSizeNotDivisibleByNumReplicas(self):
        dataset = dataset_ops.Dataset.range(64).batch(
            32, drop_remainder=True).apply(sleep.sleep(10))

        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)

        expected_output = []
        i = 0
        for _ in range(2):  # number of steps
            # first four minibatches have seven elements
            for _ in range(4):
                expected_output.append([k for k in range(i, i + 7)])
                i += 7
            # last minibatch has four elements
            expected_output.append([k for k in range(i, i + 4)])
            i += 4
        self.assertDatasetProduces(rebatched_dataset, expected_output)
 def slow_branch(dataset):
   return dataset.apply(sleep.sleep(10000))
 def make_dataset(time_us, num_elements):
   return dataset_ops.Dataset.range(num_elements).apply(sleep.sleep(time_us))
 def setup_slow_dataset(self):
   dataset = self.setup_fast_dataset()
   self.iters = 1000
   # sleep for 1e-3s per iteration
   return dataset.apply(sleep.sleep(1000))
 def make_dataset(time_us, num_elements):
     dataset = dataset_ops.Dataset.range(num_elements)
     if time_us > 0:
         dataset = dataset.apply(sleep.sleep(time_us))
     return dataset
 def fast_branch(dataset):
   return dataset.apply(sleep.sleep(10))
 def fast_branch(dataset):
   return dataset.apply(sleep.sleep(10))
 def make_dataset(time_us, num_elements):
   return dataset_ops.Dataset.range(num_elements).apply(sleep.sleep(time_us))
 def slow_branch(dataset):
   return dataset.apply(sleep.sleep(10000))