コード例 #1
0
ファイル: myplt.py プロジェクト: Yamanaka-Lab-TUAT/DNN-NMT
def draw_all(texture, erase_number=True):
    if erase_number:
        plt.tick_params(labelbottom=False,
                        labelleft=False,
                        labelright=False,
                        labeltop=False)
    # ax.set_xlabel('$\\varphi_1 [{\\rm deg}]$')
    # ax.set_ylabel('$\\phi [\\rm deg]$')
    # ax.set_zlabel('$\\varphi_2 [\\rm deg]$')
    ax.xaxis._axinfo['juggled'] = (2, 0, 1)
    ax.yaxis._axinfo['juggled'] = (2, 1, 0)
    ax.zaxis._axinfo['juggled'] = (2, 2, 2)

    ax.set_xlim(0, 360)
    ax.set_ylim(0, 180)
    ax.invert_yaxis()
    ax.set_zlim(0, 360)
    ax.invert_zaxis()
    aff = np.diag([1, 0.5, 1, 1])  # Positioning
    aff[0][3] = 100  # x
    aff[1][3] = 100  # y
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), aff)
    ax.plot(texture[:, 0],
            texture[:, 1],
            texture[:, 2],
            "o",
            color="blue",
            ms=2,
            mew=0.5)
    plt.yticks(range(0, 180 + 1, 90))
    plt.xticks(range(0, 360 + 1, 90))
    ax.set_zticks(range(0, 360 + 1, 90))
コード例 #2
0
ファイル: myplt.py プロジェクト: Yamanaka-Lab-TUAT/DNN-NMT
def draw(texture, c, lab, erase_number=True):
    if erase_number:
        plt.tick_params(labelbottom=False,
                        labelleft=False,
                        labelright=False,
                        labeltop=False)
    # ax.set_xlabel('$\\varphi_1 [{\\rm deg}]$')
    # ax.set_ylabel('$\\phi [\\rm deg]$')
    # ax.set_zlabel('$\\varphi_2 [\\rm deg]$')
    ax.xaxis._axinfo['juggled'] = (2, 0, 1)
    ax.yaxis._axinfo['juggled'] = (2, 1, 0)
    ax.zaxis._axinfo['juggled'] = (2, 2, 2)

    ax.set_xlim(0, 90)
    ax.set_ylim(0, 90)
    ax.invert_yaxis()
    ax.set_zlim(0, 90)
    ax.invert_zaxis()
    aff = np.diag([1, 1, 1, 1])  # Positioning
    aff[0][3] = -20
    aff[1][3] = 0
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), aff)
    lim_tex = texture[np.where((texture[:, 0] < 90.) & (texture[:, 1] < 90.)
                               & (texture[:, 2] < 90.))]
    ax.scatter(lim_tex[:, 0],
               lim_tex[:, 1],
               lim_tex[:, 2],
               color=c,
               marker='o',
               label=lab)
    plt.yticks(range(0, 90 + 1, 30))
    plt.xticks(range(0, 90 + 1, 30))
    ax.set_zticks(range(0, 90 + 1, 30))
    ax.legend()
コード例 #3
0
def static_surf(dfs):
    """
    #surf creates the surface plot
    """

    fig, ax = plt.subplots(subplot_kw=dict(projection='3d'), figsize=(12, 10))

    for idx, df in enumerate(dfs):
        df = df.set_index('Unnamed: 0')
        x = range(len(df.index))
        y = range(len(df.columns))
        mx, my = np.meshgrid(x, y, indexing='ij')
        z = df.round(decimals=2)
        z = z.fillna(0)
        z = z.astype(int)
        surf = ax.plot_surface(mx,
                               my,
                               z,
                               cmap=cm.coolwarm,
                               linewidth=0,
                               alpha=(idx + 0.01) / 5)

    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax),
                                 np.diag([1.0, 1.0, 1.0, 1.0]))
    ax.set_xticklabels(df.index)
    ax.set_yticklabels(df.columns)
    ax.set_xlabel('Medium components')
    #ax.set_ylabel('Demand reactions')
    #ax.set_zlabel('Metabolic Flux')
    ax.view_init(60, 35)
    fig.colorbar(surf, shrink=0.5, aspect=5)
    fig.tight_layout()
    fig.savefig('test.png')
コード例 #4
0
def plot_3DElevFrames(DomainTS, SL, TMAX, DuneDomain):

    from Barrier3D_Parameters import (BarrierLength, BermEl, DuneWidth)

    for t in range(0, len(DomainTS)):
        # Build beach elevation domain
        BW = 6
        Beach = np.zeros([BW, BarrierLength])
        berm = math.ceil(BW * 0.65)
        Beach[berm:BW + 1, :] = BermEl
        add = (BermEl - SL) / berm
        for i in range(berm):
            Beach[i, :] = SL + add * i

        # Construct frame
        Dunes = [(DuneDomain[TMAX] + BermEl) * 10] * DuneWidth
        Water = np.zeros([3, BarrierLength])
        Domain = DomainTS[TMAX] * 10
        Domain = np.vstack([Water, Beach, Dunes, Domain, Water])
        Dlen = np.shape(Domain)[1]
        Dwid = np.shape(Domain)[0]
        fig = plt.figure(figsize=(12, 9))
        ax = fig.add_subplot(111, projection='3d')
        scale_x = 1
        scale_y = Dwid / Dlen
        scale_z = 4 / Dlen * 4
        ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax),
                                     np.diag([scale_x, scale_y, scale_z, 1]))
        X = np.arange(Dlen)
        Y = np.arange(Dwid)
        X, Y = np.meshgrid(X, Y)
        Z = Domain
        ax.plot_surface(X,
                        Y,
                        Z,
                        cmap='terrain',
                        alpha=1,
                        vmin=-1.1,
                        vmax=4.0,
                        linewidth=0,
                        shade=True)
        ax.set_zlim(0, 4)

        # Plot shrubs
        #    Shrubs = PercentCoverTS[t]
        #    Shrubs[Shrubs>0] = 1
        #    Shrubs = np.vstack([np.zeros([DuneWidth,BarrierLength]), Shrubs])
        #    Shrubs = Shrubs * Domain
        #    Shrubs[Shrubs>0] = Shrubs[Shrubs>0] + 0.1
        #    Shrubs[Shrubs<1] = None
        #    ax.scatter(X, Y+1, Shrubs, s=30, c='black')

        timestr = 'Time = ' + str(t) + ' yrs'
        ax.set_ylabel(timestr)
        ax.view_init(20, 155)
        plt.subplots_adjust(left=-1.2, right=1.3, top=2.2,
                            bottom=-0.3)  #mostly centered
        plt.show()
        name = 'Output/SimFrames/3D_' + str(t)
        fig.savefig(name, dpi=150)
コード例 #5
0
def axLabel(ax):
    ax.view_init(azim=-60, elev=40)
    #ax.view_init(azim=-60, elev=-40)
    stepsizex = 20
    stepsizeyz = stepsizex / 2
    startx, endx = ax.get_xlim()

    xticks = np.arange(0, endx, stepsizex)
    ax.xaxis.set_ticks(xticks)
    ax.xaxis.set_ticklabels(['%d' % (x * 10 / scale) for x in xticks])

    starty, endy = ax.get_ylim()
    yticks = np.arange(0, endy, stepsizeyz)
    ax.yaxis.set_ticks(yticks)
    ax.yaxis.set_ticklabels(['%d' % (y * 10 / scale) for y in yticks])

    startz, endz = ax.get_zlim()
    zticks = np.arange(0, endz, stepsizeyz)
    ax.zaxis.set_ticks(zticks)
    ax.zaxis.set_ticklabels(['%d' % (z * 10 / scale) for z in zticks])

    ax.get_proj = lambda: np.dot(
        Axes3D.get_proj(ax),
        np.diag([upscale * 0.75, upscale * 0.3, upscale * 0.5, 1]))
    ax.set_xlabel('\n[µm]', linespacing=3.2)
