예제 #1
0
    def test_step(self):
        env = textworld.start(self.game_file)
        npt.assert_raises(GameNotRunningError, env.step, "look")

        # Test sending command when the game is done.
        env = textworld.start(self.game_file)
        env.reset()
        env.step("quit")
        env.step("yes")
        npt.assert_raises(GameNotRunningError, env.step, "look")

        # Test sending empty command.
        env = textworld.start(self.game_file)
        env.reset()
        env.step("")
예제 #2
0
    def test_step(self):
        env = textworld.start(self.game_file)
        npt.assert_raises(GameNotRunningError, env.step, "look")

        # TODO: not supported by Jericho
        # # Test sending command when the game has quit.
        # env = textworld.start(self.game_file)
        # env.reset()
        # env.step("quit")
        # env.step("yes")
        # npt.assert_raises(GameNotRunningError, env.step, "look")

        # Test sending empty command.
        env = textworld.start(self.game_file)
        env.reset()
        env.step("")
예제 #3
0
def test_winning_game():
    MAX_NB_STEPS = 1000  # Just in case.
    env = textworld.start("./games/zork1.z5")
    walkthrough_file = os.path.abspath(
        pjoin(env.game_filename, "..", "solutions", "zork1.txt"))
    with open(walkthrough_file) as f:
        commands = f.readlines()
    agent = textworld.agents.WalkthroughAgent(commands)

    env.seed(1)  # In order for the walkthrough to work.
    game_state = env.reset()

    # env.render()

    done = False
    for t in range(MAX_NB_STEPS):
        command = agent.act(game_state, 0, done)
        game_state, reward, done = env.step(command)
        # env.render()

        if done:
            break

    print("Done after {} steps. Score {}/{}.".format(game_state.moves,
                                                     game_state.score,
                                                     game_state.max_score))
    assert game_state.won
    assert not game_state.lost
    def reset(self):
        self.current_game = self._next_game()
        self.infos = {}
        self.infos["game_file"] = os.path.basename(self.current_game)

        if self.textworld_env is not None:
            self.textworld_env.close()

        self.textworld_env = textworld.start(self.current_game)

        if "admissible_commands" in self.request_infos:
            self.textworld_env.activate_state_tracking()

        if "intermediate_reward" in self.request_infos:
            self.textworld_env.activate_state_tracking()
            self.textworld_env.compute_intermediate_reward()

        self.infos[
            "directions_names"] = self.textworld_env.game.directions_names
        self.infos["verbs"] = self.textworld_env.game.verbs
        self.infos["objects_names"] = self.textworld_env.game.objects_names
        self.infos["objects_types"] = self.textworld_env.game.objects_types
        self.infos[
            "objects_names_and_types"] = self.textworld_env.game.objects_names_and_types
        self.infos["max_score"] = 1

        self.performed_actions = set()
        self.game_state = self.textworld_env.reset()
        ob = self.game_state.feedback
        self._update_requested_infos()
        return ob, self.infos
예제 #5
0
def test_html_viewer():
    # Integration test for visualization service
    num_nodes = 3
    num_items = 10
    options = textworld.GameOptions()
    options.seeds = 1234
    options.nb_rooms = num_nodes
    options.nb_objects = num_items
    options.quest_length = 3
    options.grammar.theme = "house"
    options.grammar.include_adj = True
    game = textworld.generator.make_game(options)

    game_name = "test_html_viewer_wrapper"
    with make_temp_directory(prefix=game_name) as tmpdir:
        game_file = compile_game(game, path=tmpdir)

        env = textworld.start(game_file)
        env = HtmlViewer(env, open_automatically=False, port=8080)
        env.reset()  # Cause rendering to occur.

    # options.binary_location = "/bin/chromium"
    driver = get_webdriver()

    driver.get("http://127.0.0.1:8080")
    nodes = driver.find_elements_by_class_name("node")
    assert len(nodes) == num_nodes
    items = driver.find_elements_by_class_name("item")

    objects = [obj for obj in game.world.objects if obj.type != "I"]
    assert len(items) == len(objects)

    env.close()
    driver.close()
