コード例 #1
0
    def init_env(self, batch_size):
        # register a new Gym environment.
        training_method = self.config["general"]["training_method"]
        expert_type = self.config["env"]["expert_type"]
        if training_method == "dqn":
            infos = textworld.EnvInfos(won=True, admissible_commands=True, expert_type=expert_type, expert_plan=False, extras=["gamefile"])
            max_nb_steps_per_episode = self.config["rl"]["training"]["max_nb_steps_per_episode"]
        elif training_method == "dagger":
            expert_plan = True if self.train_eval == "train" else False
            infos = textworld.EnvInfos(won=True, admissible_commands=True, expert_type=expert_type, expert_plan=expert_plan, extras=["gamefile"])
            max_nb_steps_per_episode = self.config["dagger"]["training"]["max_nb_steps_per_episode"]
        else:
            raise NotImplementedError

        domain_randomization = self.config["env"]["domain_randomization"]
        if self.train_eval != "train":
            domain_randomization = False
        alfred_demangler = AlfredDemangler(shuffle=domain_randomization)
        env_id = textworld.gym.register_games(self.game_files, infos,
                                              batch_size=batch_size,
                                              asynchronous=True,
                                              max_episode_steps=max_nb_steps_per_episode,
                                              wrappers=[alfred_demangler, AlfredInfos])
        # launch Gym environment.
        env = gym.make(env_id)
        return env
コード例 #2
0
def benchmark_gym(gamefile, args):
    infos = textworld.EnvInfos(admissible_commands=True)
    env_id = textworld.gym.register_games([gamefile] * args.batch_size, infos,
                                          args.batch_size)
    env = gym.make(env_id)
    print("Using {}".format(env.__class__.__name__))

    rng = np.random.RandomState(args.seed)

    obs, infos = env.reset()

    if args.verbose:
        print(obs[0])

    start_time = time.time()
    for _ in range(args.max_steps):
        command = rng.choice(infos["admissible_commands"][0])
        obs, _, dones, infos = env.step([command] * args.batch_size)

        if all(dones):
            env.reset()

        if args.verbose:
            print(obs[0])

    duration = time.time() - start_time
    speed = args.max_steps / duration
    print("Done {:,} steps in {:.2f} secs ({:,.1f} steps/sec)".format(
        args.max_steps, duration, speed))
    return speed
コード例 #3
0
def ensure_gameinfo_file(gamefile, env_seed=42, save_to_file=True):
    # NOTE: individual gamefiles have already been registered
    # NO NEED FOR batch env here
    # if env_id is not None:
    #     assert gamefile == self.env2game_map[env_id]
    print("+++== ensure_gameinfo_file:", gamefile, "...")
    if not _gameinfo_file_exists(gamefile):
        print(
            f"NEED TO GENERATE game_info: '{_gameinfo_path_from_gamefile(gamefile)}'",
        )
        print("CURRENT WORKING DIR:", os.getcwd())
        if gamefile.find("game_") > -1:
            game_guid = gamefile[gamefile.find("game_"):].split('_')[1]
        else:
            game_guid = ''
        game_guid += "-ginfo"

        request_qait_infos = textworld.EnvInfos(
            # static infos, don't change during the game:
            game=True,
            verbs=True,
            # the following are all specific to TextWorld version customized for QAIT (all are static)
            #UNUSED # location_names=True,
            #UNUSED # location_nouns=True,
            #UNUSED # location_adjs=True,
            #TODO: object_names=True,
            #TODO: object_nouns=True,
            #TODO: object_adjs=True,
            extras=["object_locations", "object_attributes", "uuid"])

        _env = textworld.start(gamefile, infos=request_qait_infos)
        game_state = _env.reset()

        # print(game_state.keys())
        # example from TW 1.3.2 without qait_infos
        # dict_keys(
        #     ['feedback', 'raw', 'game', 'command_templates', 'verbs', 'entities', 'objective', 'max_score', 'extra.seeds',
        #      'extra.goal', 'extra.ingredients', 'extra.skills', 'extra.entities', 'extra.nb_distractors',
        #      'extra.walkthrough', 'extra.max_score', 'extra.uuid', 'description', 'inventory', 'score', 'moves', 'won',
        #      'lost', '_game_progression', '_facts', '_winning_policy', 'facts', '_last_action', '_valid_actions',
        #      'admissible_commands'])
        # game_uuid = game_state['extra.uuid']   # tw-cooking-recipe1+cook+cut+open+drop+go6-xEKyIJpqua0Gsm0q

        game_info = _get_gameinfo_from_gamestate(
            game_state
        )  # ? maybe don't filter/ keep everything (including dynamic info)
        _env.close()
        load_additional_gameinfo_from_jsonfile(
            game_info, _gamejson_path_from_gamefile(gamefile))

        if save_to_file:
            print("+++== save_gameinfo_file:", gamefile, game_info.keys())
            _s = _serialize_gameinfo(game_info)
            with open(_gameinfo_path_from_gamefile(gamefile), "w") as infofile:
                infofile.write(_s + '\n')
                infofile.flush()
        game_info['_gamefile'] = gamefile
        return game_info
    else:
        return load_gameinfo_file(gamefile)
