Пример #1
0
    def __init__(self, init_n_ego_dict):  # init_n_ego_dict is used to init traffic (mainly) and ego dynamics
        self.TASK2MODEL = dict(left=LoadPolicy('../utils/models/left/experiment-2021-03-15-16-39-00', 180000),
                               straight=LoadPolicy('../utils/models/straight/experiment-2021-03-15-19-16-13', 175000),
                               right=LoadPolicy('../utils/models/right/experiment-2021-03-15-21-02-51', 195000),)
        self.n_ego_instance = {}
        self.n_ego_dynamics = {}
        self.n_ego_select_index = {}
        for egoID, ego_dict in init_n_ego_dict.items():
            self.n_ego_instance[egoID] = CrossroadEnd2end(training_task=NAME2TASK[egoID[:2]],
                                                          mode='testing',
                                                          multi_display=True)

        self.mpp = MultiPathGenerator()
        self.virtual_model = dict(left=EnvironmentModel(training_task='left', mode='selecting'),
                                  straight=EnvironmentModel(training_task='straight', mode='selecting'),
                                  right=EnvironmentModel(training_task='right', mode='selecting'))

        # ------------------build graph for tf.function in advance-----------------------
        for task in ['left', 'straight', 'right']:
            env = CrossroadEnd2end(training_task=task, mode='testing', multi_display=False)
            for i in range(3):
                obs = env.reset()
                obs = tf.convert_to_tensor(obs[np.newaxis, :], dtype=tf.float32)
                self.is_safe(obs, i, task)
            obs = env.reset()
            obs_with_specific_shape = np.tile(obs, (3, 1))
            self.TASK2MODEL[task].run_batch(obs_with_specific_shape)
            self.TASK2MODEL[task].obj_value_batch(obs_with_specific_shape)
            env.close()
        # ------------------build graph for tf.function in advance-----------------------
        self.reset(init_n_ego_dict)
Пример #2
0
 def __init__(self, task, train_exp_dir, ite, logdir=None):
     self.task = task
     self.policy = LoadPolicy('../utils/models/{}/{}'.format(task, train_exp_dir), ite)
     self.env = CrossroadEnd2end(training_task=self.task, mode='testing')
     self.model = EnvironmentModel(self.task, mode='selecting')
     self.recorder = Recorder()
     self.episode_counter = -1
     self.step_counter = -1
     self.obs = None
     self.stg = MultiPathGenerator()
     self.step_timer = TimerStat()
     self.ss_timer = TimerStat()
     self.logdir = logdir
     if self.logdir is not None:
         config = dict(task=task, train_exp_dir=train_exp_dir, ite=ite)
         with open(self.logdir + '/config.json', 'w', encoding='utf-8') as f:
             json.dump(config, f, ensure_ascii=False, indent=4)
     self.fig = plt.figure(figsize=(8, 8))
     plt.ion()
     self.hist_posi = []
     self.old_index = 0
     self.path_list = self.stg.generate_path(self.task)
     # ------------------build graph for tf.function in advance-----------------------
     for i in range(3):
         obs = self.env.reset()
         obs = tf.convert_to_tensor(obs[np.newaxis, :], dtype=tf.float32)
         self.is_safe(obs, i)
     obs = self.env.reset()
     obs_with_specific_shape = np.tile(obs, (3, 1))
     self.policy.run_batch(obs_with_specific_shape)
     self.policy.obj_value_batch(obs_with_specific_shape)
     # ------------------build graph for tf.function in advance-----------------------
     self.reset()
Пример #3
0
    def __init__(
        self, init_n_ego_dict
    ):  # init_n_ego_dict is used to init traffic (mainly) and ego dynamics
        self.TASK2MODEL = dict(
            left=LoadPolicy('../utils/models/left', 100000),
            straight=LoadPolicy('../utils/models/straight', 95000),
            right=LoadPolicy('../utils/models/right', 145000),
        )
        self.n_ego_instance = {}
        self.n_ego_dynamics = {}
        self.n_ego_select_index = {}
        for egoID, ego_dict in init_n_ego_dict.items():
            self.n_ego_instance[egoID] = CrossroadEnd2end(
                training_task=NAME2TASK[egoID[:2]], display=True)

        self.mpp = MultiPathGenerator()
        self.virtual_model = dict(left=EnvironmentModel(training_task='left',
                                                        mode='selecting'),
                                  straight=EnvironmentModel(
                                      training_task='straight',
                                      mode='selecting'),
                                  right=EnvironmentModel(training_task='right',
                                                         mode='selecting'))

        self.reset(init_n_ego_dict)