예제 #6
0
def test_filter_wrapper():
    # Make a game for testing purposes.
    num_nodes = 3
    num_items = 10
    options = textworld.GameOptions()
    options.seeds = 1234
    options.nb_rooms = num_nodes
    options.nb_objects = num_items
    options.quest_length = 3
    options.grammar.theme = "house"
    options.grammar.include_adj = True
    game = textworld.generator.make_game(options)

    game_name = "test_filter_wrapper"
    with make_temp_directory(prefix=game_name) as tmpdir:
        options.path = tmpdir
        game_file = compile_game(game, options)

        env = textworld.start(game_file)
        env_options = EnvInfos()
        for attr in env_options.__slots__:
            if attr == "extras":
                continue  # Skip since it's not a boolean attribute.

            setattr(env_options, attr, True)

        assert len(env_options) == len(env_options.__slots__) - 1
        assert len(env_options) == len(env_options.basics)

        env = Filter(env, env_options)
        _, infos = env.reset()

        for attr in env_options.basics:
            assert attr in infos
예제 #7
0
def ensure_gameinfo_file(gamefile, env_seed=42, save_to_file=True):
    # NOTE: individual gamefiles have already been registered
    # NO NEED FOR batch env here
    # if env_id is not None:
    #     assert gamefile == self.env2game_map[env_id]
    print("+++== ensure_gameinfo_file:", gamefile, "...")
    if not _gameinfo_file_exists(gamefile):
        print(
            f"NEED TO GENERATE game_info: '{_gameinfo_path_from_gamefile(gamefile)}'",
        )
        print("CURRENT WORKING DIR:", os.getcwd())
        if gamefile.find("game_") > -1:
            game_guid = gamefile[gamefile.find("game_"):].split('_')[1]
        else:
            game_guid = ''
        game_guid += "-ginfo"

        request_qait_infos = textworld.EnvInfos(
            # static infos, don't change during the game:
            game=True,
            verbs=True,
            # the following are all specific to TextWorld version customized for QAIT (all are static)
            #UNUSED # location_names=True,
            #UNUSED # location_nouns=True,
            #UNUSED # location_adjs=True,
            #TODO: object_names=True,
            #TODO: object_nouns=True,
            #TODO: object_adjs=True,
            extras=["object_locations", "object_attributes", "uuid"])

        _env = textworld.start(gamefile, infos=request_qait_infos)
        game_state = _env.reset()

        # print(game_state.keys())
        # example from TW 1.3.2 without qait_infos
        # dict_keys(
        #     ['feedback', 'raw', 'game', 'command_templates', 'verbs', 'entities', 'objective', 'max_score', 'extra.seeds',
        #      'extra.goal', 'extra.ingredients', 'extra.skills', 'extra.entities', 'extra.nb_distractors',
        #      'extra.walkthrough', 'extra.max_score', 'extra.uuid', 'description', 'inventory', 'score', 'moves', 'won',
        #      'lost', '_game_progression', '_facts', '_winning_policy', 'facts', '_last_action', '_valid_actions',
        #      'admissible_commands'])
        # game_uuid = game_state['extra.uuid']   # tw-cooking-recipe1+cook+cut+open+drop+go6-xEKyIJpqua0Gsm0q

        game_info = _get_gameinfo_from_gamestate(
            game_state
        )  # ? maybe don't filter/ keep everything (including dynamic info)
        _env.close()
        load_additional_gameinfo_from_jsonfile(
            game_info, _gamejson_path_from_gamefile(gamefile))

        if save_to_file:
            print("+++== save_gameinfo_file:", gamefile, game_info.keys())
            _s = _serialize_gameinfo(game_info)
            with open(_gameinfo_path_from_gamefile(gamefile), "w") as infofile:
                infofile.write(_s + '\n')
                infofile.flush()
        game_info['_gamefile'] = gamefile
        return game_info
    else:
        return load_gameinfo_file(gamefile)
예제 #8
0
 def setUp(self):
     self.env = textworld.start(self.game_file)
     self.env.activate_state_tracking()
     self.env.compute_intermediate_reward()
     self.env.enable_extra_info("description")
     self.env.enable_extra_info("inventory")
     self.game_state = self.env.reset()
