コード例 #1
0
ファイル: test_mplot3d.py プロジェクト: magnunor/matplotlib
def test_lines_dists():
    fig, ax = plt.subplots(figsize=(4, 6), subplot_kw=dict(aspect='equal'))

    xs = (0, 30)
    ys = (20, 150)
    ax.plot(xs, ys)
    p0, p1 = zip(xs, ys)

    xs = (0, 0, 20, 30)
    ys = (100, 150, 30, 200)
    ax.scatter(xs, ys)

    dist = proj3d.line2d_seg_dist(p0, p1, (xs[0], ys[0]))
    dist = proj3d.line2d_seg_dist(p0, p1, np.array((xs, ys)))
    for x, y, d in zip(xs, ys, dist):
        c = Circle((x, y), d, fill=0)
        ax.add_patch(c)

    ax.set_xlim(-50, 150)
    ax.set_ylim(0, 300)
コード例 #2
0
ファイル: test_mplot3d.py プロジェクト: wangtaogh/matplotlib
def test_lines_dists():
    fig, ax = plt.subplots(figsize=(4, 6), subplot_kw=dict(aspect='equal'))

    xs = (0, 30)
    ys = (20, 150)
    ax.plot(xs, ys)
    p0, p1 = zip(xs, ys)

    xs = (0, 0, 20, 30)
    ys = (100, 150, 30, 200)
    ax.scatter(xs, ys)

    dist = proj3d.line2d_seg_dist(p0, p1, (xs[0], ys[0]))
    dist = proj3d.line2d_seg_dist(p0, p1, np.array((xs, ys)))
    for x, y, d in zip(xs, ys, dist):
        c = Circle((x, y), d, fill=0)
        ax.add_patch(c)

    ax.set_xlim(-50, 150)
    ax.set_ylim(0, 300)
コード例 #3
0
def get_3d_point(xd, yd, ax):
    p = (xd, yd)
    edges = ax.tunit_edges()
    ldists = [(proj3d.line2d_seg_dist(p0, p1, p), i) for \
              i, (p0, p1) in enumerate(edges)]
    ldists.sort()
    # nearest edge
    edgei = ldists[0][1]

    p0, p1 = edges[edgei]

    # scale the z value to match
    x0, y0, z0 = p0
    x1, y1, z1 = p1
    d0 = np.hypot(x0-xd, y0-yd)
    d1 = np.hypot(x1-xd, y1-yd)
    dt = d0 + d1
    z = d1/dt * z0 + d0/dt * z1
    return proj3d.inv_transform(xd, yd, z, ax.M)
コード例 #4
0
ファイル: cPlot3D.py プロジェクト: nabihach/bayesact
    def getCoord(self, iPosX, iPosY):
        point = (iPosX, iPosY)
        edges = self.m_Axes.tunit_edges()

        # lines = [proj3d.line2d(p0,p1) for (p0,p1) in edges]
        ldists = [(proj3d.line2d_seg_dist(p0, p1, point), i) for i, (p0, p1) in enumerate(edges)]
        ldists.sort()
        # nearest edge
        edgei = ldists[0][1]

        p0, p1 = edges[edgei]

        # scale the z value to match
        x0, y0, z0 = p0
        x1, y1, z1 = p1
        d0 = NP.hypot(x0 - iPosX, y0 - iPosY)
        d1 = NP.hypot(x1 - iPosX, y1 - iPosY)
        dt = d0 + d1
        z = d1 / dt * z0 + d0 / dt * z1

        x, y, z = proj3d.inv_transform(iPosX, iPosY, z, self.m_Axes.M)
        return (x, y, z)
コード例 #5
0
ファイル: cPlot3D.py プロジェクト: ahcombs/inteRact
    def getCoord(self, iPosX, iPosY):
        point = (iPosX, iPosY)
        edges = self.m_Axes.tunit_edges()

        #lines = [proj3d.line2d(p0,p1) for (p0,p1) in edges]
        ldists = [(proj3d.line2d_seg_dist(p0, p1, point), i) for i, (p0, p1) in enumerate(edges)]
        ldists.sort()
        # nearest edge
        edgei = ldists[0][1]

        p0, p1 = edges[edgei]

        # scale the z value to match
        x0, y0, z0 = p0
        x1, y1, z1 = p1
        d0 = NP.hypot(x0-iPosX, y0-iPosY)
        d1 = NP.hypot(x1-iPosX, y1-iPosY)
        dt = d0+d1
        z = d1/dt * z0 + d0/dt * z1

        x, y, z = proj3d.inv_transform(iPosX, iPosY, z, self.m_Axes.M)
        return (x, y, z)
