def testPredict(self): """Tests predict()""" dataset = data_fn.input_fn_tfrecord( input_pattern=self.data_dir, batch_size=self.batch_size, mode=tf.estimator.ModeKeys.EVAL, feature_type2name=self.feature_type2name, feature_name2num=self.feature_name2num, input_pipeline_context=None, ) detext_model = model.create_detext_model(self.feature_type2name, task_type=self.task_type, **self.deep_match_param) predicted_output = train_flow_helper.predict_with_additional_info( dataset, detext_model, self.feature_type2name) for output in predicted_output: for key in [ train_flow_helper._SCORES, self.feature_type2name.get( InputFtrType.WEIGHT_COLUMN_NAME, Constant()._DEFAULT_WEIGHT_FTR_NAME), self.feature_type2name.get( InputFtrType.UID_COLUMN_NAME, Constant()._DEFAULT_UID_FTR_NAME), self.feature_type2name[InputFtrType.LABEL_COLUMN_NAME] ]: self.assertIn(key, output)
def _input_fn_tfrecord(ctx): return data_fn.input_fn_tfrecord(input_pattern=data_dir, batch_size=batch_size, mode=tf.estimator.ModeKeys.EVAL, feature_type2name=feature_type2name, feature_name2num=feature_name2num, input_pipeline_context=ctx)
def testBinaryClassificationInputFnBuilderTfrecord(self): """Test binary classification input reader """ data_dir = self.binary_cls_data_dir feature_type2name = { InputFtrType.LABEL_COLUMN_NAME: 'label', InputFtrType.SPARSE_FTRS_COLUMN_NAMES: ['sparse_ftrs'], InputFtrType.SHALLOW_TOWER_SPARSE_FTRS_COLUMN_NAMES: ['shallow_tower_sparse_ftrs', 'sparse_ftrs'] } feature_name2num = { 'sparse_ftrs': 20, 'shallow_tower_sparse_ftrs': 20 } batch_size = 2 dataset = data_fn.input_fn_tfrecord(input_pattern=data_dir, batch_size=batch_size, mode=tf.estimator.ModeKeys.EVAL, task_type=TaskType.BINARY_CLASSIFICATION, feature_type2name=feature_type2name, feature_name2num=feature_name2num ) for features, label in dataset: # First dimension of data should be batch_size for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type2name): if ftr_type in (InputFtrType.LABEL_COLUMN_NAME, InputFtrType.WEIGHT_COLUMN_NAME, InputFtrType.UID_COLUMN_NAME): self.assertLen(ftr_name_lst, 1), f'Length for current ftr type ({ftr_type}) should be 1' ftr_name = ftr_name_lst[0] self.assertIn(ftr_name, label) continue for ftr_name in ftr_name_lst: self.assertIn(ftr_name, features) self.assertEqual(features[ftr_name].shape[0], batch_size) weight_ftr_name = feature_type2name.get(InputFtrType.WEIGHT_COLUMN_NAME, constant.Constant()._DEFAULT_WEIGHT_FTR_NAME) self.assertAllEqual(tf.shape(label[weight_ftr_name]), [batch_size]) uid_ftr_name = feature_type2name.get(InputFtrType.UID_COLUMN_NAME, constant.Constant()._DEFAULT_UID_FTR_NAME) self.assertAllEqual(tf.shape(label[uid_ftr_name]), [batch_size]) self.assertAllEqual(label['label'].shape, [batch_size]) self.assertAllEqual(tf.sparse.to_dense(features['sparse_ftrs']), tf.sparse.to_dense( tf.SparseTensor(indices=[[0, 0], [0, 2], [0, 7], [1, 0], [1, 2], [1, 7]], values=[1, 0, 7, 1, 0, 7], dense_shape=[batch_size, self.nums_sparse_ftrs[0]]) ) ) # Only check first batch break
def testRankingMultitaskInputFnBuilderTfrecord(self): """Test additional input from multitask training in eval mode""" data_dir = self.ranking_data_dir # Test minimum features required for multitask jobs feature_type2name = { InputFtrType.LABEL_COLUMN_NAME: 'label', InputFtrType.QUERY_COLUMN_NAME: 'query', InputFtrType.DOC_TEXT_COLUMN_NAMES: ['doc_headline', 'doc_title'], InputFtrType.USER_TEXT_COLUMN_NAMES: ['user_headline', 'user_title'], InputFtrType.DOC_ID_COLUMN_NAMES: ['doc_headline_id'], InputFtrType.USER_ID_COLUMN_NAMES: ['user_headline_id'], InputFtrType.DENSE_FTRS_COLUMN_NAMES: ['dense_ftrs'], InputFtrType.WEIGHT_COLUMN_NAME: 'weight', InputFtrType.TASK_ID_COLUMN_NAME: 'task_id_field' } feature_name2num = { 'dense_ftrs': 2 } batch_size = 5 dataset = data_fn.input_fn_tfrecord(input_pattern=data_dir, batch_size=batch_size, mode=tf.estimator.ModeKeys.EVAL, feature_type2name=feature_type2name, feature_name2num=feature_name2num) for features, label in dataset: # First dimension of data should be batch_size for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type2name): if ftr_type in (InputFtrType.LABEL_COLUMN_NAME, InputFtrType.WEIGHT_COLUMN_NAME, InputFtrType.UID_COLUMN_NAME): self.assertLen(ftr_name_lst, 1), f'Length for current ftr type ({ftr_type}) should be 1' ftr_name = ftr_name_lst[0] self.assertIn(ftr_name, label) continue for ftr_name in ftr_name_lst: self.assertIn(ftr_name, features) self.assertEqual(features[ftr_name].shape[0], batch_size) weight_ftr_name = feature_type2name.get(InputFtrType.WEIGHT_COLUMN_NAME, constant.Constant()._DEFAULT_WEIGHT_FTR_NAME) self.assertAllEqual(tf.shape(label[weight_ftr_name]), [batch_size]) uid_ftr_name = feature_type2name.get(InputFtrType.UID_COLUMN_NAME, constant.Constant()._DEFAULT_UID_FTR_NAME) self.assertAllEqual(tf.shape(label[uid_ftr_name]), [batch_size]) # First dimension of data should be batch_size self.assertEqual(label['label'].shape[0], batch_size) task_ids = features['task_id_field'] # Check task_id dimension size self.assertEqual(len(task_ids.shape), 1) # Check task_id value in the sample data for t_id in task_ids: self.assertAllEqual(t_id, 5)
def _get_input_fn_common(pattern, batch_size, mode, task_type, feature_type2name, feature_name2num: dict): """ Returns the common input function used in DeText training and evaluation""" return lambda ctx: input_fn_tfrecord(input_pattern=pattern, batch_size=batch_size, mode=mode, task_type=task_type, feature_type2name=feature_type2name, feature_name2num=feature_name2num, input_pipeline_context=ctx)
def predict(self, output_dir, input_data_path, metadata_file, checkpoint_path, execution_context, schema_params): n_records = 0 n_batch = 0 # Predict on the dataset sharded_dataset_paths, file_level_sharding = shard_input_files(input_data_path, execution_context[constants.NUM_SHARDS], execution_context[constants.SHARD_INDEX]) if file_level_sharding and len(sharded_dataset_paths) == 0: logger.info("No input dataset is found, returning...") return inference_dataset = input_fn_tfrecord(input_pattern=','.join(sharded_dataset_paths), # noqa: E731 batch_size=self.model_params.test_batch_size, mode=tf.estimator.ModeKeys.EVAL, feature_type2name=self.model_params.feature_type2name, feature_name2num=self.model_params.feature_name2num, task_type=self.model_params.task_type) self.model = train_model_helper.load_model_with_ckpt( parsing_utils.HParams(**asdict(self.model_params)), self.best_checkpoint) output = train_flow_helper.predict_with_additional_info(inference_dataset, self.model, self.model_params.feature_type2name) detext_writer = DetextWriter(schema_params=schema_params) shard_index = execution_context[constants.SHARD_INDEX] output_file = os.path.join(output_dir, f"part-{shard_index:05d}.avro") for batch_score in output: if n_batch == 0: with tf.io.gfile.GFile(output_file, 'wb') as f: f.seekable = lambda: False n_records, n_batch = detext_writer.save_batch(f, batch_score, output_file, n_records, n_batch) else: with tf.io.gfile.GFile(output_file, 'ab+') as f: f.seek(0, 2) f.seekable = lambda: True f.readable = lambda: True n_records, n_batch = detext_writer.save_batch(f, batch_score, output_file, n_records, n_batch) logger.info(f"{n_batch} records, e.g. {n_records} records inferenced")
def testClassificationInputFnBuilderTfrecord(self): """Test classification input reader in eval mode""" data_dir = self.cls_data_dir feature_type2name = { InputFtrType.LABEL_COLUMN_NAME: 'label', InputFtrType.DOC_TEXT_COLUMN_NAMES: ['query_text'], InputFtrType.USER_TEXT_COLUMN_NAMES: ['user_headline'], InputFtrType.DENSE_FTRS_COLUMN_NAMES: 'dense_ftrs', } feature_name2num = { 'dense_ftrs': 8 } batch_size = 2 dataset = data_fn.input_fn_tfrecord(input_pattern=data_dir, batch_size=batch_size, mode=tf.estimator.ModeKeys.EVAL, task_type=TaskType.CLASSIFICATION, feature_type2name=feature_type2name, feature_name2num=feature_name2num) for features, label in dataset: # First dimension of data should be batch_size for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type2name): if ftr_type in (InputFtrType.LABEL_COLUMN_NAME, InputFtrType.WEIGHT_COLUMN_NAME, InputFtrType.UID_COLUMN_NAME): self.assertLen(ftr_name_lst, 1), f'Length for current ftr type ({ftr_type}) should be 1' ftr_name = ftr_name_lst[0] self.assertIn(ftr_name, label) continue for ftr_name in ftr_name_lst: self.assertIn(ftr_name, features) self.assertEqual(features[ftr_name].shape[0], batch_size) weight_ftr_name = feature_type2name.get(InputFtrType.WEIGHT_COLUMN_NAME, constant.Constant()._DEFAULT_WEIGHT_FTR_NAME) self.assertAllEqual(tf.shape(label[weight_ftr_name]), [batch_size]) uid_ftr_name = feature_type2name.get(InputFtrType.UID_COLUMN_NAME, constant.Constant()._DEFAULT_UID_FTR_NAME) self.assertAllEqual(tf.shape(label[uid_ftr_name]), [batch_size]) self.assertAllEqual(label['label'].shape, [batch_size])