def collect_traj_for_program(env, program, debug=False): env = env.clone() env.use_cache = False ob = env.start_ob for tk in program: valid_actions = list(ob[0].valid_indices) mapped_action = env.de_vocab.lookup(tk) try: action = valid_actions.index(mapped_action) except Exception as e: if debug: return None, (env.interpreter.namespace, env.actions, program, mapped_action, valid_actions) else: return None ob, _, _, _ = env.step(action) traj = agent_factory.Traj(obs=env.obs, actions=env.actions, rewards=env.rewards, context=env.get_context(), env_name=env.name, answer=env.interpreter.result) if debug: return traj, None else: return traj
def explore_sketch_programs(env, sketch, max_hyp=1000): '''provide a sketch, find all the executable programs fit that sketch''' env = env.clone() env.use_cache = False all_hyps = [(env, env.start_ob)] for cmd in sketch: # first add the left bracket and the cmd head new_all_hyps = [] for env, ob in all_hyps: try: action_index = list(ob[0].valid_indices).index( env.de_vocab.lookup('(')) ob, _, _, _ = env.step(action_index) action_index = list(ob[0].valid_indices).index( env.de_vocab.lookup(cmd)) ob, _, _, _ = env.step(action_index) new_all_hyps.append((env, ob)) except ValueError: raise ValueError('This should not happen') all_hyps = new_all_hyps # then for every possible action, we use a cloned env and step that until the closure of this stmt stmt_done = False while not stmt_done: new_all_hyps = [] for env, ob in all_hyps: valid_indices = list(ob[0].valid_indices) if len(valid_indices ) == 0: # this is a dead end for current env continue for action_index in range(len(valid_indices)): new_env = env.clone() new_env.use_cache = False ob, _, _, _ = new_env.step(action_index) new_all_hyps.append((new_env, ob)) if valid_indices == [env.de_vocab.lookup(')') ]: # maximum stmt length is reached stmt_done = True all_hyps = list(np.random.permutation(new_all_hyps))[:max_hyp] # add then end token new_all_hyps = [] for env, ob in all_hyps: try: action_index = list(ob[0].valid_indices).index( env.de_vocab.lookup('<END>')) ob, _, _, _ = env.step(action_index) new_all_hyps.append((env, ob)) except ValueError: raise ValueError('This should not happen') all_hyps = new_all_hyps # pruning based on the result explored_programs = [] for env, _ in all_hyps: if env.rewards[-1] == 1.0: traj = agent_factory.Traj(obs=env.obs, actions=env.actions, rewards=env.rewards, context=env.get_context(), env_name=env.name, answer=env.interpreter.result) explored_programs.append(traj_to_program(traj, env.de_vocab)) explored_programs = list(np.random.permutation(explored_programs)) return explored_programs