コード例 #1
0
    def run(self, category, im_list, output_fn="tmp.txt"):
        # Load checkpoint
        checkpoint_dir = "{}/{}/{}".format(self.main_dir, category, self.lr)
        checkpoint, _ = get_latest_checkpoint(checkpoint_dir)
        disc = Discriminator(lr=self.lr, checkpoint=checkpoint)

        output_dir = join(self.main_dir, "predictions")
        output_path = join(output_dir, output_fn)
        recorder = Recorder(output_path, restart=False)

        for im in im_list:
            if recorder.contains(im):
                print im, recorder.get(im), "Done."
                continue

            img, _ = self.datasource.get_image(im)
            gt, _ = self.datasource.get_ground_truth(im, one_hot=True)
            ap, _ = self.datasource.get_all_prob(im)
            pr_prob = disc.predict(img, ap, args.category)
            gt_prob = disc.predict(img, gt, args.category)

            print "{} {} {}".format(im, pr_prob, gt_prob)
            recorder.save(im, [pr_prob, gt_prob])
        recorder.write()
        return output_path
コード例 #2
0
ファイル: hier_decision.py プロジェクト: idthanm/env_build
class HierarchicalDecision(object):
    def __init__(self, task, train_exp_dir, ite, logdir=None):
        self.task = task
        self.policy = LoadPolicy('../utils/models/{}/{}'.format(task, train_exp_dir), ite)
        self.env = CrossroadEnd2end(training_task=self.task, mode='testing')
        self.model = EnvironmentModel(self.task, mode='selecting')
        self.recorder = Recorder()
        self.episode_counter = -1
        self.step_counter = -1
        self.obs = None
        self.stg = MultiPathGenerator()
        self.step_timer = TimerStat()
        self.ss_timer = TimerStat()
        self.logdir = logdir
        if self.logdir is not None:
            config = dict(task=task, train_exp_dir=train_exp_dir, ite=ite)
            with open(self.logdir + '/config.json', 'w', encoding='utf-8') as f:
                json.dump(config, f, ensure_ascii=False, indent=4)
        self.fig = plt.figure(figsize=(8, 8))
        plt.ion()
        self.hist_posi = []
        self.old_index = 0
        self.path_list = self.stg.generate_path(self.task)
        # ------------------build graph for tf.function in advance-----------------------
        for i in range(3):
            obs = self.env.reset()
            obs = tf.convert_to_tensor(obs[np.newaxis, :], dtype=tf.float32)
            self.is_safe(obs, i)
        obs = self.env.reset()
        obs_with_specific_shape = np.tile(obs, (3, 1))
        self.policy.run_batch(obs_with_specific_shape)
        self.policy.obj_value_batch(obs_with_specific_shape)
        # ------------------build graph for tf.function in advance-----------------------
        self.reset()

    def reset(self,):
        self.obs = self.env.reset()
        self.recorder.reset()
        self.old_index = 0
        self.hist_posi = []
        if self.logdir is not None:
            self.episode_counter += 1
            os.makedirs(self.logdir + '/episode{}/figs'.format(self.episode_counter))
            self.step_counter = -1
            self.recorder.save(self.logdir)
            if self.episode_counter >= 1:
                select_and_rename_snapshots_of_an_episode(self.logdir, self.episode_counter-1, 12)
                self.recorder.plot_and_save_ith_episode_curves(self.episode_counter-1,
                                                               self.logdir + '/episode{}/figs'.format(self.episode_counter-1),
                                                               isshow=False)
        return self.obs

    # @tf.function
    # def is_safe(self, obs, path_index):
    #     self.model.ref_path.set_path(path_index)
    #     action = self.policy.run_batch(obs)
    #     veh2veh4real = self.model.ss(obs, action, lam=0.1)
    #     return False if veh2veh4real[0] > 0 else True

    @tf.function
    def is_safe(self, obs, path_index):
        self.model.add_traj(obs, path_index)
        punish = 0.
        for step in range(5):
            action = self.policy.run_batch(obs)
            obs, _, _, _, veh2veh4real, _ = self.model.rollout_out(action)
            punish += veh2veh4real[0]
        return False if punish > 0 else True

    def safe_shield(self, real_obs, path_index):
        action_safe_set = [[[0., -1.]]]
        real_obs = tf.convert_to_tensor(real_obs[np.newaxis, :], dtype=tf.float32)
        obs = real_obs
        if not self.is_safe(obs, path_index):
            print('SAFETY SHIELD STARTED!')
            return np.array(action_safe_set[0], dtype=np.float32).squeeze(0), True
        else:
            return self.policy.run_batch(real_obs).numpy()[0], False

    def step(self):
        self.step_counter += 1
        with self.step_timer:
            obs_list = []
            # select optimal path
            for path in self.path_list:
                self.env.set_traj(path)
                obs_list.append(self.env._get_obs())
            all_obs = tf.stack(obs_list, axis=0)
            path_values = self.policy.obj_value_batch(all_obs).numpy()
            old_value = path_values[self.old_index]
            new_index, new_value = int(np.argmin(path_values)), min(path_values)  # value is to approximate (- sum of reward)
            path_index = self.old_index if old_value - new_value < 0.1 else new_index
            self.old_index = path_index

            self.env.set_traj(self.path_list[path_index])
            self.obs_real = obs_list[path_index]

            # obtain safe action
            with self.ss_timer:
                safe_action, is_ss = self.safe_shield(self.obs_real, path_index)
            print('ALL TIME:', self.step_timer.mean, 'ss', self.ss_timer.mean)
        self.render(self.path_list, path_values, path_index)
        self.recorder.record(self.obs_real, safe_action, self.step_timer.mean,
                             path_index, path_values, self.ss_timer.mean, is_ss)
        self.obs, r, done, info = self.env.step(safe_action)
        return done

    def render(self, traj_list, path_values, path_index):
        square_length = CROSSROAD_SIZE
        extension = 40
        lane_width = LANE_WIDTH
        light_line_width = 3
        dotted_line_style = '--'
        solid_line_style = '-'

        plt.cla()
        ax = plt.axes([-0.05, -0.05, 1.1, 1.1])
        for ax in self.fig.get_axes():
            ax.axis('off')
        ax.axis("equal")

        # ----------arrow--------------
        plt.arrow(lane_width / 2, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width / 2, -square_length / 2 - 10 + 5, -0.5, 0, color='b', head_width=1)
        plt.arrow(lane_width * 1.5, -square_length / 2 - 10, 0, 4, color='b', head_width=1)
        plt.arrow(lane_width * 2.5, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width * 2.5, -square_length / 2 - 10 + 5, 0.5, 0, color='b', head_width=1)

        # ----------horizon--------------

        plt.plot([-square_length / 2 - extension, -square_length / 2], [0.3, 0.3], color='orange')
        plt.plot([-square_length / 2 - extension, -square_length / 2], [-0.3, -0.3], color='orange')
        plt.plot([square_length / 2 + extension, square_length / 2], [0.3, 0.3], color='orange')
        plt.plot([square_length / 2 + extension, square_length / 2], [-0.3, -0.3], color='orange')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            linewidth = 1 if i < LANE_NUMBER else 2
            plt.plot([-square_length / 2 - extension, -square_length / 2], [i * lane_width, i * lane_width],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([square_length / 2 + extension, square_length / 2], [i * lane_width, i * lane_width],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([-square_length / 2 - extension, -square_length / 2], [-i * lane_width, -i * lane_width],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([square_length / 2 + extension, square_length / 2], [-i * lane_width, -i * lane_width],
                     linestyle=linestyle, color='black', linewidth=linewidth)

        # ----------vertical----------------
        plt.plot([0.3, 0.3], [-square_length / 2 - extension, -square_length / 2], color='orange')
        plt.plot([-0.3, -0.3], [-square_length / 2 - extension, -square_length / 2], color='orange')
        plt.plot([0.3, 0.3], [square_length / 2 + extension, square_length / 2], color='orange')
        plt.plot([-0.3, -0.3], [square_length / 2 + extension, square_length / 2], color='orange')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            linewidth = 1 if i < LANE_NUMBER else 2
            plt.plot([i * lane_width, i * lane_width], [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([i * lane_width, i * lane_width], [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([-i * lane_width, -i * lane_width], [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle, color='black', linewidth=linewidth)
            plt.plot([-i * lane_width, -i * lane_width], [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle, color='black', linewidth=linewidth)

        v_light = self.env.v_light
        if v_light == 0:
            v_color, h_color = 'green', 'red'
        elif v_light == 1:
            v_color, h_color = 'orange', 'red'
        elif v_light == 2:
            v_color, h_color = 'red', 'green'
        else:
            v_color, h_color = 'red', 'orange'

        plt.plot([0, (LANE_NUMBER - 1) * lane_width], [-square_length / 2, -square_length / 2],
                 color=v_color, linewidth=light_line_width)
        plt.plot([(LANE_NUMBER - 1) * lane_width, LANE_NUMBER * lane_width], [-square_length / 2, -square_length / 2],
                 color='green', linewidth=light_line_width)

        plt.plot([-LANE_NUMBER * lane_width, -(LANE_NUMBER - 1) * lane_width], [square_length / 2, square_length / 2],
                 color='green', linewidth=light_line_width)
        plt.plot([-(LANE_NUMBER - 1) * lane_width, 0], [square_length / 2, square_length / 2],
                 color=v_color, linewidth=light_line_width)

        plt.plot([-square_length / 2, -square_length / 2], [0, -(LANE_NUMBER - 1) * lane_width],
                 color=h_color, linewidth=light_line_width)
        plt.plot([-square_length / 2, -square_length / 2], [-(LANE_NUMBER - 1) * lane_width, -LANE_NUMBER * lane_width],
                 color='green', linewidth=light_line_width)

        plt.plot([square_length / 2, square_length / 2], [(LANE_NUMBER - 1) * lane_width, 0],
                 color=h_color, linewidth=light_line_width)
        plt.plot([square_length / 2, square_length / 2], [LANE_NUMBER * lane_width, (LANE_NUMBER - 1) * lane_width],
                 color='green', linewidth=light_line_width)

        # ----------Oblique--------------
        plt.plot([LANE_NUMBER * lane_width, square_length / 2], [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black', linewidth=2)
        plt.plot([LANE_NUMBER * lane_width, square_length / 2], [square_length / 2, LANE_NUMBER * lane_width],
                 color='black', linewidth=2)
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2], [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black', linewidth=2)
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2], [square_length / 2, LANE_NUMBER * lane_width],
                 color='black', linewidth=2)

        def is_in_plot_area(x, y, tolerance=5):
            if -square_length / 2 - extension + tolerance < x < square_length / 2 + extension - tolerance and \
                    -square_length / 2 - extension + tolerance < y < square_length / 2 + extension - tolerance:
                return True
            else:
                return False

        def draw_rotate_rec(x, y, a, l, w, c):
            bottom_left_x, bottom_left_y, _ = rotate_coordination(-l / 2, w / 2, 0, -a)
            ax.add_patch(plt.Rectangle((x + bottom_left_x, y + bottom_left_y), w, l, edgecolor=c,
                                       facecolor='white', angle=-(90 - a), zorder=50))

        def plot_phi_line(x, y, phi, color):
            line_length = 3
            x_forw, y_forw = x + line_length * cos(phi * pi / 180.), \
                             y + line_length * sin(phi * pi / 180.)
            plt.plot([x, x_forw], [y, y_forw], color=color, linewidth=0.5)

        # plot cars
        for veh in self.env.all_vehicles:
            veh_x = veh['x']
            veh_y = veh['y']
            veh_phi = veh['phi']
            veh_l = veh['l']
            veh_w = veh['w']
            if is_in_plot_area(veh_x, veh_y):
                plot_phi_line(veh_x, veh_y, veh_phi, 'black')
                draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, 'black')

        # plot_interested vehs
        # for mode, num in self.env.veh_mode_dict.items():
        #     for i in range(num):
        #         veh = self.env.interested_vehs[mode][i]
        #         veh_x = veh['x']
        #         veh_y = veh['y']
        #         veh_phi = veh['phi']
        #         veh_l = veh['l']
        #         veh_w = veh['w']
        #         task2color = {'left': 'b', 'straight': 'c', 'right': 'm'}
        #
        #         if is_in_plot_area(veh_x, veh_y):
        #             plot_phi_line(veh_x, veh_y, veh_phi, 'black')
        #             task = MODE2TASK[mode]
        #             color = task2color[task]
        #             draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, color)

        ego_v_x = self.env.ego_dynamics['v_x']
        ego_v_y = self.env.ego_dynamics['v_y']
        ego_r = self.env.ego_dynamics['r']
        ego_x = self.env.ego_dynamics['x']
        ego_y = self.env.ego_dynamics['y']
        ego_phi = self.env.ego_dynamics['phi']
        ego_l = self.env.ego_dynamics['l']
        ego_w = self.env.ego_dynamics['w']
        ego_alpha_f = self.env.ego_dynamics['alpha_f']
        ego_alpha_r = self.env.ego_dynamics['alpha_r']
        alpha_f_bound = self.env.ego_dynamics['alpha_f_bound']
        alpha_r_bound = self.env.ego_dynamics['alpha_r_bound']
        r_bound = self.env.ego_dynamics['r_bound']

        plot_phi_line(ego_x, ego_y, ego_phi, 'fuchsia')
        draw_rotate_rec(ego_x, ego_y, ego_phi, ego_l, ego_w, 'fuchsia')
        self.hist_posi.append((ego_x, ego_y))

        # plot history
        xs = [pos[0] for pos in self.hist_posi]
        ys = [pos[1] for pos in self.hist_posi]
        plt.scatter(np.array(xs), np.array(ys), color='fuchsia', alpha=0.1)


        # plot future data
        tracking_info = self.obs[
                        self.env.ego_info_dim:self.env.ego_info_dim + self.env.per_tracking_info_dim * (self.env.num_future_data + 1)]
        future_path = tracking_info[self.env.per_tracking_info_dim:]
        for i in range(self.env.num_future_data):
            delta_x, delta_y, delta_phi = future_path[i * self.env.per_tracking_info_dim:
                                                      (i + 1) * self.env.per_tracking_info_dim]
            path_x, path_y, path_phi = ego_x + delta_x, ego_y + delta_y, ego_phi - delta_phi
            plt.plot(path_x, path_y, 'g.')
            plot_phi_line(path_x, path_y, path_phi, 'g')

        delta_, _, _ = tracking_info[:3]
        indexs, points = self.env.ref_path.find_closest_point(np.array([ego_x], np.float32), np.array([ego_y], np.float32))
        path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
        # plt.plot(path_x, path_y, 'g.')
        delta_x, delta_y, delta_phi = ego_x - path_x, ego_y - path_y, ego_phi - path_phi

        # plot real time traj
        try:
            color = ['blue', 'coral', 'darkcyan']
            for i, item in enumerate(traj_list):
                if i == path_index:
                    plt.plot(item.path[0], item.path[1], color=color[i])
                else:
                    plt.plot(item.path[0], item.path[1], color=color[i], alpha=0.3)
                indexs, points = item.find_closest_point(np.array([ego_x], np.float32), np.array([ego_y], np.float32))
                path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
                plt.plot(path_x, path_y, color=color[i])
        except Exception:
            pass

        # text
        # text_x, text_y_start = -120, 60
        # ge = iter(range(0, 1000, 4))
        # plt.text(text_x, text_y_start - next(ge), 'ego_x: {:.2f}m'.format(ego_x))
        # plt.text(text_x, text_y_start - next(ge), 'ego_y: {:.2f}m'.format(ego_y))
        # plt.text(text_x, text_y_start - next(ge), 'path_x: {:.2f}m'.format(path_x))
        # plt.text(text_x, text_y_start - next(ge), 'path_y: {:.2f}m'.format(path_y))
        # plt.text(text_x, text_y_start - next(ge), 'delta_: {:.2f}m'.format(delta_))
        # plt.text(text_x, text_y_start - next(ge), 'delta_x: {:.2f}m'.format(delta_x))
        # plt.text(text_x, text_y_start - next(ge), 'delta_y: {:.2f}m'.format(delta_y))
        # plt.text(text_x, text_y_start - next(ge), r'ego_phi: ${:.2f}\degree$'.format(ego_phi))
        # plt.text(text_x, text_y_start - next(ge), r'path_phi: ${:.2f}\degree$'.format(path_phi))
        # plt.text(text_x, text_y_start - next(ge), r'delta_phi: ${:.2f}\degree$'.format(delta_phi))
        # plt.text(text_x, text_y_start - next(ge), 'v_x: {:.2f}m/s'.format(ego_v_x))
        # plt.text(text_x, text_y_start - next(ge), 'exp_v: {:.2f}m/s'.format(self.env.exp_v))
        # plt.text(text_x, text_y_start - next(ge), 'v_y: {:.2f}m/s'.format(ego_v_y))
        # plt.text(text_x, text_y_start - next(ge), 'yaw_rate: {:.2f}rad/s'.format(ego_r))
        # plt.text(text_x, text_y_start - next(ge), 'yaw_rate bound: [{:.2f}, {:.2f}]'.format(-r_bound, r_bound))
        #
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$: {:.2f} rad'.format(ego_alpha_f))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$ bound: [{:.2f}, {:.2f}] '.format(-alpha_f_bound,
        #                                                                                         alpha_f_bound))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$: {:.2f} rad'.format(ego_alpha_r))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$ bound: [{:.2f}, {:.2f}] '.format(-alpha_r_bound,
        #                                                                                         alpha_r_bound))
        # if self.env.action is not None:
        #     steer, a_x = self.env.action[0], self.env.action[1]
        #     plt.text(text_x, text_y_start - next(ge),
        #              r'steer: {:.2f}rad (${:.2f}\degree$)'.format(steer, steer * 180 / np.pi))
        #     plt.text(text_x, text_y_start - next(ge), 'a_x: {:.2f}m/s^2'.format(a_x))
        #
        # text_x, text_y_start = 70, 60
        # ge = iter(range(0, 1000, 4))
        #
        # # done info
        # plt.text(text_x, text_y_start - next(ge), 'done info: {}'.format(self.env.done_type))
        #
        # # reward info
        # if self.env.reward_info is not None:
        #     for key, val in self.env.reward_info.items():
        #         plt.text(text_x, text_y_start - next(ge), '{}: {:.4f}'.format(key, val))
        #
        # # indicator for trajectory selection
        # text_x, text_y_start = -18, -70
        # ge = iter(range(0, 1000, 6))
        # if path_values is not None:
        #     for i, value in enumerate(path_values):
        #         if i == path_index:
        #             plt.text(text_x, text_y_start - next(ge), 'Path reward={:.4f}'.format(value[0]), fontsize=14,
        #                      color=color[i], fontstyle='italic')
        #         else:
        #             plt.text(text_x, text_y_start - next(ge), 'Path reward={:.4f}'.format(value[0]), fontsize=12,
        #                      color=color[i], fontstyle='italic')
        plt.show()
        plt.pause(0.001)
        if self.logdir is not None:
            plt.savefig(self.logdir + '/episode{}'.format(self.episode_counter) + '/step{:03d}.png'.format(self.step_counter))
コード例 #3
0
class HierarchicalDecision(object):
    def __init__(self, task, logdir=None):
        self.task = task
        if self.task == 'left':
            self.policy = LoadPolicy('../utils/models/left', 100000)
        elif self.task == 'right':
            self.policy = LoadPolicy('../utils/models/right', 145000)
        elif self.task == 'straight':
            self.policy = LoadPolicy('../utils/models/straight', 95000)
        self.env = CrossroadEnd2end(training_task=self.task)
        self.model = EnvironmentModel(self.task, mode='selecting')
        self.recorder = Recorder()
        self.episode_counter = -1
        self.step_counter = -1
        self.obs = None
        self.stg = None
        self.step_timer = TimerStat()
        self.ss_timer = TimerStat()
        self.logdir = logdir
        self.fig = plt.figure(figsize=(8, 8))
        plt.ion()
        self.hist_posi = []
        self.reset()

    def reset(self, ):
        self.obs = self.env.reset()
        self.stg = MultiPathGenerator()
        self.recorder.reset()
        self.hist_posi = []
        if self.logdir is not None:
            self.episode_counter += 1
            os.makedirs(self.logdir +
                        '/episode{}'.format(self.episode_counter))
            self.step_counter = -1
            self.recorder.save(self.logdir)
        return self.obs

    def safe_shield(self, real_obs, traj):
        action_bound = 1.0
        action_safe_set = [[[0, -action_bound]]]
        real_obs = tf.convert_to_tensor(real_obs[np.newaxis, :])
        obs = real_obs
        self.model.add_traj(obs, traj)
        total_punishment = 0.0
        for step in range(3):
            action = self.policy.run(obs)
            _, _, _, _, veh2veh4real, _ = self.model.rollout_out(action)
            total_punishment += veh2veh4real
        if total_punishment != 0:
            print('original action will cause collision within three steps!!!')
            for safe_action in action_safe_set:
                obs = real_obs
                total_punishment = 0
                # for step in range(1):
                #     obs, veh2veh4real = self.model.safety_calculation(obs, safe_action)
                #     total_punishment += veh2veh4real
                #     if veh2veh4real != 0:   # collide
                #         break
                # if total_punishment == 0:
                #     print('found the safe action', safe_action)
                #     safe_action = np.array(safe_action)
                #     break
                # else:
                #     print('still collide')
                #     safe_action = self.policy.run(real_obs).numpy().squeeze(0)
                return np.array(safe_action).squeeze(0)
        else:
            safe_action = self.policy.run(real_obs).numpy()[0]
            return safe_action

    def step(self):
        self.step_counter += 1
        with self.step_timer:
            path_list = self.stg.generate_path(self.task)
            obs_list = []

            # select optimal path
            for path in path_list:
                self.env.set_traj(path)
                obs_list.append(self.env._get_obs())
            all_obs = tf.stack(obs_list, axis=0)
            path_values = self.policy.values(all_obs).numpy().squeeze()
            path_index = int(np.argmax(path_values[:, 0]))

            self.env.set_traj(path_list[path_index])
            self.obs_real = obs_list[path_index]

            # obtain safe action
            # with self.ss_timer:
            #     safe_action = self.safe_shield(self.obs_real, path_list[path_index])
            safe_action = self.policy.run(self.obs_real).numpy()

            print('ALL TIME:', self.step_timer.mean)
        # deal with red light temporally
        # if self.env.v_light != 0 and -25 > self.env.ego_dynamics['y'] > -35 and self.env.training_task != 'right':
        #     scaled_steer = 0.
        #     if self.env.ego_dynamics['v_x'] == 0.0:
        #         scaled_a_x = 0.33
        #     else:
        #         scaled_a_x = np.random.uniform(-0.6, -0.4)
        #     safe_action = np.array([scaled_steer, scaled_a_x], dtype=np.float32)
        self.render(path_list, path_values, path_index)
        self.recorder.record(self.obs_real, safe_action, self.step_timer.mean,
                             path_index, path_values, self.ss_timer.mean)
        self.obs, r, done, info = self.env.step(safe_action)
        return done

    def render(self, traj_list, path_values, path_index):
        square_length = CROSSROAD_SIZE
        extension = 40
        lane_width = LANE_WIDTH
        light_line_width = 3
        dotted_line_style = '--'
        solid_line_style = '-'

        plt.cla()
        ax = plt.axes([-0.05, -0.05, 1.1, 1.1])
        for ax in self.fig.get_axes():
            ax.axis('off')
        ax.axis("equal")

        # ----------arrow--------------
        plt.arrow(lane_width / 2, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width / 2,
                  -square_length / 2 - 10 + 5,
                  -0.5,
                  0,
                  color='b',
                  head_width=1)
        plt.arrow(lane_width * 1.5,
                  -square_length / 2 - 10,
                  0,
                  4,
                  color='b',
                  head_width=1)
        plt.arrow(lane_width * 2.5, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width * 2.5,
                  -square_length / 2 - 10 + 5,
                  0.5,
                  0,
                  color='b',
                  head_width=1)

        # ----------horizon--------------

        plt.plot([-square_length / 2 - extension, -square_length / 2],
                 [0.3, 0.3],
                 color='orange')
        plt.plot([-square_length / 2 - extension, -square_length / 2],
                 [-0.3, -0.3],
                 color='orange')
        plt.plot([square_length / 2 + extension, square_length / 2],
                 [0.3, 0.3],
                 color='orange')
        plt.plot([square_length / 2 + extension, square_length / 2],
                 [-0.3, -0.3],
                 color='orange')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            linewidth = 1 if i < LANE_NUMBER else 2
            plt.plot([-square_length / 2 - extension, -square_length / 2],
                     [i * lane_width, i * lane_width],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([square_length / 2 + extension, square_length / 2],
                     [i * lane_width, i * lane_width],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([-square_length / 2 - extension, -square_length / 2],
                     [-i * lane_width, -i * lane_width],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([square_length / 2 + extension, square_length / 2],
                     [-i * lane_width, -i * lane_width],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)

        # ----------vertical----------------
        plt.plot([0.3, 0.3],
                 [-square_length / 2 - extension, -square_length / 2],
                 color='orange')
        plt.plot([-0.3, -0.3],
                 [-square_length / 2 - extension, -square_length / 2],
                 color='orange')
        plt.plot([0.3, 0.3],
                 [square_length / 2 + extension, square_length / 2],
                 color='orange')
        plt.plot([-0.3, -0.3],
                 [square_length / 2 + extension, square_length / 2],
                 color='orange')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            linewidth = 1 if i < LANE_NUMBER else 2
            plt.plot([i * lane_width, i * lane_width],
                     [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([i * lane_width, i * lane_width],
                     [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([-i * lane_width, -i * lane_width],
                     [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)
            plt.plot([-i * lane_width, -i * lane_width],
                     [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle,
                     color='black',
                     linewidth=linewidth)

        v_light = self.env.v_light
        if v_light == 0:
            v_color, h_color = 'green', 'red'
        elif v_light == 1:
            v_color, h_color = 'orange', 'red'
        elif v_light == 2:
            v_color, h_color = 'red', 'green'
        else:
            v_color, h_color = 'red', 'orange'

        plt.plot([0, (LANE_NUMBER - 1) * lane_width],
                 [-square_length / 2, -square_length / 2],
                 color=v_color,
                 linewidth=light_line_width)
        plt.plot([(LANE_NUMBER - 1) * lane_width, LANE_NUMBER * lane_width],
                 [-square_length / 2, -square_length / 2],
                 color='green',
                 linewidth=light_line_width)

        plt.plot([-LANE_NUMBER * lane_width, -(LANE_NUMBER - 1) * lane_width],
                 [square_length / 2, square_length / 2],
                 color='green',
                 linewidth=light_line_width)
        plt.plot([-(LANE_NUMBER - 1) * lane_width, 0],
                 [square_length / 2, square_length / 2],
                 color=v_color,
                 linewidth=light_line_width)

        plt.plot([-square_length / 2, -square_length / 2],
                 [0, -(LANE_NUMBER - 1) * lane_width],
                 color=h_color,
                 linewidth=light_line_width)
        plt.plot([-square_length / 2, -square_length / 2],
                 [-(LANE_NUMBER - 1) * lane_width, -LANE_NUMBER * lane_width],
                 color='green',
                 linewidth=light_line_width)

        plt.plot([square_length / 2, square_length / 2],
                 [(LANE_NUMBER - 1) * lane_width, 0],
                 color=h_color,
                 linewidth=light_line_width)
        plt.plot([square_length / 2, square_length / 2],
                 [LANE_NUMBER * lane_width, (LANE_NUMBER - 1) * lane_width],
                 color='green',
                 linewidth=light_line_width)

        # ----------Oblique--------------
        plt.plot([LANE_NUMBER * lane_width, square_length / 2],
                 [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black',
                 linewidth=2)
        plt.plot([LANE_NUMBER * lane_width, square_length / 2],
                 [square_length / 2, LANE_NUMBER * lane_width],
                 color='black',
                 linewidth=2)
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2],
                 [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black',
                 linewidth=2)
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2],
                 [square_length / 2, LANE_NUMBER * lane_width],
                 color='black',
                 linewidth=2)

        def is_in_plot_area(x, y, tolerance=5):
            if -square_length / 2 - extension + tolerance < x < square_length / 2 + extension - tolerance and \
                    -square_length / 2 - extension + tolerance < y < square_length / 2 + extension - tolerance:
                return True
            else:
                return False

        def draw_rotate_rec(x, y, a, l, w, c):
            bottom_left_x, bottom_left_y, _ = rotate_coordination(
                -l / 2, w / 2, 0, -a)
            ax.add_patch(
                plt.Rectangle((x + bottom_left_x, y + bottom_left_y),
                              w,
                              l,
                              edgecolor=c,
                              facecolor='white',
                              angle=-(90 - a),
                              zorder=50))

        def plot_phi_line(x, y, phi, color):
            line_length = 3
            x_forw, y_forw = x + line_length * cos(phi * pi / 180.), \
                             y + line_length * sin(phi * pi / 180.)
            plt.plot([x, x_forw], [y, y_forw], color=color, linewidth=0.5)

        # plot cars
        for veh in self.env.all_vehicles:
            veh_x = veh['x']
            veh_y = veh['y']
            veh_phi = veh['phi']
            veh_l = veh['l']
            veh_w = veh['w']
            if is_in_plot_area(veh_x, veh_y):
                plot_phi_line(veh_x, veh_y, veh_phi, 'black')
                draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, 'black')

        # plot_interested vehs
        # for mode, num in self.veh_mode_dict.items():
        #     for i in range(num):
        #         veh = self.interested_vehs[mode][i]
        #         veh_x = veh['x']
        #         veh_y = veh['y']
        #         veh_phi = veh['phi']
        #         veh_l = veh['l']
        #         veh_w = veh['w']
        #         task2color = {'left': 'b', 'straight': 'c', 'right': 'm'}
        #
        #         if is_in_plot_area(veh_x, veh_y):
        #             plot_phi_line(veh_x, veh_y, veh_phi, 'black')
        #             task = MODE2TASK[mode]
        #             color = task2color[task]
        #             draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, color, linestyle=':')

        ego_v_x = self.env.ego_dynamics['v_x']
        ego_v_y = self.env.ego_dynamics['v_y']
        ego_r = self.env.ego_dynamics['r']
        ego_x = self.env.ego_dynamics['x']
        ego_y = self.env.ego_dynamics['y']
        ego_phi = self.env.ego_dynamics['phi']
        ego_l = self.env.ego_dynamics['l']
        ego_w = self.env.ego_dynamics['w']
        ego_alpha_f = self.env.ego_dynamics['alpha_f']
        ego_alpha_r = self.env.ego_dynamics['alpha_r']
        alpha_f_bound = self.env.ego_dynamics['alpha_f_bound']
        alpha_r_bound = self.env.ego_dynamics['alpha_r_bound']
        r_bound = self.env.ego_dynamics['r_bound']

        plot_phi_line(ego_x, ego_y, ego_phi, 'fuchsia')
        draw_rotate_rec(ego_x, ego_y, ego_phi, ego_l, ego_w, 'fuchsia')
        self.hist_posi.append((ego_x, ego_y))
        # plot history data
        for hist_x, hist_y in self.hist_posi:
            plt.scatter(hist_x, hist_y, color='fuchsia', alpha=0.1)
        # plot future data
        tracking_info = self.obs[self.env.ego_info_dim:self.env.ego_info_dim +
                                 self.env.per_tracking_info_dim *
                                 (self.env.num_future_data + 1)]
        future_path = tracking_info[self.env.per_tracking_info_dim:]
        for i in range(self.env.num_future_data):
            delta_x, delta_y, delta_phi = future_path[
                i * self.env.per_tracking_info_dim:(i + 1) *
                self.env.per_tracking_info_dim]
            path_x, path_y, path_phi = ego_x + delta_x, ego_y + delta_y, ego_phi - delta_phi
            plt.plot(path_x, path_y, 'g.')
            plot_phi_line(path_x, path_y, path_phi, 'g')

        delta_, _, _ = tracking_info[:3]
        indexs, points = self.env.ref_path.find_closest_point(
            np.array([ego_x], np.float32), np.array([ego_y], np.float32))
        path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
        # plt.plot(path_x, path_y, 'g.')
        delta_x, delta_y, delta_phi = ego_x - path_x, ego_y - path_y, ego_phi - path_phi

        # plot real time traj
        try:
            color = ['blue', 'coral', 'darkcyan']
            for i, item in enumerate(traj_list):
                if i == path_index:
                    plt.plot(item.path[0], item.path[1], color=color[i])
                else:
                    plt.plot(item.path[0],
                             item.path[1],
                             color=color[i],
                             alpha=0.3)
                indexs, points = item.find_closest_point(
                    np.array([ego_x], np.float32), np.array([ego_y],
                                                            np.float32))
                path_x, path_y, path_phi = points[0][0], points[1][0], points[
                    2][0]
                plt.plot(path_x, path_y, color=color[i])
        except Exception:
            pass

        # text
        # text_x, text_y_start = -120, 60
        # ge = iter(range(0, 1000, 4))
        # plt.text(text_x, text_y_start - next(ge), 'ego_x: {:.2f}m'.format(ego_x))
        # plt.text(text_x, text_y_start - next(ge), 'ego_y: {:.2f}m'.format(ego_y))
        # plt.text(text_x, text_y_start - next(ge), 'path_x: {:.2f}m'.format(path_x))
        # plt.text(text_x, text_y_start - next(ge), 'path_y: {:.2f}m'.format(path_y))
        # plt.text(text_x, text_y_start - next(ge), 'delta_: {:.2f}m'.format(delta_))
        # plt.text(text_x, text_y_start - next(ge), 'delta_x: {:.2f}m'.format(delta_x))
        # plt.text(text_x, text_y_start - next(ge), 'delta_y: {:.2f}m'.format(delta_y))
        # plt.text(text_x, text_y_start - next(ge), r'ego_phi: ${:.2f}\degree$'.format(ego_phi))
        # plt.text(text_x, text_y_start - next(ge), r'path_phi: ${:.2f}\degree$'.format(path_phi))
        # plt.text(text_x, text_y_start - next(ge), r'delta_phi: ${:.2f}\degree$'.format(delta_phi))
        # plt.text(text_x, text_y_start - next(ge), 'v_x: {:.2f}m/s'.format(ego_v_x))
        # plt.text(text_x, text_y_start - next(ge), 'exp_v: {:.2f}m/s'.format(self.env.exp_v))
        # plt.text(text_x, text_y_start - next(ge), 'v_y: {:.2f}m/s'.format(ego_v_y))
        # plt.text(text_x, text_y_start - next(ge), 'yaw_rate: {:.2f}rad/s'.format(ego_r))
        # plt.text(text_x, text_y_start - next(ge), 'yaw_rate bound: [{:.2f}, {:.2f}]'.format(-r_bound, r_bound))
        #
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$: {:.2f} rad'.format(ego_alpha_f))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$ bound: [{:.2f}, {:.2f}] '.format(-alpha_f_bound,
        #                                                                                         alpha_f_bound))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$: {:.2f} rad'.format(ego_alpha_r))
        # plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$ bound: [{:.2f}, {:.2f}] '.format(-alpha_r_bound,
        #                                                                                         alpha_r_bound))
        # if self.env.action is not None:
        #     steer, a_x = self.env.action[0], self.env.action[1]
        #     plt.text(text_x, text_y_start - next(ge),
        #              r'steer: {:.2f}rad (${:.2f}\degree$)'.format(steer, steer * 180 / np.pi))
        #     plt.text(text_x, text_y_start - next(ge), 'a_x: {:.2f}m/s^2'.format(a_x))
        #
        # text_x, text_y_start = 70, 60
        # ge = iter(range(0, 1000, 4))
        #
        # # done info
        # plt.text(text_x, text_y_start - next(ge), 'done info: {}'.format(self.env.done_type))
        #
        # # reward info
        # if self.env.reward_info is not None:
        #     for key, val in self.env.reward_info.items():
        #         plt.text(text_x, text_y_start - next(ge), '{}: {:.4f}'.format(key, val))
        #
        # # indicator for trajectory selection
        # text_x, text_y_start = -18, -70
        # ge = iter(range(0, 1000, 6))
        # if path_values is not None:
        #     for i, value in enumerate(path_values):
        #         if i == path_index:
        #             plt.text(text_x, text_y_start - next(ge), 'Path reward={:.4f}'.format(value[0]), fontsize=14,
        #                      color=color[i], fontstyle='italic')
        #         else:
        #             plt.text(text_x, text_y_start - next(ge), 'Path reward={:.4f}'.format(value[0]), fontsize=12,
        #                      color=color[i], fontstyle='italic')
        plt.show()
        plt.pause(0.001)
        if self.logdir is not None:
            plt.savefig(self.logdir +
                        '/episode{}'.format(self.episode_counter) +
                        '/step{}.pdf'.format(self.step_counter))
コード例 #4
0
ファイル: mpc_ipopt.py プロジェクト: idthanm/env_build
class HierarchicalMpc(object):
    def __init__(self, task):
        self.task = task
        if self.task == 'left':
            self.policy = LoadPolicy('G:\\env_build\\utils\\models\\left', 100000)
        elif self.task == 'right':
            self.policy = LoadPolicy('G:\\env_build\\utils\\models\\right', 145000)
        elif self.task == 'straight':
            self.policy = LoadPolicy('G:\\env_build\\utils\\models\\straight', 95000)

        self.horizon = 25
        self.num_future_data = 0
        self.env = CrossroadEnd2end(training_task=self.task, num_future_data=self.num_future_data)
        self.model = EnvironmentModel(self.task)
        self.obs = self.env.reset()
        self.stg = StaticTrajectoryGenerator_origin(mode='static_traj')
        self.data2plot = []
        self.mpc_cal_timer = TimerStat()
        self.adp_cal_timer = TimerStat()
        self.recorder = Recorder()

    def reset(self):
        self.obs = self.env.reset()
        self.stg = StaticTrajectoryGenerator_origin(mode='static_traj')
        self.recorder.reset()
        self.recorder.save('.')
        self.data2plot = []
        return self.obs

    def convert_vehs_to_abso(self, obs_rela):
        ego_infos, tracking_infos, veh_rela = obs_rela[:6], \
                                              obs_rela[6:6 + 3 * (1 + self.num_future_data)],\
                                              obs_rela[6 + 3 * (1 + self.num_future_data):]
        ego_vx, ego_vy, ego_r, ego_x, ego_y, ego_phi = ego_infos
        ego = np.array([ego_x, ego_y, 0, 0] * int(len(veh_rela) / 4), dtype=np.float32)
        vehs_abso = veh_rela + ego
        out = np.concatenate((ego_infos, tracking_infos, vehs_abso), axis=0)
        return out

    def step(self):
        traj_list, _ = self.stg.generate_traj(self.task, self.obs)
        ADP_traj_return_value, MPC_traj_return_value = [], []
        action_total = []
        state_total = []

        with self.mpc_cal_timer:
            for ref_index, trajectory in enumerate(traj_list):
                mpc = ModelPredictiveControl(self.horizon, self.task, self.num_future_data, ref_index)
                state_all = np.array((list(self.obs[:6 + 3 * (1 + self.num_future_data)]) + [0, 0]) * self.horizon +
                                      list(self.obs[:6 + 3 * (1 + self.num_future_data)])).reshape((-1, 1))
                state, control, state_all, g_all, cost = mpc.mpc_solver(list(self.convert_vehs_to_abso(self.obs)), state_all)
                state_total.append(state)
                if any(g_all < -1):
                    print('optimization fail')
                    mpc_action = np.array([0., -1.])
                    state_all = np.array((list(self.obs[:9]) + [0, 0]) * self.horizon + list(self.obs[:9])).reshape(
                        (-1, 1))
                else:
                    state_all = np.array((list(self.obs[:9]) + [0, 0]) * self.horizon + list(self.obs[:9])).reshape(
                        (-1, 1))
                    mpc_action = control[0]

                MPC_traj_return_value.append(-cost.squeeze().tolist())
                action_total.append(mpc_action)

            MPC_traj_return_value = np.array(MPC_traj_return_value, dtype=np.float32)
            MPC_path_index = np.argmax(MPC_traj_return_value)
            MPC_action = action_total[MPC_path_index]

        with self.adp_cal_timer:
            for ref_index, trajectory in enumerate(traj_list):
                self.env.set_traj(trajectory)
                obs = self.env._get_obs()[np.newaxis, :]
                traj_value = self.policy.values(obs)
                ADP_traj_return_value.append(traj_value.numpy().squeeze().tolist())

            ADP_traj_return_value = np.array(ADP_traj_return_value, dtype=np.float32)[:, 0]
            ADP_path_index = np.argmax(ADP_traj_return_value)
            if np.amax(ADP_traj_return_value) == np.amin(ADP_traj_return_value):
                ADP_path_index = MPC_path_index
            self.env.set_traj(traj_list[ADP_path_index])
            self.obs_real = self.env._get_obs()
            ADP_action = self.policy.run(self.obs_real).numpy()

        self.recorder.record_compare(self.obs, ADP_action, MPC_action, self.adp_cal_timer.mean * 1000, self.mpc_cal_timer.mean * 1000,
                             ADP_path_index, MPC_path_index, 'both')

        self.data2plot.append(dict(obs=self.obs,
                                   ADP_action=ADP_action,
                                   MPC_action=MPC_action,
                                   ADP_path_index=ADP_path_index,
                                   MPC_path_index=MPC_path_index,
                                   mpc_time=self.mpc_cal_timer.mean * 1000,
                                   ))

        self.obs, rew, done, _ = self.env.step(ADP_action)
        self.render(traj_list, ADP_traj_return_value, ADP_path_index, MPC_traj_return_value, MPC_path_index, method='ADP')
        state = state_total[MPC_path_index]
        plt.plot([state[i][3] for i in range(1, self.horizon - 1)], [state[i][4] for i in range(1, self.horizon - 1)], 'r*')
        plt.pause(0.001)

        return done

    def render(self, traj_list, ADP_traj_return_value, ADP_path_index, MPC_traj_return_value, MPC_path_index, method='ADP'):
        square_length = CROSSROAD_SIZE
        extension = 40
        lane_width = LANE_WIDTH
        light_line_width = 3
        dotted_line_style = '--'
        solid_line_style = '-'

        plt.cla()
        plt.title("Crossroad")
        ax = plt.axes(xlim=(-square_length / 2 - extension, square_length / 2 + extension),
                      ylim=(-square_length / 2 - extension, square_length / 2 + extension))
        plt.axis("equal")
        plt.axis('off')

        # ax.add_patch(plt.Rectangle((-square_length / 2, -square_length / 2),
        #                            square_length, square_length, edgecolor='black', facecolor='none'))
        ax.add_patch(plt.Rectangle((-square_length / 2 - extension, -square_length / 2 - extension),
                                   square_length + 2 * extension, square_length + 2 * extension, edgecolor='black',
                                   facecolor='none'))

        # ----------arrow--------------
        plt.arrow(lane_width / 2, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width / 2, -square_length / 2 - 10 + 5, -0.5, 0, color='b', head_width=1)
        plt.arrow(lane_width * 1.5, -square_length / 2 - 10, 0, 5, color='b', head_width=1)
        plt.arrow(lane_width * 2.5, -square_length / 2 - 10, 0, 5, color='b')
        plt.arrow(lane_width * 2.5, -square_length / 2 - 10 + 5, 0.5, 0, color='b', head_width=1)

        # ----------horizon--------------
        plt.plot([-square_length / 2 - extension, -square_length / 2], [0, 0], color='black')
        plt.plot([square_length / 2 + extension, square_length / 2], [0, 0], color='black')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            plt.plot([-square_length / 2 - extension, -square_length / 2], [i * lane_width, i * lane_width],
                     linestyle=linestyle, color='black')
            plt.plot([square_length / 2 + extension, square_length / 2], [i * lane_width, i * lane_width],
                     linestyle=linestyle, color='black')
            plt.plot([-square_length / 2 - extension, -square_length / 2], [-i * lane_width, -i * lane_width],
                     linestyle=linestyle, color='black')
            plt.plot([square_length / 2 + extension, square_length / 2], [-i * lane_width, -i * lane_width],
                     linestyle=linestyle, color='black')

        # ----------vertical----------------
        plt.plot([0, 0], [-square_length / 2 - extension, -square_length / 2], color='black')
        plt.plot([0, 0], [square_length / 2 + extension, square_length / 2], color='black')

        #
        for i in range(1, LANE_NUMBER + 1):
            linestyle = dotted_line_style if i < LANE_NUMBER else solid_line_style
            plt.plot([i * lane_width, i * lane_width], [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle, color='black')
            plt.plot([i * lane_width, i * lane_width], [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle, color='black')
            plt.plot([-i * lane_width, -i * lane_width], [-square_length / 2 - extension, -square_length / 2],
                     linestyle=linestyle, color='black')
            plt.plot([-i * lane_width, -i * lane_width], [square_length / 2 + extension, square_length / 2],
                     linestyle=linestyle, color='black')

        v_light = self.env.v_light
        if v_light == 0:
            v_color, h_color = 'green', 'red'
        elif v_light == 1:
            v_color, h_color = 'orange', 'red'
        elif v_light == 2:
            v_color, h_color = 'red', 'green'
        else:
            v_color, h_color = 'red', 'orange'

        plt.plot([0, (LANE_NUMBER - 1) * lane_width], [-square_length / 2, -square_length / 2],
                 color=v_color, linewidth=light_line_width)
        plt.plot([(LANE_NUMBER - 1) * lane_width, LANE_NUMBER * lane_width], [-square_length / 2, -square_length / 2],
                 color='green', linewidth=light_line_width)

        plt.plot([-LANE_NUMBER * lane_width, -(LANE_NUMBER - 1) * lane_width], [square_length / 2, square_length / 2],
                 color='green', linewidth=light_line_width)
        plt.plot([-(LANE_NUMBER - 1) * lane_width, 0], [square_length / 2, square_length / 2],
                 color=v_color, linewidth=light_line_width)

        plt.plot([-square_length / 2, -square_length / 2], [0, -(LANE_NUMBER - 1) * lane_width],
                 color=h_color, linewidth=light_line_width)
        plt.plot([-square_length / 2, -square_length / 2], [-(LANE_NUMBER - 1) * lane_width, -LANE_NUMBER * lane_width],
                 color='green', linewidth=light_line_width)

        plt.plot([square_length / 2, square_length / 2], [(LANE_NUMBER - 1) * lane_width, 0],
                 color=h_color, linewidth=light_line_width)
        plt.plot([square_length / 2, square_length / 2], [LANE_NUMBER * lane_width, (LANE_NUMBER - 1) * lane_width],
                 color='green', linewidth=light_line_width)

        # ----------Oblique--------------
        plt.plot([LANE_NUMBER * lane_width, square_length / 2], [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black')
        plt.plot([LANE_NUMBER * lane_width, square_length / 2], [square_length / 2, LANE_NUMBER * lane_width],
                 color='black')
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2], [-square_length / 2, -LANE_NUMBER * lane_width],
                 color='black')
        plt.plot([-LANE_NUMBER * lane_width, -square_length / 2], [square_length / 2, LANE_NUMBER * lane_width],
                 color='black')

        def is_in_plot_area(x, y, tolerance=5):
            if -square_length / 2 - extension + tolerance < x < square_length / 2 + extension - tolerance and \
                    -square_length / 2 - extension + tolerance < y < square_length / 2 + extension - tolerance:
                return True
            else:
                return False

        def draw_rotate_rec(x, y, a, l, w, color, linestyle='-'):
            RU_x, RU_y, _ = rotate_coordination(l / 2, w / 2, 0, -a)
            RD_x, RD_y, _ = rotate_coordination(l / 2, -w / 2, 0, -a)
            LU_x, LU_y, _ = rotate_coordination(-l / 2, w / 2, 0, -a)
            LD_x, LD_y, _ = rotate_coordination(-l / 2, -w / 2, 0, -a)
            ax.plot([RU_x + x, RD_x + x], [RU_y + y, RD_y + y], color=color, linestyle=linestyle)
            ax.plot([RU_x + x, LU_x + x], [RU_y + y, LU_y + y], color=color, linestyle=linestyle)
            ax.plot([LD_x + x, RD_x + x], [LD_y + y, RD_y + y], color=color, linestyle=linestyle)
            ax.plot([LD_x + x, LU_x + x], [LD_y + y, LU_y + y], color=color, linestyle=linestyle)

        def plot_phi_line(x, y, phi, color):
            line_length = 5
            x_forw, y_forw = x + line_length * cos(phi * pi / 180.), \
                             y + line_length * sin(phi * pi / 180.)
            plt.plot([x, x_forw], [y, y_forw], color=color, linewidth=0.5)

        # plot cars
        for veh in self.env.all_vehicles:
            veh_x = veh['x']
            veh_y = veh['y']
            veh_phi = veh['phi']
            veh_l = veh['l']
            veh_w = veh['w']
            if is_in_plot_area(veh_x, veh_y):
                plot_phi_line(veh_x, veh_y, veh_phi, 'black')
                draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, 'black')

        # plot_interested vehs
        # for mode, num in self.veh_mode_dict.items():
        #     for i in range(num):
        #         veh = self.interested_vehs[mode][i]
        #         veh_x = veh['x']
        #         veh_y = veh['y']
        #         veh_phi = veh['phi']
        #         veh_l = veh['l']
        #         veh_w = veh['w']
        #         task2color = {'left': 'b', 'straight': 'c', 'right': 'm'}
        #
        #         if is_in_plot_area(veh_x, veh_y):
        #             plot_phi_line(veh_x, veh_y, veh_phi, 'black')
        #             task = MODE2TASK[mode]
        #             color = task2color[task]
        #             draw_rotate_rec(veh_x, veh_y, veh_phi, veh_l, veh_w, color, linestyle=':')

        ego_v_x = self.env.ego_dynamics['v_x']
        ego_v_y = self.env.ego_dynamics['v_y']
        ego_r = self.env.ego_dynamics['r']
        ego_x = self.env.ego_dynamics['x']
        ego_y = self.env.ego_dynamics['y']
        ego_phi = self.env.ego_dynamics['phi']
        ego_l = self.env.ego_dynamics['l']
        ego_w = self.env.ego_dynamics['w']
        ego_alpha_f = self.env.ego_dynamics['alpha_f']
        ego_alpha_r = self.env.ego_dynamics['alpha_r']
        alpha_f_bound = self.env.ego_dynamics['alpha_f_bound']
        alpha_r_bound = self.env.ego_dynamics['alpha_r_bound']
        r_bound = self.env.ego_dynamics['r_bound']

        plot_phi_line(ego_x, ego_y, ego_phi, 'fuchsia')
        draw_rotate_rec(ego_x, ego_y, ego_phi, ego_l, ego_w, 'fuchsia')

        # plot future data
        tracking_info = self.obs[
                        self.env.ego_info_dim:self.env.ego_info_dim + self.env.per_tracking_info_dim * (self.env.num_future_data + 1)]
        future_path = tracking_info[self.env.per_tracking_info_dim:]
        for i in range(self.env.num_future_data):
            delta_x, delta_y, delta_phi = future_path[i * self.env.per_tracking_info_dim:
                                                      (i + 1) * self.env.per_tracking_info_dim]
            path_x, path_y, path_phi = ego_x + delta_x, ego_y + delta_y, ego_phi - delta_phi
            plt.plot(path_x, path_y, 'g.')
            plot_phi_line(path_x, path_y, path_phi, 'g')

        delta_, _, _ = tracking_info[:3]
        indexs, points = self.env.ref_path.find_closest_point(np.array([ego_x], np.float32), np.array([ego_y], np.float32))
        path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
        # plt.plot(path_x, path_y, 'g.')
        delta_x, delta_y, delta_phi = ego_x - path_x, ego_y - path_y, ego_phi - path_phi

        # plot real time traj
        try:
            color = ['blue', 'coral', 'cyan']
            for i, item in enumerate(traj_list):
                if method == 'ADP':
                    if i == ADP_path_index:
                        plt.plot(item.path[0], item.path[1], color=color[i])
                    else:
                        plt.plot(item.path[0], item.path[1], color=color[i], alpha=0.3)
                    indexs, points = item.find_closest_point(np.array([ego_x], np.float32), np.array([ego_y], np.float32))
                    path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
                    plt.plot(path_x, path_y, color=color[i])
                elif method == 'MPC':
                    if i == MPC_path_index:
                        plt.plot(item.path[0], item.path[1], color=color[i])
                    else:
                        plt.plot(item.path[0], item.path[1], color=color[i], alpha=0.3)
                    indexs, points = item.find_closest_point(np.array([ego_x], np.float32), np.array([ego_y], np.float32))
                    path_x, path_y, path_phi = points[0][0], points[1][0], points[2][0]
                    plt.plot(path_x, path_y, color=color[i])
        except Exception:
            pass

        # plot ego dynamics
        text_x, text_y_start = -120, 60
        ge = iter(range(0, 1000, 4))
        plt.text(text_x, text_y_start - next(ge), 'ego_x: {:.2f}m'.format(ego_x))
        plt.text(text_x, text_y_start - next(ge), 'ego_y: {:.2f}m'.format(ego_y))
        plt.text(text_x, text_y_start - next(ge), 'path_x: {:.2f}m'.format(path_x))
        plt.text(text_x, text_y_start - next(ge), 'path_y: {:.2f}m'.format(path_y))
        plt.text(text_x, text_y_start - next(ge), 'delta_: {:.2f}m'.format(delta_))
        plt.text(text_x, text_y_start - next(ge), 'delta_x: {:.2f}m'.format(delta_x))
        plt.text(text_x, text_y_start - next(ge), 'delta_y: {:.2f}m'.format(delta_y))
        plt.text(text_x, text_y_start - next(ge), r'ego_phi: ${:.2f}\degree$'.format(ego_phi))
        plt.text(text_x, text_y_start - next(ge), r'path_phi: ${:.2f}\degree$'.format(path_phi))
        plt.text(text_x, text_y_start - next(ge), r'delta_phi: ${:.2f}\degree$'.format(delta_phi))
        plt.text(text_x, text_y_start - next(ge), 'v_x: {:.2f}m/s'.format(ego_v_x))
        plt.text(text_x, text_y_start - next(ge), 'exp_v: {:.2f}m/s'.format(self.env.exp_v))
        plt.text(text_x, text_y_start - next(ge), 'v_y: {:.2f}m/s'.format(ego_v_y))
        plt.text(text_x, text_y_start - next(ge), 'yaw_rate: {:.2f}rad/s'.format(ego_r))
        plt.text(text_x, text_y_start - next(ge), 'yaw_rate bound: [{:.2f}, {:.2f}]'.format(-r_bound, r_bound))

        plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$: {:.2f} rad'.format(ego_alpha_f))
        plt.text(text_x, text_y_start - next(ge), r'$\alpha_f$ bound: [{:.2f}, {:.2f}] '.format(-alpha_f_bound,
                                                                                                alpha_f_bound))
        plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$: {:.2f} rad'.format(ego_alpha_r))
        plt.text(text_x, text_y_start - next(ge), r'$\alpha_r$ bound: [{:.2f}, {:.2f}] '.format(-alpha_r_bound,
                                                                                                alpha_r_bound))
        if self.env.action is not None:
            steer, a_x = self.env.action[0], self.env.action[1]
            plt.text(text_x, text_y_start - next(ge),
                     r'steer: {:.2f}rad (${:.2f}\degree$)'.format(steer, steer * 180 / np.pi))
            plt.text(text_x, text_y_start - next(ge), 'a_x: {:.2f}m/s^2'.format(a_x))

        text_x, text_y_start = 70, 60
        ge = iter(range(0, 1000, 4))

        # done info
        plt.text(text_x, text_y_start - next(ge), 'done info: {}'.format(self.env.done_type))

        # reward info
        if self.env.reward_info is not None:
            for key, val in self.env.reward_info.items():
                plt.text(text_x, text_y_start - next(ge), '{}: {:.4f}'.format(key, val))

        # indicator for Atrajectory selection
        text_x, text_y_start = 18, -70
        ge = iter(range(0, 1000, 6))
        plt.text(text_x+10, text_y_start - next(ge), 'ADP', fontsize=14, color='r', fontstyle='italic')
        if ADP_traj_return_value is not None:
            for i, value in enumerate(ADP_traj_return_value):
                if i == ADP_path_index:
                    plt.text(text_x, text_y_start - next(ge), 'Path cost={:.4f}'.format(value), fontsize=14,
                             color=color[i], fontstyle='italic')
                else:
                    plt.text(text_x, text_y_start - next(ge), 'Path cost={:.4f}'.format(value), fontsize=12,
                             color=color[i], fontstyle='italic')

        text_x, text_y_start = -36, -70
        ge = iter(range(0, 1000, 6))
        plt.text(text_x+10, text_y_start - next(ge), 'MPC', fontsize=14, color='r', fontstyle='italic')
        if MPC_traj_return_value is not None:
            for i, value in enumerate(MPC_traj_return_value):
                if i == MPC_path_index:
                    plt.text(text_x, text_y_start - next(ge), 'Path cost={:.4f}'.format(value), fontsize=14,
                             color=color[i], fontstyle='italic')
                else:
                    plt.text(text_x, text_y_start - next(ge), 'Path cost={:.4f}'.format(value), fontsize=12,
                             color=color[i], fontstyle='italic')
        plt.pause(0.001)