示例#1
0
                    help="gym environment to load",
                    default='MiniGrid-MultiRoom-N6-v0')
parser.add_argument("--seed",
                    type=int,
                    help="random seed to generate the environment with",
                    default=-1)
parser.add_argument("--tile_size",
                    type=int,
                    help="size at which to render tiles",
                    default=32)
parser.add_argument('--agent_view',
                    default=False,
                    help="draw the agent sees (partially observable view)",
                    action='store_true')

args = parser.parse_args()

env = gym.make(args.env)

if args.agent_view:
    env = RGBImgPartialObsWrapper(env)
    env = ImgObsWrapper(env)

window = Window('gym_minigrid - ' + args.env)
window.reg_key_handler(key_handler)

reset()

# Blocking event loop
window.show(block=True)
    # Spacebar
    if event.key == ' ':
        step(env.actions.toggle)
        return
    if event.key == 'pageup':
        step(env.actions.pickup)
        return
    if event.key == 'pagedown':
        step(env.actions.drop)
        return

    if event.key == 'enter':
        step(env.actions.done)
        return



env = gym.make('MiniGrid-Empty-5x5-v0')

env = RGBImgPartialObsWrapper(env)
env = ImgObsWrapper(env)

window = Window('gym_minigrid')
window.reg_key_handler(random_solve)

reset()

# Blocking event loop
window.show(block=True)
示例#3
0
class Interface:
    """
    User interface for generating counterfactual states and actions.
    For each trajectory:
        The user can use the `a` and `d` keys to go backward and forward in time to a specific state they wish to
        generate a counterfactual explanation from. Once a specific state is found, the user presses `w` to launch the
        counterfactual mode.
        In the counterfactual mode, the user can control the agent using the keyboard. Once satisfied with their
        progress, the user can give up control to the bot to finish the episode by pressing `w`.

        The keys for controlling the agent are summarized here:

        escape: quit the program

        view mode:
            d: next time step
            a: previous time step
            w: select time step and go to counterfactual mode

        counterfactual mode:
            left:          turn counter-clockwise
            right:         turn clockwise
            up:            go forward
            space:         toggle
            pageup or x:   pickup
            pagedown or z: drop
            enter or q:    done (should not use)
            a:             undo action
            w:             roll out to the end of the episode and move to next episode
    """
    def __init__(self,
                 original_dataset: TrajectoryDataset,
                 counterfactual_dataset: TrajectoryDataset,
                 policy_factory=None):
        self.policy_factory = policy_factory
        self.dataset = original_dataset
        self.counterfactual_dataset = counterfactual_dataset
        self.trajectory_generator = self.dataset.trajectory_generator()
        self.navigator: TrajectoryNavigator = None
        self.window = None
        self.is_counterfactual = False
        self.run()

    def run(self):
        for i, trajectory in enumerate(self.trajectory_generator):
            self.saved = False
            self.is_counterfactual = False
            self.navigator = TrajectoryNavigator(trajectory)
            self.window = Window(f'Trajectory {i}')
            self.window.reg_key_handler(self.key_handler)
            self.reset()
            self.window.show(block=True)
            if not self.saved:
                raise Exception('Continued without saving the trajectory!')

    def redraw(self):
        step: TrajectoryStep = self.navigator.step()
        # if not self.agent_view:
        env = step.state
        img = env.render('rgb_array', tile_size=32)
        # else:
        # img = step.observation['image']
        # TODO later: figure out when to use the observation instead.

        self.window.show_img(img)

    def step(self, action=None):
        if action is None:
            self.navigator.forward()
        else:
            assert isinstance(self.navigator, CounterfactualNavigator)
            self.navigator.forward(action)
        self.redraw()

    def backward(self):
        self.navigator.backward()
        self.redraw()

    def reset(self):
        env = self.navigator.step().state

        if hasattr(env, 'mission'):
            print('Mission: %s' % env.mission)
            self.window.set_caption(env.mission)

        self.redraw()

    def select(self):
        new_navigator = CounterfactualNavigator(
            self.navigator.episode,
            self.navigator.index,
            self.navigator.step(),
            policy_factory=self.policy_factory)
        self.navigator = new_navigator
        self.is_counterfactual = True
        print(
            f'Starting counterfactual trajectory from {self.navigator.index}')
        self.redraw()

    def save_trajectory(self):
        assert isinstance(self.navigator, CounterfactualNavigator)
        self.navigator.store(self.counterfactual_dataset)
        self.saved = True

    def key_handler(self, event):
        print('pressed', event.key)

        if event.key == 'escape':
            self.window.close()
            exit()
            return

        # if event.key == 'backspace':
        #     self.reset()
        #     return
        if self.is_counterfactual:
            if event.key == 'left':
                self.step('left')
                return
            if event.key == 'right':
                self.step('right')
                return
            if event.key == 'up':
                self.step('forward')
                return

            # Spacebar
            if event.key == ' ':
                self.step('toggle')
                return
            if event.key == 'pageup' or event.key == 'x':
                self.step('pickup')
                return
            if event.key == 'pagedown' or event.key == 'z':
                self.step('drop')
                return

            if event.key == 'enter' or event.key == 'q':
                self.step('done')
                return

            if event.key == 'w':
                if self.policy_factory is not None:
                    self.navigator.rollout()
                self.save_trajectory()
                self.window.close()

            if event.key == 'a':
                self.backward()
                return

        if not self.is_counterfactual:
            if event.key == 'd':
                self.step()
                return

            if event.key == 'a':
                self.backward()
                return

            if event.key == 'w':
                self.select()
                return
示例#4
0
class SimpleEnv(object):
    def __init__(self, display=False, agent_view=5, map_size=20, roads=1, max_step=100):
        super().__init__()
        self.display = display
        self.map = Simple2Dv2(map_size, map_size, agent_view=agent_view, roads=roads, max_step=max_step)
        self.window = None
        if self.display:
            self.window = Window('GYM_MiniGrid')
            self.window.reg_key_handler(self.key_handler)
            self.window.show(True)
        self.detect_rate = []
        self.rewards = []
        self.step_count = []
        self.old = None
        self.new = None
        self._rewards = []

    def short_term_reward(self):
        # (- manhattan distance / 100) + ( - stay time / 100)
        return self.new["reward"] / 100 - self.map.check_history() / 100

    def long_term_reward(self):
        _extrinsic_reward = self.new["l_reward"]
        _extrinsic_reward = sum(_extrinsic_reward) / len(_extrinsic_reward)
        return _extrinsic_reward

    def step(self, action):
        # Turn left, turn right, move forward
        # forward = 0
        # left = 1
        # right = 2
        self.old = self.map.state()
        self.new, done = self.map.step(action)
        reward = self.short_term_reward()
        if self.display is True:
            self.redraw()
        if done != 0:
            self.detect_rate.append(self.new["l_reward"])
            self.step_count.append(self.map.step_count)
            reward += self.long_term_reward()
            self._rewards.append(reward)
            self.rewards.append(np.mean(self._rewards))
        else:
            self._rewards.append(reward)

        return self.old, self.new, reward, done

    def key_handler(self, event):
        print('pressed', event.key)
        if event.key == 'left':
            self.step(0)
            return
        if event.key == 'right':
            self.step(1)
            return
        if event.key == 'up':
            self.step(2)
            return

    def redraw(self):
        if self.window is not None:
            self.map.render('human')

    def reset_env(self):
        """
        reset environment to the start point
        :return:
        """
        self.map.reset()
        self._rewards = []
        if self.display:
            self.redraw()