Esempio n. 1
0
class TestCartPoleEnv(unittest.TestCase):
    def setUp(self):
        self.env = CartPoleEnv()

    @unittest.skip('Skipping due to cartpole.out')
    def test_reset_return_type(self):
        """Check if a list is returned"""
        self.assertIsInstance(self.env.reset(), list)

    @unittest.skip('Skipping due to cartpole.out')
    def test_step_return_type(self):
        """Check if a 3-tuple of list, float, and bool is returned"""
        env = CartPoleEnv()
        env.reset()
        obs, reward, done = env.step(action=-1)

        self.assertIsInstance(obs, list)
        self.assertIsInstance(reward, float)
        self.assertIsInstance(done, bool)

    @unittest.skip('Skipping due to cartpole.out')
    def test_obs_dim_return_type(self):
        """Check if an integer is returned"""
        self.assertIsInstance(self.env.obs_dim(), int)

    @unittest.skip('Skipping due to cartpole.out')
    def test_reset_return_dims(self):
        """Check if a 4-dimensional list is returned"""
        self.assertEqual(len(self.env.reset()), 4)

    @unittest.skip('Skipping due to cartpole.out')
    def test_obs_dim_return_value(self):
        """Check if 4 is returned"""
        env = CartPoleEnv()
        env.reset()
        self.assertEqual(env.obs_dim(), 4)

    def test_step_wrong_input(self):
        """Check if assertion is raised with wrong input"""
        with self.assertRaises(AssertionError):
            self.env.step(43892.42)

    @unittest.skip('Skipping due to cartpole.out')
    def test_done_signal_per_episode(self):
        """Check if done signal is triggered at the end of the episode"""
        env = CartPoleEnv()
        env.reset()

        for t in range(10):
            _, _, done = env.step(action=-1)
            if t != 499:
                # Must be false within 10 steps
                self.assertFalse(done)

        # Must be true at the end of the episode
        self.assertTrue(done)
Esempio n. 2
0
class CartPole:
    def __init__(self, gravity):
        self.dim = 5
        self.env = CartPoleEnv()
        self.env.gravity = gravity

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def action(self, observation, x):
        x = x * 10 - 5
        w = x[:4]
        b = x[4]
        return int(self.sigmoid(np.sum(observation * w) + b) > 0.5)

    def fitness(self, x):
        fitness = 0
        observation = self.env.reset()
        for t in range(200):
            action = self.action(observation, x)
            observation, reward, done, info = self.env.step(action)
            fitness += reward
            if done:
                break
        return -fitness

    def __del__(self):
        self.env.close()
def evaluate_agent(agent, episodes, return_trajectories=False, seed=1):
    env = CartPoleEnv()
    env.seed(seed)

    returns, trajectories = [], []
    for _ in range(episodes):
        states, actions, rewards = [], [], []
        state, terminal = env.reset(), False
        while not terminal:
            with torch.no_grad():
                policy, _ = agent(state)
                action = policy.logits.argmax(dim=-1)  # Pick action greedily
                state, reward, terminal = env.step(action)

                if return_trajectories:
                    states.append(state)
                    actions.append(action)
                rewards.append(reward)
        returns.append(sum(rewards))
        if return_trajectories:
            # Collect trajectory data (including terminal signal, which may be needed for offline learning)
            terminals = torch.cat(
                [torch.ones(len(rewards) - 1),
                 torch.zeros(1)])
            trajectories.append(
                dict(states=torch.cat(states),
                     actions=torch.cat(actions),
                     rewards=torch.tensor(rewards, dtype=torch.float32),
                     terminals=terminals))

    return (returns, trajectories) if return_trajectories else returns
Esempio n. 4
0
    def test_step_return_type(self):
        """Check if a 3-tuple of list, float, and bool is returned"""
        env = CartPoleEnv()
        env.reset()
        obs, reward, done = env.step(action=-1)

        self.assertIsInstance(obs, list)
        self.assertIsInstance(reward, float)
        self.assertIsInstance(done, bool)
Esempio n. 5
0
    def test_done_signal_per_episode(self):
        """Check if done signal is triggered at the end of the episode"""
        env = CartPoleEnv()
        env.reset()

        for t in range(10):
            _, _, done = env.step(action=-1)
            if t != 499:
                # Must be false within 10 steps
                self.assertFalse(done)

        # Must be true at the end of the episode
        self.assertTrue(done)
Esempio n. 6
0
), False, 0, [], deque(maxlen=args.imitation_replay_size)
pbar = tqdm(range(1, args.steps + 1), unit_scale=1, smoothing=0)
for step in pbar:
    if args.imitation == 'BC':
        # Perform behavioural cloning updates offline
        if step == 1:
            for _ in tqdm(range(args.imitation_epochs), leave=False):
                behavioural_cloning_update(agent, expert_trajectories,
                                           agent_optimiser,
                                           args.imitation_batch_size)
    else:
        # Collect set of trajectories by running policy π in the environment
        policy, value = agent(state)
        action = policy.sample()
        log_prob_action, entropy = policy.log_prob(action), policy.entropy()
        next_state, reward, terminal = env.step(action)
        episode_return += reward
        trajectories.append(
            dict(states=state,
                 actions=action,
                 rewards=torch.tensor([reward], dtype=torch.float32),
                 terminals=torch.tensor([terminal], dtype=torch.float32),
                 log_prob_actions=log_prob_action,
                 old_log_prob_actions=log_prob_action.detach(),
                 values=value,
                 entropies=entropy))
        state = next_state

        if terminal:
            # Store metrics and reset environment
            metrics['train_steps'].append(step)