def predict(self, output_dir, input_data_path, metadata_file, checkpoint_path, execution_context, schema_params): n_records = 0 n_batch = 0 # Predict on the dataset sharded_dataset_paths, file_level_sharding = shard_input_files( input_data_path, execution_context[constants.NUM_SHARDS], execution_context[constants.SHARD_INDEX]) if file_level_sharding and len(sharded_dataset_paths) == 0: logger.info("No input dataset is found, returning...") return inference_dataset = lambda: input_fn( input_pattern=','.join(sharded_dataset_paths), # noqa: E731 # DeText uses metadata_path metadata_path=self.model_params.metadata_path, batch_size=self.model_params.test_batch_size, mode=tf.estimator.ModeKeys.EVAL, vocab_table=vocab_utils.read_tf_vocab(self.model_params.vocab_file, self.model_params.UNK), vocab_table_for_id_ftr=vocab_utils.read_tf_vocab( self.model_params.vocab_file_for_id_ftr, self.model_params. UNK_FOR_ID_FTR), feature_names=self.model_params.feature_names, CLS=self.model_params.CLS, SEP=self.model_params.SEP, PAD=self.model_params.PAD, PAD_FOR_ID_FTR=self.model_params.PAD_FOR_ID_FTR, max_len=self.model_params.max_len, min_len=self.model_params.min_len, cnn_filter_window_size=max(self.model_params.filter_window_sizes) if self.model_params.ftr_ext == 'cnn' else 0) self.estimator_based_model = detext_train.get_estimator( self.model_params, strategy=None, # local mode best_checkpoint=self.best_checkpoint) output = self.estimator_based_model.predict( inference_dataset, yield_single_examples=False) detext_writer = DetextWriter(schema_params=schema_params) shard_index = execution_context[constants.SHARD_INDEX] output_file = os.path.join(output_dir, "part-{0:05d}.avro".format(shard_index)) for batch_score in output: if n_batch == 0: with tf.io.gfile.GFile(output_file, 'wb') as f: f.seekable = lambda: False n_records, n_batch = detext_writer.save_batch( f, batch_score, output_file, n_records, n_batch) else: with tf.io.gfile.GFile(output_file, 'ab+') as f: f.seek(0, 2) f.seekable = lambda: True f.readable = lambda: True n_records, n_batch = detext_writer.save_batch( f, batch_score, output_file, n_records, n_batch) logger.info("{} batches, e.g. {} records inferenced".format( n_batch, n_records))
def testMultitaskInputFnBuilderTfrecord(self): """Test additional input from multitask training in eval mode""" res_dir = os.path.dirname(__file__) + '/../resources' # create a vocab table vocab_table = vocab_utils.read_tf_vocab(res_dir + '/vocab.txt', '[UNK]') # dataset dir data_dir = os.path.join(res_dir, 'train', 'multitask', 'tfrecord') # test minimum features required for multitask jobs feature_names = ('label', 'query', 'doc_field1', 'doc_field2', 'wide_ftrs', 'task_id') batch_size = 5 dataset = data_fn.input_fn(input_pattern=data_dir, metadata_path=None, batch_size=batch_size, mode=tf.estimator.ModeKeys.EVAL, vocab_table=vocab_table, vocab_table_for_id_ftr=vocab_table, feature_names=feature_names, CLS='[CLS]', SEP='[SEP]', PAD='[PAD]', PAD_FOR_ID_FTR='[PAD]', max_len=16, cnn_filter_window_size=1) # Make iterator iterator = dataset.make_initializable_iterator() batch_data = iterator.get_next() with tf.Session() as sess: sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) sess.run([iterator.initializer]) batch_data_val, = sess.run([batch_data]) features, label = batch_data_val # First dimension of data should be batch_size for ftr_name in feature_names: if ftr_name != 'label': self.assertTrue(ftr_name in features) self.assertTrue(features[ftr_name].shape[0] == batch_size) self.assertTrue(label['label'].shape[0] == batch_size) task_ids = features['task_id'] # Check task_id dimension size self.assertEqual(len(task_ids.shape), 1) # Check task_id value in the sample data for t_id in task_ids: self.assertTrue(t_id in (0, 1))
def testInputFnBuilderTfrecord(self): """Test function input_fn_builder() in eval mode""" res_dir = os.path.dirname(__file__) + '/../resources' # create a vocab table vocab_table = vocab_utils.read_tf_vocab(res_dir + '/vocab.txt', '[UNK]') # dataset dir data_dir = os.path.join(res_dir, 'train', 'dataset') input_files = os.path.join(data_dir, '*.tfrecord') # create a dataset. # Read schema # Parse and process data in dataset feature_names = ('label', 'query', 'doc_completedQuery', 'usr_headline', 'usr_skills', 'usr_currTitles', 'usrId_currTitles', 'docId_completedQuery', 'wide_ftrs', 'weight') batch_size = 2 dataset = data_fn.input_fn(input_pattern=input_files, metadata_path=None, batch_size=batch_size, mode=tf.estimator.ModeKeys.EVAL, vocab_table=vocab_table, vocab_table_for_id_ftr=vocab_table, feature_names=feature_names, CLS='[CLS]', SEP='[SEP]', PAD='[PAD]', PAD_FOR_ID_FTR='[PAD]', max_len=16, cnn_filter_window_size=1) # Make iterator iterator = dataset.make_initializable_iterator() batch_data = iterator.get_next() with tf.Session() as sess: sess.run( [tf.global_variables_initializer(), tf.tables_initializer()]) sess.run([iterator.initializer]) batch_data_val, = sess.run([batch_data]) features, label = batch_data_val # First dimension of data should be batch_size for ftr_name in feature_names: if ftr_name != 'label': self.assertTrue(ftr_name in features) self.assertTrue(features[ftr_name].shape[0] == batch_size) self.assertTrue(label['label'].shape[0] == batch_size) doc_completedQuery = features['doc_completedQuery'] docId_completedQuery = features['docId_completedQuery'] usr_currTitles = features['usr_currTitles'] usrId_currTitles = features['usrId_currTitles'] # vocab[PAD] == PAD_ID self.assertTrue(doc_completedQuery[0, 0, -1] == self.PAD_ID) self.assertTrue(docId_completedQuery[0, 0, -1] == self.PAD_ID) # vocab[CLS] == CLS_ID self.assertTrue(np.all(doc_completedQuery[0, 0, 0] == self.CLS_ID)) self.assertTrue(np.all(usr_currTitles[0, 0] == self.CLS_ID)) # No CLS in id feature self.assertTrue( np.all(docId_completedQuery[:, :, 0] != self.CLS_ID)) # In this TFRecord file, we populate docId_completeQuery using doc_completedQuery # doc id feature should be the same as doc text feature except CLS and SEP addition # Here we make sure this is correct for the first sample for text_arr, id_arr in zip(doc_completedQuery[0], docId_completedQuery[0]): self.assertAllEqual(text_arr[text_arr != self.PAD_ID][1:-1], id_arr[id_arr != self.PAD_ID]) # In this TFRecord file, we populate usrId_currTitles using usr_currTitles # usr id feature should be the same as usr text feature except CLS and SEP addition for text_arr, id_arr in zip(usr_currTitles, usrId_currTitles): self.assertAllEqual(text_arr[text_arr != self.PAD_ID][1:-1], id_arr[id_arr != self.PAD_ID])