Пример #4
0
 def __init__(self, exp_dir, iter):
     model_dir = exp_dir + '/models'
     parser = argparse.ArgumentParser()
     params = json.loads(open(exp_dir + '/config.json').read())
     for key, val in params.items():
         parser.add_argument("-" + key, default=val)
     self.args = parser.parse_args()
     env = CrossroadEnd2end(training_task=self.args.env_kwargs_training_task,
                            num_future_data=self.args.env_kwargs_num_future_data)
     self.policy = Policy4Toyota(self.args)
     self.policy.load_weights(model_dir, iter)
     self.preprocessor = Preprocessor((self.args.obs_dim,), self.args.obs_preprocess_type, self.args.reward_preprocess_type,
                                      self.args.obs_scale, self.args.reward_scale, self.args.reward_shift,
                                      gamma=self.args.gamma)
     # self.preprocessor.load_params(load_dir)
     init_obs = env.reset()
     self.run_batch(init_obs[np.newaxis, :])
     self.obj_value_batch(init_obs[np.newaxis, :])
Пример #5
0
    def __init__(self, task):
        self.task = task
        if self.task == 'left':
            self.policy = LoadPolicy('G:\\env_build\\utils\\models\\left', 100000)
        elif self.task == 'right':
            self.policy = LoadPolicy('G:\\env_build\\utils\\models\\right', 145000)
        elif self.task == 'straight':
            self.policy = LoadPolicy('G:\\env_build\\utils\\models\\straight', 95000)

        self.horizon = 25
        self.num_future_data = 0
        self.env = CrossroadEnd2end(training_task=self.task, num_future_data=self.num_future_data)
        self.model = EnvironmentModel(self.task)
        self.obs = self.env.reset()
        self.stg = StaticTrajectoryGenerator_origin(mode='static_traj')
        self.data2plot = []
        self.mpc_cal_timer = TimerStat()
        self.adp_cal_timer = TimerStat()
        self.recorder = Recorder()
Пример #6
0
 def __init__(self, task, logdir=None):
     self.task = task
     if self.task == 'left':
         self.policy = LoadPolicy('../utils/models/left', 100000)
     elif self.task == 'right':
         self.policy = LoadPolicy('../utils/models/right', 145000)
     elif self.task == 'straight':
         self.policy = LoadPolicy('../utils/models/straight', 95000)
     self.env = CrossroadEnd2end(training_task=self.task)
     self.model = EnvironmentModel(self.task, mode='selecting')
     self.recorder = Recorder()
     self.episode_counter = -1
     self.step_counter = -1
     self.obs = None
     self.stg = None
     self.step_timer = TimerStat()
     self.ss_timer = TimerStat()
     self.logdir = logdir
     self.fig = plt.figure(figsize=(8, 8))
     plt.ion()
     self.hist_posi = []
     self.reset()
Пример #7
0
 def __init__(self, task, train_exp_dir, ite, logdir=None):
     self.task = task
     self.policy = LoadPolicy(
         '../utils/models/{}/{}'.format(task, train_exp_dir), ite)
     self.env = CrossroadEnd2end(training_task=self.task)
     self.model = EnvironmentModel(self.task, mode='selecting')
     self.recorder = Recorder()
     self.episode_counter = -1
     self.step_counter = -1
     self.obs = None
     self.stg = None
     self.step_timer = TimerStat()
     self.ss_timer = TimerStat()
     self.logdir = logdir
     if self.logdir is not None:
         config = dict(task=task, train_exp_dir=train_exp_dir, ite=ite)
         with open(self.logdir + '/config.json', 'w',
                   encoding='utf-8') as f:
             json.dump(config, f, ensure_ascii=False, indent=4)
     self.fig = plt.figure(figsize=(8, 8))
     plt.ion()
     self.hist_posi = []
     self.reset()