예제 #9
0
def test_limit_wrapper():
    # Make a game for testing purposes.
    max_episode_steps = 7

    num_nodes = 3
    num_items = 10
    options = textworld.GameOptions()
    options.seeds = 1234
    options.nb_rooms = num_nodes
    options.nb_objects = num_items
    options.quest_length = 3
    options.grammar.theme = "house"
    options.grammar.include_adj = True
    game = textworld.generator.make_game(options)

    game_name = "test_limit_wrapper"
    with make_temp_directory(prefix=game_name) as tmpdir:
        options.path = tmpdir
        game_file = compile_game(game, options)

        env = textworld.start(game_file)
        env = Limit(env, max_episode_steps)
        state = env.reset()

        done = False
        assert state["moves"] == 0
        for no_step in range(1, max_episode_steps + 1):
            assert not done
            state, score, done = env.step("wait")
            assert state["moves"] == no_step

        assert done
예제 #10
0
def test_quest_winning_condition_go():
    M = textworld.GameMaker()

    # R1 -- R2 -- R3
    R1 = M.new_room("West room")
    R2 = M.new_room("Center room")
    R3 = M.new_room("East room")
    M.set_player(R1)

    M.connect(R1.east, R2.west)
    M.connect(R2.east, R3.west)

    M.set_quest_from_commands(["go east", "go east"])

    game = M.build()
    game_name = "test_quest_winning_condition_go"
    with make_temp_directory(prefix=game_name) as tmpdir:
        game_file = compile_game(game, game_name, games_folder=tmpdir)

        env = textworld.start(game_file)
        env.reset()
        game_state, _, done = env.step("go east")
        assert not done
        assert not game_state.has_won

        game_state, _, done = env.step("go east")
        assert done
        assert game_state.has_won
예제 #11
0
def run_random_agent(agent, game, max_step=500, nb_episodes=10, verbose = False):
    env = textworld.start(game)  
    print(game.split("/")[-1])
    
    avg_moves, avg_scores = [], []
    for no_episode in range(nb_episodes):
        agent.reset(env)          
        game_state = env.reset() 
        if verbose:
            print_state(game_state, 0)
        reward = 0
        done = False
        for no_step in range(max_step):
            command = agent.act(game_state, reward, done)
            game_state, reward, done = env.step(command)
            if verbose:
                print_command(command)
                print_state(game_state, no_step)
            if done:
                break
        print("Episode " + str(no_episode))
        avg_moves.append(game_state.nb_moves)
        avg_scores.append(game_state.score)

    env.close()
    print("  \tavg. steps: {:5.1f}; avg. score: {:4.1f} / 1.".format(np.mean(avg_moves), np.mean(avg_scores)))
예제 #12
0
def test_html_viewer():
    # Integration test for visualization service
    num_nodes = 3
    num_items = 10
    g_rng.set_seed(1234)
    grammar_flags = {"theme": "house", "include_adj": True}
    game = textworld.generator.make_game(world_size=num_nodes,
                                         nb_objects=num_items,
                                         quest_length=3,
                                         grammar_flags=grammar_flags)

    game_name = "test_html_viewer_wrapper"
    with make_temp_directory(prefix=game_name) as tmpdir:
        game_file = compile_game(game, game_name, games_folder=tmpdir)

        env = textworld.start(game_file)
        env = HtmlViewer(env, open_automatically=False, port=8080)
        env.reset()  # Cause rendering to occur.

    # options.binary_location = "/bin/chromium"
    driver = get_webdriver()

    driver.get("http://127.0.0.1:8080")
    nodes = driver.find_elements_by_class_name("node")
    assert len(nodes) == num_nodes
    items = driver.find_elements_by_class_name("item")

    objects = [obj for obj in game.world.objects if obj.type != "I"]
    assert len(items) == len(objects)

    env.close()
    driver.close()
