Ejemplo n.º 1
0
def _meta_path(save_dir):
    """Get meta path."""
    meta_prefix = "meta.{}.doc".format(FLAGS.split)
    meta_suffix = "json-{:05d}-of-{:05d}".format(FLAGS.task, FLAGS.num_task)
    meta_name = data_utils.format_filename(
        prefix=meta_prefix,
        suffix=meta_suffix,
        uncased=FLAGS.uncased,
    )
    meta_path = os.path.join(save_dir, meta_name)

    return meta_path
def _tfrecord_path(save_dir):
  """Get tfrecord path."""
  data_prefix = "data.{}.sent".format(FLAGS.split)
  data_suffix = "tfrecord-{:05d}-of-{:05d}".format(FLAGS.task, FLAGS.num_task)
  tfrecord_name = data_utils.format_filename(
      prefix=data_prefix,
      suffix=data_suffix,
      uncased=FLAGS.uncased,
  )
  tfrecord_path = os.path.join(save_dir, tfrecord_name)

  return tfrecord_path
Ejemplo n.º 3
0
def get_input_fn(tfrecord_dir,
                 split,
                 max_length,
                 num_hosts=1,
                 uncased=False,
                 use_bfloat16=False,
                 **kwargs):
    """Create Estimator input function."""

    # Merge all record infos into a single one
    record_glob_base = format_filename(prefix="meta.{}".format(split),
                                       suffix="json*",
                                       uncased=uncased)

    record_info = {"num_example": 0, "filenames": []}

    tfrecord_dirs = tfrecord_dir.split(",")
    tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)

    for idx, record_dir in enumerate(tfrecord_dirs):
        record_glob = os.path.join(record_dir, record_glob_base)
        tf.logging.info("[%d] Record glob: %s", idx, record_glob)

        record_paths = sorted(tf.gfile.Glob(record_glob))
        tf.logging.info("[%d] Num of record info path: %d", idx,
                        len(record_paths))

        cur_record_info = {"num_example": 0, "filenames": []}

        for record_info_path in record_paths:
            with tf.io.gfile.GFile(record_info_path, "r") as fp:
                info = json.load(fp)
                cur_record_info["num_example"] += info["num_example"]
                cur_record_info["filenames"] += info["filenames"]

        # overwrite directory for `cur_record_info`
        new_filenames = []
        for filename in cur_record_info["filenames"]:
            basename = os.path.basename(filename)
            new_filename = os.path.join(record_dir, basename)
            new_filenames.append(new_filename)
        cur_record_info["filenames"] = new_filenames

        tf.logging.info("[Dir %d] Number of chosen batches: %s", idx,
                        cur_record_info["num_example"])
        tf.logging.info("[Dir %d] Number of chosen files: %s", idx,
                        len(cur_record_info["filenames"]))
        tf.logging.debug(cur_record_info["filenames"])

        # add `cur_record_info` to global `record_info`
        record_info["num_example"] += cur_record_info["num_example"]
        record_info["filenames"] += cur_record_info["filenames"]

    tf.logging.info("Total number of batches: %d", record_info["num_example"])
    tf.logging.info("Total number of files: %d", len(record_info["filenames"]))
    tf.logging.debug(record_info["filenames"])

    kwargs = dict(data_files=record_info["filenames"],
                  num_hosts=num_hosts,
                  is_training=split == "train",
                  max_length=max_length,
                  use_bfloat16=use_bfloat16,
                  **kwargs)

    def input_fn(params):
        """Input function wrapper."""
        dataset = get_dataset(params=params, **kwargs)

        return dataset

    return input_fn, record_info