Exemplo n.º 1
0
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser)
        args_list = [
            "--env", "CartPole-v0",
            "--batch_size", "32"
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = VanillaPolicyGradient(**vars(self.hparams))
class TestPolicyGradient(TestCase):
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = VanillaPolicyGradient.add_model_specific_args(
            parent_parser)
        args_list = [
            "--episode_length",
            "100",
            "--env",
            "CartPole-v0",
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = VanillaPolicyGradient(**vars(self.hparams))

    def test_loss(self):
        """Test the reinforce loss function"""

        batch_states = torch.rand(32, 4)
        batch_actions = torch.rand(32).long()
        batch_qvals = torch.rand(32)

        loss = self.model.loss(batch_states, batch_actions, batch_qvals)

        self.assertIsInstance(loss, torch.Tensor)

    def test_train_batch(self):
        """Tests that a single batch generates correctly"""

        self.model.n_steps = 4
        self.model.batch_size = 1
        xp_dataloader = self.model.train_dataloader()

        batch = next(iter(xp_dataloader))
        self.assertEqual(len(batch), 3)
        self.assertEqual(len(batch[0]), self.model.batch_size)
        self.assertTrue(isinstance(batch, list))
        self.assertEqual(self.model.baseline, 3.9403989999999998)
        self.assertIsInstance(batch[0], torch.Tensor)
        self.assertIsInstance(batch[1], torch.Tensor)
        self.assertIsInstance(batch[2], torch.Tensor)
    def setUp(self) -> None:
        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = VanillaPolicyGradient.add_model_specific_args(
            parent_parser)
        args_list = ["--env", "CartPole-v0"]
        self.hparams = parent_parser.parse_args(args_list)

        self.trainer = pl.Trainer(
            gpus=0,
            max_steps=100,
            max_epochs=
            100,  # Set this as the same as max steps to ensure that it doesn't stop early
            val_check_interval=
            1,  # This just needs 'some' value, does not effect training right now
            fast_dev_run=True)
Exemplo n.º 4
0
 def test_policy_gradient(self):
     """Smoke test that the policy gradient model runs"""
     model = VanillaPolicyGradient(self.hparams.env)
     self.trainer.fit(model)
    def test_policy_gradient(self):
        """Smoke test that the policy gradient model runs"""
        model = VanillaPolicyGradient(self.hparams.env)
        result = self.trainer.fit(model)

        self.assertEqual(result, 1)