Esempio n. 1
0
def build_padding_config(log2_num_nodes, log2_num_input_tagged_nodes,
                         log2_max_initial_transitions,
                         log2_max_in_tagged_transitions, log2_max_edges,
                         log2_max_tokens):
    """Builds a padding config with power-of-2 sizes."""
    return example_definition.GraphBundleWithTokensPaddingConfig(
        bundle_padding=graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=2**log2_num_nodes,
                num_input_tagged_nodes=2**log2_num_input_tagged_nodes),
            max_initial_transitions=2**log2_max_initial_transitions,
            max_in_tagged_transitions=2**log2_max_in_tagged_transitions,
            max_edges=2**log2_max_edges),
        max_tokens=2**log2_max_tokens)
          static_max_metadata=automaton_builder.EncodedGraphMetadata(
              num_nodes=2**log2_num_nodes,
              num_input_tagged_nodes=2**log2_num_input_tagged_nodes),
          max_initial_transitions=2**log2_max_initial_transitions,
          max_in_tagged_transitions=2**log2_max_in_tagged_transitions,
          max_edges=2**log2_max_edges),
      max_tokens=2**log2_max_tokens)


PaddingAndBatchSizes = (
    List[Tuple[example_definition.GraphBundleWithTokensPaddingConfig, int]])

TINY_PADDING_CONFIG = example_definition.GraphBundleWithTokensPaddingConfig(
    bundle_padding=graph_bundle.PaddingConfig(
        static_max_metadata=automaton_builder.EncodedGraphMetadata(
            num_nodes=1, num_input_tagged_nodes=1),
        max_initial_transitions=1,
        max_in_tagged_transitions=1,
        max_edges=1),
    max_tokens=1)


def pad_and_batch_with_rng(
    it, num_devices,
    padding_and_batch_sizes,
    base_rng):
  """Pad and batch according to a collection of sizes.

  Args:
    it: Iterable over individual examples.
    num_devices: Number of devices; determines constant leading batch dimension.
    padding_and_batch_sizes: List of pairs of padding config and per-device