コード例 #4
0
ファイル: playthroughs.py プロジェクト: gstrazds/twagents
def start_game_for_playthrough(gamefile,
                               raw_obs_feedback=True,  # don't apply ConsistentFeedbackWrapper
                               passive_oracle_mode=False,  # if True, don't predict next action
                               max_episode_steps=MAX_PLAYTHROUGH_STEPS
                               ):  #
    _word_vocab = WordVocab(vocab_file=QAIT_VOCAB)
    _qgym_ = QaitGym(random_seed=DEFAULT_PTHRU_SEED,
                     raw_obs_feedback=raw_obs_feedback,
                     passive_oracle_mode=passive_oracle_mode)
    _qgym_env = _qgym_.make_batch_env([gamefile],
                                   _word_vocab,  # vocab not really needed by Oracle, just for gym.space
                                   request_infos=textworld.EnvInfos(
                                        feedback=True,
                                        description=True,
                                        inventory=True,
                                        location=True,
                                        entities=True,
                                        verbs=True,
                                        facts=True,   # use ground truth facts about the world (since this is a training oracle)
                                        admissible_commands=True,
                                        game=True,
                                        extras=["recipe", "uuid"]
                                   ),
                                   batch_size=1,
                                   max_episode_steps=max_episode_steps)
    obs, infos = _qgym_env.reset()
    _word_vocab.init_from_infos_lists(infos['verbs'], infos['entities'])
    return _qgym_env, obs, infos
コード例 #5
0
ファイル: playgame.py プロジェクト: pvl/CogniTextWorldAgent
 def select_additional_infos(self):
     infos = textworld.EnvInfos()
     for info in self._stats["requested_infos"]:
         if info in AVAILABLE_INFORMATION.extras:
             infos.extras.append(info)
         else:
             setattr(infos, info, True)
     return infos
コード例 #6
0
def main(args):
    GAME_LOGIC = {
        "pddl_domain":
        open(args.domain).read(),
        "grammar":
        "\n".join(
            open(f).read()
            for f in glob.glob("data/textworld_data/logic/*.twl2")),
    }

    # load state and trajectory files
    pddl_file = os.path.join(args.problem, 'initial_state.pddl')
    json_file = os.path.join(args.problem, 'traj_data.json')
    with open(json_file, 'r') as f:
        traj_data = json.load(f)
    GAME_LOGIC['grammar'] = add_task_to_grammar(GAME_LOGIC['grammar'],
                                                traj_data)

    # dump game file
    gamedata = dict(**GAME_LOGIC, pddl_problem=open(pddl_file).read())
    gamefile = os.path.join(os.path.dirname(pddl_file), 'game.tw-pddl')
    json.dump(gamedata, open(gamefile, "w"))

    # register a new Gym environment.
    infos = textworld.EnvInfos(won=True, admissible_commands=True)
    env_id = textworld.gym.register_game(gamefile,
                                         infos,
                                         max_episode_steps=1000000,
                                         wrappers=[AlfredDemangler])

    # reset env
    env = gym.make(env_id)
    obs, infos = env.reset()

    # human agent
    agent = HumanAgent(True)
    agent.reset(env)

    while True:
        print(obs)
        cmd = agent.act(infos, 0, False)

        if cmd == "ipdb":
            from ipdb import set_trace
            set_trace()
            continue

        obs, score, done, infos = env.step(cmd)

        if done:
            print("You won!")
            break
