def test_shard_input_file_with_more_shards(self): input_file_pattern = os.path.join(self._base_dir, "*.tfrecord") shard_files, indicator = shard_input_files(input_file_pattern, 20, 1) expected_files = [os.path.join(self._base_dir, '1.tfrecord')] self.assertAllEqual(shard_files, expected_files) self.assertTrue(indicator) shard_files, indicator = shard_input_files(input_file_pattern, 20, 19) self.assertEqual(len(shard_files), 0) self.assertTrue(indicator)
def predict(self, output_dir, input_data_path, metadata_file, checkpoint_path, execution_context, schema_params): n_records = 0 n_batch = 0 # Predict on the dataset sharded_dataset_paths, file_level_sharding = shard_input_files( input_data_path, execution_context[constants.NUM_SHARDS], execution_context[constants.SHARD_INDEX]) if file_level_sharding and len(sharded_dataset_paths) == 0: logger.info("No input dataset is found, returning...") return inference_dataset = lambda: input_fn( input_pattern=','.join(sharded_dataset_paths), # noqa: E731 # DeText uses metadata_path metadata_path=self.model_params.metadata_path, batch_size=self.model_params.test_batch_size, mode=tf.estimator.ModeKeys.EVAL, vocab_table=vocab_utils.read_tf_vocab(self.model_params.vocab_file, self.model_params.UNK), vocab_table_for_id_ftr=vocab_utils.read_tf_vocab( self.model_params.vocab_file_for_id_ftr, self.model_params. UNK_FOR_ID_FTR), feature_names=self.model_params.feature_names, CLS=self.model_params.CLS, SEP=self.model_params.SEP, PAD=self.model_params.PAD, PAD_FOR_ID_FTR=self.model_params.PAD_FOR_ID_FTR, max_len=self.model_params.max_len, min_len=self.model_params.min_len, cnn_filter_window_size=max(self.model_params.filter_window_sizes) if self.model_params.ftr_ext == 'cnn' else 0) self.estimator_based_model = detext_train.get_estimator( self.model_params, strategy=None, # local mode best_checkpoint=self.best_checkpoint) output = self.estimator_based_model.predict( inference_dataset, yield_single_examples=False) detext_writer = DetextWriter(schema_params=schema_params) shard_index = execution_context[constants.SHARD_INDEX] output_file = os.path.join(output_dir, "part-{0:05d}.avro".format(shard_index)) for batch_score in output: if n_batch == 0: with tf.io.gfile.GFile(output_file, 'wb') as f: f.seekable = lambda: False n_records, n_batch = detext_writer.save_batch( f, batch_score, output_file, n_records, n_batch) else: with tf.io.gfile.GFile(output_file, 'ab+') as f: f.seek(0, 2) f.seekable = lambda: True f.readable = lambda: True n_records, n_batch = detext_writer.save_batch( f, batch_score, output_file, n_records, n_batch) logger.info("{} batches, e.g. {} records inferenced".format( n_batch, n_records))
def _get_assigned_files(input_data_path, num_shards, shard_index): """ Get the assigned files from the shard :param input_data_path: :return: a list of assigned file names. """ assigned_files, sample_level_shard = shard_input_files(input_data_path, num_shards, shard_index) assert not sample_level_shard, "Doesn't support sample level sharding," \ "number of files must >= number of workers" return assigned_files
def test_shard_input_files_with_wrong_params(self): with self.assertRaises(AssertionError): shard_input_files(self._base_dir, 1, 2) with self.assertRaises(AssertionError): shard_input_files(self._base_dir, -1, -2) with self.assertRaises(tf.errors.NotFoundError): shard_input_files(os.path.join(self._base_dir, "nowhere/nofile"), 3, 2)
def predict(self, output_dir, input_data_path, metadata_file, checkpoint_path, execution_context, schema_params): n_records = 0 n_batch = 0 # Predict on the dataset sharded_dataset_paths, file_level_sharding = shard_input_files(input_data_path, execution_context[constants.NUM_SHARDS], execution_context[constants.SHARD_INDEX]) if file_level_sharding and len(sharded_dataset_paths) == 0: logger.info("No input dataset is found, returning...") return inference_dataset = input_fn_tfrecord(input_pattern=','.join(sharded_dataset_paths), # noqa: E731 batch_size=self.model_params.test_batch_size, mode=tf.estimator.ModeKeys.EVAL, feature_type2name=self.model_params.feature_type2name, feature_name2num=self.model_params.feature_name2num, task_type=self.model_params.task_type) self.model = train_model_helper.load_model_with_ckpt( parsing_utils.HParams(**asdict(self.model_params)), self.best_checkpoint) output = train_flow_helper.predict_with_additional_info(inference_dataset, self.model, self.model_params.feature_type2name) detext_writer = DetextWriter(schema_params=schema_params) shard_index = execution_context[constants.SHARD_INDEX] output_file = os.path.join(output_dir, f"part-{shard_index:05d}.avro") for batch_score in output: if n_batch == 0: with tf.io.gfile.GFile(output_file, 'wb') as f: f.seekable = lambda: False n_records, n_batch = detext_writer.save_batch(f, batch_score, output_file, n_records, n_batch) else: with tf.io.gfile.GFile(output_file, 'ab+') as f: f.seek(0, 2) f.seekable = lambda: True f.readable = lambda: True n_records, n_batch = detext_writer.save_batch(f, batch_score, output_file, n_records, n_batch) logger.info(f"{n_batch} records, e.g. {n_records} records inferenced")
def test_shard_input_file_with_filename_pattern(self): input_file_pattern = os.path.join(self._base_dir, "*.tfrecord") shard_files, indicator = shard_input_files(input_file_pattern, 3, 1) expected_files = [os.path.join(self._base_dir, f'{i}.tfrecord') for i in range(1, 10, 3)] self.assertAllEqual(shard_files, expected_files) self.assertFalse(indicator)
def test_shard_input_files_with_directory(self): shard_files, _ = shard_input_files(self._base_dir, 2, 0) expected_files = [os.path.join(self._base_dir, f'{i}.avro') for i in range(10)] self.assertAllEqual(shard_files, expected_files)
def per_record_input_fn(input_path, metadata_file, num_shards, shard_index, batch_size, data_format, custom_input_fn=None): """ Input function for per-record dataset. In the dataset, the records are individual examples. Batch size adds extra dimension on top of that. :param input_path: input directory or file pattern. :param metadata_file: tensor metadata file :param num_shards: number of shards :param shard_index: the index of this worker :param batch_size: batch size :param data_format: tfrecord or avro :param custom_input_fn: full name "package.module.fn" for the external custom input_fn. :return: a batched dataset. """ if data_format == constants.TFRECORD: logger.info("using {} dataset".format(constants.TFRECORD)) def build_features(tensors): """ Create features from metadata, used to deserialize the tfrecord. :param tensors: list of metadata for all tensors. :return: tfrecord features """ tf_features = {} for feature in tensors: if feature.isSparse: # If this is a sparse tensor, we process indices and values separately. # Note in the metadata, we don't see _indices and _values, # only the feature name. tf_features[feature.name] = tf.io.SparseFeature( index_key=f"{feature.name}_{DatasetMetadata.INDICES}", value_key=f"{feature.name}_{DatasetMetadata.VALUES}", dtype=DatasetMetadata.map_int(feature.dtype), size=_unpack_one_element_list(feature.shape)) else: tf_features[feature.name] = tf.io.FixedLenFeature( shape=feature.shape, dtype=DatasetMetadata.map_int(feature.dtype)) return tf_features def map_fn(serialized, feature_tensors, label_tensors): """ Deserialize TF records to features. This is done after batching since we are using tf.io.parse_example :param serialized: Serialized TF records :param feature_tensors: list of feature tensors :param label_tensors: list of label tensors :return: (features, labels) tuple where each of them is a map. """ tensors = feature_tensors + label_tensors tf_features = build_features(tensors) example = tf.io.parse_example(serialized, tf_features, example_names=None, name=None) # then split features from labels return _splits_label_and_features(example, label_tensors) # Get shard input files input_filename_pattern = _convert_dir_to_filename_pattern( input_path, constants.TFRECORD_GLOB_PATTERN) input_files, _ = shard_input_files(input_filename_pattern, num_shards, shard_index) # Get metadata feature_tensors, label_tensors = _get_features_and_labels_info( metadata_file) # Batching dataset = tf.data.TFRecordDataset(input_files).batch( batch_size, drop_remainder=False) # Deserialize to features dataset = dataset.map(partial(map_fn, feature_tensors=feature_tensors, label_tensors=label_tensors), num_parallel_calls=16) elif custom_input_fn is not None: logger.info("loading {} dataset by {}".format(data_format, custom_input_fn)) import importlib module_name, fn_name = custom_input_fn.rsplit('.', 1) dataset_module = importlib.import_module(module_name) dataset = getattr(dataset_module, fn_name)(input_path, metadata_file, num_shards, shard_index, batch_size, data_format) else: raise Exception("Unknown data format :{}".format(data_format)) return dataset
def per_entity_grouped_input_fn(input_path, metadata_file, num_shards, shard_index, batch_size, data_format, entity_name, custom_input_fn=None): """ Input function for per-entity grouped dataset. In the dataset, the records are grouped based on entity Id. each feature is a vector except the entity Id which is a scalar. Batch size adds extra dimension on top of that. :param input_path: input directory or file pattern. :param metadata_file: tensor metadata file :param num_shards: number of shards :param shard_index: the index of this worker :param batch_size: batch size :param data_format: tfrecord or avro :param entity_name: the name of the entity which is used to group the records. :param custom_input_fn: full name "package.module.fn" for the external custom input_fn. :return: a batched dataset. """ if data_format == constants.TFRECORD: logger.info(f"using {constants.TFRECORD} dataset") # Build features def build_features(tensors, entity_name): """ Create features from metadata, used to deserialize the tfrecord. :param tensors: list of metadata for all tensors. :param entity_name: entity by which the records are grouped. :return: a tuple of context_features and sequence_features """ sequence_features = dict() context_features = dict() for tensor in tensors: tensor_dtype = DatasetMetadata.map_int(tensor.dtype) if tensor.name == entity_name: # entity_name column is a scalar context_features[entity_name] = tf.io.FixedLenFeature( shape=[], dtype=tensor_dtype) else: if tensor.isSparse: # If this is a sparse tensor, we process indices and values separately. # Note in the metadata, we don't see _indices and _values, # only the feature name. indices_name = f"{tensor.name}_{DatasetMetadata.INDICES}" values_name = f"{tensor.name}_{DatasetMetadata.VALUES}" sequence_features[indices_name] = tf.io.VarLenFeature( dtype=tf.int64) sequence_features[values_name] = tf.io.VarLenFeature( dtype=tensor_dtype) else: context_features[tensor.name] = tf.io.VarLenFeature( dtype=tensor_dtype) if len(sequence_features) == 0: sequence_features = None if len(context_features) == 0: context_features = None return context_features, sequence_features def map_fn(serialized, feature_tensors, label_tensors, entity_name): """ Map serialized tfrecord to a dict of tensors. :param serialized: serialized tfrecord :param feature_tensors: list of metadata for features :param label_tensors: list of metadata for labels :param entity_name: entity by which the records are grouped. :return: A tuple (features, labels) """ tensors = feature_tensors + label_tensors context_features, sequence_features = build_features( tensors, entity_name) example = tf.io.parse_sequence_example( serialized, context_features=context_features, sequence_features=sequence_features, example_names=None, name=None) # Split features from labels context, sequence = example[0], example[1] sequence.update(context) return _splits_label_and_features(sequence, label_tensors) # Get shard input files input_filename_pattern = _convert_dir_to_filename_pattern( input_path, constants.TFRECORD_GLOB_PATTERN) input_files, _ = shard_input_files(input_filename_pattern, num_shards, shard_index) # Get metadata feature_tensors, label_tensors = _get_features_and_labels_info( metadata_file) # Check if entity_name is one of the features feature_names = [x.name for x in feature_tensors] if entity_name not in feature_names: raise ValueError( f"entity name {entity_name} is not found among the features") # Batching dataset = tf.data.TFRecordDataset(input_files).batch( batch_size, drop_remainder=False) # Deserialize to features dataset = dataset.map(partial(map_fn, feature_tensors=feature_tensors, label_tensors=label_tensors, entity_name=entity_name), num_parallel_calls=tf.data.experimental.AUTOTUNE) elif custom_input_fn: logger.info(f"loading {data_format} dataset by {custom_input_fn}") import importlib module_name, fn_name = custom_input_fn.rsplit('.', 1) dataset_module = importlib.import_module(module_name) dataset = getattr(dataset_module, fn_name)(input_path, metadata_file, num_shards, shard_index, batch_size, entity_name, data_format) else: raise Exception(f"Unknown data format : {data_format}") return dataset