def display_vehicles_attention(cls, agent, sim_surface):
        import pygame
        try:
            state = agent.previous_state
            if (not hasattr(cls, "state")) or (cls.state != state).any():
                cls.v_attention = cls.compute_vehicles_attention(agent, state)
                cls.state = state

            for head in range(list(cls.v_attention.values())[0].shape[0]):
                attention_surface = pygame.Surface(sim_surface.get_size(), pygame.SRCALPHA)
                for vehicle, attention in cls.v_attention.items():
                    if attention[head] < cls.MIN_ATTENTION:
                        continue
                    width = attention[head] * 5
                    desat = remap(attention[head], (0, 0.5), (0.7, 1), clip=True)
                    colors = sns.color_palette("dark", desat=desat)
                    color = np.array(colors[2-2*head]) * 255
                    color = (*color, remap(attention[head], (0, 0.5), (100, 200), clip=True))
                    if vehicle is agent.env.vehicle:
                        pygame.draw.circle(attention_surface, color,
                                           sim_surface.vec2pix(agent.env.vehicle.position),
                                           max(sim_surface.pix(width / 2), 1))
                    else:
                        pygame.draw.line(attention_surface, color,
                                         sim_surface.vec2pix(agent.env.vehicle.position),
                                         sim_surface.vec2pix(vehicle.position),
                                         max(sim_surface.pix(width), 1))
                sim_surface.blit(attention_surface, (0, 0))
        except ValueError as e:
            print("Unable to display vehicles attention", e)
 def compute_vehicles_attention(cls, agent, state):
     import torch
     state_t = torch.tensor([state], dtype=torch.float).to(agent.device)
     attention = agent.value_net.get_attention_matrix(state_t).squeeze(
         0).squeeze(1).detach().cpu().numpy()
     ego, others, mask = agent.value_net.split_input(state_t)
     mask = mask.squeeze()
     v_attention = {}
     for v_index in range(state.shape[0]):
         if mask[v_index]:
             continue
         v_position = {}
         for feature in ["x", "y"]:
             v_feature = state[
                 v_index,
                 agent.env.observation.features.index(feature)]
             v_feature = remap(
                 v_feature, [-1, 1],
                 agent.env.observation.features_range[feature])
             v_position[feature] = v_feature
         v_position = np.array([v_position["x"], v_position["y"]])
         if not agent.env.observation.absolute and v_index > 0:
             v_position += agent.env.unwrapped.vehicle.position
         vehicle = min(
             agent.env.road.vehicles,
             key=lambda v: np.linalg.norm(v.position - v_position))
         v_attention[vehicle] = attention[:, v_index]
     return v_attention
Beispiel #3
0
 def _plot_node(self, node, pos, ax, depth=0):
     if depth > self.max_depth:
         return
     for a in range(self.actions):
         if a in node.children:
             child = node.children[a]
             if not child.count:
                 continue
             d = 1 / self.actions**depth
             pos_child = [pos[0] - d/2 + a/(self.actions - 1)*d, pos[1] - 1/self.max_depth]
             width = constrain(remap(child.count, (1, self.total_count), (0.5, 4)), 0.5, 4)
             ax.plot([pos[0], pos_child[0]], [pos[1], pos_child[1]], 'k', linewidth=width, solid_capstyle='round')
             self._plot_node(child, pos_child, ax, depth+1)
Beispiel #4
0
    def display(cls, agent, surface, sim_surface=None, display_text=True):
        """
            Display the action-values for the current state

        :param agent: the DQNAgent to be displayed
        :param surface: the pygame surface on which the agent is displayed
        :param sim_surface: the pygame surface on which the env is rendered
        :param display_text: whether to display the action values as text
        """
        import pygame
        action_values = agent.get_state_action_values(agent.previous_state)
        action_distribution = agent.action_distribution(agent.previous_state)

        cell_size = (surface.get_width() // len(action_values),
                     surface.get_height())
        pygame.draw.rect(surface, cls.BLACK,
                         (0, 0, surface.get_width(), surface.get_height()), 0)

        # Display node value
        for action, value in enumerate(action_values):
            cmap = cm.jet_r
            norm = mpl.colors.Normalize(vmin=0,
                                        vmax=1 / (1 - agent.config["gamma"]))
            color = cmap(norm(value), bytes=True)
            pygame.draw.rect(
                surface, color,
                (cell_size[0] * action, 0, cell_size[0], cell_size[1]), 0)

            if display_text:
                font = pygame.font.Font(None, 15)
                text = "v={:.2f} / p={:.2f}".format(
                    value, action_distribution[action])
                text = font.render(text, 1, (10, 10, 10), (255, 255, 255))
                surface.blit(text, (cell_size[0] * action, 0))

        if sim_surface:
            try:
                state = agent.previous_state
                if (not hasattr(cls, "state")) or (cls.state != state).any():
                    cls.v_attention = cls.compute_vehicles_attention(
                        agent, state)
                    cls.state = state

                for head in range(list(cls.v_attention.values())[0].shape[0]):
                    attention_surface = pygame.Surface(sim_surface.get_size(),
                                                       pygame.SRCALPHA)
                    for vehicle, attention in cls.v_attention.items():
                        if attention[head] < cls.MIN_ATTENTION:
                            continue
                        width = attention[head] * 5
                        desat = remap(attention[head], (0, 0.5), (0.7, 1),
                                      clip=True)
                        colors = sns.color_palette("dark", desat=desat)
                        color = np.array(colors[2 - 2 * head]) * 255
                        color = (*color,
                                 remap(attention[head], (0, 0.5), (100, 200),
                                       clip=True))
                        if vehicle is agent.env.vehicle:
                            pygame.draw.circle(
                                attention_surface, color,
                                sim_surface.vec2pix(
                                    agent.env.vehicle.position),
                                max(sim_surface.pix(width / 2), 1))
                        else:
                            pygame.draw.line(
                                attention_surface, color,
                                sim_surface.vec2pix(
                                    agent.env.vehicle.position),
                                sim_surface.vec2pix(vehicle.position),
                                max(sim_surface.pix(width), 1))
                    sim_surface.blit(attention_surface, (0, 0))
            except ValueError as e:
                print("Unable to display vehicles attention", e)