コード例 #7
0
def benchmark(gamefile, args):
    infos = textworld.EnvInfos()
    if args.activate_state_tracking or args.mode == "random-cmd":
        infos.admissible_commands = True

    if args.compute_intermediate_reward:
        infos.intermediate_reward = True

    env = textworld.start(gamefile, infos)
    print("Using {}".format(env))

    if args.mode == "random":
        agent = textworld.agents.NaiveAgent()
    elif args.mode == "random-cmd":
        agent = textworld.agents.RandomCommandAgent(seed=args.agent_seed)
    elif args.mode == "walkthrough":
        agent = textworld.agents.WalkthroughAgent()

    agent.reset(env)
    game_state = env.reset()

    if args.verbose:
        env.render()

    reward = 0
    done = False
    nb_resets = 1
    start_time = time.time()
    for _ in range(args.max_steps):
        command = agent.act(game_state, reward, done)
        game_state, reward, done = env.step(command)

        if done:
            env.reset()
            nb_resets += 1
            done = False

        if args.verbose:
            env.render()

    duration = time.time() - start_time
    speed = args.max_steps / duration
    msg = "Done {:,} steps in {:.2f} secs ({:,.1f} steps/sec) with {} resets."
    print(msg.format(args.max_steps, duration, speed, nb_resets))
    return speed
コード例 #8
0
def test_manually_defined_objective():
    M = GameMaker()

    # Create a 'bedroom' room.
    R1 = M.new_room("bedroom")
    M.set_player(R1)

    game = M.build()
    game.objective = "There's nothing much to do in here."

    with make_temp_directory(
            prefix="test_manually_defined_objective") as tmpdir:
        game_file = M.compile(tmpdir)

        env = textworld.start(game_file,
                              infos=textworld.EnvInfos(objective=True))
        state = env.reset()
        assert state["objective"] == "There's nothing much to do in here."
コード例 #9
0
    def __init__(self,
                 gamefile: str,
                 network: tf.keras.Model,
                 cpuct: Optional[float] = 0.4,
                 max_steps: int = 100,
                 temperature: float = 1.0,
                 dnoise: float = 0.5,
                 verbs: List[str] = None):
        # the environment can only have ONE game
        self.gamefile = gamefile
        self.current = Node(None, None)
        self.network = network
        self.root = self.current
        self.cpuct = cpuct
        self.max_score = None
        self.max_steps = max_steps
        self.temperature = temperature
        self.vocab = network.vocab
        self.dnoise = dnoise

        infos_to_request = textworld.EnvInfos(description=False,
                                              inventory=False,
                                              has_won=True,
                                              has_lost=True,
                                              admissible_commands=True,
                                              entities=True,
                                              max_score=True)

        env_id = textworld.gym.register_games(game_files=[gamefile],
                                              request_infos=infos_to_request,
                                              max_episode_steps=max_steps)

        env = gym.make(env_id)
        self.env = env
        obs, infos = env.reset()
        self.mission = obs[1210:obs.find("=")]
        self.root.feedback = FeedbackMeta(obs[1210:])
        self.max_score = infos['max_score']
コード例 #10
0
def generate_bert_triplet_data(games, seed, branching_depth):
    rng = np.random.RandomState(seed)
    dataset = []
    for game in tqdm(games):
        # Ignore the following commands.
        commands_to_ignore = ["look", "examine", "inventory"]

        request_infos = textworld.EnvInfos(admissible_commands=True,
                                           last_action=True,
                                           game=True,
                                           description=True,
                                           entities=True,
                                           facts=True,
                                           extras=["recipe", "walkthrough"])
        env_id = textworld.gym.register_game(game,
                                             request_infos,
                                             max_episode_steps=10000)
        env = gym.make(env_id)

        _, infos = env.reset()
        walkthrough = infos["extra.walkthrough"]
        if walkthrough[
                0] != "inventory":  # Make sure we start with listing the inventory.
            walkthrough = ["inventory"] + walkthrough

        done = False
        cmd = "restart"  # The first previous_action is like [re]starting a new game.
        for i in range(len(walkthrough) + 1):
            obs, infos = env.reset()
            obs = infos[
                "description"]  # `obs` would contain the banner and objective text which we don't want.

            # Follow the walkthrough for a bit.
            for cmd in walkthrough[:i]:

                obs, _, done, infos = env.step(cmd)
                state = "DESCRIPTION: " + infos[
                    'description'] + "INVENTORY: " + env.step(
                        'inventory')[0] + infos["extra.recipe"]
                state = clean_game_state(state)
                local_facts = process_local_facts(infos['game'],
                                                  infos['facts'])
                serialized_facts = serialize_facts(local_facts)
                filtered_facts = filter_triplets(serialized_facts)

                dataset += [{
                    "game": os.path.basename(game),
                    "step": (i, 0),
                    "state": state,
                    "facts": filtered_facts
                }]

            if done:
                break  # Stop collecting data if game is done.

            if i == 0:
                continue  # No random commands before 'inventory'

            # Then, take N random actions.
            for j in range(branching_depth):
                cmd = rng.choice([
                    c for c in infos["admissible_commands"]
                    if (c == "examine cookbook"
                        or c.split()[0] not in commands_to_ignore)
                ])
                obs, _, done, infos = env.step(cmd)
                if done:
                    break  # Stop collecting data if game is done.
                state = "DESCRIPTION: " + infos[
                    'description'] + "INVENTORY: " + env.step(
                        'inventory')[0] + infos["extra.recipe"]
                state = clean_game_state(state)
                local_facts = process_local_facts(infos['game'],
                                                  infos['facts'])
                serialized_facts = serialize_facts(local_facts)
                filtered_facts = filter_triplets(serialized_facts)

                dataset += [{
                    "game": os.path.basename(game),
                    "step": (i, j + 1),
                    "state": state,
                    "facts": filtered_facts
                }]
    with open('data.json', 'w') as fp:
        json.dump(dataset, fp)
