def _test_bucket_by_padding(no_padding):
      dataset = build_dataset(sparse=no_padding)
      dataset = dataset.apply(
          grouping.bucket_by_sequence_length(
              _element_length_fn,
              boundaries,
              batch_sizes,
              no_padding=no_padding))
      get_next = self.getNext(dataset)
      batches = []
      for _ in range(4):
        batch, = self.evaluate(get_next())
        batches.append(batch)
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())

      batch_sizes_val = []
      lengths_val = []
      for batch in batches:
        shape = batch.dense_shape if no_padding else batch.shape
        batch_size = shape[0]
        length = shape[1]
        batch_sizes_val.append(batch_size)
        lengths_val.append(length)
        if not context.executing_eagerly():
          sum_check = batch.values.sum() if no_padding else batch.sum()
          self.assertEqual(sum_check, batch_size * length - 1)
      self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
      self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
      self.assertEqual(sorted(lengths), sorted(lengths_val))
Ejemplo n.º 2
0
  def testPadToBoundaryNoExtraneousPadding(self):

    boundaries = [3, 7, 11]
    batch_sizes = [2, 2, 2, 2]
    lengths = range(1, 11)

    def element_gen():
      for length in lengths:
        yield ([1] * length,)

    element_len = lambda element: array_ops.shape(element)[0]
    dataset = dataset_ops.Dataset.from_generator(
        element_gen, (dtypes.int64,), ([None],)).apply(
            grouping.bucket_by_sequence_length(
                element_len, boundaries, batch_sizes,
                pad_to_bucket_boundary=True))
    batch, = dataset.make_one_shot_iterator().get_next()

    with self.cached_session() as sess:
      batches = []
      for _ in range(5):
        batches.append(sess.run(batch))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(batch)

    self.assertAllEqual(batches[0], [[1, 0],
                                     [1, 1]])
    self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
                                     [1, 1, 1, 1, 0, 0]])
    self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
                                     [1, 1, 1, 1, 1, 1]])
    self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
                                     [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
    self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
Ejemplo n.º 3
0
        def _test_bucket_by_padding(no_padding):
            dataset = build_dataset(sparse=no_padding)
            dataset = dataset.apply(
                grouping.bucket_by_sequence_length(_element_length_fn,
                                                   boundaries,
                                                   batch_sizes,
                                                   no_padding=no_padding))
            batch, = dataset.make_one_shot_iterator().get_next()

            with self.cached_session() as sess:
                batches = []
                for _ in range(4):
                    batches.append(sess.run(batch))
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(batch)
            batch_sizes_val = []
            lengths_val = []
            for batch in batches:
                shape = batch.dense_shape if no_padding else batch.shape
                batch_size = shape[0]
                length = shape[1]
                batch_sizes_val.append(batch_size)
                lengths_val.append(length)
                sum_check = batch.values.sum() if no_padding else batch.sum()
                self.assertEqual(sum_check, batch_size * length - 1)
            self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
            self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
            self.assertEqual(sorted(lengths), sorted(lengths_val))
Ejemplo n.º 4
0
  def testPadToBoundaryNoExtraneousPadding(self):

    boundaries = [3, 7, 11]
    batch_sizes = [2, 2, 2, 2]
    lengths = range(1, 11)

    def element_gen():
      for length in lengths:
        yield ([1] * length,)

    element_len = lambda element: array_ops.shape(element)[0]
    dataset = dataset_ops.Dataset.from_generator(
        element_gen, (dtypes.int64,), ([None],)).apply(
            grouping.bucket_by_sequence_length(
                element_len, boundaries, batch_sizes,
                pad_to_bucket_boundary=True))
    batch, = dataset_ops.make_one_shot_iterator(dataset).get_next()

    with self.cached_session() as sess:
      batches = []
      for _ in range(5):
        batches.append(self.evaluate(batch))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(batch)

    self.assertAllEqual(batches[0], [[1, 0],
                                     [1, 1]])
    self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
                                     [1, 1, 1, 1, 0, 0]])
    self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
                                     [1, 1, 1, 1, 1, 1]])
    self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
                                     [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
    self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
Ejemplo n.º 5
0
        def _test_bucket_by_padding(no_padding):
            dataset = build_dataset(sparse=no_padding)
            dataset = dataset.apply(
                grouping.bucket_by_sequence_length(_element_length_fn,
                                                   boundaries,
                                                   batch_sizes,
                                                   no_padding=no_padding))
            get_next = self.getNext(dataset)
            batches = []
            for _ in range(4):
                batch, = self.evaluate(get_next())
                batches.append(batch)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

            batch_sizes_val = []
            lengths_val = []
            for batch in batches:
                shape = batch.dense_shape if no_padding else batch.shape
                batch_size = shape[0]
                length = shape[1]
                batch_sizes_val.append(batch_size)
                lengths_val.append(length)
                if not context.executing_eagerly():
                    sum_check = batch.values.sum(
                    ) if no_padding else batch.sum()
                    self.assertEqual(sum_check, batch_size * length - 1)
            self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
            self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
            self.assertEqual(sorted(lengths), sorted(lengths_val))
Ejemplo n.º 6
0
    def _test_bucket_by_padding(no_padding):
      dataset = build_dataset(sparse=no_padding)
      dataset = dataset.apply(
          grouping.bucket_by_sequence_length(
              _element_length_fn,
              boundaries,
              batch_sizes,
              no_padding=no_padding))
      batch, = dataset.make_one_shot_iterator().get_next()

      with self.cached_session() as sess:
        batches = []
        for _ in range(4):
          batches.append(sess.run(batch))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(batch)
      batch_sizes_val = []
      lengths_val = []
      for batch in batches:
        shape = batch.dense_shape if no_padding else batch.shape
        batch_size = shape[0]
        length = shape[1]
        batch_sizes_val.append(batch_size)
        lengths_val.append(length)
        sum_check = batch.values.sum() if no_padding else batch.sum()
        self.assertEqual(sum_check, batch_size * length - 1)
      self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
      self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
      self.assertEqual(sorted(lengths), sorted(lengths_val))
Ejemplo n.º 7
0
    def testRoundRobinBucketizing(self):
        # Tests a common use case for round robin reads. At each step, all
        # consumers should get batches with the same bucket size.
        cluster = self.create_cluster(num_workers=4)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
        ds = ds.shuffle(num_elements)
        low_bucket_max = 30
        mid_bucket_max = 60
        bucket_boundaries = [low_bucket_max, mid_bucket_max]
        batch_size = 10
        num_consumers = 3
        bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
        ds = ds.apply(
            grouping.bucket_by_sequence_length(lambda x: x,
                                               bucket_boundaries,
                                               bucket_batch_sizes,
                                               drop_remainder=True))
        ds = ds.apply(
            grouping.group_by_window(
                lambda x: math_ops.cast(x[1], dtypes.int64),
                lambda _, x: dataset_ops.Dataset.from_tensors(x),
                window_size=num_consumers))
        ds = ds.flat_map(lambda x: x)
        ds = ds.repeat()

        consumers = []
        for consumer_index in range(num_consumers):
            consumers.append(
                self.make_distributed_dataset(ds,
                                              cluster,
                                              job_name="test",
                                              consumer_index=consumer_index,
                                              num_consumers=num_consumers))
        # Use parallel interleave to read from consumers in parallel.
        ds = dataset_ops.Dataset.from_tensor_slices(consumers)
        ds = ds.interleave(lambda x: x.prefetch(num_elements),
                           cycle_length=num_consumers,
                           num_parallel_calls=num_consumers)

        num_rounds = 10
        get_next = self.getNext(ds, requires_initialization=True)
        results = []
        for _ in range(num_rounds):
            results.append(self.evaluate(get_next()))

        def get_bucket(elem):
            bucket_ind = 0
            while bucket_ind < len(bucket_boundaries
                                   ) and elem >= bucket_boundaries[bucket_ind]:
                bucket_ind += 1
            return bucket_ind

        for i in range(0, len(results), num_consumers):
            batches = results[num_consumers * i:num_consumers * i +
                              num_consumers]
            bucket_inds = [get_bucket(batch[0]) for batch in batches]
            for bucket_ind in bucket_inds[1:]:
                self.assertEqual(bucket_inds[0], bucket_ind)
Ejemplo n.º 8
0
 def _test_tuple_elements_by_padding(no_padding):
   dataset = build_dataset(sparse=no_padding)
   dataset = dataset.apply(grouping.bucket_by_sequence_length(
       element_length_func=_element_length_fn,
       bucket_batch_sizes=[2, 2, 2],
       bucket_boundaries=[0, 8],
       no_padding=no_padding))
   shapes = dataset.output_shapes
   self.assertEqual([None, None], shapes[0].as_list())
   self.assertEqual([None], shapes[1].as_list())
Ejemplo n.º 9
0
 def _test_tuple_elements_by_padding(no_padding):
   dataset = build_dataset(sparse=no_padding)
   dataset = dataset.apply(grouping.bucket_by_sequence_length(
       element_length_func=_element_length_fn,
       bucket_batch_sizes=[2, 2, 2],
       bucket_boundaries=[0, 8],
       no_padding=no_padding))
   shapes = dataset.output_shapes
   self.assertEqual([None, None], shapes[0].as_list())
   self.assertEqual([None], shapes[1].as_list())
Ejemplo n.º 10
0
    def testPadToBoundary(self):

        boundaries = [10, 20, 30]
        batch_sizes = [10, 8, 4, 2]
        lengths = [8, 13, 25]

        def element_gen():
            # Produce 1 batch for each bucket
            elements = []
            for batch_size, length in zip(batch_sizes[:-1], lengths):
                for _ in range(batch_size):
                    elements.append([1] * length)
            random.shuffle(elements)
            for el in elements:
                yield (el, )
            for _ in range(batch_sizes[-1]):
                el = [1] * (boundaries[-1] + 5)
                yield (el, )

        element_len = lambda el: array_ops.shape(el)[0]
        dataset = dataset_ops.Dataset.from_generator(
            element_gen, (dtypes.int64, ), ([None], )).apply(
                grouping.bucket_by_sequence_length(
                    element_len,
                    boundaries,
                    batch_sizes,
                    pad_to_bucket_boundary=True))
        get_next = self.getNext(dataset)

        batches = []
        for _ in range(3):
            batch, = self.evaluate(get_next())
            batches.append(batch)
        with self.assertRaisesOpError("bucket_boundaries"):
            self.evaluate(get_next())

        batch_sizes_val = []
        lengths_val = []
        for batch in batches:
            batch_size = batch.shape[0]
            length = batch.shape[1]
            batch_sizes_val.append(batch_size)
            lengths_val.append(length)
        batch_sizes = batch_sizes[:-1]
        self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
        self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
        self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
                         sorted(lengths_val))
