def create_prover(options: deephol_pb2.ProverOptions) -> Prover: """Creates a Prover object, initializing all dependencies.""" theorem_database = io_util.load_theorem_database_from_file( str(options.path_theorem_database)) tactics = io_util.load_tactics_from_file(str(options.path_tactics), str(options.path_tactics_replace)) if options.action_generator_options.asm_meson_no_params_only: tf.logging.warn( 'Note: Using Meson action generator with no parameters.') action_gen = action_generator.MesonActionGenerator() else: predictor = get_predictor(options) emb_store = None if options.HasField('theorem_embeddings'): emb_store = embedding_store.TheoremEmbeddingStore(predictor) emb_store.read_embeddings(str(options.theorem_embeddings)) assert emb_store.thm_embeddings.shape[0] == len( theorem_database.theorems) action_gen = action_generator.ActionGenerator( theorem_database, tactics, predictor, options.action_generator_options, options.model_architecture, emb_store) hol_wrapper = setup_prover(theorem_database) tf.logging.info('DeepHOL dependencies initialization complete.') if options.prover == 'bfs': return BFSProver(options, hol_wrapper, action_gen, theorem_database) return NoBacktrackProver(options, hol_wrapper, action_gen, theorem_database)
def _verify_prover_options(prover_options: deephol_pb2.ProverOptions) -> None: """Asserts some (incomplete) consistency requirements over prover_options.""" for field_name in [ 'path_tactics', 'path_tactics_replace', 'path_theorem_database', 'path_model_prefix' ]: if not prover_options.HasField(field_name): tf.logging.fatal('Missing field "%s" in ProverOptions', field_name) if prover_options.prover not in ['nobacktrack', 'bfs']: tf.logging.fatal('Unsupported proof strategy: "%s"', prover_options.prover) history_dependent = [HIST_AVG, HIST_CONV, HIST_ATT] if prover_options.model_architecture in history_dependent: if not prover_options.path_emb_model_prefix: tf.logging.fatal( 'History dependent model %s requires embeddings checkpoint ' 'path_emb_model_prefix.', deephol_pb2.ProverOptions.ModelArchitecture.Name( prover_options.model_architecture)) # Light assertions on file naming conventions for embedding consistency. # Embedding checkpoint number should be the end of the input file. emb_checkpoint_num = next( re.finditer(r'\d+$', prover_options.path_emb_model_prefix)).group(0) if emb_checkpoint_num not in prover_options.path_model_prefix: tf.logging.fatal( 'Embeddings checkpoint number (%s) was not found ' 'in the path of predictions checkpoint (%s), indicating ' 'it was trained with different embeddings.', emb_checkpoint_num, prover_options.path_model_prefix)
def cache_embeddings(options: deephol_pb2.ProverOptions): emb_path = str(options.theorem_embeddings) if options.HasField( 'theorem_embeddings') and not tf.gfile.Exists(emb_path): tf.logging.info( 'theorem_embeddings file "%s" does not exist, computing & saving.', emb_path) emb_store = embedding_store.TheoremEmbeddingStore( get_predictor(options)) emb_store.compute_embeddings_for_thms_from_db_file( str(options.path_theorem_database)) emb_store.save_embeddings(emb_path)
def _sample_bfs_options(prover_options: deephol_pb2.ProverOptions): """Sample parameters according the meta options.""" if not prover_options.HasField('bfs_options'): return options = prover_options.bfs_options if options.HasField('meta_options'): meta_options = options.meta_options if meta_options.HasField('max_top_suggestions'): options.max_top_suggestions = _sample_from_interval( meta_options.max_top_suggestions) if meta_options.HasField('max_successful_branches'): options.max_successful_branches = _sample_from_interval( meta_options.max_successful_branches) if meta_options.HasField('max_explored_nodes'): options.max_explored_nodes = _sample_from_interval( meta_options.max_explored_nodes) if meta_options.HasField('min_successful_branches'): options.min_successful_branches = _sample_from_interval( meta_options.min_successful_branches) if meta_options.HasField('max_theorem_parameters'): prover_options.action_generator_options.max_theorem_parameters = ( _sample_from_interval(meta_options.max_theorem_parameters))