コード例 #11
0
ファイル: playgame.py プロジェクト: pvl/CogniTextWorldAgent
import textworld.gym
import time
import tqdm
from termcolor import colored


NB_EPISODES = 10
MAX_EPISODE_STEPS = 100
TIMEOUT = 12 * 30 * 60  # 12 hours
DISPLAY_GAME = False

# List of additional information available during evaluation.
AVAILABLE_INFORMATION = textworld.EnvInfos(
    max_score=True, has_won=True, has_lost=True,                    # Handicap 0
    description=True, inventory=True, objective=True,               # Handicap 1
    verbs=True, command_templates=True,                             # Handicap 2
    entities=True,                                                  # Handicap 3
    extras=["recipe"],                                              # Handicap 4
    admissible_commands=True,                                       # Handicap 5
)


def display(text, type=''):
    def formatproba(data, color='blue'):
        data = list(reversed(sorted(data, key=lambda x: x[1])))
        data = data[:5]
        lines = []
        for k,v in data:
            lines.append('  {:4.2f} {}'.format(v,k))
        if lines:
            maxsize = max([len(k) for k in lines])
            print(colored('-'*(maxsize+2), color))
コード例 #12
0
def generate_data(games, seed, branching_depth):
    rng = np.random.RandomState(seed)
    dataset = []
    seen_states = set()
    for game in tqdm(games):
        # Ignore the following commands.
        commands_to_ignore = ["look", "examine", "inventory"]

        request_infos = textworld.EnvInfos(admissible_commands=True, last_action = True, game = True,inventory=True, description=True, entities=True, facts = True, extras=["recipe","walkthrough","goal"])
        env_id = textworld.gym.register_game(game, request_infos, max_episode_steps=10000)
        env = gym.make(env_id)

        _, infos = env.reset()
        walkthrough = infos["extra.walkthrough"]
        if walkthrough[0] != "inventory":  # Make sure we start with listing the inventory.
            walkthrough = ["inventory"] + walkthrough

        
        done = False
        cmd = "restart"  # The first previous_action is like [re]starting a new game.
        for i in range(len(walkthrough) + 1):
            obs, infos = env.reset()
            obs = infos["description"]  # `obs` would contain the banner and objective text which we don't want.

            # Follow the walkthrough for a bit.
            for cmd in walkthrough[:i]:
               
                obs, _, done, infos = env.step(cmd)
                state = "DESCRIPTION: "+ infos['description'] + " INVENTORY: "+ infos['inventory']
                state = clean_game_state(state)

                if state not in seen_states:

                    acs = infos['admissible_commands']
                    for ac in acs[:]:
                        if ac.startswith('examine') and ac != 'examine cookbook' or ac == 'look' or ac == 'inventory':
                            acs.remove(ac)
                    data = acs
                    data_name = 'admissible_commands'
                    

                    dataset += [{
                        "game": os.path.basename(game),
                        "step": (i, 0),
                        "state": state,
                        data_name : data
                    }]

                    seen_states.add(state)

            if done:
                break  # Stop collecting data if game is done.

            if i == 0:
                continue  # No random commands before 'inventory'

            # Then, take N random actions.
            for j in range(branching_depth):
                cmd = rng.choice([c for c in infos["admissible_commands"] if (c == "examine cookbook" or c.split()[0] not in commands_to_ignore)])
                obs, _, done, infos = env.step(cmd)
                if done:
                    break  # Stop collecting data if game is done.
                state = "DESCRIPTION: "+ infos['description'] + " INVENTORY: "+ infos['inventory']
                state = clean_game_state(state)
                if state not in seen_states:

                    acs = infos['admissible_commands']
                    for ac in acs[:]:
                        if (ac.startswith('examine') and ac != 'examine cookbook') or ac == 'look' or ac == 'inventory':
                            acs.remove(ac)
                    data = acs
                    data_name = 'admissible_commands'

                    dataset += [{
                        "game": os.path.basename(game),
                        "step": (i, j),
                        "state": state,
                        data_name : data
                    }]
                    seen_states.add(state)

    with open('data.json', 'w') as fp:
        json.dump(dataset, fp)
