def sample( self, ptra: ProofTraceActions, repl: REPL, tries: int, conclusion: bool = False, ) -> Action: for i in range(tries): if not conclusion: action = random.choice(list(NON_PREPARE_TOKENS.values())) else: action = random.choice(list(CONCLUSION_TOKENS.values())) if INV_PROOFTRACE_TOKENS[action] == 'REFL': left = self.sample_term() right = 0 if INV_PROOFTRACE_TOKENS[action] == 'TRANS': left = self.sample_theorem(ptra) right = self.sample_theorem(ptra) if INV_PROOFTRACE_TOKENS[action] == 'MK_COMB': left = self.sample_theorem(ptra) right = self.sample_theorem(ptra) if INV_PROOFTRACE_TOKENS[action] == 'ABS': left = self.sample_theorem(ptra) right = self.sample_term() if INV_PROOFTRACE_TOKENS[action] == 'BETA': left = self.sample_term() right = 0 if INV_PROOFTRACE_TOKENS[action] == 'ASSUME': left = self.sample_term() right = 0 if INV_PROOFTRACE_TOKENS[action] == 'EQ_MP': left = self.sample_theorem(ptra) right = self.sample_theorem(ptra) if INV_PROOFTRACE_TOKENS[action] == 'DEDUCT_ANTISYM_RULE': left = self.sample_theorem(ptra) right = self.sample_theorem(ptra) if INV_PROOFTRACE_TOKENS[action] == 'INST': left = self.sample_theorem(ptra) right = self.sample_subst() if INV_PROOFTRACE_TOKENS[action] == 'INST_TYPE': left = self.sample_theorem(ptra) right = self.sample_subst_type() a = Action.from_action( INV_PROOFTRACE_TOKENS[action], ptra.arguments()[left], ptra.arguments()[right], ) if ptra.seen(a): continue if not repl.valid(a): continue return a return None
def prepare( self, ptra: ProofTraceActions, ) -> Thm: for i, a in enumerate(ptra.actions()): if a.value == PROOFTRACE_TOKENS['TARGET']: target = Thm( ptra.actions()[i].index(), self.build_hypothesis(ptra.actions()[i].left), ptra.actions()[i].right.value, ) if a.value == PROOFTRACE_TOKENS['PREMISE']: self.apply(a) return target
def __init__( self, config: Config, l_model: LModel, ptra: ProofTraceActions, repl: REPL, target: Thm, ) -> None: super(Beam, self).__init__(config, ptra, repl, target) self._l_model = l_model index, actions, arguments = self.preprocess_ptra(ptra) with torch.no_grad(): prd_actions, prd_lefts, prd_rights = \ self._l_model.infer([index], [actions], [arguments]) self._ptras = [ptra.copy()] self._repls = [repl.copy()] self._heads = [ Head( prd_actions[0][0].cpu(), prd_lefts[0][0].cpu(), prd_rights[0][0].cpu(), 1.0, # PROB ) ]
def apply( self, ptra: ProofTraceActions, repl: REPL, beta_width: int, head_width: int, ) -> typing.List[typing.Tuple[float, Action], ]: a_count = min( beta_width, len(PROOFTRACE_TOKENS) - len(PREPARE_TOKENS), ) top_actions = torch.exp(self._prd_actions.cpu()).topk(a_count) top_lefts = torch.exp(self._prd_lefts.cpu()).topk(beta_width) top_rights = torch.exp(self._prd_rights.cpu()).topk(beta_width) candidates = [] for ia in range(a_count): for il in range(beta_width): for ir in range(beta_width): action = top_actions[1][ia].item() left = top_lefts[1][il].item() right = top_rights[1][ir].item() if left >= ptra.len() or right >= ptra.len(): continue a = Action.from_action( INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)], ptra.arguments()[left], ptra.arguments()[right], ) if ptra.seen(a): continue if not repl.valid(a): continue candidates.append(( self._value * # PROB top_actions[0][ia].item() * top_lefts[0][il].item() * top_rights[0][ir].item(), a)) return sorted(candidates, key=lambda c: c[0], reverse=True)[:head_width]
def sample_theorem( self, ptra: ProofTraceActions, ): indices = self._premises_indices + \ list(range(ptra.prepare_len(), ptra.len())) # If we don't have any theorem to return (no premise, no action yet) we # return an invalid position (1 which is the EMPTY action). if len(indices) == 0: return 1 probs = [ float(p) / sum(range(1, len(indices) + 1)) for p in range(1, len(indices) + 1) ] return indices[numpy.array(probs).cumsum().searchsorted( numpy.random.sample(1))[0]]
def preprocess_ptra( self, ptra: ProofTraceActions, ) -> typing.Tuple[int, typing.List[Action], typing.List[Action], ]: actions = ptra.actions().copy() arguments = ptra.arguments().copy() index = len(actions) - 1 assert index < self._config.get('prooftrace_sequence_length') empty = ptra.actions()[1] assert empty.value == PREPARE_TOKENS['EMPTY'] extract = Action.from_action('EXTRACT', empty, empty) while len(actions) < self._config.get('prooftrace_sequence_length'): actions.append(extract) while len(arguments) < self._config.get('prooftrace_sequence_length'): arguments.append(empty) return index, actions, arguments
def bootstrap( config: Config, tokenizer: ProofTraceTokenizer, model: Model, ground: ProofTraceActions, target: Thm, ): ptra = ProofTraceActions( 'TREE-{}-{}'.format( datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"), random.randint(0, 9999), ), [a for a in ground.actions() if a.value in INV_PREPARE_TOKENS], ) repl = REPL(tokenizer) repl.prepare(ptra) pre_trc, pre_idx = \ Node.prepare(ptra, None, config.get('prooftrace_sequence_length')) trc = [pre_trc] idx = [pre_idx] prd_actions, prd_lefts, prd_rights, prd_values = \ model.infer(trc, idx) return Node( config, None, model, ground, target, ptra, repl, prd_actions[0].to(torch.device('cpu')), prd_lefts[0].to(torch.device('cpu')), prd_rights[0].to(torch.device('cpu')), # prd_values[0].item(), )
def __init__( self, config: Config, l_model: LModel, ptra: ProofTraceActions, repl: REPL, target: Thm, ) -> None: super(PolicySample, self).__init__(config, ptra, repl, target) self._l_model = l_model self._ptra = ptra.copy() self._repl = repl.copy()
def __init__( self, config: Config, ptra: ProofTraceActions, repl: REPL, target: Thm, ) -> None: super(Random, self).__init__(config, ptra, repl, target) self._ptra = ptra.copy() self._repl = repl.copy() self._last_thm = None self._sampler = RandomSampler(self._ptra)
def __init__( self, ptra: ProofTraceActions, ) -> None: self._term_indices = [0] + [ i for i in range(ptra.len()) if ptra.actions()[i].value == 7 ] self._subst_indices = [0] + [ i for i in range(ptra.len()) if ptra.actions()[i].value == 4 ] self._subst_type_indices = [0] + [ i for i in range(ptra.len()) if ptra.actions()[i].value == 5 ] self._premises_indices = [ i for i in range(1, ptra.len()) if ptra.actions()[i].value == 2 ]
def prepare( ptra: ProofTraceActions, a: Action, sequence_length: int, ) -> typing.Tuple[typing.List[Action], typing.List[int], ]: trc = ptra.actions().copy() idx = len(trc) if a is not None: trc.append(a) idx += 1 trc.append(Action.from_action('EXTRACT', None, None)) empty = Action.from_action('EMPTY', None, None) while len(trc) < sequence_length: trc.append(empty) return trc, idx
def __init__( self, config: Config, l_model: LModel, v_model: VModel, ptra: ProofTraceActions, repl: REPL, target: Thm, ) -> None: super(ParticleFilter, self).__init__(config, ptra, repl, target) self._l_model = l_model self._v_model = v_model self._filter_size = \ config.get('prooftrace_search_particle_filter_size') self._sample_size = \ config.get('prooftrace_search_particle_filter_sample_size') self._particles = [{ 'ptra': ptra.copy(), 'repl': repl.copy(), } for _ in range(self._filter_size)]
def run_once(self, ): info = self._wrk.fetch(self._device, False) if info is not None: self.update(info['config']) for m in self._model.modules(): self._model.modules()[m].eval() assert os.path.isdir(self._rollout_dir) rdirs = [ os.path.join(self._rollout_dir, d) for d in os.listdir(self._rollout_dir) if os.path.isdir(os.path.join(self._rollout_dir, d)) ] rdir = random.choice(rdirs) rfiles = sorted([ os.path.join(rdir, f) for f in os.listdir(rdir) if re.search(".rollout$", f) ], reverse=True) if len(rfiles) == 0: return path = rfiles[0] with gzip.open(path, 'rb') as f: base = pickle.load(f) gamma = random.choice(GAMMAS) ground = base.positive() name = base.name() ptra = ProofTraceActions( 'BEAM-{}-{}'.format( datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"), random.randint(0, 9999), ), [ ground.actions()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], [ ground.arguments()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], ) repl = REPL(self._tokenizer) target = repl.prepare(ptra) gamma = min(ground.action_len(), gamma) gamma_len = ground.action_len() - gamma offset = ground.prepare_len() + gamma_len for i in range(gamma_len): assert ground.prepare_len() + i < ground.len() - 1 pos = ground.prepare_len() + i action = ground.actions()[pos] argument = ground.arguments()[pos] thm = repl.apply(action) action._index = thm.index() argument._index = thm.index() ptra.append(action, argument) search = None if self._config.get('prooftrace_search_type') == 'beam': search = Beam(self._config, self._model, ptra, repl, target) if self._config.get('prooftrace_search_type') == 'mcts': search = MCTS(self._config, self._model, ptra, repl, target) assert search is not None Log.out( "ROLLOUT START", { 'name': name, 'gamma': gamma, 'prepare_length': ground.prepare_len(), 'action_length': ground.action_len(), 'length': ground.len(), }) Log.out("TARGET", { 'name': name, 'summary': ground.summary(offset), }) rollout = None proven = False ptra = None depth = self._config.get('prooftrace_search_depth') if self._config.get('prooftrace_search_type') == 'beam': depth = gamma * 2 for i in range(depth): step_start = time.time() done, ptra, proven = search.step(i == (depth - 1), offset) step_end = time.time() Log.out( 'STEP', { 'i': i, 'done': done, 'proven': proven, 'gamma': gamma, 'time': "{:.2f}".format(step_end - step_start), }) if done: if proven: rollout = Rollout(name, [ptra], []) else: rollout = Rollout(name, [], [ptra]) break if (step_end - step_start) > \ self._config.get('prooftrace_search_step_timeout'): rollout = Rollout(name, [], [ptra]) break demo_length = (ptra.len() - (ground.prepare_len() + gamma_len)) Log.out( "ROLLOUT END", { 'name': name, 'proven': proven, 'gamma': gamma, 'demo_length': demo_length, }) Log.out("PTRA", { 'name': name, 'summary': ptra.summary(offset), }) if demo_length > 0: info = { 'rll_cnt': 1, 'pos_cnt': 1 if proven else 0, 'neg_cnt': 0 if proven else 1, } if proven: info['demo_len'] = demo_length # Publish the statistics. self._wrk.publish(info) # Finally merge and store the new rollout base.merge(rollout) now = datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f") rnd = random.randint(0, 10e9) tmp_path = os.path.join(rdir, "{}_{}.tmp".format(now, rnd)) fnl_path = os.path.join(rdir, "{}_{}.rollout".format(now, rnd)) with gzip.open(tmp_path, 'wb') as f: pickle.dump(base, f, protocol=pickle.HIGHEST_PROTOCOL) os.rename(tmp_path, fnl_path) del base del rollout if len(rfiles) > 1: for p in rfiles[1:]: try: os.remove(p) except FileNotFoundError: pass Log.out("MERGE WRITE", { 'name': name, 'path': fnl_path, })
def reset( self, gamma: float, fixed_gamma: int, ) -> typing.Tuple[int, typing.List[Action]]: self._ground = None self._run = None self._repl = None self._target = None self._alpha = 0 self._gamma_len = 0 self._match_count = 0 while self._ground is None: path = random.choice(self._trace_files) match = re.search("_(\\d+)_(\\d+)\\.actions$", path) ptra_len = int(match.group(1)) if ptra_len <= self._sequence_length: with gzip.open(path, 'rb') as f: self._ground = pickle.load(f) # Log.out("Selecting trace", { # "trace": self._ground.name(), # 'length': self._ground.len(), # }) self._run = ProofTraceActions( 'REPL-{}-{}'.format( datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"), random.randint(0, 9999), ), [ self._ground.actions()[i] for i in range(self._ground.len()) if self._ground.actions()[i].value in INV_PREPARE_TOKENS ], [ self._ground.arguments()[i] for i in range(self._ground.len()) if self._ground.actions()[i].value in INV_PREPARE_TOKENS ], ) self._repl = REPL(self._tokenizer) self._target = self._repl.prepare(self._run) # GAMMA Initialization. if gamma > 0.0 and random.random() < gamma: if fixed_gamma > 0: self._gamma_len = self._ground.action_len() - \ random.randrange( 1, min(fixed_gamma, self._ground.action_len()) + 1 ) else: self._gamma_len = random.randrange(0, self._ground.action_len()) for i in range(self._gamma_len): assert self._ground.prepare_len() + i < self._ground.len() - 1 pos = self._ground.prepare_len() + i action = self._ground.actions()[pos] argument = self._ground.arguments()[pos] thm = self._repl.apply(action) action._index = thm.index() argument._index = thm.index() self._run.append(action, argument) return self.observation()
def search(): parser = argparse.ArgumentParser(description="") parser.add_argument( 'config_path', type=str, help="path to the config file", ) parser.add_argument( '--dataset_size', type=str, help="config override", ) parser.add_argument( '--load_dir', type=str, help="config override", ) parser.add_argument( '--device', type=str, help="config override", ) parser.add_argument( '--train', type=str2bool, help="search training set", ) args = parser.parse_args() config = Config.from_file(args.config_path) if args.device is not None: config.override('device', args.device) if args.dataset_size is not None: config.override( 'prooftrace_dataset_size', args.dataset_size, ) if args.load_dir is not None: config.override( 'prooftrace_load_dir', os.path.expanduser(args.load_dir), ) train = False if args.train is not None: train = args.train if train: dataset_dir = os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'train_traces') else: dataset_dir = os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'test_traces') assert os.path.isdir(dataset_dir) files = [ os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir) if os.path.isfile(os.path.join(dataset_dir, f)) ] cases = [] with gzip.open( os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'traces.tokenizer', ), 'rb') as f: tokenizer = pickle.load(f) for p in files: match = re.search("_(\\d+)_(\\d+)\\.actions$", p) if match is None: continue ptra_len = int(match.group(1)) cases.append((p, ptra_len)) Log.out("Loaded ProofTraceActions", { 'cases': len(cases), }) l_model = LModel(config).load() # v_model = VModel(config).load() cases = sorted(cases, key=lambda c: c[1]) for i in range(len(cases)): c = cases[i][0] with gzip.open(c, 'rb') as f: ground = pickle.load(f) ptra = ProofTraceActions( 'SEARCH-{}-{}'.format( datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"), random.randint(0, 9999), ), [ ground.actions()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], [ ground.arguments()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], ) repl = REPL(tokenizer) target = repl.prepare(ptra) offset = 0 fixed_gamma = config.get('prooftrace_search_fixed_gamma') if fixed_gamma > 0: gamma_len = max(ground.action_len() - fixed_gamma, 0) offset = ground.prepare_len() + gamma_len for i in range(gamma_len): assert ground.prepare_len() + i < ground.len() - 1 pos = ground.prepare_len() + i action = ground.actions()[pos] argument = ground.arguments()[pos] thm = repl.apply(action) action._index = thm.index() argument._index = thm.index() ptra.append(action, argument) Log.out( "TARGET", { 'name': ground.name(), 'prepare_length': ground.prepare_len(), 'action_length': ground.action_len(), 'summary': ground.summary(offset), 'theorem': target.thm_string(False, True), }) search = None if config.get('prooftrace_search_type') == 'beam': search = Beam(config, l_model, ptra, repl, target) # if config.get('prooftrace_search_type') == 'mcts': # search = MCTS(config, l_model, v_model, ptra, repl, target) # if config.get('prooftrace_search_type') == 'particle_filter': # search = ParticleFilter( # config, l_model, v_model, ptra, repl, target, # ) if config.get('prooftrace_search_type') == 'policy_sample': search = PolicySample(config, l_model, ptra, repl, target) assert search is not None depth = config.get('prooftrace_sequence_length') - \ ground.prepare_len() if fixed_gamma != 0: if 2 * fixed_gamma < depth: depth = fixed_gamma * 2 else: if 2 * ground.action_len() < depth: depth = 2 * ground.action_len() for i in range(depth): if fixed_gamma != 0: conclusion = (i >= fixed_gamma * 2) else: conclusion = (i >= ground.action_len()) step_start = time.time() done, ptra, proved = search.step(offset, conclusion) step_end = time.time() Log.out( 'STEP', { 'i': i, 'done': done, 'proved': proved, 'time': "{:.2f}".format(step_end - step_start), 'summary': ptra.summary(offset), }) if done: if proved: Log.out("DEMONSTRATED", { 'theorem': target.thm_string(False, True), }) break # if (step_end - step_start) > \ # config.get('prooftrace_search_step_timeout'): # break Log.out("FINISH", { 'summary': ptra.summary(offset), }) if config.get('prooftrace_search_type') == 'random' \ and search.last_thm() is not None: Log.out("GENERATED", {'theorem': search.last_thm().thm_string(False, True)})
def run_once(self, ): info = self._wrk.fetch(self._device, False) if info is not None: self.update(info['config']) for m in self._model.modules(): self._model.modules()[m].eval() assert os.path.isdir(self._rollout_dir) rdirs = [ os.path.join(self._rollout_dir, d) for d in os.listdir(self._rollout_dir) if os.path.isdir(os.path.join(self._rollout_dir, d)) ] rdir = random.choice(rdirs) rfiles = sorted([ os.path.join(rdir, f) for f in os.listdir(rdir) if re.search(".rollout$", f) ], reverse=True) if len(rfiles) == 0: return path = rfiles[0] with gzip.open(path, 'rb') as f: base = pickle.load(f) ground = base.positive() name = base.name() ptra = ProofTraceActions( 'ROLLOUT-{}-{}'.format( datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"), random.randint(0, 9999), ), [ ground.actions()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], [ ground.arguments()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], ) repl = REPL(self._tokenizer) target = repl.prepare(ptra) search = None if self._config.get('prooftrace_search_type') == 'beam': search = Beam( self._config, self._model, ptra, repl, target, ) if self._config.get('prooftrace_search_type') == 'policy_sample': search = PolicySample( self._config, self._model, ptra, repl, target, ) assert search is not None depth = self._config.get('prooftrace_sequence_length') - \ ground.prepare_len() if 2 * ground.action_len() < depth: depth = 2 * ground.action_len() Log.out( "ROLLOUT START", { 'name': name, 'prepare_length': ground.prepare_len(), 'action_length': ground.action_len(), 'depth': depth, }) rollout = None proved = False ptra = None for i in range(depth): step_start = time.time() done, ptra, proved = search.step() step_end = time.time() Log.out( 'STEP', { 'i': i, 'done': done, 'proved': proved, 'time': "{:.2f}".format(step_end - step_start), }) if done: break if (step_end - step_start) > 20: # self._config.get('prooftrace_search_step_timeout'): break if proved: rollout = Rollout(name, [ptra], []) else: rollout = Rollout(name, [], [ptra]) demo_length = ptra.action_len() demo_delta = ptra.action_len() - ground.action_len() Log.out( "ROLLOUT END", { 'name': name, 'proved': proved, 'demo_length': demo_length, 'demo_delta': demo_delta }) if proved: Log.out("PTRA", { 'name': name, 'summary': ptra.summary(), }) if demo_length > 0: info = { 'rll_cnt': 1, 'pos_cnt': 1 if proved else 0, 'neg_cnt': 0 if proved else 1, } if proved: info['demo_len'] = demo_length info['demo_dlt'] = demo_delta # Publish the statistics. self._wrk.publish(info) # Finally merge and store the new rollout base.merge(rollout) now = datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f") rnd = random.randint(0, 10e9) tmp_path = os.path.join(rdir, "{}_{}.tmp".format(now, rnd)) fnl_path = os.path.join(rdir, "{}_{}.rollout".format(now, rnd)) with gzip.open(tmp_path, 'wb') as f: pickle.dump(base, f, protocol=pickle.HIGHEST_PROTOCOL) os.rename(tmp_path, fnl_path) del base del rollout if len(rfiles) > 1: for p in rfiles[1:]: try: os.remove(p) except FileNotFoundError: pass Log.out("MERGE WRITE", { 'name': name, 'path': fnl_path, })
class Env: def __init__( self, config: Config, test: bool, ) -> None: self._sequence_length = config.get('prooftrace_sequence_length') self._device = torch.device(config.get('device')) if test: dataset_dir = os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'test_traces') else: dataset_dir = os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'train_traces') assert os.path.isdir(dataset_dir) self._trace_files = [ os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir) if (os.path.isfile(os.path.join(dataset_dir, f)) and re.search("\\.actions$", f) is not None) ] with gzip.open( os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'traces.tokenizer', ), 'rb') as f: self._tokenizer = pickle.load(f) self._ground = None self._run = None self._repl = None self._target = None self._alpha = 0 def reset( self, gamma: float, fixed_gamma: int, ) -> typing.Tuple[int, typing.List[Action]]: self._ground = None self._run = None self._repl = None self._target = None self._alpha = 0 self._gamma_len = 0 self._match_count = 0 while self._ground is None: path = random.choice(self._trace_files) match = re.search("_(\\d+)_(\\d+)\\.actions$", path) ptra_len = int(match.group(1)) if ptra_len <= self._sequence_length: with gzip.open(path, 'rb') as f: self._ground = pickle.load(f) # Log.out("Selecting trace", { # "trace": self._ground.name(), # 'length': self._ground.len(), # }) self._run = ProofTraceActions( 'REPL-{}-{}'.format( datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"), random.randint(0, 9999), ), [ self._ground.actions()[i] for i in range(self._ground.len()) if self._ground.actions()[i].value in INV_PREPARE_TOKENS ], [ self._ground.arguments()[i] for i in range(self._ground.len()) if self._ground.actions()[i].value in INV_PREPARE_TOKENS ], ) self._repl = REPL(self._tokenizer) self._target = self._repl.prepare(self._run) # GAMMA Initialization. if gamma > 0.0 and random.random() < gamma: if fixed_gamma > 0: self._gamma_len = self._ground.action_len() - \ random.randrange( 1, min(fixed_gamma, self._ground.action_len()) + 1 ) else: self._gamma_len = random.randrange(0, self._ground.action_len()) for i in range(self._gamma_len): assert self._ground.prepare_len() + i < self._ground.len() - 1 pos = self._ground.prepare_len() + i action = self._ground.actions()[pos] argument = self._ground.arguments()[pos] thm = self._repl.apply(action) action._index = thm.index() argument._index = thm.index() self._run.append(action, argument) return self.observation() def observation( self, ) -> typing.Tuple[int, typing.List[Action], typing.List[Action], ]: actions = self._run.actions().copy() arguments = self._run.arguments().copy() # If the len match this is a final observation, so no extract will be # appended and that's fine because this observation won't make it to # the agent. if len(actions) < self._sequence_length: actions.append(Action.from_action('EXTRACT', None, None)) # Finally we always return actions with the same length. empty = Action.from_action('EMPTY', None, None) while len(actions) < self._sequence_length: actions.append(empty) while len(arguments) < self._sequence_length: arguments.append(empty) return (self._run.len(), actions, arguments) def alpha_oracle(self, ) -> typing.Tuple[torch.Tensor, int]: self._alpha += 1 for i in range(self._ground.prepare_len(), self._ground.len()): a = self._ground.actions()[i] if (not self._run.seen(a)) and \ self._run.seen(a.left) and \ self._run.seen(a.right): assert 0 <= a.value - len(PREPARE_TOKENS) assert a.value < len(PROOFTRACE_TOKENS) actions = torch.tensor([[ a.value - len(PREPARE_TOKENS), self._run.hashes()[a.left.hash()], self._run.hashes()[a.right.hash()], ]], dtype=torch.int64).to(self._device) return actions, 0 # We may reach this point as final actions are sometime repeated at the # end of prooftraces. return None, 0 def beta_oracle( self, prd_actions: torch.Tensor, prd_lefts: torch.Tensor, prd_rights: torch.Tensor, beta_width: int, beta_size: int, ) -> typing.Tuple[torch.Tensor, int]: top_actions = torch.exp(prd_actions).topk(beta_width) top_lefts = torch.exp(prd_lefts).topk(beta_width) top_rights = torch.exp(prd_rights).topk(beta_width) out = [] frame_count = 0 for ia in range(beta_width): for il in range(beta_width): for ir in range(beta_width): action = top_actions[1][ia].item() assert action >= 0 assert action < len(PROOFTRACE_TOKENS) - len( PREPARE_TOKENS) left = top_lefts[1][il].item() right = top_rights[1][ir].item() prob = top_actions[0][ia].item() * \ top_lefts[0][il].item() * \ top_rights[0][ir].item() if left >= self._run.len() or right >= self._run.len(): out.append(([action, left, right], prob)) continue a = Action.from_action( INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)], self._run.arguments()[left], self._run.arguments()[right], ) if self._run.seen(a): out.append(([action, left, right], prob)) continue frame_count += 1 if not self._repl.valid(a): out.append(([action, left, right], prob)) continue out.append(([action, left, right], prob + 1.0)) out = sorted(out, key=lambda o: o[1], reverse=True) actions = [] for i in range(beta_size): actions.append(out[i][0]) return \ torch.tensor(actions, dtype=torch.int64).to(self._device), \ frame_count def explore( self, prd_actions: torch.Tensor, prd_lefts: torch.Tensor, prd_rights: torch.Tensor, alpha: float, beta: float, beta_width: int, ) -> typing.Tuple[torch.Tensor, int]: # ALPHA Oracle. if alpha > 0.0 and random.random() < alpha and self._alpha == 0: actions, frame_count = self.alpha_oracle() if actions is not None: return actions, frame_count # BETA Oracle. if beta > 0.0 and random.random() < beta: return self.beta_oracle( prd_actions, prd_lefts, prd_rights, beta_width, 1, ) # Sampling. actions = torch.cat(( Categorical( torch.exp(prd_actions)).sample().unsqueeze(0).unsqueeze(1), Categorical( torch.exp(prd_lefts)).sample().unsqueeze(0).unsqueeze(1), Categorical( torch.exp(prd_rights)).sample().unsqueeze(0).unsqueeze(1), ), dim=1) return actions, 0 def step( self, action: typing.Tuple[int, int, int], step_reward_prob: float, match_reward_prob: float, gamma: float, fixed_gamma: int, ) -> typing.Tuple[typing.Tuple[int, typing.List[Action]], typing.Tuple[ float, float, float], bool, typing.Dict[str, int], ]: assert self._ground is not None assert self._run is not None def finish(rewards, done, info): if done: observation = self.reset(gamma, fixed_gamma) else: observation = self.observation() return observation, rewards, done, info if action[1] >= self._run.len() or action[2] >= self._run.len(): Log.out( "DONE ILLEGAL[overflow]", { 'ground_length': self._ground.action_len(), 'gamma_length': self._gamma_len, 'run_length': self._run.action_len() - self._gamma_len, 'name': self._ground.name(), }) return finish( (0.0, 0.0, 0.0), True, { 'match_count': self._match_count, 'run_length': self._run.action_len() - self._gamma_len, }) action = Action.from_action( INV_PROOFTRACE_TOKENS[action[0] + len(PREPARE_TOKENS)], self._run.arguments()[action[1]], self._run.arguments()[action[2]], ) if self._run.seen(action): Log.out( "DONE ILLEGAL[seen]", { 'ground_length': self._ground.action_len(), 'gamma_length': self._gamma_len, 'run_length': self._run.action_len() - self._gamma_len, 'name': self._ground.name(), }) return finish( (0.0, 0.0, 0.0), True, { 'match_count': self._match_count, 'run_length': self._run.action_len() - self._gamma_len, }) try: thm = self._repl.apply(action) except (FusionException, REPLException, TypeException): Log.out( "DONE ILLEGAL[fusion]", { 'ground_length': self._ground.action_len(), 'gamma_length': self._gamma_len, 'run_length': self._run.action_len() - self._gamma_len, 'name': self._ground.name(), }) return finish( (0.0, 0.0, 0.0), True, { 'match_count': self._match_count, 'run_length': self._run.action_len() - self._gamma_len, }) action._index = thm.index() argument = self._run.build_argument( thm.concl(), thm.hyp(), thm.index(), ) self._run.append(action, argument) step_reward = 0.0 match_reward = 0.0 final_reward = 0.0 done = False info = {} if step_reward_prob > 0.0 and random.random() < step_reward_prob: step_reward = 1.0 if self._ground.seen(action): self._match_count += 1 if match_reward_prob > 0.0 and random.random() < match_reward_prob: match_reward = 1.0 step_reward = 0.0 if self._target.thm_string(True) == thm.thm_string(True): final_reward = 10.0 done = True info['demo_length'] = min( self._run.action_len(), self._ground.action_len(), ) - self._gamma_len info['demo_delta'] = \ self._run.action_len() - self._ground.action_len() Log.out( "DEMONSTRATED", { 'ground_length': self._ground.action_len(), 'gamma_length': self._gamma_len, 'run_length': self._run.action_len() - self._gamma_len, 'name': self._ground.name(), }) if self._run.len() >= self._sequence_length: done = True Log.out( "DONE LENGTH ", { 'ground_length': self._ground.action_len(), 'gamma_length': self._gamma_len, 'run_length': self._run.action_len() - self._gamma_len, 'name': self._ground.name(), }) if done: info['match_count'] = self._match_count info['run_length'] = self._run.action_len() - self._gamma_len return finish((step_reward, match_reward, final_reward), done, info)
def search(): parser = argparse.ArgumentParser(description="") parser.add_argument( 'config_path', type=str, help="path to the config file", ) parser.add_argument( '--dataset_size', type=str, help="config override", ) parser.add_argument( '--load_dir', type=str, help="config override", ) parser.add_argument( '--device', type=str, help="config override", ) args = parser.parse_args() config = Config.from_file(args.config_path) if args.device is not None: config.override('device', args.device) if args.dataset_size is not None: config.override( 'prooftrace_dataset_size', args.dataset_size, ) if args.load_dir is not None: config.override( 'prooftrace_load_dir', os.path.expanduser(args.load_dir), ) dataset_dir = os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'test_traces' ) assert os.path.isdir(dataset_dir) files = [ os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir) if os.path.isfile(os.path.join(dataset_dir, f)) ] cases = [] with gzip.open( os.path.join( os.path.expanduser(config.get('prooftrace_dataset_dir')), config.get('prooftrace_dataset_size'), 'traces.tokenizer', ), 'rb') as f: tokenizer = pickle.load(f) for p in files: match = re.search("_(\\d+)_(\\d+)\\.actions$", p) if match is None: continue ptra_len = int(match.group(1)) cases.append((p, ptra_len)) Log.out( "Loaded ProofTraceActions", { 'cases': len(cases), }) model = SearchModel(config).load() cases = sorted(cases, key=lambda c: c[1]) for i in range(len(cases)): c = cases[i][0] with gzip.open(c, 'rb') as f: ground = pickle.load(f) ptra = ProofTraceActions( 'BEAM-{}-{}'.format( datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"), random.randint(0, 9999), ), [ ground.actions()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], [ ground.arguments()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], ) repl = REPL(tokenizer) target = repl.prepare(ptra) offset = 0 fixed_gamma = 4 if fixed_gamma > 0: gamma_len = max(ground.action_len() - fixed_gamma, 0) offset = ground.prepare_len() + gamma_len for i in range(gamma_len): assert ground.prepare_len() + i < ground.len() - 1 pos = ground.prepare_len() + i action = ground.actions()[pos] argument = ground.arguments()[pos] thm = repl.apply(action) action._index = thm.index() argument._index = thm.index() ptra.append(action, argument) Log.out("TARGET", { 'name': ground.name(), 'prepare_length': ground.prepare_len(), 'length': ground.action_len(), 'summary': ground.summary(offset), }) search = None if config.get('prooftrace_search_type') == 'beam': search = Beam(config, model, ptra, repl, target) if config.get('prooftrace_search_type') == 'mcts': search = MCTS(config, model, ptra, repl, target) assert search is not None depth = config.get('prooftrace_search_depth') if config.get('prooftrace_search_type') == 'beam': depth = fixed_gamma * 2 for i in range(depth): done, ptra, proved = search.step(False, offset) if done: break
def run_once(self, ): info = self._tst.fetch(self._device, False) if info is not None: self.update(info['config']) self._modules['E'].eval() self._modules['T'].eval() self._modules['PH'].eval() self._modules['VH'].eval() model = BeamModel(self._config, self._modules) assert os.path.isdir(self._dataset_dir) files = [ os.path.join(self._dataset_dir, f) for f in os.listdir(self._dataset_dir) if os.path.isfile(os.path.join(self._dataset_dir, f)) ] cases = {} for gamma in GAMMAS: cases[gamma] = [] for p in files: match = re.search("_(\\d+)_(\\d+)\\.actions$", p) if match is None: continue for gamma in GAMMAS: cases[gamma].append(p) info = { 'demo_len': 0.0, } for gamma in GAMMAS: cases[gamma] = random.sample(cases[gamma], self._test_gamma_size) info['gamma_{}'.format(gamma)] = 0.0 for gamma in GAMMAS: for i in range(len(cases[gamma])): c = cases[gamma][i] with gzip.open(c, 'rb') as f: ground = pickle.load(f) ptra = ProofTraceActions( 'BEAM-{}-{}'.format( datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"), random.randint(0, 9999), ), [ ground.actions()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], [ ground.arguments()[i] for i in range(ground.len()) if ground.actions()[i].value in INV_PREPARE_TOKENS ], ) repl = REPL(self._tokenizer) target = repl.prepare(ptra) offset = 0 gamma_len = max(ground.action_len() - gamma, 0) offset = ground.prepare_len() + gamma_len for i in range(gamma_len): assert ground.prepare_len() + i < ground.len() - 1 pos = ground.prepare_len() + i action = ground.actions()[pos] argument = ground.arguments()[pos] thm = repl.apply(action) action._index = thm.index() argument._index = thm.index() ptra.append(action, argument) Log.out( "TARGET", { 'name': ground.name(), 'prepare_length': ground.prepare_len(), 'length': ground.action_len(), # 'summary': ground.summary(offset), }) beam = Beam(self._config, model, ptra, repl, target) proven = False ptra = None for i in range(gamma): step_start = time.time() done, ptra, proven = beam.step(i == (gamma - 1), offset) step_end = time.time() if done: break if (step_end - step_start) > \ self._config.get('prooftrace_search_step_timeout'): break demo_length = (ptra.len() - (ground.prepare_len() + gamma_len)) Log.out( "DONE", { 'name': ground.name(), 'proven': proven, 'gamma': gamma, 'demo_length': demo_length, }) if proven: info['gamma_{}'.format(gamma)] += \ 1.0 / self._test_gamma_size info['demo_len'] += \ demo_length / (self._test_gamma_size * len(GAMMAS)) self._tst.publish(info)