def build_ds():
      dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)

      def branch(dataset):
        return dataset.apply(batching.unbatch())

      return optimization._ChooseFastestBranchDataset(
          dataset, [branch, branch],
          ratio_denominator=10,
          num_elements_per_branch=100)
    def build_ds():
      dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)

      def branch(dataset):
        return dataset.apply(batching.unbatch())

      return optimization._ChooseFastestBranchDataset(
          dataset, [branch, branch],
          ratio_denominator=10,
          num_elements_per_branch=100)
    def build_ds(size):
      dataset = dataset_ops.Dataset.range(size)

      def branch_0(dataset):
        return dataset.map(lambda x: x).batch(10)

      def branch_1(dataset):
        return dataset.batch(10).map(lambda x: x)

      return optimization._ChooseFastestBranchDataset(  # pylint: disable=protected-access
          dataset, [branch_0, branch_1],
          ratio_numerator=10)
コード例 #4
0
        def build_ds(size):
            dataset = dataset_ops.Dataset.range(size)

            def branch_0(dataset):
                return dataset.map(lambda x: x).batch(10)

            def branch_1(dataset):
                return dataset.batch(10).map(lambda x: x)

            return optimization._ChooseFastestBranchDataset(  # pylint: disable=protected-access
                dataset, [branch_0, branch_1],
                ratio_numerator=10)
コード例 #5
0
  def testWithMoreOutputThanInput(self):

    dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)

    def branch(dataset):
      return dataset.apply(batching.unbatch())

    choose_fastest = optimization._ChooseFastestBranchDataset(
        dataset, [branch, branch],
        ratio_denominator=100,
        num_elements_per_branch=100)

    self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
    def build_ds():
      dataset = dataset_ops.Dataset.range(10)
      const_64 = constant_op.constant(1, dtypes.int64)
      const_32 = constant_op.constant(1, dtypes.int32)

      def branch_0(dataset):
        return dataset.map(lambda x: x + const_64)

      def branch_1(dataset):
        return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))

      return optimization._ChooseFastestBranchDataset(
          dataset, [branch_0, branch_1], num_elements_per_branch=3)
    def build_ds():
      dataset = dataset_ops.Dataset.range(10)
      const_64 = constant_op.constant(1, dtypes.int64)
      const_32 = constant_op.constant(1, dtypes.int32)

      def branch_0(dataset):
        return dataset.map(lambda x: x + const_64)

      def branch_1(dataset):
        return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))

      return optimization._ChooseFastestBranchDataset(
          dataset, [branch_0, branch_1], num_elements_per_branch=3)
コード例 #8
0
  def testSimple(self):
    dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])

    def branch(dataset):
      return dataset.map(lambda x: x)

    choose_fastest = optimization._ChooseFastestBranchDataset(
        dataset, [branch, branch])

    self.assertDatasetProduces(
        choose_fastest,
        expected_output=[0, 1, 2, 3, 4],
        expected_shapes=dataset.output_shapes)
コード例 #9
0
    def testWithMoreOutputThanInput(self):

        dataset = dataset_ops.Dataset.from_tensors(0).repeat(1000).batch(100)

        def branch(dataset):
            return dataset.apply(batching.unbatch())

        choose_fastest = optimization._ChooseFastestBranchDataset(
            dataset, [branch, branch],
            ratio_denominator=100,
            num_elements_per_branch=100)

        self.assertDatasetProduces(choose_fastest, expected_output=[0] * 1000)
