def test_state_to_csv_from_csv(env: LlvmEnv): env.reset(benchmark="cBench-v0/crc32") env.episode_reward = 10 state = env.state assert state.reward == 10 state_from_csv = CompilerEnvState.from_csv(env.state.to_csv()) assert state_from_csv.reward == 10 assert state == state_from_csv
def test_connection_dies_default_reward_negated(env: LlvmEnv): env.reward_space = "IrInstructionCount" env.reset(benchmark="cBench-v0/crc32") env.reward_space.default_negates_returns = True env.reward_space.default_value = 2.5 env.episode_reward = 10 env.service.close() observation, reward, done, _ = env.step(0) assert done assert reward == -7.5 # negates reward.
def test_state_serialize_deserialize_equality(env: LlvmEnv): env.reset(benchmark="cbench-v1/crc32") env.episode_reward = 10 state = env.state assert state.reward == 10 buf = StringIO() CompilerEnvStateWriter(buf).write_state(state) buf.seek(0) # Rewind the buffer for reading. state_from_csv = next(iter(CompilerEnvStateReader(buf))) assert state_from_csv.reward == 10 assert state == state_from_csv
def test_connection_dies_default_reward(env: LlvmEnv): env.reward_space = "IrInstructionCount" env.reset(benchmark="cbench-v1/crc32") env.reward_space.default_negates_returns = False env.reward_space.default_value = 2.5 env.episode_reward = 10 # Kill the service. Note killing the service for a ManagedConnection will # result in a ServiceError because we have not ended the session we started # with env.reset() above. For UnmanagedConnection, this error will not be # raised. try: env.service.close() except ServiceError as e: assert "Service exited with returncode " in str(e) _, reward, done, _ = env.step(0) assert done assert reward == 2.5