class Interface: """ User interface for generating counterfactual states and actions. For each trajectory: The user can use the `a` and `d` keys to go backward and forward in time to a specific state they wish to generate a counterfactual explanation from. Once a specific state is found, the user presses `w` to launch the counterfactual mode. In the counterfactual mode, the user can control the agent using the keyboard. Once satisfied with their progress, the user can give up control to the bot to finish the episode by pressing `w`. The keys for controlling the agent are summarized here: escape: quit the program view mode: d: next time step a: previous time step w: select time step and go to counterfactual mode counterfactual mode: left: turn counter-clockwise right: turn clockwise up: go forward space: toggle pageup or x: pickup pagedown or z: drop enter or q: done (should not use) a: undo action w: roll out to the end of the episode and move to next episode """ def __init__(self, original_dataset: TrajectoryDataset, counterfactual_dataset: TrajectoryDataset, policy_factory=None): self.policy_factory = policy_factory self.dataset = original_dataset self.counterfactual_dataset = counterfactual_dataset self.trajectory_generator = self.dataset.trajectory_generator() self.navigator: TrajectoryNavigator = None self.window = None self.is_counterfactual = False self.run() def run(self): for i, trajectory in enumerate(self.trajectory_generator): self.saved = False self.is_counterfactual = False self.navigator = TrajectoryNavigator(trajectory) self.window = Window(f'Trajectory {i}') self.window.reg_key_handler(self.key_handler) self.reset() self.window.show(block=True) if not self.saved: raise Exception('Continued without saving the trajectory!') def redraw(self): step: TrajectoryStep = self.navigator.step() # if not self.agent_view: env = step.state img = env.render('rgb_array', tile_size=32) # else: # img = step.observation['image'] # TODO later: figure out when to use the observation instead. self.window.show_img(img) def step(self, action=None): if action is None: self.navigator.forward() else: assert isinstance(self.navigator, CounterfactualNavigator) self.navigator.forward(action) self.redraw() def backward(self): self.navigator.backward() self.redraw() def reset(self): env = self.navigator.step().state if hasattr(env, 'mission'): print('Mission: %s' % env.mission) self.window.set_caption(env.mission) self.redraw() def select(self): new_navigator = CounterfactualNavigator( self.navigator.episode, self.navigator.index, self.navigator.step(), policy_factory=self.policy_factory) self.navigator = new_navigator self.is_counterfactual = True print( f'Starting counterfactual trajectory from {self.navigator.index}') self.redraw() def save_trajectory(self): assert isinstance(self.navigator, CounterfactualNavigator) self.navigator.store(self.counterfactual_dataset) self.saved = True def key_handler(self, event): print('pressed', event.key) if event.key == 'escape': self.window.close() exit() return # if event.key == 'backspace': # self.reset() # return if self.is_counterfactual: if event.key == 'left': self.step('left') return if event.key == 'right': self.step('right') return if event.key == 'up': self.step('forward') return # Spacebar if event.key == ' ': self.step('toggle') return if event.key == 'pageup' or event.key == 'x': self.step('pickup') return if event.key == 'pagedown' or event.key == 'z': self.step('drop') return if event.key == 'enter' or event.key == 'q': self.step('done') return if event.key == 'w': if self.policy_factory is not None: self.navigator.rollout() self.save_trajectory() self.window.close() if event.key == 'a': self.backward() return if not self.is_counterfactual: if event.key == 'd': self.step() return if event.key == 'a': self.backward() return if event.key == 'w': self.select() return
accuracy = torch.cat((accuracy, correct), 0) update += 1 # Print logs if update % args.log_interval == 0: if args.visualize: # Visualize last frame of last sample from gym_minigrid.window import Window window = Window('gym_minigrid - ' + args.env) images = images.transpose(1,2) images = images.transpose(2,3) print(images[-1].shape) print(label) window.show_img(images[-1]) input() window.close() duration = int(time.time() - start_time) header = ["Update", "Time", "Loss", "Accuracy"] acc = torch.mean(accuracy) data = [update, duration, sum(losses) / len(losses), acc] losses = [] over = (acc >= 0.9999) accuracy = torch.tensor([]).to(device) txt_logger.info( "U {} | T {} | L {:.3f} | A {:.4f}" .format(*data)) if status["update"] == 0: