示例#1
0
def plot_vfunction_2D(vfunction, env, plot=True, figname="vfunction.pdf", foldername='/plots/', save_figure=True, definition=50) -> None:
    """
    Plot a value function in a 2-dimensional state space
    :param vfunction: the value function to be plotted
    :param env: the environment
    :param plot: whether the plot should be interactive
    :param figname: the name of the file where to plot the function
    :param foldername: the name of the folder where to put the file
    :param save_figure: whether the plot should be saved into a file
    :param definition: the resolution of the plot
    :return: nothing
    """
    if env.observation_space.shape[0] != 2:
        raise(ValueError("Observation space dimension {}, should be 2".format(env.observation_space.shape[0])))

    portrait = np.zeros((definition, definition))
    x_min, y_min = env.observation_space.low
    x_max, y_max = env.observation_space.high

    for index_x, x in enumerate(np.linspace(x_min, x_max, num=definition)):
        for index_y, y in enumerate(np.linspace(y_min, y_max, num=definition)):
            # Be careful to fill the matrix in the right order
            portrait[definition - (1 + index_y), index_x] = vfunction.evaluate(np.array([[x, y]]))

    plt.figure(figsize=(10, 10))
    plt.imshow(portrait, cmap="inferno", extent=[x_min, x_max, y_min, y_max], aspect='auto')
    plt.colorbar(label="critic value")
    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, figname, x_label, y_label, "V Function", foldername)
示例#2
0
def plot_qfunction_cont_act(qfunction, action, env, plot=True, figname="qfunction_cont.pdf", foldername='/plots/', save_figure=True, definition=50) -> None:
    """
    Plot a q function using the same action everywhere in the state space
    :param qfunction: the action value function to be plotted
    :param action: the action to be plotted
    :param env: the environment
    :param plot: whether the plot should be interactive
    :param figname: the name of the file where to plot the function
    :param foldername: the name of the folder where to put the file
    :param save_figure: whether the plot should be saved into a file
    :param definition: the resolution of the plot
    :return: nothing
    """
    if env.observation_space.shape[0] < 2:
        raise(ValueError("The observation space dimension is {}, whereas it should be 2".format(env.observation_space.shape[0])))

    portrait = np.zeros((definition, definition))
    x_min, y_min = env.observation_space.low
    x_max, y_max = env.observation_space.high

    for index_x, x in enumerate(np.linspace(x_min, x_max, num=definition)):
        for index_y, y in enumerate(np.linspace(y_min, y_max, num=definition)):
            state = np.array([x, y])
            portrait[definition - (1 + index_y), index_x] = qfunction.evaluate(state, action)

    plt.figure(figsize=(10, 10))
    plt.imshow(portrait, cmap="inferno", extent=[x_min, x_max, y_min, y_max], aspect='auto')
    plt.colorbar(label="critic value")
    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, figname, x_label, y_label, "Q Function or current policy", foldername)
示例#3
0
def plot_qfunction_ND(qfunction,
                      policy,
                      env,
                      plot=True,
                      figname="qfunction_cont.pdf",
                      foldername='/plots/',
                      save_figure=True,
                      definition=50) -> None:
    """
    Plot a q function in a N-dimensional state space using a given policy to chose an action everywhere in the state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first so as to plot them
    :param qfunction: the action value function to be plotted
    :param policy: the policy specifying the action to be plotted
    :param env: the environment
    :param plot: whether the plot should be interactive
    :param figname: the name of the file where to plot the function
    :param foldername: the name of the folder where to put the file
    :param save_figure: whether the plot should be saved into a file
    :param definition: the resolution of the plot
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        raise (ValueError(
            "Observation space dimension {}, should be > 2".format(
                env.observation_space.shape[0])))

    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
            np.linspace(state_min[0], state_max[0], num=definition)):
        for index_y, y in enumerate(
                np.linspace(state_min[1], state_max[1], num=definition)):
            state = np.array([[x, y]])
            for i in range(2, len(state_min)):
                z = random.random() - 0.5
                state = np.append(state, z)
            action = policy.select_action(state)
            portrait[definition - (1 + index_y),
                     index_x] = qfunction.evaluate(state, action)

    plt.figure(figsize=(10, 10))
    plt.imshow(portrait,
               cmap="inferno",
               extent=[state_min[0], state_max[0], state_min[1], state_max[1]],
               aspect='auto')
    plt.colorbar(label="critic value")
    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, figname, x_label, y_label,
               "Q Function or current policy", foldername)
示例#4
0
def plot_weight_histograms(policy, nb, env_name) -> None:
    """
    :param policy: the policy network
    :param nb: a number to allow several such plots through repeated epochs
    :param env_name: the name of the environment
    :return: nothing
    """
    probas = np.array(get_weight_sample(policy, env_name))
    plt.figure(1, figsize=(13, 10))

    bar_width = 0.0005
    bins = np.arange(probas.min(), probas.max() + bar_width, bar_width)
    plt.hist(probas, bins=bins)
    final_show(True, False, 'dispersion_' + str(nb) + '.pdf', "decision threshold", "count", "decision dispersion", '/results/')
示例#5
0
def plot_normal_histograms(policy, nb, env_name) -> None:
    """
    
    :param policy: the policy network
    :param nb: a number to allow several such plots through repeated epochs
    :param env_name: the name of the environment
    :return: nothing
    """
    mus, stds = get_normal_sample(policy, env_name)
    mus = np.array(mus)
    stds = np.array(stds)
    plt.figure(1, figsize=(13, 10))

    bar_width = 0.0005
    bins_mus = np.arange(mus.min(), mus.max() + bar_width, bar_width)
    bins_stds = np.arange(stds.min(), stds.max() + bar_width, bar_width)
    plt.hist(mus, bins=bins_mus)
    final_show(True, False, 'dispersion_mu_' + str(nb) + '.pdf', "mu", "count", "dispersion mu", '/results/')

    plt.hist(stds, bins=bins_stds)
    final_show(True, False, 'dispersion_std_' + str(nb) + '.pdf', "variance", "count", "dispersion variance", '/results/')
示例#6
0
def plot_trajectory(batch, env, nb, save_figure=True) -> None:
    """
    Plot the set of trajectories stored into a batch
    :param batch: the source batch
    :param env: the environment where the batch was built
    :param nb: a number, to save several similar plots
    :param save_figure: where the plot should be saved
    :return: nothing
    """
    if env.observation_space.shape[0] < 2:
        raise (ValueError(
            "Observation space of dimension {}, should be at least 2".format(
                env.observation_space.shape[0])))

    # Use the dimension names if given otherwise default to "x" and "y"
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])

    for episode in batch.episodes:
        x, y = episode_to_traj(episode)
        plt.scatter(x, y, c=range(1, len(episode.state_pool) + 1), s=3)
    figname = 'trajectory_' + str(nb) + '.pdf'
    final_show(save_figure, False, figname, x_label, y_label, "Trajectory",
               '/plots/')