Beispiel #1
0
def test_seeding(env=None, seed=0):
    """Test if environments are replicable."""
    if env is None:
        env = ngym.all_envs()[0]

    if isinstance(env, str):
        kwargs = {'dt': 20}
        env = gym.make(env, **kwargs)
    else:
        if not isinstance(env, gym.Env):
            raise ValueError('env must be a string or a gym.Env')
    env.seed(seed=seed)
    env.reset()
    states_mat = []
    rew_mat = []
    env.action_space.seed(seed)
    for stp in range(100):
        state, rew, done, _ = env.step(env.action_space.sample())
        states_mat.append(state)
        rew_mat.append(rew)
        if done:
            env.reset()
    states_mat = np.array(states_mat)
    rew_mat = np.array(rew_mat)
    return states_mat, rew_mat
Beispiel #2
0
def make_tags():
    string = 'Tags\n'
    string += '===================================\n\n'

    all_tags = ngym.all_tags()

    for tag in sorted(all_tags):
        string += '.. _tag-{:s}:\n\n'.format(tag)
        string += tag + '\n--------------------------------\n'
        for env in ngym.all_envs(tag=tag):
            if env in ENV_IGNORE:
                continue
            string += '    :class:`{:s} <{:s}>`\n'.format(env, all_envs[
                env].replace(':', '.'))
        string += '\n'
    with open(Path(__file__).parent / 'tags.rst', 'w') as f:
        f.write(string)

    string = 'Wrappers\n'
    string += '===================================\n\n'

    for key, val in ALL_WRAPPERS.items():
        string += key + '\n' + '-' * 50 + '\n'
        string += '.. autoclass:: ' + val.split(':')[0] + '.' + val.split(':')[
            1] + '\n'
        string += '    :members:\n'
        string += '    :exclude-members: new_trial\n\n'

    with open(Path(__file__).parent / 'wrappers.rst', 'w') as f:
        f.write(string)
Beispiel #3
0
def test_run(env=None, num_steps=100, verbose=False, **kwargs):
    """Test if one environment can at least be run."""
    if env is None:
        env = ngym.all_envs()[0]

    if isinstance(env, str):
        env = gym.make(env, **kwargs)
    else:
        if not isinstance(env, gym.Env):
            raise ValueError('env must be a string or a gym.Env')

    env.reset()
    for stp in range(num_steps):
        action = env.action_space.sample()
        state, rew, done, info = env.step(action)  # env.action_space.sample())
        if done:
            env.reset()

    tags = env.metadata.get('tags', [])
    all_tags = ngym.all_tags()
    for t in tags:
        if t not in all_tags:
            print('Warning: env has tag {:s} not in all_tags'.format(t))

    if verbose:
        print(env)

    return env
Beispiel #4
0
def supervised_all():
    all_envs = ngym.all_envs(tag='supervised')

    # Detection needs to be skipped now because it seems to have an error with dt=100
    # 'ReachingDelayResponse-v0' needs to be skipped now because it has Box action space
    skip_envs = ['Detection-v0', 'ReachingDelayResponse-v0']

    # all_envs = [all_envs[2]]

    # Skipping 'MotorTiming-v0' now because can't make all period same length
    skip_analysis_envs = ['MotorTiming-v0']


    for envid in all_envs:
        if envid in skip_envs:
            continue

        print('Train & analyze env ', envid)

        # supervised_train.train_network(envid)

        if envid in skip_analysis_envs:
            continue

        activity, info, config = supervised_train.run_network(envid)
        ta.analysis_average_activity(activity, info, config)
        ta.analysis_activity_by_condition(activity, info, config)
        ta.analysis_example_units_by_condition(activity, info, config)
        ta.analysis_pca_by_condition(activity, info, config)
