def test_credit_initial_state(): """Test initial state for Credit sim via gym.make""" base_env = envs.make("Credit-v0") original_state = base_env.reset() features, labels = original_state["features"], original_state["labels"] subsampled_state = wn.credit.State(features[:100], labels[:100]) env = envs.make("Credit-v0", initial_state=subsampled_state) assert np.allclose(env.initial_state.features, subsampled_state.features) assert np.allclose(env.initial_state.labels, subsampled_state.labels) ob = env.reset() assert np.allclose(ob["features"], subsampled_state.features) assert np.allclose(ob["labels"], subsampled_state.labels) for idx in range(10): print(env.observation_space) print(ob["features"].shape, ob["labels"].shape) assert env.observation_space.contains(ob) a = env.action_space.sample() assert env.action_space.contains(a) (ob, _reward, done, _info) = env.step(a) if done: break ob = env.reset() assert np.allclose(ob["features"], subsampled_state.features) assert np.allclose(ob["labels"], subsampled_state.labels) env.close()
def test_config(spec): """Test setting simulator config via gym.make""" base_env = envs.make(spec) base_config = base_env.config new_config = dataclasses.replace(base_config, delta_t=-100) new_env = envs.make(spec, config=new_config) assert base_env.config.delta_t == base_config.delta_t assert new_env.config.delta_t == new_config.delta_t
def test_credit_config(): """Set simulator config for Credit sim via gym.make""" base_features = wn.credit.Config().changeable_features new_features = np.array([0, 1, 2]) base_env = envs.make("Credit-v0") config = wn.credit.Config(changeable_features=new_features) env = envs.make("Credit-v0", config=config) assert np.allclose(base_env.config.changeable_features, base_features) assert np.allclose(env.config.changeable_features, new_features) base_env.close() env.close()
def test_make_with_kwargs(): env = envs.make("test.ArgumentEnv-v0", arg2="override_arg2", arg3="override_arg3") assert env.spec.id == "test.ArgumentEnv-v0" assert isinstance(env.unwrapped, ArgumentEnv) assert env.arg1 == "arg1" assert env.arg2 == "override_arg2" assert env.arg3 == "override_arg3"
def test_random_rollout(spec): for env in [envs.make(spec), envs.make(spec), envs.make(spec)]: def agent(ob): return env.action_space.sample() ob = env.reset() for _ in range(10): assert env.observation_space.contains(ob) print("Observation: ", ob) a = agent(ob) assert env.action_space.contains(a) (ob, _reward, done, _info) = env.step(a) if done: break env.close()
def test_make(): env = envs.make("HIV-v0") assert env.spec.id == "HIV-v0" assert isinstance(env.unwrapped, envs.ODEEnvBuilder)