def generate_md_atari(): entries = rldb.find_all({}) envs = set([entry['env-title'] for entry in entries]) atari_envs = [env for env in envs if 'atari-' in env] atari_envs.sort() _generate_md_atari_main('envs/gym/atari/atari.md', atari_envs) _generate_md_atari_batch(atari_envs)
def generate_md_mujoco(): entries = rldb.find_all({}) envs = set([entry['env-title'] for entry in entries]) mujoco_envs = [env for env in envs if 'mujoco-' in env] _generate_md_mujoco_main('envs/gym/mujoco/mujoco.md', mujoco_envs) _generate_md_mujoco_all_env(mujoco_envs)
def env_barplot(filter, plot_title, filepath): entries = rldb.find_all(filter) # Dedupe and sort entries by score, increasing sorted_entries = sorted( entries, key=lambda entry: entry['score'], reverse=True, ) labels, scores, colors = entries_to_labels_scores(sorted_entries) # Draw bar plot fig, ax = plt.subplots(figsize=(12, int(len(sorted_entries) / 2))) bars = ax.barh(labels, scores, color=colors) plt.yticks(range(len(scores)), [''] * len(scores)) # Add title plt.title(plot_title) # Add score labels bar_x0 = bars[0].get_window_extent(renderer=fig.canvas.get_renderer()).x0 bar_x1 = bars[0].get_window_extent(renderer=fig.canvas.get_renderer()).x1 BBOX_TO_BAR_UNIT = bars[0].get_width() / (bar_x1 - bar_x0) switch_label_align = False # If true, put score label outside of barplot for i, (bar, entry) in enumerate(zip(bars, sorted_entries)): set_label(fig, ax, bar, entry, BBOX_TO_BAR_UNIT) set_score_text(fig, ax, bar, entry) # Add label legends set_label_legends(fig, ax, bars, BBOX_TO_BAR_UNIT) plt.tight_layout() plt.savefig(filepath) # plt.show() plt.clf()
def _generate_md_mujoco_main(filepath, mujoco_envs): # Get template template = get_template('markdown/source/{}'.format(filepath)) table = '' table += '| Environment | Result | Algorithm | Source |\n' table += '|-------------|--------|-----------|--------|\n' for env_title in mujoco_envs: entries = rldb.find_all({ 'env-title': env_title, }) entries.sort(key=lambda entry: entry['score'], reverse=True) entry = entries[0] # Best Performing algorithm source_link = get_best_source_link(entry) # Choose best link table += '| [{}](/envs/gym/mujoco/{}) | {} | {} | [{}]({}) |\n'.format( entry['env-title'][7:], entry['env-title'][7:], entry['score'], entry['algo-nickname'], entry['source-title'], source_link['url'], ) result = populate_template(template, {"table": table}) save_file('markdown/docs/{}'.format(filepath), result)
def test_entries_count(): """Verify number of entries in rldb. This number should match README.""" all_entries = rldb.find_all({}) assert len(all_entries) == 3380 assert len(all_entries) == ( 0 + 171 # A3C + 80 # ACKTR + 114 # Ape-X + 171 # C51 + 179 # DDQN + 245 # DQN + 56 # DQN2013 + 38 # DRQN + 301 # DuDQN + 245 # Gorila DQN + 291 # IMPALA + 57 # IQN + 342 # NoisyNet + 147 # PPO + 171 # Prioritized DQN + 114 # QR-DQN + 232 # Rainbow + 171 # Reactor + 18 # RND + 49 # TD3 + 21 # TRPO + 15 # Trust-PCL + 49 # OpenAI Baselines cbd21ef + 14 # OpenAI Baselines ea68f3b + 89 # RL Baselines Zoo b76641e )
def test_a3c_paper_count(): """Verify number of entries in Ape-X DQN paper.""" ape_x_dqn_entries = rldb.find_all({ 'source-title': 'Distributed Prioritized Experience Replay', }) assert len(ape_x_dqn_entries) == (0 + 114)
def test_find_all_return_match_all_filter(): """Make sure that all entries returned by find_all() matches all filters.""" all_entries = rldb.find_all({ 'algo-nickname': 'DQN', 'env-title': 'atari-pong', }) for entry in all_entries: assert entry['algo-nickname'] == 'DQN' assert entry['env-title'] == 'atari-pong'
def test_repo_openai_baselines_ea68f3b_count(): """Verify number of entries in OpenAI Baselines.""" baselines_ea68f3b_entries = rldb.find_all({ 'source-title': 'OpenAI Baselines ea68f3b', }) assert len(baselines_ea68f3b_entries) == ( 0 + 7 # TRPO (MPI) + 7 # PPO2 )
def test_drqn_paper_count(): """Verify number of entries in DRQN paper.""" drqn_entries = rldb.find_all({ 'source-title': 'Deep Recurrent Q-Learning for Partially Observable MDPs', }) assert len(drqn_entries) == ( 0 + 19 # DQN (Ours) + 19 # DRQN )
def test_a3c_paper_count(): """Verify number of entries in A3C paper.""" a3c_entries = rldb.find_all({ 'source-title': 'Asynchronous Methods for Deep Reinforcement Learning', }) assert len(a3c_entries) == ( 0 + 57 # A3C FF 1 day + 57 # A3C FF + 57 # A3C LSTM )
def test_trust_pcl_paper_count(): """Verify number of entries in Trust-PCL paper.""" trust_pcl_entries = rldb.find_all({ 'source-title': 'Trust-PCL: An Off-Policy Trust Region Method for Continuous Control', }) assert len(trust_pcl_entries) == ( 0 + 5 # TRPO+GAE + 5 # TRPO (from Trust-PCL) + 5 # Trust-PCL )
def test_rnd_paper_count(): """Verify number of entries in RND paper.""" rnd_entries = rldb.find_all({ 'source-title': 'Exploration by Random Network Distillation', }) assert len(rnd_entries) == ( 0 + 6 # Dynamics + 6 # PPO + 6 # RND )
def test_rainbow_paper_count(): """Verify number of entries in Rainbow paper.""" rainbow_entries = rldb.find_all({ 'source-title': 'Rainbow: Combining Improvements in Deep Reinforcement Learning', }) assert len(rainbow_entries) == ( 0 + 16 # DQN + 108 # Distributional DQN + 108 # Rainbow )
def test_prioritized_dqn_paper_count(): """Verify number of entries in Prioritized DQN paper.""" prioritized_dqn_entries = rldb.find_all({ 'source-title': 'Prioritized Experience Replay', }) assert len(prioritized_dqn_entries) == ( 0 + 57 # Proportional Prioritized DDQN + 57 # Rank Prioritized DQN + 57 # Rank Prioritized DDQN )
def test_c51_paper_count(): """Verify number of entries in C51 paper.""" c51_entries = rldb.find_all({ 'source-title': 'A Distributional Perspective on Reinforcement Learning', }) assert len(c51_entries) == ( 0 + 57 # DQN + 57 # DDQN + 57 # C51 )
def test_trpo_paper_count(): """Verify number of entries in TRPO paper.""" trpo_entries = rldb.find_all({ 'source-title': 'Trust Region Policy Optimization', }) assert len(trpo_entries) == ( 0 + 7 # TRPO (single path) + 7 # TRPO (vine) + 7 # UCC-I )
def test_ppo_paper_count(): """Verify number of entries in PPO paper.""" ppo_entries = rldb.find_all({ 'source-title': 'Proximal Policy Optimization Algorithm', }) assert len(ppo_entries) == ( 0 + 49 # A2C + 49 # ACER + 49 # PPO )
def test_readme_entry_count(): """ Test if entry count in README is correct. """ all_entries = rldb.find_all({}) with open('README.md', 'r') as f: lines = f.readlines() text = '\n'.join(lines) nb_entries = re.search('entries-(.*)-blue.svg', text).group(1) assert int(nb_entries) == len(all_entries)
def test_ddqn_paper_count(): """Verify number of entries in DDQN paper.""" ddqn_entries = rldb.find_all({ 'source-title': 'Deep Reinforcement Learning with Double Q-learning', }) assert len(ddqn_entries) == ( 0 + 57 + 49 # DDQN + 57 # DDQN (tuned) + 8 # Human + 8 # Random )
def test_readme_env_count(): """ Test if environment count in README is correct. """ all_entries = rldb.find_all({}) all_envs = set([e['env-title'] for e in all_entries]) with open('README.md', 'r') as f: lines = f.readlines() text = '\n'.join(lines) nb_envs = re.search('environments-(.*)-blue.svg', text).group(1) assert int(nb_envs) == len(all_envs)
def test_gorila_dqn_paper_count(): """Verify number of entries in Gorila DQN paper.""" gorila_dqn_entries = rldb.find_all({ 'source-title': 'Massively Parallel Methods for Deep Reinforcement Learning', }) assert len(gorila_dqn_entries) == ( 0 + 49 # DQN + 98 # Gorila DQN + 49 # Human + 49 # Random )
def test_dueling_dqn_paper_count(): """Verify number of entries in Dueling DQN paper.""" dudqn_entries = rldb.find_all({ 'source-title': 'Dueling Network Architectures for Deep Reinforcement Learning', }) assert len(dudqn_entries) == ( 0 + 114 # Dueling DQN + 65 # Human + 114 # PDD DQN + 8 # Random )
def test_dqn_paper_count(): """Verify number of entries in DQN paper.""" dqn_entries = rldb.find_all({ 'source-title': 'Human-level control through deep reinforcement learning', }) assert len(dqn_entries) == ( 0 + 49 # Best Linear Learner + 49 # Contingency + 49 # DQN + 49 # Human + 49 # Random )
def test_readme_source_count(): """ Test if source count in README is correct. """ all_entries = rldb.find_all({}) all_sources = set([e['source-title'] for e in all_entries]) with open('README.md', 'r') as f: lines = f.readlines() text = '\n'.join(lines) nb_papers = re.search('papers-(.*)-blue.svg', text).group(1) nb_repos = re.search('repos-(.*)-blue.svg', text).group(1) assert int(nb_papers) + int(nb_repos) == len(all_sources)
def test_noisynet_paper_count(): """Verify number of entries in NoisyNet paper.""" noisynet_entries = rldb.find_all({ 'source-title': 'Noisy Networks for Exploration', }) assert len(noisynet_entries) == ( 0 + 57 # A3C + 57 # DQN + 57 # DuDQN + 57 # NoisyNet A3C + 57 # NoisyNet DQN + 57 # NoisyNet DuDQN )
def test_impala_paper_count(): """Verify number of entries in IMPALA paper.""" impala_entries = rldb.find_all({ 'source-title': 'IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures', }) assert len(impala_entries) == ( 0 + 57 # Atari-57 IMPALA (deep) + 57 # Atari-57 IMPALA (deep, multitask) + 57 # Atari-57 IMPALA (shallow) + 30 # DMLab-30 Experts + 30 # DMLab-30 Human + 30 # DMLab-30 IMPALA + 30 # DMLab-30 Random )
def _generate_md_mujoco_single_env(filepath, env_title): """Generate a Markdown file for a single MuJoCo environment.""" # Get template template = get_template('markdown/source/{}'.format(filepath)) # Parse rldb entries = rldb.find_all({ 'env-title': env_title, }) entries.sort(key=lambda entry: entry['score'], reverse=True) feed_dict = { 'table': entries_to_table(entries), } result = populate_template(template, feed_dict) save_file('markdown/docs/{}'.format(filepath), result)
def test_td3_paper_count(): """Verify number of entries in TD3 paper.""" td3_entries = rldb.find_all({ 'source-title': 'Addressing Function Approximation Error in Actor-Critic Methods', }) assert len(td3_entries) == ( 0 + 7 # ACKTR + 7 # DDPG + 7 # Our DDPG + 7 # PPO + 7 # SAC + 7 # TD3 + 7 # TRPO )
def test_repo_openai_baselines_cbd21ef_count(): """Verify number of entries in OpenAI Baselines.""" baselines_cbd21ef_entries = rldb.find_all({ 'source-title': 'OpenAI Baselines cbd21ef', }) assert len(baselines_cbd21ef_entries) == ( 0 + 7 # A2C + 7 # ACER + 7 # ACKTR + 7 # DQN + 7 # PPO2 + 7 # PPO2 (MPI) + 7 # TRPO (MPI) )
def test_repo_rl_baselines_zoo_count(): """Verify number of entries in RL Baselines Zoo.""" rl_baselines_zoo_entries = rldb.find_all({ 'source-title': 'RL Baselines Zoo b76641e', }) assert len(rl_baselines_zoo_entries) == ( 0 + 12 # A2C + 11 # ACER + 12 # ACKTR + 3 # DDPG + 12 # DQN + 27 # PPO2 + 12 # SAC )