def testProofStateEmbedding(self, predictor_str): predictor = self._get_predictor(predictor_str) goal = 'a f bool v A y' state = predictions.ProofState(goal=goal) goal_emb = predictor.goal_embedding(goal) expected_state_emb = predictions.EmbProofState(goal_emb=goal_emb) actual_state_emb = predictor.proof_state_embedding(state) self.assertAllClose(expected_state_emb.goal_emb, actual_state_emb.goal_emb) self.assertAllEqual(expected_state_emb[1:], actual_state_emb[1:])
def proof_state_from_search(self, node) -> predictions.ProofState: """Convert from proof_search_tree.ProofSearchNode to ProofState.""" return predictions.ProofState(goal=str(node.goal.conclusion))
def proof_state_from_search(self, node): return predictions.ProofState(goal='goal')
def step(self, node: proof_search_tree.ProofSearchNode, premises: proof_assistant_pb2.PremiseSet) -> List[Suggestion]: """Generates a list of possible ApplyTactic argument strings from a goal. Args: node: state of the proof search, starting at current goal. premises: Specification of the selection of premises that can be used for tactic parameters. Currently we are supporting only a single DatabaseSection. Returns: List of string arugments for HolLight.ApplyTactic function, along with scores (Suggestion). """ assert not premises.reference_sets, ('Premise reference sets are not ' 'supported.') assert len(premises.sections) == 1, ('Premise set must have exactly one ' 'section.') # TODO(szegedy): If the premise is not specified, we want the whole # database to be used. Not sure if -1 or len(database.theorems) would do # that or not. Assertion will certainly fail before that. # Also we don't have checks on this use case. assert premises.sections[0].HasField('before_premise'), ('Premise is ' 'required.') fp = premises.sections[0].before_premise thm_number = self.thm_index_by_fingerprint.get(fp) assert thm_number is not None assert theorem_fingerprint.Fingerprint( self.theorem_database.theorems[thm_number]) == fp thm_names = self.thm_names[:thm_number] tf.logging.debug(thm_names) # TODO(smloos): update predictor api to accept theorems directly proof_state = predictions.ProofState( goal=str(normalization_lib.normalize(node.goal).conclusion)) proof_state_emb = self.predictor.proof_state_embedding(proof_state) proof_state_enc = self.predictor.proof_state_encoding(proof_state_emb) tf.logging.debug(proof_state_enc) tactic_scores = self._compute_tactic_scores(proof_state_enc) empty_emb = self.predictor.thm_embedding('') empty_emb_batch = np.reshape(empty_emb, [1, empty_emb.shape[0]]) enumerated_tactics = enumerate(self.tactics) if self.options.asm_meson_only: enumerated_tactics = [ v for v in enumerated_tactics if str(v[1].name) == 'ASM_MESON_TAC' ] assert enumerated_tactics, ( 'action generator option asm_meson_only requires ASM_MESON_TAC.') ranked_closest = self.compute_closest(node.goal, thm_number) if ranked_closest: tf.logging.info( 'Cosine closest picked:\n%s', '\n'.join( ['%s: %.6f' % (name, score) for score, name in ranked_closest])) ret = [] thm_scores = None # TODO(smloos): This computes parameters for all tactics. It should cut off # based on the prover BFS options. for tactic_id, tactic in enumerated_tactics: if (thm_scores is None or self.model_architecture == deephol_pb2.ProverOptions.PARAMETERS_CONDITIONED_ON_TAC): thm_scores = self._get_theorem_scores(proof_state_enc, thm_number, tactic_id) tf.logging.debug(thm_scores) no_params_score = self.predictor.batch_thm_scores( proof_state_enc, empty_emb_batch, tactic_id)[0] tf.logging.info('Theorem score for empty theorem: %f0.2', no_params_score) thm_ranked = sorted( zip(thm_scores, self.thm_names), reverse=True)[:self.options.max_theorem_parameters] pass_no_arguments = thm_ranked[-1][0] < no_params_score thm_ranked = self.add_similar(thm_ranked, ranked_closest) tf.logging.info('thm_ranked: %s', str(thm_ranked)) tactic_str = str(tactic.name) try: tactic_params = _compute_parameter_string( list(tactic.parameter_types), pass_no_arguments, thm_ranked) for params_str in tactic_params: ret.append( Suggestion( string=tactic_str + params_str, score=tactic_scores[tactic_id])) except ValueError as e: tf.logging.warning('Failed to compute parameters for tactic %s: %s', tactic.name, str(e)) return ret