def test_and_plot_policy(policy,
                         env,
                         as_goals=True,
                         visualize=True,
                         sampling_res=1,
                         n_traj=1,
                         max_reward=1,
                         itr=0,
                         report=None,
                         center=None,
                         limit=None,
                         bounds=None):

    avg_totRewards, avg_success, states, spacing, avg_time = test_policy(
        policy,
        env,
        as_goals,
        visualize,
        sampling_res=sampling_res,
        n_traj=n_traj,
        bounds=bounds)
    obj = env
    while not hasattr(obj, '_maze_id') and hasattr(obj, 'wrapped_env'):
        obj = env.wrapped_env
    maze_id = obj._maze_id
    plot_heatmap(avg_success,
                 states,
                 spacing=spacing,
                 show_heatmap=False,
                 maze_id=maze_id,
                 center=center,
                 limit=limit)
    reward_img = save_image()

    plot_heatmap(avg_time,
                 states,
                 spacing=spacing,
                 show_heatmap=False,
                 maze_id=maze_id,
                 center=center,
                 limit=limit,
                 adaptive_range=True)
    time_img = save_image()

    mean_rewards = np.mean(avg_totRewards)
    success = np.mean(avg_success)

    with logger.tabular_prefix('Outer_'):
        logger.record_tabular('iter', itr)
        logger.record_tabular('MeanRewards', mean_rewards)
        logger.record_tabular('Success', success)
    # logger.dump_tabular(with_prefix=False)

    if report is not None:
        report.add_image(
            reward_img,
            'policy performance\n itr: {} \nmean_rewards: {} \nsuccess: {}'.
            format(itr, mean_rewards, success))
        report.add_image(time_img, 'policy time\n itr: {} \n'.format(itr))
    return mean_rewards, success
Beispiel #2
0
def plot_generator_samples(generator, env=None, size=100, fname=None):

    goals, _ = generator.sample_states(size)

    if env is None:
        limit = np.max(np.abs(goals))
    else:
        goals_bound = env.goal_bounds[:len(env.goal_bounds) // 2]
        limit = np.max(goals_bound)

    goals_dim = goals.shape[1]

    if goals_dim == 2:
        plt.scatter(goals[:, 0], goals[:, 1], s=10)
        plt.axis('equal')
        plt.xlim(-limit, limit)
        plt.ylim(-limit, limit)

    else:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        p = ax.scatter(goals[:, 0], goals[:, 1], goals[:, 2], s=10)
        ax.axis('equal')
        ax.set_xlim(-limit, limit)
        ax.set_ylim(-limit, limit)
        ax.set_zlim(-limit, limit)

    img = save_image(fname=fname)
    return img
Beispiel #3
0
def plot_policy_performance(policy,
                            env,
                            horizon,
                            n_samples=200,
                            n_traj=10,
                            fname=None):
    goals_dim = env.dim
    goals_bound = env.goal_bounds[:len(env.goal_bounds) // 2]

    goals = np.random.uniform(-goals_bound, goals_bound,
                              [n_samples, goals_dim])

    limit = np.max(goals_bound)

    success_rates = evaluate_states(goals,
                                    env,
                                    policy,
                                    horizon,
                                    n_traj=n_traj,
                                    key='goal_reached',
                                    aggregator=(np.max, np.mean))

    plt.clf()

    if goals_dim == 2:
        plt.scatter(goals[:, 0],
                    goals[:, 1],
                    c=success_rates,
                    s=10,
                    cmap='plasma')
        plt.colorbar()
        plt.axis('equal')
        plt.xlim(-limit, limit)
        plt.ylim(-limit, limit)

    else:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        p = ax.scatter(goals[:, 0],
                       goals[:, 1],
                       goals[:, 2],
                       c=success_rates,
                       s=10,
                       cmap='plasma')
        fig.colorbar(p)
        ax.axis('equal')
        ax.set_xlim(-limit, limit)
        ax.set_ylim(-limit, limit)
        ax.set_zlim(-limit, limit)

    img = save_image(fname=fname)

    return img
def plot_policy_means(policy,
                      env,
                      sampling_res=2,
                      report=None,
                      center=None,
                      limit=None):  # only for start envs!
    states, empty_spaces, spacing = find_empty_spaces(
        env, sampling_res=sampling_res)
    goal = env.current_goal
    observations = [
        np.concatenate([
            state, [
                0,
            ] * (env.observation_space.flat_dim - len(state) - len(goal)), goal
        ]) for state in states
    ]
    actions, agent_infos = policy.get_actions(observations)
    vecs = agent_infos['mean']
    vars = [np.exp(log_std) * 0.25 for log_std in agent_infos['log_std']]
    ells = [
        patches.Ellipse(state, width=vars[i][0], height=vars[i][1], angle=0)
        for i, state in enumerate(states)
    ]

    fig = plt.figure()
    ax = fig.add_subplot(111)
    for e in ells:
        ax.add_artist(e)
        e.set_alpha(0.2)
    plt.scatter(*goal, color='r', s=100)
    Q = plt.quiver(states[:, 0],
                   states[:, 1],
                   vecs[:, 0],
                   vecs[:, 1],
                   units='xy',
                   angles='xy',
                   scale_units='xy',
                   scale=1)  # , np.linalg.norm(vars * 4)
    qk = plt.quiverkey(Q,
                       0.8,
                       0.85,
                       1,
                       r'1 Nkg',
                       labelpos='E',
                       coordinates='figure')
    # cb = plt.colorbar(Q)
    vec_img = save_image()
    if report is not None:
        report.add_image(vec_img, 'policy mean')
Beispiel #5
0
    def plot_regions_states(self, maze_id=0, report=None):
        fig, ax = plt.subplots()

        states_per_reg = [len(region.states) for region in self.regions]
        states_per_reg_lims = (min(states_per_reg), max(states_per_reg))
        normal = pylab.Normalize(*states_per_reg_lims)

        colors = pylab.cm.BuGn(normal(states_per_reg))

        for region, color in zip(self.regions, colors):
            lengths = region.max_border - region.min_border
            ax.add_patch(
                patches.Rectangle(region.min_border,
                                  *lengths,
                                  fill=True,
                                  edgecolor='k',
                                  facecolor=color))

        cax, _ = cbar.make_axes(ax)
        print("the interest lims are: ", states_per_reg_lims)
        cb2 = cbar.ColorbarBase(cax, cmap=pylab.cm.BuGn, norm=normal)
        ax.set_xlim(self.state_bounds[0][0], self.state_bounds[1][0])
        ax.set_ylim(self.state_bounds[0][1], self.state_bounds[1][1])

        if maze_id == 0:
            ax.add_patch(
                patches.Rectangle((-3, -3),
                                  10,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-3, -1),
                                  2,
                                  6,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-3, 5),
                                  10,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((5, -1),
                                  2,
                                  6,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-1, 1),
                                  4,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
        elif maze_id == 11:
            ax.add_patch(
                patches.Rectangle((-7, 5),
                                  14,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((5, -5),
                                  2,
                                  10,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-7, -7),
                                  14,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-7, -5),
                                  2,
                                  10,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-1, 1),
                                  6,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-3, -3),
                                  2,
                                  6,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-1, -3),
                                  4,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))

        regions_fig = save_image(fig)

        if report is None:
            return regions_fig
        else:
            report.add_image(
                regions_fig,
                'States per region\nTotal number of states: {}'.format(
                    sum([len(region.states) for region in self.regions])))