예제 #13
0
def test_names_disambiguation():
    M = textworld.GameMaker()
    room = M.new_room("room")
    M.set_player(room)

    apple = M.new(type="o", name="apple")
    orange = M.new(type="o", name="orange")
    tasty_apple = M.new(type="o", name="tasty apple")
    tasty_orange = M.new(type="o", name="tasty orange")
    room.add(apple, orange, tasty_apple, tasty_orange)

    game = M.build()
    game_name = "test_names_disambiguation"
    with make_temp_directory(prefix=game_name) as tmpdir:
        game_file = compile_game(game, game_name, games_folder=tmpdir)
        env = textworld.start(game_file)
        env.reset()
        game_state, _, done = env.step("take tasty apple")
        assert "tasty apple" in game_state.inventory
        game_state, _, done = env.step("take tasty orange")
        assert "tasty orange" in game_state.inventory

        env.reset()
        game_state, _, done = env.step("take orange")
        assert "tasty orange" not in game_state.inventory
        assert "orange" in game_state.inventory

        game_state, _, done = env.step("take tasty")
        assert "?" in game_state.feedback  # Disambiguation question.
        game_state, _, done = env.step("apple")
        assert "tasty orange" not in game_state.inventory
        assert "tasty apple" in game_state.inventory
        assert "tasty apple" not in game_state.description
예제 #14
0
def test_playing_generated_games():
    NB_GAMES = 10
    rng = np.random.RandomState(1234)
    for i in range(NB_GAMES):

        # Sample game specs.
        world_size = rng.randint(1, 10)
        nb_objects = rng.randint(0, 20)
        quest_length = rng.randint(2, 5)
        quest_breadth = rng.randint(3, 7)
        game_seed = rng.randint(0, 65365)

        with make_temp_directory(prefix="test_play_generated_games") as tmpdir:
            options = textworld.GameOptions()
            options.nb_rooms = world_size
            options.nb_objects = nb_objects
            options.quest_length = quest_length
            options.quest_breadth = quest_breadth
            options.seeds = game_seed
            game_file, game = textworld.make(options, path=tmpdir)

            # Solve the game using WalkthroughAgent.
            agent = textworld.agents.WalkthroughAgent()
            textworld.play(game_file, agent=agent, silent=True)

            # Play the game using RandomAgent and make sure we can always finish the
            # game by following the winning policy.
            env = textworld.start(game_file)

            agent = textworld.agents.RandomCommandAgent()
            agent.reset(env)
            env.compute_intermediate_reward()

            env.seed(4321)
            game_state = env.reset()

            max_steps = 100
            reward = 0
            done = False
            for step in range(max_steps):
                command = agent.act(game_state, reward, done)
                game_state, reward, done = env.step(command)

                if done:
                    msg = "Finished before playing `max_steps` steps because of command '{}'.".format(command)
                    if game_state.has_won:
                        msg += " (winning)"
                        assert len(game_state._game_progression.winning_policy) == 0

                    if game_state.has_lost:
                        msg += " (losing)"
                        assert game_state._game_progression.winning_policy is None

                    print(msg)
                    break

                # Make sure the game can still be solved.
                winning_policy = game_state._game_progression.winning_policy
                assert len(winning_policy) > 0
                assert game_state.state.is_sequence_applicable(winning_policy)
예제 #15
0
def test_disambiguation_questions():
    M = textworld.GameMaker()
    room = M.new_room("room")
    M.set_player(room)

    tasty_apple = M.new(type="o", name="tasty apple")
    tasty_orange = M.new(type="o", name="tasty orange")
    room.add(tasty_apple, tasty_orange)

    game = M.build()
    game_name = "test_names_disambiguation"
    with make_temp_directory(prefix=game_name) as tmpdir:
        game_file = _compile_game(game, path=tmpdir)
        env = textworld.start(game_file, EnvInfos(description=True, inventory=True))

        game_state = env.reset()
        previous_inventory = game_state.inventory
        previous_description = game_state.description

        game_state, _, _ = env.step("take tasty")
        assert "?" in game_state.feedback  # Disambiguation question.

        # When there is a question in Inform7, the next string sent to the game
        # will be considered as the answer. We now make sure that asking for
        # extra information like `description` or `inventory` before answering
        # the question works.
        assert game_state.description == previous_description
        assert game_state.inventory == previous_inventory

        # Now answering the question.
        game_state, _, _ = env.step("apple")
        assert "That's not a verb I recognise." not in game_state.feedback
        assert "tasty orange" not in game_state.inventory
        assert "tasty apple" in game_state.inventory
        assert "tasty apple" not in game_state.description
예제 #16
0
 def test_100_sequential_runs(self):
     for i in range(1, 100):
         env = textworld.start(self.game_file)
         env.reset()
         game_state, reward, done = env.step('take inventory')
         self.assertIsNotNone(game_state, "Checking gamestate is not None")
         self.assertIsNotNone(reward, "Checking reward is not None")
         self.assertFalse(done, "Checking we don't finish the game by looking at our stuff")
