コード例 #1
0
        def get_xyz_mouse_click(event, ax):
            """
            Get coordinates clicked by user
            """
            if ax.M is None:
                return {}

            xd, yd = event.xdata, event.ydata
            p = (xd, yd)
            edges = ax.tunit_edges()
            ldists = [(proj3d._line2d_seg_dist(p0, p1, p), i) for \
                        i, (p0, p1) in enumerate(edges)]
            ldists.sort()
            # nearest edge
            edgei = ldists[0][1]
            p0, p1 = edges[edgei]
            # scale the z value to match
            x0, y0, z0 = p0
            x1, y1, z1 = p1
            d0 = sqrt(pow(x0 - xd, 2) + pow(y0 - yd, 2))
            d1 = sqrt(pow(x1 - xd, 2) + pow(y1 - yd, 2))
            dt = d0 + d1
            z = d1 / dt * z0 + d0 / dt * z1
            x, y, z = proj3d.inv_transform(xd, yd, z, ax.M)
            return x, y, z
コード例 #2
0
def test_proj_transform():
    M = _test_proj_make_M()

    xs = np.array([0, 1, 1, 0, 0, 0, 1, 1, 0, 0]) * 300.0
    ys = np.array([0, 0, 1, 1, 0, 0, 0, 1, 1, 0]) * 300.0
    zs = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) * 300.0

    txs, tys, tzs = proj3d.proj_transform(xs, ys, zs, M)
    ixs, iys, izs = proj3d.inv_transform(txs, tys, tzs, M)

    np.testing.assert_almost_equal(ixs, xs)
    np.testing.assert_almost_equal(iys, ys)
    np.testing.assert_almost_equal(izs, zs)
コード例 #3
0
ファイル: test_mplot3d.py プロジェクト: magnunor/matplotlib
def test_proj_transform():
    M = _test_proj_make_M()

    xs = np.array([0, 1, 1, 0, 0, 0, 1, 1, 0, 0]) * 300.0
    ys = np.array([0, 0, 1, 1, 0, 0, 0, 1, 1, 0]) * 300.0
    zs = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) * 300.0

    txs, tys, tzs = proj3d.proj_transform(xs, ys, zs, M)
    ixs, iys, izs = proj3d.inv_transform(txs, tys, tzs, M)

    np.testing.assert_almost_equal(ixs, xs)
    np.testing.assert_almost_equal(iys, ys)
    np.testing.assert_almost_equal(izs, zs)
コード例 #4
0
 def onPick(event):
     if not isinstance(
             event.artist,
             Path3DCollection):  # just 3D picking for the moment
         return
     if matplotlib_version > '3.0.3':
         return
     self.datapoint = None
     if len(event.ind) > 0:
         self.datapoint = []
         if self.d3:
             x, y, z = event.artist._offsets3d  # 2021-05-21
             for n in event.ind:
                 self.datapoint.append([
                     n, event.artist._offsets3d[0][n],
                     event.artist._offsets3d[1][n],
                     event.artist._offsets3d[2][n]
                 ])
             self.msg = '{:d}: {}: {:.2f}\n{}: {:.2f}\n{}: {:.2f}'.format(
                 self.datapoint[0][0], ax.get_xlabel(),
                 self.datapoint[0][1], ax.get_ylabel(),
                 self.datapoint[0][2], ax.get_zlabel(),
                 self.datapoint[0][3])
             # If we have previously displayed another label, remove it first
             if hasattr(ax, 'label'):
                 try:
                     ax.label.remove()
                 except:
                     pass
             x2, y2, zs = proj3d.inv_transform(self.datapoint[0][1],
                                               self.datapoint[0][2],
                                               self.datapoint[0][3],
                                               ax.get_proj())
             ax.label = ax.annotate(self.msg,
                                    xy=(x2, y2),
                                    xytext=(0, 20),
                                    textcoords='offset points',
                                    ha='right',
                                    va='bottom',
                                    bbox=dict(boxstyle='round,pad=0.5',
                                              alpha=0.5),
                                    arrowprops=dict(
                                        arrowstyle='->',
                                        connectionstyle='arc3,rad=0'))
         set_flex()
         ax.figure.canvas.draw()
コード例 #5
0
def get_3d_point(xd, yd, ax):
    p = (xd, yd)
    edges = ax.tunit_edges()
    ldists = [(proj3d.line2d_seg_dist(p0, p1, p), i) for \
              i, (p0, p1) in enumerate(edges)]
    ldists.sort()
    # nearest edge
    edgei = ldists[0][1]

    p0, p1 = edges[edgei]

    # scale the z value to match
    x0, y0, z0 = p0
    x1, y1, z1 = p1
    d0 = np.hypot(x0-xd, y0-yd)
    d1 = np.hypot(x1-xd, y1-yd)
    dt = d0 + d1
    z = d1/dt * z0 + d0/dt * z1
    return proj3d.inv_transform(xd, yd, z, ax.M)
