Exemplo n.º 1
0
    def test_train_and_eval_with_document_interactions(self):
        data_dir = tf.compat.v1.test.get_temp_dir()
        data_file = os.path.join(data_dir, "elwc.tfrecord")
        if tf.io.gfile.exists(data_file):
            tf.io.gfile.remove(data_file)

        with tf.io.TFRecordWriter(data_file) as writer:
            for elwc in [ELWC] * 10:
                writer.write(elwc.SerializeToString())

        model_dir = os.path.join(data_dir, "model")

        with flagsaver.flagsaver(train_path=data_file,
                                 eval_path=data_file,
                                 data_format="example_list_with_context",
                                 model_dir=model_dir,
                                 num_train_steps=10,
                                 listwise_inference=True,
                                 use_document_interactions=True,
                                 group_size=1,
                                 weights_feature_name="doc_weight"):
            tf_ranking_tfrecord.train_and_eval()

        if tf.io.gfile.exists(model_dir):
            tf.io.gfile.rmtree(model_dir)
    def test_train_and_eval(self, listwise_inference):
        tmp_dir = self.create_tempdir()
        data_file = os.path.join(tmp_dir, "elwc.tfrecord")
        if tf.io.gfile.exists(data_file):
            tf.io.gfile.remove(data_file)

        with tf.io.TFRecordWriter(data_file) as writer:
            for elwc in [ELWC] * 10:
                writer.write(elwc.SerializeToString())

        model_dir = os.path.join(tmp_dir, "model")

        with flagsaver.flagsaver(train_path=data_file,
                                 eval_path=data_file,
                                 data_format="example_list_with_context",
                                 model_dir=model_dir,
                                 num_train_steps=10,
                                 listwise_inference=listwise_inference,
                                 group_size=1,
                                 weights_feature_name="doc_weight"):
            tf_ranking_tfrecord.train_and_eval()