コード例 #10
0
    def testSimple(self):
        dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])

        def branch(dataset):
            return dataset.map(lambda x: x)

        choose_fastest = optimization._ChooseFastestBranchDataset(
            dataset, [branch, branch])

        self.assertDatasetProduces(
            choose_fastest,
            expected_output=[0, 1, 2, 3, 4],
            expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
コード例 #11
0
  def make_benchmark_datasets(self,
                              input_dataset,
                              branch_0,
                              branch_1,
                              ratio_numerator,
                              num_elements_per_branch=None):
    ds_0 = branch_0(input_dataset)
    ds_1 = branch_1(input_dataset)
    choose_fastest_dataset = optimization._ChooseFastestBranchDataset(  # pylint: disable=protected-access
        input_dataset, [branch_0, branch_1],
        ratio_numerator=ratio_numerator,
        num_elements_per_branch=num_elements_per_branch)

    return ds_0, ds_1, choose_fastest_dataset
  def make_benchmark_datasets(self,
                              input_dataset,
                              branch_0,
                              branch_1,
                              ratio_numerator,
                              num_elements_per_branch=None):
    ds_0 = branch_0(input_dataset)
    ds_1 = branch_1(input_dataset)
    choose_fastest_dataset = optimization._ChooseFastestBranchDataset(  # pylint: disable=protected-access
        input_dataset, [branch_0, branch_1],
        ratio_numerator=ratio_numerator,
        num_elements_per_branch=num_elements_per_branch)

    return ds_0, ds_1, choose_fastest_dataset
コード例 #13
0
  def testWithPrefetch(self):
    """Should maintain ordering even if the branches do prefetching."""
    dataset = dataset_ops.Dataset.range(100)

    def branch_0(dataset):
      return dataset.prefetch(1)

    def branch_1(dataset):
      return dataset.prefetch(2)

    choose_fastest = optimization._ChooseFastestBranchDataset(
        dataset, [branch_0, branch_1])

    self.assertDatasetProduces(choose_fastest, expected_output=list(range(100)))
コード例 #14
0
  def testDifferentFunctions(self):
    dataset = dataset_ops.Dataset.range(100)

    def branch_0(dataset):
      return dataset.map(lambda x: x).batch(10)

    def branch_1(dataset):
      return dataset.batch(10).map(lambda x: x)

    choose_fastest = optimization._ChooseFastestBranchDataset(
        dataset, [branch_0, branch_1], ratio_numerator=10)

    self.assertDatasetProduces(
        choose_fastest,
        expected_output=[list(range(10 * x, 10 * x + 10)) for x in range(10)])
コード例 #15
0
  def testWithRepeatBeforeAndAfter(self):
    dataset = dataset_ops.Dataset.from_tensors(0).repeat(10)

    def branch_0(dataset):
      return dataset.map(lambda x: x).batch(10)

    def branch_1(dataset):
      return dataset.batch(10).map(lambda x: x)

    choose_fastest = optimization._ChooseFastestBranchDataset(
        dataset, [branch_0, branch_1], ratio_numerator=10)
    choose_fastest = choose_fastest.repeat(10)

    self.assertDatasetProduces(
        choose_fastest, expected_output=[[0] * 10 for _ in range(10)])
コード例 #16
0
  def testErrorWithRepeat(self):
    dataset = dataset_ops.Dataset.from_tensors(0)

    def branch(dataset):
      return dataset.repeat(10)

    choose_fastest = optimization._ChooseFastestBranchDataset(
        dataset, [branch, branch],
        ratio_denominator=10,
        num_elements_per_branch=10)
    self.assertDatasetProduces(
        choose_fastest,
        expected_error=(
            errors.InvalidArgumentError,
            "Cannot create more than one WrapperIterator per WrapperDataset."),
        expected_error_iter=2)
  def make_benchmark_datasets(self):

    dataset = dataset_ops.Dataset.range(1000**2).repeat()

    def branch_0(dataset):
      return dataset.map(lambda x: x + 1).batch(100)

    def branch_1(dataset):
      return dataset.batch(100).map(lambda x: x + 1)

    map_batch_dataset = branch_0(dataset)
    batch_map_dataset = branch_1(dataset)
    choose_fastest_dataset = optimization._ChooseFastestBranchDataset(  # pylint: disable=protected-access
        dataset, [branch_0, branch_1],
        ratio_numerator=100)
    return map_batch_dataset, batch_map_dataset, choose_fastest_dataset
コード例 #18
0
  def testCaptureSimple(self):
    dataset = dataset_ops.Dataset.range(10)

    const_64 = constant_op.constant(1, dtypes.int64)
    const_32 = constant_op.constant(1, dtypes.int32)

    def branch_0(dataset):
      return dataset.map(lambda x: x + const_64)

    def branch_1(dataset):
      return dataset.map(lambda x: x + math_ops.cast(const_32, dtypes.int64))

    choose_fastest = optimization._ChooseFastestBranchDataset(
        dataset, [branch_0, branch_1])

    self.assertDatasetProduces(
        choose_fastest, expected_output=list(range(1, 11)))
コード例 #19
0
 def make_dataset():
   return optimization._ChooseFastestBranchDataset(
       dataset, [branch, branch],
       ratio_denominator=100,
       num_elements_per_branch=10)
コード例 #20
0
 def make_dataset():
     return optimization._ChooseFastestBranchDataset(
         dataset, [branch, branch],
         ratio_denominator=100,
         num_elements_per_branch=10)