Ejemplo n.º 1
0
lr = plot.load_exps(
    [vitchyr_base_dir + "papers/nips2018/06-06-lr-baseline-pusher-0.2-range/"],
    plot.filter_by_flat_params({
        'rdim': '16--lr-1e-3',
    }),
    suppress_output=True,
)
plot.tag_exps(lr, "name", "l&r")

plot.comparison(
    #lr + ours + oracle + her + dsae,
    ours + oracle + her + lr + dsae,
    ["Final  puck_distance Mean", "Final  hand_distance Mean"],
    vary=["name"],
    smooth=plot.padded_ma_filter(10),
    #method_order=[4, 0, 1, 3, 2],
    ylim=(0.1, 0.28),
    # xlim=(0, 250000),
    xlim=(0, 500000),
    figsize=(6, 4),
)
plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(format_func))
plt.xlabel("Timesteps")
plt.ylabel("")  # "Final Distance to Goal")
plt.title("Visual Pusher Baselines")
plt.legend([])  # [our_method_name, "DSAE", "HER", "Oracle", "L&R", ])

plt.tight_layout()
plt.savefig(output_dir + "pusher_baselines.pdf")
print("File saved to", output_dir + "pusher_baselines.pdf")
Ejemplo n.º 2
0
offline_pusher = dp.get_trials(
    pusher_dir,
    criteria={
        'rdim':
        250,
        'algo_kwargs.should_train_vae.$function':
        'railrl.torch.vae.vae_schedules.never_train',
    })
plt.figure(figsize=(6, 5))
plot.plot_trials(
    OrderedDict([
        ("Online", online_pusher),
        ("Offline", offline_pusher),
    ]),
    y_keys="Final  sum_distance Mean",
    x_key="Number of env steps total",
    process_time_series=plot.padded_ma_filter(100),
)
plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(format_func))
plt.xlabel("Timesteps")
plt.ylabel("Final Distance to Goal")
plt.title("Visual Pusher, Online Ablation")
lgnd = plt.legend(["Online", "Offline"],
                  bbox_to_anchor=(0.49, -0.2),
                  loc="upper center",
                  ncol=4,
                  handlelength=1)
plt.tight_layout()
plt.savefig(output_dir + "pusher_online_ablation.pdf")
print("File saved to", output_dir + "pusher_online_ablation.pdf")
Ejemplo n.º 3
0
    })
offline_reacher = dp.get_trials(
    reacher_dir,
    criteria={
        'algo_kwargs.should_train_vae.$function':
        'railrl.torch.vae.vae_schedules.never_train',
    })

plt.figure(figsize=(6, 5))
plot.plot_trials(
    OrderedDict([
        ("Online", online_reacher),
        ("Offline", offline_reacher),
    ]),
    y_keys="Final  distance Mean",
    x_key="Number of env steps total",
    process_time_series=plot.padded_ma_filter(10, avg_only_from_left=True),
)
plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(format_func))
plt.xlabel("Timesteps")
plt.ylabel("Final Distance to Goal")
plt.title("Visual Reacher Online Ablation")
plt.legend(["Online", "Offline"],
           bbox_to_anchor=(0.49, -0.2),
           loc="upper center",
           ncol=4,
           handlelength=1)
plt.tight_layout()
plt.savefig(output_dir + "reacher_online_ablation.pdf")
print("File saved to", output_dir + "reacher_online_ablation.pdf")
Ejemplo n.º 4
0
#     '/home/vitchyr/git/rlkit/data/doodads3/05-12-sawyer-reach-vae-rl-reproduce-2/',
#     criteria={
#         'replay_kwargs.fraction_resampled_goals_are_env_goals': 0.5,
#         'replay_kwargs.fraction_goals_are_rollout_goals': 0.2,
#     }
# )


y_keys = [
    'Final  distance Mean',
]
plot_trials(
    {
        'State - HER TD3': state_her_td3,
        'State - TDM DDPG': state_tdm_ddpg,
        'VAE - HER TD3': vae_trials,
        # 'VAE - TD3': vae_td3_trials,
    },
    y_keys=y_keys,
    process_time_series=padded_ma_filter(3),
    # x_key=x_key,
)

plt.xlabel('Number of Environment Steps Total')
plt.ylabel('Final distance to Goal')
plt.savefig('/home/vitchyr/git/railrl/experiments/vitchyr/nips2018/plots'
            '/reach.jpg')
plt.show()

# plt.savefig("/home/ashvin/data/s3doodad/media/plots/pusher2d.pdf")