def predict(self, output_dir, input_data_path, metadata_file,
                checkpoint_path, execution_context, schema_params):
        n_records = 0
        n_batch = 0
        # Predict on the dataset
        sharded_dataset_paths, file_level_sharding = shard_input_files(
            input_data_path, execution_context[constants.NUM_SHARDS],
            execution_context[constants.SHARD_INDEX])
        if file_level_sharding and len(sharded_dataset_paths) == 0:
            logger.info("No input dataset is found, returning...")
            return

        inference_dataset = lambda: input_fn(
            input_pattern=','.join(sharded_dataset_paths),  # noqa: E731
            # DeText uses metadata_path
            metadata_path=self.model_params.metadata_path,
            batch_size=self.model_params.test_batch_size,
            mode=tf.estimator.ModeKeys.EVAL,
            vocab_table=vocab_utils.read_tf_vocab(self.model_params.vocab_file,
                                                  self.model_params.UNK),
            vocab_table_for_id_ftr=vocab_utils.read_tf_vocab(
                self.model_params.vocab_file_for_id_ftr, self.model_params.
                UNK_FOR_ID_FTR),
            feature_names=self.model_params.feature_names,
            CLS=self.model_params.CLS,
            SEP=self.model_params.SEP,
            PAD=self.model_params.PAD,
            PAD_FOR_ID_FTR=self.model_params.PAD_FOR_ID_FTR,
            max_len=self.model_params.max_len,
            min_len=self.model_params.min_len,
            cnn_filter_window_size=max(self.model_params.filter_window_sizes)
            if self.model_params.ftr_ext == 'cnn' else 0)

        self.estimator_based_model = detext_train.get_estimator(
            self.model_params,
            strategy=None,  # local mode
            best_checkpoint=self.best_checkpoint)
        output = self.estimator_based_model.predict(
            inference_dataset, yield_single_examples=False)
        detext_writer = DetextWriter(schema_params=schema_params)
        shard_index = execution_context[constants.SHARD_INDEX]
        output_file = os.path.join(output_dir,
                                   "part-{0:05d}.avro".format(shard_index))
        for batch_score in output:
            if n_batch == 0:
                with tf.io.gfile.GFile(output_file, 'wb') as f:
                    f.seekable = lambda: False
                    n_records, n_batch = detext_writer.save_batch(
                        f, batch_score, output_file, n_records, n_batch)
            else:
                with tf.io.gfile.GFile(output_file, 'ab+') as f:
                    f.seek(0, 2)
                    f.seekable = lambda: True
                    f.readable = lambda: True
                    n_records, n_batch = detext_writer.save_batch(
                        f, batch_score, output_file, n_records, n_batch)
        logger.info("{} batches, e.g. {} records inferenced".format(
            n_batch, n_records))
Example #2
0
    def testMultitaskInputFnBuilderTfrecord(self):
        """Test additional input from multitask training in eval mode"""
        res_dir = os.path.dirname(__file__) + '/../resources'

        # create a vocab table
        vocab_table = vocab_utils.read_tf_vocab(res_dir + '/vocab.txt', '[UNK]')

        # dataset dir
        data_dir = os.path.join(res_dir, 'train', 'multitask', 'tfrecord')

        # test minimum features required for multitask jobs
        feature_names = ('label', 'query', 'doc_field1', 'doc_field2', 'wide_ftrs', 'task_id')

        batch_size = 5
        dataset = data_fn.input_fn(input_pattern=data_dir,
                                   metadata_path=None,
                                   batch_size=batch_size,
                                   mode=tf.estimator.ModeKeys.EVAL,
                                   vocab_table=vocab_table,
                                   vocab_table_for_id_ftr=vocab_table,
                                   feature_names=feature_names,
                                   CLS='[CLS]',
                                   SEP='[SEP]',
                                   PAD='[PAD]',
                                   PAD_FOR_ID_FTR='[PAD]',
                                   max_len=16,
                                   cnn_filter_window_size=1)

        # Make iterator
        iterator = dataset.make_initializable_iterator()
        batch_data = iterator.get_next()

        with tf.Session() as sess:
            sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
            sess.run([iterator.initializer])
            batch_data_val, = sess.run([batch_data])
            features, label = batch_data_val

            # First dimension of data should be batch_size
            for ftr_name in feature_names:
                if ftr_name != 'label':
                    self.assertTrue(ftr_name in features)
                    self.assertTrue(features[ftr_name].shape[0] == batch_size)

            self.assertTrue(label['label'].shape[0] == batch_size)

            task_ids = features['task_id']

            # Check task_id dimension size
            self.assertEqual(len(task_ids.shape), 1)

            # Check task_id value in the sample data
            for t_id in task_ids:
                self.assertTrue(t_id in (0, 1))
