예제 #1
0
def plot_features(ro: StepSequence, policy: Policy):
    """
    Plot all features given the policy and the observation trajectories.

    :param policy: linear policy used during the rollout
    :param ro: input rollout
    """
    if not isinstance(policy, LinearPolicy):
        print_cbt(
            'Plotting of the feature values is only supports linear policies!',
            'y')
        return

    if hasattr(ro, 'observations'):
        # Use recorded time stamps if possible
        t = ro.env_infos.get('t', np.arange(0, ro.length)) if hasattr(
            ro, 'env_infos') else np.arange(0, ro.length)

        # Recover the features from the observations
        feat_vals = policy.eval_feats(to.from_numpy(ro.observations))
        dim_feat = range(feat_vals.shape[1])
        if len(dim_feat) <= 6:
            divisor = 2
        elif len(dim_feat) <= 12:
            divisor = 4
        else:
            divisor = 8
        num_cols = int(np.ceil(len(dim_feat) / divisor))
        num_rows = int(np.ceil(len(dim_feat) / num_cols))

        fig, axs = plt.subplots(num_rows,
                                num_cols,
                                figsize=(num_cols * 5, num_rows * 3),
                                constrained_layout=True)
        fig.suptitle('Feature values over Time')
        plt.subplots_adjust(hspace=.5)
        colors = plt.get_cmap('tab20')(np.linspace(0, 1, len(dim_feat)))

        if len(dim_feat) == 1:
            axs.plot(t,
                     feat_vals[:-1, dim_feat[0]],
                     label=_get_obs_label(ro, dim_feat[0]))
            axs.legend()
        else:
            for i in range(num_rows):
                for j in range(num_cols):
                    if j + i * num_cols < len(dim_feat):
                        # Omit the last observation for simplicity
                        axs[i, j].plot(t,
                                       feat_vals[:-1, j + i * num_cols],
                                       label=rf'$\phi_{j + i*num_cols}$',
                                       c=colors[j + i * num_cols])
                        axs[i, j].legend()
                    else:
                        # We might create more subplots than there are observations
                        pass
        plt.show()
예제 #2
0
def plot_features(ro: StepSequence, policy: Policy):
    """
    Plot all features given the policy and the observation trajectories.

    :param policy: linear policy used during the rollout
    :param ro: input rollout
    """
    if not isinstance(policy, LinearPolicy):
        print_cbt(
            "Plotting of the feature values is only supports linear policies!",
            "r")
        return

    if hasattr(ro, "observations"):
        # Use recorded time stamps if possible
        t = getattr(ro, "time", np.arange(0, ro.length + 1))[:-1]

        # Recover the features from the observations
        feat_vals = policy.eval_feats(to.from_numpy(ro.observations))
        dim_feat = range(feat_vals.shape[1])
        if len(dim_feat) <= 6:
            divisor = 2
        elif len(dim_feat) <= 12:
            divisor = 4
        else:
            divisor = 8
        num_cols = int(np.ceil(len(dim_feat) / divisor))
        num_rows = int(np.ceil(len(dim_feat) / num_cols))

        fig, axs = plt.subplots(num_rows,
                                num_cols,
                                figsize=(num_cols * 5, num_rows * 3),
                                tight_layout=True)
        axs = np.atleast_2d(axs)
        axs = correct_atleast_2d(axs)
        fig.canvas.manager.set_window_title("Feature Values over Time")
        plt.subplots_adjust(hspace=0.5)
        colors = plt.get_cmap("tab20")(np.linspace(0, 1, len(dim_feat)))

        if len(dim_feat) == 1:
            axs[0, 0].plot(t,
                           feat_vals[:-1, dim_feat[0]],
                           label=_get_obs_label(ro, dim_feat[0]))
            axs[0, 0].legend()
        else:
            for i in range(num_rows):
                for j in range(num_cols):
                    if j + i * num_cols < len(dim_feat):
                        # Omit the last observation for simplicity
                        axs[i, j].plot(t,
                                       feat_vals[:-1, j + i * num_cols],
                                       c=colors[j + i * num_cols])
                        axs[i, j].set_ylabel(rf"$\phi_{{{j + i*num_cols}}}$")
                    else:
                        # We might create more subplots than there are observations
                        axs[i, j].remove()