示例#1
0
    def test_sample_collector_by_number_success(self, mock_env_reset, mock_env_step, mock_envs_reset,
                                                mock_envs_step) -> None:
        dummy_env = Env()
        dummy_env.observation_space = Box(-1, 1, [STATE_DIM])
        dummy_env.action_space = Box(-1, 1, [ACTION_DIM])
        mock_env_reset.return_value = self.dummy_state
        mock_env_step.return_value = (self.dummy_state, self.dummy_reward, self.dummy_done, self.dummy_info)
        dummy_env.reset = mock_env_reset
        dummy_env.step = mock_env_step

        dummy_envs = DummyVectorEnv(N_ENVS, STATE_DIM, ACTION_DIM)
        mock_envs_reset.return_value = self.dummy_states
        mock_envs_step.return_value = (self.dummy_next_states, self.dummy_rewards, self.dummy_dones, {})
        dummy_envs.reset = mock_envs_reset
        dummy_envs.step = mock_envs_step

        dummy_env_container = EnvContainer(dummy_env, dummy_envs)
        mock_envs_reset.assert_called_once_with()  # __init__ of EnvContainer calls reset

        actor: nn.Module = ProbMLPConstantLogStd(STATE_DIM, ACTION_DIM, HIDDEN_DIMS, ACTIVATION, FINAL_LAYER_ACTIVATION, LOG_STD)
        scaler: nn.Module = DummyNet()
        tanh: nn.Module = nn.Tanh()
        action_getter: ActionGetter = ActionGetterModule(actor, scaler)
        sample_collector: SampleCollector = SampleCollectorV0(dummy_env_container, action_getter, N_ENVS * 10, 1)

        array_dict: ArrayDict = sample_collector.collect_samples_by_number()
        self.assertEqual(mock_envs_reset.call_count, 2)
        self.assertEqual(mock_envs_step.call_count, 10)

        collected_states = array_dict.get(ArrayKey.states)
        self.assertTupleEqual(collected_states.shape, (N_ENVS * 10, STATE_DIM))
示例#2
0
    def test_sample_action_success(self, sample_action, mock_env_reset, mock_env_step, mock_envs_reset,
                                   mock_envs_step):
        # mock
        action_getter: ActionGetter = ActionGetter()
        sample_action.return_value = self.dummy_actions, self.dummy_log_probs
        action_getter.sample_action = sample_action

        dummy_env = Env()
        dummy_env.observation_space = Box(-1, 1, [STATE_DIM])
        dummy_env.action_space = Box(-1, 1, [ACTION_DIM])
        mock_env_reset.return_value = self.dummy_state
        mock_env_step.return_value = (self.dummy_state, self.dummy_reward, self.dummy_done, self.dummy_info)
        dummy_env.reset = mock_env_reset
        dummy_env.step = mock_env_step

        dummy_envs = DummyVectorEnv(N_ENVS, STATE_DIM, ACTION_DIM)
        mock_envs_reset.return_value = self.dummy_states
        mock_envs_step.return_value = (self.dummy_next_states, self.dummy_rewards, self.dummy_dones, {})
        dummy_envs.reset = mock_envs_reset
        dummy_envs.step = mock_envs_step

        dummy_env_container: EnvContainer = EnvContainer(dummy_env, dummy_envs)

        # run
        sample_collector: SampleCollector = SampleCollectorV0(dummy_env_container, action_getter, N_ENVS * 2, 1)
        sample_collector.collect_samples_by_number()

        # assert
        self.assertEqual(sample_action.call_count, 2)
        np.testing.assert_array_equal(sample_action.call_args_list[0][0][0], self.dummy_states)
        np.testing.assert_array_equal(sample_action.call_args_list[1][0][0], self.dummy_next_states)
    def __init__(self, nb_actions, nb_space_dims, observation_low, observation_high, window_length, log_interval, filename):
        self.nb_actions = nb_actions
        self.nb_space_dims = nb_space_dims
        self.observation_low = observation_low
        self.observation_high = observation_high
        self.window_length = window_length  # 连续多少观测状态作为神经网络的输入
        self.log_interval = log_interval
        self.filename = filename

        self.env = Env()
        self.env.action_space = spaces.Discrete(self.nb_actions)
        self.env.observation_space = spaces.Box(low=np.ones(self.nb_space_dims) * self.observation_low,
                                                high=np.ones(self.nb_space_dims) * self.observation_high,
                                                dtype=np.float32)

        self.logging = []

        self.model = self.generate_model_fully_connected()
        self.policy = self.generate_policy()
        self.memory = self.generate_memory()
        self.agent = self.generate_agent()

        self.observation_last = None
        self.action = None
        self.observation = None
        self.reward = None
        self.done = None
        self.info = None
        self.metrics = None

        self.agent.training = True
        self.agent.step = 0
        self.agent.episode = 0
示例#4
0
    def test_dummy_env_return_values(self, mock_env_reset, mock_env_step) -> None:
        dummy_env = Env()
        dummy_env.observation_space = Box(-1, 1, [STATE_DIM])
        dummy_env.action_space = Box(-1, 1, [ACTION_DIM])
        mock_env_reset.return_value = self.dummy_state
        mock_env_step.return_value = (self.dummy_state, self.dummy_reward, self.dummy_done, self.dummy_info)
        dummy_env.reset = mock_env_reset
        dummy_env.step = mock_env_step

        dummy_env.reset()
        mock_env_reset.assert_called_with()
        dummy_action = np.random.random(ACTION_DIM)
        dummy_env.step(dummy_action)
        mock_env_step.assert_called_with(dummy_action)
示例#5
0
    def setup_playground(self, mocker):
        """Setup of used fixtures"""

        self.observation = 'observation'
        self.next_observation = 'next_observation'
        self.reward = 1.2
        self.handled_reward = 1.7
        self.done = False
        self.info = {'env_info': 'env_info'}

        mocker.patch('gym.Env.render')
        mocker.patch('gym.Env.step',
                     return_value=(self.next_observation, self.reward,
                                   self.done, self.info))
        self.env = Env()

        mocker.patch('learnrl.agent.Agent.remember')
        mocker.patch('learnrl.agent.Agent.learn')
        self.action = 3
        mocker.patch('learnrl.agent.Agent.act', return_value=self.action)
        self.n_agents = 5
        self.agents = [Agent() for _ in range(self.n_agents)]

        self.agent_id = 0
        mocker.patch('learnrl.playground.Playground._get_next_agent',
                     return_value=(self.agents[self.agent_id], self.agent_id))

        def handler_mocker(cls, reward, done, experience, reward_handler,
                           done_handler, logs):
            experience['reward'] = self.handled_reward
            logs['handled_reward'] = self.handled_reward

        mocker.patch('learnrl.playground.Playground._call_handlers',
                     handler_mocker)
        self.playground = Playground(self.env, self.agents)

        self.previous = [{
            'observation': None,
            'action': None,
            'reward': None,
            'done': None,
            'info': None
        } for _ in range(self.n_agents)]
示例#6
0
 def test_no_turnenv(self):
     """ should return the first agent if env is not a TurnEnv. """
     playground = Playground(Env(), self.agents)
     _, agent_id = playground._get_next_agent('observation')
     check.equal(agent_id, 0)
示例#7
0
 def setup_playground(self):
     """Setup of used fixtures"""
     self.env = Env()
     self.n_agents = 5
     self.agents = [Agent() for _ in range(self.n_agents)]
     self.playground = Playground(self.env, self.agents)