def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)
        self.xp_stream = EpisodicExperienceStream(self.env,
                                                  self.agent,
                                                  Mock(),
                                                  episodes=4)
        self.rl_dataloader = DataLoader(self.xp_stream)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = DQN.add_model_specific_args(parent_parser)
        args_list = [
            "--algo",
            "dqn",
            "--warm_start_steps",
            "500",
            "--episode_length",
            "100",
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = Reinforce(**vars(self.hparams))
 def setUp(self) -> None:
     self.env = gym.make("CartPole-v0")
     self.net = Mock()
     self.agent = Agent(self.net)
     self.xp_stream = EpisodicExperienceStream(self.env,
                                               self.agent,
                                               torch.device('cpu'),
                                               episodes=4)
     self.rl_dataloader = DataLoader(self.xp_stream)
示例#3
0
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser)
        args_list = [
            "--env", "CartPole-v0",
            "--batch_size", "32"
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = VanillaPolicyGradient(**vars(self.hparams))
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)
        self.exp_source = DiscountedExperienceSource(self.env, self.agent)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = Reinforce.add_model_specific_args(parent_parser)
        args_list = [
            "--env", "CartPole-v0", "--batch_size", "32", "--gamma", "0.99"
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = Reinforce(**vars(self.hparams))

        self.rl_dataloader = self.model.train_dataloader()
示例#5
0
 def test_base_agent(self):
     agent = Agent(self.net)
     action = agent(self.state, "cuda:0")
     self.assertIsInstance(action, list)
示例#6
0
 def setUp(self) -> None:
     self.env = ToTensor(gym.make("CartPole-v0"))
     self.net = Mock()
     self.agent = Agent(self.net)
     self.xp_stream = EpisodicExperienceStream(self.env, self.agent, device=Mock(), episodes=4)
     self.rl_dataloader = DataLoader(self.xp_stream)
 def test_base_agent(self):
     agent = Agent(self.net)
     action = agent(self.state, 'cuda:0')
     self.assertIsInstance(action, int)