コード例 #6
0
ファイル: data_visualizer.py プロジェクト: inspiros/pcmvda
 def _init_axe(self, ax=None, dim=2):
     assert 0 < dim <= 3
     if self.fig is None:
         self.fig = plt.figure(**self.figure_params)
     if ax is None or (isinstance(ax, int) and ax >= len(self.axes)):
         new_plot_pos = (1, len(self.axes) + 1, 1)
         if dim <= 2:
             ax = self.fig.add_subplot(*new_plot_pos)
         else:
             ax = self.fig.add_subplot(*new_plot_pos, projection='3d')
             ax.get_proj = lambda: np.dot(
                 Axes3D.get_proj(ax), np.diag([self.axe3d_scale] * 3 + [1]))
             ax.patch.set_edgecolor('black')
             ax.patch.set_linewidth(1)
         self.axes.append(ax)
         plt.rc('grid', linestyle=':', color='black', alpha=0.6)
     else:
         if isinstance(ax, Axes):
             self.axes.append(ax)
         elif isinstance(ax, int):
             ax = self.axes[ax]
         else:
             raise ValueError(
                 f'ax must be either number or Axes object, got {type(ax)}.'
             )
         ax.clear()
     ax.set_facecolor(self.ax_params['facecolor'])
     ax.set_axisbelow(True)
     return ax
コード例 #7
0
def visualization(CoM_coord, eig_vec, new_coord):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    vlen = np.linalg.norm(eig_vec)
    X, Y, Z = CoM_coord
    U, V, W = eig_vec
    ax.quiver(X,
              Y,
              Z,
              U,
              V,
              W,
              pivot='tail',
              length=vlen,
              arrow_length_ratio=0.2 / vlen)

    ax.scatter(new_coord[:, 0],
               new_coord[:, 1],
               new_coord[:, 2],
               color="r",
               marker="o",
               s=50)

    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax),
                                 np.diag([0.75, 0.75, 1, 1]))
    ax.set_xlim([new_coord.min(), new_coord.max()])
    ax.set_ylim([new_coord.min(), new_coord.max()])
    ax.set_zlim([new_coord.min(), new_coord.max()])
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    plt.show()
コード例 #8
0
ファイル: plot_operators.py プロジェクト: LePingKYXK/symmetry
def visualization(CoM_coord, eig_vec, new_coord):
    """ This function visualizes the moment of inertia as vectors and
    molecule (using the shifted coordinates) as points.
    
    ========================
    parameters:
    CoM_coord:    the coordinates of the center of mass
    eig_vec:      the eigen vectors
    new_coord:    the shifted coordinates of molecule

    ------------------------
        ax.quiver method draws the 3D vectors, the 6 arguments in this
    method are X,Y,Z for the starting point and U, V, W for the ending
    point.
        ax.scatter method draws the scatter points in 3D space.
    """
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111, projection='3d')
    vlen = np.linalg.norm(eig_vec)
    X, Y, Z = CoM_coord
    U, V, W = eig_vec
    ax.quiver(X, Y, Z, U, V, W, pivot='tail',
              length=vlen, arrow_length_ratio=0.2/vlen)

    ax.scatter(new_coord[:,0], new_coord[:,1], new_coord[:,2],
               color="r", marker="o", s=50)
    
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([0.75, 0.75, 1, 1]))
    ax.set_xlim([new_coord.min(),new_coord.max()])
    ax.set_ylim([new_coord.min(),new_coord.max()])
    ax.set_zlim([new_coord.min(),new_coord.max()])
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    plt.show()
コード例 #9
0
def plot_3DElevTMAX(TMAX, t, SL, DuneDomain, DomainTS):

    if TMAX > t:
        TMAX = t
    # Build beach elevation domain
    BW = 6
    Beach = np.zeros([BW, BarrierLength])
    berm = math.ceil(BW * 0.65)
    Beach[berm:BW + 1, :] = BermEl
    add = (BermEl - SL) / berm
    for i in range(berm):
        Beach[i, :] = SL + add * i

    # Construct frame
    Dunes = [(DuneDomain[TMAX] + BermEl) * 10] * DuneWidth
    Water = np.zeros([3, BarrierLength])
    Domain = DomainTS[TMAX] * 10
    Domain[Domain < 0] = 0
    Domain = np.vstack([Water, Beach, Dunes, Domain, Water])
    Dlen = np.shape(Domain)[1]
    Dwid = np.shape(Domain)[0]
    fig = plt.figure(figsize=(12, 9))
    # fig.set_size_inches(12,7)
    ax = fig.add_subplot(111, projection="3d")
    scale_x = 1
    scale_y = Dwid / Dlen
    scale_z = 4 / Dlen * 3
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax),
                                 np.diag([scale_x, scale_y, scale_z, 1]))
    X = np.arange(Dlen)
    Y = np.arange(Dwid)
    X, Y = np.meshgrid(X, Y)
    Z = Domain
    ax.plot_surface(X,
                    Y,
                    Z,
                    cmap="terrain",
                    alpha=1,
                    vmin=-1.1,
                    vmax=4.0,
                    linewidth=0,
                    shade=True)
    ax.set_zlim(0, 4)

    # Plot shrubs
    # Shrubs = PercentCoverTS[TMAX]
    # Shrubs[Shrubs>0] = 1
    # Shrubs = np.vstack([np.zeros([DuneWidth,BarrierLength]), Shrubs])
    # Shrubs = Shrubs * Domain
    # Shrubs[Shrubs>0] = Shrubs[Shrubs>0] + 0.1
    # Shrubs[Shrubs<1] = None
    # ax.scatter(X, Y+1, Shrubs, s=30, c='black')

    ax.view_init(10, 155)
    plt.subplots_adjust(left=-1.2, right=1.3, top=2.2,
                        bottom=-0.3)  # mostly centered
    plt.show()
    name = "Output/Domain3D"
    fig.savefig(name, dpi=200)
コード例 #10
0
    def barplot3d(data, y_names, x_names, baseline):
        fig = plt.figure(figsize=(10, 8))
        ax = fig.gca(projection='3d')

        x_len = len(x_names)
        y_len = len(y_names)
        x = np.arange(0, x_len, 1)
        y = np.arange(0, y_len, 1)
        x, y = np.meshgrid(x - 0.25, y - 0.5)
        x = x.flatten()
        y = y.flatten()
        z = np.zeros(x_len * y_len)

        rho = np.array(data).flatten()
        dx = 0.5 * np.ones_like(z)
        dy = dx.copy()
        dz = rho.flatten()

        # xx, yy = np.meshgrid(range(len(x_names)), range(len(y_names)))
        # zz = copy(yy)
        # zz.fill(baseline)

        # ax.plot_surface(xx, yy, zz,alpha=0.5)

        ax.w_xaxis.set_ticks([i for i in range(len(data[0]))])
        ax.w_xaxis.set_ticklabels(x_names)

        ax.w_yaxis.set_ticks([i for i in range(len(data))])
        ax.w_yaxis.set_ticklabels(y_names)

        # ax.set_title('models with the size based prior')
        ax.set_zlabel('Predictive accuracy (%)')
        ax.w_zaxis.set_tick_params(labelsize=12)

        ax.get_proj = lambda: np.dot(
            Axes3D.get_proj(ax), np.diag([1, len(data) / len(data[0]), 1, 1]))

        nrm = mpl.colors.Normalize(0, 30)
        c_range = (np.array(data) - 15).flatten()
        # colors = cm.viridis(nrm(c_range))
        # colors = cm.winter(nrm(c_range))
        colors = cm.RdYlGn(nrm(c_range))
        ax.bar3d(x, y, z, dx, dy, dz, colors)
        plt.tight_layout()
        # plt.show()
        fig.savefig('./predictions/barplot3D.png', bbox_inches='tight')
        fig.savefig('./predictions/barplot3D.pdf',
                    format='pdf',
                    transparent=True,
                    bbox_inches='tight')
        fig.savefig('./predictions/barplot3D.eps',
                    format='eps',
                    transparent=True,
                    bbox_inches='tight')
        fig.clear()
        plt.clf()
コード例 #11
0
    def method_scale(self):
        # Scaling
        max_scale = max([self.x_scale, self.y_scale, self.z_scale])
        x_scale = self.x_scale / max_scale
        y_scale = self.y_scale / max_scale
        z_scale = self.z_scale / max_scale

        # Reference:
        # https://stackoverflow.com/questions/30223161/matplotlib-mplot3d-how-to-increase-the-size-of-an-axis-stretch-in-a-3d-plo
        self.ax.get_proj = lambda: np.dot(
            Axes3D.get_proj(self.ax), np.diag([x_scale, y_scale, z_scale, 1]))