Ejemplo n.º 11
0
def bucket_by_sequence_length(element_length_func,
                              bucket_boundaries,
                              bucket_batch_sizes,
                              padded_shapes=None,
                              padding_values=None,
                              pad_to_bucket_boundary=False,
                              no_padding=False):
    """A transformation that buckets elements in a `Dataset` by length.

  Elements of the `Dataset` are grouped together by length and then are padded
  and batched.

  This is useful for sequence tasks in which the elements have variable length.
  Grouping together elements that have similar lengths reduces the total
  fraction of padding in a batch which increases training step efficiency.

  Args:
    element_length_func: function from element in `Dataset` to `tf.int32`,
      determines the length of the element, which will determine the bucket it
      goes into.
    bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
    bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
      `len(bucket_boundaries) + 1`.
    padded_shapes: Nested structure of `tf.TensorShape` to pass to
      `tf.data.Dataset.padded_batch`. If not provided, will use
      `dataset.output_shapes`, which will result in variable length dimensions
      being padded out to the maximum length in each batch.
    padding_values: Values to pad with, passed to
      `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
    pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
      size to maximum length in batch. If `True`, will pad dimensions with
      unknown size to bucket boundary minus 1 (i.e., the maximum length in each
      bucket), and caller must ensure that the source `Dataset` does not contain
      any elements with length longer than `max(bucket_boundaries)`.
    no_padding: `bool`, indicates whether to pad the batch features (features
      need to be either of type `tf.SparseTensor` or of same shape).

  Returns:
    A `Dataset` transformation function, which can be passed to
    `tf.data.Dataset.apply`.

  Raises:
    ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
  """
    return grouping.bucket_by_sequence_length(
        element_length_func, bucket_boundaries, bucket_batch_sizes,
        padded_shapes, padding_values, pad_to_bucket_boundary, no_padding)
  def testPadToBoundary(self):

    boundaries = [10, 20, 30]
    batch_sizes = [10, 8, 4, 2]
    lengths = [8, 13, 25]

    def element_gen():
      # Produce 1 batch for each bucket
      elements = []
      for batch_size, length in zip(batch_sizes[:-1], lengths):
        for _ in range(batch_size):
          elements.append([1] * length)
      random.shuffle(elements)
      for el in elements:
        yield (el,)
      for _ in range(batch_sizes[-1]):
        el = [1] * (boundaries[-1] + 5)
        yield (el,)

    element_len = lambda el: array_ops.shape(el)[0]
    dataset = dataset_ops.Dataset.from_generator(
        element_gen, (dtypes.int64,), ([None],)).apply(
            grouping.bucket_by_sequence_length(
                element_len, boundaries, batch_sizes,
                pad_to_bucket_boundary=True))
    get_next = self.getNext(dataset)

    batches = []
    for _ in range(3):
      batch, = self.evaluate(get_next())
      batches.append(batch)
    with self.assertRaisesOpError("bucket_boundaries"):
      self.evaluate(get_next())

    batch_sizes_val = []
    lengths_val = []
    for batch in batches:
      batch_size = batch.shape[0]
      length = batch.shape[1]
      batch_sizes_val.append(batch_size)
      lengths_val.append(length)
    batch_sizes = batch_sizes[:-1]
    self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
    self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
    self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
                     sorted(lengths_val))
    def testCardinality(self):

        boundaries = [3, 7, 11]
        batch_sizes = [2, 2, 2, 2]
        lengths = range(1, 11)

        def element_gen():
            for length in lengths:
                yield ([1] * length, )

        element_len = lambda element: array_ops.shape(element)[0]
        dataset = dataset_ops.Dataset.from_generator(
            element_gen, (dtypes.int64, ), ([None], )).repeat().apply(
                grouping.bucket_by_sequence_length(
                    element_len,
                    boundaries,
                    batch_sizes,
                    pad_to_bucket_boundary=True))
        self.assertEqual(self.evaluate(dataset.cardinality()),
                         dataset_ops.INFINITE)
    def testRoundRobinBucketizing(self):
        # Tests a common use case for round robin reads. At each step, all
        # consumers should get batches with the same bucket size.
        cluster = self.create_cluster(num_workers=4)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        num_elements = 100
        low_bucket_max = 30
        mid_bucket_max = 60
        bucket_boundaries = [low_bucket_max, mid_bucket_max]
        batch_size = 10
        num_consumer_hosts = 3
        replicas_per_consumer_host = 5
        num_consumers = num_consumer_hosts * replicas_per_consumer_host
        bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
        # Set up the dataset that will run on the tf.data workers.
        ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
        ds = ds.shuffle(num_elements)
        ds = ds.repeat()
        ds = ds.apply(
            grouping.bucket_by_sequence_length(lambda x: x,
                                               bucket_boundaries,
                                               bucket_batch_sizes,
                                               drop_remainder=True))
        ds = ds.apply(
            grouping.group_by_window(
                lambda x: math_ops.cast(x[1], dtypes.int64),
                lambda _, x: dataset_ops.Dataset.from_tensors(x),
                window_size=num_consumers))
        ds = ds.flat_map(lambda x: x)

        # Set up the per-consumer-host datasets. During each global step, we pull
        # `replicas_per_consumer_host` batches from each of these datasets.
        host_datasets = []
        for host_index in range(num_consumer_hosts):
            per_replica_datasets = []
            for i in range(replicas_per_consumer_host):
                consumer_index = host_index * replicas_per_consumer_host + i
                per_replica_datasets.append(
                    self.make_distributed_dataset(
                        ds,
                        cluster,
                        job_name="test",
                        consumer_index=consumer_index,
                        num_consumers=num_consumers))
            host_dataset = dataset_ops.Dataset.from_tensor_slices(
                per_replica_datasets)
            host_dataset = host_dataset.interleave(
                lambda x: x,
                cycle_length=len(per_replica_datasets),
                num_parallel_calls=len(per_replica_datasets),
                deterministic=True)
            host_datasets.append(host_dataset)

        # Use parallel interleave to read from host datasets in parallel.
        ds = dataset_ops.Dataset.from_tensor_slices(host_datasets)
        ds = ds.interleave(lambda x: x,
                           block_length=replicas_per_consumer_host,
                           cycle_length=len(host_datasets),
                           num_parallel_calls=len(host_datasets),
                           deterministic=True)

        num_rounds = 10
        get_next = self.getNext(ds, requires_initialization=True)
        results = []
        for _ in range(num_rounds * num_consumers):
            results.append(self.evaluate(get_next()))

        def get_bucket(elem):
            bucket_ind = 0
            while bucket_ind < len(bucket_boundaries
                                   ) and elem >= bucket_boundaries[bucket_ind]:
                bucket_ind += 1
            return bucket_ind

        # Check that the batches for each step contain elements from the same
        # bucket.
        for i in range(0, len(results), num_consumers):
            batches = results[num_consumers * i:num_consumers * (i + 1)]
            bucket_inds = [get_bucket(batch[0]) for batch in batches]
            for bucket_ind in bucket_inds[1:]:
                self.assertEqual(bucket_inds[0], bucket_ind)
