Ejemplo n.º 1
0
def generate_md_atari():
    entries = rldb.find_all({})
    envs = set([entry['env-title'] for entry in entries])
    atari_envs = [env for env in envs if 'atari-' in env]
    atari_envs.sort()
    _generate_md_atari_main('envs/gym/atari/atari.md', atari_envs)
    _generate_md_atari_batch(atari_envs)
Ejemplo n.º 2
0
def generate_md_mujoco():
    entries = rldb.find_all({})
    envs = set([entry['env-title'] for entry in entries])
    mujoco_envs = [env for env in envs if 'mujoco-' in env]

    _generate_md_mujoco_main('envs/gym/mujoco/mujoco.md', mujoco_envs)
    _generate_md_mujoco_all_env(mujoco_envs)
Ejemplo n.º 3
0
def env_barplot(filter, plot_title, filepath):
    entries = rldb.find_all(filter)

    # Dedupe and sort entries by score, increasing
    sorted_entries = sorted(
        entries,
        key=lambda entry: entry['score'],
        reverse=True,
    )
    labels, scores, colors = entries_to_labels_scores(sorted_entries)

    # Draw bar plot
    fig, ax = plt.subplots(figsize=(12, int(len(sorted_entries) / 2)))
    bars = ax.barh(labels, scores, color=colors)
    plt.yticks(range(len(scores)), [''] * len(scores))

    # Add title
    plt.title(plot_title)

    # Add score labels
    bar_x0 = bars[0].get_window_extent(renderer=fig.canvas.get_renderer()).x0
    bar_x1 = bars[0].get_window_extent(renderer=fig.canvas.get_renderer()).x1
    BBOX_TO_BAR_UNIT = bars[0].get_width() / (bar_x1 - bar_x0)
    switch_label_align = False  # If true, put score label outside of barplot
    for i, (bar, entry) in enumerate(zip(bars, sorted_entries)):
        set_label(fig, ax, bar, entry, BBOX_TO_BAR_UNIT)
        set_score_text(fig, ax, bar, entry)

    # Add label legends
    set_label_legends(fig, ax, bars, BBOX_TO_BAR_UNIT)

    plt.tight_layout()
    plt.savefig(filepath)
    # plt.show()
    plt.clf()
Ejemplo n.º 4
0
def _generate_md_mujoco_main(filepath, mujoco_envs):
    # Get template
    template = get_template('markdown/source/{}'.format(filepath))

    table = ''
    table += '| Environment | Result | Algorithm | Source |\n'
    table += '|-------------|--------|-----------|--------|\n'

    for env_title in mujoco_envs:
        entries = rldb.find_all({
            'env-title': env_title,
        })
        entries.sort(key=lambda entry: entry['score'], reverse=True)
        entry = entries[0]  # Best Performing algorithm
        source_link = get_best_source_link(entry)  # Choose best link
        table += '| [{}](/envs/gym/mujoco/{}) | {} | {} | [{}]({}) |\n'.format(
            entry['env-title'][7:],
            entry['env-title'][7:],
            entry['score'],
            entry['algo-nickname'],
            entry['source-title'],
            source_link['url'],
        )

    result = populate_template(template, {"table": table})
    save_file('markdown/docs/{}'.format(filepath), result)
Ejemplo n.º 5
0
def test_entries_count():
    """Verify number of entries in rldb. This number should match README."""
    all_entries = rldb.find_all({})

    assert len(all_entries) == 3380
    assert len(all_entries) == (
        0 + 171  # A3C
        + 80  # ACKTR
        + 114  # Ape-X
        + 171  # C51
        + 179  # DDQN
        + 245  # DQN
        + 56  # DQN2013
        + 38  # DRQN
        + 301  # DuDQN
        + 245  # Gorila DQN
        + 291  # IMPALA
        + 57  # IQN
        + 342  # NoisyNet
        + 147  # PPO
        + 171  # Prioritized DQN
        + 114  # QR-DQN
        + 232  # Rainbow
        + 171  # Reactor
        + 18  # RND
        + 49  # TD3
        + 21  # TRPO
        + 15  # Trust-PCL
        + 49  # OpenAI Baselines cbd21ef
        + 14  # OpenAI Baselines ea68f3b
        + 89  # RL Baselines Zoo b76641e
    )
Ejemplo n.º 6
0
def test_a3c_paper_count():
    """Verify number of entries in Ape-X DQN paper."""
    ape_x_dqn_entries = rldb.find_all({
        'source-title':
        'Distributed Prioritized Experience Replay',
    })

    assert len(ape_x_dqn_entries) == (0 + 114)
Ejemplo n.º 7
0
def test_find_all_return_match_all_filter():
    """Make sure that all entries returned by find_all() matches all filters."""
    all_entries = rldb.find_all({
        'algo-nickname': 'DQN',
        'env-title': 'atari-pong',
    })

    for entry in all_entries:
        assert entry['algo-nickname'] == 'DQN'
        assert entry['env-title'] == 'atari-pong'
