コード例 #1
0
ファイル: test_reinforce.py プロジェクト: djbyrne/core_rl
    def setUp(self) -> None:
        self.env = ToTensor(gym.make("CartPole-v0"))
        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.n
        self.net = MLP(self.obs_shape, self.n_actions)
        self.agent = Agent(self.net)
        self.xp_stream = EpisodicExperienceStream(self.env,
                                                  self.agent,
                                                  Mock(),
                                                  episodes=4)
        self.rl_dataloader = DataLoader(self.xp_stream)

        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = DQNLightning.add_model_specific_args(parent_parser)
        args_list = [
            "--algo",
            "dqn",
            "--warm_start_steps",
            "500",
            "--episode_length",
            "100",
        ]
        self.hparams = parent_parser.parse_args(args_list)

        self.model = ReinforceLightning(self.hparams)
コード例 #2
0
ファイル: test_policy_models.py プロジェクト: djbyrne/core_rl
    def setUp(self) -> None:
        parent_parser = argparse.ArgumentParser(add_help=False)
        parent_parser = cli.add_base_args(parent=parent_parser)
        parent_parser = VPGLightning.add_model_specific_args(parent_parser)
        args_list = [
            "--algo", "vpg", "--episode_length", "100", "--env", "CartPole-v0"
        ]
        self.hparams = parent_parser.parse_args(args_list)

        self.trainer = pl.Trainer(
            gpus=0,
            max_steps=100,
            max_epochs=
            100,  # Set this as the same as max steps to ensure that it doesn't stop early
            val_check_interval=
            1000  # This just needs 'some' value, does not effect training right now
        )
コード例 #3
0
        model = VPGLightning(hparams)
    else:
        model = DQNLightning(hparams)

    checkpoint_callback = ModelCheckpoint(save_top_k=1,
                                          monitor='avg_reward',
                                          mode='max',
                                          prefix='')

    trainer = pl.Trainer(gpus=hparams.gpus,
                         distributed_backend=hparams.backend,
                         max_steps=hparams.max_steps,
                         max_epochs=hparams.max_steps,
                         val_check_interval=1000,
                         profiler=True,
                         checkpoint_callback=checkpoint_callback)

    trainer.fit(model)
    trainer.test()


if __name__ == '__main__':
    parent_parser = argparse.ArgumentParser(add_help=False)
    parent_parser = cli.add_base_args(parent=parent_parser)
    parent_parser = DQNLightning.add_model_specific_args(parent_parser)
    parent_parser = VPGLightning.add_model_specific_args(parent_parser)

    args = parent_parser.parse_args()

    main(args)