示例#1
0
    def plot_dynamics_model(self, option_name: str, model: np.ndarray,
                            save_path: Union[Path, str]):
        """ Plots probability transition matrices for each state as a subplots.
        
        Averages over direction dimension for simpler visualization.
        """

        subplot_idx = [(i, j) for i in range(self.grid_size)
                       for j in range(self.grid_size)]
        fig = tools.make_subplots(cols=self.grid_size, rows=self.grid_size)
        fig = self._remove_tick_labels(fig)
        full_grid = 0.99 * self._normalize_matrix(self._get_env_grid()).T
        full_grid = np.flip(full_grid, axis=0)

        # Turn (4, 19, 19, 4, 19, 19) into (19*19, 19, 19)
        dim = model.shape[-2:]
        model = np.mean(model, axis=(0, 3)).reshape(np.prod(dim), *dim)

        for i, p_map in enumerate(model):
            if full_grid[subplot_idx[i]] != 0:
                continue

            heatmap = go.Heatmap(
                z=np.flip(p_map, axis=0),
                showscale=False,
            )
            fig.append_trace(heatmap, subplot_idx[i][0], subplot_idx[i][1])

        fig['layout'].update(title=f'{option_name} dynamics model',
                             height=1600,
                             width=1600)
        plotly.offline.plot(fig, filename=f'{save_path}.html')
        plotlyfig2json(fig=fig, fpath=f'{save_path}.json')

        return fig
示例#2
0
    def plot_option_value_function(self, option_name: str,
                                   option_state_values: np.ndarray,
                                   save_path: Union[Path, str]):
        """ Plots V of an agent associated with a particular mini-grid.
        
        Creates a figure with 4 subplots c
        """
        directions = option_state_values.shape[0]
        fig = tools.make_subplots(cols=directions, rows=1)

        # # Create annotations for the heatmap that correspond to policy
        # symbols = [self.unicode_actions[a] for a in option_policy.tolist()]
        # symbols = np.array(symbols).reshape(self.dim).tolist()

        # Create a heatmap object
        for direction, state_values in enumerate(option_state_values):
            heatmap = go.Heatmap(
                z=np.flip(state_values, axis=0),
                showscale=False,
                reversescale=False,
            )
            fig.append_trace(heatmap, 1, direction + 1)
        fig['layout'].update(title=f'{option_name}', height=400, width=800)
        plotly.offline.plot(fig, filename=f'{save_path}.html')
        plotlyfig2json(fig=fig, fpath=f'{save_path}.json')
示例#3
0
    def plot_qs_plotly(self, action_values_options, current_states_flat,
                       config, ep):
        idx = [(i, j) for i in range(self.grid_size)
               for j in range(self.grid_size)]
        full_grid = 0.99 * self._normalize_matrix(self._get_env_grid())

        for option_id, (option_instance, q) in enumerate(
                zip(config.options[:-config.action_size],
                    action_values_options)):

            q = np.reshape(q,
                           (*config.expanded_obs_size, config.action_size + 1))
            v = np.max(q, axis=-1)

            # Since the walls have high q-values, make sure to zero-out their values before normalization
            for j, k in np.array(np.where(full_grid != 0)).T:
                v[j, k] = 0
            v = 0.099 * self._normalize_matrix(v)

            # Create annotations for the heatmap that correspond to policy
            symbols = [
                self.unicode_actions[option_instance.target_policy(q[i, j])[0]]
                for i, j in idx
            ]
            symbols = np.array(symbols).reshape(self.dim).tolist()

            # Merge the heatmap of the grid with policy values
            grid = full_grid.copy()
            for j, k in np.array(np.where(full_grid == 0)).T:
                grid[j, k] = v[j, k]

            # Create a heatmap object
            heatmap = ff.create_annotated_heatmap(
                np.flip(grid, axis=0),
                showscale=False,
                annotation_text=symbols[::-1],
                font_colors=['black'],
                colorscale=[[0., 'white'], [0.1, self.colors[option_id]],
                            [0.10001, 'grey'], [1., 'black']],
            )

            # Increase the font size so that the policy is visible
            for i in range(len(heatmap.layout.annotations)):
                heatmap.layout.annotations[i].font.size = 25

            # Save the figure as a json(for adding to the dashboard) and html (for quick view)
            path = f'{config.images_path}/learn_options'
            os.makedirs(f'{path}/json', exist_ok=True)
            os.makedirs(f'{path}/html', exist_ok=True)
            plotlyfig2json(heatmap,
                           fpath=f'{path}/json/qs_{option_id}_{ep}.json')
            plotly.offline.plot(
                heatmap, filename=f'{path}/html/qs_{option_id}_{ep}.html')

        return []
示例#4
0
    def plot_reward_model(self, option_name: str, model: np.ndarray,
                          save_path: Union[Path, str]):
        fig = go.Figure(data=go.Heatmap(
            z=np.flip(np.mean(model, axis=0), axis=0),
            showscale=False,
        ),
                        layout=dict(
                            title=f'reward model of "{option_name}"',
                            height=1600,
                            width=1600,
                        ))
        plotly.offline.plot(fig, filename=f'{save_path}.html')
        plotlyfig2json(fig=fig, fpath=f'{save_path}.json')

        return fig
示例#5
0
    def plot_models_sf_plotly(self, config, current_obs, successor_features,
                              option_name, option_id):
        subplot_idx = [(i, j) for i in range(self.grid_size)
                       for j in range(self.grid_size)]
        fig = tools.make_subplots(cols=self.grid_size, rows=self.grid_size)
        fig = self._remove_tick_labels(fig)
        full_grid = 0.99 * self._normalize_matrix(self._get_env_grid())

        for i, (current_state,
                sfeat) in enumerate(zip(current_obs, successor_features)):

            if np.flip(np.flip(full_grid, axis=1),
                       axis=0)[subplot_idx[i]] != 0:
                continue

            sfeat = np.reshape(sfeat, config.expanded_obs_size)
            sfeat = 0.99 * self._normalize_matrix(sfeat)

            grid = full_grid.copy()
            for j, k in np.array(np.where(full_grid == 0)).T:
                grid[j, k] = sfeat[j, k]
            grid[self.wrapper.obs_to_state(current_state)] = 1

            heatmap = go.Heatmap(
                z=np.flip(grid, axis=0),
                showscale=False,
                colorscale=[[0., 'white'], [0.0999, self.colors[option_id]],
                            [0.1, 'grey'], [0.999, 'black'], [1., 'red']],
            )
            fig.append_trace(heatmap, subplot_idx[i][0], subplot_idx[i][1])

        fig['layout'].update(title=f'{option_name} SF', height=800, width=800)

        plotlyfig2json(
            fig,
            f'{config.images_path}/learn_models/option_{option_id}_sf.json')
        plotly.offline.plot(
            fig,
            filename=
            f'{config.images_path}/learn_models/option_{option_id}_sf.html')

        return fig