Beispiel #5
0
def test_trialenv_all():
    """Test if all environments can at least be run."""
    success_count = 0
    total_count = 0
    hastrial_count = 0
    for env_name in sorted(ngym.all_envs()):
        if env_name in ['Combine-v0']:
            continue
        env = gym.make(env_name)
        if not isinstance(env, ngym.TrialEnv):
            continue
        total_count += 1

        print('Running env: {:s}'.format(env_name))
        try:
            env.new_trial()
            if env.trial is None:
                print('No self.trial is available after new_trial()')
            else:
                print('Success')
                hastrial_count += 1
            # print(env)
            success_count += 1
        except BaseException as e:
            print('Failure at running env: {:s}'.format(env_name))
            print(e)

    print('Success {:d}/{:d} envs'.format(success_count, total_count))
    print('{:d}/{:d} envs have self.trial after new_trial'.format(
        hastrial_count, success_count))
Beispiel #6
0
def test_seeding_all():
    """Test if all environments can at least be run."""
    success_count = 0
    total_count = 0
    for env_name in sorted(ngym.all_envs()):
        total_count += 1

        # print('Running env: {:s}'.format(env_name))
        # env = test_run(env_name)
        try:
            states1, rews1 = test_seeding(env_name, seed=0)
            states2, rews2 = test_seeding(env_name, seed=0)
            assert (states1 == states2).all(), 'states are not identical'
            assert (rews1 == rews2).all(), 'rewards are not identical'
            states1, rews1 = test_seeding(env_name, seed=0)
            states2, rews2 = test_seeding(env_name, seed=0)
            assert (states1 == states2).all(), 'states are not identical'
            assert (rews1 == rews2).all(), 'rewards are not identical'

            # print('Success')
            # print(env)
            success_count += 1
        except BaseException as e:
            print('Failure at running env: {:s}'.format(env_name))
            print(e)

    print('Success {:d}/{:d} envs'.format(success_count, total_count))
Beispiel #7
0
def make_env_images():
    envs = ngym.all_envs()
    for env_name in envs:
        env = gym.make(env_name, **{'dt': 20})
        action = np.zeros_like(env.action_space.sample())
        fname = os.path.join('images', env_name + '_examplerun')
        ngym.utils.plot_env(env, num_trials=2, def_act=action, fname=fname)
        plt.close()
Beispiel #8
0
def test_speed_dataset_all():
    """Test dataset speed of all experiments."""
    for env_name in sorted(ngym.all_envs()):
        print('Running env: {:s}'.format(env_name))
        try:
            test_speed_dataset(env_name)
            print('Success')
        except BaseException as e:
            print('Failure at running env: {:s}'.format(env_name))
            print(e)
Beispiel #9
0
def test_dataset_all():
    """Test if all environments can at least be run."""
    success_count = 0
    total_count = 0
    supervised_count = len(ngym.all_envs(tag='supervised'))
    for env_name in sorted(ngym.all_envs()):
        total_count += 1

        print('Running env: {:s}'.format(env_name))
        try:
            test_dataset(env_name)
            print('Success')
            success_count += 1
        except BaseException as e:
            print('Failure at running env: {:s}'.format(env_name))
            print(e)

    print('Success {:d}/{:d} envs'.format(success_count, total_count))
    print('Expect {:d} envs to support supervised learning'.format(supervised_count))
Beispiel #10
0
def test_trialenv(env=None, **kwargs):
    """Test if a TrialEnv is behaving correctly."""
    if env is None:
        env = ngym.all_envs()[0]

    if isinstance(env, str):
        env = gym.make(env, **kwargs)
    else:
        if not isinstance(env, gym.Env):
            raise ValueError('env must be a string or a gym.Env')

    trial = env.new_trial()
    assert trial is not None, 'TrialEnv should return trial info dict ' + str(
        env)
Beispiel #11
0
def test_all(test_fn):
    """Test speed of all experiments."""
    success_count = 0
    total_count = 0
    for env_name in sorted(ngym.all_envs()):
        total_count += 1
        print('Running env: {:s} Wrapped with SideBias'.format(env_name))
        try:
            test_fn(env_name)
            print('Success')
            success_count += 1
        except BaseException as e:
            print('Failure at running env: {:s}'.format(env_name))
            print(e)
        print('')

    print('Success {:d}/{:d} envs'.format(success_count, total_count))