Ejemplo n.º 8
0
def test_repo_openai_baselines_ea68f3b_count():
    """Verify number of entries in OpenAI Baselines."""
    baselines_ea68f3b_entries = rldb.find_all({
        'source-title':
        'OpenAI Baselines ea68f3b',
    })

    assert len(baselines_ea68f3b_entries) == (
        0 + 7  # TRPO (MPI)
        + 7  # PPO2
    )
Ejemplo n.º 9
0
def test_drqn_paper_count():
    """Verify number of entries in DRQN paper."""
    drqn_entries = rldb.find_all({
        'source-title':
        'Deep Recurrent Q-Learning for Partially Observable MDPs',
    })

    assert len(drqn_entries) == (
        0 + 19  # DQN (Ours)
        + 19  # DRQN
    )
Ejemplo n.º 10
0
def test_a3c_paper_count():
    """Verify number of entries in A3C paper."""
    a3c_entries = rldb.find_all({
        'source-title':
        'Asynchronous Methods for Deep Reinforcement Learning',
    })

    assert len(a3c_entries) == (
        0 + 57  # A3C FF 1 day
        + 57  # A3C FF
        + 57  # A3C LSTM
    )
Ejemplo n.º 11
0
def test_trust_pcl_paper_count():
    """Verify number of entries in Trust-PCL paper."""
    trust_pcl_entries = rldb.find_all({
        'source-title':
        'Trust-PCL: An Off-Policy Trust Region Method for Continuous Control',
    })

    assert len(trust_pcl_entries) == (
        0 + 5  # TRPO+GAE
        + 5  # TRPO (from Trust-PCL)
        + 5  # Trust-PCL
    )
Ejemplo n.º 12
0
def test_rnd_paper_count():
    """Verify number of entries in RND paper."""
    rnd_entries = rldb.find_all({
        'source-title':
        'Exploration by Random Network Distillation',
    })

    assert len(rnd_entries) == (
        0 + 6  # Dynamics
        + 6  # PPO
        + 6  # RND
    )
Ejemplo n.º 13
0
def test_rainbow_paper_count():
    """Verify number of entries in Rainbow paper."""
    rainbow_entries = rldb.find_all({
        'source-title':
        'Rainbow: Combining Improvements in Deep Reinforcement Learning',
    })

    assert len(rainbow_entries) == (
        0 + 16  # DQN
        + 108  # Distributional DQN
        + 108  # Rainbow
    )
Ejemplo n.º 14
0
def test_prioritized_dqn_paper_count():
    """Verify number of entries in Prioritized DQN paper."""
    prioritized_dqn_entries = rldb.find_all({
        'source-title':
        'Prioritized Experience Replay',
    })

    assert len(prioritized_dqn_entries) == (
        0 + 57  # Proportional Prioritized DDQN
        + 57  # Rank Prioritized DQN
        + 57  # Rank Prioritized DDQN
    )
Ejemplo n.º 15
0
def test_c51_paper_count():
    """Verify number of entries in C51 paper."""
    c51_entries = rldb.find_all({
        'source-title':
        'A Distributional Perspective on Reinforcement Learning',
    })

    assert len(c51_entries) == (
        0 + 57  # DQN
        + 57  # DDQN
        + 57  # C51
    )
Ejemplo n.º 16
0
def test_trpo_paper_count():
    """Verify number of entries in TRPO paper."""
    trpo_entries = rldb.find_all({
        'source-title':
        'Trust Region Policy Optimization',
    })

    assert len(trpo_entries) == (
        0 + 7  # TRPO (single path)
        + 7  # TRPO (vine)
        + 7  # UCC-I
    )
Ejemplo n.º 17
0
def test_ppo_paper_count():
    """Verify number of entries in PPO paper."""
    ppo_entries = rldb.find_all({
        'source-title':
        'Proximal Policy Optimization Algorithm',
    })

    assert len(ppo_entries) == (
        0 + 49  # A2C
        + 49  # ACER
        + 49  # PPO
    )
Ejemplo n.º 18
0
def test_readme_entry_count():
    """
    Test if entry count in README is correct.
    """
    all_entries = rldb.find_all({})

    with open('README.md', 'r') as f:
        lines = f.readlines()
        text = '\n'.join(lines)
        nb_entries = re.search('entries-(.*)-blue.svg', text).group(1)

    assert int(nb_entries) == len(all_entries)
Ejemplo n.º 19
0
def test_ddqn_paper_count():
    """Verify number of entries in DDQN paper."""
    ddqn_entries = rldb.find_all({
        'source-title':
        'Deep Reinforcement Learning with Double Q-learning',
    })

    assert len(ddqn_entries) == (
        0 + 57 + 49  # DDQN
        + 57  # DDQN (tuned)
        + 8  # Human
        + 8  # Random
    )
Ejemplo n.º 20
0
def test_readme_env_count():
    """
    Test if environment count in README is correct.
    """
    all_entries = rldb.find_all({})
    all_envs = set([e['env-title'] for e in all_entries])

    with open('README.md', 'r') as f:
        lines = f.readlines()
        text = '\n'.join(lines)
        nb_envs = re.search('environments-(.*)-blue.svg', text).group(1)

    assert int(nb_envs) == len(all_envs)
