def testBatchingSchemeMaxLength(self):
        scheme = data_reader.batching_scheme(batch_size=20,
                                             max_length=None,
                                             min_length_bucket=8,
                                             length_bucket_step=1.1,
                                             drop_long_sequences=False)
        self.assertGreater(scheme["max_length"], 10000)

        scheme = data_reader.batching_scheme(batch_size=20,
                                             max_length=None,
                                             min_length_bucket=8,
                                             length_bucket_step=1.1,
                                             drop_long_sequences=True)
        self.assertEqual(scheme["max_length"], 20)

        scheme = data_reader.batching_scheme(batch_size=20,
                                             max_length=15,
                                             min_length_bucket=8,
                                             length_bucket_step=1.1,
                                             drop_long_sequences=True)
        self.assertEqual(scheme["max_length"], 15)

        scheme = data_reader.batching_scheme(batch_size=20,
                                             max_length=15,
                                             min_length_bucket=8,
                                             length_bucket_step=1.1,
                                             drop_long_sequences=False)
        self.assertGreater(scheme["max_length"], 10000)
  def testBatchingSchemeBuckets(self):
    scheme = data_reader.batching_scheme(
        batch_size=128,
        max_length=0,
        min_length_bucket=8,
        length_bucket_step=1.1)
    boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"]
    self.assertEqual(len(boundaries), len(batch_sizes) - 1)
    expected_boundaries = [
        8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 26, 28, 30,
        33, 36, 39, 42, 46, 50, 55, 60, 66, 72, 79, 86, 94, 103, 113, 124
    ]
    self.assertEqual(expected_boundaries, boundaries)
    expected_batch_sizes = [
        16, 12, 12, 8, 8, 8, 8, 8, 8, 6, 6, 6, 6, 4, 4, 4, 4, 4, 3, 3, 3, 3, 2,
        2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1
    ]
    self.assertEqual(expected_batch_sizes, batch_sizes)

    scheme = data_reader.batching_scheme(
        batch_size=128,
        max_length=0,
        min_length_bucket=8,
        length_bucket_step=1.1,
        shard_multiplier=2)
    boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"]
    self.assertAllEqual([bs * 2 for bs in expected_batch_sizes], batch_sizes)
    self.assertEqual(expected_boundaries, boundaries)

    scheme = data_reader.batching_scheme(
        batch_size=128,
        max_length=0,
        min_length_bucket=8,
        length_bucket_step=1.1,
        length_multiplier=2)
    boundaries, batch_sizes = scheme["boundaries"], scheme["batch_sizes"]
    self.assertAllEqual([b * 2 for b in expected_boundaries], boundaries)
    self.assertEqual([max(1, bs // 2)
                      for bs in expected_batch_sizes], batch_sizes)
Example #3
0
def batch_fn(dataset,
             training,
             shapes,
             target_names,
             batch_size=32,
             eval_batch_size=32,
             bucket_batch_length=32,
             bucket_max_length=256,
             bucket_min_length=8,
             bucket_length_step=1.1,
             buckets=None):
    """Batching function."""
    del target_names
    # If bucketing is not specified, check if target shapes are variable.
    cur_batch_size = batch_size if training else eval_batch_size
    if buckets is None:
        variable_target_shapes = False
        target_shape = shapes[1]
        for dim in target_shape:
            if dim is None:
                variable_target_shapes = True
        tf.logging.info(
            "Heuristically setting bucketing to %s based on shapes "
            "of target tensors." % variable_target_shapes)
        if variable_target_shapes:
            batch_size_per_token = cur_batch_size * bucket_batch_length
            scheme = data_reader.batching_scheme(batch_size_per_token,
                                                 bucket_max_length,
                                                 bucket_min_length,
                                                 bucket_length_step,
                                                 drop_long_sequences=training)
            buckets = (scheme["boundaries"], scheme["batch_sizes"])

    if buckets:
        tf.logging.info("Bucketing with buckets %s." % str(buckets))

        def example_length(_, target):
            return tf.shape(target)[0]

        boundaries, batch_sizes = buckets
        dataset = dataset.apply(
            tf.data.experimental.bucket_by_sequence_length(
                example_length, boundaries, batch_sizes))
    else:
        dataset = dataset.padded_batch(cur_batch_size, shapes)
    return dataset