def testProofStateEmbedding(self, predictor_str):
   predictor = self._get_predictor(predictor_str)
   goal = 'a f bool v A y'
   state = predictions.ProofState(goal=goal)
   goal_emb = predictor.goal_embedding(goal)
   expected_state_emb = predictions.EmbProofState(goal_emb=goal_emb)
   actual_state_emb = predictor.proof_state_embedding(state)
   self.assertAllClose(expected_state_emb.goal_emb, actual_state_emb.goal_emb)
   self.assertAllEqual(expected_state_emb[1:], actual_state_emb[1:])
예제 #2
0
 def proof_state_from_search(self, node) -> predictions.ProofState:
   """Convert from proof_search_tree.ProofSearchNode to ProofState."""
   return predictions.ProofState(goal=str(node.goal.conclusion))
예제 #3
0
 def proof_state_from_search(self, node):
     return predictions.ProofState(goal='goal')
예제 #4
0
  def step(self, node: proof_search_tree.ProofSearchNode,
           premises: proof_assistant_pb2.PremiseSet) -> List[Suggestion]:
    """Generates a list of possible ApplyTactic argument strings from a goal.

    Args:
      node: state of the proof search, starting at current goal.
      premises: Specification of the selection of premises that can be used for
        tactic parameters. Currently we are supporting only a single
        DatabaseSection.

    Returns:
      List of string arugments for HolLight.ApplyTactic function, along with
      scores (Suggestion).
    """
    assert not premises.reference_sets, ('Premise reference sets are not '
                                         'supported.')
    assert len(premises.sections) == 1, ('Premise set must have exactly one '
                                         'section.')
    # TODO(szegedy): If the premise is not specified, we want the whole
    # database to be used. Not sure if -1 or len(database.theorems) would do
    # that or not. Assertion will certainly fail before that.
    # Also we don't have checks on this use case.
    assert premises.sections[0].HasField('before_premise'), ('Premise is '
                                                             'required.')
    fp = premises.sections[0].before_premise
    thm_number = self.thm_index_by_fingerprint.get(fp)
    assert thm_number is not None
    assert theorem_fingerprint.Fingerprint(
        self.theorem_database.theorems[thm_number]) == fp
    thm_names = self.thm_names[:thm_number]
    tf.logging.debug(thm_names)
    # TODO(smloos): update predictor api to accept theorems directly
    proof_state = predictions.ProofState(
        goal=str(normalization_lib.normalize(node.goal).conclusion))
    proof_state_emb = self.predictor.proof_state_embedding(proof_state)
    proof_state_enc = self.predictor.proof_state_encoding(proof_state_emb)
    tf.logging.debug(proof_state_enc)
    tactic_scores = self._compute_tactic_scores(proof_state_enc)

    empty_emb = self.predictor.thm_embedding('')
    empty_emb_batch = np.reshape(empty_emb, [1, empty_emb.shape[0]])

    enumerated_tactics = enumerate(self.tactics)
    if self.options.asm_meson_only:
      enumerated_tactics = [
          v for v in enumerated_tactics if str(v[1].name) == 'ASM_MESON_TAC'
      ]
      assert enumerated_tactics, (
          'action generator option asm_meson_only requires ASM_MESON_TAC.')

    ranked_closest = self.compute_closest(node.goal, thm_number)
    if ranked_closest:
      tf.logging.info(
          'Cosine closest picked:\n%s', '\n'.join(
              ['%s: %.6f' % (name, score) for score, name in ranked_closest]))

    ret = []
    thm_scores = None
    # TODO(smloos): This computes parameters for all tactics. It should cut off
    # based on the prover BFS options.
    for tactic_id, tactic in enumerated_tactics:
      if (thm_scores is None or self.model_architecture ==
          deephol_pb2.ProverOptions.PARAMETERS_CONDITIONED_ON_TAC):
        thm_scores = self._get_theorem_scores(proof_state_enc, thm_number,
                                              tactic_id)
        tf.logging.debug(thm_scores)
        no_params_score = self.predictor.batch_thm_scores(
            proof_state_enc, empty_emb_batch, tactic_id)[0]
        tf.logging.info('Theorem score for empty theorem: %f0.2',
                        no_params_score)

      thm_ranked = sorted(
          zip(thm_scores, self.thm_names),
          reverse=True)[:self.options.max_theorem_parameters]
      pass_no_arguments = thm_ranked[-1][0] < no_params_score
      thm_ranked = self.add_similar(thm_ranked, ranked_closest)

      tf.logging.info('thm_ranked: %s', str(thm_ranked))
      tactic_str = str(tactic.name)
      try:
        tactic_params = _compute_parameter_string(
            list(tactic.parameter_types), pass_no_arguments, thm_ranked)
        for params_str in tactic_params:
          ret.append(
              Suggestion(
                  string=tactic_str + params_str,
                  score=tactic_scores[tactic_id]))
      except ValueError as e:
        tf.logging.warning('Failed to compute parameters for tactic %s: %s',
                           tactic.name, str(e))
    return ret