def setUp(self) -> None: self.env = ToTensor(gym.make("CartPole-v0")) self.obs_shape = self.env.observation_space.shape self.n_actions = self.env.action_space.n self.net = MLP(self.obs_shape, self.n_actions) self.agent = Agent(self.net) parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser) args_list = [ "--env", "CartPole-v0", "--batch_size", "32" ] self.hparams = parent_parser.parse_args(args_list) self.model = VanillaPolicyGradient(**vars(self.hparams))
class TestPolicyGradient(TestCase): def setUp(self) -> None: self.env = ToTensor(gym.make("CartPole-v0")) self.obs_shape = self.env.observation_space.shape self.n_actions = self.env.action_space.n self.net = MLP(self.obs_shape, self.n_actions) self.agent = Agent(self.net) parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = cli.add_base_args(parent=parent_parser) parent_parser = VanillaPolicyGradient.add_model_specific_args( parent_parser) args_list = [ "--episode_length", "100", "--env", "CartPole-v0", ] self.hparams = parent_parser.parse_args(args_list) self.model = VanillaPolicyGradient(**vars(self.hparams)) def test_loss(self): """Test the reinforce loss function""" batch_states = torch.rand(32, 4) batch_actions = torch.rand(32).long() batch_qvals = torch.rand(32) loss = self.model.loss(batch_states, batch_actions, batch_qvals) self.assertIsInstance(loss, torch.Tensor) def test_train_batch(self): """Tests that a single batch generates correctly""" self.model.n_steps = 4 self.model.batch_size = 1 xp_dataloader = self.model.train_dataloader() batch = next(iter(xp_dataloader)) self.assertEqual(len(batch), 3) self.assertEqual(len(batch[0]), self.model.batch_size) self.assertTrue(isinstance(batch, list)) self.assertEqual(self.model.baseline, 3.9403989999999998) self.assertIsInstance(batch[0], torch.Tensor) self.assertIsInstance(batch[1], torch.Tensor) self.assertIsInstance(batch[2], torch.Tensor)
def setUp(self) -> None: parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = VanillaPolicyGradient.add_model_specific_args( parent_parser) args_list = ["--env", "CartPole-v0"] self.hparams = parent_parser.parse_args(args_list) self.trainer = pl.Trainer( gpus=0, max_steps=100, max_epochs= 100, # Set this as the same as max steps to ensure that it doesn't stop early val_check_interval= 1, # This just needs 'some' value, does not effect training right now fast_dev_run=True)
def test_policy_gradient(self): """Smoke test that the policy gradient model runs""" model = VanillaPolicyGradient(self.hparams.env) self.trainer.fit(model)
def test_policy_gradient(self): """Smoke test that the policy gradient model runs""" model = VanillaPolicyGradient(self.hparams.env) result = self.trainer.fit(model) self.assertEqual(result, 1)