Ejemplo n.º 1
0
def collect_traj_for_program(env, program, debug=False):
    env = env.clone()
    env.use_cache = False
    ob = env.start_ob

    for tk in program:
        valid_actions = list(ob[0].valid_indices)
        mapped_action = env.de_vocab.lookup(tk)
        try:
            action = valid_actions.index(mapped_action)
        except Exception as e:
            if debug:
                return None, (env.interpreter.namespace, env.actions, program,
                              mapped_action, valid_actions)
            else:
                return None
        ob, _, _, _ = env.step(action)
    traj = agent_factory.Traj(obs=env.obs,
                              actions=env.actions,
                              rewards=env.rewards,
                              context=env.get_context(),
                              env_name=env.name,
                              answer=env.interpreter.result)

    if debug:
        return traj, None
    else:
        return traj
Ejemplo n.º 2
0
    def explore_sketch_programs(env, sketch, max_hyp=1000):
        '''provide a sketch, find all the executable programs fit that sketch'''
        env = env.clone()
        env.use_cache = False

        all_hyps = [(env, env.start_ob)]

        for cmd in sketch:
            # first add the left bracket and the cmd head
            new_all_hyps = []
            for env, ob in all_hyps:
                try:
                    action_index = list(ob[0].valid_indices).index(
                        env.de_vocab.lookup('('))
                    ob, _, _, _ = env.step(action_index)
                    action_index = list(ob[0].valid_indices).index(
                        env.de_vocab.lookup(cmd))
                    ob, _, _, _ = env.step(action_index)

                    new_all_hyps.append((env, ob))
                except ValueError:
                    raise ValueError('This should not happen')
            all_hyps = new_all_hyps

            # then for every possible action, we use a cloned env and step that until the closure of this stmt
            stmt_done = False
            while not stmt_done:
                new_all_hyps = []
                for env, ob in all_hyps:
                    valid_indices = list(ob[0].valid_indices)
                    if len(valid_indices
                           ) == 0:  # this is a dead end for current env
                        continue
                    for action_index in range(len(valid_indices)):
                        new_env = env.clone()
                        new_env.use_cache = False

                        ob, _, _, _ = new_env.step(action_index)
                        new_all_hyps.append((new_env, ob))
                    if valid_indices == [env.de_vocab.lookup(')')
                                         ]:  # maximum stmt length is reached
                        stmt_done = True
                all_hyps = list(np.random.permutation(new_all_hyps))[:max_hyp]

        # add then end token
        new_all_hyps = []
        for env, ob in all_hyps:
            try:
                action_index = list(ob[0].valid_indices).index(
                    env.de_vocab.lookup('<END>'))
                ob, _, _, _ = env.step(action_index)

                new_all_hyps.append((env, ob))
            except ValueError:
                raise ValueError('This should not happen')
        all_hyps = new_all_hyps

        # pruning based on the result
        explored_programs = []
        for env, _ in all_hyps:
            if env.rewards[-1] == 1.0:
                traj = agent_factory.Traj(obs=env.obs,
                                          actions=env.actions,
                                          rewards=env.rewards,
                                          context=env.get_context(),
                                          env_name=env.name,
                                          answer=env.interpreter.result)
                explored_programs.append(traj_to_program(traj, env.de_vocab))

        explored_programs = list(np.random.permutation(explored_programs))

        return explored_programs