コード例 #12
0
 def fix_view(self, scalefactor=1.5):
     # preserves constant lengths of cube edges but contains a bug
     bbox = np.array([getattr(self.ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
     bbox_center = np.mean(bbox, axis=1)
     bbox_dim = np.max(bbox, axis=1) - np.min(bbox, axis=1)
     scaling_factors = scalefactor * bbox_dim / np.max(bbox_dim)
     tr1, tr2, tr3 = bbox_center
     s1, s2, s3 = scaling_factors
     A = np.array([[1, 0, 0, -tr1], [0, 1, 0, -tr2], [0, 0, 1, -tr3], [0, 0, 0, 1]])
     T = np.array([[s1, 0, 0, 0], [0, s2, 0, 0], [0, 0, s3, 0], [0, 0, 0, 1]]) #replace s3 with 2*s3 sometimes
     B = np.array([[1, 0, 0, tr1], [0, 1, 0, tr2], [0, 0, 1, tr3], [0, 0, 0, 1]])
     self.ax.get_proj = lambda: np.dot(np.dot(np.dot(Axes3D.get_proj(self.ax), B), T), A)
コード例 #13
0
ファイル: data_visualizer.py プロジェクト: qpmnh/mvda
 def __init_axe(self, ax=None, dim=2):
     if self.fig is None:
         self.fig = plt.figure(**self.figure_params)
     if not self.pausing:
         new_plot_pos = (1, len(self.axes) + 1, 1)
         if dim <= 2:
             ax = self.fig.add_subplot(*new_plot_pos)
         elif dim == 3:
             ax = self.fig.add_subplot(*new_plot_pos, projection='3d')
             ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([self.axe3d_scale] * 3 + [1]))
         plt.rc('grid', linestyle="dotted", color='black', alpha=0.6)
     else:
         ax.clear()
     ax.grid(self.grid)
     return ax
コード例 #14
0
 def _fix_view(scalefactor=1.5, ax=None):
     bbox = np.array(
         [getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
     bbox_center = np.mean(bbox, axis=1)
     bbox_dim = np.max(bbox, axis=1) - np.min(bbox, axis=1)
     scaling_factors = scalefactor * bbox_dim / np.max(bbox_dim)
     tr1, tr2, tr3 = bbox_center
     s1, s2, s3 = scaling_factors
     A = np.array([[1, 0, 0, -tr1], [0, 1, 0, -tr2], [0, 0, 1, -tr3],
                   [0, 0, 0, 1]])
     T = np.array([[s1, 0, 0, 0], [0, s2, 0, 0], [0, 0, s3, 0],
                   [0, 0, 0, 1]])  #replace s3 with 2*s3 sometimes
     B = np.array([[1, 0, 0, tr1], [0, 1, 0, tr2], [0, 0, 1, tr3],
                   [0, 0, 0, 1]])
     ax.get_proj = lambda: np.dot(np.dot(np.dot(Axes3D.get_proj(ax), B), T),
                                  A)
コード例 #15
0
ファイル: plot.py プロジェクト: r2ufuk/easy21
def plot_q(value, path=None):
    plt.clf()
    ax = plt.axes(projection='3d')

    dealer = np.arange(1, 11, 1)
    player = np.arange(1, 22, 1)

    plt.xlabel("Dealer Showing", fontsize=10)
    plt.ylabel("Player Total", fontsize=10)

    plt.xticks(np.arange(1, 11, step=9))
    plt.yticks(np.arange(1, 22, step=20))
    ax.set_zticks([-1, 1])

    ax.set_zlim(-1, 1)
    ax.set_ylim(1, 21)
    ax.set_xlim(1, 10)

    x_scale = 10
    y_scale = 21
    z_scale = 4

    dealer, player = np.meshgrid(dealer, player)

    ax.plot_surface(dealer,
                    player,
                    value,
                    cmap="Greens",
                    antialiased=True,
                    rstride=1,
                    cstride=1,
                    lw=0.25,
                    edgecolors="black")

    scale = np.diag([x_scale, y_scale, z_scale, 1.0])
    scale = scale * (1.0 / np.max(scale))
    scale[3, 3] = 0.65

    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), scale)

    if path:
        plt.savefig(path)
    else:
        plt.show()
コード例 #16
0
def wireplot(ax, Z):
    w = Z.shape[0]
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([1, 1, .5, 1]))
    x = y = np.arange(w)
    X, Y = np.meshgrid(x, y)
    ax.plot_wireframe(X, Y, Z[X,Y], lw=1, rstride=1, cstride=1, color='k')
    # make the panes transparent
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    # make the grid lines transparent
    ax.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax.set_xlim((0, w-1))
    ax.set_ylim((0, w-1))
    ax.set_zlim((-1, 1))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
コード例 #17
0
ファイル: myplt.py プロジェクト: Yamanaka-Lab-TUAT/DNN-NMT
def draw_voxels(voxel_data, use_alpha=False, erase_number=True):
    if erase_number:
        # plt.axis('off')
        ax.grid(False)
        plt.tick_params(labelbottom=False,
                        labelleft=False,
                        labelright=False,
                        labeltop=False)
    ax.set_xlim(0, 32)
    ax.set_ylim(0, 16)
    ax.set_zlim(0, 32)
    ax.xaxis._axinfo['juggled'] = (2, 0, 1)
    ax.yaxis._axinfo['juggled'] = (2, 1, 0)
    ax.zaxis._axinfo['juggled'] = (2, 2, 2)
    # ax.invert_yaxis()
    # ax.invert_zaxis()
    aff = np.diag([0.5, 0.25, 1, 1])  # Positioning
    aff[0][3] = 0  # x
    aff[1][3] = 10  # y
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), aff)
    voxels = np.ceil(voxel_data[:, :, :, 0]).astype(np.bool)
    colors = np.empty((32, 16, 32, 4), dtype=np.float32)
    vmax = np.max(voxel_data[:, :, :, 0])
    print(vmax)
    # vmax = 0.225
    if use_alpha:
        colors[:, :, :, 0] = 0.
        colors[:, :, :, 1] = 0.
        colors[:, :, :, 2] = 0.
        colors[:, :, :, 3] = voxel_data[:, :, :, 0]
    else:
        colors[:, :, :, :] = cm.jet(voxel_data[:, :, :, 0] / vmax)

    ax_cb = fig.add_axes([0.9, 0.2, 0.025, 0.5])
    norm = Normalize(vmin=0., vmax=vmax)
    cmap = cm.get_cmap('binary' if use_alpha else 'jet')
    cbar = ColorbarBase(ax_cb, cmap=cmap, norm=norm, orientation='vertical')
    # cbar.set_ticks(np.arange(0, vmax + 0.075, 0.075))
    cbar.set_clim(vmin=0., vmax=1.)
    cbar.solids.set(alpha=1)
    ax.voxels(voxels, facecolors=colors)
コード例 #18
0
    def draw(self):
        self.index += 1064 * 3  #跳过前3帧,后3帧才是杂波
        port_dor = []
        for i in range(4):
            port_dor.append([])
            for j in range(128):
                shift = self.index + 36 + 256 * i + 2 * j
                the_hex = self.data_bytes[shift:shift + 2]
                port_dor[i].append(
                    int.from_bytes(the_hex, byteorder='little', signed=False))
        self.index += 1064
        for i in range(4, 8):
            port_dor.append([])
            for j in range(128):
                shift = self.index + 36 + 256 * (i - 4) + 2 * j
                the_hex = self.data_bytes[shift:shift + 2]
                port_dor[i].append(
                    int.from_bytes(the_hex, byteorder='little', signed=False))
        self.index += 1064
        for i in range(8, 12):
            port_dor.append([])
            for j in range(128):
                shift = self.index + 36 + 256 * (i - 8) + 2 * j
                the_hex = self.data_bytes[shift:shift + 2]
                port_dor[i].append(
                    int.from_bytes(the_hex, byteorder='little', signed=False))
        self.index += 1064

        #self.index+=1064*3#跳过后3帧

        #开始画图
        ax = plt.gca()
        ax.xaxis.set_major_locator(plt.MultipleLocator(1))
        ax.yaxis.set_major_locator(plt.MultipleLocator(20))

        def reshape_data(data):
            x = []
            y = []
            z = []
            for i in range(12):
                for j in range(128):
                    x.append(i)
                    y.append(j)
                    z.append(data[i][j])
            return np.array(x), np.array(y), np.array(z)

        fig = plt.figure(figsize=(12, 8))
        ax = Axes3D(fig)
        ax.set_zlim3d([0, 1300])
        ax.set_xlim([0, 12])
        ax.set_ylim([0, 140])
        x, y, z = reshape_data(port_dor)
        ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax),
                                     np.diag([0.5, 1, 0.7, 0.7]))
        ax.plot_trisurf(x, y, z)
        ax.view_init(33, -26)
        #plt.show()
        self.pic_index += 1
        pic_path = "E:/学习/论文/else/自己处理后的数据/杂波图/外协/杂波图{:0>6}.png".format(
            self.pic_index)
        plt.savefig(pic_path)
        plt.clf()
        plt.close()
        return
コード例 #19
0
fig, ax = plt.subplots()
N = 50
plot = ax.contourf(UhubMagMesh, N, extend='both')
cb = fig.colorbar(plot, drawedges=False)
# Remove colorbar outline
cb.outline.set_linewidth(0)
plt.tight_layout()
plt.savefig('2dplot.png', dpi=1000, transparent=True, bbox_inches='tight')

from mpl_toolkits.mplot3d import Axes3D
from matplotlib import colors as mcolors
fig = plt.figure()
ax2 = fig.gca(projection='3d')
alpha = 0.5
ax2.get_proj = lambda: np.dot(Axes3D.get_proj(ax2),
                              np.diag([0.65, 0.65, 1.3, 1]))

plot = ax2.contourf(xv, yv, UgroundMagMesh, N, zdir='z', offset=5, alpha=1)
ax2.contourf(xv, yv, UhubMagMesh, N, zdir='z', offset=90, alpha=1)
ax2.contourf(xv, yv, U3MagMesh, N, zdir='z', offset=215, alpha=1)

# Adjust the limits, ticks and view angle
ax2.set_zlim(5, 215)
# ax2.set_zticks(np.linspace(0,0.2,5))
ax2.set_xticks(np.arange(0, 3001, 1000))
ax2.set_yticks(np.arange(0, 3001, 1000))
ax2.view_init(35, -120)
cbaxes = fig.add_axes([0.8, 0.1, 0.02, 0.8])
cb = plt.colorbar(plot, cax=cbaxes, drawedges=False, extend='both')
# Remove colorbar outline
コード例 #20
0
 def short_proj():
     return np.dot(Axes3D.get_proj(self.axe), scale)
コード例 #21
0
    def plot(self, item, itemHistory=None, **kwargs):
        useDate = False
        for key in kwargs:
            if key == "useDate":
                useDate = kwargs[key]

        if item == "coe":
            # PLOTTING CLASSICAL ORBITAL ELEMENTS
            titles = ["a", "e", "i", "$\omega$", "$\Omega$", "$\\nu$"]
            ylabels = ["[km]", "", "[°]", "[°]", "[°]", "[°]"]
            timeAxis = self.history.datetime if useDate else self.history.t / 60 / 60 / 24
            fig, axes = plt.subplots(3, 2, figsize=(10, 8), sharex=True)
            for i in range(0, 6):
                for j in range(0, len(self.history.maneuverIdxs) - 1):
                    maneuverSlice = slice(self.history.maneuverIdxs[j],
                                          self.history.maneuverIdxs[j + 1])
                    if i in [2, 3, 4, 5]:
                        axes[int((i - i % 2) / 2), i % 2].plot(
                            timeAxis[maneuverSlice],
                            self.history.coe[maneuverSlice, i] * 180 / np.pi)
                    else:
                        if i == 0:
                            axes[int((i - i % 2) / 2), i % 2].plot(
                                timeAxis[maneuverSlice],
                                self.history.coe[maneuverSlice, i] / 1000)
                        else:
                            axes[int((i - i % 2) / 2), i % 2].plot(
                                timeAxis[maneuverSlice],
                                self.history.coe[maneuverSlice, i])
                    axes[int((i - i % 2) / 2),
                         i % 2].set_title(titles[i] + " " + ylabels[i])

                if useDate:
                    fig.autofmt_xdate()
                    axes[int((i - i % 2) / 2),
                         i % 2].xaxis.set_major_formatter(
                             mdates.DateFormatter('%Y-%m-%d'))
                else:
                    if i in [4, 5]:
                        axes[int((i - i % 2) / 2),
                             i % 2].set_xlabel("Tiempo [días]")
                axes[int((i - i % 2) / 2),
                     i % 2].yaxis.get_major_formatter().set_scientific(False)
                axes[int((i - i % 2) / 2),
                     i % 2].yaxis.get_major_formatter().set_useOffset(False)
                axes[int((i - i % 2) / 2), i % 2].grid(b=True)
                if i in [0, 1]:
                    axes[int((i - i % 2) / 2), i %
                         2].yaxis.get_major_formatter().set_scientific(True)
                    axes[int((i - i % 2) / 2),
                         i % 2].yaxis.get_major_formatter().set_useOffset(True)
        if item == "secularCoe":
            # PLOTTING CLASSICAL ORBITAL ELEMENTS
            titles = ["a", "e", "i", "$\omega$", "$\Omega$", "$\\nu$"]
            ylabels = ["[km]", "", "[°]", "[°]", "[°]", "[°]"]
            timeAxis = self.history.datetime if useDate else self.history.tSecular / 60 / 60 / 24
            fig, axes = plt.subplots(3, 2, figsize=(10, 8), sharex=True)
            for i in range(0, 5):
                if i in [2, 3, 4]:
                    axes[int((i - i % 2) / 2), i % 2].plot(
                        timeAxis, self.history.secularCoe[:, i] * 180 / np.pi)
                else:
                    if i == 0:
                        axes[int((i - i % 2) / 2),
                             i % 2].plot(timeAxis,
                                         self.history.secularCoe[:, i] / 1e3)
                    else:
                        axes[int((i - i % 2) / 2),
                             i % 2].plot(timeAxis, self.history.secularCoe[:,
                                                                           i])
                axes[int((i - i % 2) / 2),
                     i % 2].set_title(titles[i] + " " + ylabels[i])

                if (useDate):
                    fig.autofmt_xdate()
                    axes[int((i - i % 2) / 2),
                         i % 2].xaxis.set_major_formatter(
                             mdates.DateFormatter('%Y-%m-%d'))
                axes[int((i - i % 2) / 2),
                     i % 2].yaxis.get_major_formatter().set_scientific(False)
                axes[int((i - i % 2) / 2),
                     i % 2].yaxis.get_major_formatter().set_useOffset(False)
                axes[int((i - i % 2) / 2), i % 2].grid(b=True)

        if item == "3d-trajectory":
            for key in kwargs:
                if key == "ax":
                    axExtern = kwargs["ax"]
            #Plot 3D Trajectory
            if 'axExtern' not in locals():
                fig = plt.figure(figsize=(10, 10))
                ax = fig.add_subplot(111, projection='3d')
            else:
                ax = axExtern

            markers = np.zeros([len(self.history.maneuverIdxs) - 1, 3])

            for i in range(0, len(self.history.maneuverIdxs) - 1):
                maneuverSlice = slice(self.history.maneuverIdxs[i],
                                      self.history.maneuverIdxs[i + 1])
                ax.plot3D(self.history.r[maneuverSlice, 0] / 1000,
                          self.history.r[maneuverSlice, 1] / 1000,
                          self.history.r[maneuverSlice, 2] / 1000,
                          linewidth=1)
                markers[i, :] = self.history.r[
                    self.history.maneuverIdxs[i], :] / 1000
            ax.plot3D(markers[:, 0], markers[:, 1], markers[:, 2], "k.")

            if 'axExtern' not in locals():
                auxiliary.set_axes_equal(ax)
                ax.set_aspect("equal")
                scale_x = 1.2
                scale_y = 1.2
                scale_z = 1.2
                ax.get_proj = lambda: np.dot(
                    Axes3D.get_proj(ax), np.diag(
                        [scale_x, scale_y, scale_z, 1]))

                figXLim = ax.get_xlim()
                figYLim = ax.get_ylim()
                figZLim = ax.get_zlim()
                xx, yy = np.meshgrid(figXLim, figYLim)
                z = np.array([[0, 0], [0, 0]])
                ax.plot_surface(xx, yy, z, alpha=0.3, color="lightgray")

                ax.set_title("Satellite Trajectory [km]")
                ax.set_xlabel("X [km]")
                ax.set_ylabel("Y [km]")
                ax.set_zlabel("Z [km]")
            return ax

        if item == "orbitalEnergy":
            fig, ax = plt.subplots(figsize=(10, 4))
            moonDistances = np.array([])
            for num, date in enumerate(self.history.datetime):
                moonVector = models.lunarPositionAlmanac2013(date)
                moonDistances = np.append(
                    moonDistances,
                    np.linalg.norm(moonVector - self.history.r[num]))

            earthEnergy = np.linalg.norm(
                self.history.v, axis=1
            )**2 / 2 - constants.mu_E / np.linalg.norm(self.history.r, axis=1)
            moonEnergy = np.linalg.norm(
                self.history.v, axis=1)**2 / 2 - constants.mu_M / moonDistances
            ax.plot(self.history.datetime, earthEnergy, label="Earth Energy")
            ax.plot(self.history.datetime, moonEnergy, label="Moon Energy")
            fig.legend()
            fig.autofmt_xdate()
            ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
            plt.grid()
            #fig, ax = plt.subplots(figsize=(10,4))
            #ax.plot(self.history.datetime,moonDistances);

        if item == "energyUsage":
            timeAxis = self.history.datetime if useDate else self.history.t / 60 / 60 / 24
            fig, ax = plt.subplots(figsize=(10, 4))
            PSolarPanels = np.diff(
                self.history.energy["solar panels"]) / np.diff(self.history.t)
            PThruster = np.diff(self.history.energy["thruster"]) / np.diff(
                self.history.t)
            POtherDevices = np.ones((len(self.history.t))) * (
                self.spacecraft.solarPanels.nominalPower) * 0.6
            EBattery = self.history.energy["battery"] / 60 / 60
            ax.plot(timeAxis[:-1],
                    PSolarPanels,
                    linewidth=1,
                    label="Potencia Paneles Solares")
            ax.plot(timeAxis[:-1],
                    PThruster,
                    linewidth=1,
                    label="Potencia Propulsor")
            ax.plot(timeAxis,
                    POtherDevices,
                    linewidth=1,
                    label="Potencia Otros Dispositivos")
            ax.set_ylabel("Potencia [W]")
            #ax.set_ylim([-1,2])

            if useDate:
                fig.autofmt_xdate()
                ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
            else:
                ax.set_xlabel("Tiempo [días]")
            ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
            ax.yaxis.get_major_formatter().set_scientific(False)
            ax.yaxis.get_major_formatter().set_useOffset(False)

            ax2 = ax.twinx()
            ax2.plot(timeAxis,
                     EBattery,
                     linewidth=1,
                     label="Energía Baterías",
                     color="purple")
            ax2.set_ylabel("Energía [Wh]")
            h1, l1 = ax.get_legend_handles_labels()
            h2, l2 = ax2.get_legend_handles_labels()
            ax2.legend(h1 + h2, l1 + l2)
            ax2.set_ylim([-2, self.spacecraft.battery.energy + 2])
            ax2.grid()
            #ax2.set_ylim([-.5,.5])

        if item == "singleItem":
            timeAxis = self.history.datetime if useDate else self.history.t / 60 / 60 / 24
            if np.isscalar(itemHistory) and itemHistory == None:
                raise Exception("History Data not specified.")
            else:
                fig, ax = plt.subplots(figsize=(10, 4))
                for i in range(0, len(self.history.maneuverIdxs) - 1):
                    maneuverSlice = slice(self.history.maneuverIdxs[i],
                                          self.history.maneuverIdxs[i + 1])
                    ax.plot(timeAxis[maneuverSlice],
                            itemHistory[maneuverSlice],
                            linewidth=1)

                if useDate:
                    fig.autofmt_xdate()
                    ax.xaxis.set_major_formatter(
                        mdates.DateFormatter('%Y-%m-%d'))
                    ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
                    ax.yaxis.get_major_formatter().set_scientific(False)
                    ax.yaxis.get_major_formatter().set_useOffset(False)
                else:
                    ax.set_xlabel("Time [days]")
                plt.grid()
                mplcursors.cursor(hover=True)
 def modified_proj():
   return np.dot(Axes3D.get_proj(gl.ax1), scale)
コード例 #23
0
def varMesh3Arrays(crossDict, refAnt, timeStep=0.4, pltLen=10,
                                 numFrames=200, pltStart=1):
    """
    Plots mesh between phase of three dict items with real distance
    between them.

    Parameters
    ----------
    crossDict : dictionary
        Must include distance items.
    refAnt : int
        Which antenna is being used as a reference.
    timeStep: float (default: 0.4)
        Seconds between data points.
    plotLen : int (default: 10)
        Length of plot (number of samples/frame).
    numFrames : int (default: 200)
        Length of the animation in frames.
    pltStart : int (default 1)
        Which index in the array to start plotting at.
    """

    global fignum

    ans = 'n' # default save to file

    if numFrames < 10:
        ans = input("Save to file [n] or show live plot [y]: ")

    if ans == 'n':
        filepath = './for-animation/'
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        print('Will save files to {}'.format(filepath))
    elif ans != 'n' and ans != 'y':
        print('Please enter a valid option (y/n) next time.')
        print('Terminating.')
        return

    plotDict = {}

    for key in crossDict:
        if str(refAnt) in key:
            plotDict[key] = crossDict[key]

    plotDictSortedKeys = sorted(plotDict)

    xValues = np.repeat(0, pltLen)  # Reference distance
    for key in plotDictSortedKeys:
        if key[0] is 'd' and key[1] is not 'T':
            xValues = np.append(xValues, np.repeat(plotDict[key], pltLen))

    for i in range(numFrames):
        zValues = np.repeat(0, pltLen)  # Reference phase
        for key in plotDictSortedKeys:
            if key[0] != 'd':
                zValues = np.append(zValues, np.angle(
                    plotDict[key][pltStart:pltStart+pltLen])*180/np.pi)

        t = np.linspace(pltStart, pltStart+pltLen, pltLen)*timeStep
        yValues = np.hstack((t, t, t, t))

        plt.rcParams.update({'font.size': 18})

        fig = plt.figure(fignum, figsize=(21,5))
        ax = fig.add_subplot(111, projection='3d')
        #  In np.diag, 4 args are x_aspect, y_aspect, z_aspect
        #  All are normalized from 0-1
        ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax),
                                     np.diag([0.4, 1, 0.95, 1]))
        surf = ax.plot_trisurf(xValues, yValues, zValues, cmap=cm.inferno,
                               linewidth=0)
        fig.colorbar(surf,
                     fraction=0.046,
                     pad=0.04,
                     label='(deg)',
                     orientation='horizontal')
        surf.set_clim(vmin=-180, vmax=180)
        ax.set_xlabel('\nDistance from A{}[m]'.format(refAnt), linespacing=3.2)
        ax.set_zlabel('\nPhase from A{}[deg]'.format(refAnt), linespacing=3.2)
        ax.set_ylabel('\nTime[s]', linespacing=3.2)
        ax.set_zlim(-180, 180)
        ax.set_zticks([-100, 0, 100])
        ax.set_xticks([0, 20, 40, 60])

        ax.view_init(azim=-35, elev=35)
        fig.tight_layout()
        fig.subplots_adjust(left=0, right=1, top=1, bottom=0.12)

        if ans == 'n':
            plt.savefig('./for-animation/animationframe{}.png'.format(i),bbox_inches='tight', pad_inches=0)
            print('Fignum: {}'.format(fignum))
            fignum += 1
            plt.clf()
        elif ans == 'y':
            plt.show(block=False)
            print('Fignum: {}'.format(fignum))
            fignum += 1

        pltStart += 1
コード例 #24
0
def visualize_voxels_original(image, voxels, transform, save=False):
    fig = plt.figure(figsize=(8, 8))
    fig.subplots_adjust(top=1, bottom=0, left=0, right=1)

    ax = fig.add_subplot(111, projection='3d')

    ###### Scaling Section #######
    x_scale = 4.0
    y_scale = 1.0
    z_scale = 4.0

    scale = np.diag([x_scale, y_scale, z_scale, 1.0])
    scale = scale * (1.0 / scale.max())
    scale[3, 3] = 1.0

    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), scale)
    ##############################

    Ainv = np.linalg.inv(transform['A'])
    b = transform['b']

    xs, ys, zs, colors = [], [], [], []
    for x in range(0, 192, STRIDE):
        for y in range(0, 192, STRIDE):
            for z in np.argwhere(voxels[x][y] == 1).T[0]:
                inv_coords = Ainv.dot(np.subtract(np.array([x, y, z]), b.T)[0])
                xt, yt = int(inv_coords[0]), int(inv_coords[1])
                color = np.array(list(image.getpixel((xt, yt)))) / 255.0

                xs.append(xt)
                ys.append(image.size[1] - yt)
                zs.append(z)
                colors.append(color)

    # Shift z-coordinates of voxels to have 0 mean
    avg_z = sum(zs) / float(len(zs))
    # Clip negative z-coordinates to 0
    zs = [max(z - avg_z, 0) for z in zs]

    # Add base layer using original image
    for x in range(0, image.size[0], STRIDE):
        for y in range(0, image.size[1], STRIDE):
            color = np.array(list(image.getpixel((x, y)))) / 255.0
            xs.insert(0, x)
            ys.insert(0, image.size[1] - y)
            zs.insert(0, 0)
            colors.insert(0, color)

    ax.scatter(xs=xs, ys=zs, zs=ys, color=colors, s=5)
    ax.set_ylim(0, 100)
    plt.axis('off')

    if save:
        # Used to crop whitespace
        bbox = fig.bbox_inches.from_bounds(0, 1, 5, 7)
        bbox_front = fig.bbox_inches.from_bounds(0, 0, 8, 8)

        # ax.view_init(elev=0, azim=90)
        # plt.savefig("voxels_front.jpg")
        ax.view_init(elev=20, azim=40)
        plt.savefig("voxels_corner.jpg", bbox_inches=bbox)
        ax.view_init(elev=0, azim=0)
        plt.savefig("voxels_side.jpg", bbox_inches=bbox)
        # ax.view_init(elev=0, azim=40)
        # plt.savefig("voxels_side_front.jpg", bbox_inches=bbox)
    else:
        ax.view_init(elev=20, azim=40)
        plt.show()
コード例 #25
0
ファイル: 3d_plot.py プロジェクト: Weibo-Hu/CFD-Post
# %% 3D PSD
# load data
var = 'p'
xval = np.loadtxt(pathSL + 'FWPSD_x.dat', delimiter=' ')
freq = np.loadtxt(pathSL + 'FWPSD_freq.dat', delimiter=' ')
FPSD = np.loadtxt(pathSL + var + '_FWPSD_psd.dat', delimiter=' ')
freq = freq[1:]
FPSD = FPSD[1:, :]
newx = [-10.0, 2.0, 3.0, 5.0, 9.0, 10.0]

fig = plt.figure(figsize=(7.0, 4.0))
plt.rcParams['grid.color'] = 'gray'
plt.rcParams['grid.linestyle'] = 'dotted'
plt.tick_params(labelsize=numsize)
ax = fig.add_subplot(111, projection='3d')
ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([0.8, 1.2, 1.0, 1.0])
                             )
for i in range(np.size(newx)):
    ind = np.where(xval[:] == newx[i])[0][0]
    xloc = newx[i] * np.ones(np.shape(freq))
    ax.plot(freq, xloc, FPSD[:, ind], zdir='z', linewidth=1.5)

ax.ticklabel_format(axis="z", style="sci", scilimits=(-2, 2))
ax.zaxis.offsetText.set_fontsize(numsize)
# ax.w_xaxis.set_xscale('log')
ax.set_xscale('symlog')
# ax.set_xticks([-2, -1, 0, 0.3])
ax.set_xlabel(r'$f$', fontsize=textsize)
ax.set_ylabel(r'$x/\delta_0$', fontsize=textsize)
ax.set_zlabel(r'$\mathcal{P}$', fontsize=textsize)
#ax.xaxis._axinfo['label']['space_factor'] = 0.1
コード例 #26
0
def func_draw(_, dargs, aw, ax):
    eye = dargs['eye']
    r = dargs['r']
    uv_lines = dargs['uv_lines']
    CONTEXT = dargs['CONTEXT']
    BIRD = dargs['BIRD']
    WM_LINES_MAP = dargs['WM_LINES_MAP']
    MAP_SIZE = dargs['MAP_SIZE']
    LINE_WIDTH = dargs['LINE_WIDTH']
    FIGSIZE = dargs['FIGSIZE']
    FIGPOS = dargs['FIGPOS']
    AFS = dargs['AFS']
    WM_LINES_MAP = dargs['WM_LINES_MAP']
    EYE_Y = dargs['EYE_Y']
    IS_TV = dargs['IS_TV']
    SHOWN_MAP = dargs['SHOWN_MAP']
    SHOWN_IMG = dargs['SHOWN_IMG']
    IS_REAR = dargs['IS_REAR']
    D_EYE_CAM_F = dargs['D_EYE_CAM_F']
    D_EYE_CAM_R = dargs['D_EYE_CAM_R']

    r_f = r
    r_r = r @ theta2r(np.array([0, pi, 0]))
    eye_f = eye + r_f @ np.array([0, 0, D_EYE_CAM_F])
    eye_r = eye + r_r @ np.array([0, 0, D_EYE_CAM_R])

    eye, r = (eye_f, r_f) if not IS_REAR else (eye_r, r_r)

    uv_lines = wm_lines2uv_lines(WM_LINES_MAP, eye, r, CONTEXT)  #  camera view

    # additional draw line
    #append_lines  = np.array([
    #np.array([[0, 0],[640, 480]]),
    #])
    #uv_lines = np.append(uv_lines, append_lines, axis=0)

    tv = uv_lines2tvs_fixed(uv_lines, EYE_Y, BIRD, CONTEXT)

    uv_lines_img = tv if SHOWN_IMG == 'TV' else uv_lines
    img = uv_lines2img(uv_lines_img, CONTEXT, LINE_WIDTH)

    aw.tick_params(labelbottom=False,
                   labelleft=False,
                   labelright=False,
                   labeltop=False)
    aw.cla()
    ax.cla()
    aw.set_position(FIGPOS[0])
    ax.set_position(FIGPOS[1])
    aw.imshow(img, cmap='gray', vmin=0, vmax=255, interpolation='none')
    ax.set_axis_off()

    # plot map
    if SHOWN_MAP == 'TV':
        EYE_ZERO = np.array([0, BIRD[1], 0])
        R_ZERO = theta2r(np.array([pi / 2, 0, 0]))
        wm_lines = uv_lines2wm_lines(tv, EYE_ZERO, R_ZERO, CONTEXT)
        if len(wm_lines) != 0:
            wm_lines = wm_lines + np.array([BIRD[0], 0, BIRD[2]])
    elif SHOWN_MAP == 'CAMERA':
        wm_lines = uv_lines2wm_lines_fixed(uv_lines, EYE_Y, CONTEXT)
    else:
        wm_lines = WM_LINES_MAP
    if len(wm_lines) == 0:
        lines = np.array([])
    else:
        if SHOWN_MAP in ['CAMERA', 'TV']:
            for i in range(len(wm_lines)):
                wm_lines[i] = np.array([
                    np.dot(r, wm_lines[i][0]),
                    np.dot(r, wm_lines[i][1]),
                ])
            wm_lines = wm_lines + np.array([eye[0], 0, eye[2]])
        lines = np.array((
            wm_lines[:, 0, 0],
            wm_lines[:, 1, 0],  # x
            [0] * len(wm_lines),
            [0] * len(wm_lines),  # y
            wm_lines[:, 0, 2],
            wm_lines[:, 1, 2],  # z
        )).T.reshape(len(wm_lines), 3, 2)
    for line in lines:
        ax.plot(*line, "-", c='red', linewidth=0.5)

    # plot axis
    lims = (np.array(
        (MAP_SIZE[0], min(MAP_SIZE) / 2, MAP_SIZE[1]))).astype(int)
    plot_axis(ax, lims)

    # plot eye, axis of c
    ax.plot(*zip(eye), "o")
    ax.plot([eye[0]], [0], [eye[2]], "o")
    l = min(MAP_SIZE) / 10
    ax.plot(*list(zip(eye, c2w(np.array([l, 0, 0]), eye, r))), '-', c='green')
    ax.plot(*list(zip(eye, c2w(np.array([0, l, 0]), eye, r))), '-', c='green')
    ax.plot(*list(zip(eye, c2w(np.array([0, 0, l]), eye, r))), '-', c='green')

    # UV_VS
    UV_VS = get_UV_VS(CONTEXT)

    # plot rectangle of F
    wf_vs = []  # 4 verticies of focused surface (4, 3)
    for uv in UV_VS:
        cf_v = uv2cf(uv, CONTEXT)
        wf_v = c2w(cf_v, eye, r)
        wf_vs.append(wf_v)
    plot_rectangle(ax, wf_vs)

    # plot rectangle of map, lines from eye to rectangle of map
    # 4 verticies of projected region on map
    wm_vs = np.array([uv2wm(uv, eye, r, CONTEXT) for uv in UV_VS])

    if w2c(wm_vs[3], eye, r)[2] >= 0:  # top edge of view is on map
        plot_rectangle(ax, wm_vs)
        for wm_v in wm_vs:
            ax.plot(*list(zip(eye, wm_v)), '--', c='000000')
    else:  # top edge of view is at infinity (behind of camera)
        # convert wm_vs[2'], wm_vs[3'] -> wm_vs[2], wm_vs[3]
        #        2'______________________________ 3'        |
        #          \                            /           |
        #           \                          /            |
        #            \                        /             |
        #             \        camera        /              |
        #              \         /\         /               |
        #               \      /    \      /                |
        #                \   /        \   /                 |
        #                 \/____________\/                  |
        #               - 0              1 -                |
        #            -           ||           -             |
        #         -             \||/             -          |
        #      -                 \/                 -       |
        # 3 -                  front                   - 2  |
        #                                                   |
        wm_vs = get_wm_vs_infinity(eye, r, CONTEXT, max(MAP_SIZE))
        ax.plot(*list(zip(wm_vs[0], wm_vs[1])), '-', c='000000')
        ax.plot(*list(zip(wm_vs[1], wm_vs[2])), '-', c='000000')
        ax.plot(*list(zip(wm_vs[3], wm_vs[0])), '-', c='000000')
        ax.plot(*list(zip(eye, wm_vs[0])), '--', c='000000')
        ax.plot(*list(zip(eye, wm_vs[1])), '--', c='000000')
        ax.plot(*list(zip(eye, wf_vs[2])), '--', c='000000')
        ax.plot(*list(zip(eye, wf_vs[3])), '--', c='000000')

    # plot eye_b
    # todo
    eye_b, r_b = eye_r2eye_r_b(eye, r, BIRD)
    ax.plot(*zip(eye_b), "o")
    ax.plot([eye_b[0]], [0], [eye_b[2]], "o")
    #ax.plot([eye_b[0], wm_vs[0][0]], [0, 0], [eye_b[2], wm_vs[0][2]])
    # plot rectangle of bird (4 verticies of bird (4, 3))
    wm_bvs = np.array([uv2wm(uv, eye_b, r_b, CONTEXT) for uv in UV_VS])
    plot_rectangle(ax, wm_bvs)
    # plot lines from eye_b to rectangle of bird
    for wm_bv in wm_bvs:
        ax.plot(*list(zip(eye_b, wm_bv)), '--', c='000000')
    l = min(MAP_SIZE) / 5
    c = 'green'
    ax.plot(*list(zip(eye_b, c2w(np.array([l, 0, 0]), eye_b, r_b))), '-', c=c)
    ax.plot(*list(zip(eye_b, c2w(np.array([0, l, 0]), eye_b, r_b))), '-', c=c)
    ax.plot(*list(zip(eye_b, c2w(np.array([0, 0, l]), eye_b, r_b))), '-', c=c)

    # affine
    AF = reduce(lambda res, x: np.dot(res, x), AFS, np.eye(4))
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), AF)
コード例 #27
0
def main():
    dataset_path = "./data/data_3d_h36m.npz"    # 加载数据
    from common.h36m_dataset import Human36mDataset
    dataset = Human36mDataset(dataset_path)
    dataset = read_3d_data(dataset)
    cudnn.benchmark = True
    device = torch.device("cpu")
    from models.sem_gcn import SemGCN
    from common.graph_utils import adj_mx_from_skeleton
    p_dropout = None
    adj = adj_mx_from_skeleton(dataset.skeleton())
    model_pos = SemGCN(adj, 128, num_layers=4, p_dropout=p_dropout,
                       nodes_group=dataset.skeleton().joints_group()).to(device)
    ckpt_path = "./checkpoint/pretrained/ckpt_semgcn_nonlocal_sh.pth.tar"
    ckpt = torch.load(ckpt_path, map_location='cpu')
    model_pos.load_state_dict(ckpt['state_dict'], False)
    model_pos.eval()
    # ============ 新增代码 ==============
    # 从项目处理2d数据的代码中输出的一个人体数据
    inputs_2d = [[483.0, 450], [503, 450], [503, 539], [496, 622], [469, 450], [462, 546], [469, 622], [483, 347],
                 [483, 326], [489, 264], [448, 347], [448, 408], [441, 463], [517, 347], [524, 408], [538, 463]]

    # # openpose的测试样例识别结果
    # inputs_2d = [[86.0, 137], [99, 128], [94, 127], [97, 110], [89, 105], [102, 129], [116, 116], [99, 110],
    #              [105, 93], [117, 69], [147, 63], [104, 93], [89, 69], [82, 38], [89, 139], [94, 140]]

    inputs_2d = np.array(inputs_2d)
    # inputs_2d[:, 1] = np.max(inputs_2d[:, 1]) - inputs_2d[:, 1]   # 变成正的人体姿态,原始数据为倒立的

    cam = dataset.cameras()['S1'][0]    # 获取相机参数
    inputs_2d[..., :2] = normalize_screen_coordinates(inputs_2d[..., :2], w=cam['res_w'], h=cam['res_h'])  # 2d坐标处理

    # 画出归一化屏幕坐标并且标记序号的二维关键点图像
    print(inputs_2d)    # 打印归一化后2d关键点坐标
    d_x = inputs_2d[:, 0]
    d_y = inputs_2d[:, 1]
    plt.figure()
    plt.scatter(d_x, d_y)
    for i, txt in enumerate(np.arange(inputs_2d.shape[0])):
        plt.annotate(txt, (d_x[i], d_y[i]))     # 标号
    # plt.show()      # 显示2d关键点归一化后的图像

    # 获取3d结果
    inputs_2d = torch.tensor(inputs_2d, dtype=torch.float32)    # 转换为张量
    outputs_3d = model_pos(inputs_2d).cpu()         # 加载模型
    outputs_3d[:, :, :] -= outputs_3d[:, :1, :]     # Remove global offset / 移除全球偏移
    predictions = [outputs_3d.detach().numpy()]     # 预测结果
    prediction = np.concatenate(predictions)[0]     # 累加取第一个
    # Invert camera transformation  / 反相机的转换
    prediction = camera_to_world(prediction, R=cam['orientation'], t=0)     # R和t的参数设置影响不大,有多种写法和选取的相机参数有关,有些S没有t等等问题
    prediction[:, 2] -= np.min(prediction[:, 2])    # 向上偏移min(prediction[:, 2]),作用是把坐标变为正数
    print('prediction')
    print(prediction)   # 打印画图的3d坐标
    plt.figure()
    ax = plt.subplot(111, projection='3d')  # 创建一个三维的绘图工程
    o_x = prediction[:, 0]
    o_y = prediction[:, 1]
    o_z = prediction[:, 2]
    print(o_x)
    print(o_y)
    print(o_z)
    ax.scatter(o_x, o_y, o_z)

    temp = o_x
    x = [temp[9], temp[8], temp[7], temp[10], temp[11], temp[12]]
    temp = o_y
    y = [temp[9], temp[8], temp[7], temp[10], temp[11], temp[12]]
    temp = o_z
    z = [temp[9], temp[8], temp[7], temp[10], temp[11], temp[12]]
    ax.plot(x, y, z)

    temp = o_x
    x = [temp[7], temp[0], temp[4], temp[5], temp[6]]
    temp = o_y
    y = [temp[7], temp[0], temp[4], temp[5], temp[6]]
    temp = o_z
    z = [temp[7], temp[0], temp[4], temp[5], temp[6]]
    ax.plot(x, y, z)

    temp = o_x
    x = [temp[0], temp[1], temp[2], temp[3]]
    temp = o_y
    y = [temp[0], temp[1], temp[2], temp[3]]
    temp = o_z
    z = [temp[0], temp[1], temp[2], temp[3]]
    ax.plot(x, y, z)

    temp = o_x
    x = [temp[7], temp[13], temp[14], temp[15]]
    temp = o_y
    y = [temp[7], temp[13], temp[14], temp[15]]
    temp = o_z
    z = [temp[7], temp[13], temp[14], temp[15]]
    ax.plot(x, y, z)

    # temp = o_x
    # x = [temp[0], temp[14]]
    # temp = o_y
    # y = [temp[0], temp[14]]
    # temp = o_z
    # z = [temp[0], temp[14]]
    # ax.plot(y, x, z)
    #
    # temp = o_x
    # x = [temp[0], temp[15]]
    # temp = o_y
    # y = [temp[0], temp[15]]
    # temp = o_z
    # z = [temp[0], temp[15]]
    # ax.plot(y, x, z)

    # 改变坐标比例的代码,该代码的效果是z坐标轴是其他坐标的两倍
    from matplotlib.pyplot import MultipleLocatort
    major_locator = MultipleLocator(0.5)
    ax.xaxis.set_major_locator(major_locator)
    ax.yaxis.set_major_locator(major_locator)
    ax.zaxis.set_major_locator(major_locator)
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([0.5, 0.5, 1, 1]))

    plt.show()