コード例 #13
0
 def infos_to_request(self) -> textworld.EnvInfos:
     return textworld.EnvInfos(admissible_commands=True)
コード例 #14
0
 def select_additional_infos(self):
     return textworld.EnvInfos(**self._send("select_additional_infos"))
コード例 #15
0
ファイル: train.py プロジェクト: xingdi-eric-yuan/qait_public
from textworld.gym import register_game, make_batch2
from agent import Agent
import generic
import reward_helper
import game_generator
import evaluate
from query import process_facts

request_infos = textworld.EnvInfos(
    description=True,
    inventory=True,
    verbs=True,
    location_names=True,
    location_nouns=True,
    location_adjs=True,
    object_names=True,
    object_nouns=True,
    object_adjs=True,
    facts=True,
    last_action=True,
    game=True,
    admissible_commands=True,
    extras=["object_locations", "object_attributes", "uuid"])


def train(data_path):

    time_1 = datetime.datetime.now()
    agent = Agent()

    # visdom
コード例 #16
0
def collect_data_from_game(gamefile):
    # Ignore the following commands.
    commands_to_ignore = ["look", "examine", "inventory"]

    env_infos = textworld.EnvInfos(description=True,
                                   location=True,
                                   facts=True,
                                   last_action=True,
                                   admissible_commands=True,
                                   game=True,
                                   extras=["walkthrough"])
    env = textworld.start(gamefile, env_infos)

    infos = env.reset()
    walkthrough = infos["extra.walkthrough"]

    # Make sure we start with listing the inventory.
    if walkthrough[0] != "inventory":
        walkthrough = ["inventory"] + walkthrough

    # Add 'restart' command as a way to indicate the beginning of the game.
    walkthrough = ["restart"] + walkthrough

    dataset = []

    done = False
    facts_seen = set()
    for i, cmd in enumerate(walkthrough):
        last_facts = facts_seen
        if i > 0:  # != "restart"
            infos, _, done = env.step(cmd)

        facts_local = process_local_obs_facts(infos["game"], infos["facts"],
                                              infos["last_action"], cmd)
        facts_seen = process_facts(last_facts, infos["game"], infos["facts"],
                                   infos["last_action"], cmd)
        facts_full = process_fully_obs_facts(infos["game"], infos["facts"])

        dataset += [{
            "game": os.path.basename(gamefile),
            "step": (i, 0),
            "action": cmd.lower(),
            "graph_local": sorted(serialize_facts(facts_local)),
            "graph_seen": sorted(serialize_facts(facts_seen)),
            "graph_full": sorted(serialize_facts(facts_full)),
        }]

        if done:
            break  # Stop collecting data if game is done.

        # Then, try all admissible commands.
        commands = [
            c for c in infos["admissible_commands"] if
            ((c == "examine cookbook" or c.split()[0] not in commands_to_ignore
              ) and (i + 1) != len(walkthrough) and c != walkthrough[i + 1])
        ]

        for j, cmd_ in enumerate(commands, start=1):
            env_ = env.copy()
            infos, _, done = env_.step(cmd_)

            facts_local_ = process_local_obs_facts(infos["game"],
                                                   infos["facts"],
                                                   infos["last_action"], cmd_)
            facts_seen_ = process_facts(facts_seen, infos["game"],
                                        infos["facts"], infos["last_action"],
                                        cmd_)
            facts_full_ = process_fully_obs_facts(infos["game"],
                                                  infos["facts"])

            dataset += [{
                "game": os.path.basename(gamefile),
                "step": (i, j),
                "action": cmd_.lower(),
                "graph_local": sorted(serialize_facts(facts_local_)),
                "graph_seen": sorted(serialize_facts(facts_seen_)),
                "graph_full": sorted(serialize_facts(facts_full_)),
            }]

    return gamefile, dataset
