Пример #1
0
    def plot_dynamics_model(self, option_name: str, model: np.ndarray,
                            save_path: Union[Path, str]):
        """ Plots probability transition matrices for each state as a subplots.
        
        Averages over direction dimension for simpler visualization.
        """

        subplot_idx = [(i, j) for i in range(self.grid_size)
                       for j in range(self.grid_size)]
        fig = tools.make_subplots(cols=self.grid_size, rows=self.grid_size)
        fig = self._remove_tick_labels(fig)
        full_grid = 0.99 * self._normalize_matrix(self._get_env_grid()).T
        full_grid = np.flip(full_grid, axis=0)

        # Turn (4, 19, 19, 4, 19, 19) into (19*19, 19, 19)
        dim = model.shape[-2:]
        model = np.mean(model, axis=(0, 3)).reshape(np.prod(dim), *dim)

        for i, p_map in enumerate(model):
            if full_grid[subplot_idx[i]] != 0:
                continue

            heatmap = go.Heatmap(
                z=np.flip(p_map, axis=0),
                showscale=False,
            )
            fig.append_trace(heatmap, subplot_idx[i][0], subplot_idx[i][1])

        fig['layout'].update(title=f'{option_name} dynamics model',
                             height=1600,
                             width=1600)
        plotly.offline.plot(fig, filename=f'{save_path}.html')
        plotlyfig2json(fig=fig, fpath=f'{save_path}.json')

        return fig
Пример #2
0
    def plot_option_value_function(self, option_name: str,
                                   option_state_values: np.ndarray,
                                   save_path: Union[Path, str]):
        """ Plots V of an agent associated with a particular mini-grid.
        
        Creates a figure with 4 subplots c
        """
        directions = option_state_values.shape[0]
        fig = tools.make_subplots(cols=directions, rows=1)

        # # Create annotations for the heatmap that correspond to policy
        # symbols = [self.unicode_actions[a] for a in option_policy.tolist()]
        # symbols = np.array(symbols).reshape(self.dim).tolist()

        # Create a heatmap object
        for direction, state_values in enumerate(option_state_values):
            heatmap = go.Heatmap(
                z=np.flip(state_values, axis=0),
                showscale=False,
                reversescale=False,
            )
            fig.append_trace(heatmap, 1, direction + 1)
        fig['layout'].update(title=f'{option_name}', height=400, width=800)
        plotly.offline.plot(fig, filename=f'{save_path}.html')
        plotlyfig2json(fig=fig, fpath=f'{save_path}.json')
Пример #3
0
    def plot_reward_model(self, option_name: str, model: np.ndarray,
                          save_path: Union[Path, str]):
        fig = go.Figure(data=go.Heatmap(
            z=np.flip(np.mean(model, axis=0), axis=0),
            showscale=False,
        ),
                        layout=dict(
                            title=f'reward model of "{option_name}"',
                            height=1600,
                            width=1600,
                        ))
        plotly.offline.plot(fig, filename=f'{save_path}.html')
        plotlyfig2json(fig=fig, fpath=f'{save_path}.json')

        return fig
Пример #4
0
    def plot_models_sf_plotly(self, config, current_obs, successor_features,
                              option_name, option_id):
        subplot_idx = [(i, j) for i in range(self.grid_size)
                       for j in range(self.grid_size)]
        fig = tools.make_subplots(cols=self.grid_size, rows=self.grid_size)
        fig = self._remove_tick_labels(fig)
        full_grid = 0.99 * self._normalize_matrix(self._get_env_grid())

        for i, (current_state,
                sfeat) in enumerate(zip(current_obs, successor_features)):

            if np.flip(np.flip(full_grid, axis=1),
                       axis=0)[subplot_idx[i]] != 0:
                continue

            sfeat = np.reshape(sfeat, config.expanded_obs_size)
            sfeat = 0.99 * self._normalize_matrix(sfeat)

            grid = full_grid.copy()
            for j, k in np.array(np.where(full_grid == 0)).T:
                grid[j, k] = sfeat[j, k]
            grid[self.wrapper.obs_to_state(current_state)] = 1

            heatmap = go.Heatmap(
                z=np.flip(grid, axis=0),
                showscale=False,
                colorscale=[[0., 'white'], [0.0999, self.colors[option_id]],
                            [0.1, 'grey'], [0.999, 'black'], [1., 'red']],
            )
            fig.append_trace(heatmap, subplot_idx[i][0], subplot_idx[i][1])

        fig['layout'].update(title=f'{option_name} SF', height=800, width=800)

        plotlyfig2json(
            fig,
            f'{config.images_path}/learn_models/option_{option_id}_sf.json')
        plotly.offline.plot(
            fig,
            filename=
            f'{config.images_path}/learn_models/option_{option_id}_sf.html')

        return fig
Пример #5
0
    def plot_goals_plotly(self, config):
        fig = tools.make_subplots(cols=config.nb_options, rows=1)
        full_grid = 0.99 * self._normalize_matrix(self._get_env_grid())
        axis_template = dict(showgrid=False,
                             zeroline=False,
                             showticklabels=False,
                             ticks='')

        for o in range(config.nb_options):

            # Normalize the goal matrix
            goals = np.reshape(config.intrinsic_motivation.goal_reward_maps[o],
                               config.expanded_obs_size)
            for j, k in np.array(np.where(full_grid == 0)).T:
                if (j, k) not in self.valid_cells:
                    goals[
                        j,
                        k] = 0  # Ensure that invalid cells do not affect normalization
            goals = 0.099 * self._normalize_matrix(goals)

            # Merge the full grid with the goals heatmap
            grid = full_grid.copy()
            for j, k in np.array(np.where(full_grid == 0)).T:
                grid[j, k] = goals[j, k]

            heatmap = go.Heatmap(
                z=np.flip(grid, axis=0),
                showscale=False,
                colorscale=[[0., 'white'], [0.1, self.colors[o]],
                            [0.10001, 'grey'], [1., 'black']],
            )
            fig.append_trace(heatmap, 1, o + 1)
            fig['layout'][f'xaxis{o + 1}'].update(axis_template)
            fig['layout'][f'yaxis{o + 1}'].update(axis_template)

        fig['layout'].update(title='Goals',
                             width=300 * config.nb_options,
                             height=500)

        return fig