def testUnsupportedTransformError(self): dataset = dataset_ops.Dataset.range(1024).batch(32).apply(sleep.sleep(10)) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset( dataset, num_replicas=4, use_fallback=False) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testUnsupportedTransformError(self, drop_remainder): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=drop_remainder).apply(sleep.sleep(10)) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testWithUnknownBatchDim(self): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=False).apply(sleep.sleep(10)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def testUnsupportedTransformInFlatMapError(self, drop_remainder): dataset = dataset_ops.Dataset.range(2).flat_map( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32, drop_remainder=drop_remainder).apply(sleep.sleep(10))) with self.assertRaises(errors.InvalidArgumentError): rebatched_dataset = distribute._RebatchDataset( dataset, num_workers=4, use_fallback=False) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testWithUnknownBatchDim(self): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=False).apply(sleep.sleep(10)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Cannot use rebatching fallback"): rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testWithUnhandledTransformation(self): dataset = dataset_ops.Dataset.range(1024).batch( 32, drop_remainder=True).apply(sleep.sleep(10)) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)]) self.assertEqual([[8]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension self.assertDatasetProduces(rebatched_dataset, expected_output)
def testBatchSizeIndivisibleByNumWorkers(self): # This doesn't work; reshape requires tensor shape to be exactly divisible # by the second dim. dataset = dataset_ops.Dataset.range(64).batch( 32, drop_remainder=True).apply(sleep.sleep(10)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Cannot use rebatching fallback"): rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5) next_element = self.getNext(rebatched_dataset) self.evaluate(next_element())
def testSleep(self): sleep_microseconds = 100 dataset = dataset_ops.Dataset.range(10).apply( sleep.sleep(sleep_microseconds)) next_element = self.getNext(dataset) start_time = time.time() for i in range(10): self.assertEqual(i, self.evaluate(next_element())) end_time = time.time() self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element())
def testSleepCancellation(self): sleep_microseconds = int(1e6) * 1000 ds = dataset_ops.Dataset.range(1) ds = ds.apply(sleep.sleep(sleep_microseconds)) ds = ds.prefetch(1) get_next = self.getNext(ds, requires_initialization=True) with self.cached_session() as sess: thread = self.checkedThread(self.assert_op_cancelled, args=(get_next(), )) thread.start() time.sleep(0.2) sess.close() thread.join()
def testSleepBackgroundCancellation(self): ds = dataset_ops.Dataset.range(1) sleep_microseconds = int(1e6) * 1000 ds_sleep = dataset_ops.Dataset.range(1) ds_sleep = ds.apply(sleep.sleep(sleep_microseconds)) ds = ds.concatenate(ds_sleep) ds = ds.prefetch(1) get_next = self.getNext(ds, requires_initialization=True) with self.cached_session(): self.assertEqual(self.evaluate(get_next()), 0)
def testWithUnhandledTransformationInFlatMap(self): dataset = dataset_ops.Dataset.range(2).flat_map( lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda 32, drop_remainder=True).apply(sleep.sleep(10))) rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4) self.assertEqual([[8]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)]) # Two elements where each element is a list of 4 elements where each element # is a list of 8. expected_output = [ [k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension for _ in range(2) for i in range(0, 32, 8)] # generates 4 elements self.assertDatasetProduces(rebatched_dataset, expected_output)
def __init__(self, metric, endpoint, interval=None, internal=True): """PrometheusScrapeStreamIODataset.""" with tf.name_scope("PrometheusScrapeStreamIODataset"): assert internal interval = 1000000 if interval is None else interval dataset = tf.data.Dataset.range(0, 10, 1) dataset = dataset.map( lambda i: golang_ops.io_prometheus_scrape(metric, endpoint, i)) dataset = dataset.apply(sleep(interval)) self._dataset = dataset super(PrometheusScrapeStreamIODataset, self).__init__(self._dataset._variant_tensor) # pylint: disable=protected-access
def testSleep(self): sleep_microseconds = 100 dataset = dataset_ops.Dataset.range(10).apply( sleep.sleep(sleep_microseconds)) iterator = dataset_ops.make_initializable_iterator(dataset) next_element = iterator.get_next() with self.cached_session() as sess: self.evaluate(iterator.initializer) start_time = time.time() for i in range(10): self.assertEqual(i, self.evaluate(next_element)) end_time = time.time() self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testSleep(self): sleep_microseconds = 100 dataset = dataset_ops.Dataset.range(10).apply( sleep.sleep(sleep_microseconds)) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() with self.cached_session() as sess: sess.run(iterator.initializer) start_time = time.time() for i in range(10): self.assertEqual(i, sess.run(next_element)) end_time = time.time() self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def testBatchSizeNotDivisibleByNumReplicas(self): dataset = dataset_ops.Dataset.range(64).batch( 32, drop_remainder=True).apply(sleep.sleep(10)) rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5) expected_output = [] i = 0 for _ in range(2): # number of steps # first four minibatches have seven elements for _ in range(4): expected_output.append([k for k in range(i, i + 7)]) i += 7 # last minibatch has four elements expected_output.append([k for k in range(i, i + 4)]) i += 4 self.assertDatasetProduces(rebatched_dataset, expected_output)
def slow_branch(dataset): return dataset.apply(sleep.sleep(10000))
def make_dataset(time_us, num_elements): return dataset_ops.Dataset.range(num_elements).apply(sleep.sleep(time_us))
def setup_slow_dataset(self): dataset = self.setup_fast_dataset() self.iters = 1000 # sleep for 1e-3s per iteration return dataset.apply(sleep.sleep(1000))
def make_dataset(time_us, num_elements): dataset = dataset_ops.Dataset.range(num_elements) if time_us > 0: dataset = dataset.apply(sleep.sleep(time_us)) return dataset
def fast_branch(dataset): return dataset.apply(sleep.sleep(10))