コード例 #17
0
def collect_data_from_game(gamefile, seed, branching_depth):
    tokenizer = spacy.load('en', disable=['ner', 'parser', 'tagger'])
    rng = np.random.RandomState(seed)

    # Ignore the following commands.
    commands_to_ignore = ["look", "examine", "inventory"]

    env_infos = textworld.EnvInfos(description=True,
                                   location=True,
                                   facts=True,
                                   last_action=True,
                                   admissible_commands=True,
                                   game=True,
                                   extras=["walkthrough"])
    env = textworld.start(gamefile, env_infos)
    env = textworld.envs.wrappers.Filter(env)

    obs, infos = env.reset()
    walkthrough = infos["extra.walkthrough"]

    # Make sure we start with listing the inventory.
    if walkthrough[0] != "inventory":
        walkthrough = ["inventory"] + walkthrough

    # Add 'restart' command as a way to indicate the beginning of the game.
    walkthrough = ["restart"] + walkthrough

    dataset = []

    done = False
    facts_seen = set()
    for i, cmd in enumerate(walkthrough):
        last_facts = facts_seen
        if i > 0:  # != "restart"
            obs, _, done, infos = env.step(cmd)

        facts_seen = process_facts(last_facts, infos["game"], infos["facts"],
                                   infos["last_action"], cmd)

        dataset += [{
            "game":
            os.path.basename(gamefile),
            "step": (i, 0),
            "observation":
            preproc(obs, tokenizer=tokenizer),
            "previous_action":
            cmd.lower(),
            "target_commands":
            sorted(
                gen_graph_commands(facts_seen - last_facts, cmd="add") +
                gen_graph_commands(last_facts - facts_seen, cmd="delete")),
            "previous_graph_seen":
            sorted(serialize_facts(last_facts)),
            "graph_seen":
            sorted(serialize_facts(facts_seen)),
        }]

        if done:
            break  # Stop collecting data if game is done.

        # Fork the current game & seen facts.
        env_ = env.copy()
        facts_seen_ = facts_seen

        # Then, take N random actions.
        for j in range(1, branching_depth + 1):
            commands = [
                c for c in infos["admissible_commands"]
                if ((c == "examine cookbook"
                     or c.split()[0] not in commands_to_ignore) and
                    (i + 1) != len(walkthrough) and c != walkthrough[i + 1])
            ]

            if len(commands) == 0:
                break

            cmd_ = rng.choice(commands)
            obs, _, done, infos = env_.step(cmd_)

            if done:
                break  # Stop collecting data if game is done.

            last_facts_ = facts_seen_
            facts_seen_ = process_facts(last_facts_, infos["game"],
                                        infos["facts"], infos["last_action"],
                                        cmd_)

            dataset += [{
                "game":
                os.path.basename(gamefile),
                "step": (i, j),
                "observation":
                preproc(obs, tokenizer=tokenizer),
                "previous_action":
                cmd_.lower(),
                "target_commands":
                sorted(
                    gen_graph_commands(facts_seen_ - last_facts_, cmd="add") +
                    gen_graph_commands(last_facts_ -
                                       facts_seen_, cmd="delete")),
                "previous_graph_seen":
                sorted(serialize_facts(last_facts_)),
                "graph_seen":
                sorted(serialize_facts(facts_seen_)),
            }]

    return gamefile, dataset
