class TestToTensor(TestCase): def setUp(self) -> None: self.env = ToTensor(gym.make("CartPole-v0")) def test_wrapper(self): state = self.env.reset() self.assertIsInstance(state, torch.Tensor) new_state, _, _, _ = self.env.step(1) self.assertIsInstance(new_state, torch.Tensor)
def setUp(self) -> None: self.env = ToTensor(gym.make("CartPole-v0")) self.obs_shape = self.env.observation_space.shape self.n_actions = self.env.action_space.n self.net = MLP(self.obs_shape, self.n_actions) self.agent = Agent(self.net) parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser) args_list = [ "--env", "CartPole-v0", "--batch_size", "32" ] self.hparams = parent_parser.parse_args(args_list) self.model = VanillaPolicyGradient(**vars(self.hparams))
def setUp(self) -> None: self.env = ToTensor(gym.make("CartPole-v0")) self.obs_shape = self.env.observation_space.shape self.n_actions = self.env.action_space.n self.net = MLP(self.obs_shape, self.n_actions) self.agent = Agent(self.net) self.exp_source = DiscountedExperienceSource(self.env, self.agent) parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = Reinforce.add_model_specific_args(parent_parser) args_list = [ "--env", "CartPole-v0", "--batch_size", "32", "--gamma", "0.99" ] self.hparams = parent_parser.parse_args(args_list) self.model = Reinforce(**vars(self.hparams)) self.rl_dataloader = self.model.train_dataloader()
def setUp(self) -> None: self.env = ToTensor(gym.make("CartPole-v0"))