예제 #1
0
  def test_train_and_eval(self, secondary_loss):
    data_dir = tf.compat.v1.test.get_temp_dir()
    data_file = os.path.join(data_dir, "libvsvm.txt")
    if tf.io.gfile.exists(data_file):
      tf.io.gfile.remove(data_file)

    with open(data_file, "wt") as writer:
      writer.write(LIBSVM_DATA)

    output_dir = os.path.join(data_dir, secondary_loss or "")

    with flagsaver.flagsaver(
        train_path=data_file,
        vali_path=data_file,
        test_path=data_file,
        output_dir=output_dir,
        loss="pairwise_logistic_loss",
        secondary_loss=secondary_loss,
        num_train_steps=10,
        list_size=10,
        group_size=2,
        num_features=100):
      tf_ranking_libsvm.train_and_eval()

    if tf.io.gfile.exists(output_dir):
      tf.io.gfile.rmtree(output_dir)
예제 #2
0
    def test_train_and_eval(self):
        data_dir = tf.test.get_temp_dir()
        data_file = os.path.join(data_dir, "libvsvm.txt")
        if tf.gfile.Exists(data_file):
            tf.gfile.Remove(data_file)

        with open(data_file, "wt") as writer:
            writer.write(LIBSVM_DATA)

        self._data_file = data_file
        self._output_dir = data_dir

        with flagsaver.flagsaver(train_path=self._data_file,
                                 vali_path=self._data_file,
                                 test_path=self._data_file,
                                 output_dir=self._output_dir,
                                 num_train_steps=10,
                                 list_size=10,
                                 group_size=2,
                                 num_features=100):
            tf_ranking_libsvm.train_and_eval()

        if tf.gfile.Exists(self._output_dir):
            tf.gfile.DeleteRecursively(self._output_dir)