コード例 #6
0
ファイル: cPlot3D.py プロジェクト: nabihach/bayesact
    def getCoord(self, iPosX, iPosY):
        point = (iPosX, iPosY)
        edges = self.m_Axes.tunit_edges()

        # lines = [proj3d.line2d(p0,p1) for (p0,p1) in edges]
        ldists = [(proj3d.line2d_seg_dist(p0, p1, point), i) for i, (p0, p1) in enumerate(edges)]
        ldists.sort()
        # nearest edge
        edgei = ldists[0][1]

        p0, p1 = edges[edgei]

        # scale the z value to match
        x0, y0, z0 = p0
        x1, y1, z1 = p1
        d0 = NP.hypot(x0 - iPosX, y0 - iPosY)
        d1 = NP.hypot(x1 - iPosX, y1 - iPosY)
        dt = d0 + d1
        z = d1 / dt * z0 + d0 / dt * z1

        x, y, z = proj3d.inv_transform(iPosX, iPosY, z, self.m_Axes.M)
        return (x, y, z)
コード例 #7
0
ファイル: cPlot3D.py プロジェクト: ahcombs/inteRact
    def getCoord(self, iPosX, iPosY):
        point = (iPosX, iPosY)
        edges = self.m_Axes.tunit_edges()

        #lines = [proj3d.line2d(p0,p1) for (p0,p1) in edges]
        ldists = [(proj3d.line2d_seg_dist(p0, p1, point), i) for i, (p0, p1) in enumerate(edges)]
        ldists.sort()
        # nearest edge
        edgei = ldists[0][1]

        p0, p1 = edges[edgei]

        # scale the z value to match
        x0, y0, z0 = p0
        x1, y1, z1 = p1
        d0 = NP.hypot(x0-iPosX, y0-iPosY)
        d1 = NP.hypot(x1-iPosX, y1-iPosY)
        dt = d0+d1
        z = d1/dt * z0 + d0/dt * z1

        x, y, z = proj3d.inv_transform(iPosX, iPosY, z, self.m_Axes.M)
        return (x, y, z)
コード例 #8
0
    def _format_coord_3d(self, xd, yd):
        """ wrapper function used to swap on hover coordinates fox x and z axes in Axes3D plot """

        # nearest edge
        p0, p1 = min(
            self._plt3d.tunit_edges(),
            key=lambda edge: self._line2d_seg_dist(edge[0], edge[1], (xd, yd)))

        # scale the z value to match
        x0, y0, z0 = p0
        x1, y1, z1 = p1
        d0 = np.hypot(x0 - xd, y0 - yd)
        d1 = np.hypot(x1 - xd, y1 - yd)
        dt = d0 + d1
        z = d1 / dt * z0 + d0 / dt * z1

        from mpl_toolkits.mplot3d import proj3d
        x, y, z = proj3d.inv_transform(xd, yd, z, self._plt3d.M)

        # swap coordinates
        xs = self._plt3d.format_xdata(y)
        ys = self._plt3d.format_ydata(z)
        zs = self._plt3d.format_zdata(x)
        return 'x=%s, y=%s, z=%s' % (xs, ys, zs)
