class Looper: def __init__(self, env="LunarLanderContinuous-v2", gamma=0.99): self.policy = Policy(env_id=env) self.env = gym.make(env) self.runs = Runs(gamma=gamma) self.plotter = VisdomPlotter(env_name=env) self.device_cpu = torch.device("cpu") if torch.cuda.is_available(): self.use_gpu = True self.device = torch.device("cuda") else: self.use_gpu = False self.device = torch.device("cpu") def loop(self, epochs=1, show_every=50, show_for=10, sample_count=1, lr=5e-3, batch_size=4096): for e in trange(epochs): if e % show_every == 0 and e > 0: # self.policy.demonstrate(ep_count=show_for) pass reward = self.generate_samples(num_ep=sample_count) self.plotter.plot_line('reward per episode', 'reward', 'avg reward when generating samples', e, reward) loss = self.estimate_return(lr=lr, batch_size=batch_size) self.plotter.plot_line('loss per batch', 'loss', 'avg loss when training critic', e, loss) self.improve_policy(lr=lr, batch_size=batch_size) def generate_samples(self, num_ep=1, render=False): # Add stuff to trajectories self.runs.reset() reward = 0 with torch.no_grad(): for e in range(num_ep): # the eth episode in this run done = False ob = self.env.reset() while not done: observation = ob[None] action = self.policy.sample_action( observations=observation) action = action[0] ob_, r, done, _ = self.env.step(action) reward += r self.runs.add_step(state=ob, action=action, reward=r, next_state=ob_, done=done) ob = ob_ if render: self.env.render() reward /= num_ep return reward def estimate_return(self, lr=5e-3, batch_size=4096): self.runs.compute_rewards() vs = self.runs.get_normalized_rtg() dataloader = CriticLoader(dframe=vs) loss = self.policy.improve_critic(data_loader=dataloader, lr=lr, batch_size=batch_size) self.runs.compute_baseline_dict(critic=self.policy.critic, batch_size=batch_size) return loss def improve_policy(self, lr=5e-3, batch_size=4096): data_loader = ActorLoader(dframe=self.runs.all_runs, baseline_dict=self.runs.baseline) self.policy.improve_actor(data_loader=data_loader, lr=lr, batch_size=batch_size)