示例#1
0
    def setUp(self) -> None:
        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = pl.Trainer.add_argparse_args(parent_parser)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = DQN.add_model_specific_args(parent_parser)
        args_list = [
            "--algo",
            "dqn",
            "--n_steps",
            "4",
            "--warm_start_steps",
            "100",
            "--episode_length",
            "100",
            "--gpus",
            "0",
            "--env",
            "PongNoFrameskip-v4",
        ]
        self.hparams = parent_parser.parse_args(args_list)

        self.trainer = pl.Trainer(
            gpus=self.hparams.gpus,
            max_steps=100,
            max_epochs=
            100,  # Set this as the same as max steps to ensure that it doesn't stop early
            val_check_interval=
            1,  # This just needs 'some' value, does not effect training right now
            fast_dev_run=True)
示例#2
0
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)
        self.xp_stream = EpisodicExperienceStream(self.env, self.agent, Mock(), episodes=4)
        self.rl_dataloader = DataLoader(self.xp_stream)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = DQN.add_model_specific_args(parent_parser)
        args_list = [
            "--algo", "dqn",
            "--warm_start_steps", "500",
            "--episode_length", "100",
            "--env", "CartPole-v0",
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = Reinforce(**vars(self.hparams))
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)
        self.exp_source = DiscountedExperienceSource(self.env, self.agent)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = DQN.add_model_specific_args(parent_parser)
        args_list = [
            "--algo", "dqn", "--warm_start_steps", "500", "--episode_length",
            "100", "--env", "CartPole-v0", "--batch_size", "32", "--gamma",
            "0.99"
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = Reinforce(**vars(self.hparams))

        self.rl_dataloader = self.model.train_dataloader()
示例#4
0
 def test_dqn(self):
     """Smoke test that the DQN model runs."""
     model = DQN(self.hparams.env, num_envs=5)
     self.trainer.fit(model)
示例#5
0
    def test_n_step_dqn(self):
        """Smoke test that the N Step DQN model runs"""
        model = DQN(self.hparams.env, n_steps=self.hparams.n_steps)
        result = self.trainer.fit(model)

        self.assertEqual(result, 1)
示例#6
0
    def test_dqn(self):
        """Smoke test that the DQN model runs"""
        model = DQN(self.hparams.env, num_envs=5)
        result = self.trainer.fit(model)

        self.assertEqual(result, 1)