def sample_meta_test(self, policy, task, batch_id, params=None, target_params=None, old_episodes=None): for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) episodes = BatchEpisodes(dic_agent_conf=self.dic_agent_conf, old_episodes=old_episodes) observations, batch_ids = self.envs.reset() dones = [False] if params: # todo precise load parameter logic policy.load_params(params) while (not all(dones)) or (not self.queue.empty()): actions = policy.choose_action(observations) ## for multi_intersection actions = np.reshape(actions, (-1, 1)) new_observations, rewards, dones, new_batch_ids, _ = self.envs.step( actions) episodes.append(observations, actions, new_observations, rewards, batch_ids) observations, batch_ids = new_observations, new_batch_ids if self.step > self.dic_agent_conf[ 'UPDATE_START'] and self.step % self.dic_agent_conf[ 'UPDATE_PERIOD'] == 0: if len(episodes) > self.dic_agent_conf['MAX_MEMORY_LEN']: episodes.forget() policy.fit(episodes, params=params, target_params=target_params) sample_size = min(self.dic_agent_conf['SAMPLE_SIZE'], len(episodes)) slice_index = random.sample(range(len(episodes)), sample_size) params = policy.update_params(episodes, params=copy.deepcopy(params), lr_step=self.lr_step, slice_index=slice_index) policy.load_params(params) self.lr_step += 1 self.target_step += 1 if self.target_step == self.dic_agent_conf[ 'UPDATE_Q_BAR_FREQ']: target_params = params self.target_step = 0 if self.step > self.dic_agent_conf[ 'UPDATE_START'] and self.step % self.dic_agent_conf[ 'TEST_PERIOD'] == 0: self.single_test_sample(policy, task, self.test_step, params=params) pickle.dump( params, open( os.path.join( self.dic_path['PATH_TO_MODEL'], 'params' + "_" + str(self.test_step) + ".pkl"), 'wb')) write_summary(self.dic_path, task, self.dic_traffic_env_conf["EPISODE_LEN"], batch_id) self.test_step += 1 self.step += 1 policy.decay_epsilon(batch_id) self.envs.bulk_log() return params, target_params, episodes
def sample_period(self, policy, task, batch_id, params=None, target_params=None, old_episodes=None): for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) episodes = BatchEpisodes( batch_size=self.batch_size, dic_traffic_env_conf=self.dic_traffic_env_conf, dic_agent_conf=self.dic_agent_conf, old_episodes=old_episodes) observations, batch_ids = self.envs.reset() dones = [False] if params: # todo precise load parameter logic policy.load_params(params) while (not all(dones)) or (not self.queue.empty()): if self.dic_traffic_env_conf['MODEL_NAME'] == 'MetaDQN': actions = policy.choose_action(observations, task_type=task_type) else: actions = policy.choose_action(observations) ## for multi_intersection actions = np.reshape(actions, (-1, 1)) new_observations, rewards, dones, new_batch_ids, _ = self.envs.step( actions) episodes.append(observations, actions, new_observations, rewards, batch_ids) observations, batch_ids = new_observations, new_batch_ids # if update if self.step > self.dic_agent_conf[ 'UPDATE_START'] and self.step % self.dic_agent_conf[ 'UPDATE_PERIOD'] == 0: if len(episodes) > self.dic_agent_conf['MAX_MEMORY_LEN']: #TODO episodes.forget() policy.fit(episodes, params=params, target_params=target_params) params = policy.update_params(episodes, params=copy.deepcopy(params), lr_step=self.lr_step) policy.load_params(params) self.lr_step += 1 self.target_step += 1 if self.target_step == self.dic_agent_conf[ 'UPDATE_Q_BAR_FREQ']: target_params = params self.target_step = 0 if self.step > self.dic_agent_conf[ 'UPDATE_START'] and self.step % self.dic_agent_conf[ 'TEST_PERIOD'] == 0: self.test(policy, task, self.test_step, params=params) pickle.dump( params, open( os.path.join( self.dic_path['PATH_TO_MODEL'], 'params' + "_" + str(self.test_step) + ".pkl"), 'wb')) self.test_step += 1 self.step += 1 return params, target_params, episodes