def build_ds(): dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100) def branch(dataset): return dataset.apply(batching.unbatch()) return optimization._ChooseFastestBranchDataset( dataset, [branch, branch], ratio_denominator=10, num_elements_per_branch=100)
def build_ds(): dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100) def branch(dataset): return dataset.apply(batching.unbatch()) return optimization._ChooseFastestBranchDataset( dataset, [branch, branch], ratio_denominator=10, num_elements_per_branch=100)
def build_ds(size): dataset = dataset_ops.Dataset.range(size) def branch_0(dataset): return dataset.map(lambda x: x).batch(10) def branch_1(dataset): return dataset.batch(10).map(lambda x: x) return optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access dataset, [branch_0, branch_1], ratio_numerator=10)
def build_ds(size): dataset = dataset_ops.Dataset.range(size) def branch_0(dataset): return dataset.map(lambda x: x).batch(10) def branch_1(dataset): return dataset.batch(10).map(lambda x: x) return optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access dataset, [branch_0, branch_1], ratio_numerator=10)
def testWithMoreOutputThanInput(self): dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100) def branch(dataset): return dataset.apply(batching.unbatch()) choose_fastest = optimization._ChooseFastestBranchDataset( dataset, [branch, branch], ratio_denominator=100, num_elements_per_branch=100) self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
def build_ds(): dataset = dataset_ops.Dataset.range(10) const_64 = constant_op.constant(1, dtypes.int64) const_32 = constant_op.constant(1, dtypes.int32) def branch_0(dataset): return dataset.map(lambda x: x + const_64) def branch_1(dataset): return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64)) return optimization._ChooseFastestBranchDataset( dataset, [branch_0, branch_1], num_elements_per_branch=3)
def build_ds(): dataset = dataset_ops.Dataset.range(10) const_64 = constant_op.constant(1, dtypes.int64) const_32 = constant_op.constant(1, dtypes.int32) def branch_0(dataset): return dataset.map(lambda x: x + const_64) def branch_1(dataset): return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64)) return optimization._ChooseFastestBranchDataset( dataset, [branch_0, branch_1], num_elements_per_branch=3)
def testSimple(self): dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4]) def branch(dataset): return dataset.map(lambda x: x) choose_fastest = optimization._ChooseFastestBranchDataset( dataset, [branch, branch]) self.assertDatasetProduces( choose_fastest, expected_output=[0, 1, 2, 3, 4], expected_shapes=dataset.output_shapes)
def testWithMoreOutputThanInput(self): dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100) def branch(dataset): return dataset.apply(batching.unbatch()) choose_fastest = optimization._ChooseFastestBranchDataset( dataset, [branch, branch], ratio_denominator=100, num_elements_per_branch=100) self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
def testSimple(self): dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4]) def branch(dataset): return dataset.map(lambda x: x) choose_fastest = optimization._ChooseFastestBranchDataset( dataset, [branch, branch]) self.assertDatasetProduces( choose_fastest, expected_output=[0, 1, 2, 3, 4], expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
def make_benchmark_datasets(self, input_dataset, branch_0, branch_1, ratio_numerator, num_elements_per_branch=None): ds_0 = branch_0(input_dataset) ds_1 = branch_1(input_dataset) choose_fastest_dataset = optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access input_dataset, [branch_0, branch_1], ratio_numerator=ratio_numerator, num_elements_per_branch=num_elements_per_branch) return ds_0, ds_1, choose_fastest_dataset
def make_benchmark_datasets(self, input_dataset, branch_0, branch_1, ratio_numerator, num_elements_per_branch=None): ds_0 = branch_0(input_dataset) ds_1 = branch_1(input_dataset) choose_fastest_dataset = optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access input_dataset, [branch_0, branch_1], ratio_numerator=ratio_numerator, num_elements_per_branch=num_elements_per_branch) return ds_0, ds_1, choose_fastest_dataset
def testWithPrefetch(self): """Should maintain ordering even if the branches do prefetching.""" dataset = dataset_ops.Dataset.range(100) def branch_0(dataset): return dataset.prefetch(1) def branch_1(dataset): return dataset.prefetch(2) choose_fastest = optimization._ChooseFastestBranchDataset( dataset, [branch_0, branch_1]) self.assertDatasetProduces(choose_fastest, expected_output=list(range(100)))
def testDifferentFunctions(self): dataset = dataset_ops.Dataset.range(100) def branch_0(dataset): return dataset.map(lambda x: x).batch(10) def branch_1(dataset): return dataset.batch(10).map(lambda x: x) choose_fastest = optimization._ChooseFastestBranchDataset( dataset, [branch_0, branch_1], ratio_numerator=10) self.assertDatasetProduces( choose_fastest, expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)])
def testWithRepeatBeforeAndAfter(self): dataset = dataset_ops.Dataset.from_tensors(0).repeat(10) def branch_0(dataset): return dataset.map(lambda x: x).batch(10) def branch_1(dataset): return dataset.batch(10).map(lambda x: x) choose_fastest = optimization._ChooseFastestBranchDataset( dataset, [branch_0, branch_1], ratio_numerator=10) choose_fastest = choose_fastest.repeat(10) self.assertDatasetProduces( choose_fastest, expected_output=[[0] * 10 for _ in range(10)])
def testErrorWithRepeat(self): dataset = dataset_ops.Dataset.from_tensors(0) def branch(dataset): return dataset.repeat(10) choose_fastest = optimization._ChooseFastestBranchDataset( dataset, [branch, branch], ratio_denominator=10, num_elements_per_branch=10) self.assertDatasetProduces( choose_fastest, expected_error=( errors.InvalidArgumentError, "Cannot create more than one WrapperIterator per WrapperDataset."), expected_error_iter=2)
def make_benchmark_datasets(self): dataset = dataset_ops.Dataset.range(1000**2).repeat() def branch_0(dataset): return dataset.map(lambda x: x + 1).batch(100) def branch_1(dataset): return dataset.batch(100).map(lambda x: x + 1) map_batch_dataset = branch_0(dataset) batch_map_dataset = branch_1(dataset) choose_fastest_dataset = optimization._ChooseFastestBranchDataset( # pylint: disable=protected-access dataset, [branch_0, branch_1], ratio_numerator=100) return map_batch_dataset, batch_map_dataset, choose_fastest_dataset
def testCaptureSimple(self): dataset = dataset_ops.Dataset.range(10) const_64 = constant_op.constant(1, dtypes.int64) const_32 = constant_op.constant(1, dtypes.int32) def branch_0(dataset): return dataset.map(lambda x: x + const_64) def branch_1(dataset): return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64)) choose_fastest = optimization._ChooseFastestBranchDataset( dataset, [branch_0, branch_1]) self.assertDatasetProduces( choose_fastest, expected_output=list(range(1, 11)))
def make_dataset(): return optimization._ChooseFastestBranchDataset( dataset, [branch, branch], ratio_denominator=100, num_elements_per_branch=10)
def make_dataset(): return optimization._ChooseFastestBranchDataset( dataset, [branch, branch], ratio_denominator=100, num_elements_per_branch=10)