コード例 #6
0
ファイル: mpl.py プロジェクト: rballester/ttrecipes
    def update(event, labelsize=labelsize):
        if event is not None:
            if event.inaxes is None:  # The user clicked outside any axis
                return
            if event.inaxes == axmax:  # The user clicked a button axis
                for n in range(N):
                    focus[n] = estimated_maximum_point[n]
            elif event.inaxes == axmin:
                for n in range(N):
                    focus[n] = estimated_minimum_point[n]
            elif event.inaxes == axrand:
                for n in range(N):
                    focus[n] = np.random.randint(t.n[n])
            elif event.inaxes == axclosest:
                closest = Xs[
                    np.argmin(np.sum(
                        (Xs - focus[np.newaxis, :])**2, axis=1)), :]
                for n in range(N):
                    focus[n] = closest[n]
            else:  # The user clicked a plot axis
                # Find which subplot was clicked
                clicked_axis = np.where(np.asarray(ax) == event.inaxes)[0][0]
                if event.dblclick:  # A double click brings us to the subspace with maximal variance
                    tmoments = tr.core.moments(t,
                                               modes=dims[clicked_axis],
                                               order=2,
                                               centered=True,
                                               normalized=False,
                                               keepdims=True,
                                               eps=1e-2,
                                               rmax=5,
                                               verbose=False)
                    if event.button == 1:  # Left double click: find maximum
                        val, point = tr.core.maximize(tmoments, rmax=5)
                    elif event.button == 3:  # Left double click: find minimum
                        val, point = tr.core.minimize(tmoments, rmax=5)
                    for n in range(N):
                        focus[n] = point[n]
                # Closest axis sample to the clicked focus
                elif diagrams[clicked_axis] == "plot":
                    new_x = (np.abs(ticks_list[dims[clicked_axis][0]] -
                                    event.xdata)).argmin()
                    focus[dims[clicked_axis][0]] = new_x
                elif diagrams[clicked_axis] == "image":
                    new_x = (np.abs(ticks_list[dims[clicked_axis][0]] -
                                    event.xdata)).argmin()
                    new_y = (np.abs(ticks_list[dims[clicked_axis][1]] -
                                    event.ydata)).argmin()
                    focus[dims[clicked_axis][0]] = new_x
                    focus[dims[clicked_axis][1]] = new_y
                elif diagrams[clicked_axis] == "surface":
                    import mpl_toolkits
                    # pdb.set_trace()
                    xd, yd = event.xdata, event.ydata
                    p = (xd, yd)
                    edges = ax[clicked_axis].tunit_edges()
                    # lines = [proj3d.line2d(p0,p1) for (p0,p1) in edges]
                    from mpl_toolkits.mplot3d import proj3d
                    ldists = [(proj3d.line2d_seg_dist(p0, p1, p), i)
                              for i, (p0, p1) in enumerate(edges)]
                    ldists.sort()
                    # nearest edge
                    edgei = ldists[0][1]

                    p0, p1 = edges[edgei]

                    # scale the z value to match
                    x0, y0, z0 = p0
                    x1, y1, z1 = p1
                    d0 = np.hypot(x0 - xd, y0 - yd)
                    d1 = np.hypot(x1 - xd, y1 - yd)
                    dt = d0 + d1
                    z = d1 / dt * z0 + d0 / dt * z1

                    x, y, _ = proj3d.inv_transform(xd, yd, z,
                                                   ax[clicked_axis].M)
                    print(event.xdata, event.ydata)
                    ax[clicked_axis].format_coord(event.xdata, event.ydata)
                    new_x = (np.abs(ticks_list[dims[clicked_axis][0]] -
                                    x)).argmin()
                    new_y = (np.abs(ticks_list[dims[clicked_axis][1]] -
                                    y)).argmin()
                    print(
                        "*", event.button,
                        ax[clicked_axis].format_coord(event.xdata,
                                                      event.ydata))
                    focus[dims[clicked_axis][0]] = new_x
                    focus[dims[clicked_axis][1]] = new_y
                # print(focus)
        for i in range(len(dims)):
            index = dims[i]
            subspace = list(focus)
            for dim in index:
                subspace[dim] = slice(None)
            y = t[subspace].full()
            if diagrams[i] == "plot":
                if not state['initialized']:
                    # Fiber plots
                    lines[i], = ax[i].plot(ticks_list[index[0]],
                                           y,
                                           linewidth=2)
                    ax[i].set_xlabel(names[index[0]], fontsize=labelsize)
                    ax[i].set_ylabel(output_name, fontsize=labelsize)
                    ax[i].set_xlim(
                        [ticks_list[index[0]][0], ticks_list[index[0]][-1]])
                    a = estimated_minimum
                    b = estimated_maximum
                    ax[i].set_ylim([(a + b) / 2 - (b - a) / 2 * 1.1,
                                    (a + b) / 2 + (b - a) / 2 * 1.1])
                    # Vertical lines marking the focus
                    vlines[i] = ax[i].axvline(
                        x=ticks_list[index[0]][focus[index[0]]],
                        ymin=0,
                        ymax=1,
                        linewidth=5,
                        color='red',
                        alpha=0.5)
                else:
                    lines[i].set_ydata(y)
                    ax[i].draw_artist(ax[i].patch)
                    if gt_range >= 0:
                        gt_markers[i].remove()
                    plot_fillings[i].remove()
                    vlines[i].set_xdata(ticks_list[index[0]][focus[index[0]]])
                if gt_range >= 0:
                    # Detect and show ground-truth points that are not far from this fiber
                    # gt_range = np.ceil(s.shape[index[0]]*gt_factor)
                    point_partial = np.delete(focus, index)
                    dists = np.sqrt(
                        np.sum(np.square(positions_partial[i] - point_partial),
                               axis=1))
                    plot_x = np.asarray(ticks_list[index[0]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[0]]]
                    plot_y = np.asarray(ys)[dists <= gt_range]
                    rgba_colors = np.repeat(np.array([[0, 0, 0.6, 0]]),
                                            len(plot_x),
                                            axis=0)
                    dist_score = np.exp(-np.square(dists[dists <= gt_range]) /
                                        (2 * np.square(gt_range / 6) +
                                         np.finfo(np.float32).eps))
                    rgba_colors[:, 3] = dist_score
                    gt_markers[i] = ax[i].scatter(plot_x,
                                                  plot_y,
                                                  s=50 * dist_score,
                                                  c=rgba_colors,
                                                  linewidths=0)
                a = estimated_minimum
                b = estimated_maximum
                plot_fillings[i] = ax[i].fill_between(
                    lines[i].get_xdata(), (a + b) / 2 - (b - a) / 2 * 1.1,
                    y,
                    alpha=0.1,
                    interpolate=True,
                    color='blue')
            elif diagrams[i] == "image":
                if not state['initialized']:
                    images[i] = ax[i].imshow(
                        t[subspace].full().T,
                        cmap=matplotlib.cm.get_cmap('pink'),
                        origin="lower",
                        vmin=estimated_minimum,
                        vmax=estimated_maximum,
                        aspect='auto',
                        extent=[
                            ticks_list[index[0]][0], ticks_list[index[0]][-1],
                            ticks_list[index[1]][0], ticks_list[index[1]][-1]
                        ])
                    ax[i].set_xlim(
                        [ticks_list[index[0]][0], ticks_list[index[0]][-1]])
                    ax[i].set_ylim(
                        [ticks_list[index[1]][0], ticks_list[index[1]][-1]])
                    ax[i].set_xlabel(names[index[0]], fontsize=labelsize)
                    ax[i].set_ylabel(names[index[1]], fontsize=labelsize)
                    points[i] = ax[i].plot(
                        [ticks_list[index[0]][focus[index[0]]]],
                        ticks_list[index[1]][focus[index[1]]],
                        'o',
                        color='red')
                else:
                    # Image 2D plot, using an AxisImage
                    images[i].set_data(t[subspace].full().T)
                    # Point marker for the image
                    points[i][0].set_data(
                        [ticks_list[index[0]][focus[index[0]]]],
                        ticks_list[index[1]][focus[index[1]]])
                    if gt_range >= 0:
                        gt_markers[i].remove()
                if gt_range >= 0:
                    # Detect and show ground-truth points that are not far
                    # from this image
                    point_partial = np.delete(focus, index)
                    dists = np.sqrt(
                        np.sum(np.square(positions_partial[i] - point_partial),
                               axis=1))
                    plot_x = np.asarray(ticks_list[index[0]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[0]]]

                    plot_y = np.asarray(ticks_list[index[1]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[1]]]
                    rgba_colors = np.repeat(np.array([[0, 0, 0.6, 0]]),
                                            len(plot_x),
                                            axis=0)
                    rgba_colors[:, 3] = np.exp(
                        -np.square(dists[dists <= gt_range]) /
                        (2 * np.square(gt_range / 5) +
                         np.finfo(np.float32).eps))
                    gt_markers[i] = ax[i].scatter(plot_x,
                                                  plot_y,
                                                  s=50,
                                                  c=rgba_colors,
                                                  linewidths=0)
            elif diagrams[i] == "surface":
                if not state['initialized']:
                    x, y = np.meshgrid(ticks_list[index[0]],
                                       ticks_list[index[1]])
                    surfaces[i] = ax[i].plot_surface(
                        x,
                        y,
                        t[subspace].full().T,
                        cmap=matplotlib.cm.get_cmap('YlOrBr'),
                        vmin=estimated_minimum,
                        vmax=estimated_maximum,
                        cstride=10,
                        rstride=10)
                    ax[i].set_xlabel(names[index[0]],
                                     fontsize=labelsize,
                                     labelpad=15)
                    ax[i].set_ylabel(names[index[1]],
                                     fontsize=labelsize,
                                     labelpad=15)
                    ax[i].set_zlabel(output_name)
                    ax[i].xaxis.set_rotate_label(False)
                    ax[i].yaxis.set_rotate_label(False)
                    # ax[i].zaxis.set_rotate_label(False)
                    ax[i].set_zlim([estimated_minimum, estimated_maximum])
                    points[i] = ax[i].plot(
                        [ticks_list[index[0]][focus[index[0]]]],
                        [ticks_list[index[1]][focus[index[1]]]], [t[focus]],
                        marker='o',
                        color='red',
                        markeredgecolor='black')
                else:
                    x, y = np.meshgrid(ticks_list[index[0]],
                                       ticks_list[index[1]])
                    surfaces[i].remove()
                    surfaces[i] = ax[i].plot_surface(
                        x,
                        y,
                        t[subspace].full().T,
                        cmap=matplotlib.cm.get_cmap('YlOrBr'),
                        vmin=estimated_minimum,
                        vmax=estimated_maximum,
                        cstride=10,
                        rstride=10)
                    points[i][0].set_data(
                        [ticks_list[index[0]][focus[index[0]]]],
                        [ticks_list[index[1]][focus[index[1]]]])
                    points[i][0].set_3d_properties([t[focus]])
                    if gt_range >= 0:
                        gt_markers[i].remove()
                if gt_range >= 0:
                    # Detect and show ground-truth points that are not far
                    # from this image
                    point_partial = np.delete(focus, index)
                    dists = np.sqrt(
                        np.sum(np.square(positions_partial[i] - point_partial),
                               axis=1))
                    plot_x = np.asarray(ticks_list[index[0]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[0]]]
                    plot_y = np.asarray(ticks_list[index[1]])[np.asarray(
                        Xs).astype(int)[dists <= gt_range, index[1]]]
                    plot_z = np.asarray(ys)[dists <= gt_range]
                    rgba_colors = np.repeat(np.array([[0, 0, 0.6, 0]]),
                                            len(plot_x),
                                            axis=0)
                    rgba_colors[:, 3] = np.exp(
                        -np.square(dists[dists <= gt_range]) /
                        (2 * np.square(gt_range / 5) +
                         np.finfo(np.float32).eps))
                    gt_markers[i] = ax[i].scatter(plot_x,
                                                  plot_y,
                                                  plot_z,
                                                  s=50,
                                                  c=rgba_colors,
                                                  linewidths=0,
                                                  depthshade=False)
        state['initialized'] = True
        point_values = tr.core.indices_to_coordinates(np.asarray([focus]),
                                                      ticks_list)[0]
        point_info = "(" + ", ".join(
            ["{:.3f}".format(point_values[i])
             for i in range(N)]) + ") -> {:.4f}".format(t[focus])
        plt.suptitle("{}".format(point_info))
        # fig.canvas.update()
        fig.canvas.draw_idle()