Beispiel #6
0
    def plot_regions_interest(self, maze_id=0, report=None):
        fig, ax = plt.subplots()

        interests = self.compute_all_interests()
        interest_lims = (min(interests), max(interests))
        normal = pylab.Normalize(*interest_lims)

        colors = pylab.cm.YlOrRd(normal(interests))

        for region, color in zip(self.regions, colors):
            lengths = region.max_border - region.min_border
            ax.add_patch(
                patches.Rectangle(region.min_border,
                                  *lengths,
                                  fill=True,
                                  edgecolor='k',
                                  facecolor=color))

        cax, _ = cbar.make_axes(ax)
        print("the interest lims are: ", interest_lims)
        cb2 = cbar.ColorbarBase(cax, cmap=pylab.cm.YlOrRd, norm=normal)
        ax.set_xlim(self.state_bounds[0][0], self.state_bounds[1][0])
        ax.set_ylim(self.state_bounds[0][1], self.state_bounds[1][1])

        if maze_id == 0:
            ax.add_patch(
                patches.Rectangle((-3, -3),
                                  10,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-3, -1),
                                  2,
                                  6,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-3, 5),
                                  10,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((5, -1),
                                  2,
                                  6,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-1, 1),
                                  4,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
        elif maze_id == 11:
            ax.add_patch(
                patches.Rectangle((-7, 5),
                                  14,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((5, -5),
                                  2,
                                  10,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-7, -7),
                                  14,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-7, -5),
                                  2,
                                  10,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-1, 1),
                                  6,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-3, -3),
                                  2,
                                  6,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))
            ax.add_patch(
                patches.Rectangle((-1, -3),
                                  4,
                                  2,
                                  fill=True,
                                  edgecolor="none",
                                  facecolor='0.4',
                                  alpha=0.3))

        regions_fig = save_image(fig)

        if report is None:
            return regions_fig
        else:
            report.add_image(
                regions_fig,
                'Interest per region:\nthe number of regions is: {}'.format(
                    len(self.regions)))