예제 #17
0
def test_quest_winning_condition():
    g_rng.set_seed(2018)
    map_ = make_small_map(n_rooms=5, possible_door_states=["open"])
    world = World.from_map(map_)

    def _rule_to_skip(rule):
        # Examine, look and inventory shouldn't be used for chaining.
        if rule.name.startswith("look"):
            return True

        if rule.name.startswith("inventory"):
            return True

        if rule.name.startswith("examine"):
            return True

        return False

    for rule in KnowledgeBase.default().rules.values():
        if _rule_to_skip(rule):
            continue

        options = ChainingOptions()
        options.backward = True
        options.max_depth = 1
        options.create_variables = True
        options.rules_per_depth = [[rule]]
        options.restricted_types = {"r"}
        chain = sample_quest(world.state, options)
        assert len(chain.actions) > 0, rule.name
        event = Event(chain.actions)
        quest = Quest(win_events=[event])

        # Set the initial state required for the quest.
        tmp_world = World.from_facts(chain.initial_state.facts)
        game = make_game_with(tmp_world, [quest], make_grammar({}))

        if tmp_world.player_room is None:
            # Randomly place the player in the world since
            # the action doesn't care about where the player is.
            tmp_world.set_player_room()

        game_name = "test_quest_winning_condition_" + rule.name.replace(
            "/", "_")
        with make_temp_directory(prefix=game_name) as tmpdir:
            game_file = _compile_game(game, path=tmpdir)

            env = textworld.start(game_file)
            env.reset()
            game_state, _, done = env.step("look")
            assert not done
            assert not game_state.won

            game_state, _, done = env.step(event.commands[0])
            assert done
            assert game_state.won
예제 #18
0
def test_game_with_infinite_max_score():
    M = textworld.GameMaker()
    museum = M.new_room("Museum")

    statue = M.new(type="o", name="golden statue")
    pedestal = M.new(type="s", name="pedestal")

    pedestal.add(statue)
    museum.add(pedestal)

    M.set_player(museum)

    M.quests = [
        Quest(win_events=[
            Event(conditions=[M.new_fact('in', statue, M.inventory)]),
        ],
              reward=10,
              optional=True,
              repeatable=True),
        Quest(win_events=[
            Event(conditions=[M.new_fact('at', statue, museum)]),
        ],
              reward=0)
    ]
    M.set_walkthrough(["take statue from pedestal", "look", "drop statue"])

    game = M.build()
    game_name = "test_game_with_infinite_max_score"
    with make_temp_directory(prefix=game_name) as tmpdir:
        game_file = _compile_game(game, path=tmpdir)
        env = textworld.start(game_file)
        state = env.reset()
        state.max_score == np.inf  # Score increases for each turn the player hold the statue.

        state, score, done = env.step("look")
        assert not done
        assert score == 0

        state, score, done = env.step("take statue")
        assert score == 10

        state, score, done = env.step("wait")
        assert score == 20

        state, score, done = env.step("score")
        assert score == 20
        assert "You have so far scored 20 points," in state.feedback

        state, score, done = env.step("wait")
        assert score == 30

        state, score, done = env.step("drop statue")
        assert done
        assert score == 30

        assert "a total of 30 points," in state.feedback