コード例 #18
0
def main():
    task = load_task(args.domain, args.problem)

    name2type = {o.name: o.type_name for o in task.objects}

    def _atom2proposition(atom):
        if isinstance(atom, fast_downward.translate.pddl.conditions.Atom):
            if atom.predicate == "=":
                return None

            return Proposition(
                atom.predicate,
                [Variable(arg, name2type[arg]) for arg in atom.args])

        elif isinstance(atom,
                        fast_downward.translate.pddl.f_expression.Assign):
            if atom.fluent.symbol == "total-cost":
                return None

            #name = "{}_{}".format(atom.fluent.symbol, atom.expression.value)
            name = "{}".format(atom.expression.value)
            return Proposition(
                name,
                [Variable(arg, name2type[arg]) for arg in atom.fluent.args])

    facts = [_atom2proposition(atom) for atom in task.init]
    facts = list(filter(None, facts))

    def _convert_variable(variable):
        if variable.name == "agent1":
            return Variable("P")

        elif variable.name.split("_", 1)[0] in [
                "apple", "tomato", "potato", "bread", "lettuce", "egg"
        ]:
            return Variable(variable.name, "f")

        elif variable.name.split("_", 1)[0] in [
                "garbagecan", "cabinet", "container", "fridge", "microwave",
                "sink"
        ]:
            return Variable(variable.name, "c")

        elif variable.name.split("_", 1)[0] in ["tabletop", "stoveburner"]:
            return Variable(variable.name, "s")

        elif variable.name.split("_", 1)[0] in [
                "bowl", "pot", "plate", "mug", "fork", "knife", "pan", "spoon"
        ]:
            return Variable(variable.name, "o")

        elif variable.type == "location":
            return Variable(variable.name, "r")

        elif variable.type == "receptacle":
            return Variable(variable.name, "c")

        elif variable.type in ["otype", "rtype"]:
            return variable

        print("Unknown variable:", variable)
        return variable

    I = Variable("I")
    P = Variable("P")

    def _exists(name, *arguments):
        for f in facts:
            if f.name != name:
                continue

            if all(v1 is None or v1 == v2
                   for v1, v2 in zip(arguments, f.arguments)):
                return True

        return False

    def _convert_proposition(proposition):
        proposition = Proposition(
            proposition.name,
            [_convert_variable(a) for a in proposition.arguments])

        if proposition.name == "atlocation":
            return Proposition("at", (P, proposition.arguments[-1]))

        elif proposition.name == "receptacleatlocation":
            return Proposition("at", proposition.arguments)

        elif proposition.name == "objectatlocation":
            if _exists("inreceptacle", proposition.arguments[0], None):
                return Proposition("at", proposition.arguments)
            else:
                return None

        elif proposition.name == "inreceptacle":
            if proposition.arguments[-1].type == "s":
                return Proposition("on", proposition.arguments)

            return Proposition("in", proposition.arguments)

        elif proposition.name == "opened":
            return Proposition("open", proposition.arguments)

        elif proposition.name == "not_opened":
            return Proposition("closed", proposition.arguments)

        elif proposition.name == "holds":
            return Proposition("in", (proposition.arguments[0], I))

        elif proposition.name in ["openable", "checked", "full"]:
            return None  # TODO: support those attributes/states.

        elif proposition.name in ["objecttype", "receptacletype"]:
            return None

        elif str.isdigit(proposition.name):
            return Proposition("connected", proposition.arguments)

        print("Unknown fact:", proposition)
        return proposition

    facts = [_convert_proposition(f) for f in facts]
    facts = filter(None, facts)
    facts = clean_alfred_facts(facts)

    variables = {v.name: v for p in facts for v in p.arguments}

    # from textworld.generator.data import KnowledgeBase
    # textworld.render.visualize(State(KnowledgeBase.default().logic, facts), True)

    import glob
    logic = GameLogic()
    logic.load_domain(args.domain)
    for f in glob.glob("data/textworld_data/logic/*.twl2"):
        logic.import_twl2(f)

    state = State.from_pddl(logic, args.problem)
    game = Game(state, quests=[])
    for info in game.infos.values():
        info.name = _demangle_alfred_name(info.id)

    from pprint import pprint
    from textworld.envs.tw2 import TextWorldEnv
    from textworld.agents import HumanAgent

    infos = textworld.EnvInfos(admissible_commands=True)
    env = TextWorldEnv(infos)
    env.load(game)

    agent = HumanAgent(True)
    agent.reset(env)

    obs = env.reset()
    while True:
        #pprint(obs)
        print(obs.feedback)
        cmd = agent.act(obs, 0, False)
        if cmd == "STATE":
            print("\n".join(sorted(map(str, clean_alfred_facts(obs._facts)))))
            continue

        elif cmd == "ipdb":
            from ipdb import set_trace
            set_trace()
            continue

        obs, _, _ = env.step(cmd)
        print(
            colored(
                "\n".join(sorted(map(str, clean_alfred_facts(obs.effects)))),
                "yellow"))

    from ipdb import set_trace
    set_trace()

    options = textworld.GameOptions()
    options.path = "tw_games/test.z8"
    options.force_recompile = True

    from ipdb import set_trace
    set_trace()

    world = World.from_facts(facts, kb=options._kb)

    # Keep names and descriptions that were manually provided.
    used_names = set()
    for k, var_infos in game.infos.items():
        if k in variables:
            game.infos[k].name = variables[k].name

    # Use text grammar to generate name and description.
    import numpy as np
    from textworld.generator import Grammar
    grammar = Grammar(options.grammar,
                      rng=np.random.RandomState(options.seeds["grammar"]))
    game.change_grammar(grammar)
    game.metadata["desc"] = "Generated with textworld.GameMaker."

    path = "/home/macote/src/TextWorld/textworld/generator/data/logic/look.twl2"
    with open(path) as f:
        document = f.read()

    actions, grammar = _parse_and_convert(document, rule_name="start2")

    # compile_game(game, options)
    game.grammar = grammar

    env = TextWorldEnv()
    env.load(game=game)
    state = env.reset()

    while True:
        print(state.feedback)
        cmd = input("> ")
        state, _, _ = env.step(cmd)

    from ipdb import set_trace
    set_trace()
