def visualize_policy_error(qf, env, args):
    model = NumpyModelExtractor(qf, args.cheat, num_steps_left=args.tau)
    policy = RandomPolicy(env.action_space)
    actual_state = env.reset()

    predicted_states = []
    actual_states = []

    predicted_state = actual_state
    for _ in range(args.H):
        predicted_states.append(predicted_state.copy())
        actual_states.append(actual_state.copy())

        action, _ = policy.get_action(actual_state)
        predicted_state = model.next_state(predicted_state, action)
        actual_state = env.step(action)[0]

    predicted_states = np.array(predicted_states)
    actual_states = np.array(actual_states)
    times = np.arange(args.H)

    num_state_dims = env.observation_space.low.size
    dims = list(range(num_state_dims))
    norm = colors.Normalize(vmin=0, vmax=num_state_dims)
    mapper = cm.ScalarMappable(norm=norm, cmap=cm.hsv)
    for dim in dims:
        plt.plot(
            times,
            predicted_states[:, dim],
            '--',
            label='Predicted, Dim {}'.format(dim),
            color=mapper.to_rgba(dim),
        )
        plt.plot(
            times,
            actual_states[:, dim],
            '-',
            label='Actual, Dim {}'.format(dim),
            color=mapper.to_rgba(dim),
        )
    plt.xlabel("Time Steps")
    plt.ylabel("Observation Value")
    plt.legend(loc='best')
    plt.show()
def visualize_policy_error(model, env, horizon):
    policy = RandomPolicy(env.action_space)
    actual_state = env.reset()

    predicted_states = []
    actual_states = []

    predicted_state = actual_state
    for _ in range(horizon):
        predicted_states.append(predicted_state.copy())
        actual_states.append(actual_state.copy())

        action, _ = policy.get_action(actual_state)
        delta = get_np_prediction(model, predicted_state, action)
        predicted_state += delta
        actual_state = env.step(action)[0]

    predicted_states = np.array(predicted_states)
    actual_states = np.array(actual_states)
    times = np.arange(horizon)

    num_state_dims = env.observation_space.low.size
    dims = list(range(num_state_dims))
    norm = colors.Normalize(vmin=0, vmax=num_state_dims)
    mapper = cm.ScalarMappable(norm=norm, cmap=cm.hsv)

    # Plot the predicted and actual values
    plt.subplot(2, 1, 1)
    for dim in dims:
        plt.plot(
            times,
            predicted_states[:, dim],
            '--',
            label='Predicted, Dim {}'.format(dim),
            color=mapper.to_rgba(dim),
        )
        plt.plot(
            times,
            actual_states[:, dim],
            '-',
            label='Actual, Dim {}'.format(dim),
            color=mapper.to_rgba(dim),
        )
    plt.xlabel("Time Steps")
    plt.ylabel("Observation Value")
    plt.legend(loc='best')

    # Plot the predicted and actual value errors
    plt.subplot(2, 1, 2)
    for dim in dims:
        plt.plot(
            times,
            np.abs(predicted_states[:, dim] - actual_states[:, dim]),
            '-',
            label='Dim {}'.format(dim),
            color=mapper.to_rgba(dim),
        )
    plt.xlabel("Time Steps")
    plt.ylabel("|Predicted - Actual| - Absolute Error")
    plt.legend(loc='best')
    plt.show()

    nrows = min(5, num_state_dims)
    ncols = math.ceil(num_state_dims / nrows)
    fig = plt.figure()
    for dim in dims:
        ax = fig.add_subplot(nrows, ncols, dim+1)
        ax.plot(
            times,
            predicted_states[:, dim],
            '--',
            label='Predicted, Dim {}'.format(dim),
        )
        ax.plot(
            times,
            actual_states[:, dim],
            '-',
            label='Actual, Dim {}'.format(dim),
        )
        ax.set_ylabel("Observation Value")
        ax.set_xlabel("Time Steps")
        ax.set_title("Dim {}".format(dim))
        ax_error = ax.twinx()
        ax_error.plot(
            times,
            np.abs(predicted_states[:, dim] - actual_states[:, dim]),
            '.',
            label='Error, Dim {}'.format(dim),
            color='r',
        )
        ax_error.set_ylabel("Error", color='r')
        ax_error.tick_params('y', colors='r')
        ax.legend(loc='best')
    plt.show()