예제 #19
0
def test_play_generated_games():
    NB_GAMES = 10
    rng = np.random.RandomState(1234)
    for i in range(NB_GAMES):

        # Sample game specs.
        world_size = rng.randint(1, 10)
        nb_objects = rng.randint(0, 20)
        quest_length = rng.randint(1, 10)
        game_seed = rng.randint(0, 65365)
        grammar_flags = {}  # Default grammar.

        with make_temp_directory(prefix="test_play_generated_games") as tmpdir:
            game_file, game = textworld.make(world_size, nb_objects, quest_length, grammar_flags,
                                             seed=game_seed, games_dir=tmpdir)

            # Solve the game using WalkthroughAgent.
            agent = textworld.agents.WalkthroughAgent()
            textworld.play(game_file, agent=agent, silent=True)

            # Play the game using RandomAgent and make sure we can always finish the
            # game by following the winning policy.
            env = textworld.start(game_file)

            agent = textworld.agents.RandomCommandAgent()
            agent.reset(env)
            env.compute_intermediate_reward()

            env.seed(4321)
            game_state = env.reset()

            max_steps = 100
            reward = 0
            done = False
            for step in range(max_steps):
                command = agent.act(game_state, reward, done)
                game_state, reward, done = env.step(command)

                if done:
                    msg = "Finished before playing `max_steps` steps."
                    if game_state.has_won:
                        msg += " (winning)"
                        assert game_state._game_progression.winning_policy == []

                    if game_state.has_lost:
                        msg += " (losing)"
                        assert game_state._game_progression.winning_policy is None

                    print(msg)
                    break

                # Make sure the game can still be solved.
                winning_policy = game_state._game_progression.winning_policy
                assert len(winning_policy) > 0
                assert game_state.state.is_sequence_applicable(winning_policy)
def test_playing_generated_games():
    NB_GAMES = 10
    rng = np.random.RandomState(1234)
    for i in range(NB_GAMES):

        # Sample game specs.
        world_size = rng.randint(1, 10)
        nb_objects = rng.randint(0, 20)
        quest_depth = rng.randint(2, 5)
        quest_breadth = rng.randint(3, 7)
        game_seed = rng.randint(0, 65365)

        with make_temp_directory(prefix="test_play_generated_games") as tmpdir:
            options = textworld.GameOptions()
            options.path = tmpdir
            options.nb_rooms = world_size
            options.nb_objects = nb_objects
            options.chaining.max_depth = quest_depth
            options.chaining.max_breadth = quest_breadth
            options.seeds = game_seed
            game_file, game = textworld.make(options)

            # Solve the game using WalkthroughAgent.
            agent = textworld.agents.WalkthroughAgent()
            textworld.play(game_file, agent=agent, silent=True)

            # Play the game using RandomAgent and make sure we can always finish the
            # game by following the winning policy.
            env = textworld.start(game_file)
            env.infos.policy_commands = True
            env.infos.game = True

            agent = textworld.agents.RandomCommandAgent()
            agent.reset(env)

            env.seed(4321)
            game_state = env.reset()

            max_steps = 100
            reward = 0
            done = False
            for step in range(max_steps):
                command = agent.act(game_state, reward, done)
                game_state, reward, done = env.step(command)

                if done:
                    assert game_state._winning_policy is None
                    game_state, reward, done = env.reset(), 0, False

                # Make sure the game can still be solved.
                winning_policy = game_state._winning_policy
                assert len(winning_policy) > 0
                assert game_state._game_progression.state.is_sequence_applicable(
                    winning_policy)
    def test_game_random_agent(self):
        env = textworld.start(self.game_file)
        agent = textworld.agents.RandomCommandAgent()
        agent.reset(env)
        game_state = env.reset()

        reward = 0
        done = False
        for _ in range(5):
            command = agent.act(game_state, reward, done)
            game_state, reward, done = env.step(command)
예제 #22
0
    def test_simultaneous_runs(self):
        envs = []
        for i in range(1, 100):
            env = textworld.start(self.game_file)
            env.reset()
            envs.append(env)

        game_state, reward, done = envs[-1].step('take inventory')
        self.assertIsNotNone(game_state, "Checking gamestate is not None")
        self.assertIsNotNone(reward, "Checking reward is not None")
        self.assertFalse(done, "Checking we don't finish the game by looking at our stuff")