コード例 #28
0
def result_plots(plot_opts: dict, w_veh_opt: float, w_veh_real: float,
                 refline: np.ndarray, bound1: np.ndarray, bound2: np.ndarray,
                 trajectory: np.ndarray) -> None:
    """
    Created by:
    Alexander Heilmeier

    Documentation:
    This function plots several figures containing relevant trajectory information after trajectory optimization.

    Inputs:
    plot_opts:      dict containing the information which figures to plot
    w_veh_opt:      vehicle width used during optimization
    w_veh_real:     real vehicle width
    refline:        contains the reference line coordinates [x_m, y_m]
    bound1:         first track boundary (mostly right) [x_m, y_m]
    bound2:         second track boundary (mostly left) [x_m, y_m]
    trajectory:     trajectory data [s_m, x_m, y_m, psi_rad, kappa_radpm, vx_mps, ax_mps2]
    """

    if plot_opts["raceline"]:
        # calculate vehicle boundary points (including safety margin in vehicle width)
        normvec_normalized_opt = trajectory_planning_helpers.calc_normal_vectors.\
            calc_normal_vectors(trajectory[:, 3])

        veh_bound1_virt = trajectory[:, 1:
                                     3] + normvec_normalized_opt * w_veh_opt / 2
        veh_bound2_virt = trajectory[:, 1:
                                     3] - normvec_normalized_opt * w_veh_opt / 2

        veh_bound1_real = trajectory[:, 1:
                                     3] + normvec_normalized_opt * w_veh_real / 2
        veh_bound2_real = trajectory[:, 1:
                                     3] - normvec_normalized_opt * w_veh_real / 2

        point1_arrow = refline[0]
        point2_arrow = refline[3]
        vec_arrow = point2_arrow - point1_arrow

        # plot track including optimized path
        plt.figure()
        plt.plot(refline[:, 0], refline[:, 1], "k--", linewidth=0.7)
        plt.plot(veh_bound1_virt[:, 0],
                 veh_bound1_virt[:, 1],
                 "b",
                 linewidth=0.5)
        plt.plot(veh_bound2_virt[:, 0],
                 veh_bound2_virt[:, 1],
                 "b",
                 linewidth=0.5)
        plt.plot(veh_bound1_real[:, 0],
                 veh_bound1_real[:, 1],
                 "c",
                 linewidth=0.5)
        plt.plot(veh_bound2_real[:, 0],
                 veh_bound2_real[:, 1],
                 "c",
                 linewidth=0.5)
        plt.plot(bound1[:, 0], bound1[:, 1], "k-", linewidth=0.7)
        plt.plot(bound2[:, 0], bound2[:, 1], "k-", linewidth=0.7)
        plt.plot(trajectory[:, 1], trajectory[:, 2], "r-", linewidth=0.7)
        plt.grid()
        ax = plt.gca()
        ax.arrow(point1_arrow[0],
                 point1_arrow[1],
                 vec_arrow[0],
                 vec_arrow[1],
                 head_width=7.0,
                 head_length=7.0,
                 fc='g',
                 ec='g')
        ax.set_aspect("equal", "datalim")
        plt.xlabel("east in m")
        plt.ylabel("north in m")
        plt.show()

    if plot_opts["raceline_curv"]:
        # plot curvature profile
        plt.figure()
        plt.plot(trajectory[:, 0], trajectory[:, 4])
        plt.grid()
        plt.xlabel("distance in m")
        plt.ylabel("curvature in rad/m")
        plt.show()

    if plot_opts["racetraj_vel_3d"]:
        scale_x = 1.0
        scale_y = 1.0
        scale_z = 0.3  # scale z axis such that it does not appear stretched

        # create 3d plot
        fig = plt.figure()
        ax = fig.gca(projection='3d')

        # recast get_proj function to use scaling factors for the axes
        ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax),
                                     np.diag([scale_x, scale_y, scale_z, 1.0]))

        # plot raceline and boundaries
        ax.plot(refline[:, 0], refline[:, 1], "k--", linewidth=0.7)
        ax.plot(bound1[:, 0], bound1[:, 1], 0.0, "k-", linewidth=0.7)
        ax.plot(bound2[:, 0], bound2[:, 1], 0.0, "k-", linewidth=0.7)
        ax.plot(trajectory[:, 1], trajectory[:, 2], "r-", linewidth=0.7)

        ax.grid()
        ax.set_aspect("equal")
        ax.set_xlabel("east in m")
        ax.set_ylabel("north in m")

        # plot velocity profile in 3D
        ax.plot(trajectory[:, 1],
                trajectory[:, 2],
                trajectory[:, 5],
                color="k")
        ax.set_zlabel("velocity in m/s")

        # plot vertical lines visualizing acceleration and deceleration zones
        ind_stepsize = int(
            np.round(plot_opts["racetraj_vel_3d_stepsize"] / trajectory[1, 0] -
                     trajectory[0, 0]))
        if ind_stepsize < 1:
            ind_stepsize = 1

        cur_ind = 0
        no_points_traj_vdc = np.shape(trajectory)[0]

        while cur_ind < no_points_traj_vdc - 1:
            x_tmp = [trajectory[cur_ind, 1], trajectory[cur_ind, 1]]
            y_tmp = [trajectory[cur_ind, 2], trajectory[cur_ind, 2]]
            z_tmp = [0.0, trajectory[cur_ind, 5]
                     ]  # plot line with height depending on velocity

            # get proper color for line depending on acceleration
            if trajectory[cur_ind, 6] > 0.0:
                col = "g"
            elif trajectory[cur_ind, 6] < 0.0:
                col = "r"
            else:
                col = "gray"

            # plot line
            ax.plot(x_tmp, y_tmp, z_tmp, color=col)

            # increment index
            cur_ind += ind_stepsize

        plt.show()

    if plot_opts["spline_normals"]:
        # plot normals
        plt.figure()
        for i in range(bound1.shape[0]):
            temp = np.vstack((bound1[i], bound2[i]))
            plt.plot(temp[:, 0], temp[:, 1], "k-", linewidth=0.7)
        plt.grid()
        ax = plt.gca()
        ax.set_aspect("equal", "datalim")
        plt.xlabel("east in m")
        plt.ylabel("north in m")
        plt.show()
コード例 #29
0
 def short_proj():
     return np.dot(Axes3D.get_proj(ax), scale)
コード例 #30
0
import sys
from typing import Iterator, List, Tuple, Union

import matplotlib.pyplot as plt  # type: ignore
import numpy as np  # type: ignore
from matplotlib import animation  # type: ignore
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D  # type: ignore

CARDS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10])

fig = plt.figure()
ace_ax = fig.add_subplot(1, 2, 2, projection="3d")
ace_ax.get_proj = lambda: np.dot(Axes3D.get_proj(ace_ax),
                                 np.diag([1.2, 1.3, 0.2, 1]))

ax = fig.add_subplot(1, 2, 1, projection="3d")
ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([1.3, 1.3, 0.2, 1]))


def hand_value(cards: np.ndarray) -> Tuple[int, bool, bool]:
    """return the hand value, blackjack boolean, usable ace boolean"""

    if 1 in cards:
        if len(cards) == 2 and 10 in cards:
            return 21, False, True  # 21 and blackjack
        if cards.sum() - 1 > 10:
            return cards.sum(), False, False

        if cards.sum() + 10 == 21:
            return 21, False, True  # non blackjack 21, usable ace