def step_wait(self): obs = [] for e in range(self.nb_env): ( ob, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e], ) = self.envs[e].step(self.actions[e]) if self.buf_dones[e]: ob = self.envs[e].reset() obs.append(ob) obs = listd_to_dlist(obs) new_obs = {} for k, v in dummy_handle_ob(obs).items(): if self._is_tensor_key(k): new_obs[k] = torch.stack(v) else: new_obs[k] = v self.buf_obs = new_obs return ( self.buf_obs, torch.tensor(self.buf_rews), torch.tensor(self.buf_dones), self.buf_infos, )
def run(self): for epoch_id in self.epoch_ids: reward_buf = 0 for net_path in self.log_dir_helper.network_paths_at_epoch( epoch_id): self.network.load_state_dict( torch.load(net_path, map_location=lambda storage, loc: storage)) self.network.eval() internals = listd_to_dlist( [self.network.new_internals(self.device)]) next_obs = dtensor_to_dev(self.env_mgr.reset(), self.device) self.env_mgr.render() episode_complete = False while not episode_complete: obs = next_obs with torch.no_grad(): actions, _, internals = self.actor.act( self.network, obs, internals) next_obs, rewards, terminals, infos = self.env_mgr.step( actions) self.env_mgr.render() next_obs = dtensor_to_dev(next_obs, self.device) reward_buf += rewards[0] if terminals[0]: episode_complete = True print(f"EPOCH_ID: {epoch_id} REWARD: {reward_buf}")
def reset(self): obs = [] for e in range(self.nb_env): ob = self.envs[e].reset() obs.append(ob) obs = listd_to_dlist(obs) new_obs = {} for k, v in dummy_handle_ob(obs).items(): if self._is_tensor_key(k): new_obs[k] = torch.stack(v) else: new_obs[k] = v self.buf_obs = new_obs return self.buf_obs
def read(self): exp_list, last_obs, is_weights = self._sample() exp_dev_list = [ self._exp_to_dev(e, self.target_device) for e in exp_list ] # will be list of dicts, convert to dict of lists dict_of_list = listd_to_dlist(exp_dev_list) # get next obs dict_of_list["next_observation"] = last_obs # importance sampling weights dict_of_list["importance_sample_weights"] = is_weights # return named tuple return namedtuple( self.__class__.__name__, ["importance_sample_weights", "next_observation"] + self._keys, )(**dict_of_list)
def act_batch(self, network, batch_obs, batch_terminals, batch_actions, internals, device): exp_cache = [] for obs, actions, terminals in zip(batch_obs, batch_actions, batch_terminals): preds, internals, _ = network(obs, internals) exp_cache.append(self._process_exp(preds, actions)) # where returns a single element tuple with the indexes terminal_inds = np.where(terminals)[0] for i in terminal_inds: for k, v in network.new_internals(device).items(): internals[k][i] = v exp = listd_to_dlist(exp_cache) return torch.stack(exp['log_probs']), torch.stack( exp['values']), torch.stack(exp['entropies'])
def run(self): local_step_count = global_step_count = self.initial_step_count ep_rewards = torch.zeros(self.nb_env) obs = dtensor_to_dev(self.env_mgr.reset(), self.device) internals = listd_to_dlist( [ self.network.new_internals(self.device) for _ in range(self.nb_env) ] ) start_time = time() while global_step_count < self.nb_step: actions, internals = self.agent.act(self.network, obs, internals) next_obs, rewards, terminals, infos = self.env_mgr.step(actions) next_obs = dtensor_to_dev(next_obs, self.device) self.agent.observe( obs, rewards.to(self.device).float(), terminals.to(self.device).float(), infos, ) for i, terminal in enumerate(terminals): if terminal: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v # Perform state updates local_step_count += self.nb_env global_step_count += self.nb_env * self.world_size ep_rewards += rewards.float() obs = next_obs term_rewards = [] for i, terminal in enumerate(terminals): if terminal: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v term_rewards.append(ep_rewards[i].item()) ep_rewards[i].zero_() if term_rewards: term_reward = np.mean(term_rewards) delta_t = time() - start_time self.logger.info( "RANK: {} " "GLOBAL STEP: {} " "REWARD: {} " "GLOBAL STEP/S: {} " "LOCAL STEP/S: {}".format( self.global_rank, global_step_count, term_reward, (global_step_count - self.initial_step_count) / delta_t, (local_step_count - self.initial_step_count) / delta_t, ) ) # Learn if self.agent.is_ready(): _, _ = self.agent.learn_step( self.updater, self.network, next_obs, internals ) self.agent.clear() for k, vs in internals.items(): internals[k] = [v.detach() for v in vs]
def run(self): local_step_count = global_step_count = self.initial_step_count next_save = self.init_next_save(self.initial_step_count, self.epoch_len) prev_step_t = time() ep_rewards = torch.zeros(self.nb_env) obs = dtensor_to_dev(self.env_mgr.reset(), self.device) internals = listd_to_dlist( [ self.network.new_internals(self.device) for _ in range(self.nb_env) ] ) start_time = time() while global_step_count < self.nb_step: actions, internals = self.agent.act(self.network, obs, internals) next_obs, rewards, terminals, infos = self.env_mgr.step(actions) next_obs = dtensor_to_dev(next_obs, self.device) self.agent.observe( obs, rewards.to(self.device).float(), terminals.to(self.device).float(), infos, ) for i, terminal in enumerate(terminals): if terminal: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v # Perform state updates local_step_count += self.nb_env global_step_count += self.nb_env * self.world_size ep_rewards += rewards.float() obs = next_obs term_rewards = [] for i, terminal in enumerate(terminals): if terminal: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v term_rewards.append(ep_rewards[i].item()) ep_rewards[i].zero_() if term_rewards: term_reward = np.mean(term_rewards) delta_t = time() - start_time self.logger.info( "RANK: {} " "GLOBAL STEP: {} " "REWARD: {} " "GLOBAL STEP/S: {} " "LOCAL STEP/S: {}".format( self.global_rank, global_step_count, term_reward, (global_step_count - self.initial_step_count) / delta_t, (local_step_count - self.initial_step_count) / delta_t, ) ) self.summary_writer.add_scalar( "reward", term_reward, global_step_count ) if global_step_count >= next_save: self.saver.save_state_dicts( self.network, global_step_count, self.optimizer ) next_save += self.epoch_len # Learn if self.agent.is_ready(): loss_dict, metric_dict = self.agent.learn_step( self.updater, self.network, next_obs, internals ) total_loss = torch.sum( torch.stack(tuple(loss for loss in loss_dict.values())) ) self.agent.clear() for k, vs in internals.items(): internals[k] = [v.detach() for v in vs] # write summaries cur_step_t = time() if cur_step_t - prev_step_t > self.summary_freq: self.write_summaries( self.summary_writer, global_step_count, total_loss, loss_dict, metric_dict, self.network.named_parameters(), ) prev_step_t = cur_step_t
def run(self, workers, profile=False): if profile: try: from pyinstrument import Profiler except: raise ImportError('You must install pyinstrument to use profiling.') profiler = Profiler() profiler.start() # setup queuer rollout_queuer = RolloutQueuerAsync(workers, self.nb_learn_batch, self.rollout_queue_size) rollout_queuer.start() # initial setup global_step_count = self.initial_step_count next_save = self.init_next_save(self.initial_step_count, self.epoch_len) prev_step_t = time() ep_rewards = torch.zeros(self.nb_env) start_time = time() # loop until total number steps print('{} starting training'.format(self.rank)) while not self.done(global_step_count): self.exp.clear() # Get batch from queue rollouts, terminal_rewards, terminal_infos = rollout_queuer.get() # Iterate forward on batch self.exp.write_exps(rollouts) # keep a copy of terminals on the cpu it's faster rollout_terminals = torch.stack(self.exp['terminals']).numpy() self.exp.to(self.device) r = self.exp.read() internals = {k: ts[0].unbind(0) for k, ts in r.internals.items()} for obs, rewards, terminals in zip( r.observations, r.rewards, rollout_terminals ): _, h_exp, internals = self.actor.act(self.network, obs, internals) self.exp.write_actor(h_exp, no_env=True) # where returns a single element tuple with the indexes terminal_inds = np.where(terminals)[0] for i in terminal_inds: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v # compute loss loss_dict, metric_dict = self.learner.compute_loss( self.network, self.exp.read(), r.next_observation, internals ) total_loss = torch.sum( torch.stack(tuple(loss for loss in loss_dict.values())) ) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() # Perform state updates global_step_count += self.nb_env * self.nb_learn_batch * len(r.terminals) * self.nb_learners # if rank 0 write summaries and save # and send parameters to workers async if self.rank == 0: # TODO: this could be parallelized, chunk by nb learners self.synchronize_worker_parameters(workers, global_step_count) # possible save if global_step_count >= next_save: self.saver.save_state_dicts( self.network, global_step_count, self.optimizer ) next_save += self.epoch_len # write reward summaries if any(terminal_rewards): terminal_rewards = list(filter(lambda x: x is not None, terminal_rewards)) self.summary_writer.add_scalar( 'reward', np.mean(terminal_rewards), global_step_count ) # write infos if any(terminal_infos): terminal_infos = list(filter(lambda x: x is not None, terminal_infos)) float_keys = [ k for k, v in terminal_infos[0].items() if type(v) == float ] terminal_infos_dlist = listd_to_dlist(terminal_infos) for k in float_keys: self.summary_writer.add_scalar( f'info/{k}', np.mean(terminal_infos_dlist[k]), global_step_count ) # write summaries cur_step_t = time() if cur_step_t - prev_step_t > self.summary_freq: print('Rank {} Metrics:'.format(self.rank), rollout_queuer.metrics()) if self.rank == 0: self.write_summaries( self.summary_writer, global_step_count, total_loss, loss_dict, metric_dict, self.network.named_parameters() ) prev_step_t = cur_step_t rollout_queuer.close() print('{} stopped training'.format(self.rank)) if profile: profiler.stop() print(profiler.output_text(unicode=True, color=True))
def run(self): local_step_count = global_step_count = self.initial_step_count ep_rewards = torch.zeros(self.nb_env) obs = dtensor_to_dev(self.env_mgr.reset(), self.device) internals = listd_to_dlist([ self.network.new_internals(self.device) for _ in range(self.nb_env) ]) start_time = time() while global_step_count < self.nb_step: actions, internals = self.agent.act(self.network, obs, internals) next_obs, rewards, terminals, infos = self.env_mgr.step(actions) next_obs = dtensor_to_dev(next_obs, self.device) self.agent.observe( obs, rewards.to(self.device).float(), terminals.to(self.device).float(), infos ) for i, terminal in enumerate(terminals): if terminal: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v # Perform state updates local_step_count += self.nb_env global_step_count += self.nb_env * self.world_size ep_rewards += rewards.float() obs = next_obs term_rewards = [] for i, terminal in enumerate(terminals): if terminal: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v term_rewards.append(ep_rewards[i].item()) ep_rewards[i].zero_() if term_rewards: term_reward = np.mean(term_rewards) delta_t = time() - start_time self.logger.info( 'RANK: {} ' 'GLOBAL STEP: {} ' 'REWARD: {} ' 'GLOBAL STEP/S: {} ' 'LOCAL STEP/S: {}'.format( self.global_rank, global_step_count, term_reward, (global_step_count - self.initial_step_count) / delta_t, (local_step_count - self.initial_step_count) / delta_t ) ) # Learn if self.agent.is_ready(): loss_dict, metric_dict = self.agent.compute_loss( self.network, next_obs, internals ) total_loss = torch.sum( torch.stack(tuple(loss for loss in loss_dict.values())) ) self.optimizer.zero_grad() total_loss.backward() dist.barrier() handles = [] for param in self.network.parameters(): handles.append( dist.all_reduce(param.grad, async_op=True)) for handle in handles: handle.wait() # for param in self.network.parameters(): # param.grad.mul_(1. / self.world_size) self.optimizer.step() self.agent.clear() for k, vs in internals.items(): internals[k] = [v.detach() for v in vs]