예제 #23
0
파일: gdqn.py 프로젝트: zhangabner/KG-DQN
    def __init__(self, game, params):
        self.num_episodes = params['num_episodes']
        self.state = StateNAction()

        self.update_freq = params['update_frequency']
        self.filename = 'kgdqn_' + '_'.join(
            [str(v) for k, v in params.items() if 'file' not in str(k)])
        logging.basicConfig(filename='logs/' + self.filename + '.log',
                            filemode='w')
        logging.warning("Parameters", params)

        self.env = textworld.start(game)
        self.params = params

        if params['replay_buffer_type'] == 'priority':
            self.replay_buffer = GraphPriorityReplayBuffer(
                params['replay_buffer_size'])
        elif params['replay_buffer_type'] == 'standard':
            self.replay_buffer = GraphReplayBuffer(
                params['replay_buffer_size'])

        params['vocab_size'] = len(self.state.vocab_drqa)

        self.model = KGDQN(params, self.state.all_actions).cuda()

        if self.params['preload_weights']:
            self.model = torch.load(self.params['preload_file'])['model']
        # model = nn.DataParallel(model)

        self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr'])

        self.env.compute_intermediate_reward()
        self.env.activate_state_tracking()

        self.num_frames = params['num_frames']
        self.batch_size = params['batch_size']
        self.gamma = params['gamma']

        self.losses = []
        self.all_rewards = []
        self.completion_steps = []

        #priority fraction
        self.rho = params['rho']

        if params['scheduler_type'] == 'exponential':
            self.e_scheduler = ExponentialSchedule(self.num_frames,
                                                   params['e_decay'],
                                                   params['e_final'])
        elif params['scheduler_type'] == 'linear':
            self.e_scheduler = LinearSchedule(self.num_frames,
                                              params['e_final'])
예제 #24
0
def test_take_all_and_variants():
    M = textworld.GameMaker()

    # Empty room.
    room = M.new_room("room")
    M.set_player(room)

    game = M.build()
    game_name = "test_take_all_and_variants"
    with make_temp_directory(prefix=game_name) as tmpdir:
        game_file = compile_game(game, game_name, games_folder=tmpdir)
        env = textworld.start(game_file)
        env.reset()

        variants_to_test = itertools.product(["take", "get", "pick up"],
                                             ["all", "everything", "each"])
        for command in variants_to_test:
            game_state, _, done = env.step(" ".join(command))
            assert game_state.feedback.strip(
            ) == "You have to be more specific!"

    # Multiple objects to take.
    red_ball = M.new(type="o", name="red ball")
    blue_ball = M.new(type="o", name="blue ball")
    room.add(red_ball, blue_ball)

    game = M.build()
    game_name = "test_take_all_and_variants2"
    with make_temp_directory(prefix=game_name) as tmpdir:
        game_file = compile_game(game, game_name, games_folder=tmpdir)
        env = textworld.start(game_file)
        env.reset()

        game_state, _, done = env.step("take all ball")
        assert "red ball: Taken." in game_state.feedback
        assert "blue ball: Taken." in game_state.feedback
        assert "red ball" in game_state.inventory
        assert "blue ball" in game_state.inventory
예제 #25
0
    def play_game(self,
                  epoch,
                  game,
                  epi_idx,
                  dictionary,
                  max_step,
                  train=True):
        game_name = os.path.basename(game)
        print('[PLAY GAME] ', game_name)
        env = textworld.start(game)

        self.reset(env)  # tells the agent a new run is starting.
        game_state = env.reset()  # Start new run.
        print(game_state)

        total_reward = 0
        reward = 0
        done = False
        for t in range(max_step):
            command = self.random_admissible_act(
                game_state)  # TODO : random_admissible_act for test
            print(game_state.admissible_commands)
            print('>> ', command)

            next_game_state, reward, done = env.step(command)
            print(next_game_state)
            total_reward += reward

            if train:
                self.update_dictionary(dictionary, game_state, next_game_state)
                self.replay.put(epi_idx, str(game_state), command, reward,
                                done, str(next_game_state))

                # Every 4 steps, DRQN should learn from a new experience batch
                # 여기서, 어쩌면 이 타이밍에 배치를 생성할 때 최근 4개의 경험을 꼭 포함하도록 하면 조금 더 좋아질지도 모르겠다.
                if (t + 1) % 4 == 0:
                    self.train_batch(dictionary)

            game_state = next_game_state

            if done:
                break

        # Tell the agent the run is done.
        self.finish(game_state, reward, done)

        msg = "#{:2d}. {}:\t {:3d} steps; score: {:2d}"
        msg = msg.format(epoch, game_name, game_state.nb_moves, total_reward)
        print(msg)
        env.close()