コード例 #9
0
ファイル: mpl.py プロジェクト: rballester/ttrecipes
    def update(event, labelsize=labelsize):
        if event is not None:
            if event.inaxes is None:  # The user clicked outside any axis
                return
            if event.inaxes == axmax:  # The user clicked a button axis
                for n in range(N):
                    focus[n] = estimated_maximum_point[n]
            elif event.inaxes == axmin:
                for n in range(N):
                    focus[n] = estimated_minimum_point[n]
            elif event.inaxes == axrand:
                for n in range(N):
                    focus[n] = np.random.randint(t.n[n])
            elif event.inaxes == axclosest:
                closest = Xs[
                    np.argmin(np.sum(
                        (Xs - focus[np.newaxis, :])**2, axis=1)), :]
                for n in range(N):
                    focus[n] = closest[n]
            else:  # The user clicked a plot axis
                # Find which subplot was clicked
                clicked_axis = np.where(np.asarray(ax) == event.inaxes)[0][0]
                if event.dblclick:  # A double click brings us to the subspace with maximal variance
                    tmoments = tr.core.moments(t,
                                               modes=dims[clicked_axis],
                                               order=2,
                                               centered=True,
                                               normalized=False,
                                               keepdims=True,
                                               eps=1e-2,
                                               rmax=5,
                                               verbose=False)
                    if event.button == 1:  # Left double click: find maximum
                        val, point = tr.core.maximize(tmoments, rmax=5)
                    elif event.button == 3:  # Left double click: find minimum
                        val, point = tr.core.minimize(tmoments, rmax=5)
                    for n in range(N):
                        focus[n] = point[n]
                # Closest axis sample to the clicked focus
                elif diagrams[clicked_axis] == "plot":
                    new_x = (np.abs(ticks_list[dims[clicked_axis][0]] -
                                    event.xdata)).argmin()
                    focus[dims[clicked_axis][0]] = new_x
                elif diagrams[clicked_axis] == "image":
                    new_x = (np.abs(ticks_list[dims[clicked_axis][0]] -
                                    event.xdata)).argmin()
                    new_y = (np.abs(ticks_list[dims[clicked_axis][1]] -
                                    event.ydata)).argmin()
                    focus[dims[clicked_axis][0]] = new_x
                    focus[dims[clicked_axis][1]] = new_y
                elif diagrams[clicked_axis] == "surface":
                    import mpl_toolkits
                    # pdb.set_trace()
                    xd, yd = event.xdata, event.ydata
                    p = (xd, yd)
                    edges = ax[clicked_axis].tunit_edges()
                    # lines = [proj3d.line2d(p0,p1) for (p0,p1) in edges]
                    from mpl_toolkits.mplot3d import proj3d
                    ldists = [(proj3d.line2d_seg_dist(p0, p1, p), i)
                              for i, (p0, p1) in enumerate(edges)]
                    ldists.sort()
                    # nearest edge
                    edgei = ldists[0][1]

                    p0, p1 = edges[edgei]

                    # scale the z value to match
                    x0, y0, z0 = p0
                    x1, y1, z1 = p1
                    d0 = np.hypot(x0 - xd, y0 - yd)
                    d1 = np.hypot(x1 - xd, y1 - yd)
                    dt = d0 + d1
                    z = d1 / dt * z0 + d0 / dt * z1

                    x, y, _ = proj3d.inv_transform(xd, yd, z,
                                                   ax[clicked_axis].M)
                    print(event.xdata, event.ydata)
                    ax[clicked_axis].format_coord(event.xdata, event.ydata)
                    new_x = (np.abs(ticks_list[dims[clicked_axis][0]] -
                                    x)).argmin()
                    new_y = (np.abs(ticks_list[dims[clicked_axis][1]] -
                                    y)).argmin()
                    print(
                        "*", event.button,
                        ax[clicked_axis].format_coord(event.xdata,
                                                      event.ydata))
                    focus[dims[clicked_axis][0]] = new_x
                    focus[dims[clicked_axis][1]] = new_y
                # print(focus)
        for i in range(len(dims)):
            index = dims[i]
            subspace = list(focus)
            for dim in index:
                subspace[dim] = slice(None)
            y = t[subspace].full()
            if diagrams[i] == "plot":
                if not state['initialized']:
                    # Fiber plots
                    lines[i], = ax[i].plot(ticks_list[index[0]],
                                           y,
                                           linewidth=2)
                    ax[i].set_xlabel(names[index[0]], fontsize=labelsize)
                    ax[i].set_ylabel(output_name, fontsize=labelsize)
                    ax[i].set_xlim(
                        [ticks_list[index[0]][0], ticks_list[index[0]][-1]])
                    a = estimated_minimum
                    b = estimated_maximum
                    ax[i].set_ylim([(a + b) / 2 - (b - a) / 2 * 1.1,
                                    (a + b) / 2 + (b - a) / 2 * 1.1])
                    # Vertical lines marking the focus
                    vlines[i] = ax[i].axvline(
                        x=ticks_list[index[0]][focus[index[0]]],
                        ymin=0,
                        ymax=1,
                        linewidth=5,
                        color='red',
                        alpha=0.5)
                else:
                    lines[i].set_ydata(y)
                    ax[i].draw_artist(ax[i].patch)
                    if gt_range >= 0:
                        gt_markers[i].remove()
                    plot_fillings[i].remove()
                    vlines[i].set_xdata(ticks_list[index[0]][focus[index[0]]])
                if gt_range >= 0:
                    # Detect and show ground-truth points that are not far from this fiber
                    # gt_range = np.ceil(s.shape[index[0]]*gt_factor)
                    point_partial = np.delete(focus, index)
                    dists = np.sqrt(
                        np.sum(np.square(positions_partial[i] - point_partial),
                               axis=1))
                    plot_x = np.asarray(ticks_list[index[0]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[0]]]
                    plot_y = np.asarray(ys)[dists <= gt_range]
                    rgba_colors = np.repeat(np.array([[0, 0, 0.6, 0]]),
                                            len(plot_x),
                                            axis=0)
                    dist_score = np.exp(-np.square(dists[dists <= gt_range]) /
                                        (2 * np.square(gt_range / 6) +
                                         np.finfo(np.float32).eps))
                    rgba_colors[:, 3] = dist_score
                    gt_markers[i] = ax[i].scatter(plot_x,
                                                  plot_y,
                                                  s=50 * dist_score,
                                                  c=rgba_colors,
                                                  linewidths=0)
                a = estimated_minimum
                b = estimated_maximum
                plot_fillings[i] = ax[i].fill_between(
                    lines[i].get_xdata(), (a + b) / 2 - (b - a) / 2 * 1.1,
                    y,
                    alpha=0.1,
                    interpolate=True,
                    color='blue')
            elif diagrams[i] == "image":
                if not state['initialized']:
                    images[i] = ax[i].imshow(
                        t[subspace].full().T,
                        cmap=matplotlib.cm.get_cmap('pink'),
                        origin="lower",
                        vmin=estimated_minimum,
                        vmax=estimated_maximum,
                        aspect='auto',
                        extent=[
                            ticks_list[index[0]][0], ticks_list[index[0]][-1],
                            ticks_list[index[1]][0], ticks_list[index[1]][-1]
                        ])
                    ax[i].set_xlim(
                        [ticks_list[index[0]][0], ticks_list[index[0]][-1]])
                    ax[i].set_ylim(
                        [ticks_list[index[1]][0], ticks_list[index[1]][-1]])
                    ax[i].set_xlabel(names[index[0]], fontsize=labelsize)
                    ax[i].set_ylabel(names[index[1]], fontsize=labelsize)
                    points[i] = ax[i].plot(
                        [ticks_list[index[0]][focus[index[0]]]],
                        ticks_list[index[1]][focus[index[1]]],
                        'o',
                        color='red')
                else:
                    # Image 2D plot, using an AxisImage
                    images[i].set_data(t[subspace].full().T)
                    # Point marker for the image
                    points[i][0].set_data(
                        [ticks_list[index[0]][focus[index[0]]]],
                        ticks_list[index[1]][focus[index[1]]])
                    if gt_range >= 0:
                        gt_markers[i].remove()
                if gt_range >= 0:
                    # Detect and show ground-truth points that are not far
                    # from this image
                    point_partial = np.delete(focus, index)
                    dists = np.sqrt(
                        np.sum(np.square(positions_partial[i] - point_partial),
                               axis=1))
                    plot_x = np.asarray(ticks_list[index[0]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[0]]]

                    plot_y = np.asarray(ticks_list[index[1]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[1]]]
                    rgba_colors = np.repeat(np.array([[0, 0, 0.6, 0]]),
                                            len(plot_x),
                                            axis=0)
                    rgba_colors[:, 3] = np.exp(
                        -np.square(dists[dists <= gt_range]) /
                        (2 * np.square(gt_range / 5) +
                         np.finfo(np.float32).eps))
                    gt_markers[i] = ax[i].scatter(plot_x,
                                                  plot_y,
                                                  s=50,
                                                  c=rgba_colors,
                                                  linewidths=0)
            elif diagrams[i] == "surface":
                if not state['initialized']:
                    x, y = np.meshgrid(ticks_list[index[0]],
                                       ticks_list[index[1]])
                    surfaces[i] = ax[i].plot_surface(
                        x,
                        y,
                        t[subspace].full().T,
                        cmap=matplotlib.cm.get_cmap('YlOrBr'),
                        vmin=estimated_minimum,
                        vmax=estimated_maximum,
                        cstride=10,
                        rstride=10)
                    ax[i].set_xlabel(names[index[0]],
                                     fontsize=labelsize,
                                     labelpad=15)
                    ax[i].set_ylabel(names[index[1]],
                                     fontsize=labelsize,
                                     labelpad=15)
                    ax[i].set_zlabel(output_name)
                    ax[i].xaxis.set_rotate_label(False)
                    ax[i].yaxis.set_rotate_label(False)
                    # ax[i].zaxis.set_rotate_label(False)
                    ax[i].set_zlim([estimated_minimum, estimated_maximum])
                    points[i] = ax[i].plot(
                        [ticks_list[index[0]][focus[index[0]]]],
                        [ticks_list[index[1]][focus[index[1]]]], [t[focus]],
                        marker='o',
                        color='red',
                        markeredgecolor='black')
                else:
                    x, y = np.meshgrid(ticks_list[index[0]],
                                       ticks_list[index[1]])
                    surfaces[i].remove()
                    surfaces[i] = ax[i].plot_surface(
                        x,
                        y,
                        t[subspace].full().T,
                        cmap=matplotlib.cm.get_cmap('YlOrBr'),
                        vmin=estimated_minimum,
                        vmax=estimated_maximum,
                        cstride=10,
                        rstride=10)
                    points[i][0].set_data(
                        [ticks_list[index[0]][focus[index[0]]]],
                        [ticks_list[index[1]][focus[index[1]]]])
                    points[i][0].set_3d_properties([t[focus]])
                    if gt_range >= 0:
                        gt_markers[i].remove()
                if gt_range >= 0:
                    # Detect and show ground-truth points that are not far
                    # from this image
                    point_partial = np.delete(focus, index)
                    dists = np.sqrt(
                        np.sum(np.square(positions_partial[i] - point_partial),
                               axis=1))
                    plot_x = np.asarray(ticks_list[index[0]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[0]]]
                    plot_y = np.asarray(ticks_list[index[1]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[1]]]
                    plot_z = np.asarray(ys)[dists <= gt_range]
                    rgba_colors = np.repeat(np.array([[0, 0, 0.6, 0]]),
                                            len(plot_x),
                                            axis=0)
                    rgba_colors[:, 3] = np.exp(
                        -np.square(dists[dists <= gt_range]) /
                        (2 * np.square(gt_range / 5) +
                         np.finfo(np.float32).eps))
                    gt_markers[i] = ax[i].scatter(plot_x,
                                                  plot_y,
                                                  plot_z,
                                                  s=50,
                                                  c=rgba_colors,
                                                  linewidths=0,
                                                  depthshade=False)
        state['initialized'] = True
        point_values = tr.core.indices_to_coordinates(np.asarray([focus]),
                                                      ticks_list)[0]
        point_info = "(" + ", ".join(
            ["{:.3f}".format(point_values[i])
             for i in range(N)]) + ") -> {:.4f}".format(t[focus])
        plt.suptitle("{}".format(point_info))
        # fig.canvas.update()
        fig.canvas.draw_idle()
