コード例 #1
0
def plot_session(session_spec, info_space, session_data):
    '''Plot the session graph, 2 panes: reward, loss & explore_var. Each aeb_df gets its own color'''
    graph_x = session_spec['meta'].get('graph_x', 'epi')
    aeb_count = len(session_data)
    palette = viz.get_palette(aeb_count)
    fig = viz.tools.make_subplots(rows=3, cols=1, shared_xaxes=True)
    for idx, (a, e, b) in enumerate(session_data):
        aeb_str = f'{a}{e}{b}'
        aeb_df = session_data[(a, e, b)]
        aeb_df.fillna(0, inplace=True)  # for saving plot, cant have nan
        fig_1 = viz.plot_line(aeb_df, 'reward', graph_x, legend_name=aeb_str, draw=False, trace_kwargs={'legendgroup': aeb_str, 'line': {'color': palette[idx]}})
        fig.append_trace(fig_1.data[0], 1, 1)

        fig_2 = viz.plot_line(aeb_df, ['loss'], graph_x, y2_col=['explore_var'], trace_kwargs={'legendgroup': aeb_str, 'showlegend': False, 'line': {'color': palette[idx]}}, draw=False)
        fig.append_trace(fig_2.data[0], 2, 1)
        fig.append_trace(fig_2.data[1], 3, 1)

    fig.layout['xaxis1'].update(title=graph_x, zerolinewidth=1)
    fig.layout['yaxis1'].update(fig_1.layout['yaxis'])
    fig.layout['yaxis1'].update(domain=[0.55, 1])
    fig.layout['yaxis2'].update(fig_2.layout['yaxis'])
    fig.layout['yaxis2'].update(showgrid=False, domain=[0, 0.45])
    fig.layout['yaxis3'].update(fig_2.layout['yaxis2'])
    fig.layout['yaxis3'].update(overlaying='y2', anchor='x2')
    fig.layout.update(ps.pick(fig_1.layout, ['legend']))
    fig.layout.update(title=f'session graph: {session_spec["name"]} t{info_space.get("trial")} s{info_space.get("session")}', width=500, height=600)
    viz.plot(fig)
    return fig
コード例 #2
0
ファイル: analysis.py プロジェクト: xenakas/SLM-Lab
def plot_trial(trial_spec, info_space):
    '''Plot the trial graph, 1 pane: mean and error envelope of reward graphs from all sessions. Each aeb_df gets its own color'''
    prepath = util.get_prepath(trial_spec, info_space)
    predir = util.prepath_to_predir(prepath)
    session_datas = session_datas_from_file(predir, trial_spec,
                                            info_space.get('trial'))

    aeb_count = len(session_datas[0])
    palette = viz.get_palette(aeb_count)
    fig = None
    for idx, (a, e, b) in enumerate(session_datas[0]):
        aeb = (a, e, b)
        aeb_str = f'{a}{e}{b}'
        color = palette[idx]
        aeb_rewards_df = gather_aeb_rewards_df(aeb, session_datas)
        aeb_fig = build_aeb_reward_fig(aeb_rewards_df, aeb_str, color)
        if fig is None:
            fig = aeb_fig
        else:
            fig.data.extend(aeb_fig.data)
    fig.layout.update(
        title=f'trial graph: {trial_spec["name"]} t{info_space.get("trial")}',
        width=500,
        height=600)
    viz.plot(fig)
    return fig
コード例 #3
0
ファイル: analysis.py プロジェクト: raghu1121/SLM-Lab
def plot_trial(trial_spec, info_space):
    '''Plot the trial graph, 1 pane: mean and error envelope of reward graphs from all sessions. Each aeb_df gets its own color'''
    prepath = util.get_prepath(trial_spec, info_space)
    predir, _, _, _, _, _ = util.prepath_split(prepath)
    session_datas = session_datas_from_file(predir, trial_spec,
                                            info_space.get('trial'))
    rand_session_data = session_datas[list(session_datas.keys())[0]]
    max_tick_unit = ps.get(trial_spec, 'env.0.max_tick_unit')
    aeb_count = len(rand_session_data)
    palette = viz.get_palette(aeb_count)
    fig = None
    for idx, (a, e, b) in enumerate(rand_session_data):
        aeb = (a, e, b)
        aeb_str = f'{a}{e}{b}'
        color = palette[idx]
        aeb_rewards_df = gather_aeb_rewards_df(aeb, session_datas,
                                               max_tick_unit)
        aeb_fig = build_aeb_reward_fig(aeb_rewards_df, aeb_str, color,
                                       max_tick_unit)
        if fig is None:
            fig = aeb_fig
        else:
            fig.add_traces(aeb_fig.data)
    fig.layout.update(
        title=
        f'trial graph: {trial_spec["name"]} t{info_space.get("trial")}, {len(session_datas)} sessions',
        width=500,
        height=600)
    viz.plot(fig)
    return fig
コード例 #4
0
    'bert': 'Qbert',
    'eaquest': 'Seaquest',
    'humanoid': 'RoboschoolHumanoid',
    'humanoidflagrun': 'RoboschoolHumanoidFlagrun',
    'humanoidflagrunharder': 'RoboschoolHumanoidFlagrunHarder',
}
master_legend_list = [
    'DQN',
    'DDQN+PER',
    'A2C (GAE)',
    'A2C (n-step)',
    'PPO',
    'SAC',
]
master_palette_dict = dict(
    zip(master_legend_list, viz.get_palette(len(master_legend_list))))
master_palette_dict['Async SAC'] = master_palette_dict['SAC']


def guard_env_name(env):
    env = env.strip('_').strip('-')
    if env in env_name_map:
        return env_name_map[env]
    else:
        return env


def get_trial_metrics_scalar(algo, env, data_folder):
    try:
        filepaths = glob(
            f'{data_folder}/{algo}*{env}*/{trial_metrics_scalar_path}')