Example #3
0
    def testInputFnBuilderTfrecord(self):
        """Test function input_fn_builder() in eval mode"""
        res_dir = os.path.dirname(__file__) + '/../resources'

        # create a vocab table
        vocab_table = vocab_utils.read_tf_vocab(res_dir + '/vocab.txt',
                                                '[UNK]')

        # dataset dir
        data_dir = os.path.join(res_dir, 'train', 'dataset')
        input_files = os.path.join(data_dir, '*.tfrecord')

        # create a dataset.
        # Read schema
        # Parse and process data in dataset
        feature_names = ('label', 'query', 'doc_completedQuery',
                         'usr_headline', 'usr_skills', 'usr_currTitles',
                         'usrId_currTitles', 'docId_completedQuery',
                         'wide_ftrs', 'weight')

        batch_size = 2
        dataset = data_fn.input_fn(input_pattern=input_files,
                                   metadata_path=None,
                                   batch_size=batch_size,
                                   mode=tf.estimator.ModeKeys.EVAL,
                                   vocab_table=vocab_table,
                                   vocab_table_for_id_ftr=vocab_table,
                                   feature_names=feature_names,
                                   CLS='[CLS]',
                                   SEP='[SEP]',
                                   PAD='[PAD]',
                                   PAD_FOR_ID_FTR='[PAD]',
                                   max_len=16,
                                   cnn_filter_window_size=1)

        # Make iterator
        iterator = dataset.make_initializable_iterator()
        batch_data = iterator.get_next()

        with tf.Session() as sess:
            sess.run(
                [tf.global_variables_initializer(),
                 tf.tables_initializer()])
            sess.run([iterator.initializer])
            batch_data_val, = sess.run([batch_data])
            features, label = batch_data_val

            # First dimension of data should be batch_size
            for ftr_name in feature_names:
                if ftr_name != 'label':
                    self.assertTrue(ftr_name in features)
                    self.assertTrue(features[ftr_name].shape[0] == batch_size)

            self.assertTrue(label['label'].shape[0] == batch_size)

            doc_completedQuery = features['doc_completedQuery']
            docId_completedQuery = features['docId_completedQuery']
            usr_currTitles = features['usr_currTitles']
            usrId_currTitles = features['usrId_currTitles']

            # vocab[PAD] == PAD_ID
            self.assertTrue(doc_completedQuery[0, 0, -1] == self.PAD_ID)
            self.assertTrue(docId_completedQuery[0, 0, -1] == self.PAD_ID)

            # vocab[CLS] == CLS_ID
            self.assertTrue(np.all(doc_completedQuery[0, 0, 0] == self.CLS_ID))
            self.assertTrue(np.all(usr_currTitles[0, 0] == self.CLS_ID))

            # No CLS in id feature
            self.assertTrue(
                np.all(docId_completedQuery[:, :, 0] != self.CLS_ID))

            # In this TFRecord file, we populate docId_completeQuery using doc_completedQuery
            # doc id feature should be the same as doc text feature except CLS and SEP addition
            # Here we make sure this is correct for the first sample
            for text_arr, id_arr in zip(doc_completedQuery[0],
                                        docId_completedQuery[0]):
                self.assertAllEqual(text_arr[text_arr != self.PAD_ID][1:-1],
                                    id_arr[id_arr != self.PAD_ID])

            # In this TFRecord file, we populate usrId_currTitles using usr_currTitles
            # usr id feature should be the same as usr text feature except CLS and SEP addition
            for text_arr, id_arr in zip(usr_currTitles, usrId_currTitles):
                self.assertAllEqual(text_arr[text_arr != self.PAD_ID][1:-1],
                                    id_arr[id_arr != self.PAD_ID])