def _train(self): batch_idxes = np.arange(self.num_tasks) gt.start() for epoch in gt.timed_for( trange(self._start_epoch, self.num_epochs), save_itrs=True, ): # Distribute the evaluation. We ship the # params of each needed network to the # remote path collector params_list = [] for net in self.policy.networks: params_list.append(ptu.state_dict_cpu(net)) self.path_collector.set_policy_params(params_list) evaluation_train_obj_id_list = [] count = 0 while count < len(self.train_goals): if len(self.train_goals) - count < self.num_workers: evaluation_obj_id = self.path_collector.async_evaluate( self.train_goals[count:]) count = len(self.train_goals) else: evaluation_obj_id = self.path_collector.async_evaluate( self.train_goals[count:count + self.num_workers]) count += self.num_workers evaluation_train_obj_id_list.extend(evaluation_obj_id) assert len(evaluation_train_obj_id_list) == len( self.train_goals ), f'{len(evaluation_train_obj_id_list)}, {len(self.train_goals)}' evaluation_wd_obj_id_list = [] count = 0 while count < len(self.wd_goals): if len(self.wd_goals) - count < self.num_workers: evaluation_obj_id = self.path_collector.async_evaluate( self.wd_goals[count:]) count = len(self.wd_goals) else: evaluation_obj_id = self.path_collector.async_evaluate( self.wd_goals[count:count + self.num_workers]) count += self.num_workers evaluation_wd_obj_id_list.extend(evaluation_obj_id) assert len(evaluation_wd_obj_id_list) == len(self.wd_goals) # evaluation_ood_obj_id_list = [] # count = 0 # while count < len(self.ood_goals) : # if len(self.ood_goals) - count < self.num_workers: # evaluation_obj_id = self.path_collector.async_evaluate(self.ood_goals[count:]) # count = len(self.ood_goals) # else: # evaluation_obj_id = self.path_collector.async_evaluate(self.ood_goals[count:count + self.num_workers]) # count += self.num_workers # evaluation_ood_obj_id_list.extend(evaluation_obj_id) # assert len(evaluation_ood_obj_id_list) == len(self.ood_goals) gt.stamp('set_up_evaluation', unique=False) train_batch_obj_id = self.train_buffer.sample_training_data( batch_idxes) for _ in trange(self.num_train_loops_per_epoch): train_raw_batch = ray.get(train_batch_obj_id) gt.stamp('sample_training_data', unique=False) # In this way, we can start the data sampling job for the # next training while doing training for the current loop. train_batch_obj_id = self.train_buffer.sample_training_data( batch_idxes) gt.stamp('set_up_sampling', unique=False) train_data = self.construct_training_batch(train_raw_batch) gt.stamp('construct_training_batch', unique=False) self.policy.train(train_data) gt.stamp('training', unique=False) eval_train_returns = ray.get(evaluation_train_obj_id_list) self.avg_train_episode_returns = [ item[0] for item in eval_train_returns ] self.final_train_achieved = [ item[1] for item in eval_train_returns ] self.train_avg_returns = np.mean(self.avg_train_episode_returns) eval_wd_returns = ray.get(evaluation_wd_obj_id_list) self.avg_wd_episode_returns = [item[0] for item in eval_wd_returns] self.final_wd_achieved = [item[1] for item in eval_wd_returns] self.wd_avg_returns = np.mean(self.avg_wd_episode_returns) # eval_ood_returns = ray.get(evaluation_ood_obj_id_list) # self.avg_ood_episode_returns = [item[0] for item in eval_ood_returns] # self.final_ood_achieved = [item[1] for item in eval_ood_returns] # self.ood_avg_returns = np.mean(self.avg_ood_episode_returns) gt.stamp('evaluation', unique=False) self._end_epoch(epoch)
def _train(self): # Fill the replay buffer to a minimum before training starts if self.min_num_steps_before_training > self.replay_buffer.num_steps_can_sample( ): init_expl_paths = self.expl_data_collector.collect_new_paths( self.trainer.policy, self.max_path_length, self.min_num_steps_before_training, discard_incomplete_paths=False, ) self.replay_buffer.add_paths(init_expl_paths) self.expl_data_collector.end_epoch(-1) for epoch in gt.timed_for( trange(self._start_epoch, self.num_epochs), save_itrs=True, ): # To evaluate the policy remotely, # we're shipping the policy params to the remote evaluator # This can be made more efficient # But this is currently extremely cheap due to small network size pol_state_dict = ptu.state_dict_cpu(self.trainer.policy) remote_eval_obj_id = self.remote_eval_data_collector.async_collect_new_paths.remote( self.max_path_length, self.num_eval_steps_per_epoch, discard_incomplete_paths=True, deterministic_pol=True, pol_state_dict=pol_state_dict) gt.stamp('remote evaluation submit') for _ in range(self.num_train_loops_per_epoch): new_expl_paths = self.expl_data_collector.collect_new_paths( self.trainer.policy, self.max_path_length, self.num_expl_steps_per_train_loop, discard_incomplete_paths=False, optimistic_exploration=self. optimistic_exp_hp['should_use'], optimistic_exploration_kwargs=dict( policy=self.trainer.policy, qfs=[self.trainer.qf1, self.trainer.qf2], hyper_params=self.optimistic_exp_hp)) gt.stamp('exploration sampling', unique=False) self.replay_buffer.add_paths(new_expl_paths) gt.stamp('data storing', unique=False) for _ in range(self.num_trains_per_train_loop): train_data = self.replay_buffer.random_batch( self.batch_size) self.trainer.train(train_data) gt.stamp('training', unique=False) # Wait for eval to finish ray.get([remote_eval_obj_id]) gt.stamp('remote evaluation wait') self._end_epoch(epoch)
def _train(self): gt.reset() # -----------------------Imitation phase-------------------------------- # Fill the replay buffer to a minimum before training starts. # Here, we use the tiMe policy as the sampler to collect # self.min_num_steps_before_training transitions, which are used # to adapt the policy to match up with the super Q function. init_paths, inferred_mdp = self.tiMe_data_collector.collect_new_paths( self.max_path_length, self.min_num_steps_before_training, discard_incomplete_paths=False, ) self.replay_buffer.add_paths(init_paths) self.expl_data_collector.end_epoch(-1) # Imitation # train the policy network using the collected transitions for _ in trange(self.num_pre_train): train_data = self.replay_buffer.random_batch(self.batch_size) self.trainer.train_to_imitate(train_data) self.replay_buffer.reset() # ------------------------------------------------------------------------- self.trainer.update_inferred_mdp_target_policy(inferred_mdp) # ------------------------------------------------------------------------- init_expl_paths = self.expl_data_collector.collect_new_paths( self.trainer.policy, self.max_path_length, self.min_num_steps_before_training, discard_incomplete_paths=False, ) self.replay_buffer.add_paths(init_expl_paths) self.expl_data_collector.end_epoch(-1) for epoch in gt.timed_for( trange(self._start_epoch, self.num_epochs), save_itrs=True, ): # To evaluate the policy remotely, # we're shipping the policy params to the remote evaluator # This can be made more efficient # But this is currently extremely cheap due to small network size pol_state_dict = ptu.state_dict_cpu(self.trainer.policy) remote_eval_obj_id = self.remote_eval_data_collector.async_collect_new_paths.remote( self.max_path_length, self.num_eval_steps_per_epoch, discard_incomplete_paths=True, deterministic_pol=True, pol_state_dict=pol_state_dict) gt.stamp('remote evaluation submit') for _ in range(self.num_train_loops_per_epoch): new_expl_paths = self.expl_data_collector.collect_new_paths( self.trainer.policy, self.max_path_length, self.num_expl_steps_per_train_loop, discard_incomplete_paths=False, optimistic_exploration=self. optimistic_exp_hp['should_use'], optimistic_exploration_kwargs=dict( policy=self.trainer.policy, qfs=[self.trainer.qf1, self.trainer.qf2], hyper_params=self.optimistic_exp_hp)) gt.stamp('exploration sampling', unique=False) self.replay_buffer.add_paths(new_expl_paths) gt.stamp('data storing', unique=False) for _ in range(self.num_trains_per_train_loop): train_data = self.replay_buffer.random_batch( self.batch_size) self.trainer.train_qf1(train_data) train_data = self.replay_buffer.random_batch( self.batch_size) self.trainer.train_qf2_policy(train_data) gt.stamp('training', unique=False) # Wait for eval to finish ray.get([remote_eval_obj_id]) gt.stamp('remote evaluation wait') self._end_epoch(epoch)