コード例 #10
0
        def onPress(event):
            if self.base_xlim is None:
                self.base_xlim = ax.get_xlim()
                self.base_ylim = ax.get_ylim()
                try:
                    self.base_zlim = ax.get_zlim()
                    self.d3 = True
                except:
                    self.d3 = False

            if self.d3 and matplotlib_version > '3.0.3':
                try:
                    x, y, z = get_xyz_mouse_click(event, ax)
                except:
                    return
            #   print(f'Clicked at: x={x}, y={y}, z={z}')
                self.datapoint = [[-1, x, y, z]]
                self.msg = '{:d}: {}: {:.2f}\n{}: {:.2f}\n{}: {:.2f}'.format(
                    self.datapoint[0][0], ax.get_xlabel(),
                    self.datapoint[0][1], ax.get_ylabel(),
                    self.datapoint[0][2], ax.get_zlabel(),
                    self.datapoint[0][3])
                # If we have previously displayed another label, remove it first
                if hasattr(ax, 'label'):
                    try:
                        ax.label.remove()
                    except:
                        pass
                x2, y2, zs = proj3d.inv_transform(self.datapoint[0][1],
                                                  self.datapoint[0][2],
                                                  self.datapoint[0][3],
                                                  ax.get_proj())
                ax.label = ax.annotate(self.msg,
                                       xy=(x2, y2),
                                       xytext=(0, 20),
                                       textcoords='offset points',
                                       ha='right',
                                       va='bottom',
                                       bbox=dict(boxstyle='round,pad=0.5',
                                                 alpha=0.5),
                                       arrowprops=dict(
                                           arrowstyle='->',
                                           connectionstyle='arc3,rad=0'))
                return

        #  if self.tbar._active is not None:
        #     return
            if event.button == 3:  # reset?
                self.month = None
                self.week = None
                if self.base_xlim is not None:
                    ax.set_xlim(self.base_xlim)
                    set_flex()
                    ax.figure.canvas.draw()
                    return
            if event.inaxes != ax:
                return
            if self.axis == 'x':
                self.cur_xlim = ax.get_xlim()
                self.press = event.xdata
            elif self.axis == 'y':
                self.cur_ylim = ax.get_ylim()
                self.press = event.ydata
            elif self.axis == 'z':
                self.cur_zlim = ax.get_zlim()
                self.press = event.zdata