Exemplo n.º 1
0
class Interface:
    """
    User interface for generating counterfactual states and actions.
    For each trajectory:
        The user can use the `a` and `d` keys to go backward and forward in time to a specific state they wish to
        generate a counterfactual explanation from. Once a specific state is found, the user presses `w` to launch the
        counterfactual mode.
        In the counterfactual mode, the user can control the agent using the keyboard. Once satisfied with their
        progress, the user can give up control to the bot to finish the episode by pressing `w`.

        The keys for controlling the agent are summarized here:

        escape: quit the program

        view mode:
            d: next time step
            a: previous time step
            w: select time step and go to counterfactual mode

        counterfactual mode:
            left:          turn counter-clockwise
            right:         turn clockwise
            up:            go forward
            space:         toggle
            pageup or x:   pickup
            pagedown or z: drop
            enter or q:    done (should not use)
            a:             undo action
            w:             roll out to the end of the episode and move to next episode
    """
    def __init__(self,
                 original_dataset: TrajectoryDataset,
                 counterfactual_dataset: TrajectoryDataset,
                 policy_factory=None):
        self.policy_factory = policy_factory
        self.dataset = original_dataset
        self.counterfactual_dataset = counterfactual_dataset
        self.trajectory_generator = self.dataset.trajectory_generator()
        self.navigator: TrajectoryNavigator = None
        self.window = None
        self.is_counterfactual = False
        self.run()

    def run(self):
        for i, trajectory in enumerate(self.trajectory_generator):
            self.saved = False
            self.is_counterfactual = False
            self.navigator = TrajectoryNavigator(trajectory)
            self.window = Window(f'Trajectory {i}')
            self.window.reg_key_handler(self.key_handler)
            self.reset()
            self.window.show(block=True)
            if not self.saved:
                raise Exception('Continued without saving the trajectory!')

    def redraw(self):
        step: TrajectoryStep = self.navigator.step()
        # if not self.agent_view:
        env = step.state
        img = env.render('rgb_array', tile_size=32)
        # else:
        # img = step.observation['image']
        # TODO later: figure out when to use the observation instead.

        self.window.show_img(img)

    def step(self, action=None):
        if action is None:
            self.navigator.forward()
        else:
            assert isinstance(self.navigator, CounterfactualNavigator)
            self.navigator.forward(action)
        self.redraw()

    def backward(self):
        self.navigator.backward()
        self.redraw()

    def reset(self):
        env = self.navigator.step().state

        if hasattr(env, 'mission'):
            print('Mission: %s' % env.mission)
            self.window.set_caption(env.mission)

        self.redraw()

    def select(self):
        new_navigator = CounterfactualNavigator(
            self.navigator.episode,
            self.navigator.index,
            self.navigator.step(),
            policy_factory=self.policy_factory)
        self.navigator = new_navigator
        self.is_counterfactual = True
        print(
            f'Starting counterfactual trajectory from {self.navigator.index}')
        self.redraw()

    def save_trajectory(self):
        assert isinstance(self.navigator, CounterfactualNavigator)
        self.navigator.store(self.counterfactual_dataset)
        self.saved = True

    def key_handler(self, event):
        print('pressed', event.key)

        if event.key == 'escape':
            self.window.close()
            exit()
            return

        # if event.key == 'backspace':
        #     self.reset()
        #     return
        if self.is_counterfactual:
            if event.key == 'left':
                self.step('left')
                return
            if event.key == 'right':
                self.step('right')
                return
            if event.key == 'up':
                self.step('forward')
                return

            # Spacebar
            if event.key == ' ':
                self.step('toggle')
                return
            if event.key == 'pageup' or event.key == 'x':
                self.step('pickup')
                return
            if event.key == 'pagedown' or event.key == 'z':
                self.step('drop')
                return

            if event.key == 'enter' or event.key == 'q':
                self.step('done')
                return

            if event.key == 'w':
                if self.policy_factory is not None:
                    self.navigator.rollout()
                self.save_trajectory()
                self.window.close()

            if event.key == 'a':
                self.backward()
                return

        if not self.is_counterfactual:
            if event.key == 'd':
                self.step()
                return

            if event.key == 'a':
                self.backward()
                return

            if event.key == 'w':
                self.select()
                return
Exemplo n.º 2
0
    accuracy = torch.cat((accuracy, correct), 0)
    update += 1

    # Print logs
    if update % args.log_interval == 0:
        if args.visualize:
            # Visualize last frame of last sample
            from gym_minigrid.window import Window
            window = Window('gym_minigrid - ' + args.env)
            images = images.transpose(1,2)
            images = images.transpose(2,3)
            print(images[-1].shape)
            print(label)
            window.show_img(images[-1])
            input()
            window.close()

        duration = int(time.time() - start_time)

        header = ["Update", "Time", "Loss", "Accuracy"]
        acc = torch.mean(accuracy)
        data = [update, duration, sum(losses) / len(losses), acc]
        losses = []
        over = (acc >= 0.9999)
        accuracy = torch.tensor([]).to(device)

        txt_logger.info(
            "U {} | T {} | L {:.3f} | A {:.4f}"
            .format(*data))

        if status["update"] == 0: