Example #1
0
 def pass_arrow(loc0, loc1):
     style = patches.ArrowStyle('->', head_length=5, head_width=5)
     connection = patches.ConnectionStyle("Arc3", rad=0)
     arrow = patches.FancyArrowPatch(tuple(loc0), tuple(
         loc1), arrowstyle=style, connectionstyle=connection, linestyle='-')
     return arrow
Example #2
0
 def shot_arrow(loc0, loc1):
     style = patches.ArrowStyle('-|>', head_length=2, head_width=2)
     connection = patches.ConnectionStyle("Arc3", rad=0)
     arrow = patches.FancyArrowPatch(tuple(loc0), tuple(
         loc1), arrowstyle=style, connectionstyle=connection, linestyle='-', color='red', linewidth=2)
     return arrow
    def render(self, mode='human', output_file=None):
        from matplotlib import animation
        import matplotlib.pyplot as plt
        plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'

        x_offset = 0.11
        y_offset = 0.11
        cmap = plt.cm.get_cmap('hsv', 10)
        robot_color = 'yellow'
        goal_color = 'red'
        arrow_color = 'red'
        arrow_style = patches.ArrowStyle("->", head_length=4, head_width=2)

        if mode == 'human':
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.set_xlim(-4, 4)
            ax.set_ylim(-4, 4)
            for human in self.humans:
                human_circle = plt.Circle(human.get_position(),
                                          human.radius,
                                          fill=False,
                                          color='b')
                ax.add_artist(human_circle)
            ax.add_artist(
                plt.Circle(self.robot.get_position(),
                           self.robot.radius,
                           fill=True,
                           color='r'))
            plt.show()
        elif mode == 'traj':
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=16)
            ax.set_xlim(-5, 5)
            ax.set_ylim(-5, 5)
            ax.set_xlabel('x(m)', fontsize=16)
            ax.set_ylabel('y(m)', fontsize=16)

            robot_positions = [
                self.states[i][0].position for i in range(len(self.states))
            ]
            human_positions = [[
                self.states[i][1][j].position for j in range(len(self.humans))
            ] for i in range(len(self.states))]
            for k in range(len(self.states)):
                if k % 4 == 0 or k == len(self.states) - 1:
                    robot = plt.Circle(robot_positions[k],
                                       self.robot.radius,
                                       fill=True,
                                       color=robot_color)
                    humans = [
                        plt.Circle(human_positions[k][i],
                                   self.humans[i].radius,
                                   fill=False,
                                   color=cmap(i))
                        for i in range(len(self.humans))
                    ]
                    ax.add_artist(robot)
                    for human in humans:
                        ax.add_artist(human)
                # add time annotation
                global_time = k * self.time_step
                if global_time % 4 == 0 or k == len(self.states) - 1:
                    agents = humans + [robot]
                    times = [
                        plt.text(agents[i].center[0] - x_offset,
                                 agents[i].center[1] - y_offset,
                                 '{:.1f}'.format(global_time),
                                 color='black',
                                 fontsize=14)
                        for i in range(self.human_num + 1)
                    ]
                    for time in times:
                        ax.add_artist(time)
                if k != 0:
                    nav_direction = plt.Line2D(
                        (self.states[k - 1][0].px, self.states[k][0].px),
                        (self.states[k - 1][0].py, self.states[k][0].py),
                        color=robot_color,
                        ls='solid')
                    human_directions = [
                        plt.Line2D((self.states[k - 1][1][i].px,
                                    self.states[k][1][i].px),
                                   (self.states[k - 1][1][i].py,
                                    self.states[k][1][i].py),
                                   color=cmap(i),
                                   ls='solid') for i in range(self.human_num)
                    ]
                    ax.add_artist(nav_direction)
                    for human_direction in human_directions:
                        ax.add_artist(human_direction)
            plt.legend([robot], ['Robot'], fontsize=16)
            plt.show()
        elif mode == 'video':
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=16)
            ax.set_xlim(-6, 6)
            ax.set_ylim(-6, 6)
            ax.set_xlabel('x(m)', fontsize=16)
            ax.set_ylabel('y(m)', fontsize=16)

            # add robot and its goal
            robot_positions = [state[0].position for state in self.states]
            goal = mlines.Line2D([0], [4],
                                 color=goal_color,
                                 marker='*',
                                 linestyle='None',
                                 markersize=15,
                                 label='Goal')
            robot = plt.Circle(robot_positions[0],
                               self.robot.radius,
                               fill=True,
                               color=robot_color)
            ax.add_artist(robot)
            ax.add_artist(goal)
            plt.legend([robot, goal], ['Robot', 'Goal'], fontsize=16)

            # add humans and their numbers
            human_positions = [[
                state[1][j].position for j in range(len(self.humans))
            ] for state in self.states]
            humans = [
                plt.Circle(human_positions[0][i],
                           self.humans[i].radius,
                           fill=False) for i in range(len(self.humans))
            ]
            human_numbers = [
                plt.text(humans[i].center[0] - x_offset,
                         humans[i].center[1] - y_offset,
                         str(i),
                         color='black',
                         fontsize=12) for i in range(len(self.humans))
            ]
            for i, human in enumerate(humans):
                ax.add_artist(human)
                ax.add_artist(human_numbers[i])

            # add time annotation
            time = plt.text(-1, 5, 'Time: {}'.format(0), fontsize=16)
            ax.add_artist(time)

            # compute attention scores
            # if self.attention_weights is not None:
            #     attention_scores = [
            #         plt.text(-5.5, 5 - 0.5 * i, 'Human {}: {:.2f}'.format(i + 1, self.attention_weights[0][i]),
            #                  fontsize=16) for i in range(len(self.humans))]

            # compute orientation in each step and use arrow to show the direction
            radius = self.robot.radius
            if self.robot.kinematics == 'unicycle':
                orientation = [
                    ((state[0].px, state[0].py),
                     (state[0].px + radius * np.cos(state[0].theta),
                      state[0].py + radius * np.sin(state[0].theta)))
                    for state in self.states
                ]
                orientations = [orientation]
            else:
                orientations = []
                for i in range(self.human_num + 1):
                    orientation = []
                    for state in self.states:
                        if i == 0:
                            agent_state = state[0]
                        else:
                            agent_state = state[1][i - 1]
                        theta = np.arctan2(agent_state.vy, agent_state.vx)
                        orientation.append(
                            ((agent_state.px, agent_state.py),
                             (agent_state.px + radius * np.cos(theta),
                              agent_state.py + radius * np.sin(theta))))
                    orientations.append(orientation)
            arrows = [
                patches.FancyArrowPatch(*orientation[0],
                                        color=arrow_color,
                                        arrowstyle=arrow_style)
                for orientation in orientations
            ]
            for arrow in arrows:
                ax.add_artist(arrow)
            global_step = 0

            def update(frame_num):
                nonlocal global_step
                nonlocal arrows
                global_step = frame_num
                robot.center = robot_positions[frame_num]
                for i, human in enumerate(humans):
                    human.center = human_positions[frame_num][i]
                    human_numbers[i].set_position((human.center[0] - x_offset,
                                                   human.center[1] - y_offset))
                    for arrow in arrows:
                        arrow.remove()
                    arrows = [
                        patches.FancyArrowPatch(*orientation[frame_num],
                                                color=arrow_color,
                                                arrowstyle=arrow_style)
                        for orientation in orientations
                    ]
                    for arrow in arrows:
                        ax.add_artist(arrow)
                    # if self.attention_weights is not None:
                    #     human.set_color(str(self.attention_weights[frame_num][i]))
                    #     attention_scores[i].set_text('human {}: {:.2f}'.format(i, self.attention_weights[frame_num][i]))

                time.set_text('Time: {:.2f}'.format(frame_num *
                                                    self.time_step))

            def plot_value_heatmap():
                assert self.robot.kinematics == 'holonomic'
                for agent in [self.states[global_step][0]
                              ] + self.states[global_step][1]:
                    print(('{:.4f}, ' * 6 + '{:.4f}').format(
                        agent.px, agent.py, agent.gx, agent.gy, agent.vx,
                        agent.vy, agent.theta))
                # when any key is pressed draw the action value plot
                fig, axis = plt.subplots()
                speeds = [0] + self.robot.policy.speeds
                rotations = self.robot.policy.rotations + [np.pi * 2]
                r, th = np.meshgrid(speeds, rotations)
                z = np.array(self.action_values[global_step %
                                                len(self.states)][1:])
                z = (z - np.min(z)) / (np.max(z) - np.min(z))
                z = np.reshape(z, (16, 5))
                polar = plt.subplot(projection="polar")
                polar.tick_params(labelsize=16)
                mesh = plt.pcolormesh(th, r, z, vmin=0, vmax=1)
                plt.plot(rotations, r, color='k', ls='none')
                plt.grid()
                cbaxes = fig.add_axes([0.85, 0.1, 0.03, 0.8])
                cbar = plt.colorbar(mesh, cax=cbaxes)
                cbar.ax.tick_params(labelsize=16)
                plt.show()

            def on_click(event):
                anim.running ^= True
                if anim.running:
                    anim.event_source.stop()
                    if hasattr(self.robot.policy, 'action_values'):
                        plot_value_heatmap()
                else:
                    anim.event_source.start()

            fig.canvas.mpl_connect('key_press_event', on_click)
            anim = animation.FuncAnimation(fig,
                                           update,
                                           frames=len(self.states),
                                           interval=self.time_step * 1000)
            anim.running = True

            if output_file is not None:
                ffmpeg_writer = animation.writers['ffmpeg']
                writer = ffmpeg_writer(fps=8,
                                       metadata=dict(artist='Me'),
                                       bitrate=1800)
                anim.save(output_file, writer=writer)
            else:
                plt.show()
        else:
            raise NotImplementedError
Example #4
0
 def carry_arrow(loc0, loc1):
     style = patches.ArrowStyle('-')
     connection = patches.ConnectionStyle("Arc3", rad=0)
     arrow = patches.FancyArrowPatch(tuple(loc0), tuple(
         loc1), arrowstyle=style, connectionstyle=connection, linestyle='--')
     return arrow
Example #5
0
def plot_base_classification(plt_mtype, plt_mcomb, pred_df, pheno_dict, cdata,
                             args):
    fig, (coh_ax, clf_ax,
          ovp_ax) = plt.subplots(figsize=(5, 8),
                                 nrows=3,
                                 ncols=1,
                                 gridspec_kw=dict(height_ratios=[1, 3, 3]))

    plt_df = pd.DataFrame({
        'Value':
        pred_df.loc[plt_mtype, cdata.get_train_samples()].apply(np.mean),
        'cStat':
        pheno_dict[plt_mtype],
        'rStat':
        np.array(cdata.train_pheno(plt_mcomb.not_mtype))
    })

    mut_prop = np.sum(plt_df.cStat) / len(cdata.get_samples())
    ovlp_prop = np.mean(~plt_df.rStat[~plt_df.cStat]) * (1 - mut_prop)
    mtype_lbl = get_fancy_label(plt_mtype)
    mtype_tbox = get_fancy_label(plt_mtype, phrase_link='\n')

    for ax in coh_ax, clf_ax, ovp_ax:
        ax.axis('off')

    coh_ax.text(0.63,
                1,
                "{}\n({} samples)".format(get_cohort_label(args.cohort),
                                          len(cdata.get_samples())),
                size=12,
                ha='center',
                va='top')

    coh_ax.add_patch(
        ptchs.FancyArrowPatch(posA=(0.63, 0.66),
                              posB=(0.63, 0.52),
                              arrowstyle=ptchs.ArrowStyle('-[',
                                                          lengthB=7.1,
                                                          widthB=119)))

    coh_ax.add_patch(
        ptchs.Rectangle((0.3, 0.28), (1 - mut_prop) * 0.66,
                        0.22,
                        facecolor=variant_clrs['WT'],
                        alpha=0.41,
                        hatch='/',
                        linewidth=1.3,
                        edgecolor='0.51'))
    coh_ax.add_patch(
        ptchs.Rectangle((0.3 + (1 - mut_prop) * 0.66, 0.28),
                        mut_prop * 0.66,
                        0.22,
                        facecolor=variant_clrs['Point'],
                        alpha=0.41,
                        hatch='/',
                        linewidth=1.3,
                        edgecolor='0.51'))

    coh_ax.text(0.28,
                0.39,
                "mutated status for:\n{}".format(mtype_tbox),
                size=11,
                ha='right',
                va='center')

    coh_ax.add_patch(
        ptchs.Rectangle((0.3 + ovlp_prop * 0.66, 0.28),
                        np.mean(plt_df.rStat) * 0.66,
                        0.22,
                        hatch='\\',
                        linewidth=1.3,
                        edgecolor='0.51',
                        facecolor='None'))

    coh_ax.add_patch(
        ptchs.Rectangle((0.3 + ovlp_prop * 0.66, 0.01),
                        np.mean(plt_df.rStat) * 0.66,
                        0.22,
                        alpha=0.83,
                        hatch='\\',
                        linewidth=1.3,
                        edgecolor='0.51',
                        facecolor=variant_clrs['Point']))

    coh_ax.text(0.29 + ovlp_prop * 0.66,
                0.23,
                "{} mutations\nother than {}".format(args.gene, mtype_tbox),
                color=variant_clrs['Point'],
                size=10,
                ha='right',
                va='top')
    coh_ax.text(0.3 + ovlp_prop * 0.66 + np.mean(plt_df.rStat) * 0.33,
                -0.02,
                "({} samples)".format(np.sum(plt_df.rStat)),
                color=variant_clrs['Point'],
                size=10,
                ha='center',
                va='top')

    diag_ax1 = clf_ax.inset_axes(bounds=(0, 0, 0.67, 1))
    vio_ax1 = clf_ax.inset_axes(bounds=(0.67, 0, 0.33, 1))
    diag_ax2 = ovp_ax.inset_axes(bounds=(0, 0, 0.67, 1))
    vio_ax2 = ovp_ax.inset_axes(bounds=(0.67, -0.11, 0.33, 0.97))

    for diag_ax in diag_ax1, diag_ax2:
        diag_ax.axis('off')
        diag_ax.set_aspect('equal')

        diag_ax.add_patch(
            ptchs.FancyArrow(0.85,
                             0.57,
                             dx=0.14,
                             dy=0,
                             width=0.03,
                             length_includes_head=True,
                             head_length=0.06,
                             linewidth=1.7,
                             facecolor='white',
                             edgecolor='black'))

    diag_ax1.add_patch(
        ptchs.Circle((0.5, 0.85),
                     radius=0.14,
                     facecolor=variant_clrs['Point'],
                     alpha=0.41))
    diag_ax1.text(0.5,
                  0.85,
                  "mutant for:\n{}\n({} samples)".format(
                      mtype_tbox, np.sum(plt_df.cStat)),
                  size=8,
                  ha='center',
                  va='center')

    diag_ax1.add_patch(
        ptchs.Circle((0.5, 0.32),
                     radius=0.31,
                     facecolor=variant_clrs['WT'],
                     alpha=0.41))
    diag_ax1.text(0.5,
                  0.32,
                  "wild-type for:\n{}\n({} samples)".format(
                      mtype_tbox, np.sum(~plt_df.cStat)),
                  size=13,
                  ha='center',
                  va='center')

    diag_ax1.text(0.2,
                  0.67,
                  "predict\nmutated\nstatus",
                  color='red',
                  size=12,
                  fontstyle='italic',
                  ha='right',
                  va='center')
    diag_ax1.axhline(y=0.67,
                     xmin=0.23,
                     xmax=0.86,
                     color='red',
                     linestyle='--',
                     linewidth=2.7,
                     alpha=0.83)

    diag_ax1.text(0.82,
                  0.68,
                  "{} (+)".format(np.sum(plt_df.cStat)),
                  color='red',
                  size=9,
                  fontstyle='italic',
                  ha='right',
                  va='bottom')
    diag_ax1.text(0.82,
                  0.655,
                  "{} (\u2212)".format(np.sum(~plt_df.cStat)),
                  color='red',
                  size=9,
                  fontstyle='italic',
                  ha='right',
                  va='top')

    sns.violinplot(data=plt_df[~plt_df.cStat],
                   y='Value',
                   ax=vio_ax1,
                   palette=[variant_clrs['WT']],
                   linewidth=0,
                   cut=0)
    sns.violinplot(data=plt_df[plt_df.cStat],
                   y='Value',
                   ax=vio_ax1,
                   palette=[variant_clrs['Point']],
                   linewidth=0,
                   cut=0)

    vio_ax1.text(0.5,
                 113 / 111,
                 "AUC: {:.3f}".format(
                     calc_auc(plt_df.Value.values, plt_df.cStat)),
                 color='red',
                 size=14,
                 fontstyle='italic',
                 ha='center',
                 va='bottom',
                 transform=vio_ax1.transAxes)

    diag_ax2.add_patch(
        ptchs.Wedge((0.48, 0.89),
                    0.14,
                    90,
                    270,
                    facecolor=variant_clrs['Point'],
                    alpha=0.41,
                    hatch='/',
                    linewidth=0.8,
                    edgecolor='0.51',
                    clip_on=False))

    diag_ax2.add_patch(
        ptchs.Wedge((0.52, 0.89),
                    0.14,
                    270,
                    90,
                    facecolor=variant_clrs['Point'],
                    alpha=0.41,
                    hatch='/',
                    linewidth=0.8,
                    edgecolor='0.51',
                    clip_on=False))
    diag_ax2.add_patch(
        ptchs.Wedge((0.52, 0.89),
                    0.14,
                    270,
                    90,
                    facecolor='None',
                    edgecolor='0.61',
                    hatch='\\',
                    linewidth=0.8,
                    clip_on=False))

    diag_ax2.text(0.22,
                  0.69,
                  "same classifier\nresults",
                  color='red',
                  size=10,
                  fontstyle='italic',
                  ha='right',
                  va='center')
    diag_ax2.axhline(y=0.69,
                     xmin=0.23,
                     xmax=0.86,
                     color='red',
                     linestyle='--',
                     linewidth=1.3,
                     alpha=0.67)

    diag_ax2.add_patch(
        ptchs.Wedge((0.48, 0.32),
                    0.31,
                    90,
                    270,
                    facecolor=variant_clrs['WT'],
                    alpha=0.41,
                    hatch='/',
                    linewidth=0.8,
                    edgecolor='0.51'))

    diag_ax2.add_patch(
        ptchs.Wedge((0.52, 0.32),
                    0.31,
                    270,
                    90,
                    facecolor=variant_clrs['WT'],
                    alpha=0.41,
                    hatch='/',
                    linewidth=0.8,
                    edgecolor='0.51'))
    diag_ax2.add_patch(
        ptchs.Wedge((0.52, 0.32),
                    0.31,
                    270,
                    90,
                    facecolor='None',
                    edgecolor='0.61',
                    linewidth=0.8,
                    hatch='\\'))

    diag_ax2.text(0.33,
                  0.89,
                  "mutant for:\n{}\nw/o overlap\n({} samps)".format(
                      mtype_tbox, np.sum(plt_df.cStat & ~plt_df.rStat)),
                  size=9,
                  ha='right',
                  va='center')
    diag_ax2.text(0.67,
                  0.89,
                  "mutant for:\n{}\nw/ overlap\n({} samps)".format(
                      mtype_tbox, np.sum(plt_df.cStat & plt_df.rStat)),
                  size=9,
                  ha='left',
                  va='center')

    diag_ax2.text(0.47,
                  0.32,
                  "wild-type for:\n{}\nw/o overlap\n({} samps)".format(
                      mtype_tbox, np.sum(~plt_df.cStat & ~plt_df.rStat)),
                  size=10,
                  ha='right',
                  va='center')
    diag_ax2.text(0.53,
                  0.32,
                  "wild-type for:\n{}\nw/ overlap\n({} samps)".format(
                      mtype_tbox, np.sum(~plt_df.cStat & plt_df.rStat)),
                  size=10,
                  ha='left',
                  va='center')

    sns.violinplot(data=plt_df[~plt_df.cStat],
                   x='cStat',
                   y='Value',
                   hue='rStat',
                   palette=[variant_clrs['WT']],
                   hue_order=[False, True],
                   split=True,
                   linewidth=0,
                   cut=0,
                   ax=vio_ax2)
    sns.violinplot(data=plt_df[plt_df.cStat],
                   x='cStat',
                   y='Value',
                   hue='rStat',
                   palette=[variant_clrs['Point']],
                   hue_order=[False, True],
                   split=True,
                   linewidth=0,
                   cut=0,
                   ax=vio_ax2)

    vio_ax2.get_legend().remove()
    diag_ax2.axvline(x=0.5,
                     ymin=-0.03,
                     ymax=1.03,
                     clip_on=False,
                     color=variant_clrs['Point'],
                     linewidth=1.1,
                     alpha=0.81,
                     linestyle=':')

    diag_ax2.text(0.5,
                  -0.05,
                  "partition scored samples according to\noverlap with "
                  "{} mutations\nthat are not {}".format(args.gene, mtype_lbl),
                  color=variant_clrs['Point'],
                  size=10,
                  fontstyle='italic',
                  ha='center',
                  va='top')

    for vio_ax in vio_ax1, vio_ax2:
        vio_ax.set_xticks([])
        vio_ax.set_xticklabels([])
        vio_ax.set_yticklabels([])
        vio_ax.xaxis.label.set_visible(False)
        vio_ax.yaxis.label.set_visible(False)

    vio_ax1.get_children()[0].set_alpha(0.41)
    vio_ax1.get_children()[2].set_alpha(0.41)
    for i in [0, 1, 3, 4]:
        vio_ax2.get_children()[i].set_alpha(0.41)

    for i in [0, 3]:
        vio_ax2.get_children()[i].set_linewidth(0.8)
        vio_ax2.get_children()[i].set_hatch('/')
        vio_ax2.get_children()[i].set_edgecolor('0.61')

    for i in [1, 4]:
        vio_ax2.get_children()[i].set_linewidth(1.0)
        vio_ax2.get_children()[i].set_hatch('/\\')
        vio_ax2.get_children()[i].set_edgecolor('0.47')

    vio_ax2.text(0.15,
                 1.1,
                 "{}\nw/o overlap".format(mtype_tbox),
                 color=variant_clrs['Point'],
                 size=10,
                 fontstyle='italic',
                 ha='center',
                 va='bottom',
                 transform=vio_ax2.transAxes)

    vio_ax2.text(0.15,
                 113 / 111,
                 "AUC: {:.3f}".format(
                     calc_auc(plt_df.Value[~plt_df.rStat].values,
                              plt_df.cStat[~plt_df.rStat])),
                 color='red',
                 size=13,
                 fontstyle='italic',
                 ha='center',
                 va='bottom',
                 transform=vio_ax2.transAxes)

    vio_ax2.text(0.85,
                 1.1,
                 "{}\nw/ overlap".format(mtype_tbox),
                 color=variant_clrs['Point'],
                 size=10,
                 fontstyle='italic',
                 ha='center',
                 va='bottom',
                 transform=vio_ax2.transAxes)

    vio_ax2.text(0.85,
                 113 / 111,
                 "AUC: {:.3f}".format(
                     calc_auc(plt_df.Value[plt_df.rStat].values,
                              plt_df.cStat[plt_df.rStat])),
                 color='red',
                 size=13,
                 fontstyle='italic',
                 ha='center',
                 va='bottom',
                 transform=vio_ax2.transAxes)

    plt.tight_layout(pad=-1, h_pad=2.3)
    plt.savefig(os.path.join(plot_dir, args.gene,
                             "{}__base-classification.svg".format(
                                 args.cohort)),
                bbox_inches='tight',
                format='svg')

    plt.close()
    def render(self, mode='human', output_file=None):
        from matplotlib import animation
        import matplotlib.pyplot as plt
        #plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'

        x_offset = 0.11
        y_offset = 0.11
        cmap = plt.cm.get_cmap('hsv', 10)
        robot_color = 'yellow'
        goal_color = 'red'
        arrow_color = 'red'
        arrow_style = patches.ArrowStyle("->", head_length=4, head_width=2)


        if mode == 'video':
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=16)
            ax.set_xlim(-6, 6)
            ax.set_ylim(-6, 6)
            ax.set_xlabel('x(m)', fontsize=16)
            ax.set_ylabel('y(m)', fontsize=16)

            # add robot and its goal
            robot_positions = [state[0].position for state in self.states]
            goal = mlines.Line2D([0], [4], color=goal_color, marker='*', linestyle='None', markersize=15, label='Goal')
            robot = plt.Circle(robot_positions[0], self.robot.radius, fill=True, color=robot_color)
            ax.add_artist(robot)
            ax.add_artist(goal)
            plt.legend([robot, goal], ['Robot', 'Goal'], fontsize=16)

            # add humans and their numbers
            obstacle_positions = [[state[1][j].position for j in range(len(self.obstacles))] for state in self.states]
            obstacles = [plt.Circle(obstacle_positions[0][i], self.obstacles[i].radius, fill=False)
                      for i in range(len(self.obstacles))]
            obstacle_numbers = [plt.text(obstacles[i].center[0] - x_offset, obstacles[i].center[1] - y_offset, str(i),
                                      color='black', fontsize=12) for i in range(len(self.obstacles))]
            for i, obstacle in enumerate(obstacles):
                ax.add_artist(obstacle)
                ax.add_artist(obstacle_numbers[i])

            # add time annotation
            time = plt.text(-1, 5, 'Time: {}'.format(0), fontsize=16)
            ax.add_artist(time)

            # compute attention scores

            # compute orientation in each step and use arrow to show the direction
            radius = self.robot.radius
            global_step = 0

            def update(frame_num):
                nonlocal global_step
                global_step = frame_num
                robot.center = robot_positions[frame_num]
                for i, obstacle in enumerate(obstacles):
                    obstacle.center = obstacle_positions[frame_num][i]
                    obstacle_numbers[i].set_position((obstacle.center[0] - x_offset, obstacle.center[1] - y_offset))

                time.set_text('Time: {:.2f}'.format(frame_num * self.time_step))

            def plot_value_heatmap():
                assert self.robot.kinematics == 'holonomic'
                for agent in [self.states[global_step][0]] + self.states[global_step][1]:
                    print(('{:.4f}, ' * 6 + '{:.4f}').format(agent.px, agent.py, agent.gx, agent.gy,
                                                             agent.vx, agent.vy, agent.theta))
                # when any key is pressed draw the action value plot
                fig, axis = plt.subplots()
                speeds = [0] + self.robot.policy.speeds
                rotations = self.robot.policy.rotations + [np.pi * 2]
                r, th = np.meshgrid(speeds, rotations)
                z = np.array(self.action_values[global_step % len(self.states)][1:])
                z = (z - np.min(z)) / (np.max(z) - np.min(z))
                z = np.reshape(z, (16, 5))
                polar = plt.subplot(projection="polar")
                polar.tick_params(labelsize=16)
                mesh = plt.pcolormesh(th, r, z, vmin=0, vmax=1)
                plt.plot(rotations, r, color='k', ls='none')
                plt.grid()
                cbaxes = fig.add_axes([0.85, 0.1, 0.03, 0.8])
                cbar = plt.colorbar(mesh, cax=cbaxes)
                cbar.ax.tick_params(labelsize=16)
                plt.show()

            def on_click(event):
                anim.running ^= True
                if anim.running:
                    anim.event_source.stop()
                    if hasattr(self.robot.policy, 'action_values'):
                        plot_value_heatmap()
                else:
                    anim.event_source.start()

            fig.canvas.mpl_connect('key_press_event', on_click)
            anim = animation.FuncAnimation(fig, update, frames=len(self.states), interval=self.time_step * 1000)
            anim.running = True

            if output_file is not None:
                ffmpeg_writer = animation.writers['pillow']
                #writer = animation.FFMpegWriter(fps=20, metadata=dict(artist='Me'), bitrate=1800)
                writer = ffmpeg_writer(fps=8, metadata=dict(artist='Me'), bitrate=1800)
                anim.save(output_file, writer=writer)
            else:
                plt.show()
        else:
            raise NotImplementedError
SG = nx.DiGraph(tmp)
# iterate over members
for n in nodelist:
    # add node if not present in the original graph
    # this could happen if in that day the player was offline
    if n not in SG.nodes:
        SG.add_node(n)

# draw the graph
plt.figure(figsize=(10, 10), dpi=500)
plt.title(edge_type[0].upper() + edge_type[1:], fontsize=30)
if edge_type != 'attacks':
    nx.draw_kamada_kawai(SG,
                         arrowsize=2,
                         arrowstyle=mpatch.ArrowStyle("-|>",
                                                      head_length=1,
                                                      head_width=1),
                         with_labels=False,
                         node_size=50,
                         nodelist=nodelist,
                         node_color=colorlist)
else:
    nx.draw_spring(SG,
                   arrowsize=2,
                   arrowstyle=mpatch.ArrowStyle("-|>",
                                                head_length=1,
                                                head_width=1),
                   with_labels=False,
                   node_size=50,
                   nodelist=nodelist,
                   node_color=colorlist)
    def render(self, mode='video', output_file=None):
        from matplotlib import animation
        import matplotlib.pyplot as plt
        # plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
        x_offset = 0.2
        y_offset = 0.4
        cmap = plt.cm.get_cmap('hsv', 10)
        robot_color = 'black'
        arrow_style = patches.ArrowStyle("->", head_length=4, head_width=2)
        display_numbers = False

        if mode == 'traj':
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=16)

            # add human start positions and goals
            human_colors = [cmap(i) for i in range(len(self.humans))]
            for i in range(len(self.humans)):
                human = self.humans[i]
                human_goal = mlines.Line2D([human.get_goal_position()[0]], [human.get_goal_position()[1]],
                                           color=human_colors[i],
                                           marker='*', linestyle='None', markersize=15)
                ax.add_artist(human_goal)
                human_start = mlines.Line2D([human.get_start_position()[0]], [human.get_start_position()[1]],
                                            color=human_colors[i],
                                            marker='o', linestyle='None', markersize=15)
                ax.add_artist(human_start)

            robot_positions = [self.states[i][0].position for i in range(len(self.states))]
            human_positions = [[self.states[i][1][j].position for j in range(len(self.humans))]
                               for i in range(len(self.states))]

            for k in range(len(self.states)):
                if k % 4 == 0 or k == len(self.states) - 1:
                    robot = plt.Circle(robot_positions[k], self.robot.radius, fill=False, color=robot_color)
                    humans = [plt.Circle(human_positions[k][i], self.humans[i].radius, fill=False, color=cmap(i))
                              for i in range(len(self.humans))]
                    ax.add_artist(robot)
                    for human in humans:
                        ax.add_artist(human)

                # add time annotation
                global_time = k * self.time_step
                if global_time % 4 == 0 or k == len(self.states) - 1:
                    agents = humans + [robot]
                    times = [plt.text(agents[i].center[0] - x_offset, agents[i].center[1] - y_offset,
                                      '{:.1f}'.format(global_time),
                                      color='black', fontsize=14) for i in range(self.human_num + 1)]
                    for time in times:
                       ax.add_artist(time)
                if k != 0:
                    nav_direction = plt.Line2D((self.states[k - 1][0].px, self.states[k][0].px),
                                               (self.states[k - 1][0].py, self.states[k][0].py),
                                               color=robot_color, ls='solid')
                    human_directions = [plt.Line2D((self.states[k - 1][1][i].px, self.states[k][1][i].px),
                                                   (self.states[k - 1][1][i].py, self.states[k][1][i].py),
                                                   color=cmap(i), ls='solid')
                                        for i in range(self.human_num)]
                    ax.add_artist(nav_direction)
                    for human_direction in human_directions:
                        ax.add_artist(human_direction)
            plt.legend([robot], ['Robot'], fontsize=16)
            plt.show()
        elif mode == 'video':
            fig, ax = plt.subplots(figsize=(8, 8))
            ax.tick_params(labelsize=12)
            ax.set_xlim(-11, 11)
            ax.set_ylim(-11, 11)
            ax.set_xlabel('x(m)', fontsize=14)
            ax.set_ylabel('y(m)', fontsize=14)
            show_human_start_goal = True
            show_sensor_range = True
            show_eval_info = True
            show_social_zone = True

            # add human start positions and goals
            human_colors = [cmap(i) for i in range(len(self.humans))]
            if show_human_start_goal:
                for i in range(len(self.humans)):
                    human = self.humans[i]
                    human_goal = mlines.Line2D([human.get_goal_position()[0]], [human.get_goal_position()[1]],
                                               color=human_colors[i],
                                               marker='*', linestyle='None', markersize=8)
                    ax.add_artist(human_goal)
                    human_start = mlines.Line2D([human.get_start_position()[0]], [human.get_start_position()[1]],
                                                color=human_colors[i],
                                                marker='o', linestyle='None', markersize=8)
                    ax.add_artist(human_start)
            # add robot start position
            robot_start = mlines.Line2D([self.robot.get_start_position()[0]], [self.robot.get_start_position()[1]],
                                        color=robot_color,
                                        marker='o', linestyle='None', markersize=8, label='Start')
            robot_start_position = [self.robot.get_start_position()[0], self.robot.get_start_position()[1]]
            ax.add_artist(robot_start)
            # add robot and its goal 
            robot_positions = [state[0].position for state in self.states]
            goal = mlines.Line2D([self.robot.get_goal_position()[0]], [self.robot.get_goal_position()[1]],
                                 color=robot_color, marker='*', linestyle='None',
                                 markersize=15, label='Goal')
            robot = plt.Circle(robot_positions[0], self.robot.radius, fill=False, color=robot_color)
            ax.add_artist(robot)
            ax.add_artist(goal)
            plt.legend([robot, goal, robot_start], ['Robot', 'Goal', 'Start'], fontsize=14)
            # if show_sensor_range:
            #     sensor_range = plt.Circle(robot_positions[0], self.robot_sensor_range, fill=False, ls='dashed')
            #     ax.add_artist(sensor_range)


            # add humans and their numbers
            human_positions = [[state[1][j].position for j in range(len(self.humans))] for state in self.states]
            humans = [plt.Circle(human_positions[0][i], self.humans[i].radius, fill=False, color=cmap(i))
                      for i in range(len(self.humans))]

            # disable showing human numbers
            if display_numbers:
                human_numbers = [plt.text(humans[i].center[0] - x_offset, humans[i].center[1] + y_offset, str(i),
                                          color='black') for i in range(len(self.humans))]
            
            for i, human in enumerate(humans):
                ax.add_artist(human)
                if display_numbers:
                    ax.add_artist(human_numbers[i])

            # add time annotation
            time = plt.text(0.5, 0.9, f'Time: {0}', fontsize=16, transform=ax.transAxes, horizontalalignment='center',
                verticalalignment='center')
            ax.add_artist(time)

            # add evaluation annotation
            if show_eval_info:
                eval_text = plt.text(0.6, 0.07, 
                    f"Aggregated Time: {0}\nMinimum Separation: {0}\nSocial Zone Violations: {0}\nJerk Cost: {0}",
                    fontsize=12, transform=ax.transAxes, horizontalalignment='left', verticalalignment='center')

            # calculate evaluation information
            list_aggregated_time = [self.infos[0]['aggregated_time']]
            list_min_separation = [self.infos[0]['min_separation']]
            list_personal_violation_cnt = [self.infos[0]['personal_violation_cnt']]
            list_social_violation_cnt = [self.infos[0]['social_violation_cnt']]
            list_jerk_cost = [self.infos[0]['jerk_cost']]
            for i in range(1, len(self.infos)):
                list_aggregated_time.append(list_aggregated_time[i-1] + self.infos[i]['aggregated_time'])
                list_min_separation.append(min(list_min_separation[i-1], self.infos[i]['min_separation']))
                list_social_violation_cnt.append(list_social_violation_cnt[i-1] + self.infos[i]['social_violation_cnt'])
                list_jerk_cost.append(list_jerk_cost[i-1] + self.infos[i]['jerk_cost'])
                list_personal_violation_cnt.append(list_personal_violation_cnt[i-1] + self.infos[i]['personal_violation_cnt'])

            # visualize attention scores
            # if hasattr(self.robot.policy, 'get_attention_weights'):
            #     attention_scores = [
            #         plt.text(-5.5, 5 - 0.5 * i, 'Human {}: {:.2f}'.format(i + 1, self.attention_weights[0][i]),
            #                  fontsize=16) for i in range(len(self.humans))]

            # compute social zone for each step
            social_zones_all_agents = []
            
            if show_social_zone:
                for i in range(self.human_num + 1):
                    social_zones = []
                    step_cnt = 0
                    for state in self.states:
                        step_cnt += 1
                        agent_state = state[0] if i == self.human_num else state[1][i]
                        if i == self.human_num: # robot
                            rect = AgentHeadingRect(agent_state.px, agent_state.py, self.robot.radius, agent_state.vx, agent_state.vy, self.robot.kinematics)
                            if step_cnt < len(self.infos) and self.infos[step_cnt]['social_violation_cnt'] > 0:
                                rect.color = 'red'
                        else:
                            rect = AgentHeadingRect(agent_state.px, agent_state.py, self.humans[i].radius, agent_state.vx, agent_state.vy, self.humans[i].kinematics)
                        social_zones.append(rect.get_pyplot_rect())
                    social_zones_all_agents.append(social_zones)

            # draw the zones for the first step
            social_zones_drawn = []
            for zones in social_zones_all_agents:
                ax.add_artist(zones[0])
                social_zones_drawn.append(zones[0])

            # compute orientation in each step and use arrow to show the direction
            radius = self.robot.radius
            orientations = []
            for i in range(self.human_num + 1):
                orientation = []
                for state in self.states:
                    agent_state = state[0] if i == 0 else state[1][i - 1]
                    if self.robot.kinematics == 'unicycle' and i == 0: # =========================================================== TODO: why unicycle only?
                        direction = (
                        (agent_state.px, agent_state.py), (agent_state.px + radius * np.cos(agent_state.theta),
                                                           agent_state.py + radius * np.sin(agent_state.theta)))
                    else:
                        theta = np.arctan2(agent_state.vy, agent_state.vx)
                        direction = ((agent_state.px, agent_state.py), (agent_state.px + radius * np.cos(theta),
                                                                        agent_state.py + radius * np.sin(theta)))
                    orientation.append(direction)
                orientations.append(orientation)
                if i == 0:
                    arrow_color = 'black'
                    arrows = [patches.FancyArrowPatch(*orientation[0], color=arrow_color, arrowstyle=arrow_style)]
                else:
                    arrows.extend(
                        [patches.FancyArrowPatch(*orientation[0], color=human_colors[i - 1], arrowstyle=arrow_style)])
            for arrow in arrows:
                ax.add_artist(arrow)

            global_step = 0

            if len(self.trajs) != 0:
                human_future_positions = []
                human_future_circles = []
                for traj in self.trajs:
                    human_future_position = [[tensor_to_joint_state(traj[step+1][0]).human_states[i].position
                                              for step in range(self.robot.policy.planning_depth)]
                                             for i in range(self.human_num)]
                    human_future_positions.append(human_future_position)

                for i in range(self.human_num):
                    circles = []
                    for j in range(self.robot.policy.planning_depth):
                        circle = plt.Circle(human_future_positions[0][i][j], self.humans[0].radius/(1.7+j), fill=False, color=cmap(i))
                        ax.add_artist(circle)
                        circles.append(circle)
                    human_future_circles.append(circles)

            def update(frame_num):
                nonlocal global_step
                nonlocal arrows
                nonlocal social_zones_drawn
                global_step = frame_num
                robot.center = robot_positions[frame_num]

                for i, human in enumerate(humans):
                    human.center = human_positions[frame_num][i]
                    if display_numbers:
                        human_numbers[i].set_position((human.center[0] - x_offset, human.center[1] + y_offset))
                for arrow in arrows:
                    arrow.remove()
                for zone in social_zones_drawn: # remove last step's social zones
                    zone.remove()

                # draw social zones for each step
                if show_social_zone:
                    social_zones_drawn = []
                    for i in range(self.human_num + 1):
                        zones = social_zones_all_agents[i]
                        social_zones_drawn.append(zones[frame_num])
                        ax.add_artist(zones[frame_num])

                for i in range(self.human_num + 1):
                    orientation = orientations[i]
                    if i == 0:
                        arrows = [patches.FancyArrowPatch(*orientation[frame_num], color='black',
                                                          arrowstyle=arrow_style)]
                    else:
                        arrows.extend([patches.FancyArrowPatch(*orientation[frame_num], color=cmap(i - 1),
                                                               arrowstyle=arrow_style)])

                for arrow in arrows:
                    ax.add_artist(arrow)
                    # if hasattr(self.robot.policy, 'get_attention_weights'):
                    #     attention_scores[i].set_text('human {}: {:.2f}'.format(i, self.attention_weights[frame_num][i]))

                time.set_text('Time: {:.2f}'.format(frame_num * self.time_step))

                if show_eval_info:
                    eval_text.set_text(f"Aggregated Time: {list_aggregated_time[frame_num]}\
                        \nPersonal Zone Violations: {list_personal_violation_cnt[frame_num]}\
                        \nSocial Zone Violations: {list_social_violation_cnt[frame_num]}\
                        \nJerk Cost: {list_jerk_cost[frame_num]: .3f}")

                if len(self.trajs) != 0:
                    for i, circles in enumerate(human_future_circles):
                        for j, circle in enumerate(circles):
                            circle.center = human_future_positions[global_step][i][j]

            def plot_value_heatmap():
                if self.robot.kinematics != 'holonomic':
                    print('Kinematics is not holonomic')
                    return
                # for agent in [self.states[global_step][0]] + self.states[global_step][1]:
                #     print(('{:.4f}, ' * 6 + '{:.4f}').format(agent.px, agent.py, agent.gx, agent.gy,
                #                                              agent.vx, agent.vy, agent.theta))

                # when any key is pressed draw the action value plot
                fig, axis = plt.subplots()
                speeds = [0] + self.robot.policy.speeds
                rotations = self.robot.policy.rotations + [np.pi * 2]
                r, th = np.meshgrid(speeds, rotations)
                z = np.array(self.action_values[global_step % len(self.states)][1:])
                z = (z - np.min(z)) / (np.max(z) - np.min(z))
                z = np.reshape(z, (self.robot.policy.rotation_samples, self.robot.policy.speed_samples))
                polar = plt.subplot(projection="polar")
                polar.tick_params(labelsize=16)
                mesh = plt.pcolormesh(th, r, z, vmin=0, vmax=1)
                plt.plot(rotations, r, color='k', ls='none')
                plt.grid()
                cbaxes = fig.add_axes([0.85, 0.1, 0.03, 0.8])
                cbar = plt.colorbar(mesh, cax=cbaxes)
                cbar.ax.tick_params(labelsize=16)
                plt.show()

            def print_matrix_A():
                # with np.printoptions(precision=3, suppress=True):
                #     print(self.As[global_step])
                h, w = self.As[global_step].shape
                print('   ' + ' '.join(['{:>5}'.format(i - 1) for i in range(w)]))
                for i in range(h):
                    print('{:<3}'.format(i-1) + ' '.join(['{:.3f}'.format(self.As[global_step][i][j]) for j in range(w)]))
                # with np.printoptions(precision=3, suppress=True):
                #     print('A is: ')
                #     print(self.As[global_step])

            def print_feat():
                with np.printoptions(precision=3, suppress=True):
                    print('feat is: ')
                    print(self.feats[global_step])

            def print_X():
                with np.printoptions(precision=3, suppress=True):
                    print('X is: ')
                    print(self.Xs[global_step])

            def on_click(event):
                if anim.running:
                    anim.event_source.stop()
                    if event.key == 'a':
                        if hasattr(self.robot.policy, 'get_matrix_A'):
                            print_matrix_A()
                        if hasattr(self.robot.policy, 'get_feat'):
                            print_feat()
                        if hasattr(self.robot.policy, 'get_X'):
                            print_X()
                        # if hasattr(self.robot.policy, 'action_values'):
                        #    plot_value_heatmap()
                else:
                    anim.event_source.start()
                anim.running ^= True

            fig.canvas.mpl_connect('key_press_event', on_click)
            anim = animation.FuncAnimation(fig, update, frames=len(self.states), interval=self.time_step * 500, repeat_delay=500)
            anim.running = True

            if output_file is not None:
                # save as video
                # ffmpeg_writer = animation.FFMpegWriter(fps=10, metadata=dict(artist='Me'), bitrate=1800)
                # writer = ffmpeg_writer(fps=10, metadata=dict(artist='Me'), bitrate=1800)
                # anim.save(output_file, writer=ffmpeg_writer)

                # save output file as gif if imagemagic is installed
                plt.rcParams["animation.convert_path"] = r'/usr/bin/convert'
                anim.save(output_file, writer='imagemagick', fps=12)
            else:
                plt.show()
        else:
            raise NotImplementedError
Example #9
0
    def render(self, mode=None, output_file=None):

        from matplotlib import animation  # 画动态图
        import matplotlib.pyplot as plt

        x_offset = 0.2
        y_offset = 0.4
        cmap = plt.cm.get_cmap('hsv', 10)  # color map
        robot_color = 'black'
        arrow_style = patches.ArrowStyle("->", head_length=4,
                                         head_width=2)  # 箭头的长度和宽度
        display_numbers = True  # 展示数字

        if mode == 'traj':
            count = len(self.states)
            pics = count // 4
            pic_steps = []
            while count > 0:
                pic_steps.append(count)
                count -= pics

            for pic_step in pic_steps[::-1]:
                fig, ax = plt.subplots(figsize=(8, 8))
                plt.suptitle('Track', fontsize=16)
                ax.tick_params(labelsize=16)
                ax.set_xlim(-9, 9)
                ax.set_ylim(-9, 9)
                ax.set_xlabel('x(m)', fontsize=16)
                ax.set_ylabel('y(m)', fontsize=16)

                # add human start positions and goals
                human_colors = [
                    cmap(i) for i in range(len(self.out_group_agents))
                ]
                for i in range(len(self.out_group_agents)):
                    human = self.out_group_agents[i]
                    human_goal = mlines.Line2D([human.get_goal()[0]],
                                               [human.get_goal()[1]],
                                               color=human_colors[i],
                                               marker='*',
                                               linestyle='None',
                                               markersize=15)
                    ax.add_artist(human_goal)
                    human_start = mlines.Line2D(
                        [human.get_start_position()[0]],
                        [human.get_start_position()[1]],
                        color=human_colors[i],
                        marker='o',
                        linestyle='None',
                        markersize=8)
                    ax.add_artist(human_start)

                # add group start and goal
                group_goal = mlines.Line2D([0], [self.circle_radius],
                                           color='black',
                                           marker='*',
                                           linestyle='None',
                                           markersize=16)
                ax.add_artist(group_goal)
                for member in self.group_members:
                    member_start = mlines.Line2D(
                        [member.get_start_position()[0]],
                        [member.get_start_position()[1]],
                        color='black',
                        marker='o',
                        linestyle='None',
                        markersize=8)
                    ax.add_artist(member_start)

                agent_position = [[
                    self.states[i][j].position for j in range(
                        len(self.group_members + self.out_group_agents))
                ] for i in range(pic_step)]

                for k in range(pic_step):  # k是状态的索引
                    # 先画直线:
                    if k != 0:
                        human_directions = [
                            plt.Line2D(
                                (agent_position[k - 1][i][0],
                                 agent_position[k][i][0]),
                                (agent_position[k - 1][i][1],
                                 agent_position[k][i][1]),
                                color='grey' if i < 3 else human_colors[i - 3],
                                linestyle=':') for i in range(
                                    len(self.group_members +
                                        self.out_group_agents))
                        ]
                        for d in human_directions:
                            ax.add_artist(d)
                    # 画圈圈
                    if k == pic_step - 1:  # 最后一个状态填充
                        for n, agent in enumerate(self.group_members +
                                                  self.out_group_agents):
                            agent = plt.Circle(
                                agent_position[k][n],
                                self.agent_radius,
                                fill=True,
                                color='black' if n < 3 else human_colors[n -
                                                                         3])
                            ax.add_artist(agent)
                            number = plt.text(agent.center[0] - 0.1,
                                              agent.center[1] - 0.1,
                                              str(n),
                                              color='white')
                            ax.add_artist(number)

                    elif k % 4 == 0 or k == len(self.states) - 1:  # 每4个状态
                        for n, agent in enumerate(self.group_members +
                                                  self.out_group_agents):
                            agent = plt.Circle(
                                agent_position[k][n],
                                self.agent_radius,
                                fill=False,
                                color='black' if n < 3 else human_colors[n -
                                                                         3])
                            ax.add_artist(agent)
                robot = plt.Circle([0, 0],
                                   self.agent_radius,
                                   fill=False,
                                   color='black')
                goal = mlines.Line2D([0], [0],
                                     color='black',
                                     marker='*',
                                     linestyle='None',
                                     markersize=16)
                plt.legend([robot, goal], ['Robot', 'Goal'], fontsize=16)
                plt.show()

        elif mode == 'video':
            fig, ax = plt.subplots(figsize=(7, 7))  # 面板大小7,7 fig 表示一窗口 ax 是一个框
            ax.tick_params(labelsize=12)  # 坐标字体大小
            ax.set_xlim(-11, 11)  # -11,11  # 坐标的范围   可用于控制画面比例
            ax.set_ylim(-11, 11)  # -11,11
            ax.set_xlabel('x(m)', fontsize=14)
            ax.set_ylabel('y(m)', fontsize=14)
            show_human_start_goal = True

            # 图例
            circle1 = plt.Circle((1, 1), 0.3, fill=True, color=cmap(4))
            circle2 = plt.Circle((1, 1), 0.3, fill=False, color=cmap(7))

            # 在图上显示组外agent(用human标识)的起始位置和目标位置
            human_colors = [cmap(i) for i in range(len(self.out_group_agents))]
            if show_human_start_goal:  # 展示human初始目标为true时才显示
                for i in range(len(self.out_group_agents)):
                    agent = self.out_group_agents[i]
                    agent_goal = mlines.Line2D([agent.get_goal()[0]],
                                               [agent.get_goal()[1]],
                                               color=human_colors[i],
                                               marker='*',
                                               linestyle='None',
                                               markersize=8)
                    ax.add_artist(agent_goal)
                    human_start = mlines.Line2D(
                        [agent.get_start_position()[0]],
                        [agent.get_start_position()[1]],
                        color=human_colors[i],
                        marker='o',
                        linestyle='None',
                        markersize=4)
                    ax.add_artist(human_start)
            # 设置小组成员的初始位置
            for i in range(len(self.group_members)):
                agent = self.group_members[i]
                agent_start = mlines.Line2D([agent.get_start_position()[0]],
                                            [agent.get_start_position()[1]],
                                            color='black',
                                            marker='o',
                                            linestyle='None',
                                            markersize=4)
                ax.add_artist(agent_start)

            # 手动添加group的goal
            group_goal = mlines.Line2D([0], [self.circle_radius],
                                       color='black',
                                       marker='*',
                                       linestyle='None',
                                       markersize=16)
            ax.add_artist(group_goal)

            group_member_template = plt.Circle((0, 0),
                                               0.3,
                                               fill=False,
                                               color='black')
            # plt.legend([group_member_template, group_goal, circle1, circle2],
            #            ['ROBOT', 'Goal', 'Want', 'truth'], fontsize=14)
            plt.legend([group_member_template, group_goal], ['ROBOT', 'Goal'],
                       fontsize=14)

            # 添加所有的agent
            agent_positions = [[
                state[j].position for j in range(self.out_group_agents_num + 3)
            ] for state in self.states]
            # group_members都是黑色
            group_members = [
                plt.Circle(agent_positions[0][i],
                           self.agent_radius,
                           fill=False,
                           color='black') for i in range(3)
            ]
            humans = [
                plt.Circle(agent_positions[0][i],
                           self.out_group_agents[i - 3].radius,
                           fill=False,
                           color=cmap(i - 3))
                for i in range(3, self.out_group_agents_num + 3)
            ]

            all_agents = group_members + humans
            # 生成并显示数字编号
            numbers = [
                plt.text(all_agents[i].center[0] - x_offset,
                         all_agents[i].center[1] + y_offset,
                         str(i),
                         color='black')
                for i in range(self.out_group_agents_num + 3)
            ]
            for i, agent in enumerate(all_agents):
                ax.add_artist(agent)
                if display_numbers:
                    ax.add_artist(numbers[i])

            # 时间和步骤
            time = plt.text(0.4,
                            0.9,
                            'Time: {}'.format(0),
                            fontsize=16,
                            transform=ax.transAxes)
            ax.add_artist(time)
            step = plt.text(0.1,
                            0.9,
                            'Step: {}'.format(0),
                            fontsize=16,
                            transform=ax.transAxes)
            ax.add_artist(step)

            # 计算朝向,使用箭头显示方向
            radius = self.agent_radius
            orientations = []  # [[某个成员在所有状态下的朝向],[],[],[]...]
            arrows = []
            for i in range(self.out_group_agents_num + 3):
                orientation = []
                for state in self.states:
                    agent_state = state[i]
                    theta = np.arctan2(agent_state.vy, agent_state.vx)
                    direction = ((agent_state.px, agent_state.py),
                                 (agent_state.px + radius * np.cos(theta),
                                  agent_state.py + radius * np.sin(theta)))
                    orientation.append(direction)
                orientations.append(orientation)  # 计算箭头的方向,也是human的朝向
                if i <= 2:
                    arrow_color = 'black'
                    arrows.append(
                        patches.FancyArrowPatch(*orientation[0],
                                                color=arrow_color,
                                                arrowstyle=arrow_style))
                else:
                    arrows.extend([
                        patches.FancyArrowPatch(*orientation[0],
                                                color=human_colors[i - 3],
                                                arrowstyle=arrow_style)
                    ])

            for arrow in arrows:
                ax.add_artist(arrow)
            global_step = 0

            def update(frame_num):  # frame_num 是第多少帧
                nonlocal global_step
                nonlocal arrows
                global_step = frame_num

                for i, agent in enumerate(all_agents):
                    agent.center = agent_positions[frame_num][i]
                    if display_numbers:
                        numbers[i].set_position((agent.center[0] - x_offset,
                                                 agent.center[1] + y_offset))
                for arrow in arrows:
                    arrow.remove()

                arrows = []
                for i in range(self.out_group_agents_num + 3):
                    orientation = orientations[i]
                    if i <= 2:
                        arrow_color = 'black'
                        arrows.append(
                            patches.FancyArrowPatch(*orientation[frame_num],
                                                    color=arrow_color,
                                                    arrowstyle=arrow_style))
                    else:
                        arrows.extend([
                            patches.FancyArrowPatch(*orientation[frame_num],
                                                    color=cmap(i - 3),
                                                    arrowstyle=arrow_style)
                        ])

                for arrow in arrows:
                    ax.add_artist(arrow)

                time.set_text('Time: {:.2f}'.format(frame_num *
                                                    self.time_step))
                step.set_text('Step: {:}'.format(frame_num))

            def plot_value_heatmap():
                if self.robot.kinematics != 'holonomic':
                    print('Kinematics is not holonomic')
                    return

                # when any key is pressed draw the action value plot
                fig, axis = plt.subplots()
                speeds = [0] + self.robot.policy.speeds
                rotations = self.robot.policy.rotations + [np.pi * 2]
                r, th = np.meshgrid(speeds, rotations)
                z = np.array(self.action_values[global_step %
                                                len(self.states)][1:])
                z = (z - np.min(z)) / (np.max(z) - np.min(z))
                z = np.reshape(z, (self.robot.policy.rotation_samples,
                                   self.robot.policy.speed_samples))
                polar = plt.subplot(projection="polar")
                polar.tick_params(labelsize=16)
                mesh = plt.pcolormesh(th, r, z, vmin=0, vmax=1)
                plt.plot(rotations, r, color='k', ls='none')
                plt.grid()
                cbaxes = fig.add_axes([0.85, 0.1, 0.03, 0.8])
                cbar = plt.colorbar(mesh, cax=cbaxes)
                cbar.ax.tick_params(labelsize=16)
                plt.show()

            def print_matrix_A():
                # with np.printoptions(precision=3, suppress=True):
                #     print(self.As[global_step])
                h, w = self.As[global_step].shape
                print('   ' +
                      ' '.join(['{:>5}'.format(i - 1) for i in range(w)]))
                for i in range(h):
                    print('{:<3}'.format(i - 1) + ' '.join([
                        '{:.3f}'.format(self.As[global_step][i][j])
                        for j in range(w)
                    ]))

            def print_feat():
                with np.printoptions(precision=3, suppress=True):
                    print('feat is: ')
                    print(self.feats[global_step])

            def print_X():
                with np.printoptions(precision=3, suppress=True):
                    print('X is: ')
                    print(self.Xs[global_step])

            def on_click(event):
                if anim.running:
                    anim.event_source.stop()
                    print('you pressd the : ', event.key, ' key')
                else:
                    anim.event_source.start()
                anim.running ^= True

            fig.canvas.mpl_connect('key_press_event',
                                   on_click)  # matplotlib的键鼠响应事件绑定
            anim = animation.FuncAnimation(fig,
                                           update,
                                           frames=len(self.states),
                                           interval=self.time_step * 500)
            anim.running = True

            if output_file is not None:
                # save as video
                ffmpeg_writer = animation.FFMpegWriter(
                    fps=10, metadata=dict(artist='Me'), bitrate=1800)
                # writer = ffmpeg_writer(fps=10, metadata=dict(artist='Me'), bitrate=1800)
                anim.save(output_file, writer=ffmpeg_writer)
                # save output file as gif if imagemagic is installed
                # anim.save(output_file, writer='imagemagic', fps=12)
            else:
                name = './' + str(
                    self.out_group_agents_num
                ) + 'p_' + self.group.policy.name + '_seed' + self.seed + '.gif'
                anim.save(name, writer='imagemagic', fps=12)
                plt.show()

        else:
            raise NotImplementedError
Example #10
0
    def render(self, mode='human'):
        import matplotlib.pyplot as plt
        import matplotlib.lines as mlines
        from matplotlib import patches

        plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'

        robot_color = 'yellow'
        goal_color = 'red'
        arrow_color = 'red'
        arrow_style = patches.ArrowStyle("->", head_length=4, head_width=2)

        def calcFOVLineEndPoint(ang, point, extendFactor):
            # choose the extendFactor big enough
            # so that the endPoints of the FOVLine is out of xlim and ylim of the figure
            FOVLineRot = np.array([[np.cos(ang), -np.sin(ang), 0],
                                   [np.sin(ang), np.cos(ang), 0], [0, 0, 1]])
            point.extend([1])
            # apply rotation matrix
            newPoint = np.matmul(FOVLineRot, np.reshape(point, [3, 1]))
            # increase the distance between the line start point and the end point
            newPoint = [
                extendFactor * newPoint[0, 0], extendFactor * newPoint[1, 0], 1
            ]
            return newPoint

        ax = self.render_axis
        artists = []

        # add goal
        goal = mlines.Line2D([self.robot.gx], [self.robot.gy],
                             color=goal_color,
                             marker='*',
                             linestyle='None',
                             markersize=15,
                             label='Goal')
        ax.add_artist(goal)
        artists.append(goal)

        # add robot
        robotX, robotY = self.robot.get_position()

        robot = plt.Circle((robotX, robotY),
                           self.robot.radius,
                           fill=True,
                           color=robot_color)
        ax.add_artist(robot)
        artists.append(robot)

        plt.legend([robot, goal], ['Robot', 'Goal'], fontsize=16)

        # compute orientation in each step and add arrow to show the direction
        radius = self.robot.radius
        arrowStartEnd = []

        robot_theta = self.robot.theta if self.robot.kinematics == 'unicycle' else np.arctan2(
            self.robot.vy, self.robot.vx)

        arrowStartEnd.append(
            ((robotX, robotY), (robotX + radius * np.cos(robot_theta),
                                robotY + radius * np.sin(robot_theta))))

        for i, human in enumerate(self.humans):
            theta = np.arctan2(human.vy, human.vx)
            arrowStartEnd.append(
                ((human.px, human.py), (human.px + radius * np.cos(theta),
                                        human.py + radius * np.sin(theta))))

        arrows = [
            patches.FancyArrowPatch(*arrow,
                                    color=arrow_color,
                                    arrowstyle=arrow_style)
            for arrow in arrowStartEnd
        ]
        for arrow in arrows:
            ax.add_artist(arrow)
            artists.append(arrow)

        # draw FOV for the robot
        # add robot FOV
        FOVAng = self.robot_fov / 2
        FOVLine1 = mlines.Line2D([0, 0], [0, 0], linestyle='--')
        FOVLine2 = mlines.Line2D([0, 0], [0, 0], linestyle='--')

        startPointX = robotX
        startPointY = robotY
        endPointX = robotX + radius * np.cos(robot_theta)
        endPointY = robotY + radius * np.sin(robot_theta)

        # transform the vector back to world frame origin, apply rotation matrix, and get end point of FOVLine
        # the start point of the FOVLine is the center of the robot
        FOVEndPoint1 = calcFOVLineEndPoint(
            FOVAng, [endPointX - startPointX, endPointY - startPointY],
            20. / self.robot.radius)
        FOVLine1.set_xdata(
            np.array([startPointX, startPointX + FOVEndPoint1[0]]))
        FOVLine1.set_ydata(
            np.array([startPointY, startPointY + FOVEndPoint1[1]]))
        FOVEndPoint2 = calcFOVLineEndPoint(
            -FOVAng, [endPointX - startPointX, endPointY - startPointY],
            20. / self.robot.radius)
        FOVLine2.set_xdata(
            np.array([startPointX, startPointX + FOVEndPoint2[0]]))
        FOVLine2.set_ydata(
            np.array([startPointY, startPointY + FOVEndPoint2[1]]))

        ax.add_artist(FOVLine1)
        ax.add_artist(FOVLine2)
        artists.append(FOVLine1)
        artists.append(FOVLine2)

        # add humans and change the color of them based on visibility
        human_circles = [
            plt.Circle(human.get_position(), human.radius, fill=False)
            for human in self.humans
        ]

        for i in range(len(self.humans)):
            ax.add_artist(human_circles[i])
            artists.append(human_circles[i])

            # green: visible; red: invisible
            if self.detect_visible(self.robot, self.humans[i], robot1=True):
                human_circles[i].set_color(c='g')
            else:
                human_circles[i].set_color(c='r')

        plt.pause(0.1)
        for item in artists:
            item.remove(
            )  # there should be a better way to do this. For example,
Example #11
0
import names
import networkx as nx
import migration_trends
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cmx
import matplotlib.patches as mp
import statistics


MY_ARROWS = mp.ArrowStyle("Fancy", head_length=10, head_width=5, tail_width=.4)
CORRECT = {'Matanuska-Susitna':'Matsu', 'Out':'Out Of State'}

def correct(text):
  if text in CORRECT.keys():
    return CORRECT[text]
  return text

def flow_dict(start_year, end_year):
  flw_dct = dict()
  for frm in names.PLACE_LIST:
    short_frm = frm.split()[0]
    flw_dct[short_frm] = dict()
    for to in names.PLACE_LIST:
      short_to = to.split()[0]
      if frm != to: 
        total = sum(migration_trends.get_net_migration_list(start_year, end_year, to, frm))
        flw_dct[short_frm][short_to] = total
  return flw_dct

def flow_dict_with_lists(start_year, end_year, area_dct):
def draw_neural_net(ax,
                    left=0.0,
                    right=1.,
                    bottom=0.0,
                    top=1.,
                    layerSizes=[2, 3, 1],
                    inputPrefix="x",
                    outputPrefix="\hat{y}_{m}",
                    inLayerPrefix="I",
                    outLayerPrefix="O",
                    hiddenLayerPrefix="H",
                    inNodePrefix="i",
                    otherNodePrefix=r"z_{m}\rightarrow a_{m}",
                    biasNodePrefix=r"b_{m}",
                    weights=None,
                    biases=None,
                    epoch="",
                    loss="",
                    hideInOutPutNodes=False,
                    hideBias=False,
                    showLayerIndex=True,
                    inputOutputColor="blue",
                    nodeColor="lightgreen",
                    biasNodeColor="lightcyan",
                    edgeColor="black",
                    biasEdgeColor="gray",
                    weightsColor="green",
                    biasColor="purple",
                    nodeFontSize=15,
                    edgeFontSize=10,
                    edgeWidth=1):
    '''
    Draw a neural network cartoon using matplotilb.
    
    :usage:
        >>> fig = plt.figure(figsize=(12, 12))
        >>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2])
    
    :parameters:
        - ax : matplotlib.axes.AxesSubplot
            The axes on which to plot the cartoon (get e.g. by plt.gca())
        - left : float
            The center of the leftmost node(s) will be placed here
        - right : float
            The center of the rightmost node(s) will be placed here
        - bottom : float
            The center of the bottommost node(s) will be placed here
        - top : float
            The center of the topmost node(s) will be placed here
        - inputPrefix : string
            Prefix of input; a prefix p will show as $p_i$ where i is an index
        - outputPrefix : string
            Prefix of output; a prefix p will show as $p_i$ where i is an index
        - inLayerPrefix : string
            string used to denote input layer
        - outLayerPrefix : string
            string used to denote output layer
        - hiddenLayerPrefix : string
            string used to denote hidden layers
        - inNodePrefix : string
            string used for text in input nodes.; a prefix p will show as $p_i$ where i is an index
        - otherNodePrefix : string or list of list of strings
            Prefix used for text in all nodes but the input and bias nodes.
            If string, this will be reused for all nodes; to get automatic indexing 
            include "_{m}" (must use 'm').
            If list of list, then outer list must conform to number of layers (excepting 
            input layer) and inner lists to the number of nodes in the respective layers;
            NB! use raw strings when including latex math notation
        - biasNodePrefix : string
            string used for text in bias nodes.; to get automatic indexing include "_{m}" 
            (must use 'm'). NB! use raw strings when including latex math notation
        - weights : None or list of numpy.array of strings or floats
            If list of list, then outer list must conform to number of layers (excepting 
            input layer) and inner numpy.array must take an indexing 'from node', 'to node',
            denoting the weight for that edge; strings may include latex math notation; 
            numbers will be rounded to 4 decimals.
        - bias : None or list of lists of strings or floats
            If list of list, then outer list must conform to number of layers (excepting 
            input layer) and inner lists to the number of nodes in the respective layers.
        - epoch : int
            The epoch number
        - loss : float
            The value of the Loss/Cost function
        - hideInOutPutNodes : True/False
            Hide inout and output nodes (e.g., when drawing only a nenuron
        - hideBias : True/False
            Hide bias nodes and edges
        - showLayerIndex = True/False
            Whether to show layer names 
        - inputOutputColor : valid MatPlotLib color name
            Color of input output arrows
        - nodeColor : valid MatPlotLib color name
            Background color of layer nodes
        - biasNodeColor : valid MatPlotLib color name
            Background color of bias nodes
        - edgeColor : valid MatPlotLib color name
            Color of weight text
        - biasEdgeColor : valid MatPlotLib color name
            Color of bias text
        - nodeFontSize : int
            Fontsize text inside nodes
        - edgeFontSize : int
            Fontsize of edge text
        - edgeWidth : int
            Width of edge lines

    '''
    n_layers = len(layerSizes)
    vSpacing = (top - bottom) / float(max(layerSizes))
    hSpacing = (right - left) / float(len(layerSizes) - 1)

    input = inputPrefix
    if not isinstance(input, list):
        input = [
            r'${}_{}$'.format(inputPrefix, m + 1) for m in range(layerSizes[0])
        ]

    output = outputPrefix
    if not isinstance(output, list):
        output = [
            r'${}$'.format(re.sub("m", "{}".format(m + 1), outputPrefix))
            for m in range(layerSizes[-1])
        ]

    hidden = otherNodePrefix
    if not isinstance(hidden, list):
        hidden = [[
            r'${}$'.format(otherNodePrefix.format(
                m=m + 1)) if "{" in otherNodePrefix else otherNodePrefix
            for m in list(range(layerSizes[n]))
        ] for n in list(range(len(layerSizes)))]

    widthLimit = ax.get_window_extent(
    ).width  # widthLimit for figsize=(12,12)=669.6
    nodeLetterWidth = 0.0007 * nodeFontSize * 669.6 / widthLimit
    edgeLetterWidth = 0.0007 * edgeFontSize * 669.6 / widthLimit
    nodeRadius = max(
        hSpacing / 8.,
        (max([lenMathString(max(x, key=lenMathString))
              for x in hidden]) + 1) * nodeLetterWidth / 2)
    #nodeRadius = max(hSpacing /8., (lenMathString(node_txt))/2 * nodeLetterWidth)
    biasRadius = max(hSpacing / 12.,
                     (lenMathString(biasNodePrefix) + 1) / 2 * nodeLetterWidth)

    nodePlusArrow = 2 * nodeRadius
    if hideInOutPutNodes:
        nodePlusArrow = 0
    # adjust left, right, bottom, top and spacing to fit input text
    inPad = (lenMathString(max(input, key=lenMathString)) +
             2) * nodeLetterWidth
    left = max(nodePlusArrow + inPad, left)

    outPad = (lenMathString(max(output, key=lenMathString)) +
              2) * nodeLetterWidth
    right = min(1.0 - nodePlusArrow - outPad, right)

    bottom = max(nodeRadius, bottom)
    top = min(1.0 - nodeRadius, top)

    vSpacing = (top - bottom) / float(max(layerSizes))
    hSpacing = (right - left) / float(len(layerSizes) - 1)

    # Input-Arrows
    if not hideInOutPutNodes:
        layer_top_0 = vSpacing * (layerSizes[0] - 1) / 2. + (top + bottom) / 2.
        for m in range(layerSizes[0]):
            xhead = left - nodeRadius
            yhead = layer_top_0 - m * vSpacing
            dx = nodeRadius  # hSpacing - vSpacing/8. #0.3*hSpacing
            dy = 0  #2*nodeLetterWidth
            xtail = xhead - dx
            ytail = yhead - dy
            # arrow = mpatches.FancyArrowPatch((xtail, ytail), (xhead, yhead),
            #                                  mutation_scale=25, zorder = 10)
            # ax.add_patch(arrow)
            line1 = plt.annotate(
                "",
                xy=(xhead, yhead),
                xytext=(xtail, ytail),
                xycoords='data',
                arrowprops=dict(arrowstyle=mpatches.ArrowStyle("simple",
                                                               head_length=0.4,
                                                               head_width=0.4),
                                color=inputOutputColor,
                                lw=edgeWidth))
            ax.add_artist(line1)
    # Nodes
    for n, layer_size in enumerate(layerSizes):
        layer_top = vSpacing * (layer_size - 1) / 2. + (top + bottom) / 2.
        for m in range(layer_size):
            x_node = n * hSpacing + left
            y_node = layer_top - m * vSpacing
            circle = mpatches.Circle(
                (x_node, y_node),
                nodeRadius,  #vSpacing/8.,
                facecolor=nodeColor,
                edgecolor='k',
                zorder=4)
            #            circle = plt.Circle((x_node, y_node), nodeRadius, #vSpacing/8.,
            #                                facecolor = nodeColor, edgecolor = 'k', zorder=0)
            txt = hidden[n][m]
            x_label = x_node - lenMathString(txt) * nodeLetterWidth / 2
            y_label = y_node - 0.01
            layerTxt = ""
            inputOutputPad = nodeRadius * 2
            if hideInOutPutNodes:
                inputOutputPad = 0

            if n == 0:
                if inLayerPrefix != "":
                    layerTxt = '${}$'.format(inLayerPrefix)
                plt.text(
                    left - inputOutputPad -
                    inPad,  #left-nodeRadius*2-0.03, #0.125,
                    y_node - nodeLetterWidth,
                    input[m],  #r'${}_{}$'.format(inputPrefix, m+1),
                    fontsize=nodeFontSize,
                    zorder=2)
                txt = r'${}_{}$'.format(
                    inNodePrefix, m +
                    1) if inNodePrefix != "" else inNodePrefix
                if not hideInOutPutNodes:
                    ax.add_artist(circle)
                    x_label = x_node - lenMathString(txt) * nodeLetterWidth / 2
                    plt.text(x_label,
                             y_label,
                             txt,
                             fontsize=nodeFontSize,
                             zorder=8,
                             color='k')  # Change txt position here
            else:
                if n == n_layers - 1:
                    if outLayerPrefix != "":
                        layerTxt = r"${}$".format(outLayerPrefix)
                    plt.text(
                        right +
                        inputOutputPad,  # +outPad/2, #+ 0.01, #right+2*nodeRadius+0.01,
                        y_node - nodeLetterWidth,
                        output[m],  #r'${}_{}$'.format(outputPrefix, m+1),
                        fontsize=nodeFontSize)
                    #txt = r'o_{}'.format(m+1)                                  # Change format of output  node text here
                    if not hideInOutPutNodes:
                        ax.add_artist(circle)
                        #x_label = x_node-lenMathString(txt) * nodeLetterWidth
                        plt.text(x_label,
                                 y_label,
                                 txt,
                                 fontsize=nodeFontSize,
                                 zorder=8,
                                 color='k')  # Change txt position here
                else:
                    if hiddenLayerPrefix != "":
                        layerTxt = r'$' + hiddenLayerPrefix + '_{' + "{}".format(
                            n) + '}$'
                    ax.add_artist(circle)
                    plt.text(x_label,
                             y_label,
                             txt,
                             fontsize=nodeFontSize,
                             zorder=8,
                             color='k')  # Change txt position here
            if showLayerIndex and m == 0:
                plt.text(
                    x_node + 0.00,
                    y_node +
                    max(vSpacing / 8. + 0.01 * vSpacing, nodeRadius + 0.01),
                    layerTxt,
                    zorder=8,
                    fontsize=nodeFontSize)

    # Bias-Nodes
    if not hideBias:
        for n, layer_size in enumerate(layerSizes):
            skip = 1
            if hideInOutPutNodes:
                skip = 2
            if n < n_layers - skip:
                x_bias = (n + 0.5) * hSpacing + left
                y_bias = top - 0.005
                circle = plt.Circle(
                    (x_bias, y_bias),
                    biasRadius,  #vSpacing/8.,
                    #                                    label="b",
                    facecolor=biasNodeColor,
                    edgecolor='k',
                    zorder=4)
                ax.add_artist(circle)
                txt = biasNodePrefix
                if "{" in biasNodePrefix:
                    txt = r'${}$'.format(biasNodePrefix.format(
                        m=m + 1))  # Change format of hidden  node text here
                r'$b${}'.format(n + 1)
                plt.text(x_bias - 0.015,
                         y_bias - 0.01,
                         txt,
                         fontsize=nodeFontSize,
                         zorder=8,
                         color='k')  # Change format of bias text here

    # Edges
    # Edges between nodes
    for n, (layer_size_a,
            layer_size_b) in enumerate(zip(layerSizes[:-1], layerSizes[1:])):
        layer_top_a = vSpacing * (layer_size_a - 1) / 2. + (top + bottom) / 2.
        layer_top_b = vSpacing * (layer_size_b - 1) / 2. + (top + bottom) / 2.
        for m in range(layer_size_a):
            for o in range(layer_size_b):
                xm = n * hSpacing + left
                xo = (n + 1) * hSpacing + left
                ym = layer_top_a - m * vSpacing
                yo = (layer_top_b - o * vSpacing)
                delta_x = xo - xm
                delta_y = yo - ym
                length = np.sqrt(delta_x**2 + delta_y**2)

                line1 = plt.annotate(
                    "",
                    xy=(xo, yo),
                    xytext=(xm, ym),
                    xycoords='data',
                    arrowprops=dict(arrowstyle=mpatches.ArrowStyle(
                        "->",
                        head_length=10 * min(0.2, hSpacing / 5.),
                        head_width=10 * min(0.1, hSpacing / 5.)),
                                    shrinkB=nodeRadius * widthLimit,
                                    color=edgeColor,
                                    lw=edgeWidth))
                ax.add_artist(line1)
                if weights != None:
                    rot_mo_rad = np.arctan((yo - ym) / (xo - xm))
                    rot_mo_deg = rot_mo_rad * 180. / np.pi
                    label = weights[n][m, o]
                    if label != "":
                        if isinstance(label, numbers.Number):
                            label = round(label, 4)
                        label = "${}$".format(label)

                    delta_x = vSpacing / 8.  # + nodeRadius
                    delta_x = max(delta_x, nodeRadius + 0.001)
                    delta_y = delta_x * abs(np.tan(rot_mo_rad))
                    epsilon = 0.01 / abs(np.cos(rot_mo_rad))
                    xm1 = xm + delta_x
                    if yo > ym:
                        label_skew = edgeLetterWidth * abs(np.sin(rot_mo_rad))
                        ym1 = ym + label_skew + delta_y + epsilon
                    elif yo < ym:
                        label_skew = lenMathString(
                            label) * edgeLetterWidth * abs(np.sin(rot_mo_rad))
                        ym1 = ym - label_skew - delta_y + epsilon
                    else:
                        ym1 = ym + epsilon

                    plt.text(xm1,
                             ym1,
                             label,
                             rotation=rot_mo_deg,
                             fontsize=edgeFontSize,
                             color=weightsColor,
                             zorder=10)

    # Edges between bias and nodes
    if not hideBias:
        for n, (layer_size_a,
                layer_size_b) in enumerate(zip(layerSizes[:-1],
                                               layerSizes[1:])):
            if hideInOutPutNodes and n == n_layers - 2:
                continue
            if n < n_layers - 1:
                layer_top_a = vSpacing * (layer_size_a -
                                          1) / 2. + (top + bottom) / 2.
                layer_top_b = vSpacing * (layer_size_b -
                                          1) / 2. + (top + bottom) / 2.
            x_bias = (n + 0.5) * hSpacing + left
            y_bias = top + 0.005
            for o in range(layer_size_b):
                xo = left + (n + 1) * hSpacing
                yo = (layer_top_b - o * vSpacing)
                line = plt.Line2D(
                    [x_bias, xo],  #(n + 1)*hSpacing + left],
                    [y_bias, yo],  #layer_top_b - o*vSpacing],
                    c=biasEdgeColor,
                    lw=edgeWidth)
                ax.add_artist(line)
                if biases != None:
                    rot_bo_rad = np.arctan((yo - y_bias) / (xo - x_bias))
                    rot_bo_deg = rot_bo_rad * 180. / np.pi
                    label = biases[n][o]
                    if isinstance(label, numbers.Number):
                        label = round(label, 4)
                    label = "${}$".format(label)

                    label_skew = len(label) * edgeLetterWidth * abs(
                        np.sin(rot_bo_rad))
                    delta_x = max(vSpacing / 8., +nodeRadius + 0.001)
                    delta_y = delta_x * abs(np.tan(rot_bo_rad))
                    epsilon = 0.01 / abs(np.cos(rot_bo_rad))

                    xo1 = xo - delta_x  #(vSpacing/8.+0.01)*np.cos(rot_bo_rad)
                    yo1 = yo - label_skew + delta_y + epsilon
                    plt.text(xo1,
                             yo1,
                             label,
                             rotation=rot_bo_deg,
                             fontsize=edgeFontSize,
                             color=biasColor)

    # Output-Arrows
    if not hideInOutPutNodes:
        layer_top_0 = vSpacing * (layerSizes[-1] - 1) / 2. + (top +
                                                              bottom) / 2.
        for m in range(layerSizes[-1]):
            xtail = right + nodeRadius  #0.015
            ytail = layer_top_0 - m * vSpacing
            dx = nodeRadius  #0.2*hSpacing
            dy = 0  #-2*nodeLetterWidth
            xhead = xtail + dx
            yhead = ytail + dy
            # arrow = mpatches.FancyArrowPatch((xtail, ytail), (xhead, yhead),
            #                                  mutation_scale=25, zorder=8)
            # ax.add_patch(arrow)
            line1 = plt.annotate(
                "",
                fontsize=nodeFontSize,
                xy=(xhead, yhead),
                xytext=(xtail, ytail),
                xycoords='data',
                arrowprops=dict(arrowstyle=mpatches.ArrowStyle("simple",
                                                               head_length=0.4,
                                                               head_width=0.4),
                                color=inputOutputColor,
                                lw=edgeWidth),
                zorder=0)
            ax.add_artist(line1)

    # Record the epoch and loss
    if isinstance(epoch, numbers.Number):
        round(epoch, 6)
    if isinstance(loss, numbers.Number):
        round(loss, 6)
    txt = ""
    if epoch != "":
        txt = "Steps: {}".format(epoch)
    if loss != "":
        txt = "{}    Loss: {}".format(txt, loss)
    plt.text(left + (right-left)/3., bottom - 0.005*vSpacing, \
             txt,
             #'Steps:' + "{}".format(epoch) + '    Loss: ' + "{}".format(loss),
             fontsize = nodeFontSize)
Example #13
0
    def render(self, mode='human', output_file=None):
        from matplotlib import animation
        import matplotlib.pyplot as plt
        # plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
        x_offset = 0.2
        y_offset = 0.4
        cmap = plt.cm.get_cmap('hsv', 10)
        robot_color = 'black'
        arrow_style = patches.ArrowStyle("->", head_length=4, head_width=2)
        display_numbers = True

        if mode == 'traj':
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=16)
            ax.set_xlim(-5, 5)
            ax.set_ylim(-5, 5)
            ax.set_xlabel('x(m)', fontsize=16)
            ax.set_ylabel('y(m)', fontsize=16)

            # add human start positions and goals
            agent_colors = [cmap(i) for i in range(len(self.agents))]
            for i in range(len(self.agents)):
                agent = self.agents[i]
                ageng_goal = mlines.Line2D([agent.get_goal_position()[0]],
                                           [agent.get_goal_position()[1]],
                                           color=agent_colors[i],
                                           marker='*',
                                           linestyle='None',
                                           markersize=15)
                ax.add_artist(ageng_goal)
                agent_start = mlines.Line2D([agent.get_start_position()[0]],
                                            [agent.get_start_position()[1]],
                                            color=agent_colors[i],
                                            marker='o',
                                            linestyle='None',
                                            markersize=15)
                ax.add_artist(agent_start)

            agent_positions = [[
                self.states[i][j].position for j in range(len(self.agents))
            ] for i in range(len(self.states))]

            for k in range(len(self.states)):
                if k % 4 == 0 or k == len(self.states) - 1:
                    agents = [
                        plt.Circle(agent_positions[k][i],
                                   self.agents[i].radius,
                                   fill=False,
                                   color=cmap(i))
                        for i in range(len(self.agents))
                    ]
                    for agent in agents:
                        ax.add_artist(agent)

                # add time annotation
                global_time = k * self.time_step
                if global_time % 4 == 0 or k == len(self.states) - 1:
                    times = [
                        plt.text(agents[i].center[0] - x_offset,
                                 agents[i].center[1] - y_offset,
                                 '{:.1f}'.format(global_time),
                                 color='black',
                                 fontsize=14) for i in range(self.human_num)
                    ]
                    for time in times:
                        ax.add_artist(time)
                if k != 0:
                    agent_directions = [
                        plt.Line2D(
                            (self.states[k - 1][i].px, self.states[k][i].px),
                            (self.states[k - 1][i].py, self.states[k][i].py),
                            color=cmap(i),
                            ls='solid') for i in range(self.human_num)
                    ]
                    for agent_direction in agent_directions:
                        ax.add_artist(agent_direction)
            plt.show()

        elif mode == 'video':
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=12)
            ax.set_xlim(-11, 11)
            ax.set_ylim(-11, 11)
            ax.set_xlabel('x(m)', fontsize=14)
            ax.set_ylabel('y(m)', fontsize=14)
            show_human_start_goal = False

            # add agent start positions and goals
            agent_colors = [cmap(i) for i in range(len(self.agents))]
            for i in range(len(self.agents)):
                agent = self.agents[i]
                agent_goal = mlines.Line2D([agent.get_goal_position()[0]],
                                           [agent.get_goal_position()[1]],
                                           color=agent_colors[i],
                                           marker='*',
                                           linestyle='None',
                                           markersize=8)
                ax.add_artist(agent_goal)
                agent_start = mlines.Line2D([agent.get_start_position()[0]],
                                            [agent.get_start_position()[1]],
                                            color=agent_colors[i],
                                            marker='o',
                                            linestyle='None',
                                            markersize=8)
                ax.add_artist(agent_start)

            # add agents and their numbers
            agent_positions = [[
                state[j].position for j in range(len(self.agents))
            ] for state in self.states]
            agents = [
                plt.Circle(agent_positions[0][i],
                           self.agents[i].radius,
                           fill=False,
                           color=cmap(i)) for i in range(len(self.agents))
            ]

            # disable showing human numbers
            if display_numbers:
                agent_numbers = [
                    plt.text(agents[i].center[0] - x_offset,
                             agents[i].center[1] + y_offset,
                             str(i),
                             color='black') for i in range(len(self.agents))
                ]

            for i, agent in enumerate(agents):
                ax.add_artist(agent)
                if display_numbers:
                    ax.add_artist(agent_numbers[i])

            # add time annotation
            time = plt.text(0.4,
                            0.9,
                            'Time: {}'.format(0),
                            fontsize=16,
                            transform=ax.transAxes)
            ax.add_artist(time)

            global_step = 0

            def update(frame_num):
                nonlocal global_step
                # nonlocal arrows
                global_step = frame_num

                for i, agent in enumerate(agents):
                    agent.center = agent_positions[frame_num][i]
                    if display_numbers:
                        agent_numbers[i].set_position(
                            (agent.center[0] - x_offset,
                             agent.center[1] + y_offset))
                # for arrow in arrows:
                #     arrow.remove()

                # for i in range(self.human_num + 1):
                #     orientation = orientations[i]
                #     if i == 0:
                #         arrows = [patches.FancyArrowPatch(*orientation[frame_num], color='black',
                #                                           arrowstyle=arrow_style)]
                #     else:
                #         arrows.extend([patches.FancyArrowPatch(*orientation[frame_num], color=cmap(i - 1),
                #                                                arrowstyle=arrow_style)])
                #
                # for arrow in arrows:
                #     ax.add_artist(arrow)
                # if hasattr(self.robot.policy, 'get_attention_weights'):
                #     attention_sco res[i].set_text('human {}: {:.2f}'.format(i, self.attention_weights[frame_num][i]))

                time.set_text('Time: {:.2f}'.format(frame_num *
                                                    self.time_step))

            def on_click(event):
                if anim.running:
                    anim.event_source.stop()
                else:
                    anim.event_source.start()
                anim.running ^= True

            fig.canvas.mpl_connect('key_press_event', on_click)
            anim = animation.FuncAnimation(fig,
                                           update,
                                           frames=len(self.states),
                                           interval=self.time_step * 500)
            anim.running = True

            if output_file is not None:
                # save as video
                ffmpeg_writer = animation.FFMpegWriter(
                    fps=10, metadata=dict(artist='Me'), bitrate=1800)
                # writer = ffmpeg_writer(fps=10, metadata=dict(artist='Me'), bitrate=1800)
                anim.save(output_file, writer=ffmpeg_writer)

                # save output file as gif if imagemagic is installed
                # anim.save(output_file, writer='imagemagic', fps=12)
            else:
                plt.show()
        else:
            raise NotImplementedError
Example #14
0
    def render(self, mode='human', output_file=None):
        from matplotlib import animation
        import matplotlib as mpl
        import matplotlib.pyplot as plt
        plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'

        use_dark = True

        def num2color(values, cmap):
            norm = mpl.colors.Normalize(vmin=np.min(values),
                                        vmax=np.max(values))
            cmap = mpl.cm.get_cmap(cmap)
            return [cmap(norm(val)) for val in values]

        x_offset = 0.11
        y_offset = 0.11
        cmap = plt.cm.get_cmap('hsv', 10)
        robot_color = 'yellow'
        goal_color = 'red'
        arrow_color = 'red'
        # arrow_style = patches.ArrowStyle("->", head_length=4, head_width=2)
        human_color = 'blue'
        arrow_style = patches.ArrowStyle("->", head_length=5, head_width=2)

        if mode == 'human':
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.set_xlim(-4, 4)
            ax.set_ylim(-4, 4)
            for human in self.humans:
                human_circle = plt.Circle(human.get_position(),
                                          human.radius,
                                          fill=True,
                                          color='b')
                ax.add_artist(human_circle)
            ax.add_artist(
                plt.Circle(self.robot.get_position(),
                           self.robot.radius,
                           fill=True,
                           color='r'))
            plt.show()
        elif mode == 'traj':
            if use_dark:
                plt.style.use('dark_background')
            fig, ax = plt.subplots(figsize=(7, 7))

            ax.tick_params(labelsize=16)
            ax.set_xlim(-5, 5)
            ax.set_ylim(-5, 5)
            ax.set_xlabel('x(m)', fontsize=16)
            ax.set_ylabel('y(m)', fontsize=16)

            robot_positions = [
                self.states[i][0].position for i in range(len(self.states))
            ]
            human_positions = [[
                self.states[i][1][j].position for j in range(len(self.humans))
            ] for i in range(len(self.states))]
            for k in range(len(self.states)):
                if k % 4 == 0 or k == len(self.states) - 1:
                    robot = plt.Circle(robot_positions[k],
                                       self.robot.radius,
                                       fill=True,
                                       color=robot_color)
                    humans = [
                        plt.Circle(human_positions[k][i],
                                   self.humans[i].radius,
                                   fill=True,
                                   color=cmap(i))
                        for i in range(len(self.humans))
                    ]
                    ax.add_artist(robot)
                    for human in humans:
                        ax.add_artist(human)
                # add time annotation
                global_time = k * self.time_step
                if global_time % 4 == 0 or k == len(self.states) - 1:
                    agents = humans + [robot]
                    times = [
                        plt.text(agents[i].center[0] - x_offset,
                                 agents[i].center[1] - y_offset,
                                 '{:.1f}'.format(global_time),
                                 color='black',
                                 fontsize=14)
                        for i in range(self.human_num + 1)
                    ]
                    for time in times:
                        ax.add_artist(time)
                if k != 0:
                    nav_direction = plt.Line2D(
                        (self.states[k - 1][0].px, self.states[k][0].px),
                        (self.states[k - 1][0].py, self.states[k][0].py),
                        color=robot_color,
                        ls='solid')
                    human_directions = [
                        plt.Line2D((self.states[k - 1][1][i].px,
                                    self.states[k][1][i].px),
                                   (self.states[k - 1][1][i].py,
                                    self.states[k][1][i].py),
                                   color=cmap(i),
                                   ls='solid') for i in range(self.human_num)
                    ]
                    ax.add_artist(nav_direction)
                    for human_direction in human_directions:
                        ax.add_artist(human_direction)
            plt.legend([robot], ['Robot'], fontsize=16)
            plt.show()
        elif mode == 'video':
            use_rate_statastic = True
            if use_dark:
                plt.style.use('dark_background')
            if use_rate_statastic:
                # fig = plt.figure(figsize=(10, 7))
                fig = plt.figure()
                # ax = fig.add_subplot(1, 2, 1)
                # ax2 = fig.add_subplot(2, 2, 2)
                # ax3 = fig.add_subplot(2, 2, 4)
                ax = plt.axes([0.12, 0.23, 0.55, 0.7])
                ax2 = plt.axes([0.82, 0.6, 0.15, 0.3])
                ax3 = plt.axes([0.82, 0.15, 0.15, 0.3])
                ax4 = plt.axes([0.08, 0.08, 0.6, 0.05])
                # box = dict(facecolor='yellow', pad=5, alpha=0.2)

                fontsize = 12
                # ax.tick_params(labelsize=12)
                # ax.set_xlabel('x(m)', fontsize=12)
                # ax.set_ylabel('y(m)', fontsize=12)
            else:
                fig, ax = plt.subplots(figsize=(7, 7))
                fontsize = 16
            ax.tick_params(labelsize=fontsize)
            ax.set_xlim(-6, 6)
            ax.set_ylim(-6, 6)
            ax.set_xlabel('x(m)', fontsize=12)
            ax.set_ylabel('y(m)', fontsize=12)

            # add robot and its goal
            robot_positions = [state[0].position for state in self.states]
            goal = mlines.Line2D([0], [4],
                                 color=goal_color,
                                 marker='*',
                                 linestyle='None',
                                 markersize=15,
                                 label='Goal')
            robot = plt.Circle(robot_positions[0],
                               self.robot.radius,
                               fill=True,
                               color=robot_color,
                               label='Robot')
            ax.add_artist(robot)
            ax.add_artist(goal)
            # plt.legend([robot, goal], ['Robot', 'Goal'], fontsize=16)
            ax.legend([robot, goal], ['Robot', 'Goal'],
                      fontsize=12,
                      loc='upper left')

            # add humans and their numbers
            human_positions = [[
                state[1][j].position for j in range(len(self.humans))
            ] for state in self.states]
            # humans = [
            #     plt.Circle(human_positions[0][i], self.humans[i].radius, fill=False, color='b') for i in range(len(self.humans))
            # ]
            # human_numbers = [
            #     ax.text(humans[i].center[0] - x_offset,
            #              humans[i].center[1] - y_offset,
            #              str(i),
            #              color='black',
            #              fontsize=10) for i in range(len(self.humans))
            # ]
            humans = [
                plt.Circle(human_positions[0][i],
                           self.humans[i].radius,
                           fill=True,
                           color='lime') for i in range(len(self.humans))
            ]

            human_numbers = [
                ax.text(humans[i].center[0] - x_offset,
                        humans[i].center[1] - y_offset,
                        str(i),
                        color='white',
                        fontsize=fontsize) for i in range(len(self.humans))
            ]
            for i, human in enumerate(humans):
                ax.add_artist(human)
                ax.add_artist(human_numbers[i])

            # add time annotation
            time = ax.text(-1, 5, 'Time: {}'.format(0), fontsize=16)
            ax.add_artist(time)

            # compute attention scores
            if self.attention_weights is not None:
                show_txt_att = False
                cmap_name = 'winter'
                if show_txt_att:
                    attention_scores = [
                        ax.text(-5.5,
                                5 - 0.5 * i,
                                'Human {}: {:.2f}'.format(
                                    i + 1, self.attention_weights[0][0][i]),
                                fontsize=10) for i in range(len(self.humans))
                    ]
                # ax.add_artist(attention_scores)
                cmap = mpl.cm.get_cmap(cmap_name)
                colors = cmap(np.linspace(0, 1, cmap.N))
                ax4.imshow([colors], extent=[-7, 6, 0, 1])
                ax4.set_xticklabels([
                    '0.0', '0.1', '0.25', '0.4', '0.55', '0.7', '0.85', '1.0'
                ],
                                    fontsize=10)
                ax4.set_yticks([])

            # compute orientation in each step and use arrow to show the direction
            radius = self.robot.radius
            if self.robot.kinematics == 'unicycle':
                orientation = [
                    ((state[0].px, state[0].py),
                     (state[0].px + radius * np.cos(state[0].theta),
                      state[0].py + radius * np.sin(state[0].theta)))
                    for state in self.states
                ]
                orientations = [orientation]
            else:
                orientations = []
                for i in range(self.human_num + 1):
                    orientation = []
                    for state in self.states:
                        if i == 0:
                            agent_state = state[0]
                        else:
                            agent_state = state[1][i - 1]
                        theta = np.arctan2(agent_state.vy, agent_state.vx)
                        orientation.append(
                            ((agent_state.px, agent_state.py),
                             (agent_state.px + radius * np.cos(theta),
                              agent_state.py + radius * np.sin(theta))))
                    orientations.append(orientation)
            arrows = [
                patches.FancyArrowPatch(*orientation[0],
                                        color=arrow_color,
                                        arrowstyle=arrow_style)
                for orientation in orientations
            ]
            for arrow in arrows:
                ax.add_artist(arrow)

            if use_rate_statastic:
                #ax2: Total_EGRU_rate
                Total_EGRU = 0
                cnt_EGRU1 = 0
                cnt_EGRU2 = 0
                cnt_EGRU3 = 0
                x_bar = np.arange(4)  # the label locations
                width = 0.2  # the width of the bars

                ax2.set_xlabel('Number of GRUs', fontsize=10)
                ax2.set_ylabel('Usage rate (%)', fontsize=10)
                ax2.set_title("Each step", fontsize=10)
                bar_Step_EGRU = ax2.bar(x_bar, [0, 0, 0, 0])
                ax2.set_xlim(0.5, 3.5)
                ax2.set_ylim(0, 1)
                #ax3: Each_step_EGRU_rate
                ax3.set_xlabel('Number of GRUs', fontsize=10)
                ax3.set_ylabel('Usage rate (%)', fontsize=10)
                ax3.set_title("Total steps", fontsize=10)
                bar_Rate_EGRU = ax3.bar(x_bar, [0, 0, 0, 0])
                ax3.set_xlim(0.5, 3.5)
                ax3.set_ylim(0., 100)
                # labels = ['', '1 EGRU', '2 EGRUs', '3 EGRUs']
            global_step = 0

            def update(frame_num):
                nonlocal global_step
                nonlocal arrows
                global_step = frame_num
                robot.center = robot_positions[frame_num]
                for i, human in enumerate(humans):
                    human.center = human_positions[frame_num][i]
                    human_numbers[i].set_position((human.center[0] - x_offset,
                                                   human.center[1] - y_offset))
                    for arrow in arrows:
                        arrow.remove()
                    arrows = [
                        patches.FancyArrowPatch(*orientation[frame_num],
                                                color=arrow_color,
                                                arrowstyle=arrow_style)
                        for orientation in orientations
                    ]
                    for arrow in arrows:
                        ax.add_artist(arrow)
                    if self.attention_weights is not None:
                        weight = self.attention_weights[frame_num][0, :]
                        colors = num2color(weight, cmap_name)
                        human.set_color(colors[i])
                        if show_txt_att:
                            attention_scores[i].set_text(
                                'human {}: {:.2f}'.format(i, weight[i]))

                time.set_text('Time: {:.2f}'.format(frame_num *
                                                    self.time_step))
                if use_rate_statastic:
                    nonlocal Total_EGRU
                    nonlocal cnt_EGRU1
                    nonlocal cnt_EGRU2
                    nonlocal cnt_EGRU3
                    # print(global_step)
                    if global_step == (len(self.step_list)):
                        cnt_EGRU1 = 0.0
                        cnt_EGRU2 = 0.0
                        cnt_EGRU3 = 0.0
                        Total_EGRU = 0.0
                        step_egrus = [0, 0, 0, 0]
                    else:
                        # print("crowd_sim step: %d, %d, %d" % (self.robot.policy.step_cnt, self.robot.policy.step2_cnt, self.robot.policy.step3_cnt))
                        step_egrus = self.step_list[
                            frame_num]  #[0, 1-egru, 2-egru, 3-egru]
                    # print(step_egrus)
                    Total_EGRU += sum(step_egrus)
                    cnt_EGRU1 += step_egrus[1]
                    cnt_EGRU2 += step_egrus[2]
                    cnt_EGRU3 += step_egrus[3]
                    if Total_EGRU == 0:
                        Rate_EGRU1 = 0.0
                        Rate_EGRU2 = 0.0
                        Rate_EGRU3 = 0.0
                    else:
                        Rate_EGRU1 = cnt_EGRU1 / Total_EGRU * 100
                        Rate_EGRU2 = cnt_EGRU2 / Total_EGRU * 100
                        Rate_EGRU3 = cnt_EGRU3 / Total_EGRU * 100
                    rate_egrus = [0, Rate_EGRU1, Rate_EGRU2, Rate_EGRU3]

                    cnt = 0
                    for s_rect, r_rect, s, r in zip(bar_Step_EGRU,
                                                    bar_Rate_EGRU, step_egrus,
                                                    rate_egrus):
                        s_rect.set_height(s)
                        r_rect.set_height(r)
                        if cnt % 4 == 1:
                            s_rect.set_color('#F6BB36')
                            r_rect.set_color('#F6BB36')
                        elif cnt % 4 == 2:
                            s_rect.set_color('#25A1FA')
                            r_rect.set_color('#25A1FA')
                        elif cnt % 4 == 3:
                            s_rect.set_color('#8FE37C')
                            r_rect.set_color('#8FE37C')
                        else:
                            cnt = 0
                        cnt += 1

                if use_rate_statastic:
                    return [s_rect for s_rect in bar_Step_EGRU
                            ] + [r_rect for r_rect in bar_Rate_EGRU]

            def plot_value_heatmap():
                assert self.robot.kinematics == 'holonomic'
                for agent in [self.states[global_step][0]
                              ] + self.states[global_step][1]:
                    print(('{:.4f}, ' * 6 + '{:.4f}').format(
                        agent.px, agent.py, agent.gx, agent.gy, agent.vx,
                        agent.vy, agent.theta))
                # when any key is pressed draw the action value plot
                fig, axis = plt.subplots()
                speeds = [0] + self.robot.policy.speeds
                rotations = np.hstack([self.robot.policy.rotations, np.pi * 2])
                r, th = np.meshgrid(speeds, rotations)
                z = np.array(self.action_values[global_step %
                                                len(self.states)][1:])
                z = (z - np.min(z)) / (np.max(z) - np.min(z))
                # z = np.reshape(z, (16, 5))
                z = np.reshape(z, (len(rotations) - 1, len(speeds) - 1))
                polar = plt.subplot(projection="polar")
                polar.tick_params(labelsize=16)
                mesh = plt.pcolormesh(th, r, z, vmin=0, vmax=1)
                plt.plot(rotations, r, color='k', ls='none')
                plt.grid()
                cbaxes = fig.add_axes([0.85, 0.1, 0.03, 0.8])
                cbar = plt.colorbar(mesh, cax=cbaxes)
                cbar.ax.tick_params(labelsize=16)
                plt.show()

            def on_click(event):
                anim.running ^= True
                if anim.running:
                    anim.event_source.stop()
                    if hasattr(self.robot.policy, 'action_values'):
                        plot_value_heatmap()
                else:
                    anim.event_source.start()

            fig.canvas.mpl_connect('key_press_event', on_click)
            anim = animation.FuncAnimation(fig,
                                           update,
                                           frames=len(self.states),
                                           interval=self.time_step * 800)
            anim.running = True

            if output_file is not None:
                ffmpeg_writer = animation.writers['ffmpeg']
                writer = ffmpeg_writer(fps=10,
                                       metadata=dict(artist='Me'),
                                       bitrate=2000)
                anim.save(output_file, writer=writer)
            else:
                plt.show()
        else:
            raise NotImplementedError
    def render(self, mode="video", output_file=None):
        from matplotlib import animation
        import matplotlib.pyplot as plt

        # plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
        x_offset = 0.2
        y_offset = 0.4
        cmap = plt.cm.get_cmap("hsv", 10)
        robot_color = "black"
        arrow_style = patches.ArrowStyle("->", head_length=4, head_width=2)
        display_numbers = True

        if mode == "traj":
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=16)
            ax.set_xlim(-5, 5)
            ax.set_ylim(-5, 5)
            ax.set_xlabel("x(m)", fontsize=16)
            ax.set_ylabel("y(m)", fontsize=16)

            # add human start positions and goals
            human_colors = [cmap(i) for i in range(len(self.humans))]
            for i in range(len(self.humans)):
                human = self.humans[i]
                human_goal = mlines.Line2D(
                    [human.get_goal_position()[0]],
                    [human.get_goal_position()[1]],
                    color=human_colors[i],
                    marker="*",
                    linestyle="None",
                    markersize=15,
                )
                ax.add_artist(human_goal)
                human_start = mlines.Line2D(
                    [human.get_start_position()[0]],
                    [human.get_start_position()[1]],
                    color=human_colors[i],
                    marker="o",
                    linestyle="None",
                    markersize=15,
                )
                ax.add_artist(human_start)

            robot_positions = [
                self.states[i][0].position for i in range(len(self.states))
            ]
            human_positions = [
                [self.states[i][1][j].position for j in range(len(self.humans))]
                for i in range(len(self.states))
            ]

            for k in range(len(self.states)):
                if k % 4 == 0 or k == len(self.states) - 1:
                    robot = plt.Circle(
                        robot_positions[k],
                        self.robot.radius,
                        fill=False,
                        color=robot_color,
                    )
                    humans = [
                        plt.Circle(
                            human_positions[k][i],
                            self.humans[i].radius,
                            fill=False,
                            color=cmap(i),
                        )
                        for i in range(len(self.humans))
                    ]
                    ax.add_artist(robot)
                    for human in humans:
                        ax.add_artist(human)

                # add time annotation
                global_time = k * self.time_step
                if global_time % 4 == 0 or k == len(self.states) - 1:
                    agents = humans + [robot]
                    times = [
                        plt.text(
                            agents[i].center[0] - x_offset,
                            agents[i].center[1] - y_offset,
                            "{:.1f}".format(global_time),
                            color="black",
                            fontsize=14,
                        )
                        for i in range(self.human_num + 1)
                    ]
                    for time in times:
                        ax.add_artist(time)
                if k != 0:
                    nav_direction = plt.Line2D(
                        (self.states[k - 1][0].px, self.states[k][0].px),
                        (self.states[k - 1][0].py, self.states[k][0].py),
                        color=robot_color,
                        ls="solid",
                    )
                    human_directions = [
                        plt.Line2D(
                            (self.states[k - 1][1][i].px, self.states[k][1][i].px),
                            (self.states[k - 1][1][i].py, self.states[k][1][i].py),
                            color=cmap(i),
                            ls="solid",
                        )
                        for i in range(self.human_num)
                    ]
                    ax.add_artist(nav_direction)
                    for human_direction in human_directions:
                        ax.add_artist(human_direction)
            plt.legend([robot], ["Robot"], fontsize=16)
            plt.show()
        elif mode == "video":
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=12)
            ax.set_xlim(-11, 11)
            ax.set_ylim(-11, 11)
            ax.set_xlabel("x(m)", fontsize=14)
            ax.set_ylabel("y(m)", fontsize=14)
            show_human_start_goal = False

            # add human start positions and goals
            human_colors = [cmap(i) for i in range(len(self.humans))]
            if show_human_start_goal:
                for i in range(len(self.humans)):
                    human = self.humans[i]
                    human_goal = mlines.Line2D(
                        [human.get_goal_position()[0]],
                        [human.get_goal_position()[1]],
                        color=human_colors[i],
                        marker="*",
                        linestyle="None",
                        markersize=8,
                    )
                    ax.add_artist(human_goal)
                    human_start = mlines.Line2D(
                        [human.get_start_position()[0]],
                        [human.get_start_position()[1]],
                        color=human_colors[i],
                        marker="o",
                        linestyle="None",
                        markersize=8,
                    )
                    ax.add_artist(human_start)
            # add robot start position
            robot_start = mlines.Line2D(
                [self.robot.get_start_position()[0]],
                [self.robot.get_start_position()[1]],
                color=robot_color,
                marker="o",
                linestyle="None",
                markersize=8,
            )
            robot_start_position = [
                self.robot.get_start_position()[0],
                self.robot.get_start_position()[1],
            ]
            ax.add_artist(robot_start)
            # add robot and its goal
            robot_positions = [state[0].position for state in self.states]
            goal = mlines.Line2D(
                [self.robot.get_goal_position()[0]],
                [self.robot.get_goal_position()[1]],
                color=robot_color,
                marker="*",
                linestyle="None",
                markersize=15,
                label="Goal",
            )
            robot = plt.Circle(
                robot_positions[0], self.robot.radius, fill=False, color=robot_color
            )
            # sensor_range = plt.Circle(robot_positions[0], self.robot_sensor_range, fill=False, ls='dashed')
            ax.add_artist(robot)
            ax.add_artist(goal)
            plt.legend([robot, goal], ["Robot", "Goal"], fontsize=14)

            # add humans and their numbers
            human_positions = [
                [state[1][j].position for j in range(len(self.humans))]
                for state in self.states
            ]
            humans = [
                plt.Circle(
                    human_positions[0][i],
                    self.humans[i].radius,
                    fill=False,
                    color=cmap(i),
                )
                for i in range(len(self.humans))
            ]

            # disable showing human numbers
            if display_numbers:
                human_numbers = [
                    plt.text(
                        humans[i].center[0] - x_offset,
                        humans[i].center[1] + y_offset,
                        str(i),
                        color="black",
                    )
                    for i in range(len(self.humans))
                ]

            for i, human in enumerate(humans):
                ax.add_artist(human)
                if display_numbers:
                    ax.add_artist(human_numbers[i])

            # add time annotation
            time = plt.text(
                0.4, 0.9, "Time: {}".format(0), fontsize=16, transform=ax.transAxes
            )
            ax.add_artist(time)

            # visualize attention scores
            # if hasattr(self.robot.policy, 'get_attention_weights'):
            #     attention_scores = [
            #         plt.text(-5.5, 5 - 0.5 * i, 'Human {}: {:.2f}'.format(i + 1, self.attention_weights[0][i]),
            #                  fontsize=16) for i in range(len(self.humans))]

            # compute orientation in each step and use arrow to show the direction
            radius = self.robot.radius
            orientations = []
            for i in range(self.human_num + 1):
                orientation = []
                for state in self.states:
                    agent_state = state[0] if i == 0 else state[1][i - 1]
                    if self.robot.kinematics == "unicycle" and i == 0:
                        direction = (
                            (agent_state.px, agent_state.py),
                            (
                                agent_state.px + radius * np.cos(agent_state.theta),
                                agent_state.py + radius * np.sin(agent_state.theta),
                            ),
                        )
                    else:
                        theta = np.arctan2(agent_state.vy, agent_state.vx)
                        direction = (
                            (agent_state.px, agent_state.py),
                            (
                                agent_state.px + radius * np.cos(theta),
                                agent_state.py + radius * np.sin(theta),
                            ),
                        )
                    orientation.append(direction)
                orientations.append(orientation)
                if i == 0:
                    arrow_color = "black"
                    arrows = [
                        patches.FancyArrowPatch(
                            *orientation[0], color=arrow_color, arrowstyle=arrow_style
                        )
                    ]
                else:
                    arrows.extend(
                        [
                            patches.FancyArrowPatch(
                                *orientation[0],
                                color=human_colors[i - 1],
                                arrowstyle=arrow_style
                            )
                        ]
                    )

            for arrow in arrows:
                ax.add_artist(arrow)
            global_step = 0

            if len(self.trajs) != 0:
                human_future_positions = []
                human_future_circles = []
                for traj in self.trajs:
                    human_future_position = [
                        [
                            tensor_to_joint_state(traj[step + 1][0])
                            .human_states[i]
                            .position
                            for step in range(self.robot.policy.planning_depth)
                        ]
                        for i in range(self.human_num)
                    ]
                    human_future_positions.append(human_future_position)

                for i in range(self.human_num):
                    circles = []
                    for j in range(self.robot.policy.planning_depth):
                        circle = plt.Circle(
                            human_future_positions[0][i][j],
                            self.humans[0].radius / (1.7 + j),
                            fill=False,
                            color=cmap(i),
                        )
                        ax.add_artist(circle)
                        circles.append(circle)
                    human_future_circles.append(circles)

            def update(frame_num):
                nonlocal global_step
                nonlocal arrows
                global_step = frame_num
                robot.center = robot_positions[frame_num]

                for i, human in enumerate(humans):
                    human.center = human_positions[frame_num][i]
                    if display_numbers:
                        human_numbers[i].set_position(
                            (human.center[0] - x_offset, human.center[1] + y_offset)
                        )
                for arrow in arrows:
                    arrow.remove()

                for i in range(self.human_num + 1):
                    orientation = orientations[i]
                    if i == 0:
                        arrows = [
                            patches.FancyArrowPatch(
                                *orientation[frame_num],
                                color="black",
                                arrowstyle=arrow_style
                            )
                        ]
                    else:
                        arrows.extend(
                            [
                                patches.FancyArrowPatch(
                                    *orientation[frame_num],
                                    color=cmap(i - 1),
                                    arrowstyle=arrow_style
                                )
                            ]
                        )

                for arrow in arrows:
                    ax.add_artist(arrow)
                    # if hasattr(self.robot.policy, 'get_attention_weights'):
                    #     attention_scores[i].set_text('human {}: {:.2f}'.format(i, self.attention_weights[frame_num][i]))

                time.set_text("Time: {:.2f}".format(frame_num * self.time_step))

                if len(self.trajs) != 0:
                    for i, circles in enumerate(human_future_circles):
                        for j, circle in enumerate(circles):
                            circle.center = human_future_positions[global_step][i][j]

            def plot_value_heatmap():
                if self.robot.kinematics != "holonomic":
                    print("Kinematics is not holonomic")
                    return
                # for agent in [self.states[global_step][0]] + self.states[global_step][1]:
                #     print(('{:.4f}, ' * 6 + '{:.4f}').format(agent.px, agent.py, agent.gx, agent.gy,
                #                                              agent.vx, agent.vy, agent.theta))

                # when any key is pressed draw the action value plot
                fig, axis = plt.subplots()
                speeds = [0] + self.robot.policy.speeds
                rotations = self.robot.policy.rotations + [np.pi * 2]
                r, th = np.meshgrid(speeds, rotations)
                z = np.array(self.action_values[global_step % len(self.states)][1:])
                z = (z - np.min(z)) / (np.max(z) - np.min(z))
                z = np.reshape(
                    z,
                    (
                        self.robot.policy.rotation_samples,
                        self.robot.policy.speed_samples,
                    ),
                )
                polar = plt.subplot(projection="polar")
                polar.tick_params(labelsize=16)
                mesh = plt.pcolormesh(th, r, z, vmin=0, vmax=1)
                plt.plot(rotations, r, color="k", ls="none")
                plt.grid()
                cbaxes = fig.add_axes([0.85, 0.1, 0.03, 0.8])
                cbar = plt.colorbar(mesh, cax=cbaxes)
                cbar.ax.tick_params(labelsize=16)
                plt.show()

            def print_matrix_A():
                # with np.printoptions(precision=3, suppress=True):
                #     print(self.As[global_step])
                h, w = self.As[global_step].shape
                print("   " + " ".join(["{:>5}".format(i - 1) for i in range(w)]))
                for i in range(h):
                    print(
                        "{:<3}".format(i - 1)
                        + " ".join(
                            [
                                "{:.3f}".format(self.As[global_step][i][j])
                                for j in range(w)
                            ]
                        )
                    )
                # with np.printoptions(precision=3, suppress=True):
                #     print('A is: ')
                #     print(self.As[global_step])

            def print_feat():
                with np.printoptions(precision=3, suppress=True):
                    print("feat is: ")
                    print(self.feats[global_step])

            def print_X():
                with np.printoptions(precision=3, suppress=True):
                    print("X is: ")
                    print(self.Xs[global_step])

            def on_click(event):
                if anim.running:
                    anim.event_source.stop()
                    if event.key == "a":
                        if hasattr(self.robot.policy, "get_matrix_A"):
                            print_matrix_A()
                        if hasattr(self.robot.policy, "get_feat"):
                            print_feat()
                        if hasattr(self.robot.policy, "get_X"):
                            print_X()
                        # if hasattr(self.robot.policy, 'action_values'):
                        #    plot_value_heatmap()
                else:
                    anim.event_source.start()
                anim.running ^= True

            fig.canvas.mpl_connect("key_press_event", on_click)
            anim = animation.FuncAnimation(
                fig, update, frames=len(self.states), interval=self.time_step * 500
            )
            anim.running = True

            if output_file is not None:
                # save as video
                ffmpeg_writer = animation.FFMpegWriter(
                    fps=10, metadata=dict(artist="Me"), bitrate=1800
                )
                # writer = ffmpeg_writer(fps=10, metadata=dict(artist='Me'), bitrate=1800)
                anim.save(output_file, writer=ffmpeg_writer)

                # save output file as gif if imagemagic is installed
                # anim.save(output_file, writer='imagemagic', fps=12)
            else:
                plt.show()
        else:
            raise NotImplementedError