예제 #26
0
def evaluate(agent, game, args):
    env = textworld.start(game)
    log.debug("Using {}".format(env.__class__.__name__))
    agent.reset(env)

    start_time = time.time()
    game_state = env.reset()
    log.debug("Environment reset.\n{}\n".format(env.render(mode="text")))

    max_score = game_state.max_score
    nb_losts = 0
    highscore = 0
    score = 0
    done = False

    for step in range(1, args.nb_steps + 1):
        action = agent.act(game_state, score, done)
        game_state, score, done = env.step(action)

        msg = "{:5d}. Time: {:9.2f}\tScore: {:3d}\tMove: {:5d}\tAction: {:20s}"
        msg = msg.format(step,
                         time.time() - start_time, game_state.score,
                         game_state.nb_moves, action)
        log.info(msg)
        log.debug(env.render(mode="text"))

        if done:
            highscore = max(score, highscore)

            if game_state.won:
                if highscore == max_score:
                    break  # No reason to play that game more.
            elif game_state.lost:
                nb_losts += 1
            else:
                assert True, "Games should either end with a win or a fail."

            # Replay the game in the hope of achieving a better score.
            game_state = env.reset()
            log.debug("Environment reset.\n{}\n".format(
                env.render(mode="text")))

    env.close()

    # Keep highest score.
    highscore = max(score, highscore)

    return step, nb_losts, highscore, max_score, time.time() - start_time
예제 #27
0
    def test_description(self):
        env = textworld.start(self.game_file)
        game_state = env.reset()
        npt.assert_raises(ExtraInfosIsMissingError, getattr, game_state, "description")

        game_state, _, _ = self.env.step("look")
        assert game_state.feedback.strip() == self.game_state.description.strip()
        assert game_state.feedback.strip() == game_state.description.strip()
        game_state, _, _ = self.env.step("go east")
        game_state, _, _ = self.env.step("look")
        assert game_state.feedback.strip() == game_state.description.strip()

        # End the game.
        game_state, _, _ = self.env.step("insert carrot into chest")
        game_state, _, _ = self.env.step("close chest")
        assert game_state.description != ""
    def test_game_walkthrough_agent(self):
        agent = textworld.agents.WalkthroughAgent()
        env = textworld.start(self.game_file)
        commands = self.game.main_quest.commands
        agent.reset(env)
        game_state = env.reset()

        reward = 0
        done = False
        for walkthrough_command in commands:
            self.assertFalse(done, 'walkthrough finished game too early')
            command = agent.act(game_state, reward, done)
            self.assertEqual(walkthrough_command, command,
                             "Walkthrough agent issued unexpected command")
            game_state, reward, done = env.step(command)
        self.assertTrue(done, 'Walkthrough did not finish the game')
예제 #29
0
def benchmark(gamefile, args):
    infos = textworld.EnvInfos()
    if args.activate_state_tracking or args.mode == "random-cmd":
        infos.admissible_commands = True

    if args.compute_intermediate_reward:
        infos.intermediate_reward = True

    env = textworld.start(gamefile, infos)
    print("Using {}".format(env))

    if args.mode == "random":
        agent = textworld.agents.NaiveAgent()
    elif args.mode == "random-cmd":
        agent = textworld.agents.RandomCommandAgent(seed=args.agent_seed)
    elif args.mode == "walkthrough":
        agent = textworld.agents.WalkthroughAgent()

    agent.reset(env)
    game_state = env.reset()

    if args.verbose:
        env.render()

    reward = 0
    done = False
    nb_resets = 1
    start_time = time.time()
    for _ in range(args.max_steps):
        command = agent.act(game_state, reward, done)
        game_state, reward, done = env.step(command)

        if done:
            env.reset()
            nb_resets += 1
            done = False

        if args.verbose:
            env.render()

    duration = time.time() - start_time
    speed = args.max_steps / duration
    msg = "Done {:,} steps in {:.2f} secs ({:,.1f} steps/sec) with {} resets."
    print(msg.format(args.max_steps, duration, speed, nb_resets))
    return speed
예제 #30
0
def test_manually_defined_objective():
    M = GameMaker()

    # Create a 'bedroom' room.
    R1 = M.new_room("bedroom")
    M.set_player(R1)

    game = M.build()
    game.objective = "There's nothing much to do in here."

    with make_temp_directory(
            prefix="test_manually_defined_objective") as tmpdir:
        game_file = M.compile(tmpdir)

        env = textworld.start(game_file,
                              infos=textworld.EnvInfos(objective=True))
        state = env.reset()
        assert state["objective"] == "There's nothing much to do in here."