def load(tfrecord_file, meta_data_file, model_spec): """Loads data from tfrecord file and metada file.""" dataset = input_pipeline.single_file_dataset( tfrecord_file, model_spec.get_name_to_features()) dataset = dataset.map(model_spec.select_data_from_record, num_parallel_calls=tf.data.experimental.AUTOTUNE) with tf.io.gfile.GFile(meta_data_file, 'rb') as reader: meta_data = json.load(reader) return dataset, meta_data
def create_tfrecord_dataset_pipeline(input_file, max_seq_length, batch_size, input_pipeline_context=None): name_to_features = { 'input_word_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([max_seq_length], tf.int64), 'input_type_ids': tf.io.FixedLenFeature([max_seq_length], tf.int64), } dataset = single_file_dataset(input_file, name_to_features) # shard dataset between hosts if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1: dataset = dataset.shard(input_pipeline_context.num_input_pipelines, input_pipeline_context.input_pipeline_id) dataset = dataset.batch(batch_size, drop_remainder=False) dataset = dataset.prefetch(1024) return dataset
def _load(tfrecord_file, meta_data_file, model_spec, is_training=None): """Loads data from tfrecord file and metada file.""" if is_training is None: name_to_features = model_spec.get_name_to_features() else: name_to_features = model_spec.get_name_to_features(is_training=is_training) dataset = input_pipeline.single_file_dataset(tfrecord_file, name_to_features) dataset = dataset.map( model_spec.select_data_from_record, num_parallel_calls=tf.data.AUTOTUNE) meta_data = file_util.load_json_file(meta_data_file) logging.info( 'Load preprocessed data and metadata from %s and %s ' 'with size: %d', tfrecord_file, meta_data_file, meta_data['size']) return dataset, meta_data