Ejemplo n.º 21
0
def test_gorila_dqn_paper_count():
    """Verify number of entries in Gorila DQN paper."""
    gorila_dqn_entries = rldb.find_all({
        'source-title':
        'Massively Parallel Methods for Deep Reinforcement Learning',
    })

    assert len(gorila_dqn_entries) == (
        0 + 49  # DQN
        + 98  # Gorila DQN
        + 49  # Human
        + 49  # Random
    )
Ejemplo n.º 22
0
def test_dueling_dqn_paper_count():
    """Verify number of entries in Dueling DQN paper."""
    dudqn_entries = rldb.find_all({
        'source-title':
        'Dueling Network Architectures for Deep Reinforcement Learning',
    })

    assert len(dudqn_entries) == (
        0 + 114  # Dueling DQN
        + 65  # Human
        + 114  # PDD DQN
        + 8  # Random
    )
Ejemplo n.º 23
0
def test_dqn_paper_count():
    """Verify number of entries in DQN paper."""
    dqn_entries = rldb.find_all({
        'source-title':
        'Human-level control through deep reinforcement learning',
    })

    assert len(dqn_entries) == (
        0 + 49  # Best Linear Learner
        + 49  # Contingency
        + 49  # DQN
        + 49  # Human
        + 49  # Random
    )
Ejemplo n.º 24
0
def test_readme_source_count():
    """
    Test if source count in README is correct.
    """
    all_entries = rldb.find_all({})
    all_sources = set([e['source-title'] for e in all_entries])

    with open('README.md', 'r') as f:
        lines = f.readlines()
        text = '\n'.join(lines)
        nb_papers = re.search('papers-(.*)-blue.svg', text).group(1)
        nb_repos = re.search('repos-(.*)-blue.svg', text).group(1)

    assert int(nb_papers) + int(nb_repos) == len(all_sources)
Ejemplo n.º 25
0
def test_noisynet_paper_count():
    """Verify number of entries in NoisyNet paper."""
    noisynet_entries = rldb.find_all({
        'source-title':
        'Noisy Networks for Exploration',
    })

    assert len(noisynet_entries) == (
        0 + 57  # A3C
        + 57  # DQN
        + 57  # DuDQN
        + 57  # NoisyNet A3C
        + 57  # NoisyNet DQN
        + 57  # NoisyNet DuDQN
    )
Ejemplo n.º 26
0
def test_impala_paper_count():
    """Verify number of entries in IMPALA paper."""
    impala_entries = rldb.find_all({
        'source-title':
        'IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures',
    })

    assert len(impala_entries) == (
        0 + 57  # Atari-57 IMPALA (deep)
        + 57  # Atari-57 IMPALA (deep, multitask)
        + 57  # Atari-57 IMPALA (shallow)
        + 30  # DMLab-30 Experts
        + 30  # DMLab-30 Human
        + 30  # DMLab-30 IMPALA
        + 30  # DMLab-30 Random
    )
Ejemplo n.º 27
0
def _generate_md_mujoco_single_env(filepath, env_title):
    """Generate a Markdown file for a single MuJoCo environment."""
    # Get template
    template = get_template('markdown/source/{}'.format(filepath))

    # Parse rldb
    entries = rldb.find_all({
        'env-title': env_title,
    })
    entries.sort(key=lambda entry: entry['score'], reverse=True)
    feed_dict = {
        'table': entries_to_table(entries),
    }

    result = populate_template(template, feed_dict)
    save_file('markdown/docs/{}'.format(filepath), result)
Ejemplo n.º 28
0
def test_td3_paper_count():
    """Verify number of entries in TD3 paper."""
    td3_entries = rldb.find_all({
        'source-title':
        'Addressing Function Approximation Error in Actor-Critic Methods',
    })

    assert len(td3_entries) == (
        0 + 7  # ACKTR
        + 7  # DDPG
        + 7  # Our DDPG
        + 7  # PPO
        + 7  # SAC
        + 7  # TD3
        + 7  # TRPO
    )
Ejemplo n.º 29
0
def test_repo_openai_baselines_cbd21ef_count():
    """Verify number of entries in OpenAI Baselines."""
    baselines_cbd21ef_entries = rldb.find_all({
        'source-title':
        'OpenAI Baselines cbd21ef',
    })

    assert len(baselines_cbd21ef_entries) == (
        0 + 7  # A2C
        + 7  # ACER
        + 7  # ACKTR
        + 7  # DQN
        + 7  # PPO2
        + 7  # PPO2 (MPI)
        + 7  # TRPO (MPI)
    )
Ejemplo n.º 30
0
def test_repo_rl_baselines_zoo_count():
    """Verify number of entries in RL Baselines Zoo."""
    rl_baselines_zoo_entries = rldb.find_all({
        'source-title':
        'RL Baselines Zoo b76641e',
    })

    assert len(rl_baselines_zoo_entries) == (
        0 + 12  # A2C
        + 11  # ACER
        + 12  # ACKTR
        + 3  # DDPG
        + 12  # DQN
        + 27  # PPO2
        + 12  # SAC
    )