Beispiel #12
0
def test_print_all():
    """Test printing of all experiments."""
    success_count = 0
    total_count = 0
    for env_name in sorted(ngym.all_envs()):
        total_count += 1
        print('')
        print('Test printing env: {:s}'.format(env_name))
        try:
            env = gym.make(env_name)
            print(env)
            print('Success')
            success_count += 1
        except BaseException as e:
            print('Failure')
            print(e)

    print('Success {:d}/{:d} envs'.format(success_count, total_count))
Beispiel #13
0
def test_run_all(verbose_success=False):
    """Test if all environments can at least be run."""
    success_count = 0
    total_count = 0
    for env_name in sorted(ngym.all_envs()):
        total_count += 1

        # print('Running env: {:s}'.format(env_name))
        # env = test_run(env_name)
        try:
            test_run(env_name, verbose=verbose_success)
            # print('Success')
            # print(env)
            success_count += 1
        except BaseException as e:
            print('Failure at running env: {:s}'.format(env_name))
            print(e)

    print('Success {:d}/{:d} envs'.format(success_count, total_count))
Beispiel #14
0
def write_doc(write_type):
    all_tags = ngym.all_tags()
    if write_type == 'tasks':
        all_items = ngym.all_envs()
        info_fn = info
        fname = 'envs.md'
        all_items_dict = ngym.envs.ALL_ENVS

    elif write_type == 'wrappers':
        all_items = ngym.all_wrappers()
        info_fn = info_wrapper
        fname = 'wrappers.md'
        all_items_dict = ngym.wrappers.ALL_WRAPPERS
    else:
        raise ValueError

    string = ''
    names = ''
    counter = 0
    link_dict = dict()
    for name in all_items:
        try:
            # Get information about individual task or wrapper
            string += '___\n\n'
            info_string = info_fn(name)
            info_string = info_string.replace('\n', '  \n')  # for markdown

            # If task, add link to tags
            if write_type == 'tasks':
                # Tags has to be last
                ind = info_string.find('Tags')
                info_string = info_string[:ind]
                env = gym.make(name)
                # Modify to add tag links
                info_string += 'Tags: '
                for tag in env.metadata.get('tags', []):
                    tag_link = tag.lower().replace(' ', '-')
                    tag_with_link = add_link(tag, tag_link)
                    info_string += tag_with_link + ', '
                info_string = info_string[:-2] + '\n\n'
            string += info_string

            # Make links to the section titles
            # Using github's automatic link to section titles
            if write_type == 'tasks':
                env = gym.make(name)
                link = type(env).__name__
            else:
                link = name
            link = link.lower().replace(' ', '-')
            link_dict[name] = link

            # Add link to source code
            names += add_link(name, link) + '\n\n'
            source_link = all_items_dict[name].split(':')[0].replace('.', '/')
            string += '[Source]({:s})\n\n'.format(
                SOURCE_ROOT + source_link + '.py')
            counter += 1
        except BaseException as e:
            print('Failure in ', name)
            print(e)

    full_string = '### List of {:d} {:s} implemented\n\n'.format(counter,
                                                                 write_type)
    full_string += names

    if write_type == 'tasks':
        string_all_tags = '___\n\nTags: '
        for tag in sorted(all_tags):
            tag_link = tag.lower().replace(' ', '-')
            tag_with_link = add_link(tag, tag_link)
            string_all_tags += tag_with_link + ', '
        string_all_tags = string_all_tags[:-2] + '\n\n'
        full_string += string_all_tags

    full_string += string

    if write_type == 'tasks':
        string_tag = '___\n\n### Tags ### \n\n'
        for tag in sorted(all_tags):
            string_tag += '### {:s} \n\n'.format(tag)
            for name in ngym.all_envs(tag=tag):
                string_tag += add_link(name, link_dict[name])
                string_tag += '\n\n'
        full_string += string_tag

    with open(fname, 'w') as f:
        f.write('* Under development, details subject to change\n\n')
        f.write(full_string)
