Пример #1
0
 def __init__(self, task, train_exp_dir, ite, logdir=None):
     self.task = task
     self.policy = LoadPolicy('../utils/models/{}/{}'.format(task, train_exp_dir), ite)
     self.env = CrossroadEnd2end(training_task=self.task, mode='testing')
     self.model = EnvironmentModel(self.task, mode='selecting')
     self.recorder = Recorder()
     self.episode_counter = -1
     self.step_counter = -1
     self.obs = None
     self.stg = MultiPathGenerator()
     self.step_timer = TimerStat()
     self.ss_timer = TimerStat()
     self.logdir = logdir
     if self.logdir is not None:
         config = dict(task=task, train_exp_dir=train_exp_dir, ite=ite)
         with open(self.logdir + '/config.json', 'w', encoding='utf-8') as f:
             json.dump(config, f, ensure_ascii=False, indent=4)
     self.fig = plt.figure(figsize=(8, 8))
     plt.ion()
     self.hist_posi = []
     self.old_index = 0
     self.path_list = self.stg.generate_path(self.task)
     # ------------------build graph for tf.function in advance-----------------------
     for i in range(3):
         obs = self.env.reset()
         obs = tf.convert_to_tensor(obs[np.newaxis, :], dtype=tf.float32)
         self.is_safe(obs, i)
     obs = self.env.reset()
     obs_with_specific_shape = np.tile(obs, (3, 1))
     self.policy.run_batch(obs_with_specific_shape)
     self.policy.obj_value_batch(obs_with_specific_shape)
     # ------------------build graph for tf.function in advance-----------------------
     self.reset()
Пример #2
0
    def __init__(
        self, init_n_ego_dict
    ):  # init_n_ego_dict is used to init traffic (mainly) and ego dynamics
        self.TASK2MODEL = dict(
            left=LoadPolicy('../utils/models/left', 100000),
            straight=LoadPolicy('../utils/models/straight', 95000),
            right=LoadPolicy('../utils/models/right', 145000),
        )
        self.n_ego_instance = {}
        self.n_ego_dynamics = {}
        self.n_ego_select_index = {}
        for egoID, ego_dict in init_n_ego_dict.items():
            self.n_ego_instance[egoID] = CrossroadEnd2end(
                training_task=NAME2TASK[egoID[:2]], display=True)

        self.mpp = MultiPathGenerator()
        self.virtual_model = dict(left=EnvironmentModel(training_task='left',
                                                        mode='selecting'),
                                  straight=EnvironmentModel(
                                      training_task='straight',
                                      mode='selecting'),
                                  right=EnvironmentModel(training_task='right',
                                                         mode='selecting'))

        self.reset(init_n_ego_dict)
Пример #3
0
    def __init__(self, init_n_ego_dict):  # init_n_ego_dict is used to init traffic (mainly) and ego dynamics
        self.TASK2MODEL = dict(left=LoadPolicy('../utils/models/left/experiment-2021-03-15-16-39-00', 180000),
                               straight=LoadPolicy('../utils/models/straight/experiment-2021-03-15-19-16-13', 175000),
                               right=LoadPolicy('../utils/models/right/experiment-2021-03-15-21-02-51', 195000),)
        self.n_ego_instance = {}
        self.n_ego_dynamics = {}
        self.n_ego_select_index = {}
        for egoID, ego_dict in init_n_ego_dict.items():
            self.n_ego_instance[egoID] = CrossroadEnd2end(training_task=NAME2TASK[egoID[:2]],
                                                          mode='testing',
                                                          multi_display=True)

        self.mpp = MultiPathGenerator()
        self.virtual_model = dict(left=EnvironmentModel(training_task='left', mode='selecting'),
                                  straight=EnvironmentModel(training_task='straight', mode='selecting'),
                                  right=EnvironmentModel(training_task='right', mode='selecting'))

        # ------------------build graph for tf.function in advance-----------------------
        for task in ['left', 'straight', 'right']:
            env = CrossroadEnd2end(training_task=task, mode='testing', multi_display=False)
            for i in range(3):
                obs = env.reset()
                obs = tf.convert_to_tensor(obs[np.newaxis, :], dtype=tf.float32)
                self.is_safe(obs, i, task)
            obs = env.reset()
            obs_with_specific_shape = np.tile(obs, (3, 1))
            self.TASK2MODEL[task].run_batch(obs_with_specific_shape)
            self.TASK2MODEL[task].obj_value_batch(obs_with_specific_shape)
            env.close()
        # ------------------build graph for tf.function in advance-----------------------
        self.reset(init_n_ego_dict)
Пример #4
0
    def __init__(
            self,
            training_task,  # 'left', 'straight', 'right'
            num_future_data=0,
            display=False,
            **kwargs):
        metadata = {'render.modes': ['human']}
        self.dynamics = VehicleDynamics()
        self.interested_vehs = None
        self.training_task = training_task
        self.ref_path = ReferencePath(self.training_task, **kwargs)
        self.detected_vehicles = None
        self.all_vehicles = None
        self.ego_dynamics = None
        self.num_future_data = num_future_data
        self.env_model = EnvironmentModel(training_task, num_future_data)
        self.init_state = {}
        self.action_number = 2
        self.exp_v = EXPECTED_V  #TODO: temp
        self.ego_l, self.ego_w = L, W
        self.action_space = gym.spaces.Box(low=-1,
                                           high=1,
                                           shape=(self.action_number, ),
                                           dtype=np.float32)

        self.seed()
        self.v_light = None
        self.step_length = 100  # ms

        self.step_time = self.step_length / 1000.0
        self.init_state = self._reset_init_state()
        self.obs = None
        self.action = None
        self.veh_mode_dict = VEHICLE_MODE_DICT[self.training_task]
        self.veh_num = VEH_NUM[self.training_task]
        self.virtual_red_light_vehicle = False

        self.done_type = 'not_done_yet'
        self.reward_info = None
        self.ego_info_dim = None
        self.per_tracking_info_dim = None
        self.per_veh_info_dim = None
        if not display:
            self.traffic = Traffic(self.step_length,
                                   mode='training',
                                   init_n_ego_dict=self.init_state,
                                   training_task=self.training_task)
            self.reset()
            action = self.action_space.sample()
            observation, _reward, done, _info = self.step(action)
            self._set_observation_space(observation)
            plt.ion()
Пример #5
0
 def __init__(self, task, logdir=None):
     self.task = task
     if self.task == 'left':
         self.policy = LoadPolicy('../utils/models/left', 100000)
     elif self.task == 'right':
         self.policy = LoadPolicy('../utils/models/right', 145000)
     elif self.task == 'straight':
         self.policy = LoadPolicy('../utils/models/straight', 95000)
     self.env = CrossroadEnd2end(training_task=self.task)
     self.model = EnvironmentModel(self.task, mode='selecting')
     self.recorder = Recorder()
     self.episode_counter = -1
     self.step_counter = -1
     self.obs = None
     self.stg = None
     self.step_timer = TimerStat()
     self.ss_timer = TimerStat()
     self.logdir = logdir
     self.fig = plt.figure(figsize=(8, 8))
     plt.ion()
     self.hist_posi = []
     self.reset()
Пример #6
0
 def __init__(self, task, train_exp_dir, ite, logdir=None):
     self.task = task
     self.policy = LoadPolicy(
         '../utils/models/{}/{}'.format(task, train_exp_dir), ite)
     self.env = CrossroadEnd2end(training_task=self.task)
     self.model = EnvironmentModel(self.task, mode='selecting')
     self.recorder = Recorder()
     self.episode_counter = -1
     self.step_counter = -1
     self.obs = None
     self.stg = None
     self.step_timer = TimerStat()
     self.ss_timer = TimerStat()
     self.logdir = logdir
     if self.logdir is not None:
         config = dict(task=task, train_exp_dir=train_exp_dir, ite=ite)
         with open(self.logdir + '/config.json', 'w',
                   encoding='utf-8') as f:
             json.dump(config, f, ensure_ascii=False, indent=4)
     self.fig = plt.figure(figsize=(8, 8))
     plt.ion()
     self.hist_posi = []
     self.reset()
Пример #7
0
    def __init__(self, task):
        self.task = task
        if self.task == 'left':
            self.policy = LoadPolicy('G:\\env_build\\utils\\models\\left', 100000)
        elif self.task == 'right':
            self.policy = LoadPolicy('G:\\env_build\\utils\\models\\right', 145000)
        elif self.task == 'straight':
            self.policy = LoadPolicy('G:\\env_build\\utils\\models\\straight', 95000)

        self.horizon = 25
        self.num_future_data = 0
        self.env = CrossroadEnd2end(training_task=self.task, num_future_data=self.num_future_data)
        self.model = EnvironmentModel(self.task)
        self.obs = self.env.reset()
        self.stg = StaticTrajectoryGenerator_origin(mode='static_traj')
        self.data2plot = []
        self.mpc_cal_timer = TimerStat()
        self.adp_cal_timer = TimerStat()
        self.recorder = Recorder()