コード例 #19
0
def main():
    args = build_parser().parse_args()
    if args.very_verbose:
        args.verbose = args.very_verbose

    if args.state_fn:
        state_wf = logging.StreamHandler(open(args.state_fn, "w"))
        state_wf.setLevel(logging.INFO)
        state_logger = logging.getLogger("states")
        state_logger.setLevel(logging.INFO)
        state_logger.addHandler(state_wf)
        # open(args.state_fn, 'w')
    if args.utts_fn:
        # bc = logging.basicConfig(filename=args.utts_fn)
        utts_wf = logging.StreamHandler(open(args.utts_fn, "w"))
        # utts_wf = open(args.utts_fn, 'w')
    else:
        # bc = logging.basicConfig()
        utts_wf = logging.StreamHandler(sys.stdout)
    utts_wf.terminator = ''
    utts_wf.setLevel(logging.INFO)
    utts_logger = logging.getLogger("utterances")
    utts_logger.setLevel(logging.INFO)
    utts_logger.addHandler(utts_wf)

    request_infos = textworld.EnvInfos(inventory=True,
                                       description=True,
                                       facts=True)
    env = textworld.start(args.game, infos=request_infos)

    if args.mode == "random":
        agent = textworld.agents.NaiveAgent(seed=args.seed)
    elif args.mode == "random-cmd":
        # TODO get rid of redundancy
        agent = textworld.agents.RandomCommandAgent(seed=args.seed)
    elif args.mode == "human":
        agent = textworld.agents.HumanAgent()
    elif args.mode == 'walkthrough':
        agent = textworld.agents.WalkthroughAgent()

    agent.reset(env)
    if args.viewer is not None:
        from textworld.envs.wrappers import HtmlViewer
        env = HtmlViewer(env, port=args.viewer)

    if args.mode == "human" or args.very_verbose:
        utts_logger.info("Using {}.\n".format(env.__class__.__name__))

    game_state = env.reset()
    if args.mode == "human" or args.verbose:
        contents = env.render(mode="text")
        contents = '\n'.join(
            contents.split('\n')[22:])  # remove `TEXTWORLD` heading
        contents = contents.replace('\n\n', '\n')
        utts_logger.info(contents)
        if args.state_fn:
            state_logger.info(
                json.dumps({
                    'facts':
                    [Proposition.serialize(prop) for prop in env.state.facts],
                    'inventory':
                    env.state.inventory.replace('\n\n', '\n'),
                    'description':
                    env.state.description.replace('\n\n', '\n'),
                }))

    reward = 0
    done = False

    for _ in range(
            args.max_steps) if args.max_steps > 0 else itertools.count():
        # tw_inform7.py
        command = agent.act(game_state, reward, done)
        # env.state.game.world.facts
        game_state, reward, done = env.step(command)

        if args.mode == "human" or args.verbose:
            contents = env.render(mode="text")
            contents = contents.replace('\n\n', '\n').replace('\n\n', '\n')
            utts_logger.info(contents)
            if args.state_fn:
                state_logger.info(
                    json.dumps({
                        'facts': [
                            Proposition.serialize(prop)
                            for prop in env.state.facts
                        ],
                        'inventory':
                        env.state.inventory.replace('\n\n', '\n'),
                        'description':
                        env.state.description.replace('\n\n', '\n'),
                    }))

        if done:
            break

    env.close()
    utts_logger.info("Done after {} steps. Score {}/{}.".format(
        game_state.moves, game_state.score, game_state.max_score))