def plot_features(ro: StepSequence, policy: Policy): """ Plot all features given the policy and the observation trajectories. :param policy: linear policy used during the rollout :param ro: input rollout """ if not isinstance(policy, LinearPolicy): print_cbt( 'Plotting of the feature values is only supports linear policies!', 'y') return if hasattr(ro, 'observations'): # Use recorded time stamps if possible t = ro.env_infos.get('t', np.arange(0, ro.length)) if hasattr( ro, 'env_infos') else np.arange(0, ro.length) # Recover the features from the observations feat_vals = policy.eval_feats(to.from_numpy(ro.observations)) dim_feat = range(feat_vals.shape[1]) if len(dim_feat) <= 6: divisor = 2 elif len(dim_feat) <= 12: divisor = 4 else: divisor = 8 num_cols = int(np.ceil(len(dim_feat) / divisor)) num_rows = int(np.ceil(len(dim_feat) / num_cols)) fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 5, num_rows * 3), constrained_layout=True) fig.suptitle('Feature values over Time') plt.subplots_adjust(hspace=.5) colors = plt.get_cmap('tab20')(np.linspace(0, 1, len(dim_feat))) if len(dim_feat) == 1: axs.plot(t, feat_vals[:-1, dim_feat[0]], label=_get_obs_label(ro, dim_feat[0])) axs.legend() else: for i in range(num_rows): for j in range(num_cols): if j + i * num_cols < len(dim_feat): # Omit the last observation for simplicity axs[i, j].plot(t, feat_vals[:-1, j + i * num_cols], label=rf'$\phi_{j + i*num_cols}$', c=colors[j + i * num_cols]) axs[i, j].legend() else: # We might create more subplots than there are observations pass plt.show()
def plot_features(ro: StepSequence, policy: Policy): """ Plot all features given the policy and the observation trajectories. :param policy: linear policy used during the rollout :param ro: input rollout """ if not isinstance(policy, LinearPolicy): print_cbt( "Plotting of the feature values is only supports linear policies!", "r") return if hasattr(ro, "observations"): # Use recorded time stamps if possible t = getattr(ro, "time", np.arange(0, ro.length + 1))[:-1] # Recover the features from the observations feat_vals = policy.eval_feats(to.from_numpy(ro.observations)) dim_feat = range(feat_vals.shape[1]) if len(dim_feat) <= 6: divisor = 2 elif len(dim_feat) <= 12: divisor = 4 else: divisor = 8 num_cols = int(np.ceil(len(dim_feat) / divisor)) num_rows = int(np.ceil(len(dim_feat) / num_cols)) fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 5, num_rows * 3), tight_layout=True) axs = np.atleast_2d(axs) axs = correct_atleast_2d(axs) fig.canvas.manager.set_window_title("Feature Values over Time") plt.subplots_adjust(hspace=0.5) colors = plt.get_cmap("tab20")(np.linspace(0, 1, len(dim_feat))) if len(dim_feat) == 1: axs[0, 0].plot(t, feat_vals[:-1, dim_feat[0]], label=_get_obs_label(ro, dim_feat[0])) axs[0, 0].legend() else: for i in range(num_rows): for j in range(num_cols): if j + i * num_cols < len(dim_feat): # Omit the last observation for simplicity axs[i, j].plot(t, feat_vals[:-1, j + i * num_cols], c=colors[j + i * num_cols]) axs[i, j].set_ylabel(rf"$\phi_{{{j + i*num_cols}}}$") else: # We might create more subplots than there are observations axs[i, j].remove()