def display_vehicles_attention(cls, agent, sim_surface): import pygame try: state = agent.previous_state if (not hasattr(cls, "state")) or (cls.state != state).any(): cls.v_attention = cls.compute_vehicles_attention(agent, state) cls.state = state for head in range(list(cls.v_attention.values())[0].shape[0]): attention_surface = pygame.Surface(sim_surface.get_size(), pygame.SRCALPHA) for vehicle, attention in cls.v_attention.items(): if attention[head] < cls.MIN_ATTENTION: continue width = attention[head] * 5 desat = remap(attention[head], (0, 0.5), (0.7, 1), clip=True) colors = sns.color_palette("dark", desat=desat) color = np.array(colors[2-2*head]) * 255 color = (*color, remap(attention[head], (0, 0.5), (100, 200), clip=True)) if vehicle is agent.env.vehicle: pygame.draw.circle(attention_surface, color, sim_surface.vec2pix(agent.env.vehicle.position), max(sim_surface.pix(width / 2), 1)) else: pygame.draw.line(attention_surface, color, sim_surface.vec2pix(agent.env.vehicle.position), sim_surface.vec2pix(vehicle.position), max(sim_surface.pix(width), 1)) sim_surface.blit(attention_surface, (0, 0)) except ValueError as e: print("Unable to display vehicles attention", e)
def compute_vehicles_attention(cls, agent, state): import torch state_t = torch.tensor([state], dtype=torch.float).to(agent.device) attention = agent.value_net.get_attention_matrix(state_t).squeeze( 0).squeeze(1).detach().cpu().numpy() ego, others, mask = agent.value_net.split_input(state_t) mask = mask.squeeze() v_attention = {} for v_index in range(state.shape[0]): if mask[v_index]: continue v_position = {} for feature in ["x", "y"]: v_feature = state[ v_index, agent.env.observation.features.index(feature)] v_feature = remap( v_feature, [-1, 1], agent.env.observation.features_range[feature]) v_position[feature] = v_feature v_position = np.array([v_position["x"], v_position["y"]]) if not agent.env.observation.absolute and v_index > 0: v_position += agent.env.unwrapped.vehicle.position vehicle = min( agent.env.road.vehicles, key=lambda v: np.linalg.norm(v.position - v_position)) v_attention[vehicle] = attention[:, v_index] return v_attention
def _plot_node(self, node, pos, ax, depth=0): if depth > self.max_depth: return for a in range(self.actions): if a in node.children: child = node.children[a] if not child.count: continue d = 1 / self.actions**depth pos_child = [pos[0] - d/2 + a/(self.actions - 1)*d, pos[1] - 1/self.max_depth] width = constrain(remap(child.count, (1, self.total_count), (0.5, 4)), 0.5, 4) ax.plot([pos[0], pos_child[0]], [pos[1], pos_child[1]], 'k', linewidth=width, solid_capstyle='round') self._plot_node(child, pos_child, ax, depth+1)
def display(cls, agent, surface, sim_surface=None, display_text=True): """ Display the action-values for the current state :param agent: the DQNAgent to be displayed :param surface: the pygame surface on which the agent is displayed :param sim_surface: the pygame surface on which the env is rendered :param display_text: whether to display the action values as text """ import pygame action_values = agent.get_state_action_values(agent.previous_state) action_distribution = agent.action_distribution(agent.previous_state) cell_size = (surface.get_width() // len(action_values), surface.get_height()) pygame.draw.rect(surface, cls.BLACK, (0, 0, surface.get_width(), surface.get_height()), 0) # Display node value for action, value in enumerate(action_values): cmap = cm.jet_r norm = mpl.colors.Normalize(vmin=0, vmax=1 / (1 - agent.config["gamma"])) color = cmap(norm(value), bytes=True) pygame.draw.rect( surface, color, (cell_size[0] * action, 0, cell_size[0], cell_size[1]), 0) if display_text: font = pygame.font.Font(None, 15) text = "v={:.2f} / p={:.2f}".format( value, action_distribution[action]) text = font.render(text, 1, (10, 10, 10), (255, 255, 255)) surface.blit(text, (cell_size[0] * action, 0)) if sim_surface: try: state = agent.previous_state if (not hasattr(cls, "state")) or (cls.state != state).any(): cls.v_attention = cls.compute_vehicles_attention( agent, state) cls.state = state for head in range(list(cls.v_attention.values())[0].shape[0]): attention_surface = pygame.Surface(sim_surface.get_size(), pygame.SRCALPHA) for vehicle, attention in cls.v_attention.items(): if attention[head] < cls.MIN_ATTENTION: continue width = attention[head] * 5 desat = remap(attention[head], (0, 0.5), (0.7, 1), clip=True) colors = sns.color_palette("dark", desat=desat) color = np.array(colors[2 - 2 * head]) * 255 color = (*color, remap(attention[head], (0, 0.5), (100, 200), clip=True)) if vehicle is agent.env.vehicle: pygame.draw.circle( attention_surface, color, sim_surface.vec2pix( agent.env.vehicle.position), max(sim_surface.pix(width / 2), 1)) else: pygame.draw.line( attention_surface, color, sim_surface.vec2pix( agent.env.vehicle.position), sim_surface.vec2pix(vehicle.position), max(sim_surface.pix(width), 1)) sim_surface.blit(attention_surface, (0, 0)) except ValueError as e: print("Unable to display vehicles attention", e)