Ejemplo n.º 15
0
    def testBucketSparse(self, param_drop_remainder):  # pylint: disable=g-doc-args
        """Tests bucketing of sparse tensors (case where `no_padding` == True).

    Test runs on following dataset:
      [
        [0],
        [0, 1],
        [0, 1, 2]
        ...
        [0, ..., max_len - 1]
      ]
    Sequences are bucketed by length and batched with
      `batch_size` < `bucket_size`.
    """

        min_len = 0
        max_len = 100
        batch_size = 7
        bucket_size = 10

        def _build_dataset():
            input_data = [range(i + 1) for i in range(min_len, max_len)]

            def generator_fn():
                for record in input_data:
                    yield _format_record(record, sparse=True)

            dataset = dataset_ops.Dataset.from_generator(
                generator=generator_fn,
                output_types=_get_record_type(sparse=True))
            dataset = dataset.map(_to_sparse_tensor)
            return dataset

        def _compute_expected_batches(drop_remainder):
            """Computes expected batch outputs and stores in a set."""
            all_expected_sparse_tensors = set()
            for bucket_start_len in range(min_len, max_len, bucket_size):
                if drop_remainder:
                    batch_offsets = [0]
                else:
                    batch_offsets = range(0, bucket_size, batch_size)

                for batch_offset in batch_offsets:
                    batch_start_len = bucket_start_len + batch_offset
                    batch_end_len = min(batch_start_len + batch_size,
                                        bucket_start_len + bucket_size)
                    expected_indices = []
                    expected_values = []
                    for length in range(batch_start_len, batch_end_len):
                        for val in range(length + 1):
                            expected_indices.append(
                                (length - batch_start_len, val))
                            expected_values.append(val)
                    expected_sprs_tensor = (tuple(expected_indices),
                                            tuple(expected_values))
                    all_expected_sparse_tensors.add(expected_sprs_tensor)
            return all_expected_sparse_tensors

        def _compute_batches(dataset):
            """Computes actual batch outputs of dataset and stores in a set."""
            batch = self.getNext(dataset)
            all_sparse_tensors = set()
            with self.assertRaises(errors.OutOfRangeError):
                while True:
                    output = self.evaluate(batch())
                    sprs_tensor = (tuple([
                        tuple(idx) for idx in output.indices
                    ]), tuple(output.values))
                    all_sparse_tensors.add(sprs_tensor)

            return all_sparse_tensors

        dataset = _build_dataset()
        boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
        dataset = dataset.apply(
            grouping.bucket_by_sequence_length(
                _element_length_fn,
                boundaries, [batch_size] * (len(boundaries) + 1),
                no_padding=True,
                drop_remainder=param_drop_remainder))
        batches = _compute_batches(dataset)
        expected_batches = _compute_expected_batches(param_drop_remainder)
        self.assertEqual(batches, expected_batches)
Ejemplo n.º 16
0
  def testBucketSparse(self):
    """Tests bucketing of sparse tensors (case where `no_padding` == True).

    Test runs on following dataset:
      [
        [0],
        [0, 1],
        [0, 1, 2]
        ...
        [0, ..., max_len - 1]
      ]
    Sequences are bucketed by length and batched with
      `batch_size` < `bucket_size`.
    """

    min_len = 0
    max_len = 100
    batch_size = 7
    bucket_size = 10

    def _build_dataset():
      input_data = [range(i+1) for i in range(min_len, max_len)]
      def generator_fn():
        for record in input_data:
          yield _format_record(record, sparse=True)
      dataset = dataset_ops.Dataset.from_generator(
          generator=generator_fn,
          output_types=_get_record_type(sparse=True))
      dataset = dataset.map(_to_sparse_tensor)
      return dataset

    def _compute_expected_batches():
      """Computes expected batch outputs and stores in a set."""
      all_expected_sparse_tensors = set()
      for bucket_start_len in range(min_len, max_len, bucket_size):
        for batch_offset in range(0, bucket_size, batch_size):
          batch_start_len = bucket_start_len + batch_offset
          batch_end_len = min(batch_start_len + batch_size,
                              bucket_start_len + bucket_size)
          expected_indices = []
          expected_values = []
          for length in range(batch_start_len, batch_end_len):
            for val in range(length + 1):
              expected_indices.append((length - batch_start_len, val))
              expected_values.append(val)
          expected_sprs_tensor = (tuple(expected_indices),
                                  tuple(expected_values))
          all_expected_sparse_tensors.add(expected_sprs_tensor)
      return all_expected_sparse_tensors

    def _compute_batches(dataset):
      """Computes actual batch outputs of dataset and stores in a set."""
      batch = dataset.make_one_shot_iterator().get_next()
      all_sparse_tensors = set()
      with self.cached_session() as sess:
        with self.assertRaises(errors.OutOfRangeError):
          while True:
            output = sess.run(batch)
            sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
                           tuple(output.values))
            all_sparse_tensors.add(sprs_tensor)
      return all_sparse_tensors

    dataset = _build_dataset()
    boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
    dataset = dataset.apply(grouping.bucket_by_sequence_length(
        _element_length_fn,
        boundaries,
        [batch_size] * (len(boundaries) + 1),
        no_padding=True))
    batches = _compute_batches(dataset)
    expected_batches = _compute_expected_batches()
    self.assertEqual(batches, expected_batches)
    def _test_bucket_by_padding(no_padding):
      dataset = build_dataset(sparse=no_padding)
      dataset = dataset.apply(
          grouping.bucket_by_sequence_length(
              _element_length_fn,
              boundaries,
              batch_sizes,
              no_padding=no_padding,
              drop_remainder=True))

      get_next = self.getNext(dataset)
      batches = []
      for _ in range(n_expected_batches):
        batch, = self.evaluate(get_next())
        batches.append(batch)

      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())

      generated_lengths = []

      # <seq-length>: <total-sum>
      generated_sums = {}

      # <seq-length>: [<batch_size>, ...]
      generated_batch_sizes = {}

      for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
                                                     n_bucket_elements):
        # Initialize the sum across all batches.
        generated_sums[length] = 0
        # Initialize the individual batch sizes.
        generated_batch_sizes[length] = []

      for batch in batches:
        shape = batch.dense_shape if no_padding else batch.shape
        length = shape[1]
        generated_lengths.append(length)

        batch_size = shape[0]
        generated_batch_sizes[length].append(batch_size)

        batch_sum = batch.values.sum() if no_padding else batch.sum()
        generated_sums[length] += batch_sum

      for l in lengths:
        # Make sure the sum of the batch contents is correct for the individual sequence lengths.
        self.assertEqual(
            generated_sums[l], expected_sums[l], "Tensor sums did not match! "
            "expected: {}, generated: {}".format(expected_sums, generated_sums))

        # Make sure the individual batch sizes are generated as expected.
        self.assertEqual(
            sorted(generated_batch_sizes[l]), sorted(expected_batch_sizes[l]),
            "Batch-sizes did not match! "
            "expected: {}, generated: {}".format(
                sorted(expected_batch_sizes[l]),
                sorted(generated_batch_sizes[l])))

      # Make sure the generated sequence lengths appear as often as expected.
      self.assertEqual(
          sorted(generated_lengths), sorted(expected_lengths),
          "The generated sequence lengths did not match! "
          "expected: {}, generated: {}".format(
              sorted(expected_lengths), sorted(generated_lengths)))
Ejemplo n.º 18
0
        def _test_bucket_by_padding(no_padding):
            dataset = build_dataset(sparse=no_padding)
            dataset = dataset.apply(
                grouping.bucket_by_sequence_length(_element_length_fn,
                                                   boundaries,
                                                   batch_sizes,
                                                   no_padding=no_padding,
                                                   drop_remainder=True))

            get_next = self.getNext(dataset)
            batches = []
            for _ in range(n_expected_batches):
                batch, = self.evaluate(get_next())
                batches.append(batch)

            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

            generated_lengths = []

            # <seq-length>: <total-sum>
            generated_sums = {}

            # <seq-length>: [<batch_size>, ...]
            generated_batch_sizes = {}

            for length, batch_size, bucket_elements in zip(
                    lengths, batch_sizes, n_bucket_elements):
                # Initialize the sum across all batches.
                generated_sums[length] = 0
                # Initialize the individual batch sizes.
                generated_batch_sizes[length] = []

            for batch in batches:
                shape = batch.dense_shape if no_padding else batch.shape
                length = shape[1]
                generated_lengths.append(length)

                batch_size = shape[0]
                generated_batch_sizes[length].append(batch_size)

                batch_sum = batch.values.sum() if no_padding else batch.sum()
                generated_sums[length] += batch_sum

            for l in lengths:
                # Make sure the sum of the batch contents is correct for the individual sequence lengths.
                self.assertEqual(
                    generated_sums[l], expected_sums[l],
                    "Tensor sums did not match! "
                    "expected: {}, generated: {}".format(
                        expected_sums, generated_sums))

                # Make sure the individual batch sizes are generated as expected.
                self.assertEqual(
                    sorted(generated_batch_sizes[l]),
                    sorted(expected_batch_sizes[l]),
                    "Batch-sizes did not match! "
                    "expected: {}, generated: {}".format(
                        sorted(expected_batch_sizes[l]),
                        sorted(generated_batch_sizes[l])))

            # Make sure the generated sequence lengths appear as often as expected.
            self.assertEqual(
                sorted(generated_lengths), sorted(expected_lengths),
                "The generated sequence lengths did not match! "
                "expected: {}, generated: {}".format(
                    sorted(expected_lengths), sorted(generated_lengths)))