Exemplo n.º 1
0
 def test_get_action(self):
     agent = DQN(
         state_shape=self.env.observation_space.shape,
         action_dim=self.env.action_space.n,
         gpu=-1)
     state = self.env.reset()
     agent.get_action(state, test=False)
     agent.get_action(state, test=True)
Exemplo n.º 2
0
    def test_train(self):
        agent = DQN(
            state_shape=self.env.observation_space.shape,
            action_dim=self.env.action_space.n,
            memory_capacity=100,
            gpu=-1)
        from cpprb import ReplayBuffer
        replay_buffer = ReplayBuffer(
            obs_dim=self.env.observation_space.shape,
            act_dim=1,
            size=agent.memory_capacity)

        obs = self.env.reset()
        for _ in range(100):
            action = agent.get_action(obs)
            next_obs, reward, done, _ = self.env.step(action)
            replay_buffer.add(obs=obs, act=action, next_obs=next_obs, rew=reward, done=done)
            if done:
                next_obs = self.env.reset()
            obs = next_obs

        for _ in range(100):
            samples = replay_buffer.sample(agent.batch_size)
            agent.train(samples["obs"], samples["act"], samples["next_obs"],
                        samples["rew"], np.array(samples["done"], dtype=np.float64))