コード例 #1
0
ファイル: prover.py プロジェクト: shiyuli/deepmath
def create_prover(options: deephol_pb2.ProverOptions) -> Prover:
    """Creates a Prover object, initializing all dependencies."""
    theorem_database = io_util.load_theorem_database_from_file(
        str(options.path_theorem_database))
    tactics = io_util.load_tactics_from_file(str(options.path_tactics),
                                             str(options.path_tactics_replace))
    if options.action_generator_options.asm_meson_no_params_only:
        tf.logging.warn(
            'Note: Using Meson action generator with no parameters.')
        action_gen = action_generator.MesonActionGenerator()
    else:
        predictor = get_predictor(options)
        emb_store = None
        if options.HasField('theorem_embeddings'):
            emb_store = embedding_store.TheoremEmbeddingStore(predictor)
            emb_store.read_embeddings(str(options.theorem_embeddings))
            assert emb_store.thm_embeddings.shape[0] == len(
                theorem_database.theorems)
        action_gen = action_generator.ActionGenerator(
            theorem_database, tactics, predictor,
            options.action_generator_options, options.model_architecture,
            emb_store)
    hol_wrapper = setup_prover(theorem_database)
    tf.logging.info('DeepHOL dependencies initialization complete.')
    if options.prover == 'bfs':
        return BFSProver(options, hol_wrapper, action_gen, theorem_database)
    return NoBacktrackProver(options, hol_wrapper, action_gen,
                             theorem_database)
コード例 #2
0
    def test_action_generator_theorem_list_parameter_tactic(
            self, use_embedding_store):
        """Checks max_theorem_parameters parameters are passed for a thmlist tactic.

    Args:
      use_embedding_store: True if the embedding store should be used.
    """
        max_parameters = self.options.max_theorem_parameters
        emb_store = None
        thmlist_param_tactic = deephol_pb2.Tactic(
            name='TAC', parameter_types=[deephol_pb2.Tactic.THEOREM_LIST])
        dummy_theorem = proof_assistant_pb2.Theorem(name='THM',
                                                    conclusion='foo')
        theorem_database = proof_assistant_pb2.TheoremDatabase()
        theorem_database.theorems.extend([
            proof_assistant_pb2.Theorem(name='THM%d' % i, conclusion='foo')
            for i in range(2 * max_parameters + 1)
        ])
        if use_embedding_store:
            emb_store = embedding_store.TheoremEmbeddingStore(self.predictor)
            emb_store.compute_embeddings_for_thms_from_db(theorem_database)
        action_gen = action_generator.ActionGenerator(
            theorem_database, [thmlist_param_tactic], self.predictor,
            self.options, self.model_architecture, emb_store)
        test_theorem = theorem_database.theorems[2 * max_parameters]
        actions_scores = action_gen.step(
            self.node, prover_util.make_premise_set(test_theorem, 'default'))
        self.assertStartsWith(actions_scores[-1].string, 'TAC')
        self.assertEqual(max_parameters,
                         actions_scores[-1].string.count('THM'))
コード例 #3
0
 def test_action_generator_unknown_parameter_tactic(self):
     unknown_param_tactic = deephol_pb2.Tactic(
         name='TAC', parameter_types=[deephol_pb2.Tactic.UNKNOWN])
     action_gen = action_generator.ActionGenerator(self.theorem_database,
                                                   [unknown_param_tactic],
                                                   self.predictor,
                                                   self.options,
                                                   self.model_architecture)
     actions_scores = action_gen.step(self.node, self.test_premise_set)
     self.assertEqual(0, len(actions_scores))
コード例 #4
0
 def test_action_generator_no_parameter_tactic(self):
     no_param_tactic = deephol_pb2.Tactic(name='TAC')
     action_gen = action_generator.ActionGenerator(self.theorem_database,
                                                   [no_param_tactic],
                                                   self.predictor,
                                                   self.options,
                                                   self.model_architecture)
     actions_scores = action_gen.step(self.node, self.test_premise_set)
     self.assertEqual(1, len(actions_scores))
     self.assertEqual(actions_scores[0].string, 'TAC')
コード例 #5
0
 def test_action_generator_theorem_parameter_tactic(self):
     thm_param_tactic = deephol_pb2.Tactic(
         name='TAC', parameter_types=[deephol_pb2.Tactic.THEOREM])
     action_gen = action_generator.ActionGenerator(self.theorem_database,
                                                   [thm_param_tactic],
                                                   self.predictor,
                                                   self.options,
                                                   self.model_architecture)
     actions_scores = action_gen.step(self.node, self.test_premise_set)
     self.assertEqual(1, len(actions_scores))
     expected = 'TAC ' + theorem_fingerprint.ToTacticArgument(
         self.theorem_database.theorems[0])
     self.assertEqual(expected, actions_scores[0].string)
コード例 #6
0
    def test_action_generator_hol_light_tactics_sanity_check(self):
        """HolLight tactics sanity test.

    This is a sanity check to ensure action generator works with actual HolLight
    tactics on which the test model was trained.
    """
        hollight_tactics = load_tactics(HOLLIGHT_TACTICS_TEXTPB_PATH)
        action_gen = action_generator.ActionGenerator(self.theorem_database,
                                                      hollight_tactics,
                                                      self.predictor,
                                                      self.options,
                                                      self.model_architecture)
        actions_with_scores = action_gen.step(self.node, self.test_premise_set)
        for action, score in sorted(actions_with_scores,
                                    key=lambda x: x.score):
            tf.logging.info(str(score) + ': ' + str(action))
        self.assertIn('EQ_TAC', [action for action, _ in actions_with_scores])