コード例 #1
0
    def testGraphReloading(self, predictor_str):
        """Reloading the graph should not affect values."""
        predictor = self._get_predictor(predictor_str)
        if predictor_str == DEFAULT:
            new_predictor = holparam_predictor.HolparamPredictor(
                self.default_ckpt)
        elif predictor_str == TAC_DEP:
            new_predictor = holparam_predictor.TacDependentPredictor(
                self.tac_dep_ckpt)
        else:
            raise ValueError('Unknown predictor string: %s' % predictor_str)

        self.assertAllEqual(
            predictor.batch_goal_embedding(self.formulas),
            new_predictor.batch_goal_embedding(self.formulas),
            'Reloading the graph should not change goal embeddings.')

        self.assertAllEqual(
            predictor.batch_thm_embedding(self.formulas),
            new_predictor.batch_thm_embedding(self.formulas),
            'Reloading the graph should not change theorem embeddings.')

        self.assertAllEqual(
            predictor.batch_tactic_scores(self.embeddings),
            new_predictor.batch_tactic_scores(self.embeddings),
            'Reloading the graph should not change tactic scores.')

        self.assertAllEqual(
            predictor.batch_thm_scores(self.embeddings[0],
                                       self.embeddings,
                                       tactic_id=self.tactic_id),
            new_predictor.batch_thm_scores(self.embeddings[0],
                                           self.embeddings,
                                           tactic_id=self.tactic_id),
            'Reloading the graph should not change theorem scores.')
コード例 #2
0
    def setUpClass(cls):
        """Restoring the graph takes a lot of time, so we do it only once here."""
        cls.default_ckpt = test_util.test_src_dir_path(DEFAULT_TEST_PATH)
        cls.default_predictions = holparam_predictor.HolparamPredictor(
            cls.default_ckpt)

        cls.tac_dep_ckpt = test_util.test_src_dir_path(TAC_DEP_TEST_PATH)
        cls.tac_dep_predictions = holparam_predictor.TacDependentPredictor(
            cls.tac_dep_ckpt)
コード例 #3
0
    def setUpClass(cls):
        """Restoring the graph takes a lot of memory, so we do it only once here."""

        cls.predictor = holparam_predictor.HolparamPredictor(
            PREDICTIONS_MODEL_PREFIX)
        cls.model_architecture = deephol_pb2.ProverOptions.PAIR_DEFAULT
        cls.theorem_database = proof_assistant_pb2.TheoremDatabase()
        cls.theorem_database.theorems.add(name='EQ_REFL', conclusion=EQ_REFL)
        cls.theorem_database.theorems.add(name='EQ_SYM', conclusion=EQ_SYM)
        cls.test_goal_index = 1
        cls.test_theorem = cls.theorem_database.theorems[cls.test_goal_index]
        cls.test_goal = cls.theorem_database.theorems[cls.test_goal_index]
        cls.test_premise_set = prover_util.make_premise_set(
            cls.test_theorem, 'default')
        cls.options = deephol_pb2.ActionGeneratorOptions()
コード例 #4
0
ファイル: prover.py プロジェクト: shiyuli/deepmath
def get_predictor(
        options: deephol_pb2.ProverOptions) -> predictions.Predictions:
    """Returns appropriate predictor based on prover options."""
    model_arch = options.model_architecture
    if model_arch == deephol_pb2.ProverOptions.PAIR_DEFAULT:
        return holparam_predictor.HolparamPredictor(
            str(options.path_model_prefix))

    if model_arch == deephol_pb2.ProverOptions.PARAMETERS_CONDITIONED_ON_TAC:
        return holparam_predictor.TacDependentPredictor(
            str(options.path_model_prefix))
    if (model_arch == deephol_pb2.ProverOptions.HIST_AVG
            or model_arch == deephol_pb2.ProverOptions.HIST_CONV
            or model_arch == deephol_pb2.ProverOptions.HIST_ATT):
        raise NotImplementedError(
            'History-dependent model %s is not supported in the prover.' %
            model_arch)

    raise AttributeError('Unknown model architecture in prover options: %s' %
                         model_arch)
コード例 #5
0
 def _get_new_predictor(self):
   return holparam_predictor.HolparamPredictor(self.checkpoint)
コード例 #6
0
  def setUpClass(cls):
    """Restoring the graph takes a lot of time, so we do it only once here."""
    super(HolparamPredictorTest, cls).setUpClass()

    cls.checkpoint = test_util.test_src_dir_path(DEFAULT_TEST_PATH)
    cls.predictor = holparam_predictor.HolparamPredictor(cls.checkpoint)
コード例 #7
0
 def start_bundle(self):
     logging.info('Initializing the batching predictor...')
     self.predictor = holparam_predictor.HolparamPredictor(
         self.chkpt, max_embedding_batch_size=self.batch_size)
     logging.info('Initializing the embedding store...')
     self.emb_store = embedding_store.TheoremEmbeddingStore(self.predictor)