def test_cpuagent_add_state_correctness(states, grid, expected_results): agent = CPUAgent() agent.states = states agent.add_state(State(grid)) assert len(agent.states) == expected_results assert_array_equal(agent.states[len(agent.states) - 1].grid, grid)
def test_cpuagent_get_state_correctness(): agent = CPUAgent() agent.states = [ State(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), State(np.array([0, 0, 0, 0, 0, 0, 0, 0, 1])), ] result = agent.get_state(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])).grid assert_array_equal(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]), result)
def test_cpuagent_update_state_correctness(): agent = CPUAgent() agent.states = [ State(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), State(np.array([1, 1, 1, 1, 1, 1, 1, 1, 0])), ] new_state = State(np.array([1, 1, 1, 1, 1, 1, 1, 1, 0])) new_state.next_states_values *= 2 agent.update_state(new_state) assert_array_equal(agent.states[1].next_states_values, new_state.next_states_values)
def test_cpuagent_serializes_correctly(): agent = CPUAgent() agent.states = [ State(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), State(np.array([1, 1, 1, 1, 1, 1, 1, 1, 0])) ] serialized_agent = agent.serialize() assert 'states' in serialized_agent assert len(serialized_agent["states"]) == 2 assert_array_equal(serialized_agent["states"][0]["grid"], np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])) assert_array_equal(serialized_agent["states"][1]["grid"], np.array([1, 1, 1, 1, 1, 1, 1, 1, 0])) assert_array_equal(serialized_agent["states"][0]["next_states_values"], np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])) assert_array_equal(serialized_agent["states"][1]["next_states_values"], np.array([0])) assert_array_equal(serialized_agent["states"][0]["next_states_transitions"], np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])) assert_array_equal(serialized_agent["states"][1]["next_states_transitions"], np.array([8]))
def test_cpuagent_has_state_correctness(states, grid, expected_results): agent = CPUAgent() agent.states = states assert agent.has_state(State(grid)) == expected_results