Пример #8
0
class HierarchicalDecision(object):
    def __init__(self, task, train_exp_dir, ite, logdir=None):
        self.task = task
        self.policy = LoadPolicy('../utils/models/{}/{}'.format(task, train_exp_dir), ite)
        self.env = CrossroadEnd2end(training_task=self.task, mode='testing')
        self.model = EnvironmentModel(self.task, mode='selecting')
        self.recorder = Recorder()
        self.episode_counter = -1
        self.step_counter = -1
        self.obs = None
        self.stg = MultiPathGenerator()
        self.step_timer = TimerStat()
        self.ss_timer = TimerStat()
        self.logdir = logdir
        if self.logdir is not None:
            config = dict(task=task, train_exp_dir=train_exp_dir, ite=ite)
            with open(self.logdir + '/config.json', 'w', encoding='utf-8') as f:
                json.dump(config, f, ensure_ascii=False, indent=4)
        self.fig = plt.figure(figsize=(8, 8))
        plt.ion()
        self.hist_posi = []
        self.old_index = 0
        self.path_list = self.stg.generate_path(self.task)
        # ------------------build graph for tf.function in advance-----------------------
        for i in range(3):
            obs = self.env.reset()
            obs = tf.convert_to_tensor(obs[np.newaxis, :], dtype=tf.float32)
            self.is_safe(obs, i)
        obs = self.env.reset()
        obs_with_specific_shape = np.tile(obs, (3, 1))
        self.policy.run_batch(obs_with_specific_shape)
        self.policy.obj_value_batch(obs_with_specific_shape)
        # ------------------build graph for tf.function in advance-----------------------
        self.reset()

    def reset(self,):
        self.obs = self.env.reset()
        self.recorder.reset()
        self.old_index = 0
        self.hist_posi = []
        if self.logdir is not None:
            self.episode_counter += 1
            os.makedirs(self.logdir + '/episode{}/figs'.format(self.episode_counter))
            self.step_counter = -1
            self.recorder.save(self.logdir)
            if self.episode_counter >= 1:
                select_and_rename_snapshots_of_an_episode(self.logdir, self.episode_counter-1, 12)
                self.recorder.plot_and_save_ith_episode_curves(self.episode_counter-1,
                                                               self.logdir + '/episode{}/figs'.format(self.episode_counter-1),
                                                               isshow=False)
        return self.obs

    # @tf.function
    # def is_safe(self, obs, path_index):
    #     self.model.ref_path.set_path(path_index)
    #     action = self.policy.run_batch(obs)
    #     veh2veh4real = self.model.ss(obs, action, lam=0.1)
    #     return False if veh2veh4real[0] > 0 else True

    @tf.function
    def is_safe(self, obs, path_index):
        self.model.add_traj(obs, path_index)
        punish = 0.
        for step in range(5):
            action = self.policy.run_batch(obs)
            obs, _, _, _, veh2veh4real, _ = self.model.rollout_out(action)
            punish += veh2veh4real[0]
        return False if punish > 0 else True

    def safe_shield(self, real_obs, path_index):
        action_safe_set = [[[0., -1.]]]
        real_obs = tf.convert_to_tensor(real_obs[np.newaxis, :], dtype=tf.float32)
        obs = real_obs
        if not self.is_safe(obs, path_index):
            print('SAFETY SHIELD STARTED!')
            return np.array(action_safe_set[0], dtype=np.float32).squeeze(0), True
        else:
            return self.policy.run_batch(real_obs).numpy()[0], False

    def step(self):
        self.step_counter += 1
        with self.step_timer:
            obs_list = []
            # select optimal path
            for path in self.path_list:
                self.env.set_traj(path)
                obs_list.append(self.env._get_obs())
            all_obs = tf.stack(obs_list, axis=0)
            path_values = self.policy.obj_value_batch(all_obs).numpy()
            old_value = path_values[self.old_index]
            new_index, new_value = int(np.argmin(path_values)), min(path_values)  # value is to approximate (- sum of reward)
            path_index = self.old_index if old_value - new_value < 0.1 else new_index
            self.old_index = path_index

            self.env.set_traj(self.path_list[path_index])
            self.obs_real = obs_list[path_index]

            # obtain safe action
            with self.ss_timer:
                safe_action, is_ss = self.safe_shield(self.obs_real, path_index)
            print('ALL TIME:', self.step_timer.mean, 'ss', self.ss_timer.mean)
        self.render(self.path_list, path_values, path_index)
        self.recorder.record(self.obs_real, safe_action, self.step_timer.mean,
                             path_index, path_values, self.ss_timer.mean, is_ss)
        self.obs, r, done, info = self.env.step(safe_action)
        return done

    def render(self, traj_list, path_values, path_index):
        square_length = CROSSROAD_SIZE
        extension = 40
        lane_width = LANE_WIDTH
        light_line_width = 3
        dotted_line_style = '--'
        solid_line_style = '-'

        plt.cla()
        ax = plt.axes([-0.05, -0.05, 1.1, 1.1])
        for ax in self.fig.get_axes():
            ax.axis('off')
        ax.axis("equal")

        # ----------arrow--------------
        plt.arrow(lane_width / 2, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width / 2, -square_length / 2 - 10 + 5, -0.5, 0, color='b', head_width=1)
        plt.arrow(lane_width * 1.5, -square_length / 2 - 10, 0, 4, color='b', head_width=1)
        plt.arrow(lane_width * 2.5, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width * 2.5, -square_length / 2 - 10 + 5, 0.5, 0, color='b', head_width=1)

        # ----------horizon--------------

        plt.plot([-square_length / 2 - extension, -square_length / 2], [0.3, 0.3], color='orange')
        plt.plot([-square_length / 2 - extension, -square_length / 2], [-0.3, -0.3], color='orange')
        plt.plot([square_length / 2 + extension, square_length / 2], [0.3, 0.3], color='orange')
        plt.plot([square_length / 2 + extension, square_length / 2], [-0.3, -0.3], color='orange')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            linewidth = 1 if i < LANE_NUMBER else 2
            plt.plot([-square_length / 2 - extension, -square_length / 2], [i * lane_width, i * lane_width],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([square_length / 2 + extension, square_length / 2], [i * lane_width, i * lane_width],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([-square_length / 2 - extension, -square_length / 2], [-i * lane_width, -i * lane_width],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([square_length / 2 + extension, square_length / 2], [-i * lane_width, -i * lane_width],
                     linestyle=linestyle, color='black', linewidth=linewidth)

        # ----------vertical----------------
        plt.plot([0.3, 0.3], [-square_length / 2 - extension, -square_length / 2], color='orange')
        plt.plot([-0.3, -0.3], [-square_length / 2 - extension, -square_length / 2], color='orange')
        plt.plot([0.3, 0.3], [square_length / 2 + extension, square_length / 2], color='orange')
        plt.plot([-0.3, -0.3], [square_length / 2 + extension, square_length / 2], color='orange')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            linewidth = 1 if i < LANE_NUMBER else 2
            plt.plot([i * lane_width, i * lane_width], [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([i * lane_width, i * lane_width], [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([-i * lane_width, -i * lane_width], [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([-i * lane_width, -i * lane_width], [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle, color='black', linewidth=linewidth)

        v_light = self.env.v_light
        if v_light == 0:
            v_color, h_color = 'green', 'red'
        elif v_light == 1:
            v_color, h_color = 'orange', 'red'
        elif v_light == 2:
            v_color, h_color = 'red', 'green'
        else:
            v_color, h_color = 'red', 'orange'

        plt.plot([0, (LANE_NUMBER - 1) * lane_width], [-square_length / 2, -square_length / 2],
                 color=v_color, linewidth=light_line_width)
        plt.plot([(LANE_NUMBER - 1) * lane_width, LANE_NUMBER * lane_width], [-square_length / 2, -square_length / 2],
                 color='green', linewidth=light_line_width)

        plt.plot([-LANE_NUMBER * lane_width, -(LANE_NUMBER - 1) * lane_width], [square_length / 2, square_length / 2],
                 color='green', linewidth=light_line_width)
        plt.plot([-(LANE_NUMBER - 1) * lane_width, 0], [square_length / 2, square_length / 2],
                 color=v_color, linewidth=light_line_width)

        plt.plot([-square_length / 2, -square_length / 2], [0, -(LANE_NUMBER - 1) * lane_width],
                 color=h_color, linewidth=light_line_width)
        plt.plot([-square_length / 2, -square_length / 2], [-(LANE_NUMBER - 1) * lane_width, -LANE_NUMBER * lane_width],
                 color='green', linewidth=light_line_width)

        plt.plot([square_length / 2, square_length / 2], [(LANE_NUMBER - 1) * lane_width, 0],
                 color=h_color, linewidth=light_line_width)
        plt.plot([square_length / 2, square_length / 2], [LANE_NUMBER * lane_width, (LANE_NUMBER - 1) * lane_width],
                 color='green', linewidth=light_line_width)

        # ----------Oblique--------------
        plt.plot([LANE_NUMBER * lane_width, square_length / 2], [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black', linewidth=2)
        plt.plot([LANE_NUMBER * lane_width, square_length / 2], [square_length / 2, LANE_NUMBER * lane_width],
                 color='black', linewidth=2)
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2], [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black', linewidth=2)
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2], [square_length / 2, LANE_NUMBER * lane_width],
                 color='black', linewidth=2)

        def is_in_plot_area(x, y, tolerance=5):
            if -square_length / 2 - extension + tolerance < x < square_length / 2 + extension - tolerance and \
                    -square_length / 2 - extension + tolerance < y < square_length / 2 + extension - tolerance:
                return True
            else:
                return False

        def draw_rotate_rec(x, y, a, l, w, c):
            bottom_left_x, bottom_left_y, _ = rotate_coordination(-l / 2, w / 2, 0, -a)
            ax.add_patch(plt.Rectangle((x + bottom_left_x, y + bottom_left_y), w, l, edgecolor=c,
                                       facecolor='white', angle=-(90 - a), zorder=50))

        def plot_phi_line(x, y, phi, color):
            line_length = 3
            x_forw, y_forw = x + line_length * cos(phi * pi / 180.), \
                             y + line_length * sin(phi * pi / 180.)
            plt.plot([x, x_forw], [y, y_forw], color=color, linewidth=0.5)

        # plot cars
        for veh in self.env.all_vehicles:
            veh_x = veh['x']
            veh_y = veh['y']
            veh_phi = veh['phi']
            veh_l = veh['l']
            veh_w = veh['w']
            if is_in_plot_area(veh_x, veh_y):
                plot_phi_line(veh_x, veh_y, veh_phi, 'black')
                draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, 'black')

        # plot_interested vehs
        # for mode, num in self.env.veh_mode_dict.items():
        #     for i in range(num):
        #         veh = self.env.interested_vehs[mode][i]
        #         veh_x = veh['x']
        #         veh_y = veh['y']
        #         veh_phi = veh['phi']
        #         veh_l = veh['l']
        #         veh_w = veh['w']
        #         task2color = {'left': 'b', 'straight': 'c', 'right': 'm'}
        #
        #         if is_in_plot_area(veh_x, veh_y):
        #             plot_phi_line(veh_x, veh_y, veh_phi, 'black')
        #             task = MODE2TASK[mode]
        #             color = task2color[task]
        #             draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, color)

        ego_v_x = self.env.ego_dynamics['v_x']
        ego_v_y = self.env.ego_dynamics['v_y']
        ego_r = self.env.ego_dynamics['r']
        ego_x = self.env.ego_dynamics['x']
        ego_y = self.env.ego_dynamics['y']
        ego_phi = self.env.ego_dynamics['phi']
        ego_l = self.env.ego_dynamics['l']
        ego_w = self.env.ego_dynamics['w']
        ego_alpha_f = self.env.ego_dynamics['alpha_f']
        ego_alpha_r = self.env.ego_dynamics['alpha_r']
        alpha_f_bound = self.env.ego_dynamics['alpha_f_bound']
        alpha_r_bound = self.env.ego_dynamics['alpha_r_bound']
        r_bound = self.env.ego_dynamics['r_bound']

        plot_phi_line(ego_x, ego_y, ego_phi, 'fuchsia')
        draw_rotate_rec(ego_x, ego_y, ego_phi, ego_l, ego_w, 'fuchsia')
        self.hist_posi.append((ego_x, ego_y))

        # plot history
        xs = [pos[0] for pos in self.hist_posi]
        ys = [pos[1] for pos in self.hist_posi]
        plt.scatter(np.array(xs), np.array(ys), color='fuchsia', alpha=0.1)


        # plot future data
        tracking_info = self.obs[
                        self.env.ego_info_dim:self.env.ego_info_dim + self.env.per_tracking_info_dim * (self.env.num_future_data + 1)]
        future_path = tracking_info[self.env.per_tracking_info_dim:]
        for i in range(self.env.num_future_data):
            delta_x, delta_y, delta_phi = future_path[i * self.env.per_tracking_info_dim:
                                                      (i + 1) * self.env.per_tracking_info_dim]
            path_x, path_y, path_phi = ego_x + delta_x, ego_y + delta_y, ego_phi - delta_phi
            plt.plot(path_x, path_y, 'g.')
            plot_phi_line(path_x, path_y, path_phi, 'g')

        delta_, _, _ = tracking_info[:3]
        indexs, points = self.env.ref_path.find_closest_point(np.array([ego_x], np.float32), np.array([ego_y], np.float32))
        path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
        # plt.plot(path_x, path_y, 'g.')
        delta_x, delta_y, delta_phi = ego_x - path_x, ego_y - path_y, ego_phi - path_phi

        # plot real time traj
        try:
            color = ['blue', 'coral', 'darkcyan']
            for i, item in enumerate(traj_list):
                if i == path_index:
                    plt.plot(item.path[0], item.path[1], color=color[i])
                else:
                    plt.plot(item.path[0], item.path[1], color=color[i], alpha=0.3)
                indexs, points = item.find_closest_point(np.array([ego_x], np.float32), np.array([ego_y], np.float32))
                path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
                plt.plot(path_x, path_y, color=color[i])
        except Exception:
            pass

        # text
        # text_x, text_y_start = -120, 60
        # ge = iter(range(0, 1000, 4))
        # plt.text(text_x, text_y_start - next(ge), 'ego_x: {:.2f}m'.format(ego_x))
        # plt.text(text_x, text_y_start - next(ge), 'ego_y: {:.2f}m'.format(ego_y))
        # plt.text(text_x, text_y_start - next(ge), 'path_x: {:.2f}m'.format(path_x))
        # plt.text(text_x, text_y_start - next(ge), 'path_y: {:.2f}m'.format(path_y))
        # plt.text(text_x, text_y_start - next(ge), 'delta_: {:.2f}m'.format(delta_))
        # plt.text(text_x, text_y_start - next(ge), 'delta_x: {:.2f}m'.format(delta_x))
        # plt.text(text_x, text_y_start - next(ge), 'delta_y: {:.2f}m'.format(delta_y))
        # plt.text(text_x, text_y_start - next(ge), r'ego_phi: ${:.2f}\degree$'.format(ego_phi))
        # plt.text(text_x, text_y_start - next(ge), r'path_phi: ${:.2f}\degree$'.format(path_phi))
        # plt.text(text_x, text_y_start - next(ge), r'delta_phi: ${:.2f}\degree$'.format(delta_phi))
        # plt.text(text_x, text_y_start - next(ge), 'v_x: {:.2f}m/s'.format(ego_v_x))
        # plt.text(text_x, text_y_start - next(ge), 'exp_v: {:.2f}m/s'.format(self.env.exp_v))
        # plt.text(text_x, text_y_start - next(ge), 'v_y: {:.2f}m/s'.format(ego_v_y))
        # plt.text(text_x, text_y_start - next(ge), 'yaw_rate: {:.2f}rad/s'.format(ego_r))
        # plt.text(text_x, text_y_start - next(ge), 'yaw_rate bound: [{:.2f}, {:.2f}]'.format(-r_bound, r_bound))
        #
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$: {:.2f} rad'.format(ego_alpha_f))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$ bound: [{:.2f}, {:.2f}] '.format(-alpha_f_bound,
        #                                                                                         alpha_f_bound))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$: {:.2f} rad'.format(ego_alpha_r))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$ bound: [{:.2f}, {:.2f}] '.format(-alpha_r_bound,
        #                                                                                         alpha_r_bound))
        # if self.env.action is not None:
        #     steer, a_x = self.env.action[0], self.env.action[1]
        #     plt.text(text_x, text_y_start - next(ge),
        #              r'steer: {:.2f}rad (${:.2f}\degree$)'.format(steer, steer * 180 / np.pi))
        #     plt.text(text_x, text_y_start - next(ge), 'a_x: {:.2f}m/s^2'.format(a_x))
        #
        # text_x, text_y_start = 70, 60
        # ge = iter(range(0, 1000, 4))
        #
        # # done info
        # plt.text(text_x, text_y_start - next(ge), 'done info: {}'.format(self.env.done_type))
        #
        # # reward info
        # if self.env.reward_info is not None:
        #     for key, val in self.env.reward_info.items():
        #         plt.text(text_x, text_y_start - next(ge), '{}: {:.4f}'.format(key, val))
        #
        # # indicator for trajectory selection
        # text_x, text_y_start = -18, -70
        # ge = iter(range(0, 1000, 6))
        # if path_values is not None:
        #     for i, value in enumerate(path_values):
        #         if i == path_index:
        #             plt.text(text_x, text_y_start - next(ge), 'Path reward={:.4f}'.format(value[0]), fontsize=14,
        #                      color=color[i], fontstyle='italic')
        #         else:
        #             plt.text(text_x, text_y_start - next(ge), 'Path reward={:.4f}'.format(value[0]), fontsize=12,
        #                      color=color[i], fontstyle='italic')
        plt.show()
        plt.pause(0.001)
        if self.logdir is not None:
            plt.savefig(self.logdir + '/episode{}'.format(self.episode_counter) + '/step{:03d}.png'.format(self.step_counter))
Пример #9
0
class HierarchicalDecision(object):
    def __init__(self, task, logdir=None):
        self.task = task
        if self.task == 'left':
            self.policy = LoadPolicy('../utils/models/left', 100000)
        elif self.task == 'right':
            self.policy = LoadPolicy('../utils/models/right', 145000)
        elif self.task == 'straight':
            self.policy = LoadPolicy('../utils/models/straight', 95000)
        self.env = CrossroadEnd2end(training_task=self.task)
        self.model = EnvironmentModel(self.task, mode='selecting')
        self.recorder = Recorder()
        self.episode_counter = -1
        self.step_counter = -1
        self.obs = None
        self.stg = None
        self.step_timer = TimerStat()
        self.ss_timer = TimerStat()
        self.logdir = logdir
        self.fig = plt.figure(figsize=(8, 8))
        plt.ion()
        self.hist_posi = []
        self.reset()

    def reset(self, ):
        self.obs = self.env.reset()
        self.stg = MultiPathGenerator()
        self.recorder.reset()
        self.hist_posi = []
        if self.logdir is not None:
            self.episode_counter += 1
            os.makedirs(self.logdir +
                        '/episode{}'.format(self.episode_counter))
            self.step_counter = -1
            self.recorder.save(self.logdir)
        return self.obs

    def safe_shield(self, real_obs, traj):
        action_bound = 1.0
        action_safe_set = [[[0, -action_bound]]]
        real_obs = tf.convert_to_tensor(real_obs[np.newaxis, :])
        obs = real_obs
        self.model.add_traj(obs, traj)
        total_punishment = 0.0
        for step in range(3):
            action = self.policy.run(obs)
            _, _, _, _, veh2veh4real, _ = self.model.rollout_out(action)
            total_punishment += veh2veh4real
        if total_punishment != 0:
            print('original action will cause collision within three steps!!!')
            for safe_action in action_safe_set:
                obs = real_obs
                total_punishment = 0
                # for step in range(1):
                #     obs, veh2veh4real = self.model.safety_calculation(obs, safe_action)
                #     total_punishment += veh2veh4real
                #     if veh2veh4real != 0:   # collide
                #         break
                # if total_punishment == 0:
                #     print('found the safe action', safe_action)
                #     safe_action = np.array(safe_action)
                #     break
                # else:
                #     print('still collide')
                #     safe_action = self.policy.run(real_obs).numpy().squeeze(0)
                return np.array(safe_action).squeeze(0)
        else:
            safe_action = self.policy.run(real_obs).numpy()[0]
            return safe_action

    def step(self):
        self.step_counter += 1
        with self.step_timer:
            path_list = self.stg.generate_path(self.task)
            obs_list = []

            # select optimal path
            for path in path_list:
                self.env.set_traj(path)
                obs_list.append(self.env._get_obs())
            all_obs = tf.stack(obs_list, axis=0)
            path_values = self.policy.values(all_obs).numpy().squeeze()
            path_index = int(np.argmax(path_values[:, 0]))

            self.env.set_traj(path_list[path_index])
            self.obs_real = obs_list[path_index]

            # obtain safe action
            # with self.ss_timer:
            #     safe_action = self.safe_shield(self.obs_real, path_list[path_index])
            safe_action = self.policy.run(self.obs_real).numpy()

            print('ALL TIME:', self.step_timer.mean)
        # deal with red light temporally
        # if self.env.v_light != 0 and -25 > self.env.ego_dynamics['y'] > -35 and self.env.training_task != 'right':
        #     scaled_steer = 0.
        #     if self.env.ego_dynamics['v_x'] == 0.0:
        #         scaled_a_x = 0.33
        #     else:
        #         scaled_a_x = np.random.uniform(-0.6, -0.4)
        #     safe_action = np.array([scaled_steer, scaled_a_x], dtype=np.float32)
        self.render(path_list, path_values, path_index)
        self.recorder.record(self.obs_real, safe_action, self.step_timer.mean,
                             path_index, path_values, self.ss_timer.mean)
        self.obs, r, done, info = self.env.step(safe_action)
        return done

    def render(self, traj_list, path_values, path_index):
        square_length = CROSSROAD_SIZE
        extension = 40
        lane_width = LANE_WIDTH
        light_line_width = 3
        dotted_line_style = '--'
        solid_line_style = '-'

        plt.cla()
        ax = plt.axes([-0.05, -0.05, 1.1, 1.1])
        for ax in self.fig.get_axes():
            ax.axis('off')
        ax.axis("equal")

        # ----------arrow--------------
        plt.arrow(lane_width / 2, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width / 2,
                  -square_length / 2 - 10 + 5,
                  -0.5,
                  0,
                  color='b',
                  head_width=1)
        plt.arrow(lane_width * 1.5,
                  -square_length / 2 - 10,
                  0,
                  4,
                  color='b',
                  head_width=1)
        plt.arrow(lane_width * 2.5, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width * 2.5,
                  -square_length / 2 - 10 + 5,
                  0.5,
                  0,
                  color='b',
                  head_width=1)

        # ----------horizon--------------

        plt.plot([-square_length / 2 - extension, -square_length / 2],
                 [0.3, 0.3],
                 color='orange')
        plt.plot([-square_length / 2 - extension, -square_length / 2],
                 [-0.3, -0.3],
                 color='orange')
        plt.plot([square_length / 2 + extension, square_length / 2],
                 [0.3, 0.3],
                 color='orange')
        plt.plot([square_length / 2 + extension, square_length / 2],
                 [-0.3, -0.3],
                 color='orange')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            linewidth = 1 if i < LANE_NUMBER else 2
            plt.plot([-square_length / 2 - extension, -square_length / 2],
                     [i * lane_width, i * lane_width],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([square_length / 2 + extension, square_length / 2],
                     [i * lane_width, i * lane_width],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([-square_length / 2 - extension, -square_length / 2],
                     [-i * lane_width, -i * lane_width],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([square_length / 2 + extension, square_length / 2],
                     [-i * lane_width, -i * lane_width],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)

        # ----------vertical----------------
        plt.plot([0.3, 0.3],
                 [-square_length / 2 - extension, -square_length / 2],
                 color='orange')
        plt.plot([-0.3, -0.3],
                 [-square_length / 2 - extension, -square_length / 2],
                 color='orange')
        plt.plot([0.3, 0.3],
                 [square_length / 2 + extension, square_length / 2],
                 color='orange')
        plt.plot([-0.3, -0.3],
                 [square_length / 2 + extension, square_length / 2],
                 color='orange')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            linewidth = 1 if i < LANE_NUMBER else 2
            plt.plot([i * lane_width, i * lane_width],
                     [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([i * lane_width, i * lane_width],
                     [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([-i * lane_width, -i * lane_width],
                     [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([-i * lane_width, -i * lane_width],
                     [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)

        v_light = self.env.v_light
        if v_light == 0:
            v_color, h_color = 'green', 'red'
        elif v_light == 1:
            v_color, h_color = 'orange', 'red'
        elif v_light == 2:
            v_color, h_color = 'red', 'green'
        else:
            v_color, h_color = 'red', 'orange'

        plt.plot([0, (LANE_NUMBER - 1) * lane_width],
                 [-square_length / 2, -square_length / 2],
                 color=v_color,
                 linewidth=light_line_width)
        plt.plot([(LANE_NUMBER - 1) * lane_width, LANE_NUMBER * lane_width],
                 [-square_length / 2, -square_length / 2],
                 color='green',
                 linewidth=light_line_width)

        plt.plot([-LANE_NUMBER * lane_width, -(LANE_NUMBER - 1) * lane_width],
                 [square_length / 2, square_length / 2],
                 color='green',
                 linewidth=light_line_width)
        plt.plot([-(LANE_NUMBER - 1) * lane_width, 0],
                 [square_length / 2, square_length / 2],
                 color=v_color,
                 linewidth=light_line_width)

        plt.plot([-square_length / 2, -square_length / 2],
                 [0, -(LANE_NUMBER - 1) * lane_width],
                 color=h_color,
                 linewidth=light_line_width)
        plt.plot([-square_length / 2, -square_length / 2],
                 [-(LANE_NUMBER - 1) * lane_width, -LANE_NUMBER * lane_width],
                 color='green',
                 linewidth=light_line_width)

        plt.plot([square_length / 2, square_length / 2],
                 [(LANE_NUMBER - 1) * lane_width, 0],
                 color=h_color,
                 linewidth=light_line_width)
        plt.plot([square_length / 2, square_length / 2],
                 [LANE_NUMBER * lane_width, (LANE_NUMBER - 1) * lane_width],
                 color='green',
                 linewidth=light_line_width)

        # ----------Oblique--------------
        plt.plot([LANE_NUMBER * lane_width, square_length / 2],
                 [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black',
                 linewidth=2)
        plt.plot([LANE_NUMBER * lane_width, square_length / 2],
                 [square_length / 2, LANE_NUMBER * lane_width],
                 color='black',
                 linewidth=2)
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2],
                 [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black',
                 linewidth=2)
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2],
                 [square_length / 2, LANE_NUMBER * lane_width],
                 color='black',
                 linewidth=2)

        def is_in_plot_area(x, y, tolerance=5):
            if -square_length / 2 - extension + tolerance < x < square_length / 2 + extension - tolerance and \
                    -square_length / 2 - extension + tolerance < y < square_length / 2 + extension - tolerance:
                return True
            else:
                return False

        def draw_rotate_rec(x, y, a, l, w, c):
            bottom_left_x, bottom_left_y, _ = rotate_coordination(
                -l / 2, w / 2, 0, -a)
            ax.add_patch(
                plt.Rectangle((x + bottom_left_x, y + bottom_left_y),
                              w,
                              l,
                              edgecolor=c,
                              facecolor='white',
                              angle=-(90 - a),
                              zorder=50))

        def plot_phi_line(x, y, phi, color):
            line_length = 3
            x_forw, y_forw = x + line_length * cos(phi * pi / 180.), \
                             y + line_length * sin(phi * pi / 180.)
            plt.plot([x, x_forw], [y, y_forw], color=color, linewidth=0.5)

        # plot cars
        for veh in self.env.all_vehicles:
            veh_x = veh['x']
            veh_y = veh['y']
            veh_phi = veh['phi']
            veh_l = veh['l']
            veh_w = veh['w']
            if is_in_plot_area(veh_x, veh_y):
                plot_phi_line(veh_x, veh_y, veh_phi, 'black')
                draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, 'black')

        # plot_interested vehs
        # for mode, num in self.veh_mode_dict.items():
        #     for i in range(num):
        #         veh = self.interested_vehs[mode][i]
        #         veh_x = veh['x']
        #         veh_y = veh['y']
        #         veh_phi = veh['phi']
        #         veh_l = veh['l']
        #         veh_w = veh['w']
        #         task2color = {'left': 'b', 'straight': 'c', 'right': 'm'}
        #
        #         if is_in_plot_area(veh_x, veh_y):
        #             plot_phi_line(veh_x, veh_y, veh_phi, 'black')
        #             task = MODE2TASK[mode]
        #             color = task2color[task]
        #             draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, color, linestyle=':')

        ego_v_x = self.env.ego_dynamics['v_x']
        ego_v_y = self.env.ego_dynamics['v_y']
        ego_r = self.env.ego_dynamics['r']
        ego_x = self.env.ego_dynamics['x']
        ego_y = self.env.ego_dynamics['y']
        ego_phi = self.env.ego_dynamics['phi']
        ego_l = self.env.ego_dynamics['l']
        ego_w = self.env.ego_dynamics['w']
        ego_alpha_f = self.env.ego_dynamics['alpha_f']
        ego_alpha_r = self.env.ego_dynamics['alpha_r']
        alpha_f_bound = self.env.ego_dynamics['alpha_f_bound']
        alpha_r_bound = self.env.ego_dynamics['alpha_r_bound']
        r_bound = self.env.ego_dynamics['r_bound']

        plot_phi_line(ego_x, ego_y, ego_phi, 'fuchsia')
        draw_rotate_rec(ego_x, ego_y, ego_phi, ego_l, ego_w, 'fuchsia')
        self.hist_posi.append((ego_x, ego_y))
        # plot history data
        for hist_x, hist_y in self.hist_posi:
            plt.scatter(hist_x, hist_y, color='fuchsia', alpha=0.1)
        # plot future data
        tracking_info = self.obs[self.env.ego_info_dim:self.env.ego_info_dim +
                                 self.env.per_tracking_info_dim *
                                 (self.env.num_future_data + 1)]
        future_path = tracking_info[self.env.per_tracking_info_dim:]
        for i in range(self.env.num_future_data):
            delta_x, delta_y, delta_phi = future_path[
                i * self.env.per_tracking_info_dim:(i + 1) *
                self.env.per_tracking_info_dim]
            path_x, path_y, path_phi = ego_x + delta_x, ego_y + delta_y, ego_phi - delta_phi
            plt.plot(path_x, path_y, 'g.')
            plot_phi_line(path_x, path_y, path_phi, 'g')

        delta_, _, _ = tracking_info[:3]
        indexs, points = self.env.ref_path.find_closest_point(
            np.array([ego_x], np.float32), np.array([ego_y], np.float32))
        path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
        # plt.plot(path_x, path_y, 'g.')
        delta_x, delta_y, delta_phi = ego_x - path_x, ego_y - path_y, ego_phi - path_phi

        # plot real time traj
        try:
            color = ['blue', 'coral', 'darkcyan']
            for i, item in enumerate(traj_list):
                if i == path_index:
                    plt.plot(item.path[0], item.path[1], color=color[i])
                else:
                    plt.plot(item.path[0],
                             item.path[1],
                             color=color[i],
                             alpha=0.3)
                indexs, points = item.find_closest_point(
                    np.array([ego_x], np.float32), np.array([ego_y],
                                                            np.float32))
                path_x, path_y, path_phi = points[0][0], points[1][0], points[
                    2][0]
                plt.plot(path_x, path_y, color=color[i])
        except Exception:
            pass

        # text
        # text_x, text_y_start = -120, 60
        # ge = iter(range(0, 1000, 4))
        # plt.text(text_x, text_y_start - next(ge), 'ego_x: {:.2f}m'.format(ego_x))
        # plt.text(text_x, text_y_start - next(ge), 'ego_y: {:.2f}m'.format(ego_y))
        # plt.text(text_x, text_y_start - next(ge), 'path_x: {:.2f}m'.format(path_x))
        # plt.text(text_x, text_y_start - next(ge), 'path_y: {:.2f}m'.format(path_y))
        # plt.text(text_x, text_y_start - next(ge), 'delta_: {:.2f}m'.format(delta_))
        # plt.text(text_x, text_y_start - next(ge), 'delta_x: {:.2f}m'.format(delta_x))
        # plt.text(text_x, text_y_start - next(ge), 'delta_y: {:.2f}m'.format(delta_y))
        # plt.text(text_x, text_y_start - next(ge), r'ego_phi: ${:.2f}\degree$'.format(ego_phi))
        # plt.text(text_x, text_y_start - next(ge), r'path_phi: ${:.2f}\degree$'.format(path_phi))
        # plt.text(text_x, text_y_start - next(ge), r'delta_phi: ${:.2f}\degree$'.format(delta_phi))
        # plt.text(text_x, text_y_start - next(ge), 'v_x: {:.2f}m/s'.format(ego_v_x))
        # plt.text(text_x, text_y_start - next(ge), 'exp_v: {:.2f}m/s'.format(self.env.exp_v))
        # plt.text(text_x, text_y_start - next(ge), 'v_y: {:.2f}m/s'.format(ego_v_y))
        # plt.text(text_x, text_y_start - next(ge), 'yaw_rate: {:.2f}rad/s'.format(ego_r))
        # plt.text(text_x, text_y_start - next(ge), 'yaw_rate bound: [{:.2f}, {:.2f}]'.format(-r_bound, r_bound))
        #
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$: {:.2f} rad'.format(ego_alpha_f))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$ bound: [{:.2f}, {:.2f}] '.format(-alpha_f_bound,
        #                                                                                         alpha_f_bound))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$: {:.2f} rad'.format(ego_alpha_r))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$ bound: [{:.2f}, {:.2f}] '.format(-alpha_r_bound,
        #                                                                                         alpha_r_bound))
        # if self.env.action is not None:
        #     steer, a_x = self.env.action[0], self.env.action[1]
        #     plt.text(text_x, text_y_start - next(ge),
        #              r'steer: {:.2f}rad (${:.2f}\degree$)'.format(steer, steer * 180 / np.pi))
        #     plt.text(text_x, text_y_start - next(ge), 'a_x: {:.2f}m/s^2'.format(a_x))
        #
        # text_x, text_y_start = 70, 60
        # ge = iter(range(0, 1000, 4))
        #
        # # done info
        # plt.text(text_x, text_y_start - next(ge), 'done info: {}'.format(self.env.done_type))
        #
        # # reward info
        # if self.env.reward_info is not None:
        #     for key, val in self.env.reward_info.items():
        #         plt.text(text_x, text_y_start - next(ge), '{}: {:.4f}'.format(key, val))
        #
        # # indicator for trajectory selection
        # text_x, text_y_start = -18, -70
        # ge = iter(range(0, 1000, 6))
        # if path_values is not None:
        #     for i, value in enumerate(path_values):
        #         if i == path_index:
        #             plt.text(text_x, text_y_start - next(ge), 'Path reward={:.4f}'.format(value[0]), fontsize=14,
        #                      color=color[i], fontstyle='italic')
        #         else:
        #             plt.text(text_x, text_y_start - next(ge), 'Path reward={:.4f}'.format(value[0]), fontsize=12,
        #                      color=color[i], fontstyle='italic')
        plt.show()
        plt.pause(0.001)
        if self.logdir is not None:
            plt.savefig(self.logdir +
                        '/episode{}'.format(self.episode_counter) +
                        '/step{}.pdf'.format(self.step_counter))
Пример #10
0
class CrossroadEnd2end(gym.Env):
    def __init__(
            self,
            training_task,  # 'left', 'straight', 'right'
            num_future_data=0,
            display=False,
            **kwargs):
        metadata = {'render.modes': ['human']}
        self.dynamics = VehicleDynamics()
        self.interested_vehs = None
        self.training_task = training_task
        self.ref_path = ReferencePath(self.training_task, **kwargs)
        self.detected_vehicles = None
        self.all_vehicles = None
        self.ego_dynamics = None
        self.num_future_data = num_future_data
        self.env_model = EnvironmentModel(training_task, num_future_data)
        self.init_state = {}
        self.action_number = 2
        self.exp_v = EXPECTED_V  #TODO: temp
        self.ego_l, self.ego_w = L, W
        self.action_space = gym.spaces.Box(low=-1,
                                           high=1,
                                           shape=(self.action_number, ),
                                           dtype=np.float32)

        self.seed()
        self.v_light = None
        self.step_length = 100  # ms

        self.step_time = self.step_length / 1000.0
        self.init_state = self._reset_init_state()
        self.obs = None
        self.action = None
        self.veh_mode_dict = VEHICLE_MODE_DICT[self.training_task]
        self.veh_num = VEH_NUM[self.training_task]
        self.virtual_red_light_vehicle = False

        self.done_type = 'not_done_yet'
        self.reward_info = None
        self.ego_info_dim = None
        self.per_tracking_info_dim = None
        self.per_veh_info_dim = None
        if not display:
            self.traffic = Traffic(self.step_length,
                                   mode='training',
                                   init_n_ego_dict=self.init_state,
                                   training_task=self.training_task)
            self.reset()
            action = self.action_space.sample()
            observation, _reward, done, _info = self.step(action)
            self._set_observation_space(observation)
            plt.ion()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def reset(self, **kwargs):  # kwargs include three keys
        self.ref_path = ReferencePath(self.training_task, **kwargs)
        self.init_state = self._reset_init_state()
        self.traffic.init_traffic(self.init_state)
        self.traffic.sim_step()
        ego_dynamics = self._get_ego_dynamics([
            self.init_state['ego']['v_x'], self.init_state['ego']['v_y'],
            self.init_state['ego']['r'], self.init_state['ego']['x'],
            self.init_state['ego']['y'], self.init_state['ego']['phi']
        ], [
            0, 0, self.dynamics.vehicle_params['miu'],
            self.dynamics.vehicle_params['miu']
        ])
        self._get_all_info(ego_dynamics)
        self.obs = self._get_obs()
        self.action = None
        self.reward_info = None
        self.done_type = 'not_done_yet'
        if np.random.random() > 0.9:
            self.virtual_red_light_vehicle = True
        else:
            self.virtual_red_light_vehicle = False
        return self.obs

    def close(self):
        del self.traffic

    def step(self, action):
        self.action = self._action_transformation_for_end2end(action)
        reward, self.reward_info = self.compute_reward(self.obs, self.action)
        next_ego_state, next_ego_params = self._get_next_ego_state(self.action)
        ego_dynamics = self._get_ego_dynamics(next_ego_state, next_ego_params)
        self.traffic.set_own_car(dict(ego=ego_dynamics))
        self.traffic.sim_step()
        all_info = self._get_all_info(ego_dynamics)
        self.obs = self._get_obs()
        self.done_type, done = self._judge_done()
        self.reward_info.update({'final_rew': reward})
        all_info.update({
            'reward_info': self.reward_info,
            'ref_index': self.ref_path.ref_index
        })
        return self.obs, reward, done, all_info

    def _set_observation_space(self, observation):
        self.observation_space = convert_observation_to_space(observation)
        return self.observation_space

    def _get_ego_dynamics(self, next_ego_state, next_ego_params):
        out = dict(
            v_x=next_ego_state[0],
            v_y=next_ego_state[1],
            r=next_ego_state[2],
            x=next_ego_state[3],
            y=next_ego_state[4],
            phi=next_ego_state[5],
            l=self.ego_l,
            w=self.ego_w,
            alpha_f=next_ego_params[0],
            alpha_r=next_ego_params[1],
            miu_f=next_ego_params[2],
            miu_r=next_ego_params[3],
        )
        miu_f, miu_r = out['miu_f'], out['miu_r']
        F_zf, F_zr = self.dynamics.vehicle_params[
            'F_zf'], self.dynamics.vehicle_params['F_zr']
        C_f, C_r = self.dynamics.vehicle_params[
            'C_f'], self.dynamics.vehicle_params['C_r']
        alpha_f_bound, alpha_r_bound = 3 * miu_f * F_zf / C_f, 3 * miu_r * F_zr / C_r
        r_bound = miu_r * self.dynamics.vehicle_params['g'] / (
            abs(out['v_x']) + 1e-8)

        l, w, x, y, phi = out['l'], out['w'], out['x'], out['y'], out['phi']

        def cal_corner_point_of_ego_car():
            x0, y0, a0 = rotate_and_shift_coordination(l / 2, w / 2, 0, -x, -y,
                                                       -phi)
            x1, y1, a1 = rotate_and_shift_coordination(l / 2, -w / 2, 0, -x,
                                                       -y, -phi)
            x2, y2, a2 = rotate_and_shift_coordination(-l / 2, w / 2, 0, -x,
                                                       -y, -phi)
            x3, y3, a3 = rotate_and_shift_coordination(-l / 2, -w / 2, 0, -x,
                                                       -y, -phi)
            return (x0, y0), (x1, y1), (x2, y2), (x3, y3)

        Corner_point = cal_corner_point_of_ego_car()
        out.update(
            dict(alpha_f_bound=alpha_f_bound,
                 alpha_r_bound=alpha_r_bound,
                 r_bound=r_bound,
                 Corner_point=Corner_point))

        return out

    def _get_all_info(
        self, ego_dynamics
    ):  # used to update info, must be called every timestep before _get_obs
        # to fetch info
        self.all_vehicles = self.traffic.n_ego_vehicles[
            'ego']  # coordination 2
        self.ego_dynamics = ego_dynamics  # coordination 2
        self.v_light = self.traffic.v_light

        # all_vehicles
        # dict(x=x, y=y, v=v, phi=a, l=length,
        #      w=width, route=route)

        all_info = dict(all_vehicles=self.all_vehicles,
                        ego_dynamics=self.ego_dynamics,
                        v_light=self.v_light)
        return all_info

    def _judge_done(self):
        """
        :return:
         1: bad done: collision
         2: bad done: break_road_constrain
         3: good done: task succeed
         4: not done
        """
        if self.traffic.collision_flag:
            return 'collision', 1
        if self._break_road_constrain():
            return 'break_road_constrain', 1
        elif self._deviate_too_much():
            return 'deviate_too_much', 1
        elif self._break_stability():
            return 'break_stability', 1
        elif self._break_red_light():
            return 'break_red_light', 1
        elif self._is_achieve_goal():
            return 'good_done', 1
        else:
            return 'not_done_yet', 0

    def _deviate_too_much(self):
        delta_y, delta_phi, delta_v = self.obs[self.
                                               ego_info_dim:self.ego_info_dim +
                                               3]
        return True if abs(delta_y) > 15 else False

    def _break_road_constrain(self):
        results = list(
            map(lambda x: judge_feasible(*x, self.training_task),
                self.ego_dynamics['Corner_point']))
        return not all(results)

    def _break_stability(self):
        alpha_f, alpha_r, miu_f, miu_r = self.ego_dynamics['alpha_f'], self.ego_dynamics['alpha_r'], \
                                         self.ego_dynamics['miu_f'], self.ego_dynamics['miu_r']
        alpha_f_bound, alpha_r_bound = self.ego_dynamics[
            'alpha_f_bound'], self.ego_dynamics['alpha_r_bound']
        r_bound = self.ego_dynamics['r_bound']
        # if -alpha_f_bound < alpha_f < alpha_f_bound \
        #         and -alpha_r_bound < alpha_r < alpha_r_bound and \
        #         -r_bound < self.ego_dynamics['r'] < r_bound:
        if -r_bound < self.ego_dynamics['r'] < r_bound:
            return False
        else:
            return True

    def _break_red_light(self):
        return True if self.v_light != 0 and self.ego_dynamics[
            'y'] > -CROSSROAD_SIZE / 2 and self.training_task != 'right' else False

    def _is_achieve_goal(self):
        x = self.ego_dynamics['x']
        y = self.ego_dynamics['y']
        if self.training_task == 'left':
            return True if x < -CROSSROAD_SIZE / 2 - 10 and 0 < y < LANE_NUMBER * LANE_WIDTH else False
        elif self.training_task == 'right':
            return True if x > CROSSROAD_SIZE / 2 + 10 and -LANE_NUMBER * LANE_WIDTH < y < 0 else False
        else:
            assert self.training_task == 'straight'
            return True if y > CROSSROAD_SIZE / 2 + 10 and 0 < x < LANE_NUMBER * LANE_WIDTH else False

    def _action_transformation_for_end2end(self, action):  # [-1, 1]
        action = np.clip(action, -1.05, 1.05)
        steer_norm, a_x_norm = action[0], action[1]
        scaled_steer = 0.4 * steer_norm
        scaled_a_x = 2.25 * a_x_norm - 0.75  # [-3, 1.5]
        scaled_action = np.array([scaled_steer, scaled_a_x], dtype=np.float32)
        return scaled_action

    def _get_next_ego_state(self, trans_action):
        current_v_x = self.ego_dynamics['v_x']
        current_v_y = self.ego_dynamics['v_y']
        current_r = self.ego_dynamics['r']
        current_x = self.ego_dynamics['x']
        current_y = self.ego_dynamics['y']
        current_phi = self.ego_dynamics['phi']
        steer, a_x = trans_action
        state = np.array([[
            current_v_x, current_v_y, current_r, current_x, current_y,
            current_phi
        ]],
                         dtype=np.float32)
        action = np.array([[steer, a_x]], dtype=np.float32)
        next_ego_state, next_ego_params = self.dynamics.prediction(
            state, action, 10)
        next_ego_state, next_ego_params = next_ego_state.numpy(
        )[0], next_ego_params.numpy()[0]
        next_ego_state[0] = next_ego_state[0] if next_ego_state[0] >= 0 else 0.
        next_ego_state[-1] = deal_with_phi(next_ego_state[-1])
        return next_ego_state, next_ego_params

    def _get_obs(self, exit_='D'):
        ego_x = self.ego_dynamics['x']
        ego_y = self.ego_dynamics['y']
        ego_phi = self.ego_dynamics['phi']
        ego_v_x = self.ego_dynamics['v_x']

        vehs_vector = self._construct_veh_vector_short(exit_)
        ego_vector = self._construct_ego_vector_short()
        tracking_error = self.ref_path.tracking_error_vector(
            np.array([ego_x], dtype=np.float32),
            np.array([ego_y], dtype=np.float32),
            np.array([ego_phi], dtype=np.float32),
            np.array([ego_v_x], dtype=np.float32),
            self.num_future_data).numpy()[0]
        self.per_tracking_info_dim = 3

        vector = np.concatenate((ego_vector, tracking_error, vehs_vector),
                                axis=0)
        # vector = self.convert_vehs_to_rela(vector)

        return vector

    # def convert_vehs_to_rela(self, obs_abso):
    #     ego_infos, tracking_infos, veh_infos = obs_abso[:self.ego_info_dim], \
    #                                            obs_abso[self.ego_info_dim:self.ego_info_dim + self.per_tracking_info_dim * (
    #                                                      self.num_future_data + 1)], \
    #                                            obs_abso[self.ego_info_dim + self.per_tracking_info_dim * (
    #                                                        self.num_future_data + 1):]
    #     ego_vx, ego_vy, ego_r, ego_x, ego_y, ego_phi = ego_infos
    #     ego = np.array([ego_x, ego_y, 0, 0]*int(len(veh_infos)/self.per_veh_info_dim), dtype=np.float32)
    #     vehs_rela = veh_infos - ego
    #     out = np.concatenate((ego_infos, tracking_infos, vehs_rela), axis=0)
    #     return out
    #
    # def convert_vehs_to_abso(self, obs_rela):
    #     ego_infos, tracking_infos, veh_rela = obs_rela[:self.ego_info_dim], \
    #                                            obs_rela[self.ego_info_dim:self.ego_info_dim + self.per_tracking_info_dim * (
    #                                                    self.num_future_data + 1)], \
    #                                            obs_rela[self.ego_info_dim + self.per_tracking_info_dim * (
    #                                                    self.num_future_data + 1):]
    #     ego_vx, ego_vy, ego_r, ego_x, ego_y, ego_phi = ego_infos
    #     ego = np.array([ego_x, ego_y, 0, 0]*int(len(veh_rela)/self.per_veh_info_dim), dtype=np.float32)
    #     vehs_abso = veh_rela + ego
    #     out = np.concatenate((ego_infos, tracking_infos, vehs_abso), axis=0)
    #     return out

    def _construct_ego_vector_short(self):
        ego_v_x = self.ego_dynamics['v_x']
        ego_v_y = self.ego_dynamics['v_y']
        ego_r = self.ego_dynamics['r']
        ego_x = self.ego_dynamics['x']
        ego_y = self.ego_dynamics['y']
        ego_phi = self.ego_dynamics['phi']
        ego_feature = [ego_v_x, ego_v_y, ego_r, ego_x, ego_y, ego_phi]
        self.ego_info_dim = 6
        return np.array(ego_feature, dtype=np.float32)

    def _construct_veh_vector_short(self, exit_='D'):
        ego_x = self.ego_dynamics['x']
        ego_y = self.ego_dynamics['y']
        v_light = self.v_light
        vehs_vector = []

        name_settings = dict(D=dict(do='1o',
                                    di='1i',
                                    ro='2o',
                                    ri='2i',
                                    uo='3o',
                                    ui='3i',
                                    lo='4o',
                                    li='4i'),
                             R=dict(do='2o',
                                    di='2i',
                                    ro='3o',
                                    ri='3i',
                                    uo='4o',
                                    ui='4i',
                                    lo='1o',
                                    li='1i'),
                             U=dict(do='3o',
                                    di='3i',
                                    ro='4o',
                                    ri='4i',
                                    uo='1o',
                                    ui='1i',
                                    lo='2o',
                                    li='2i'),
                             L=dict(do='4o',
                                    di='4i',
                                    ro='1o',
                                    ri='1i',
                                    uo='2o',
                                    ui='2i',
                                    lo='3o',
                                    li='3i'))

        name_setting = name_settings[exit_]

        def filter_interested_vehicles(vs, task):
            dl, du, dr, rd, rl, ru, ur, ud, ul, lu, lr, ld = [], [], [], [], [], [], [], [], [], [], [], []
            for v in vs:
                route_list = v['route']
                start = route_list[0]
                end = route_list[1]
                if start == name_setting['do'] and end == name_setting['li']:
                    dl.append(v)
                elif start == name_setting['do'] and end == name_setting['ui']:
                    du.append(v)
                elif start == name_setting['do'] and end == name_setting['ri']:
                    dr.append(v)

                elif start == name_setting['ro'] and end == name_setting['di']:
                    rd.append(v)
                elif start == name_setting['ro'] and end == name_setting['li']:
                    rl.append(v)
                elif start == name_setting['ro'] and end == name_setting['ui']:
                    ru.append(v)

                elif start == name_setting['uo'] and end == name_setting['ri']:
                    ur.append(v)
                elif start == name_setting['uo'] and end == name_setting['di']:
                    ud.append(v)
                elif start == name_setting['uo'] and end == name_setting['li']:
                    ul.append(v)

                elif start == name_setting['lo'] and end == name_setting['ui']:
                    lu.append(v)
                elif start == name_setting['lo'] and end == name_setting['ri']:
                    lr.append(v)
                elif start == name_setting['lo'] and end == name_setting['di']:
                    ld.append(v)
            if self.training_task != 'right':
                if (v_light != 0 and ego_y < -CROSSROAD_SIZE/2) \
                        or (self.virtual_red_light_vehicle and ego_y < -CROSSROAD_SIZE/2):
                    dl.append(
                        dict(x=LANE_WIDTH / 2,
                             y=-CROSSROAD_SIZE / 2 + 2.5,
                             v=0.,
                             phi=90,
                             l=5,
                             w=2.5,
                             route=None))
                    du.append(
                        dict(x=LANE_WIDTH * 1.5,
                             y=-CROSSROAD_SIZE / 2 + 2.5,
                             v=0.,
                             phi=90,
                             l=5,
                             w=2.5,
                             route=None))

            # fetch veh in range
            dl = list(
                filter(
                    lambda v: v['x'] > -CROSSROAD_SIZE / 2 - 10 and v['y'] >
                    ego_y - 2, dl))  # interest of left straight
            du = list(
                filter(
                    lambda v: ego_y - 2 < v['y'] < CROSSROAD_SIZE / 2 + 10 and
                    v['x'] < ego_x + 5, du))  # interest of left straight

            dr = list(
                filter(
                    lambda v: v['x'] < CROSSROAD_SIZE / 2 + 10 and v['y'] >
                    ego_y, dr))  # interest of right

            rd = rd  # not interest in case of traffic light
            rl = rl  # not interest in case of traffic light
            ru = list(
                filter(
                    lambda v: v['x'] < CROSSROAD_SIZE / 2 + 10 and v['y'] <
                    CROSSROAD_SIZE / 2 + 10, ru))  # interest of straight

            if task == 'straight':
                ur = list(
                    filter(
                        lambda v: v['x'] < ego_x + 7 and ego_y < v['y'] <
                        CROSSROAD_SIZE / 2 + 10, ur))  # interest of straight
            elif task == 'right':
                ur = list(
                    filter(
                        lambda v: v['x'] < CROSSROAD_SIZE / 2 + 10 and v['y'] <
                        CROSSROAD_SIZE / 2, ur))  # interest of right
            ud = list(
                filter(
                    lambda v: max(ego_y - 2, -CROSSROAD_SIZE / 2) < v[
                        'y'] < CROSSROAD_SIZE / 2 and ego_x > v['x'],
                    ud))  # interest of left
            ul = list(
                filter(
                    lambda v: -CROSSROAD_SIZE / 2 - 10 < v['x'] < ego_x and v[
                        'y'] < CROSSROAD_SIZE / 2, ul))  # interest of left

            lu = lu  # not interest in case of traffic light
            lr = list(
                filter(
                    lambda v: -CROSSROAD_SIZE / 2 - 10 < v['x'] <
                    CROSSROAD_SIZE / 2 + 10, lr))  # interest of right
            ld = ld  # not interest in case of traffic light

            # sort
            dl = sorted(dl, key=lambda v: (v['y'], -v['x']))
            du = sorted(du, key=lambda v: v['y'])
            dr = sorted(dr, key=lambda v: (v['y'], v['x']))

            ru = sorted(ru, key=lambda v: (-v['x'], v['y']), reverse=True)

            if task == 'straight':
                ur = sorted(ur, key=lambda v: v['y'])
            elif task == 'right':
                ur = sorted(ur, key=lambda v: (-v['y'], v['x']), reverse=True)

            ud = sorted(ud, key=lambda v: v['y'])
            ul = sorted(ul, key=lambda v: (-v['y'], -v['x']), reverse=True)

            lr = sorted(lr, key=lambda v: -v['x'])

            # slice or fill to some number
            def slice_or_fill(sorted_list, fill_value, num):
                if len(sorted_list) >= num:
                    return sorted_list[:num]
                else:
                    while len(sorted_list) < num:
                        sorted_list.append(fill_value)
                    return sorted_list

            mode2fillvalue = dict(dl=dict(x=LANE_WIDTH / 2,
                                          y=-(CROSSROAD_SIZE / 2 + 30),
                                          v=0,
                                          phi=90,
                                          w=2.5,
                                          l=5,
                                          route=('1o', '4i')),
                                  du=dict(x=LANE_WIDTH * 1.5,
                                          y=-(CROSSROAD_SIZE / 2 + 30),
                                          v=0,
                                          phi=90,
                                          w=2.5,
                                          l=5,
                                          route=('1o', '3i')),
                                  dr=dict(x=LANE_WIDTH * (LANE_NUMBER - 0.5),
                                          y=-(CROSSROAD_SIZE / 2 + 30),
                                          v=0,
                                          phi=90,
                                          w=2.5,
                                          l=5,
                                          route=('1o', '2i')),
                                  ru=dict(x=(CROSSROAD_SIZE / 2 + 15),
                                          y=LANE_WIDTH * (LANE_NUMBER - 0.5),
                                          v=0,
                                          phi=180,
                                          w=2.5,
                                          l=5,
                                          route=('2o', '3i')),
                                  ur=dict(x=-LANE_WIDTH / 2,
                                          y=(CROSSROAD_SIZE / 2 + 20),
                                          v=0,
                                          phi=-90,
                                          w=2.5,
                                          l=5,
                                          route=('3o', '2i')),
                                  ud=dict(x=-LANE_WIDTH * 1.5,
                                          y=(CROSSROAD_SIZE / 2 + 20),
                                          v=0,
                                          phi=-90,
                                          w=2.5,
                                          l=5,
                                          route=('3o', '1i')),
                                  ul=dict(x=-LANE_WIDTH * (LANE_NUMBER - 0.5),
                                          y=(CROSSROAD_SIZE / 2 + 20),
                                          v=0,
                                          phi=-90,
                                          w=2.5,
                                          l=5,
                                          route=('3o', '4i')),
                                  lr=dict(x=-(CROSSROAD_SIZE / 2 + 20),
                                          y=-LANE_WIDTH * 1.5,
                                          v=0,
                                          phi=0,
                                          w=2.5,
                                          l=5,
                                          route=('4o', '2i')))

            tmp = OrderedDict()
            for mode, num in VEHICLE_MODE_DICT[task].items():
                tmp[mode] = slice_or_fill(eval(mode), mode2fillvalue[mode],
                                          num)

            return tmp

        list_of_interested_veh_dict = []
        self.interested_vehs = filter_interested_vehicles(
            self.all_vehicles, self.training_task)
        for part in list(self.interested_vehs.values()):
            list_of_interested_veh_dict.extend(part)

        for veh in list_of_interested_veh_dict:
            veh_x, veh_y, veh_v, veh_phi = veh['x'], veh['y'], veh['v'], veh[
                'phi']
            vehs_vector.extend([veh_x, veh_y, veh_v, veh_phi])
        self.per_veh_info_dim = 4
        return np.array(vehs_vector, dtype=np.float32)

    def recover_orig_position_fn(self, transformed_x, transformed_y, x, y,
                                 d):  # x, y, d are used to transform
        # coordination
        transformed_x, transformed_y, _ = rotate_coordination(
            transformed_x, transformed_y, 0, -d)
        orig_x, orig_y = shift_coordination(transformed_x, transformed_y, -x,
                                            -y)
        return orig_x, orig_y

    def _reset_init_state(self):
        if self.training_task == 'left':
            random_index = int(np.random.random() * (900 + 500)) + 700
        elif self.training_task == 'straight':
            random_index = int(np.random.random() * (1200 + 500)) + 700
        else:
            random_index = int(np.random.random() * (420 + 500)) + 700

        x, y, phi = self.ref_path.indexs2points(random_index)
        # v = 7 + 6 * np.random.random()
        v = EXPECTED_V * np.random.random()
        if self.training_task == 'left':
            routeID = 'dl'
        elif self.training_task == 'straight':
            routeID = 'du'
        else:
            assert self.training_task == 'right'
            routeID = 'dr'
        return dict(ego=dict(
            v_x=v,
            v_y=0,
            r=0,
            x=x.numpy(),
            y=y.numpy(),
            phi=phi.numpy(),
            l=self.ego_l,
            w=self.ego_w,
            routeID=routeID,
        ))

    def compute_reward(self, obs, action):
        obses, actions = obs[np.newaxis, :], action[np.newaxis, :]
        reward, _, _, _, _, reward_dict = \
            self.env_model.compute_rewards(obses, actions)
        for k, v in reward_dict.items():
            reward_dict[k] = v.numpy()[0]
        return reward.numpy()[0], reward_dict

    def render(self, mode='human'):
        if mode == 'human':
            # plot basic map
            square_length = CROSSROAD_SIZE
            extension = 40
            lane_width = LANE_WIDTH
            light_line_width = 3
            dotted_line_style = '--'
            solid_line_style = '-'

            plt.cla()
            ax = plt.axes([-0.05, -0.05, 1.1, 1.1])
            ax.axis("equal")
            # ax.add_patch(plt.Rectangle((-square_length / 2 - extension, -square_length / 2 - extension),
            #                            square_length + 2 * extension, square_length + 2 * extension, edgecolor='black',
            #                            facecolor='none', linewidth=2))

            # ----------arrow--------------
            plt.arrow(lane_width / 2, -square_length / 2 - 10, 0, 5, color='b')
            plt.arrow(lane_width / 2,
                      -square_length / 2 - 10 + 5,
                      -0.5,
                      0,
                      color='b',
                      head_width=1)
            plt.arrow(lane_width * 1.5,
                      -square_length / 2 - 10,
                      0,
                      4,
                      color='b',
                      head_width=1)
            plt.arrow(lane_width * 2.5,
                      -square_length / 2 - 10,
                      0,
                      5,
                      color='b')
            plt.arrow(lane_width * 2.5,
                      -square_length / 2 - 10 + 5,
                      0.5,
                      0,
                      color='b',
                      head_width=1)

            # ----------horizon--------------

            plt.plot([-square_length / 2 - extension, -square_length / 2],
                     [0.3, 0.3],
                     color='orange')
            plt.plot([-square_length / 2 - extension, -square_length / 2],
                     [-0.3, -0.3],
                     color='orange')
            plt.plot([square_length / 2 + extension, square_length / 2],
                     [0.3, 0.3],
                     color='orange')
            plt.plot([square_length / 2 + extension, square_length / 2],
                     [-0.3, -0.3],
                     color='orange')

            #
            for i in range(1, LANE_NUMBER + 1):
                linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
                linewidth = 1 if i < LANE_NUMBER else 2
                plt.plot([-square_length / 2 - extension, -square_length / 2],
                         [i * lane_width, i * lane_width],
                         linestyle=linestyle,
                         color='black',
                         linewidth=linewidth)
                plt.plot([square_length / 2 + extension, square_length / 2],
                         [i * lane_width, i * lane_width],
                         linestyle=linestyle,
                         color='black',
                         linewidth=linewidth)
                plt.plot([-square_length / 2 - extension, -square_length / 2],
                         [-i * lane_width, -i * lane_width],
                         linestyle=linestyle,
                         color='black',
                         linewidth=linewidth)
                plt.plot([square_length / 2 + extension, square_length / 2],
                         [-i * lane_width, -i * lane_width],
                         linestyle=linestyle,
                         color='black',
                         linewidth=linewidth)

            # ----------vertical----------------
            plt.plot([0.3, 0.3],
                     [-square_length / 2 - extension, -square_length / 2],
                     color='orange')
            plt.plot([-0.3, -0.3],
                     [-square_length / 2 - extension, -square_length / 2],
                     color='orange')
            plt.plot([0.3, 0.3],
                     [square_length / 2 + extension, square_length / 2],
                     color='orange')
            plt.plot([-0.3, -0.3],
                     [square_length / 2 + extension, square_length / 2],
                     color='orange')

            #
            for i in range(1, LANE_NUMBER + 1):
                linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
                linewidth = 1 if i < LANE_NUMBER else 2
                plt.plot([i * lane_width, i * lane_width],
                         [-square_length / 2 - extension, -square_length / 2],
                         linestyle=linestyle,
                         color='black',
                         linewidth=linewidth)
                plt.plot([i * lane_width, i * lane_width],
                         [square_length / 2 + extension, square_length / 2],
                         linestyle=linestyle,
                         color='black',
                         linewidth=linewidth)
                plt.plot([-i * lane_width, -i * lane_width],
                         [-square_length / 2 - extension, -square_length / 2],
                         linestyle=linestyle,
                         color='black',
                         linewidth=linewidth)
                plt.plot([-i * lane_width, -i * lane_width],
                         [square_length / 2 + extension, square_length / 2],
                         linestyle=linestyle,
                         color='black',
                         linewidth=linewidth)

            # ----------stop line--------------
            # plt.plot([0, 2 * lane_width], [-square_length / 2, -square_length / 2],
            #          color='black')
            # plt.plot([-2 * lane_width, 0], [square_length / 2, square_length / 2],
            #          color='black')
            # plt.plot([-square_length / 2, -square_length / 2], [0, -2 * lane_width],
            #          color='black')
            # plt.plot([square_length / 2, square_length / 2], [2 * lane_width, 0],
            #          color='black')
            v_light = self.v_light
            if v_light == 0:
                v_color, h_color = 'green', 'red'
            elif v_light == 1:
                v_color, h_color = 'orange', 'red'
            elif v_light == 2:
                v_color, h_color = 'red', 'green'
            else:
                v_color, h_color = 'red', 'orange'

            plt.plot([0, (LANE_NUMBER - 1) * lane_width],
                     [-square_length / 2, -square_length / 2],
                     color=v_color,
                     linewidth=light_line_width)
            plt.plot(
                [(LANE_NUMBER - 1) * lane_width, LANE_NUMBER * lane_width],
                [-square_length / 2, -square_length / 2],
                color='green',
                linewidth=light_line_width)

            plt.plot(
                [-LANE_NUMBER * lane_width, -(LANE_NUMBER - 1) * lane_width],
                [square_length / 2, square_length / 2],
                color='green',
                linewidth=light_line_width)
            plt.plot([-(LANE_NUMBER - 1) * lane_width, 0],
                     [square_length / 2, square_length / 2],
                     color=v_color,
                     linewidth=light_line_width)

            plt.plot([-square_length / 2, -square_length / 2],
                     [0, -(LANE_NUMBER - 1) * lane_width],
                     color=h_color,
                     linewidth=light_line_width)
            plt.plot(
                [-square_length / 2, -square_length / 2],
                [-(LANE_NUMBER - 1) * lane_width, -LANE_NUMBER * lane_width],
                color='green',
                linewidth=light_line_width)

            plt.plot([square_length / 2, square_length / 2],
                     [(LANE_NUMBER - 1) * lane_width, 0],
                     color=h_color,
                     linewidth=light_line_width)
            plt.plot(
                [square_length / 2, square_length / 2],
                [LANE_NUMBER * lane_width, (LANE_NUMBER - 1) * lane_width],
                color='green',
                linewidth=light_line_width)

            # ----------Oblique--------------
            plt.plot([LANE_NUMBER * lane_width, square_length / 2],
                     [-square_length / 2, -LANE_NUMBER * lane_width],
                     color='black',
                     linewidth=2)
            plt.plot([LANE_NUMBER * lane_width, square_length / 2],
                     [square_length / 2, LANE_NUMBER * lane_width],
                     color='black',
                     linewidth=2)
            plt.plot([-LANE_NUMBER * lane_width, -square_length / 2],
                     [-square_length / 2, -LANE_NUMBER * lane_width],
                     color='black',
                     linewidth=2)
            plt.plot([-LANE_NUMBER * lane_width, -square_length / 2],
                     [square_length / 2, LANE_NUMBER * lane_width],
                     color='black',
                     linewidth=2)

            def is_in_plot_area(x, y, tolerance=5):
                if -square_length / 2 - extension + tolerance < x < square_length / 2 + extension - tolerance and \
                        -square_length / 2 - extension + tolerance < y < square_length / 2 + extension - tolerance:
                    return True
                else:
                    return False

            def draw_rotate_rec(x, y, a, l, w, color, linestyle='-'):
                RU_x, RU_y, _ = rotate_coordination(l / 2, w / 2, 0, -a)
                RD_x, RD_y, _ = rotate_coordination(l / 2, -w / 2, 0, -a)
                LU_x, LU_y, _ = rotate_coordination(-l / 2, w / 2, 0, -a)
                LD_x, LD_y, _ = rotate_coordination(-l / 2, -w / 2, 0, -a)
                ax.plot([RU_x + x, RD_x + x], [RU_y + y, RD_y + y],
                        color=color,
                        linestyle=linestyle)
                ax.plot([RU_x + x, LU_x + x], [RU_y + y, LU_y + y],
                        color=color,
                        linestyle=linestyle)
                ax.plot([LD_x + x, RD_x + x], [LD_y + y, RD_y + y],
                        color=color,
                        linestyle=linestyle)
                ax.plot([LD_x + x, LU_x + x], [LD_y + y, LU_y + y],
                        color=color,
                        linestyle=linestyle)

            def plot_phi_line(x, y, phi, color):
                line_length = 5
                x_forw, y_forw = x + line_length * cos(phi*pi/180.),\
                                 y + line_length * sin(phi*pi/180.)
                plt.plot([x, x_forw], [y, y_forw], color=color, linewidth=0.5)

            # plot cars
            for veh in self.all_vehicles:
                veh_x = veh['x']
                veh_y = veh['y']
                veh_phi = veh['phi']
                veh_l = veh['l']
                veh_w = veh['w']
                if is_in_plot_area(veh_x, veh_y):
                    plot_phi_line(veh_x, veh_y, veh_phi, 'black')
                    draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w,
                                    'black')

            # plot_interested vehs
            for mode, num in self.veh_mode_dict.items():
                for i in range(num):
                    veh = self.interested_vehs[mode][i]
                    veh_x = veh['x']
                    veh_y = veh['y']
                    veh_phi = veh['phi']
                    veh_l = veh['l']
                    veh_w = veh['w']
                    task2color = {'left': 'b', 'straight': 'c', 'right': 'm'}

                    if is_in_plot_area(veh_x, veh_y):
                        plot_phi_line(veh_x, veh_y, veh_phi, 'black')
                        task = MODE2TASK[mode]
                        color = task2color[task]
                        draw_rotate_rec(veh_x,
                                        veh_y,
                                        veh_phi,
                                        veh_l,
                                        veh_w,
                                        color,
                                        linestyle=':')

            # plot own car
            # dict(v_x=ego_dict['v_x'],
            #      v_y=ego_dict['v_y'],
            #      r=ego_dict['r'],
            #      x=ego_dict['x'],
            #      y=ego_dict['y'],
            #      phi=ego_dict['phi'],
            #      l=ego_dict['l'],
            #      w=ego_dict['w'],
            #      Corner_point=self.cal_corner_point_of_ego_car(ego_dict)
            #      alpha_f_bound=alpha_f_bound,
            #      alpha_r_bound=alpha_r_bound,
            #      r_bound=r_bound)

            ego_v_x = self.ego_dynamics['v_x']
            ego_v_y = self.ego_dynamics['v_y']
            ego_r = self.ego_dynamics['r']
            ego_x = self.ego_dynamics['x']
            ego_y = self.ego_dynamics['y']
            ego_phi = self.ego_dynamics['phi']
            ego_l = self.ego_dynamics['l']
            ego_w = self.ego_dynamics['w']
            ego_alpha_f = self.ego_dynamics['alpha_f']
            ego_alpha_r = self.ego_dynamics['alpha_r']
            alpha_f_bound = self.ego_dynamics['alpha_f_bound']
            alpha_r_bound = self.ego_dynamics['alpha_r_bound']
            r_bound = self.ego_dynamics['r_bound']

            plot_phi_line(ego_x, ego_y, ego_phi, 'red')
            draw_rotate_rec(ego_x, ego_y, ego_phi, ego_l, ego_w, 'red')

            # plot future data
            tracking_info = self.obs[self.ego_info_dim:self.ego_info_dim +
                                     self.per_tracking_info_dim *
                                     (self.num_future_data + 1)]
            future_path = tracking_info[self.per_tracking_info_dim:]
            for i in range(self.num_future_data):
                delta_x, delta_y, delta_phi = future_path[
                    i * self.per_tracking_info_dim:(i + 1) *
                    self.per_tracking_info_dim]
                path_x, path_y, path_phi = ego_x + delta_x, ego_y + delta_y, ego_phi - delta_phi
                plt.plot(path_x, path_y, 'g.')
                plot_phi_line(path_x, path_y, path_phi, 'g')

            delta_, _, _ = tracking_info[:3]
            ax.plot(self.ref_path.path[0], self.ref_path.path[1], color='g')
            indexs, points = self.ref_path.find_closest_point(
                np.array([ego_x], np.float32), np.array([ego_y], np.float32))
            path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
            plt.plot(path_x, path_y, 'g.')
            delta_x, delta_y, delta_phi = ego_x - path_x, ego_y - path_y, ego_phi - path_phi

            # plot real time traj
            # try:
            #     color = ['b', 'lime']
            #     for i, item in enumerate(real_time_traj):
            #         if i == path_index:
            #             plt.plot(item.path[0], item.path[1], color=color[i], alpha=1.0)
            #         else:
            #             plt.plot(item.path[0], item.path[1], color=color[i], alpha=0.3)
            #         indexs, points = item.find_closest_point(np.array([ego_x], np.float32), np.array([ego_y], np.float32))
            #         path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
            #         plt.plot(path_x, path_y,  color=color[i])
            # except Exception:
            #     pass

            # for j, item_point in enumerate(self.real_path.feature_points_all):
            #     for k in range(len(item_point)):
            #         plt.scatter(item_point[k][0], item_point[k][1], c='g')

            # text
            text_x, text_y_start = -110, 60
            ge = iter(range(0, 1000, 4))
            plt.text(text_x, text_y_start - next(ge),
                     'ego_x: {:.2f}m'.format(ego_x))
            plt.text(text_x, text_y_start - next(ge),
                     'ego_y: {:.2f}m'.format(ego_y))
            plt.text(text_x, text_y_start - next(ge),
                     'path_x: {:.2f}m'.format(path_x))
            plt.text(text_x, text_y_start - next(ge),
                     'path_y: {:.2f}m'.format(path_y))
            plt.text(text_x, text_y_start - next(ge),
                     'delta_: {:.2f}m'.format(delta_))
            plt.text(text_x, text_y_start - next(ge),
                     'delta_x: {:.2f}m'.format(delta_x))
            plt.text(text_x, text_y_start - next(ge),
                     'delta_y: {:.2f}m'.format(delta_y))
            plt.text(text_x, text_y_start - next(ge),
                     r'ego_phi: ${:.2f}\degree$'.format(ego_phi))
            plt.text(text_x, text_y_start - next(ge),
                     r'path_phi: ${:.2f}\degree$'.format(path_phi))
            plt.text(text_x, text_y_start - next(ge),
                     r'delta_phi: ${:.2f}\degree$'.format(delta_phi))

            plt.text(text_x, text_y_start - next(ge),
                     'v_x: {:.2f}m/s'.format(ego_v_x))
            plt.text(text_x, text_y_start - next(ge),
                     'exp_v: {:.2f}m/s'.format(self.exp_v))
            plt.text(text_x, text_y_start - next(ge),
                     'v_y: {:.2f}m/s'.format(ego_v_y))
            plt.text(text_x, text_y_start - next(ge),
                     'yaw_rate: {:.2f}rad/s'.format(ego_r))
            plt.text(
                text_x, text_y_start - next(ge),
                'yaw_rate bound: [{:.2f}, {:.2f}]'.format(-r_bound, r_bound))

            plt.text(text_x, text_y_start - next(ge),
                     r'$\alpha_f$: {:.2f} rad'.format(ego_alpha_f))
            plt.text(
                text_x, text_y_start - next(ge),
                r'$\alpha_f$ bound: [{:.2f}, {:.2f}] '.format(
                    -alpha_f_bound, alpha_f_bound))
            plt.text(text_x, text_y_start - next(ge),
                     r'$\alpha_r$: {:.2f} rad'.format(ego_alpha_r))
            plt.text(
                text_x, text_y_start - next(ge),
                r'$\alpha_r$ bound: [{:.2f}, {:.2f}] '.format(
                    -alpha_r_bound, alpha_r_bound))
            if self.action is not None:
                steer, a_x = self.action[0], self.action[1]
                plt.text(
                    text_x, text_y_start - next(ge),
                    r'steer: {:.2f}rad (${:.2f}\degree$)'.format(
                        steer, steer * 180 / np.pi))
                plt.text(text_x, text_y_start - next(ge),
                         'a_x: {:.2f}m/s^2'.format(a_x))

            text_x, text_y_start = 80, 60
            ge = iter(range(0, 1000, 4))

            # done info
            plt.text(text_x, text_y_start - next(ge),
                     'done info: {}'.format(self.done_type))

            # reward info
            if self.reward_info is not None:
                for key, val in self.reward_info.items():
                    plt.text(text_x, text_y_start - next(ge),
                             '{}: {:.4f}'.format(key, val))

            # indicator for trajectory selection
            # text_x, text_y_start = -25, -65
            # ge = iter(range(0, 1000, 6))
            # if traj_return is not None:
            #     for i, value in enumerate(traj_return):
            #         if i==path_index:
            #             plt.text(text_x, text_y_start-next(ge), 'track_error={:.4f}, collision_risk={:.4f}'.format(value[0], value[1]), fontsize=14, color=color[i], fontstyle='italic')
            #         else:
            #             plt.text(text_x, text_y_start-next(ge), 'track_error={:.4f}, collision_risk={:.4f}'.format(value[0], value[1]), fontsize=12, color=color[i], fontstyle='italic')

            plt.show()
            plt.pause(0.001)

    def set_traj(self, trajectory):
        """set the real trajectory to reconstruct observation"""
        self.ref_path = trajectory