Beispiel #1
0
 def train(cls, num_frames: int):
     rl_agent = cls()
     preset = ddpg(device=rl_agent.device,
                   lr_q=Settings.LEARNING_RATE,
                   lr_pi=Settings.LEARNING_RATE)
     experiment = SingleEnvExperiment(preset, rl_agent.env)
     experiment.train(num_frames)
     default_log_dir = experiment._writer.log_dir
     copy_tree(default_log_dir, Settings.FULL_LOG_DIR)
     rmtree(default_log_dir)
     rl_agent.env.close()
Beispiel #2
0
    def resume_training(cls, path, num_frames: int):
        rl_agent = cls()
        lr = Settings.LEARNING_RATE
        agent = ddpg(device=rl_agent.device, lr_q=lr, lr_pi=lr)
        q_module = torch.load(os.path.join(path, "q.pt"),
                              map_location='cpu').to(rl_agent.device)
        policy_module = torch.load(os.path.join(path, "policy.pt"),
                                   map_location='cpu').to(rl_agent.device)

        experiment = SingleEnvExperiment(agent, rl_agent.env)
        agent = experiment._agent.agent
        old_q = agent.q
        old_q.model.load_state_dict(q_module.state_dict())
        old_policy = agent.policy
        old_policy.model.load_state_dict(policy_module.state_dict())
        experiment.train(frames=num_frames)
        default_log_dir = experiment._writer.log_dir
        copy_tree(default_log_dir, Settings.FULL_LOG_DIR)
        rmtree(default_log_dir)
        rl_agent.env.close()
def main():
    device = 'cuda'

    frames = int(1e7)

    agents = [
        ddpg(last_frame=frames),
        ppo(last_frame=frames),
        sac(last_frame=frames)
    ]

    envs = [
        GymEnvironment(env, device) for env in [
            'AntBulletEnv-v0', "HalfCheetahBulletEnv-v0",
            'HumanoidBulletEnv-v0', 'HopperBulletEnv-v0',
            'Walker2DBulletEnv-v0'
        ]
    ]

    SlurmExperiment(agents,
                    envs,
                    frames,
                    sbatch_args={'partition': '1080ti-long'})
Beispiel #4
0
 def test_ddpg(self):
     self.validate(ddpg(replay_start_size=50, device='cpu'))