Beispiel #15
0
"""Test Dataset for supervised learning.

All tests in this file can be run by running in command line
pytest test_data.py
"""

import pytest

import numpy as np

import gym
import neurogym as ngym

# Get all supervised learning environment
SLENVS = ngym.all_envs(tag='supervised')


def _test_env(env):
    """Test if one environment can at least be run with Dataset."""
    batch_size = 32
    seq_len = 40
    dataset = ngym.Dataset(env,
                           env_kwargs={'dt': 100},
                           batch_size=batch_size,
                           seq_len=seq_len)
    for i in range(2):
        inputs, target = dataset()
        assert inputs.shape[0] == seq_len
        assert inputs.shape[1] == batch_size
        assert target.shape[0] == seq_len
        assert target.shape[1] == batch_size
Beispiel #16
0
def main():
    make_env_images()
    string = 'Environments\n'
    string += '===================================\n\n'

    for key, val in sorted(ALL_ENVS.items()):
        string += key + '\n'+'-'*50+'\n'
        string += '.. autoclass:: ' + val.split(':')[0] + '.' + val.split(':')[1] + '\n'
        string += '    :members:\n'
        string += '    :exclude-members: new_trial\n\n'

        env = gym.make(key)
        # Add paper
        paper_name = env.metadata.get('paper_name', '')
        paper_link = env.metadata.get('paper_link', '')
        if paper_name:
            string += '    Reference paper\n'
            paper_name = paper_name.replace('\n', ' ')
            string += '        `{:s} <{:s}>`__\n\n'.format(paper_name, paper_link)
            # string += '    .. __{:s}:\n        {:s}\n\n'.format(paper_name, paper_link)

        # Add tags
        string += '    Tags\n'
        for tag in env.metadata.get('tags', []):
            string += '        :ref:`tag-{:s}`, '.format(tag)
        string = string[:-2]
        string += '\n\n'


        # Add image
        string += '    Sample run\n'
        image_path = os.path.join('images', key+'_examplerun.png')
        if os.path.isfile(image_path):
            string += ' '*8 + '.. image:: {:s}\n'.format(image_path)
            string += ' '*12 + ':width: 600\n\n'

    with open('envs.rst', 'w') as f:
        f.write(string)

    string = 'Tags\n'
    string += '===================================\n\n'

    all_tags = ngym.all_tags()

    for tag in sorted(all_tags):
        string += '.. _tag-{:s}:\n\n'.format(tag)
        string += tag + '\n--------------------------------\n'
        for env in ngym.all_envs(tag=tag):
            string += '    :class:`{:s} <{:s}>`\n'.format(env, ALL_ENVS[env].replace(':', '.'))
        string += '\n'
    with open('tags.rst', 'w') as f:
        f.write(string)


    string = 'Wrappers\n'
    string += '===================================\n\n'

    for key, val in ALL_WRAPPERS.items():
        string += key + '\n' + '-' * 50 + '\n'
        string += '.. autoclass:: ' + val.split(':')[0] + '.' + val.split(':')[1] + '\n'
        string += '    :members:\n'
        string += '    :exclude-members: new_trial\n\n'

    with open('wrappers.rst', 'w') as f:
        f.write(string)
Beispiel #17
0
"""

import pytest

import numpy as np

import gym
import neurogym as ngym

try:
    import psychopy
    _have_psychopy = True
except ImportError as e:
    _have_psychopy = False

ENVS = ngym.all_envs(psychopy=_have_psychopy, contrib=True, collections=True)


def test_run(env=None, num_steps=100, verbose=False, **kwargs):
    """Test if one environment can at least be run."""
    if env is None:
        env = ngym.all_envs()[0]

    if isinstance(env, str):
        env = gym.make(env, **kwargs)
    else:
        if not isinstance(env, gym.Env):
            raise ValueError('env must be a string or a gym.Env')

    env.reset()
    for stp in range(num_steps):