def load(tfrecord_file, meta_data_file, model_spec):
    """Loads data from tfrecord file and metada file."""

    dataset = input_pipeline.single_file_dataset(
        tfrecord_file, model_spec.get_name_to_features())
    dataset = dataset.map(model_spec.select_data_from_record,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    with tf.io.gfile.GFile(meta_data_file, 'rb') as reader:
        meta_data = json.load(reader)
    return dataset, meta_data
示例#2
0
def create_tfrecord_dataset_pipeline(input_file, max_seq_length, batch_size, input_pipeline_context=None):
    name_to_features = {
        'input_word_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64),
        'input_mask': tf.io.FixedLenFeature([max_seq_length], tf.int64),
        'input_type_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64),
    }
    dataset = single_file_dataset(input_file, name_to_features)
    # shard dataset between hosts
    if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
        dataset = dataset.shard(input_pipeline_context.num_input_pipelines, input_pipeline_context.input_pipeline_id)
    dataset = dataset.batch(batch_size, drop_remainder=False)
    dataset = dataset.prefetch(1024)
    return dataset
示例#3
0
def _load(tfrecord_file, meta_data_file, model_spec, is_training=None):
  """Loads data from tfrecord file and metada file."""

  if is_training is None:
    name_to_features = model_spec.get_name_to_features()
  else:
    name_to_features = model_spec.get_name_to_features(is_training=is_training)

  dataset = input_pipeline.single_file_dataset(tfrecord_file, name_to_features)
  dataset = dataset.map(
      model_spec.select_data_from_record, num_parallel_calls=tf.data.AUTOTUNE)

  meta_data = file_util.load_json_file(meta_data_file)

  logging.info(
      'Load preprocessed data and metadata from %s and %s '
      'with size: %d', tfrecord_file, meta_data_file, meta_data['size'])
  return dataset, meta_data