コード例 #1
0
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)
        self.xp_stream = EpisodicExperienceStream(self.env,
                                                  self.agent,
                                                  Mock(),
                                                  episodes=4)
        self.rl_dataloader = DataLoader(self.xp_stream)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = DQN.add_model_specific_args(parent_parser)
        args_list = [
            "--algo",
            "dqn",
            "--warm_start_steps",
            "500",
            "--episode_length",
            "100",
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = Reinforce(**vars(self.hparams))
コード例 #2
0
    def setUp(self) -> None:
        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = pl.Trainer.add_argparse_args(parent_parser)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = DQN.add_model_specific_args(parent_parser)
        args_list = [
            "--algo",
            "dqn",
            "--n_steps",
            "4",
            "--warm_start_steps",
            "100",
            "--episode_length",
            "100",
            "--gpus",
            "0",
            "--env",
            "PongNoFrameskip-v4",
        ]
        self.hparams = parent_parser.parse_args(args_list)

        self.trainer = pl.Trainer(
            gpus=self.hparams.gpus,
            max_steps=100,
            max_epochs=
            100,  # Set this as the same as max steps to ensure that it doesn't stop early
            val_check_interval=
            1,  # This just needs 'some' value, does not effect training right now
            fast_dev_run=True)
コード例 #3
0
def cli_main():
    parser = argparse.ArgumentParser(add_help=False)

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = cli.add_base_args(parser)
    parser = DQN.add_model_specific_args(parser)
    args = parser.parse_args()

    model = DQN(**args.__dict__)

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model)
コード例 #4
0
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = PolicyGradient.add_model_specific_args(parent_parser)
        args_list = [
            "--episode_length",
            "100",
            "--env",
            "CartPole-v0",
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = PolicyGradient(**vars(self.hparams))
コード例 #5
0
    def setUp(self) -> None:
        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = PolicyGradient.add_model_specific_args(parent_parser)
        args_list = [
            "--algo", "PolicyGradient", "--episode_length", "100", "--env",
            "CartPole-v0"
        ]
        self.hparams = parent_parser.parse_args(args_list)

        self.trainer = pl.Trainer(
            gpus=0,
            max_steps=100,
            max_epochs=
            100,  # Set this as the same as max steps to ensure that it doesn't stop early
            val_check_interval=
            1000,  # This just needs 'some' value, does not effect training right now
            fast_dev_run=True)
コード例 #6
0
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)
        self.exp_source = DiscountedExperienceSource(self.env, self.agent)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = DQN.add_model_specific_args(parent_parser)
        args_list = [
            "--algo", "dqn", "--warm_start_steps", "500", "--episode_length",
            "100", "--env", "CartPole-v0", "--batch_size", "32", "--gamma",
            "0.99"
        ]
        self.hparams = parent_parser.parse_args(args_list)
        self.model = Reinforce(**vars(self.hparams))

        self.rl_dataloader = self.model.train_dataloader()
def cli_main():
    parser = argparse.ArgumentParser(add_help=False)

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = cli.add_base_args(parser)
    parser = VanillaPolicyGradient.add_model_specific_args(parser)
    args = parser.parse_args()

    model = VanillaPolicyGradient(**args.__dict__)

    # save checkpoints based on avg_reward
    checkpoint_callback = ModelCheckpoint(
        save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True
    )

    seed_everything(123)
    trainer = pl.Trainer.from_argparse_args(
        args, deterministic=True, checkpoint_callback=checkpoint_callback
    )
    trainer.fit(model)
コード例 #8
0
            "episodes": self.episode_count,
            "episode_steps": self.episode_steps,
            "epsilon": self.agent.epsilon,
        }

        return OrderedDict({
            "loss": loss,
            "avg_reward": self.avg_reward,
            "log": log,
            "progress_bar": status,
        })


# todo: covert to CLI func and add test
if __name__ == '__main__':

    parser = argparse.ArgumentParser(add_help=False)

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = cli.add_base_args(parser)
    parser = DoubleDQN.add_model_specific_args(parser)
    args = parser.parse_args()

    model = DoubleDQN(**args.__dict__)

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model)