예제 #1
0
def test_create_malformed():
    client = gym_http_client.Client(get_remote_base())
    try:
        client.env_create('bad string')
    except gym_http_client.ServerError, e:
        assert 'malformed environment ID' in e.message
        assert e.status_code == 400
예제 #2
0
def test_create_malformed():
    client = gym_http_client.Client(get_remote_base())
    try:
        client.env_create('bad string')
    except gym_http_client.ServerError as e:
        assert 'malformed environment ID' in e.message
        assert e.status_code == 400
    else:
        assert False


# @with_server
# def test_missing_API_key():
#    client = gym_http_client.Client(get_remote_base())
#    cur_key = os.environ.get('OPENAI_GYM_API_KEY')
#    os.environ['OPENAI_GYM_API_KEY'] = ''
#    try:
#        print 'UPLOADING'
#        print cur_key
#        client.upload('tmp')
#        print '*****'
#    except requests.HTTPError, e:
#        assert e.response.status_code == 400
#    else:
#        assert False
#    finally:
#        if cur_key:
#            os.environ['OPENAI_GYM_API_KEY'] = cur_key
예제 #3
0
def test_step():
    client = gym_http_client.Client(get_remote_base())
    instance_id = client.env_create('CartPole-v0')
    [observation, reward, done, info] = client.env_step(instance_id, 1)
    assert len(observation) == 4
    assert type(reward) == float
    assert type(done) == bool
    assert type(info) == dict
예제 #4
0
def test_action_space_contains():
    client = gym_http_client.Client(get_remote_base())
    instance_id = client.env_create('CartPole-v0')
    action_info = client.env_action_space_info(instance_id)
    assert action_info['n'] == 2
    assert client.env_action_space_contains(instance_id, 0) == True
    assert client.env_action_space_contains(instance_id, 1) == True
    assert client.env_action_space_contains(instance_id, 2) == False
예제 #5
0
def test_observation_space_box():
    client = gym_http_client.Client(get_remote_base())
    instance_id = client.env_create('CartPole-v0')
    obs_info = client.env_observation_space_info(instance_id)
    assert obs_info['name'] == 'Box'
    assert len(obs_info['shape']) == 1
    assert obs_info['shape'][0] == 4
    assert len(obs_info['low']) == 4
    assert len(obs_info['high']) == 4
예제 #6
0
def test_reset():
    client = gym_http_client.Client(get_remote_base())

    instance_id = client.env_create('CartPole-v0')
    init_obs = client.env_reset(instance_id)
    assert len(init_obs) == 4

    instance_id = client.env_create('FrozenLake-v0')
    init_obs = client.env_reset(instance_id)
    assert init_obs == 0
예제 #7
0
def test_monitor_start_close_upload():
    assert os.environ.get('OPENAI_GYM_API_KEY')
    # otherwise test is invalid

    client = gym_http_client.Client(get_remote_base())
    instance_id = client.env_create('CartPole-v0')
    client.env_monitor_start(instance_id, 'tmp', force=True)
    client.env_reset(instance_id)
    client.env_step(instance_id, 1)
    client.env_monitor_close(instance_id)
    client.upload('tmp')
예제 #8
0
def test_observation_space_contains():
    client = gym_http_client.Client(get_remote_base())
    instance_id = client.env_create('CartPole-v0')
    obs_info = client.env_observation_space_info(instance_id)
    assert obs_info['name'] == 'Box'
    assert client.env_observation_space_contains(instance_id, {"name": "Box"})
    assert client.env_observation_space_contains(instance_id, {"shape": (4, )})
    assert client.env_observation_space_contains(instance_id, {
        "name": "Box",
        "shape": (4, )
    })
예제 #9
0
def test_bad_instance_id():
    ''' Test all methods that use instance_id with an invalid ID'''
    client = gym_http_client.Client(get_remote_base())
    try_these = [
        lambda x: client.env_reset(x), lambda x: client.env_step(x, 1),
        lambda x: client.env_action_space_info(x),
        lambda x: client.env_observation_space_info(x),
        lambda x: client.env_monitor_start(x, directory='tmp', force=True),
        lambda x: client.env_monitor_close(x)
    ]
    for call in try_these:
        try:
            call('bad_id')
        except requests.HTTPError, e:
            assert e.response.status_code == 400
        else:
            assert False
예제 #10
0
def test_action_space_discrete():
    client = gym_http_client.Client(get_remote_base())
    instance_id = client.env_create('CartPole-v0')
    action_info = client.env_action_space_info(instance_id)
    assert action_info['name'] == 'Discrete'
    assert action_info['n'] == 2
예제 #11
0
def test_create_valid():
    client = gym_http_client.Client(get_remote_base())
    instance_id = client.env_create('CartPole-v0')
    assert instance_id in client.env_list_all()
예제 #12
0
def test_create_malformed():
    client = gym_http_client.Client(get_remote_base())
    try:
        client.env_create('bad string')
    except requests.HTTPError, e:
        assert e.response.status_code == 400
예제 #13
0
def test_action_space_sample():
    client = gym_http_client.Client(get_remote_base())
    instance_id = client.env_create('CartPole-v0')
    action = client.env_action_space_sample(instance_id)
    assert 0 <= action < 2
예제 #14
0
def test_create_destroy():
    client = gym_http_client.Client(get_remote_base())
    instance_id = client.env_create('CartPole-v0')
    assert instance_id in client.env_list_all()
    client.env_close(instance_id)
    assert instance_id not in client.env_list_all()
예제 #15
0
 def __init__(self, env_id):
     self.client = gym_http_client.Client('http://127.0.0.1:5000')
     self.instance_id = self.client.env_create(env_id)