Esempio n. 1
0
    def sample_meta_test(self,
                         policy,
                         task,
                         batch_id,
                         params=None,
                         target_params=None,
                         old_episodes=None):
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        episodes = BatchEpisodes(dic_agent_conf=self.dic_agent_conf,
                                 old_episodes=old_episodes)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        if params:  # todo precise load parameter logic
            policy.load_params(params)

        while (not all(dones)) or (not self.queue.empty()):
            actions = policy.choose_action(observations)
            ## for multi_intersection
            actions = np.reshape(actions, (-1, 1))
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            episodes.append(observations, actions, new_observations, rewards,
                            batch_ids)
            observations, batch_ids = new_observations, new_batch_ids

            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'UPDATE_PERIOD'] == 0:
                if len(episodes) > self.dic_agent_conf['MAX_MEMORY_LEN']:
                    episodes.forget()

                policy.fit(episodes,
                           params=params,
                           target_params=target_params)
                sample_size = min(self.dic_agent_conf['SAMPLE_SIZE'],
                                  len(episodes))
                slice_index = random.sample(range(len(episodes)), sample_size)
                params = policy.update_params(episodes,
                                              params=copy.deepcopy(params),
                                              lr_step=self.lr_step,
                                              slice_index=slice_index)

                policy.load_params(params)

                self.lr_step += 1
                self.target_step += 1
                if self.target_step == self.dic_agent_conf[
                        'UPDATE_Q_BAR_FREQ']:
                    target_params = params
                    self.target_step = 0

            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'TEST_PERIOD'] == 0:
                self.single_test_sample(policy,
                                        task,
                                        self.test_step,
                                        params=params)
                pickle.dump(
                    params,
                    open(
                        os.path.join(
                            self.dic_path['PATH_TO_MODEL'],
                            'params' + "_" + str(self.test_step) + ".pkl"),
                        'wb'))
                write_summary(self.dic_path, task,
                              self.dic_traffic_env_conf["EPISODE_LEN"],
                              batch_id)

                self.test_step += 1
            self.step += 1

        policy.decay_epsilon(batch_id)
        self.envs.bulk_log()
        return params, target_params, episodes
Esempio n. 2
0
    def sample_period(self,
                      policy,
                      task,
                      batch_id,
                      params=None,
                      target_params=None,
                      old_episodes=None):
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        episodes = BatchEpisodes(
            batch_size=self.batch_size,
            dic_traffic_env_conf=self.dic_traffic_env_conf,
            dic_agent_conf=self.dic_agent_conf,
            old_episodes=old_episodes)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        if params:  # todo precise load parameter logic
            policy.load_params(params)

        while (not all(dones)) or (not self.queue.empty()):

            if self.dic_traffic_env_conf['MODEL_NAME'] == 'MetaDQN':
                actions = policy.choose_action(observations,
                                               task_type=task_type)
            else:
                actions = policy.choose_action(observations)
            ## for multi_intersection
            actions = np.reshape(actions, (-1, 1))
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            episodes.append(observations, actions, new_observations, rewards,
                            batch_ids)
            observations, batch_ids = new_observations, new_batch_ids

            # if update
            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'UPDATE_PERIOD'] == 0:
                if len(episodes) > self.dic_agent_conf['MAX_MEMORY_LEN']:
                    #TODO
                    episodes.forget()

                policy.fit(episodes,
                           params=params,
                           target_params=target_params)
                params = policy.update_params(episodes,
                                              params=copy.deepcopy(params),
                                              lr_step=self.lr_step)
                policy.load_params(params)

                self.lr_step += 1
                self.target_step += 1
                if self.target_step == self.dic_agent_conf[
                        'UPDATE_Q_BAR_FREQ']:
                    target_params = params
                    self.target_step = 0

            if self.step > self.dic_agent_conf[
                    'UPDATE_START'] and self.step % self.dic_agent_conf[
                        'TEST_PERIOD'] == 0:
                self.test(policy, task, self.test_step, params=params)
                pickle.dump(
                    params,
                    open(
                        os.path.join(
                            self.dic_path['PATH_TO_MODEL'],
                            'params' + "_" + str(self.test_step) + ".pkl"),
                        'wb'))

                self.test_step += 1
            self.step += 1

        return params, target_params, episodes