Ejemplo n.º 1
0
 def _change_map_size(self, figure: Figure, factor):
     old_size = figure.get_size_inches()
     figure.set_size_inches([factor * s for s in old_size])
     wi, hi = [i * figure.dpi for i in figure.get_size_inches()]
     if self.mpl_canvas and self.canvas:
         self.mpl_canvas.config(width=wi, height=hi)
         self.canvas.itemconfigure(self.cwid, width=wi, height=hi)
         self.canvas.config(scrollregion=self.canvas.bbox(Tkconstants.ALL),
                            width=200,
                            height=200)
         figure.canvas.draw()
Ejemplo n.º 2
0
def get_Img_Matrix_Antialiasing(point_pos, w, x_size=128, y_size=128):
    point_pos = np.dot(w, point_pos)
    mydpi = 96
    fig = Figure(figsize=(x_size / mydpi, y_size / mydpi), dpi=mydpi)
    fig.subplots_adjust(left=0.,
                        bottom=0.,
                        right=1.,
                        top=1.,
                        wspace=0,
                        hspace=0)
    fig.set_facecolor('black')
    a = fig.add_subplot(111, axisbg='r')
    canvas = FigureCanvas(fig)
    # ax = fig.gca()

    a.axis([0, x_size, y_size, 0])
    a.axis('off')
    a.plot(point_pos[:, 0], point_pos[:, 1], c='w')
    # fig.tight_layout()
    # plt.tight_layout()
    canvas.draw()  # draw the canvas, cache the renderer
    width, height = fig.get_size_inches() * fig.get_dpi()
    # print(width, height)
    image = np.fromstring(canvas.tostring_rgb(),
                          dtype='uint8').reshape(int(height), int(width),
                                                 3)[:, :, 0]
    return image
Ejemplo n.º 3
0
    def my_draw(self):
        """
        This function is used by DQL_visualization_actions.py to make video from sequence of actions        
        """

        fig = Figure()
        canvas = FigureCanvas(fig)
        ax = fig.gca()

        ax.imshow(self.image_playground)

        rect = patches.Rectangle((self.agent_window[0], self.agent_window[1]),
                                 self.agent_window[2] - self.agent_window[0],
                                 self.agent_window[3] - self.agent_window[1],
                                 linewidth=1,
                                 edgecolor='r',
                                 facecolor='none')
        ax.add_patch(rect)

        for target in [self.targets[0]]:
            rect2 = patches.Rectangle((target[0], target[1]),
                                      target[2] - target[0],
                                      target[3] - target[1],
                                      linewidth=1,
                                      edgecolor='b',
                                      facecolor='none')
            ax.add_patch(rect2)

        canvas.draw()

        width, height = fig.get_size_inches() * fig.get_dpi()

        return np.fromstring(canvas.tostring_rgb(),
                             dtype='uint8').reshape(int(height), int(width), 3)
Ejemplo n.º 4
0
    def my_draw(self):
        from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
        from matplotlib.figure import Figure

        fig = Figure()
        canvas = FigureCanvas(fig)
        ax = fig.gca()

        ax.imshow(self.image_playground)

        # Drawing agent window
        rect = patches.Rectangle((self.agent_window[0], self.agent_window[1]),
                                 self.agent_window[2] - self.agent_window[0],
                                 self.agent_window[3] - self.agent_window[1],
                                 linewidth=1,
                                 edgecolor='r',
                                 facecolor='none')
        ax.add_patch(rect)

        # Drawing target objects bouning boxes
        for target in [self.targets[0]]:
            rect2 = patches.Rectangle((target[0], target[1]),
                                      target[2] - target[0],
                                      target[3] - target[1],
                                      linewidth=1,
                                      edgecolor='b',
                                      facecolor='none')
            ax.add_patch(rect2)

        canvas.draw()  # draw the canvas, cache the renderer

        width, height = fig.get_size_inches() * fig.get_dpi()

        return np.fromstring(canvas.tostring_rgb(),
                             dtype='uint8').reshape(int(height), int(width), 3)
Ejemplo n.º 5
0
def visualize_cos(feat, labels, epoch):
    colors = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
              '#ff00ff', '#990000', '#999900', '#009900', '#009999']

    fig = Figure(figsize=(6, 6), dpi=100)
    fig.clf()
    canvas = FigureCanvas(fig)
    ax = fig.gca()


    feat = feat / feat.norm(2, 1).unsqueeze(1).repeat(1, 2)
    feat = feat.data.cpu().numpy()
    labels = labels.data.cpu().numpy()

    weight = head.state_dict()['weight'].t()
    weight = weight / weight.norm(2, 1).unsqueeze(1).repeat(1, 2)
    weight = weight.data.cpu().numpy()

    for i in range(10):
        ax.scatter(feat[labels == i, 0], feat[labels == i, 1], c=colors[i], s=1)
        ax.text(feat[labels == i, 0].mean(), feat[labels == i, 1].mean(), str(i), color='black', fontsize=12)
        ax.plot([0,weight[i][0]],[0,weight[i][1]],linewidth=2,color=colors[i])
    ax.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
    ax.text(0, 0, "epoch=%d" % epoch)
    canvas.draw()
    fig.savefig(imgDir + '/cos_epoch=%d.jpg' % epoch)
    width, height = fig.get_size_inches() * fig.get_dpi()
    img = np.fromstring(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)

    tt = transforms.ToTensor()
    timg = tt(img)
    timg.unsqueeze(0)
    writer.add_image('NormFace_cos', timg, epoch)
def plot_as_image(masking: "Masking") -> "Array":
    """
    Plot layer wise density as bar plot figure.

    :param masking: Masking instance
    :type masking: sparselearning.core.Masking
    :return: Numpy array representing figure (H, W, 3)
    :rtype: np.ndarray
    """
    fig = Figure()
    canvas = FigureCanvas(fig)
    ax = fig.gca()

    density_ll = _get_density_ll(masking)
    bin_ll = np.arange(len(density_ll)) + 1
    width = 0.8

    ax.bar(bin_ll, density_ll, width, color="b")

    ax.set_ylabel("Density")
    ax.set_xlabel("Layer Number")

    canvas.draw()  # draw the canvas, cache the renderer

    width, height = fig.get_size_inches() * fig.get_dpi()
    width, height = int(width), int(height)
    image = np.fromstring(canvas.tostring_rgb(),
                          dtype="uint8").reshape(height, width, 3)

    return image
Ejemplo n.º 7
0
def visualize(feat, labels, epoch):
    colors = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
              '#ff00ff', '#990000', '#999900', '#009900', '#009999']

    fig = Figure(figsize=(6, 6), dpi=100)
    fig.clf()
    canvas = FigureCanvas(fig)
    ax = fig.gca()

    feat = feat.data.cpu().numpy()
    labels = labels.data.cpu().numpy()

    for i in range(10):
        ax.scatter(feat[labels == i, 0], feat[labels == i, 1], c=colors[i], s=1)
        ax.text(feat[labels == i, 0].mean(), feat[labels == i, 1].mean(), str(i), color='black', fontsize=12)
    ax.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
    ax.text(0, 0, "epoch=%d" % epoch)
    canvas.draw()

    if (os.path.exists(imgDir)):
        pass
    else:
        os.makedirs(imgDir)
    fig.savefig(imgDir + '/epoch=%d.jpg' % epoch)
    width, height = fig.get_size_inches() * fig.get_dpi()
    img = np.fromstring(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)

    tt = transforms.ToTensor()
    timg = tt(img)
    timg.unsqueeze(0)
    writer.add_image('NormFace', timg, epoch)
def getPlotImage(data_x,
                 data_y,
                 cols,
                 title,
                 line_labels,
                 x_label,
                 y_label,
                 ylim=None,
                 legend=1):
    cols = [(col[0] / 255.0, col[1] / 255.0, col[2] / 255.0) for col in cols]

    fig = Figure(
        # figsize=(6.4, 3.6), dpi=300,
        figsize=(4.8, 2.7),
        dpi=400,
        # edgecolor='k',
        # facecolor ='k'
    )
    # fig.tight_layout()
    # fig.set_tight_layout(True)
    fig.subplots_adjust(
        bottom=0.17,
        right=0.95,
    )
    canvas = FigureCanvas(fig)
    ax = fig.gca()

    n_data = len(data_y)
    for i in range(n_data):
        datum_y = data_y[i]
        line_label = line_labels[i]
        col = cols[i]
        args = {'color': col}
        if legend:
            args['label'] = line_label
        ax.plot(data_x, datum_y, **args)
    plt.rcParams['axes.titlesize'] = 10
    # fontdict = {'fontsize': plt.rcParams['axes.titlesize'],
    #             'fontweight': plt.rcParams['axes.titleweight'],
    # 'verticalalignment': 'baseline',
    # 'horizontalalignment': plt.loc
    # }
    ax.set_title(title,
                 # fontdict=fontdict
                 )
    if legend:
        ax.legend(fancybox=True, framealpha=0.1)
    ax.grid(1)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)

    if ylim is not None:
        ax.set_ylim(*ylim)

    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    plot_img = np.fromstring(canvas.tostring_rgb(),
                             dtype='uint8').reshape(int(height), int(width), 3)

    return plot_img
Ejemplo n.º 9
0
    def get_virtual_drone():
        fig_scale = 1
        fig = Figure(figsize=(4 * fig_scale, 3 * fig_scale), dpi=200)
        canvas = FigureCanvas(fig)
        ax = fig.gca(projection='3d')
        ax.set_title("Drone Position (meters)")
        ax.set_ylim([-0.7, 0.7])
        ax.set_xlim([-0.7, 0.7])
        ax.set_zlim([0, 1])

        color = "#00ff00"
        quiver_color = "#0055ff"
        if d_pos != None:
            ax.scatter([d_pos['x']], [d_pos['y']], [d_pos['z']], color=color)
            if d_vel != None:
                ax.quiver(d_pos['x'],
                          d_pos['y'],
                          d_pos['z'],
                          d_vel['x'],
                          d_vel['y'],
                          d_vel['z'],
                          length=0.2,
                          color=quiver_color)

        canvas.draw()
        width, height = fig.get_size_inches() * fig.get_dpi()
        image = np.fromstring(canvas.tostring_rgb(),
                              dtype='uint8').reshape(int(height), int(width),
                                                     3)
        return image
Ejemplo n.º 10
0
def segmented_image(arr):
    """Get segmented image as a numpy array"""
    
    colors_dic = {'0':'#ffd300', '1':'#93cc93', '2':'#4970a3', '3':'#999999'}

    keys = list(np.unique(arr))
    keys = [str(i) for i in keys]
    colors = [colors_dic.get(key) for key in keys]
    
    cmap = mpl.colors.ListedColormap(colors)

    fig = Figure()
    fig.set_size_inches(256/fig.get_dpi(), 256/fig.get_dpi())
    fig.subplots_adjust(0,0,1,1)
    canvas = FigureCanvas(fig)
    ax = fig.gca()

    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
            hspace = 0, wspace = 0)

    ax.imshow(arr, cmap=cmap)
    ax.axis('off')
    ax.margins(0,0)

    canvas.draw()       # draw the canvas, cache the renderer

    width, height = fig.get_size_inches() * fig.get_dpi()

    image = np.fromstring(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)
    
    return image
Ejemplo n.º 11
0
 def __init__(self, figure: Figure):
     """Initialize window."""
     super(FigureWindow,
           self).__init__(size=figure.get_size_inches() * figure.get_dpi())
     chart = Chart()
     self.content = toga.Box(children=[chart])
     chart.draw(figure)
Ejemplo n.º 12
0
def changeFigureSize(w: float, h: float, cut_from_top: bool = False, cut_from_left: bool = False, fig: Figure = None):
    """ change the figure size to the given dimensions. Optionally define if to remove or add space at the top or bottom
        and left or right.
    """
    if fig is None:
        fig = plt.gcf()
    oldw, oldh = fig.get_size_inches()
    fx = oldw / w
    fy = oldh / h
    for axe in fig.axes:
        box = axe.get_position()
        if cut_from_top:
            if cut_from_left:
                axe.set_position([1 - (1 - box.x0) * fx, box.y0 * fy, (box.x1 - box.x0) * fx, (box.y1 - box.y0) * fy])
            else:
                axe.set_position([box.x0 * fx, box.y0 * fy, (box.x1 - box.x0) * fx, (box.y1 - box.y0) * fy])
        else:
            if cut_from_left:
                axe.set_position(
                    [1 - (1 - box.x0) * fx, 1 - (1 - box.y0) * fy, (box.x1 - box.x0) * fx, (box.y1 - box.y0) * fy])
            else:
                axe.set_position([box.x0 * fx, 1 - (1 - box.y0) * fy, (box.x1 - box.x0) * fx, (box.y1 - box.y0) * fy])
    for text in fig.texts:
        x0, y0 = text.get_position()
        if cut_from_top:
            if cut_from_left:
                text.set_position([1 - (1- x0) * fx, y0 * fy])
            else:
                text.set_position([x0 * fx, y0 * fy])
        else:
            if cut_from_left:
                text.set_position([1 - (1 - x0) * fx, 1 - (1 - y0) * fy])
            else:
                text.set_position([x0 * fx, 1 - (1 - y0) * fy])
    fig.set_size_inches(w, h, forward=True)
Ejemplo n.º 13
0
def mathTex_to_QPixmap(mathTex, fs):
    ''' E/ mathTEx : The formula to be displayed on screen in LaTeX
        E/ fs : The desired font size 
        S/ qpixmap : The image to be displayed in Qpixmap format '''
    # Set up a mpl figure instance #
    fig = Figure()
    fig.patch.set_facecolor('none')
    fig.set_canvas(FigureCanvas(fig))
    renderer = fig.canvas.get_renderer()

    # Plot the mathTex expression #
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis('off')
    ax.patch.set_facecolor('none')
    t = ax.text(0, 0, mathTex, ha='left', va='bottom', fontsize=fs)

    # Fit figure size to text artist #
    fwidth, fheight = fig.get_size_inches()
    fig_bbox = fig.get_window_extent(renderer)
    text_bbox = t.get_window_extent(renderer)
    tight_fwidth = text_bbox.width * fwidth / fig_bbox.width
    tight_fheight = text_bbox.height * fheight / fig_bbox.height
    fig.set_size_inches(tight_fwidth, tight_fheight)

    # Convert mpl figure to QPixmap #
    buf, size = fig.canvas.print_to_buffer()
    qimage = QtGui.QImage.rgbSwapped(QtGui.QImage(buf, size[0], size[1],
                                                  QtGui.QImage.Format_ARGB32))
    qpixmap = QtGui.QPixmap(qimage)

    return qpixmap
Ejemplo n.º 14
0
class Render(object):
    def __init__(self, figsize=(15, 15), dpi=48):
        self.figsize = figsize
        self.dpi = dpi
        self.fig = Figure(figsize=figsize, dpi=dpi)
        self.canvas = FigureCanvas(self.fig)
        self.artists = []

    def add_artist(self, artist):
        self.artists.append(artist)

    def new_frame(self):
        self.fig.clear()
        self.ax = self.fig.gca()
        self.ax.clear()
        for artist in self.artists:
            artist.remove()
        self.artists = []
        self.ax.set_xlim(-0.1, MAP_SIZE + 1.1)
        self.ax.set_ylim(-0.1, MAP_SIZE + 1.1)
        self.ax.axis('off')

    def draw(self):
        for artist in self.artists:
            self.ax.add_artist(artist)
        self.canvas.draw()       # draw the canvas, cache the renderer
        width, height = self.fig.get_size_inches() * self.fig.get_dpi()
        image = np.frombuffer(self.canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)
        return image
Ejemplo n.º 15
0
class Plot:
    def __init__(self, size, pos, x_label, x, y_label, y, title):
        self.x_series = x
        self.x_label = x_label
        self.y_label = y_label
        self.y_series = y
        self.title = title
        self.size = size
        self.pos = pos
        self.fig = Figure()
        self.figure = []

    def draw(self, window):
        canvas = FigureCanvas(self.fig)
        ax = self.fig.gca()

        ax.plot(self.x_series, self.y_series)
        ax.set_xlabel(self.x_label)
        ax.set_ylabel(self.y_label)
        ax.set_title(self.title)
        ax.grid(True)

        canvas.draw()  # draw the canvas, cache the renderer

        width, height = self.fig.get_size_inches() * self.fig.get_dpi()
        self.figure = np.fromstring(canvas.tostring_rgb(),
                                    dtype='uint8').reshape(
                                        int(height), int(width), 3)
        self.figure = cv2.resize(self.figure, flip(self.size))
        window[self.pos[0]:self.pos[0] + self.size[0],
               self.pos[1]:self.pos[1] + self.size[1], :] = self.figure
Ejemplo n.º 16
0
def mathTex_to_QPixmap(mathTex, fs):

    # ---- set up a mpl figure instance ----
    fig = Figure()
    fig.patch.set_facecolor('none')
    fig.set_canvas(FigureCanvasAgg(fig))
    renderer = fig.canvas.get_renderer()

    # ---- plot the mathTex expression ----

    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis('off')
    ax.patch.set_facecolor('none')
    t = ax.text(0, 0, mathTex, ha='left', va='bottom', fontsize=fs)

    # ---- fit figure size to text artist ----

    fwidth, fheight = fig.get_size_inches()
    fig_bbox = fig.get_window_extent(renderer)

    text_bbox = t.get_window_extent(renderer)

    tight_fwidth = text_bbox.width * fwidth / fig_bbox.width
    tight_fheight = text_bbox.height * fheight / fig_bbox.height

    fig.set_size_inches(tight_fwidth, tight_fheight)

    # ---- convert mpl figure to QPixmap ----

    buf, size = fig.canvas.print_to_buffer()
    qimage = QtGui.QImage.rgbSwapped(
        QtGui.QImage(buf, size[0], size[1], QtGui.QImage.Format_ARGB32))
    qpixmap = QtGui.QPixmap(qimage)

    return qpixmap
Ejemplo n.º 17
0
    def render(self, mode='human'):
        """
        use PCA to plot users, items and actions.
        :param mode:
        :return:
        """
        if mode == 'rgb_array':
            users_vec = np.array(
                [u.embedding for uid, u in self.users.items()])
            items_vec = np.array(
                [i.embedding for iid, i in self.items.items()])

            # build dimensionality reduction if dims > 2
            if items_vec.shape[1] > 2:
                if not hasattr(self, 'pca'):
                    self.pca = PCA(n_components=2)
                    self.pca.fit(items_vec)

                items = self.pca.transform(items_vec)
                users = self.pca.transform(users_vec)
            else:
                items, users = items_vec, users_vec

            fig = Figure(figsize=(5, 5))
            canvas = FigureCanvas(fig)
            ax = fig.gca()
            ax.scatter(items[:, 0], items[:, 1], c='green', label='items')
            ax.scatter(users[:, 0], users[:, 1], c='red', label='users')
            # active user
            x, y = users[self.active_user]
            ax.scatter(x, y, marker='*', c='black', s=20, label='active user')

            # active user recommendation history
            # actions = self.last_actions[self.active_user]
            # rewards = self.last_rewards[self.active_user]
            # TODO: if item set will change will have problems
            # if self.action_is_items:
            #     lines = [ [(x, y), a ] for a in actions]
            # else:
            #     lines = [ [(x, y), (self.items[a][0], self.items[a][1])] for a in actions]
            #
            # c = [ 'yellow' if r else 'black' for r in rewards]
            # lc = mc.LineCollection(lines, colors=c, linewidths=2)
            # ax.add_collection(lc)

            ax.legend()
            ax.axis('off')
            canvas.draw()  # draw the canvas, cache the renderer
            width, height = [
                int(x) for x in fig.get_size_inches() * fig.get_dpi()
            ]
            image = np.fromstring(canvas.tostring_rgb(),
                                  dtype='uint8').reshape(height, width, 3)

            return image
        else:
            pass
Ejemplo n.º 18
0
    def __call__(self, shape, seed=None):
        if seed is not None:
            self.rng.seed(seed)
        else:
            self.rng.seed(int(time.time() * 100) % (2**32 - 1))

        target_width, target_height = shape

        fig = Figure()
        canvas = FigureCanvas(fig)
        ax = fig.gca()
        ax.axis('off')

        def central_crop(image, new_width, new_height):
            image = image.copy()

            width, height = image.size

            left = (width - new_width) / 2
            top = (height - new_height) / 2
            right = (width + new_width) / 2
            bottom = (height + new_height) / 2

            image = image.crop((left, top, right, bottom))

            return image

        if self.rng.choice([True, False],
                           p=[self.fat_stroke_prob, 1 - self.fat_stroke_prob]):
            stroke_index = self.rng.randint(0, len(self.strokes))
            stroke_width = self.fat_stroke_width
            stroke = self.strokes[
                stroke_index][:int(len(self.strokes[stroke_index]) * 1)]
            ax.plot(*stroke, color='black', lw=stroke_width)
        else:
            n_strokes = int(self.rng.normal(5, 1))
            for _ in range(n_strokes):
                stroke_index = self.rng.randint(0, len(self.strokes))
                stroke_width = self.rng.randint(self.min_stroke_width,
                                                self.max_stroke_width)
                stroke = self.strokes[
                    stroke_index][:int(len(self.strokes[stroke_index]) * 1)]
                ax.plot(*stroke, color='black', lw=stroke_width)

        canvas.draw()  # draw the canvas, cache the renderer
        width, height = fig.get_size_inches() * fig.get_dpi()
        image = np.frombuffer(canvas.tostring_rgb(),
                              dtype='uint8').reshape(int(height), int(width),
                                                     3)
        image = Image.fromarray(image)

        image = image.resize(
            (int(1.25 * target_width), int(1.25 * target_height)))
        image = central_crop(image, target_width, target_height)
        image = image.convert('1')

        return np.array(image).astype('float32')
Ejemplo n.º 19
0
def get_star_metrics(epoch, X, Y, emulated):
    my_dpi = 72.0
    fig = Figure(figsize=(720 / my_dpi, 360 / my_dpi),
                 dpi=my_dpi,
                 tight_layout=True)
    canvas = plotting.plot_confusion_matrix(fig, X, Y, emulated)
    width, height = fig.get_size_inches() * fig.get_dpi()
    conf_mat = np.frombuffer(canvas.tostring_rgb(),
                             dtype='uint8').reshape(1, int(height), int(width),
                                                    3)
    return conf_mat
Ejemplo n.º 20
0
    def __init__(self, figure: Figure, title: str, app: toga.App):
        """Initialize window."""
        self.figure = figure
        super().__init__(
            title=title, size=(1, 1.35) * (figure.get_size_inches() * figure.get_dpi())
        )
        chart = Chart()

        save_button = toga.Button(label="Save", on_press=self.save_figure)
        save_box = toga.Box(children=[save_button])
        chart_box = toga.Box(
            children=[chart],
            style=Pack(height=(figure.get_size_inches() * figure.get_dpi())[1]),
        )
        main_box = toga.Box(
            children=[chart_box, save_box], style=Pack(direction=COLUMN)
        )
        self.content = main_box
        chart.draw(figure)
        self.app = app
Ejemplo n.º 21
0
def create_bounding_box(image_path,
                        image_size,
                        model,
                        *,
                        dpi=120,
                        topk=4,
                        grid_num=5):
    model.eval()

    image = Image.open(image_path)
    im_h, im_w = image.size

    # unsqueeze for make batch_size = 1
    input = totensor(image.resize([image_size, image_size])).unsqueeze(0)
    if cuda:
        input = input.cuda()
    output_loc, output_cnf, output_cls = model(Variable(input))

    if cuda:
        output_loc = output_loc.cpu()
        output_cnf = output_cnf.cpu()
        output_cls = output_cls.cpu()
    # squeeze because batch is 1
    output_cls.data.squeeze_(0), output_loc.data.squeeze_(0)
    output_cnf = output_cnf.data[0, 0, :, :]
    # get high confidence grids
    grids = topk_2d(output_cnf, topk)
    grid_size = image_size // grid_num

    # load image for matplotlib
    image = mpimg.imread(image_path)
    fig = Figure(figsize=(im_h / dpi, im_w / dpi), dpi=dpi)
    canvas = FigureCanvas(fig)
    ax = fig.gca()
    ax.imshow(image)

    for grid_x, grid_y in grids:
        x_0, x_1, y_0, y_1 = get_bbox_points(output_loc.data, grid_x, grid_y,
                                             grid_size, im_w, im_h, image_size)
        ax.plot([y_0, y_0, y_1, y_1, y_0], [x_0, x_1, x_1, x_0, x_0])
        ax.text(y_0,
                x_0,
                f"{id_class[find_cls(output_cls[:, grid_x, grid_y])]}",
                bbox={'alpha': 0.5})
        ax.axis('off')

    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    image = np.fromstring(canvas.tostring_rgb(),
                          dtype='uint8').reshape(int(height), int(width), 3)
    # numpy array -> PIL data -> tensor
    tensor = totensor(Image.fromarray(image))
    # image is numpy array and tensor is tensor
    return image, tensor
Ejemplo n.º 22
0
def get_sample_cornerplot(Y_nat, sampled_result):
    n_obj, n_sample, reg_dim = sampled_result.shape
    my_dpi = 72.0
    fig = Figure(figsize=(720 / my_dpi, 360 / my_dpi),
                 dpi=my_dpi,
                 tight_layout=True)
    canvas = plotting.plot_sample_corner(fig, X, Y, emulated, flux_formatting,
                                         bp)
    width, height = fig.get_size_inches() * fig.get_dpi()
    img = np.frombuffer(canvas.tostring_rgb(),
                        dtype='uint8').reshape(1, int(height), int(width), 3)
    return img
Ejemplo n.º 23
0
def plotit( data, outfilename, **kargs ) :

    # going to make a 1 row x N column plot
    if len(data.shape)==1 : 
        num_rows = 1
    else : 
        num_rows = data.shape[1]

    # davep 02-Oct-2012 ; bump up the size to accommodate multiple rows
    fig = Figure()
    figsize = fig.get_size_inches()
#    fig.set_size_inches( (figsize[0],figsize[1]*num_rows) )

    if "title" in kargs : 
        fig.suptitle(kargs["title"])

    # http://matplotlib.org/faq/howto_faq.html
    # "Move the edge of an axes to make room for tick labels"
    # hspace is "the amount of height reserved for white space between
    # subplots"
    fig.subplots_adjust( hspace=0.40 )

    ax = fig.add_subplot(111)
    ax.grid()
    ax.set_ylim(-0.1,1.1)

    label_iter = iter( ("Strip Metric","FullPage Metric","All Strips' Mean"))
    for i in range(num_rows) : 
        if num_rows==1 :
            column = data 
        else : 
            column = data[ :, i ] 

        fmt = kargs.get("fmt","+")
        if "color" in kargs : 
            fmt += kargs["color"]            
        ax.plot(column,fmt,label=label_iter.next())

    if "axis_title" in kargs : 
        title = kargs["axis_title"][i]
        ax.set_title(title)

    ax.legend(loc="lower left")

    ax.set_xlabel( "Strip Number" )
    ax.set_ylabel( "Match Metric" )

    canvas = FigureCanvasAgg(fig)
    canvas.print_figure(outfilename)
    print "wrote", outfilename
Ejemplo n.º 24
0
    def Draw_Pixel_Image(self):
     
        fig = Figure((3.0, 4.0))
        fig.patch.set_visible(False)
        self.Pixel_Canvas = FigureCanvas(fig)
        self.Pixel_Canvas.setParent(self.Pixel_Graph)
        size = fig.get_size_inches()*fig.dpi
        self.Pixel_Graph.setMinimumSize(*size)
        self.Pixel_Axes = fig.add_subplot(111)

        zeros = np.zeros(shape = (48,16))
        self.Pixel_Image = self.Pixel_Axes.imshow(zeros, animated = True, interpolation = 'nearest')
        self.Pixel_Canvas.draw()
        Pixel_mpl_toolbar = NavigationToolbar(self.Pixel_Canvas, self.Pixel_Graph)
Ejemplo n.º 25
0
def plt_as_img(x, y):
    fig = Figure(figsize=(3.2, 2.4))
    canvas = FigureCanvas(fig)
    ax = fig.gca()

    ps = ax.plot(x.T, y.T, 'o')
    ax.legend(iter(ps), [str(i) for i in range(6)], loc=1)
    ax.axis('on')
    width, height = np.array(fig.get_size_inches() * fig.get_dpi(),
                             dtype=np.uint32)
    canvas.draw()  # draw the canvas, cache the renderer
    image = np.fromstring(canvas.tostring_rgb(),
                          dtype='uint8').reshape(height, width, 3)
    return image
Ejemplo n.º 26
0
class LorenzPlot(FigureCanvas):
    def __init__(self, *args):
        
        self.fig = Figure()
        self.ax = p3.Axes3D(self.fig)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self)
    
    def resizeEvent(self, ev):
        self.ax.clear()
        self.canvas.draw()
        self.fig.set_size_inches(self.size().width()/self.fig.get_dpi(),
                self.size().height()/self.fig.get_dpi())
        self.draw_plot()
        print self.fig.get_size_inches()*self.fig.get_dpi()
        print self.size()

    def Lorenz(self, w, t, s, r, b):
        x, y, z = w
        return array([s*(y-x), r*x-y-x*z, x*y-b*z])

    def draw_plot(self, s=8.0, r=28.1, b=8/3.0):
        # Parameters
        self.s, self.r, self.b = s, r, b
        
        self.w_0 = array([0., 0.8, 0.])         # initial condition
        self.time = arange(0., 100., 0.01)      # time vector 
        #integrate a system of ordinary differential equations
        self.trajectory = odeint(self.Lorenz, self.w_0, self.time, args=(self.s, self.r, self.b))
        
        self.x = self.trajectory[:, 0]
        self.y = self.trajectory[:, 1]
        self.z = self.trajectory[:, 2]
        
        self.ax = p3.Axes3D(self.fig)
        self.ax.plot3D(self.x, self.y, self.z)
        self.canvas.draw()
Ejemplo n.º 27
0
def generate_deform_grid(transform,
                         slice_axis,
                         background_image=None,
                         n_bins=20):
    """
    Abandoned
    :param background_image: 1xMxN or 3xMxN tensor or numpy array
    :param transform: 3xMxN tensor or numpy array, the first axis are z,y,x coordinates
    :param slice_axis: which axis the slice is taken from a 3d volume,
    if 0, it is taken from z axis, than it is a x-y slice; Similarly, 1 for y, 2 for x
    :return: image numpy array MxNx3
    """
    if isinstance(transform, torch.Tensor):
        transform = transform.cpu().numpy()
    if background_image is not None:
        if isinstance(background_image, torch.Tensor):
            background_image = background_image.cpu().numpy()
        assert background_image.shape[1:] == transform.shape[1:]
    left_axis = [0, 1, 2]
    left_axis.remove(2 - slice_axis)
    fig = Figure(figsize=np.array(transform.shape[1:]) / 5, dpi=20)
    ax = fig.add_axes([0, 0, 1, 1], frameon=False)
    # ax = fig.add_subplot(111)
    ax.set_axis_off()
    ax.axis('equal')
    xx = np.arange(0, transform.shape[1])
    yy = np.arange(0, transform.shape[2])
    if background_image is not None:
        ax.imshow(background_image.squeeze(), vmin=0, vmax=1, cmap='gray')

    # ax.set_ylim([0, background_image.shape[0]])

    for i, axis in enumerate(left_axis):
        T_slice = transform[axis, :, :]
        ax.contour(T_slice,
                   colors=['yellow'],
                   linewidths=10.0,
                   linestyles='solid',
                   levels=np.linspace(-1, 1, n_bins))
    ax.set_xlim([0, transform.shape[2]])
    plt.autoscale(tight=True)
    canvas = FigureCanvas(fig)
    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    image = np.fromstring(canvas.tostring_rgb(), dtype='uint8').reshape(
        int(height), int(width), 3) / 255

    return np.transpose(image, [2, 0, 1])
Ejemplo n.º 28
0
 def get_pic(self):
     samples, sample_rate = librosa.load(self.WAVE_OUTPUT_FILENAME)
     self.samples = samples.copy()
     from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
     from matplotlib.figure import Figure
     fig = Figure()
     canvas = FigureCanvas(fig)
     ax = fig.gca()
     ax.plot(samples)
     ax.axis('off')
     canvas.draw()  # draw the canvas, cache the renderer
     width, height = fig.get_size_inches() * fig.get_dpi()
     image = np.fromstring(canvas.tostring_rgb(),
                           dtype='uint8').reshape(int(height), int(width),
                                                  3)
     return image
def plt_as_img(x, y):
    fig = Figure(figsize=(6.4, 4.8))
    canvas = FigureCanvas(fig)
    ax = fig.gca()

    ps = ax.plot(x.T, y.T, '-')
    ax.legend(iter(ps), [str(i) for i in range(6)], loc=1)

    ax.set_xlabel('time [seconds]')
    ax.set_ylabel('accuracy of predicted digit')
    width, height = np.array(fig.get_size_inches() * fig.get_dpi(),
                             dtype=np.uint32)
    canvas.draw()  # draw the canvas, cache the renderer
    image = np.fromstring(canvas.tostring_rgb(),
                          dtype='uint8').reshape(height, width, 3)
    return image
Ejemplo n.º 30
0
def plot_grad_flow(named_parameters, toFigure=False):
    """
    count and visualized avg grdient of each layer
    https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/5
    :param named_parameters: get from model.named_parameters if model is a nn.Module instance
    :return:
    """
    ave_grads = []
    layers = []
    for n, p in named_parameters:
        if (p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())

    if toFigure:
        fig = Figure(figsize=10, dpi=20)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('equal')
        plt.plot(ave_grads, alpha=0.3, color="b")
        plt.hlines(0, 0, len(ave_grads) + 1, linewidth=1, color="k")
        # ax.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
        # ax.xlim(xmin=0, xmax=len(ave_grads))
        # ax.xlabel("Layers")
        # ax.ylabel("average gradient")
        # ax.title("Gradient flow")
        ax.grid(True)
        canvas = FigureCanvas(fig)
        canvas.draw()
        width, height = fig.get_size_inches() * fig.get_dpi()
        image = np.fromstring(canvas.tostring_rgb(), dtype='uint8').reshape(
            int(height), int(width), 3) / 255
        # image = np.transpose(image, [2, 0, 1])
        return image

    else:
        plt.figure()
        plt.plot(ave_grads, alpha=0.3, color="b")
        plt.hlines(0, 0, len(ave_grads) + 1, linewidth=1, color="k")
        plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
        plt.xlim(xmin=0, xmax=len(ave_grads))
        plt.xlabel("Layers")
        plt.ylabel("average gradient")
        plt.title("Gradient flow")
        plt.grid(True)
        plt.show()
Ejemplo n.º 31
0
def plotit( data, outfilename, **kargs ) :

    # going to make a 1 row x N column plot
    if len(data.shape)==1 : 
        num_rows = 1
    else : 
        num_rows = data.shape[1]

    # davep 02-Oct-2012 ; bump up the size to accommodate multiple rows
    fig = Figure()
    figsize = fig.get_size_inches()
    fig.set_size_inches( (figsize[0],figsize[1]*num_rows) )

    if "title" in kargs : 
        fig.suptitle(kargs["title"])

    # http://matplotlib.org/faq/howto_faq.html
    # "Move the edge of an axes to make room for tick labels"
    # hspace is "the amount of height reserved for white space between
    # subplots"
    fig.subplots_adjust( hspace=0.40 )

    for i in range(num_rows) : 
        ax = fig.add_subplot(num_rows,1,i+1)
        ax.grid()
        if num_rows==1 :
            column = data 
        else : 
            column = data[ :, i ] 

        fmt = ""
        if "color" in kargs : 
            fmt += kargs["color"]            
        fmt += "+"
        ax.plot(column,fmt)

        if "axis_title" in kargs : 
            title = kargs["axis_title"][i]
            ax.set_title(title)

    canvas = FigureCanvasAgg(fig)
    canvas.print_figure(outfilename)
    print("wrote", outfilename)
Ejemplo n.º 32
0
class TreeFigure:
    def __init__(self, root, relwidth=0.5, leafpad=1.5, name=None,
                 support=70.0, scaled=True, mark_named=True,
                 leaf_fontsize=10, branch_fontsize=10,
                 branch_width=1, branch_color="black",
                 highlight_support=True,
                 branchlabels=True, leaflabels=True, decorators=[],
                 xoff=0, yoff=0,
                 xlim=None, ylim=None,
                 height=None, width=None):
        self.root = root
        self.relwidth = relwidth
        self.leafpad = leafpad
        self.name = name
        self.support = support
        self.scaled = scaled
        self.mark_named = mark_named
        self.leaf_fontsize = leaf_fontsize
        self.branch_fontsize = branch_fontsize
        self.branch_width = branch_width
        self.branch_color = branch_color
        self.highlight_support = highlight_support
        self.branchlabels = branchlabels
        self.leaflabels = leaflabels
        self.decorators = decorators
        self.xoff = xoff
        self.yoff = yoff

        nleaves = len(root.leaves())
        self.dpi = 72.0
        h = height or (nleaves*self.leaf_fontsize*self.leafpad)/self.dpi
        self.height = h
        self.width = width or self.height*self.relwidth
        ## p = min(self.width, self.height)*0.1
        ## self.height += p
        ## self.width += p
        self.figure = Figure(figsize=(self.width, self.height), dpi=self.dpi)
        self.canvas = FigureCanvas(self.figure)
        self.axes = self.figure.add_axes(
            tree.TreePlot(self.figure, 1,1,1,
                          support=self.support,
                          scaled=self.scaled,
                          mark_named=self.mark_named,
                          leaf_fontsize=self.leaf_fontsize,
                          branch_fontsize=self.branch_fontsize,
                          branch_width=self.branch_width,
                          branch_color=self.branch_color,
                          highlight_support=self.highlight_support,
                          branchlabels=self.branchlabels,
                          leaflabels=self.leaflabels,
                          interactive=False,
                          decorators=self.decorators,
                          xoff=self.xoff, yoff=self.yoff,
                          name=self.name).plot_tree(self.root)
            )
        self.axes.spines["top"].set_visible(False)
        self.axes.spines["left"].set_visible(False)
        self.axes.spines["right"].set_visible(False)
        self.axes.spines["bottom"].set_smart_bounds(True)
        self.axes.xaxis.set_ticks_position("bottom")

        for v in self.axes.node2label.values():
            v.set_visible(True)

        ## for k, v in self.decorators:
        ##     func, args, kwargs = v
        ##     func(self.axes, *args, **kwargs)

        self.canvas.draw()
        ## self.axes.home()
        ## adjust_limits(self.axes)
        self.axes.set_position([0.05,0.05,0.95,0.95])

    @property
    def detail(self):
        return self.axes
        
    def savefig(self, fname):
        root, ext = os.path.splitext(fname)
        buf = tempfile.TemporaryFile()
        for i in range(3):
            self.figure.savefig(buf, format=ext[1:].lower())
            self.home()
            buf.seek(0)
        buf.close()
        self.figure.savefig(fname)

    def set_relative_width(self, relwidth):
        w, h = self.figure.get_size_inches()
        self.figure.set_figwidth(h*relwidth)

    def autoheight(self):
        "adjust figure height to show all leaf labels"
        nleaves = len(self.root.leaves())
        h = (nleaves*self.leaf_fontsize*self.leafpad)/self.dpi
        self.height = h
        self.figure.set_size_inches(self.width, self.height)
        self.axes.set_ylim(-2, nleaves+2)

    def home(self):
        self.axes.home()
Ejemplo n.º 33
0
	return r


phi = numpy.linspace(0, 2 * numpy.pi, 1024)
m_init = 3
n1_init = 2
n2_init = 18
n3_init = 18


fig = Figure((6, 6), dpi = 80)
ax = fig.add_subplot(111, polar = True)

r = supershape_radius(phi, 1, 1, m_init, n1_init, n2_init, n3_init)
lines, = ax.plot(phi, r, lw = 3.)


win = Gtk.Window()
win.connect('delete-event', Gtk.main_quit)
win.set_title('SuperShape')

canvas = FigureCanvasGTK3Agg(fig)
w, h = fig.get_size_inches()
dpi_res = fig.get_dpi()
w, h = int(numpy.ceil(w * dpi_res)), int(numpy.ceil(h * dpi_res))
canvas.set_size_request(w, h)
win.add(canvas)

win.show_all()
Gtk.main()
Ejemplo n.º 34
0
class HorizonFrame(wx.Frame):
    """ The main frame of the horizon indicator."""

    def __init__(self, state, title):
        self.state = state
        # Create Frame and Panel(s)
        wx.Frame.__init__(self, None, title=title)
        state.frame = self

        # Initialisation
        self.initData()
        self.initUI()
        self.startTime = time.time()

    def initData(self):
        # Initialise Attitude
        self.pitch = 0.0  # Degrees
        self.roll = 0.0  # Degrees
        self.yaw = 0.0  # Degrees

        # History Values
        self.oldRoll = 0.0  # Degrees

        # Initialise Rate Information
        self.airspeed = 0.0  # m/s
        self.relAlt = 0.0  # m relative to home position
        self.climbRate = 0.0  # m/s
        self.altHist = []  # Altitude History
        self.timeHist = []  # Time History
        self.altMax = 0.0  # Maximum altitude since startup

        # Initialise HUD Info
        self.heading = 0.0  # 0-360

        # Initialise Battery Info
        self.voltage = 0.0
        self.current = 0.0
        self.batRemain = 0.0

        # Initialise Mode and State
        self.mode = "UNKNOWN"
        self.armed = ""
        self.safetySwitch = ""

        # Intialise Waypoint Information
        self.currentWP = 0
        self.finalWP = 0
        self.wpDist = 0
        self.nextWPTime = 0
        self.wpBearing = 0

    def initUI(self):
        # Create Event Timer and Bindings
        self.timer = wx.Timer(self)
        self.Bind(wx.EVT_TIMER, self.on_timer, self.timer)
        self.timer.Start(100)
        self.Bind(wx.EVT_IDLE, self.on_idle)
        self.Bind(wx.EVT_CHAR_HOOK, self.on_KeyPress)

        # Create Panel
        self.panel = wx.Panel(self)
        self.vertSize = 0.09
        self.resized = False

        # Create Matplotlib Panel
        self.createPlotPanel()

        # Fix Axes - vertical is of length 2, horizontal keeps the same lengthscale
        self.rescaleX()
        self.calcFontScaling()

        # Create Horizon Polygons
        self.createHorizonPolygons()

        # Center Pointer Marker
        self.thick = 0.015
        self.createCenterPointMarker()

        # Pitch Markers
        self.dist10deg = 0.2  # Graph distance per 10 deg
        self.createPitchMarkers()

        # Add Roll, Pitch, Yaw Text
        self.createRPYText()

        # Add Airspeed, Altitude, Climb Rate Text
        self.createAARText()

        # Create Heading Pointer
        self.createHeadingPointer()

        # Create North Pointer
        self.createNorthPointer()

        # Create Battery Bar
        self.batWidth = 0.1
        self.batHeight = 0.2
        self.rOffset = 0.35
        self.createBatteryBar()

        # Create Mode & State Text
        self.createStateText()

        # Create Waypoint Text
        self.createWPText()

        # Create Waypoint Pointer
        self.createWPPointer()

        # Create Altitude History Plot
        self.createAltHistoryPlot()

        # Show Frame
        self.Show(True)
        self.pending = []

    def createPlotPanel(self):
        """Creates the figure and axes for the plotting panel."""
        self.figure = Figure()
        self.axes = self.figure.add_subplot(111)
        self.canvas = FigureCanvas(self, -1, self.figure)
        self.canvas.SetSize(wx.Size(300, 300))
        self.axes.axis("off")
        self.figure.subplots_adjust(left=0, right=1, top=1, bottom=0)
        self.sizer = wx.BoxSizer(wx.VERTICAL)
        self.sizer.Add(self.canvas, 1, wx.EXPAND, wx.ALL)
        self.SetSizerAndFit(self.sizer)
        self.Fit()

    def rescaleX(self):
        """Rescales the horizontal axes to make the lengthscales equal."""
        self.ratio = self.figure.get_size_inches()[0] / float(self.figure.get_size_inches()[1])
        self.axes.set_xlim(-self.ratio, self.ratio)
        self.axes.set_ylim(-1, 1)

    def calcFontScaling(self):
        """Calculates the current font size and left position for the current window."""
        self.ypx = self.figure.get_size_inches()[1] * self.figure.dpi
        self.xpx = self.figure.get_size_inches()[0] * self.figure.dpi
        self.fontSize = self.vertSize * (self.ypx / 2.0)
        self.leftPos = self.axes.get_xlim()[0]
        self.rightPos = self.axes.get_xlim()[1]

    def checkReszie(self):
        """Checks if the window was resized."""
        if not self.resized:
            oldypx = self.ypx
            oldxpx = self.xpx
            self.ypx = self.figure.get_size_inches()[1] * self.figure.dpi
            self.xpx = self.figure.get_size_inches()[0] * self.figure.dpi
            if (oldypx != self.ypx) or (oldxpx != self.xpx):
                self.resized = True
            else:
                self.resized = False

    def createHeadingPointer(self):
        """Creates the pointer for the current heading."""
        self.headingTri = patches.RegularPolygon((0.0, 0.80), 3, 0.05, color="k", zorder=4)
        self.axes.add_patch(self.headingTri)
        self.headingText = self.axes.text(
            0.0,
            0.675,
            "0",
            color="k",
            size=self.fontSize,
            horizontalalignment="center",
            verticalalignment="center",
            zorder=4,
        )

    def adjustHeadingPointer(self):
        """Adjust the value of the heading pointer."""
        self.headingText.set_text(str(self.heading))
        self.headingText.set_size(self.fontSize)

    def createNorthPointer(self):
        """Creates the north pointer relative to current heading."""
        self.headingNorthTri = patches.RegularPolygon((0.0, 0.80), 3, 0.05, color="k", zorder=4)
        self.axes.add_patch(self.headingNorthTri)
        self.headingNorthText = self.axes.text(
            0.0,
            0.675,
            "N",
            color="k",
            size=self.fontSize,
            horizontalalignment="center",
            verticalalignment="center",
            zorder=4,
        )

    def adjustNorthPointer(self):
        """Adjust the position and orientation of
        the north pointer."""
        self.headingNorthText.set_size(self.fontSize)
        headingRotate = mpl.transforms.Affine2D().rotate_deg_around(0.0, 0.0, self.heading) + self.axes.transData
        self.headingNorthText.set_transform(headingRotate)
        if (self.heading > 90) and (self.heading < 270):
            headRot = self.heading - 180
        else:
            headRot = self.heading
        self.headingNorthText.set_rotation(headRot)
        self.headingNorthTri.set_transform(headingRotate)
        # Adjust if overlapping with heading pointer
        if (self.heading <= 10.0) or (self.heading >= 350.0):
            self.headingNorthText.set_text("")
        else:
            self.headingNorthText.set_text("N")

    def toggleWidgets(self, widgets):
        """Hides/shows the given widgets."""
        for wig in widgets:
            if wig.get_visible():
                wig.set_visible(False)
            else:
                wig.set_visible(True)

    def createRPYText(self):
        """Creates the text for roll, pitch and yaw."""
        self.rollText = self.axes.text(
            self.leftPos + (self.vertSize / 10.0),
            -0.97 + (2 * self.vertSize) - (self.vertSize / 10.0),
            "Roll:   %.2f" % self.roll,
            color="w",
            size=self.fontSize,
        )
        self.pitchText = self.axes.text(
            self.leftPos + (self.vertSize / 10.0),
            -0.97 + self.vertSize - (0.5 * self.vertSize / 10.0),
            "Pitch: %.2f" % self.pitch,
            color="w",
            size=self.fontSize,
        )
        self.yawText = self.axes.text(
            self.leftPos + (self.vertSize / 10.0), -0.97, "Yaw:   %.2f" % self.yaw, color="w", size=self.fontSize
        )
        self.rollText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])
        self.pitchText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])
        self.yawText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])

    def updateRPYLocations(self):
        """Update the locations of roll, pitch, yaw text."""
        # Locations
        self.rollText.set_position(
            (self.leftPos + (self.vertSize / 10.0), -0.97 + (2 * self.vertSize) - (self.vertSize / 10.0))
        )
        self.pitchText.set_position(
            (self.leftPos + (self.vertSize / 10.0), -0.97 + self.vertSize - (0.5 * self.vertSize / 10.0))
        )
        self.yawText.set_position((self.leftPos + (self.vertSize / 10.0), -0.97))
        # Font Size
        self.rollText.set_size(self.fontSize)
        self.pitchText.set_size(self.fontSize)
        self.yawText.set_size(self.fontSize)

    def updateRPYText(self):
        "Updates the displayed Roll, Pitch, Yaw Text"
        self.rollText.set_text("Roll:   %.2f" % self.roll)
        self.pitchText.set_text("Pitch: %.2f" % self.pitch)
        self.yawText.set_text("Yaw:   %.2f" % self.yaw)

    def createCenterPointMarker(self):
        """Creates the center pointer in the middle of the screen."""
        self.axes.add_patch(
            patches.Rectangle((-0.75, -self.thick), 0.5, 2.0 * self.thick, facecolor="orange", zorder=3)
        )
        self.axes.add_patch(patches.Rectangle((0.25, -self.thick), 0.5, 2.0 * self.thick, facecolor="orange", zorder=3))
        self.axes.add_patch(patches.Circle((0, 0), radius=self.thick, facecolor="orange", edgecolor="none", zorder=3))

    def createHorizonPolygons(self):
        """Creates the two polygons to show the sky and ground."""
        # Sky Polygon
        vertsTop = [[-1, 0], [-1, 1], [1, 1], [1, 0], [-1, 0]]
        self.topPolygon = Polygon(vertsTop, facecolor="dodgerblue", edgecolor="none")
        self.axes.add_patch(self.topPolygon)
        # Ground Polygon
        vertsBot = [[-1, 0], [-1, -1], [1, -1], [1, 0], [-1, 0]]
        self.botPolygon = Polygon(vertsBot, facecolor="brown", edgecolor="none")
        self.axes.add_patch(self.botPolygon)

    def calcHorizonPoints(self):
        """Updates the verticies of the patches for the ground and sky."""
        ydiff = math.tan(math.radians(-self.roll)) * float(self.ratio)
        pitchdiff = self.dist10deg * (self.pitch / 10.0)
        # Sky Polygon
        vertsTop = [
            (-self.ratio, ydiff - pitchdiff),
            (-self.ratio, 1),
            (self.ratio, 1),
            (self.ratio, -ydiff - pitchdiff),
            (-self.ratio, ydiff - pitchdiff),
        ]
        self.topPolygon.set_xy(vertsTop)
        # Ground Polygon
        vertsBot = [
            (-self.ratio, ydiff - pitchdiff),
            (-self.ratio, -1),
            (self.ratio, -1),
            (self.ratio, -ydiff - pitchdiff),
            (-self.ratio, ydiff - pitchdiff),
        ]
        self.botPolygon.set_xy(vertsBot)

    def createPitchMarkers(self):
        """Creates the rectangle patches for the pitch indicators."""
        self.pitchPatches = []
        # Major Lines (multiple of 10 deg)
        for i in [-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
            width = self.calcPitchMarkerWidth(i)
            currPatch = patches.Rectangle(
                (-width / 2.0, self.dist10deg * i - (self.thick / 2.0)),
                width,
                self.thick,
                facecolor="w",
                edgecolor="none",
            )
            self.axes.add_patch(currPatch)
            self.pitchPatches.append(currPatch)
        # Add Label for +-30 deg
        self.vertSize = 0.09
        self.pitchLabelsLeft = []
        self.pitchLabelsRight = []
        i = 0
        for j in [-90, -60, -30, 30, 60, 90]:
            self.pitchLabelsLeft.append(
                self.axes.text(
                    -0.55,
                    (j / 10.0) * self.dist10deg,
                    str(j),
                    color="w",
                    size=self.fontSize,
                    horizontalalignment="center",
                    verticalalignment="center",
                )
            )
            self.pitchLabelsLeft[i].set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])
            self.pitchLabelsRight.append(
                self.axes.text(
                    0.55,
                    (j / 10.0) * self.dist10deg,
                    str(j),
                    color="w",
                    size=self.fontSize,
                    horizontalalignment="center",
                    verticalalignment="center",
                )
            )
            self.pitchLabelsRight[i].set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])
            i += 1

    def calcPitchMarkerWidth(self, i):
        """Calculates the width of a pitch marker."""
        if (i % 3) == 0:
            if i == 0:
                width = 1.5
            else:
                width = 0.9
        else:
            width = 0.6

        return width

    def adjustPitchmarkers(self):
        """Adjusts the location and orientation of pitch markers."""
        pitchdiff = self.dist10deg * (self.pitch / 10.0)
        rollRotate = mpl.transforms.Affine2D().rotate_deg_around(0.0, -pitchdiff, self.roll) + self.axes.transData
        j = 0
        for i in [-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
            width = self.calcPitchMarkerWidth(i)
            self.pitchPatches[j].set_xy((-width / 2.0, self.dist10deg * i - (self.thick / 2.0) - pitchdiff))
            self.pitchPatches[j].set_transform(rollRotate)
            j += 1
        # Adjust Text Size and rotation
        i = 0
        for j in [-9, -6, -3, 3, 6, 9]:
            self.pitchLabelsLeft[i].set_y(j * self.dist10deg - pitchdiff)
            self.pitchLabelsRight[i].set_y(j * self.dist10deg - pitchdiff)
            self.pitchLabelsLeft[i].set_size(self.fontSize)
            self.pitchLabelsRight[i].set_size(self.fontSize)
            self.pitchLabelsLeft[i].set_rotation(self.roll)
            self.pitchLabelsRight[i].set_rotation(self.roll)
            self.pitchLabelsLeft[i].set_transform(rollRotate)
            self.pitchLabelsRight[i].set_transform(rollRotate)
            i += 1

    def createAARText(self):
        """Creates the text for airspeed, altitude and climb rate."""
        self.airspeedText = self.axes.text(
            self.rightPos - (self.vertSize / 10.0),
            -0.97 + (2 * self.vertSize) - (self.vertSize / 10.0),
            "AS:   %.1f m/s" % self.airspeed,
            color="w",
            size=self.fontSize,
            ha="right",
        )
        self.altitudeText = self.axes.text(
            self.rightPos - (self.vertSize / 10.0),
            -0.97 + self.vertSize - (0.5 * self.vertSize / 10.0),
            "ALT: %.1f m   " % self.relAlt,
            color="w",
            size=self.fontSize,
            ha="right",
        )
        self.climbRateText = self.axes.text(
            self.rightPos - (self.vertSize / 10.0),
            -0.97,
            "CR:   %.1f m/s" % self.climbRate,
            color="w",
            size=self.fontSize,
            ha="right",
        )
        self.airspeedText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])
        self.altitudeText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])
        self.climbRateText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])

    def updateAARLocations(self):
        """Update the locations of airspeed, altitude and Climb rate."""
        # Locations
        self.airspeedText.set_position(
            (self.rightPos - (self.vertSize / 10.0), -0.97 + (2 * self.vertSize) - (self.vertSize / 10.0))
        )
        self.altitudeText.set_position(
            (self.rightPos - (self.vertSize / 10.0), -0.97 + self.vertSize - (0.5 * self.vertSize / 10.0))
        )
        self.climbRateText.set_position((self.rightPos - (self.vertSize / 10.0), -0.97))
        # Font Size
        self.airspeedText.set_size(self.fontSize)
        self.altitudeText.set_size(self.fontSize)
        self.climbRateText.set_size(self.fontSize)

    def updateAARText(self):
        "Updates the displayed airspeed, altitude, climb rate Text"
        self.airspeedText.set_text("AR:   %.1f m/s" % self.airspeed)
        self.altitudeText.set_text("ALT: %.1f m   " % self.relAlt)
        self.climbRateText.set_text("CR:   %.1f m/s" % self.climbRate)

    def createBatteryBar(self):
        """Creates the bar to display current battery percentage."""
        self.batOutRec = patches.Rectangle(
            (self.rightPos - (1.3 + self.rOffset) * self.batWidth, 1.0 - (0.1 + 1.0 + (2 * 0.075)) * self.batHeight),
            self.batWidth * 1.3,
            self.batHeight * 1.15,
            facecolor="darkgrey",
            edgecolor="none",
        )
        self.batInRec = patches.Rectangle(
            (self.rightPos - (self.rOffset + 1 + 0.15) * self.batWidth, 1.0 - (0.1 + 1 + 0.075) * self.batHeight),
            self.batWidth,
            self.batHeight,
            facecolor="lawngreen",
            edgecolor="none",
        )
        self.batPerText = self.axes.text(
            self.rightPos - (self.rOffset + 0.65) * self.batWidth,
            1 - (0.1 + 1 + (0.075 + 0.15)) * self.batHeight,
            "%.f" % self.batRemain,
            color="w",
            size=self.fontSize,
            ha="center",
            va="top",
        )
        self.batPerText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])
        self.voltsText = self.axes.text(
            self.rightPos - (self.rOffset + 1.3 + 0.2) * self.batWidth,
            1 - (0.1 + 0.05 + 0.075) * self.batHeight,
            "%.1f V" % self.voltage,
            color="w",
            size=self.fontSize,
            ha="right",
            va="top",
        )
        self.ampsText = self.axes.text(
            self.rightPos - (self.rOffset + 1.3 + 0.2) * self.batWidth,
            1 - self.vertSize - (0.1 + 0.05 + 0.1 + 0.075) * self.batHeight,
            "%.1f A" % self.current,
            color="w",
            size=self.fontSize,
            ha="right",
            va="top",
        )
        self.voltsText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])
        self.ampsText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])

        self.axes.add_patch(self.batOutRec)
        self.axes.add_patch(self.batInRec)

    def updateBatteryBar(self):
        """Updates the position and values of the battery bar."""
        # Bar
        self.batOutRec.set_xy(
            (self.rightPos - (1.3 + self.rOffset) * self.batWidth, 1.0 - (0.1 + 1.0 + (2 * 0.075)) * self.batHeight)
        )
        self.batInRec.set_xy(
            (self.rightPos - (self.rOffset + 1 + 0.15) * self.batWidth, 1.0 - (0.1 + 1 + 0.075) * self.batHeight)
        )
        self.batPerText.set_position(
            (self.rightPos - (self.rOffset + 0.65) * self.batWidth, 1 - (0.1 + 1 + (0.075 + 0.15)) * self.batHeight)
        )
        self.batPerText.set_fontsize(self.fontSize)
        self.voltsText.set_text("%.1f V" % self.voltage)
        self.ampsText.set_text("%.1f A" % self.current)
        self.voltsText.set_position(
            (self.rightPos - (self.rOffset + 1.3 + 0.2) * self.batWidth, 1 - (0.1 + 0.05) * self.batHeight)
        )
        self.ampsText.set_position(
            (
                self.rightPos - (self.rOffset + 1.3 + 0.2) * self.batWidth,
                1 - self.vertSize - (0.1 + 0.05 + 0.1) * self.batHeight,
            )
        )
        self.voltsText.set_fontsize(self.fontSize)
        self.ampsText.set_fontsize(self.fontSize)
        if self.batRemain >= 0:
            self.batPerText.set_text(int(self.batRemain))
            self.batInRec.set_height(self.batRemain * self.batHeight / 100.0)
            if self.batRemain / 100.0 > 0.5:
                self.batInRec.set_facecolor("lawngreen")
            elif self.batRemain / 100.0 <= 0.5 and self.batRemain / 100.0 > 0.2:
                self.batInRec.set_facecolor("yellow")
            elif self.batRemain / 100.0 <= 0.2 and self.batRemain >= 0.0:
                self.batInRec.set_facecolor("r")
        elif self.batRemain == -1:
            self.batInRec.set_height(self.batHeight)
            self.batInRec.set_facecolor("k")

    def createStateText(self):
        """Creates the mode and arm state text."""
        self.modeText = self.axes.text(
            self.leftPos + (self.vertSize / 10.0),
            0.97,
            "UNKNOWN",
            color="grey",
            size=1.5 * self.fontSize,
            ha="left",
            va="top",
        )
        self.modeText.set_path_effects([PathEffects.withStroke(linewidth=self.fontSize / 10.0, foreground="black")])

    def updateStateText(self):
        """Updates the mode and colours red or green depending on arm state."""
        self.modeText.set_position((self.leftPos + (self.vertSize / 10.0), 0.97))
        self.modeText.set_text(self.mode)
        self.modeText.set_size(1.5 * self.fontSize)
        if self.armed:
            self.modeText.set_color("red")
            self.modeText.set_path_effects(
                [PathEffects.withStroke(linewidth=self.fontSize / 10.0, foreground="yellow")]
            )
        elif self.armed == False:
            self.modeText.set_color("lightgreen")
            self.modeText.set_bbox(None)
            self.modeText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="black")])
        else:
            # Fall back if unknown
            self.modeText.set_color("grey")
            self.modeText.set_bbox(None)
            self.modeText.set_path_effects([PathEffects.withStroke(linewidth=self.fontSize / 10.0, foreground="black")])

    def createWPText(self):
        """Creates the text for the current and final waypoint,
        and the distance to the new waypoint."""
        self.wpText = self.axes.text(
            self.leftPos + (1.5 * self.vertSize / 10.0),
            0.97 - (1.5 * self.vertSize) + (0.5 * self.vertSize / 10.0),
            "0/0\n(0 m, 0 s)",
            color="w",
            size=self.fontSize,
            ha="left",
            va="top",
        )
        self.wpText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="black")])

    def updateWPText(self):
        """Updates the current waypoint and distance to it."""
        self.wpText.set_position(
            (self.leftPos + (1.5 * self.vertSize / 10.0), 0.97 - (1.5 * self.vertSize) + (0.5 * self.vertSize / 10.0))
        )
        self.wpText.set_size(self.fontSize)
        if type(self.nextWPTime) is str:
            self.wpText.set_text("%.f/%.f\n(%.f m, ~ s)" % (self.currentWP, self.finalWP, self.wpDist))
        else:
            self.wpText.set_text(
                "%.f/%.f\n(%.f m, %.f s)" % (self.currentWP, self.finalWP, self.wpDist, self.nextWPTime)
            )

    def createWPPointer(self):
        """Creates the waypoint pointer relative to current heading."""
        self.headingWPTri = patches.RegularPolygon((0.0, 0.55), 3, 0.05, facecolor="lime", zorder=4, ec="k")
        self.axes.add_patch(self.headingWPTri)
        self.headingWPText = self.axes.text(
            0.0,
            0.45,
            "1",
            color="lime",
            size=self.fontSize,
            horizontalalignment="center",
            verticalalignment="center",
            zorder=4,
        )
        self.headingWPText.set_path_effects([PathEffects.withStroke(linewidth=1, foreground="k")])

    def adjustWPPointer(self):
        """Adjust the position and orientation of
        the waypoint pointer."""
        self.headingWPText.set_size(self.fontSize)
        headingRotate = (
            mpl.transforms.Affine2D().rotate_deg_around(0.0, 0.0, -self.wpBearing + self.heading) + self.axes.transData
        )
        self.headingWPText.set_transform(headingRotate)
        angle = self.wpBearing - self.heading
        if angle < 0:
            angle += 360
        if (angle > 90) and (angle < 270):
            headRot = angle - 180
        else:
            headRot = angle
        self.headingWPText.set_rotation(-headRot)
        self.headingWPTri.set_transform(headingRotate)
        self.headingWPText.set_text("%.f" % (angle))

    def createAltHistoryPlot(self):
        """Creates the altitude history plot."""
        self.altHistRect = patches.Rectangle(
            (self.leftPos + (self.vertSize / 10.0), -0.25),
            0.5,
            0.5,
            facecolor="grey",
            edgecolor="none",
            alpha=0.4,
            zorder=4,
        )
        self.axes.add_patch(self.altHistRect)
        self.altPlot, = self.axes.plot(
            [self.leftPos + (self.vertSize / 10.0), self.leftPos + (self.vertSize / 10.0) + 0.5],
            [0.0, 0.0],
            color="k",
            marker=None,
            zorder=4,
        )
        self.altMarker, = self.axes.plot(
            self.leftPos + (self.vertSize / 10.0) + 0.5, 0.0, marker="o", color="k", zorder=4
        )
        self.altText2 = self.axes.text(
            self.leftPos + (4 * self.vertSize / 10.0) + 0.5,
            0.0,
            "%.f m" % self.relAlt,
            color="k",
            size=self.fontSize,
            ha="left",
            va="center",
            zorder=4,
        )

    def updateAltHistory(self):
        """Updates the altitude history plot."""
        self.altHist.append(self.relAlt)
        self.timeHist.append(time.time())

        # Delete entries older than x seconds
        histLim = 10
        currentTime = time.time()
        point = 0
        for i in range(0, len(self.timeHist)):
            if self.timeHist[i] > (currentTime - 10.0):
                break
        # Remove old entries
        self.altHist = self.altHist[i:]
        self.timeHist = self.timeHist[i:]

        # Transform Data
        x = []
        y = []
        tmin = min(self.timeHist)
        tmax = max(self.timeHist)
        x1 = self.leftPos + (self.vertSize / 10.0)
        y1 = -0.25
        altMin = 0
        altMax = max(self.altHist)
        # Keep alt max for whole mission
        if altMax > self.altMax:
            self.altMax = altMax
        else:
            altMax = self.altMax
        if tmax != tmin:
            mx = 0.5 / (tmax - tmin)
        else:
            mx = 0.0
        if altMax != altMin:
            my = 0.5 / (altMax - altMin)
        else:
            my = 0.0
        for t in self.timeHist:
            x.append(mx * (t - tmin) + x1)
        for alt in self.altHist:
            val = my * (alt - altMin) + y1
            # Crop extreme noise
            if val < -0.25:
                val = -0.25
            elif val > 0.25:
                val = 0.25
            y.append(val)
        # Display Plot
        self.altHistRect.set_x(self.leftPos + (self.vertSize / 10.0))
        self.altPlot.set_data(x, y)
        self.altMarker.set_data(self.leftPos + (self.vertSize / 10.0) + 0.5, val)
        self.altText2.set_position((self.leftPos + (4 * self.vertSize / 10.0) + 0.5, val))
        self.altText2.set_size(self.fontSize)
        self.altText2.set_text("%.f m" % self.relAlt)

    # =============== Event Bindings =============== #
    def on_idle(self, event):
        """To adjust text and positions on rescaling the window when resized."""
        # Check for resize
        self.checkReszie()

        if self.resized:
            # Fix Window Scales
            self.rescaleX()
            self.calcFontScaling()

            # Recalculate Horizon Polygons
            self.calcHorizonPoints()

            # Update Roll, Pitch, Yaw Text Locations
            self.updateRPYLocations()

            # Update Airpseed, Altitude, Climb Rate Locations
            self.updateAARLocations()

            # Update Pitch Markers
            self.adjustPitchmarkers()

            # Update Heading and North Pointer
            self.adjustHeadingPointer()
            self.adjustNorthPointer()

            # Update Battery Bar
            self.updateBatteryBar()

            # Update Mode and State
            self.updateStateText()

            # Update Waypoint Text
            self.updateWPText()

            # Adjust Waypoint Pointer
            self.adjustWPPointer()

            # Update History Plot
            self.updateAltHistory()

            # Update Matplotlib Plot
            self.canvas.draw()
            self.canvas.Refresh()

            self.resized = False

        time.sleep(0.05)

    def on_timer(self, event):
        """Main Loop."""
        state = self.state
        if state.close_event.wait(0.001):
            self.timer.Stop()
            self.Destroy()
            return

        # Check for resizing
        self.checkReszie()
        if self.resized:
            self.on_idle(0)

        # Get attitude information
        while state.child_pipe_recv.poll():
            obj = state.child_pipe_recv.recv()
            self.calcFontScaling()
            if isinstance(obj, Attitude):
                self.oldRoll = self.roll
                self.pitch = obj.pitch * 180 / math.pi
                self.roll = obj.roll * 180 / math.pi
                self.yaw = obj.yaw * 180 / math.pi

                # Update Roll, Pitch, Yaw Text Text
                self.updateRPYText()

                # Recalculate Horizon Polygons
                self.calcHorizonPoints()

                # Update Pitch Markers
                self.adjustPitchmarkers()

            elif isinstance(obj, VFR_HUD):
                self.heading = obj.heading
                self.airspeed = obj.airspeed
                self.climbRate = obj.climbRate

                # Update Airpseed, Altitude, Climb Rate Locations
                self.updateAARText()

                # Update Heading North Pointer
                self.adjustHeadingPointer()
                self.adjustNorthPointer()

            elif isinstance(obj, Global_Position_INT):
                self.relAlt = obj.relAlt

                # Update Airpseed, Altitude, Climb Rate Locations
                self.updateAARText()

                # Update Altitude History
                self.updateAltHistory()

            elif isinstance(obj, BatteryInfo):
                self.voltage = obj.voltage
                self.current = obj.current
                self.batRemain = obj.batRemain

                # Update Battery Bar
                self.updateBatteryBar()

            elif isinstance(obj, FlightState):
                self.mode = obj.mode
                self.armed = obj.armState

                # Update Mode and Arm State Text
                self.updateStateText()

            elif isinstance(obj, WaypointInfo):
                self.currentWP = obj.current
                self.finalWP = obj.final
                self.wpDist = obj.currentDist
                self.nextWPTime = obj.nextWPTime
                if obj.wpBearing < 0.0:
                    self.wpBearing = obj.wpBearing + 360
                else:
                    self.wpBearing = obj.wpBearing

                # Update waypoint text
                self.updateWPText()

                # Adjust Waypoint Pointer
                self.adjustWPPointer()

        # Update Matplotlib Plot
        self.canvas.draw()
        self.canvas.Refresh()

        self.Refresh()
        self.Update()

    def on_KeyPress(self, event):
        """To adjust the distance between pitch markers."""
        if event.GetKeyCode() == wx.WXK_UP:
            self.dist10deg += 0.1
            print "Dist per 10 deg: %.1f" % self.dist10deg
        elif event.GetKeyCode() == wx.WXK_DOWN:
            self.dist10deg -= 0.1
            if self.dist10deg <= 0:
                self.dist10deg = 0.1
            print "Dist per 10 deg: %.1f" % self.dist10deg
        # Toggle Widgets
        elif event.GetKeyCode() == 49:  # 1
            widgets = [self.modeText, self.wpText]
            self.toggleWidgets(widgets)
        elif event.GetKeyCode() == 50:  # 2
            widgets = [self.batOutRec, self.batInRec, self.voltsText, self.ampsText, self.batPerText]
            self.toggleWidgets(widgets)
        elif event.GetKeyCode() == 51:  # 3
            widgets = [self.rollText, self.pitchText, self.yawText]
            self.toggleWidgets(widgets)
        elif event.GetKeyCode() == 52:  # 4
            widgets = [self.airspeedText, self.altitudeText, self.climbRateText]
            self.toggleWidgets(widgets)
        elif event.GetKeyCode() == 53:  # 5
            widgets = [self.altHistRect, self.altPlot, self.altMarker, self.altText2]
            self.toggleWidgets(widgets)
        elif event.GetKeyCode() == 54:  # 6
            widgets = [
                self.headingTri,
                self.headingText,
                self.headingNorthTri,
                self.headingNorthText,
                self.headingWPTri,
                self.headingWPText,
            ]
            self.toggleWidgets(widgets)

        # Update Matplotlib Plot
        self.canvas.draw()
        self.canvas.Refresh()

        self.Refresh()
        self.Update()
Ejemplo n.º 35
0
class PlotController(DialogMixin):
    """
        A base class for matplotlib-canvas controllers that, sets up the 
        widgets and has image exporting functionality.
    """

    file_filters = ("Portable Network Graphics (PNG)", "*.png"), \
                   ("Scalable Vector Graphics (SVG)", "*.svg"), \
                   ("Portable Document Format (PDF)", "*.pdf")

    _canvas = None
    @property
    def canvas(self):
        if not self._canvas:
            self.setup_figure()
            self.setup_canvas()
            self.setup_content()
        return self._canvas

    # ------------------------------------------------------------
    #      Initialisation and other internals
    # ------------------------------------------------------------
    def __init__(self):
        self._proxies = dict()
        self.setup_figure()
        self.setup_canvas()
        self.setup_content()

    def setup_figure(self):
        style = gtk.Style()
        self.figure = Figure(dpi=72, edgecolor=str(style.bg[2]), facecolor=str(style.bg[0]))
        self.figure.subplots_adjust(hspace=0.0, wspace=0.0)

    def setup_canvas(self):
        self._canvas = FigureCanvasGTK(self.figure)

    def setup_content(self):
        raise NotImplementedError

    # ------------------------------------------------------------
    #      Update subroutines
    # ------------------------------------------------------------
    def draw(self):
        try:
            self.figure.canvas.draw()
            self.fix_after_drawing()
        except ParseFatalException:
            logger.exception("Caught unhandled exception when drawing")

    def fix_after_drawing(self):
        pass # nothing to fix

    # ------------------------------------------------------------
    #      Graph exporting
    # ------------------------------------------------------------
    def save(self, parent=None, suggest_name="graph", size="auto", num_specimens=1, offset=0.75):
        """
            Displays a save dialog to export an image from the current plot.
        """
        # Parse arguments:
        width, height = 0, 0
        if size == "auto":
            descr, width, height, dpi = settings.OUTPUT_PRESETS[0]
        else:
            width, height, dpi = map(float, size.replace("@", "x").split("x"))

        # Load gui:
        builder = gtk.Builder()
        builder.add_from_file(resource_filename("pyxrd.specimen", "glade/save_graph_size.glade")) # FIXME move this to this namespace!!
        size_expander = builder.get_object("size_expander")
        cmb_presets = builder.get_object("cmb_presets")

        # Setup combo with presets:
        cmb_store = gtk.ListStore(str, int, int, float)
        for row in settings.OUTPUT_PRESETS:
            cmb_store.append(row)
        cmb_presets.clear()
        cmb_presets.set_model(cmb_store)
        cell = gtk.CellRendererText()
        cmb_presets.pack_start(cell, True)
        cmb_presets.add_attribute(cell, 'text', 0)
        def on_cmb_changed(cmb, *args):
            itr = cmb_presets.get_active_iter()
            w, h, d = cmb_store.get(itr, 1, 2, 3)
            entry_w.set_text(str(w))
            entry_h.set_text(str(h))
            entry_dpi.set_text(str(d))
        cmb_presets.connect('changed', on_cmb_changed)

        # Setup input boxes:
        entry_w = builder.get_object("entry_width")
        entry_h = builder.get_object("entry_height")
        entry_dpi = builder.get_object("entry_dpi")
        entry_w.set_text(str(width))
        entry_h.set_text(str(height))
        entry_dpi.set_text(str(dpi))

        # What to do when the user wants to save this:
        def on_accept(dialog):
            # Get the selected file type and name:
            cur_fltr = dialog.get_filter()
            filename = dialog.get_filename()
            # Add the correct extension if not present yet:
            for fltr in self.file_filters:
                if cur_fltr.get_name() == fltr[0]:
                    if filename[len(filename) - 4:] != fltr[1][1:]:
                        filename = "%s%s" % (filename, fltr[1][1:])
                    break
            # Get the width, height & dpi
            width = float(entry_w.get_text())
            height = float(entry_h.get_text())
            dpi = float(entry_dpi.get_text())
            i_width, i_height = width / dpi, height / dpi
            # Save it all right!
            self.save_figure(filename, dpi, i_width, i_height)

        # Ask the user where, how and if he wants to save:
        self.run_save_dialog("Save Graph", on_accept, None, parent=parent, suggest_name=suggest_name, extra_widget=size_expander)

    def save_figure(self, filename, dpi, i_width, i_height):
        """
            Save the current plot
            
            Arguments:
             filename: the filename to save to (either .png, .pdf or .svg)
             dpi: Dots-Per-Inch resolution
             i_width: the width in inch
             i_height: the height in inch
        """
        # Get original settings:
        original_dpi = self.figure.get_dpi()
        original_width, original_height = self.figure.get_size_inches()
        # Set everything according to the user selection:
        self.figure.set_dpi(dpi)
        self.figure.set_size_inches((i_width, i_height))
        self.figure.canvas.draw() # replot
        bbox_inches = matplotlib.transforms.Bbox.from_bounds(0, 0, i_width, i_height)
        # Save the figure:
        self.figure.savefig(filename, dpi=dpi, bbox_inches=bbox_inches)
        # Put everything back the way it was:
        self.figure.set_dpi(original_dpi)
        self.figure.set_size_inches((original_width, original_height))
        self.figure.canvas.draw() # replot
Ejemplo n.º 36
0
class MplView(FigureCanvas, BaseView):
    """
    Base class for matplotlib based views. This handles graph canvas setup, toolbar initialisation
    and figure save options. Subclass for your own graph-specific views.
    """
    is_floatable_view = True
    is_mpl_toolbar_enabled = True

    """Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
    def __init__(self, parent, width=5, height=4, dpi=100, **kwargs):

        self.v = parent

        self.fig = Figure(figsize=(width, height), dpi=dpi)
        self.ax = self.fig.add_subplot(111)
        
        self.ax.plot([1,2,3,4])

        self.ax.spines['top'].set_visible(False)
        self.ax.spines['right'].set_visible(False)
        self.ax.get_xaxis().tick_bottom()
        self.ax.get_yaxis().tick_left()

        FigureCanvas.__init__(self, self.fig)

        self.setParent(parent.views)

        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        
        # Install navigation handler; we need to provide a Qt interface that can handle multiple 
        # plots in a window under separate tabs
        self.navigation = MplNavigationHandler( self )
        

    def generate(self):
        pass

    def saveAsImage(self, settings): # Size, dots per metre (for print), resample (redraw) image
        filename, _ = QFileDialog.getSaveFileName(self, 'Save current figure', '',  "Tagged Image File Format (*.tif);;\
                                                                                     Portable Document File (*.pdf);;\
                                                                                     Encapsulated Postscript File (*.eps);;\
                                                                                     Scalable Vector Graphics (*.svg);;\
                                                                                     Portable Network Graphics (*.png)")

        if filename:
            size = settings.get_print_size('in')
            dpi = settings.get_dots_per_inch()
            prev_size = self.fig.get_size_inches()
            self.fig.set_size_inches(*size)
            
            self.fig.savefig(filename, dpi=dpi)
            self.fig.set_size_inches(*prev_size)
            self.redraw()
            
    def redraw(self):
        #FIXME: Ugly hack to refresh the canvas
        self.resize( self.size() - QSize(1,1) )
        self.resize( self.size() + QSize(1,1) )
        
    def resizeEvent(self,e):
        FigureCanvas.resizeEvent(self,e)
        

    def get_text_bbox_screen_coords(self, t):
        bbox = t.get_window_extent(self.get_renderer())        
        return bbox.get_points()

    def get_text_bbox_data_coords(self, t):
        bbox = t.get_window_extent(self.get_renderer())        
        axbox = bbox.transformed(self.ax.transData.inverted())
        return axbox.get_points()
        
    def extend_limits(self, a, b):
        # Extend a to meet b where applicable
        ax, ay = list(a[0]), list(a[1])
        bx, by = b[:,0], b[:,1]
   
        ax[0] = bx[0] if bx[0] < ax[0] else ax[0]
        ax[1] = bx[1] if bx[1] > ax[1] else ax[1]

        ay[0] = by[0] if by[0] < ay[0] else ay[0]
        ay[1] = by[1] if by[1] > ay[1] else ay[1]
                
        return [ax,ay]
Ejemplo n.º 37
0
class BackendMatplotlib(BackendBase.BackendBase):
    """Base class for Matplotlib backend without a FigureCanvas.

    For interactive on screen plot, see :class:`BackendMatplotlibQt`.

    See :class:`BackendBase.BackendBase` for public API documentation.
    """

    def __init__(self, plot, parent=None):
        super(BackendMatplotlib, self).__init__(plot, parent)

        # matplotlib is handling keep aspect ratio at draw time
        # When keep aspect ratio is on, and one changes the limits and
        # ask them *before* next draw has been performed he will get the
        # limits without applying keep aspect ratio.
        # This attribute is used to ensure consistent values returned
        # when getting the limits at the expense of a replot
        self._dirtyLimits = True

        self.fig = Figure()
        self.fig.set_facecolor("w")

        self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
        self.ax2 = self.ax.twinx()
        self.ax2.set_label("right")

        # critical for picking!!!!
        self.ax2.set_zorder(0)
        self.ax2.set_autoscaley_on(True)
        self.ax.set_zorder(1)
        # this works but the figure color is left
        self.ax.set_axis_bgcolor('none')
        self.fig.sca(self.ax)

        self._overlays = set()
        self._background = None

        self._colormaps = {}

        self._graphCursor = tuple()
        self.matplotlibVersion = matplotlib.__version__

        self.setGraphXLimits(0., 100.)
        self.setGraphYLimits(0., 100., axis='right')
        self.setGraphYLimits(0., 100., axis='left')

        self._enableAxis('right', False)

    # Add methods

    def addCurve(self, x, y, legend,
                 color, symbol, linewidth, linestyle,
                 yaxis,
                 xerror, yerror, z, selectable,
                 fill):
        for parameter in (x, y, legend, color, symbol, linewidth, linestyle,
                          yaxis, z, selectable, fill):
            assert parameter is not None
        assert yaxis in ('left', 'right')

        if (len(color) == 4 and
                type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
            color = numpy.array(color, dtype=numpy.float) / 255.

        if yaxis == "right":
            axes = self.ax2
            self._enableAxis("right", True)
        else:
            axes = self.ax

        picker = 3 if selectable else None

        artists = []  # All the artists composing the curve

        # First add errorbars if any so they are behind the curve
        if xerror is not None or yerror is not None:
            if hasattr(color, 'dtype') and len(color) == len(x):
                errorbarColor = 'k'
            else:
                errorbarColor = color

            # On Debian 7 at least, Nx1 array yerr does not seems supported
            if (yerror is not None and yerror.ndim == 2 and
                    yerror.shape[1] == 1 and len(x) != 1):
                yerror = numpy.ravel(yerror)

            errorbars = axes.errorbar(x, y, label=legend,
                                      xerr=xerror, yerr=yerror,
                                      linestyle=' ', color=errorbarColor)
            artists += list(errorbars.get_children())

        if hasattr(color, 'dtype') and len(color) == len(x):
            # scatter plot
            if color.dtype not in [numpy.float32, numpy.float]:
                actualColor = color / 255.
            else:
                actualColor = color

            if linestyle not in ["", " ", None]:
                # scatter plot with an actual line ...
                # we need to assign a color ...
                curveList = axes.plot(x, y, label=legend,
                                      linestyle=linestyle,
                                      color=actualColor[0],
                                      linewidth=linewidth,
                                      picker=picker,
                                      marker=None)
                artists += list(curveList)

            scatter = axes.scatter(x, y,
                                   label=legend,
                                   color=actualColor,
                                   marker=symbol,
                                   picker=picker)
            artists.append(scatter)

            if fill:
                artists.append(axes.fill_between(
                    x, 1.0e-8, y, facecolor=actualColor[0], linestyle=''))

        else:  # Curve
            curveList = axes.plot(x, y,
                                  label=legend,
                                  linestyle=linestyle,
                                  color=color,
                                  linewidth=linewidth,
                                  marker=symbol,
                                  picker=picker)
            artists += list(curveList)

            if fill:
                artists.append(
                    axes.fill_between(x, 1.0e-8, y,
                                      facecolor=color, linewidth=0))

        for artist in artists:
            artist.set_zorder(z)

        return Container(artists)

    def addImage(self, data, legend,
                 origin, scale, z,
                 selectable, draggable,
                 colormap):
        # Non-uniform image
        # http://wiki.scipy.org/Cookbook/Histograms
        # Non-linear axes
        # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
        for parameter in (data, legend, origin, scale, z,
                          selectable, draggable):
            assert parameter is not None

        h, w = data.shape[0:2]
        xmin = origin[0]
        xmax = xmin + scale[0] * w
        if scale[0] < 0.:
            xmin, xmax = xmax, xmin
        ymin = origin[1]
        ymax = ymin + scale[1] * h
        if scale[1] < 0.:
            ymin, ymax = ymax, ymin
        extent = (xmin, xmax, ymax, ymin)

        picker = (selectable or draggable)

        # Debian 7 specific support
        # No transparent colormap with matplotlib < 1.2.0
        # Add support for transparent colormap for uint8 data with
        # colormap with 256 colors, linear norm, [0, 255] range
        if matplotlib.__version__ < '1.2.0':
            if (len(data.shape) == 2 and colormap['name'] is None and
                    'colors' in colormap):
                colors = numpy.array(colormap['colors'], copy=False)
                if (colors.shape[-1] == 4 and
                        not numpy.all(numpy.equal(colors[3], 255))):
                    # This is a transparent colormap
                    if (colors.shape == (256, 4) and
                            colormap['normalization'] == 'linear' and
                            not colormap['autoscale'] and
                            colormap['vmin'] == 0 and
                            colormap['vmax'] == 255 and
                            data.dtype == numpy.uint8):
                        # Supported case, convert data to RGBA
                        data = colors[data.reshape(-1)].reshape(
                            data.shape + (4,))
                    else:
                        _logger.warning(
                            'matplotlib %s does not support transparent '
                            'colormap.', matplotlib.__version__)

        # the normalization can be a source of time waste
        # Two possibilities, we receive data or a ready to show image
        if len(data.shape) == 3:
            if data.shape[-1] == 4:
                # force alpha? data[:,:,3] = 255
                pass

            # RGBA image
            # TODO: Possibility to mirror the image
            # in case of pixmaps just setting
            # extend = (xmin, xmax, ymax, ymin)
            # instead of (xmin, xmax, ymin, ymax)
            extent = (xmin, xmax, ymin, ymax)
            if tuple(origin) != (0., 0.) or tuple(scale) != (1., 1.):
                # for the time being not properly handled
                imageClass = AxesImage
            elif (data.shape[0] * data.shape[1]) > 5.0e5:
                imageClass = ModestImage
            else:
                imageClass = AxesImage
            image = imageClass(self.ax,
                               label="__IMAGE__" + legend,
                               interpolation='nearest',
                               picker=picker,
                               zorder=z)
            if image.origin == 'upper':
                image.set_extent((xmin, xmax, ymax, ymin))
            else:
                image.set_extent((xmin, xmax, ymin, ymax))
            image.set_data(data)

        else:
            assert colormap is not None

            if colormap['name'] is not None:
                cmap = self.__getColormap(colormap['name'])
            else:  # No name, use custom colors
                if 'colors' not in colormap:
                    raise ValueError(
                        'addImage: colormap no name nor list of colors.')
                colors = numpy.array(colormap['colors'], copy=True)
                assert len(colors.shape) == 2
                assert colors.shape[-1] in (3, 4)
                if colors.dtype == numpy.uint8:
                    # Convert to float in [0., 1.]
                    colors = colors.astype(numpy.float32) / 255.
                cmap = ListedColormap(colors)

            if colormap['normalization'].startswith('log'):
                vmin, vmax = None, None
                if not colormap['autoscale']:
                    if colormap['vmin'] > 0.:
                        vmin = colormap['vmin']
                    if colormap['vmax'] > 0.:
                        vmax = colormap['vmax']

                    if vmin is None or vmax is None:
                        _logger.warning('Log colormap with negative bounds, ' +
                                        'changing bounds to positive ones.')
                    elif vmin > vmax:
                        _logger.warning('Colormap bounds are inverted.')
                        vmin, vmax = vmax, vmin

                # Set unset/negative bounds to positive bounds
                if vmin is None or vmax is None:
                    finiteData = data[numpy.isfinite(data)]
                    posData = finiteData[finiteData > 0]
                    if vmax is None:
                        # 1. as an ultimate fallback
                        vmax = posData.max() if posData.size > 0 else 1.
                    if vmin is None:
                        vmin = posData.min() if posData.size > 0 else vmax
                    if vmin > vmax:
                        vmin = vmax

                norm = LogNorm(vmin, vmax)

            else:  # Linear normalization
                if colormap['autoscale']:
                    finiteData = data[numpy.isfinite(data)]
                    vmin = finiteData.min()
                    vmax = finiteData.max()
                else:
                    vmin = colormap['vmin']
                    vmax = colormap['vmax']
                    if vmin > vmax:
                        _logger.warning('Colormap bounds are inverted.')
                        vmin, vmax = vmax, vmin

                norm = Normalize(vmin, vmax)

            # try as data
            if tuple(origin) != (0., 0.) or tuple(scale) != (1., 1.):
                # for the time being not properly handled
                imageClass = AxesImage
            elif (data.shape[0] * data.shape[1]) > 5.0e5:
                imageClass = ModestImage
            else:
                imageClass = AxesImage
            image = imageClass(self.ax,
                               label="__IMAGE__" + legend,
                               interpolation='nearest',
                               cmap=cmap,
                               extent=extent,
                               picker=picker,
                               zorder=z,
                               norm=norm)

            if image.origin == 'upper':
                image.set_extent((xmin, xmax, ymax, ymin))
            else:
                image.set_extent((xmin, xmax, ymin, ymax))

            image.set_data(data)

        self.ax.add_artist(image)

        return image

    def addItem(self, x, y, legend, shape, color, fill, overlay, z):
        xView = numpy.array(x, copy=False)
        yView = numpy.array(y, copy=False)

        if shape == "line":
            item = self.ax.plot(x, y, label=legend, color=color,
                                linestyle='-', marker=None)[0]

        elif shape == "hline":
            if hasattr(y, "__len__"):
                y = y[-1]
            item = self.ax.axhline(y, label=legend, color=color)

        elif shape == "vline":
            if hasattr(x, "__len__"):
                x = x[-1]
            item = self.ax.axvline(x, label=legend, color=color)

        elif shape == 'rectangle':
            xMin = numpy.nanmin(xView)
            xMax = numpy.nanmax(xView)
            yMin = numpy.nanmin(yView)
            yMax = numpy.nanmax(yView)
            w = xMax - xMin
            h = yMax - yMin
            item = Rectangle(xy=(xMin, yMin),
                             width=w,
                             height=h,
                             fill=False,
                             color=color)
            if fill:
                item.set_hatch('.')

            self.ax.add_patch(item)

        elif shape in ('polygon', 'polylines'):
            xView = xView.reshape(1, -1)
            yView = yView.reshape(1, -1)
            item = Polygon(numpy.vstack((xView, yView)).T,
                           closed=(shape == 'polygon'),
                           fill=False,
                           label=legend,
                           color=color)
            if fill and shape == 'polygon':
                item.set_hatch('/')

            self.ax.add_patch(item)

        else:
            raise NotImplementedError("Unsupported item shape %s" % shape)

        item.set_zorder(z)

        if overlay:
            item.set_animated(True)
            self._overlays.add(item)

        return item

    def addMarker(self, x, y, legend, text, color,
                  selectable, draggable,
                  symbol, constraint, overlay):
        legend = "__MARKER__" + legend

        if x is not None and y is not None:
            line = self.ax.plot(x, y, label=legend,
                                linestyle=" ",
                                color=color,
                                marker=symbol,
                                markersize=10.)[-1]

            if text is not None:
                xtmp, ytmp = self.ax.transData.transform_point((x, y))
                inv = self.ax.transData.inverted()
                xtmp, ytmp = inv.transform_point((xtmp, ytmp))

                if symbol is None:
                    valign = 'baseline'
                else:
                    valign = 'top'
                    text = "  " + text

                line._infoText = self.ax.text(x, ytmp, text,
                                              color=color,
                                              horizontalalignment='left',
                                              verticalalignment=valign)

        elif x is not None:
            line = self.ax.axvline(x, label=legend, color=color)
            if text is not None:
                text = " " + text
                ymin, ymax = self.getGraphYLimits(axis='left')
                delta = abs(ymax - ymin)
                if ymin > ymax:
                    ymax = ymin
                ymax -= 0.005 * delta
                line._infoText = self.ax.text(x, ymax, text,
                                              color=color,
                                              horizontalalignment='left',
                                              verticalalignment='top')

        elif y is not None:
            line = self.ax.axhline(y, label=legend, color=color)

            if text is not None:
                text = " " + text
                xmin, xmax = self.getGraphXLimits()
                delta = abs(xmax - xmin)
                if xmin > xmax:
                    xmax = xmin
                xmax -= 0.005 * delta
                line._infoText = self.ax.text(xmax, y, text,
                                              color=color,
                                              horizontalalignment='right',
                                              verticalalignment='top')

        else:
            raise RuntimeError('A marker must at least have one coordinate')

        if selectable or draggable:
            line.set_picker(5)

        if overlay:
            line.set_animated(True)
            self._overlays.add(line)

        return line

    # Remove methods

    def remove(self, item):
        # Warning: It also needs to remove extra stuff if added as for markers
        if hasattr(item, "_infoText"):  # For markers text
            item._infoText.remove()
            item._infoText = None
        self._overlays.discard(item)
        item.remove()

    # Interaction methods

    def setGraphCursor(self, flag, color, linewidth, linestyle):
        if flag:
            lineh = self.ax.axhline(
                self.ax.get_ybound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            lineh.set_animated(True)

            linev = self.ax.axvline(
                self.ax.get_xbound()[0], visible=False, color=color,
                linewidth=linewidth, linestyle=linestyle)
            linev.set_animated(True)

            self._graphCursor = lineh, linev
        else:
            if self._graphCursor is not None:
                lineh, linev = self._graphCursor
                lineh.remove()
                linev.remove()
                self._graphCursor = tuple()

    # Active curve

    def setActiveCurve(self, curve, active, color=None):
        # Store Line2D and PathCollection
        for artist in curve.get_children():
            if active:
                if isinstance(artist, (Line2D, LineCollection)):
                    artist._initialColor = artist.get_color()
                    artist.set_color(color)
                elif isinstance(artist, PathCollection):
                    artist._initialColor = artist.get_facecolors()
                    artist.set_facecolors(color)
                    artist.set_edgecolors(color)
                else:
                    _logger.warning(
                        'setActiveCurve ignoring artist %s', str(artist))
            else:
                if hasattr(artist, '_initialColor'):
                    if isinstance(artist, (Line2D, LineCollection)):
                        artist.set_color(artist._initialColor)
                    elif isinstance(artist, PathCollection):
                        artist.set_facecolors(artist._initialColor)
                        artist.set_edgecolors(artist._initialColor)
                    else:
                        _logger.info(
                            'setActiveCurve ignoring artist %s', str(artist))
                    del artist._initialColor

    # Misc.

    def getWidgetHandle(self):
        return self.fig.canvas

    def _enableAxis(self, axis, flag=True):
        """Show/hide Y axis

        :param str axis: Axis name: 'left' or 'right'
        :param bool flag: Default, True
        """
        assert axis in ('right', 'left')
        axes = self.ax2 if axis == 'right' else self.ax
        axes.get_yaxis().set_visible(flag)

    def replot(self):
        """Do not perform rendering.

        Override in subclass to actually draw something.
        """
        # TODO images, markers? scatter plot? move in remove?
        # Right Y axis only support curve for now
        # Hide right Y axis if no line is present
        self._dirtyLimits = False
        if not self.ax2.lines:
            self._enableAxis('right', False)

    def saveGraph(self, fileName, fileFormat, dpi):
        # fileName can be also a StringIO or file instance
        if dpi is not None:
            self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
        else:
            self.fig.savefig(fileName, format=fileFormat)
        self._plot._setDirtyPlot()

    # Graph labels

    def setGraphTitle(self, title):
        self.ax.set_title(title)

    def setGraphXLabel(self, label):
        self.ax.set_xlabel(label)

    def setGraphYLabel(self, label, axis):
        axes = self.ax if axis == 'left' else self.ax2
        axes.set_ylabel(label)

    # Graph limits

    def resetZoom(self, dataMargins):
        xAuto = self._plot.isXAxisAutoScale()
        yAuto = self._plot.isYAxisAutoScale()

        if not xAuto and not yAuto:
            _logger.debug("Nothing to autoscale")
        else:  # Some axes to autoscale
            xLimits = self.getGraphXLimits()
            yLimits = self.getGraphYLimits(axis='left')
            y2Limits = self.getGraphYLimits(axis='right')

            # Get data range
            ranges = self._plot.getDataRange()
            xmin, xmax = (1., 100.) if ranges.x is None else ranges.x
            ymin, ymax = (1., 100.) if ranges.y is None else ranges.y
            if ranges.yright is None:
                ymin2, ymax2 = None, None
            else:
                ymin2, ymax2 = ranges.yright

            # Add margins around data inside the plot area
            newLimits = list(_utils.addMarginsToLimits(
                dataMargins,
                self.ax.get_xscale() == 'log',
                self.ax.get_yscale() == 'log',
                xmin, xmax, ymin, ymax, ymin2, ymax2))

            if self.isKeepDataAspectRatio():
                # Compute bbox wth figure aspect ratio
                figW, figH = self.fig.get_size_inches()
                figureRatio = figH / figW

                dataRatio = (ymax - ymin) / (xmax - xmin)
                if dataRatio < figureRatio:
                    # Increase y range
                    ycenter = 0.5 * (newLimits[3] + newLimits[2])
                    yrange = (xmax - xmin) * figureRatio
                    newLimits[2] = ycenter - 0.5 * yrange
                    newLimits[3] = ycenter + 0.5 * yrange

                elif dataRatio > figureRatio:
                    # Increase x range
                    xcenter = 0.5 * (newLimits[1] + newLimits[0])
                    xrange_ = (ymax - ymin) / figureRatio
                    newLimits[0] = xcenter - 0.5 * xrange_
                    newLimits[1] = xcenter + 0.5 * xrange_

            self.setLimits(*newLimits)

            if not xAuto and yAuto:
                self.setGraphXLimits(*xLimits)
            elif xAuto and not yAuto:
                if y2Limits is not None:
                    self.setGraphYLimits(
                        y2Limits[0], y2Limits[1], axis='right')
                if yLimits is not None:
                    self.setGraphYLimits(yLimits[0], yLimits[1], axis='left')

    def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
        # Let matplotlib taking care of keep aspect ratio if any
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

        if y2min is not None and y2max is not None:
            if not self.isYAxisInverted():
                self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
            else:
                self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))

        if not self.isYAxisInverted():
            self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
        else:
            self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))

    def getGraphXLimits(self):
        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits
        return self.ax.get_xbound()

    def setGraphXLimits(self, xmin, xmax):
        self._dirtyLimits = True
        self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))

    def getGraphYLimits(self, axis):
        assert axis in ('left', 'right')
        ax = self.ax2 if axis == 'right' else self.ax

        if not ax.get_visible():
            return None

        if self._dirtyLimits and self.isKeepDataAspectRatio():
            self.replot()  # makes sure we get the right limits

        return ax.get_ybound()

    def setGraphYLimits(self, ymin, ymax, axis):
        ax = self.ax2 if axis == 'right' else self.ax
        if ymax < ymin:
            ymin, ymax = ymax, ymin
        self._dirtyLimits = True

        if self.isKeepDataAspectRatio():
            # matplotlib keeps limits of shared axis when keeping aspect ratio
            # So x limits are kept when changing y limits....
            # Change x limits first by taking into account aspect ratio
            # and then change y limits.. so matplotlib does not need
            # to make change (to y) to keep aspect ratio
            xmin, xmax = ax.get_xbound()
            curYMin, curYMax = ax.get_ybound()

            newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
            xcenter = 0.5 * (xmin + xmax)
            ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)

        if not self.isYAxisInverted():
            ax.set_ylim(ymin, ymax)
        else:
            ax.set_ylim(ymax, ymin)

    # Graph axes

    def setXAxisLogarithmic(self, flag):
        self.ax2.set_xscale('log' if flag else 'linear')
        self.ax.set_xscale('log' if flag else 'linear')

    def setYAxisLogarithmic(self, flag):
        self.ax2.set_yscale('log' if flag else 'linear')
        self.ax.set_yscale('log' if flag else 'linear')

    def setYAxisInverted(self, flag):
        if self.ax.yaxis_inverted() != bool(flag):
            self.ax.invert_yaxis()

    def isYAxisInverted(self):
        return self.ax.yaxis_inverted()

    def isKeepDataAspectRatio(self):
        return self.ax.get_aspect() in (1.0, 'equal')

    def setKeepDataAspectRatio(self, flag):
        self.ax.set_aspect(1.0 if flag else 'auto')
        self.ax2.set_aspect(1.0 if flag else 'auto')

    def setGraphGrid(self, which):
        self.ax.grid(False, which='both')  # Disable all grid first
        if which is not None:
            self.ax.grid(True, which=which)

    # colormap

    def getSupportedColormaps(self):
        default = super(BackendMatplotlib, self).getSupportedColormaps()
        maps = [m for m in cm.datad]
        maps.sort()
        return default + maps

    def __getColormap(self, name):
        if not self._colormaps:  # Lazy initialization of own colormaps
            cdict = {'red': ((0.0, 0.0, 0.0),
                             (1.0, 1.0, 1.0)),
                     'green': ((0.0, 0.0, 0.0),
                               (1.0, 0.0, 0.0)),
                     'blue': ((0.0, 0.0, 0.0),
                              (1.0, 0.0, 0.0))}
            self._colormaps['red'] = LinearSegmentedColormap(
                'red', cdict, 256)

            cdict = {'red': ((0.0, 0.0, 0.0),
                             (1.0, 0.0, 0.0)),
                     'green': ((0.0, 0.0, 0.0),
                               (1.0, 1.0, 1.0)),
                     'blue': ((0.0, 0.0, 0.0),
                              (1.0, 0.0, 0.0))}
            self._colormaps['green'] = LinearSegmentedColormap(
                'green', cdict, 256)

            cdict = {'red': ((0.0, 0.0, 0.0),
                             (1.0, 0.0, 0.0)),
                     'green': ((0.0, 0.0, 0.0),
                               (1.0, 0.0, 0.0)),
                     'blue': ((0.0, 0.0, 0.0),
                              (1.0, 1.0, 1.0))}
            self._colormaps['blue'] = LinearSegmentedColormap(
                'blue', cdict, 256)

            # Temperature as defined in spslut
            cdict = {'red': ((0.0, 0.0, 0.0),
                             (0.5, 0.0, 0.0),
                             (0.75, 1.0, 1.0),
                             (1.0, 1.0, 1.0)),
                     'green': ((0.0, 0.0, 0.0),
                               (0.25, 1.0, 1.0),
                               (0.75, 1.0, 1.0),
                               (1.0, 0.0, 0.0)),
                     'blue': ((0.0, 1.0, 1.0),
                              (0.25, 1.0, 1.0),
                              (0.5, 0.0, 0.0),
                              (1.0, 0.0, 0.0))}
            # but limited to 256 colors for a faster display (of the colorbar)
            self._colormaps['temperature'] = LinearSegmentedColormap(
                'temperature', cdict, 256)

            # reversed gray
            cdict = {'red':     ((0.0, 1.0, 1.0),
                                 (1.0, 0.0, 0.0)),
                     'green':   ((0.0, 1.0, 1.0),
                                 (1.0, 0.0, 0.0)),
                     'blue':    ((0.0, 1.0, 1.0),
                                 (1.0, 0.0, 0.0))}

            self._colormaps['reversed gray'] = LinearSegmentedColormap(
                'yerg', cdict, 256)

        if name in self._colormaps:
            return self._colormaps[name]
        elif hasattr(MPLColormap, name):  # viridis and sister colormaps
            return getattr(MPLColormap, name)
        else:
            # matplotlib built-in
            return cm.get_cmap(name)

    # Data <-> Pixel coordinates conversion

    def dataToPixel(self, x, y, axis):
        ax = self.ax2 if axis == "right" else self.ax

        pixels = ax.transData.transform_point((x, y))
        xPixel, yPixel = pixels.T
        return xPixel, yPixel

    def pixelToData(self, x, y, axis, check):
        ax = self.ax2 if axis == "right" else self.ax

        inv = ax.transData.inverted()
        x, y = inv.transform_point((x, y))

        if check:
            xmin, xmax = self.getGraphXLimits()
            ymin, ymax = self.getGraphYLimits(axis=axis)

            if x > xmax or x < xmin or y > ymax or y < ymin:
                return None  # (x, y) is out of plot area

        return x, y

    def getPlotBoundsInPixels(self):
        bbox = self.ax.get_window_extent().transformed(
            self.fig.dpi_scale_trans.inverted())
        dpi = self.fig.dpi
        # Warning this is not returning int...
        return (bbox.bounds[0] * dpi, bbox.bounds[1] * dpi,
                bbox.bounds[2] * dpi, bbox.bounds[3] * dpi)
Ejemplo n.º 38
0
class MainPlotController(object):
    """
        A controller for the main plot canvas.
        Sets up the widgets and has image exporting functionality.
    """
    
    file_filters = ("Portable Network Graphics (PNG)", "*.png"), \
                   ("Scalable Vector Graphics (SVG)", "*.svg"), \
                   ("Portable Document Format (PDF)", "*.pdf")

    _canvas = None
    @property
    def canvas(self):
        if not self._canvas:
            self.setup_figure()
            self.setup_canvas()
            self.setup_content()
        return self._canvas

    # ------------------------------------------------------------
    #      View integration getters
    # ------------------------------------------------------------
    def get_toolbar_widget(self, window):
        return NavigationToolbar(self.canvas, window)

    def get_canvas_widget(self):
        return self.canvas
    
    # ------------------------------------------------------------
    #      Initialization and other internals
    # ------------------------------------------------------------
    def __init__(self, status_callback, marker_callback, *args, **kwargs):
        self.setup_layout_cache()
        self.setup_figure()
        self.setup_canvas()
        self.setup_content(status_callback, marker_callback)

    def setup_layout_cache(self):
        self.position_setup = PositionSetup()
        self.labels = list()
        self.marker_lbls = list()
        self._proxies = dict()
        self.scale = 1.0
        self.stats = False
        self._last_pos = None

    def setup_figure(self):
        self.figure = Figure(dpi=72, facecolor="#FFFFFF", linewidth=0)
        self.figure.subplots_adjust(hspace=0.0, wspace=0.0)

    def setup_canvas(self):
        self._canvas = FigureCanvasGTK(self.figure)

    def setup_content(self, status_callback, marker_callback):
        # Create subplot and add it to the figure:
        self.plot = Subplot(self.figure, 211, facecolor=(1.0, 1.0, 1.0, 0.0))
        self.plot.set_autoscale_on(False)
        self.figure.add_axes(self.plot)

        # Connect events:
        self.canvas.mpl_connect('draw_event', self.fix_after_drawing)
        self.canvas.mpl_connect('resize_event', self.fix_after_drawing)

        self.mtc = MotionTracker(self, status_callback)
        self.cc = ClickCatcher(self, marker_callback)

        #self.update()

    # ------------------------------------------------------------
    #      Update methods
    # ------------------------------------------------------------
    def draw(self):
        self._last_pos = self.fix_before_drawing()
        self.figure.canvas.draw()

    def fix_after_drawing(self, *args):
        _new_pos = self.fix_before_drawing()
        
        if _new_pos != self._last_pos:
            self.figure.canvas.draw()
        self._last_pos = _new_pos

        return False

    def fix_before_drawing(self, *args):
        """
            Fixes alignment issues due to longer labels or smaller windows
            Is executed after an initial draw event, since we can then retrieve
            the actual label dimensions and shift/resize the plot accordingly.
        """
        renderer = get_renderer(self.figure)        
        if not renderer or not self._canvas.get_realized():
            return False
        
        # Fix left side for wide specimen labels:
        if len(self.labels) > 0:
            bbox = self._get_joint_bbox(self.labels, renderer)
            if bbox is not None: 
                self.position_setup.left = self.position_setup.default_left + bbox.width
        # Fix top for high marker labels:
        if len(self.marker_lbls) > 0:
            bbox = self._get_joint_bbox([ label for label, flag, _ in self.marker_lbls if flag ], renderer)
            if bbox is not None: 
                self.position_setup.top = self.position_setup.default_top - bbox.height
        # Fix bottom for x-axis label:
        bottom_label = self.plot.axis["bottom"].label
        if bottom_label is not None:
            bbox = self._get_joint_bbox([bottom_label], renderer)
            if bbox is not None:
                self.position_setup.bottom = self.position_setup.default_bottom + (bbox.ymax - bbox.ymin) * 2.0 # somehow we need this?

        # Calculate new plot position & set it:
        plot_pos = self.position_setup.position
        self.plot.set_position(plot_pos)

        # Adjust specimen label position
        for label in self.labels:
            label.set_x(plot_pos[0] - 0.025)

        # Adjust marker label position
        for label, flag, y_offset in self.marker_lbls:
            if flag:
                newy = plot_pos[1] + plot_pos[3] + y_offset - 0.025
                label.set_y(newy)
        
        _new_pos = self.position_setup.to_string()
        return _new_pos
    
    def update(self, clear=False, project=None, specimens=None):
        """
            Updates the entire plot with the given information.
        """
        if clear: self.plot.cla()

        if project and specimens:
            self.labels, self.marker_lbls = plot_specimens(
                self.plot, self.position_setup, self.cc,
                project, specimens
            )
            # get mixtures for the selected specimens:
            plot_mixtures(self.plot, project, [ mixture for mixture in project.mixtures if any(specimen in mixture.specimens for specimen in specimens) ])

        update_axes(
            self.plot, self.position_setup,
            project, specimens
        )

        self.draw()

    # ------------------------------------------------------------
    #      Plot position and size calculations
    # ------------------------------------------------------------
    def _get_joint_bbox(self, container, renderer):
        bboxes = []
        try:
            for text in container:
                bbox = text.get_window_extent(renderer=renderer)
                # the figure transform goes from relative coords->pixels and we
                # want the inverse of that
                bboxi = bbox.inverse_transformed(self.figure.transFigure)
                bboxes.append(bboxi)
        except (RuntimeError, ValueError):
            logger.exception("Caught unhandled exception when joining boundig boxes")
            return None # don't continue
        # this is the bbox that bounds all the bboxes, again in relative
        # figure coords
        if len(bboxes) > 0:
            bbox = transforms.Bbox.union(bboxes)
            return bbox
        else:
            return None

    # ------------------------------------------------------------
    #      Graph exporting
    # ------------------------------------------------------------
    def save(self, parent=None, current_name="graph", size="auto", num_specimens=1, offset=0.75):
        """
            Displays a save dialog to export an image from the current plot.
        """
        # Parse arguments:
        width, height = 0, 0
        if size == "auto":
            descr, width, height, dpi = settings.OUTPUT_PRESETS[0]
        else:
            width, height, dpi = list(map(float, size.replace("@", "x").split("x")))

        # Load gui:
        builder = Gtk.Builder()
        builder.add_from_file(resource_filename("pyxrd.specimen", "glade/save_graph_size.glade")) # FIXME move this to this namespace!!
        size_expander = builder.get_object("size_expander")
        cmb_presets = builder.get_object("cmb_presets")

        # Setup combo with presets:
        cmb_store = Gtk.ListStore(str, int, int, float)
        for row in settings.OUTPUT_PRESETS:
            cmb_store.append(row)
        cmb_presets.clear()
        cmb_presets.set_model(cmb_store)
        cell = Gtk.CellRendererText()
        cmb_presets.pack_start(cell, True)
        cmb_presets.add_attribute(cell, 'text', 0)
        def on_cmb_changed(cmb, *args):
            itr = cmb.get_active_iter()
            w, h, d = cmb_store.get(itr, 1, 2, 3)
            entry_w.set_text(str(w))
            entry_h.set_text(str(h))
            entry_dpi.set_text(str(d))
        cmb_presets.connect('changed', on_cmb_changed)

        # Setup input boxes:
        entry_w = builder.get_object("entry_width")
        entry_h = builder.get_object("entry_height")
        entry_dpi = builder.get_object("entry_dpi")
        entry_w.set_text(str(width))
        entry_h.set_text(str(height))
        entry_dpi.set_text(str(dpi))

        # What to do when the user wants to save this:
        def on_accept(dialog):
            # Get the width, height & dpi
            width = float(entry_w.get_text())
            height = float(entry_h.get_text())
            dpi = float(entry_dpi.get_text())
            i_width, i_height = width / dpi, height / dpi
            # Save it all right!
            self.save_figure(dialog.filename, dpi, i_width, i_height)

        # Ask the user where, how and if he wants to save:
        DialogFactory.get_save_dialog(
            "Save Graph", parent=parent,
            filters=self.file_filters, current_name=current_name,
            extra_widget=size_expander
        ).run(on_accept)

    def save_figure(self, filename, dpi, i_width, i_height):
        """
            Save the current plot
            
            Arguments:
             filename: the filename to save to (either .png, .pdf or .svg)
             dpi: Dots-Per-Inch resolution
             i_width: the width in inch
             i_height: the height in inch
        """
        # Get original settings:
        original_dpi = self.figure.get_dpi()
        original_width, original_height = self.figure.get_size_inches()
        # Set everything according to the user selection:
        self.figure.set_dpi(dpi)
        self.figure.set_size_inches((i_width, i_height))
        self.figure.canvas.draw() # replot
        bbox_inches = matplotlib.transforms.Bbox.from_bounds(0, 0, i_width, i_height)
        # Save the figure:
        self.figure.savefig(filename, dpi=dpi, bbox_inches=bbox_inches)
        # Put everything back the way it was:
        self.figure.set_dpi(original_dpi)
        self.figure.set_size_inches((original_width, original_height))
        self.figure.canvas.draw() # replot

    pass # end of class
Ejemplo n.º 39
0
class plotTableWindow(QMainWindow, Ui_plotTableWindow):
    """
    This class provides the view to manage the plots of the tables. It inherits from the Ui_plotTableWindow which is 
    a dialog class built by QTDesinger. This dialog contains two principal panel. The right panel is used to embed a Matplotlib figure canvas. 
    The left panel shows a form where the user can select the data to plot, give a name for each plot, design the label, etc.

    Attributes:
    - fig, Matplotlib Figure
    - canvas, Matplotlib FigureCanvas
    - axes, Matplitlib axes
    - plots, a dictionary of "plot" objects. When a new plot is added to the canvas, a new object containing its corresponding attributes 
    will be added to this dictionary. The keys of this dictionary are the labels of the plots.
    - view_tables, a dictionary with the widgets which contains a table (ascii table, votable or settable) in the main window. 
    This dictionary will be helpful for populating the table and the column combo-boxes  and for accessing the table data.
    """

    def __init__(self, view_tables, currentTable, parent=None):
        super(plotTableWindow, self).__init__(parent)
        self.setupUi(self)
        self.plots = {}
        self.view_tables = view_tables

        # Adding a empty plot figure
        # self.dpi = 100
        self.fig = Figure((6.0, 4.0), dpi=100)
        self.canvas = FigureCanvas(self.fig)
        self.canvas.setParent(self.plot_frame)
        self.axes = self.fig.add_subplot(111)
        mpl_toolbar = NavigationToolbar(self.canvas, self.plot_frame)
        self.plot_frame_Layout.addWidget(self.canvas)
        self.plot_frame_Layout.addWidget(mpl_toolbar)

        self.canvas.mpl_connect("pick_event", self.onpick)
        self.canvas.mpl_connect("button_press_event", self.onbuttonpress)

        # Generating  table list comboBox
        lineEdit = QLineEdit()
        lineEdit.setAlignment(Qt.AlignRight)
        lineEdit.setReadOnly(True)
        self.tableList.setLineEdit(lineEdit)
        for key, val in view_tables.iteritems():
            self.tableList.addItem(key, QVariant(key))

        i = self.tableList.findText(QString(currentTable))
        if i != -1:
            self.tableList.setCurrentIndex(i)

        # Generating  X/Y Axis list comboBox

        columnTitles = self.view_tables[currentTable].getColumnTitles()
        numericColumns = self.view_tables[currentTable].getNumericColumns()
        for index, title in enumerate(columnTitles):
            if numericColumns[index]:
                self.XAxis.addItem(title, QVariant(index))
                self.YAxis.addItem(title, QVariant(index))
                self.XError.addItem(title, QVariant(index))
                self.YError.addItem(title, QVariant(index))

        # Connecting in case a new table is open
        self.connect(self, SIGNAL("opentable"), self.updateTableList)
        # Connecting buttons
        self.connect(self.addButton, SIGNAL("clicked()"), self.addPlot)
        self.connect(self.delPlotButton, SIGNAL("clicked()"), self.delPlot)
        self.connect(self.plotButton, SIGNAL("clicked()"), self.plot)
        self.connect(self.showLegendBox, SIGNAL("stateChanged(int)"), self.enableLegend)
        self.connect(self.modifyPlotsBox, SIGNAL("stateChanged(int)"), self.enableModifyPlots)
        self.connect(self.xerrorCheck, SIGNAL("stateChanged(int)"), self.enableXErrorBars)
        self.connect(self.yerrorCheck, SIGNAL("stateChanged(int)"), self.enableYErrorBars)
        # Connecting tableList
        self.connect(self.tableList, SIGNAL("currentIndexChanged(int)"), self.updateColumnsList)

        # Connecting legend function
        self.connect(self.positionCombo, SIGNAL("currentIndexChanged(int)"), self.enablePosition)
        # Connection "save to table" button
        self.connect(self.saveButton, SIGNAL("clicked()"), self.updateTables)

        self.positionCombo.setCurrentIndex(1)  # Best
        self.textSizeSpin.setValue(12)
        self.ncolSpin.setValue(1)

    def onbuttonpress(self, event):
        """
        When the onbuttonpress event is raised, this function add the points corresponding to the point clicked by the user, 
        in case the "insert points" radio button is selected
        """
        if self.modifyPlotsBox.isChecked() and self.insertPointsRadio.isChecked():
            # if not self.picked:
            self.saveButton.setEnabled(True)
            # the click locations
            x = event.xdata
            y = event.ydata
            table = unicode(self.tableList.itemData(self.tableList.currentIndex()).toString())
            if len(self.plotList.selectedItems()) == 1:
                plot_selected = self.plotList.selectedItems()[0]

                if plot_selected != None:
                    plabel = plot_selected.text()
                    lines = self.axes.get_lines()
                    for l in lines:
                        label = l.get_label()
                        print plabel, label
                        if str(plabel) == str(label):
                            xdata = l.get_xdata()
                            ydata = l.get_ydata()
                            xs = numpy.concatenate((xdata, [x]))

                            ys = numpy.concatenate((ydata, [y]))
                            l.set_xdata(xs)
                            l.set_ydata(ys)
                            self.fig.canvas.draw()
                            self.plots[str(plabel)].currentX = xs.tolist()
                            self.plots[str(plabel)].currentY = ys.tolist()

    def onpick(self, event):
        """
        When the onpick event is raised, this function delete the points picked by the user, 
        in case the "delete points" radio button is selected
        """
        if self.modifyPlotsBox.isChecked() and self.deletePointsRadio.isChecked():

            thisline = event.artist
            xdata = thisline.get_xdata()
            ydata = thisline.get_ydata()
            label = thisline.get_label()
            print "label", label

            x = event.mouseevent.xdata
            y = event.mouseevent.ydata
            if x == None or y == None:
                return

            dx = numpy.array(abs(x - xdata[event.ind]), dtype=float)
            dy = numpy.array(abs(y - ydata[event.ind]), dtype=float)
            canvasSize = self.fig.get_size_inches()
            rangeX = abs(self.axes.get_xlim()[1] - self.axes.get_xlim()[0])
            rangeY = abs(self.axes.get_ylim()[1] - self.axes.get_ylim()[0])
            # Calculating limit distance on X
            # The limit  distance will be 0.01 inch
            limX = (0.05 * rangeX) / canvasSize[0]
            # Calculating limit  distance on Y
            # The limit  distance will be 0.01 inch
            limY = (0.05 * rangeY) / canvasSize[1]

            distances = numpy.hypot(dx, dy)
            indmin = distances.argmin()
            distanX = dx[indmin]
            distanY = dy[indmin]

            if distanX < limX and distanY < limY and str(label) in self.plots.keys():
                ind = event.ind[indmin]
                x = numpy.delete(xdata, ind)
                y = numpy.delete(ydata, ind)
                thisline.set_xdata(x)
                thisline.set_ydata(y)
                self.plots[str(label)].currentX = x.tolist()
                self.plots[str(label)].currentY = y.tolist()
                self.saveButton.setEnabled(True)
                self.canvas.draw()

    def enablePosition(self, position):
        if position == 0:
            self.xposSpin.setEnabled(True)
            self.yposSpin.setEnabled(True)

        else:
            self.xposSpin.setEnabled(False)
            self.yposSpin.setEnabled(False)

    def addPlot(self):

        if self.plotLabel.text() != "" and self.plotLabel.text() not in self.plots.keys():
            p = QPalette()
            p.setColor(QPalette.Base, QColor(255, 255, 255))
            self.plotLabel.setPalette(p)
            plabel = unicode(self.plotLabel.text())
            self.plotLabel.setText("")
            self.plotList.addItem(plabel)
            xcol = self.XAxis.itemData(self.XAxis.currentIndex()).toInt()[0]
            ycol = self.YAxis.itemData(self.YAxis.currentIndex()).toInt()[0]

            xcol = int(xcol)
            ycol = int(ycol)
            table = unicode(self.tableList.itemData(self.tableList.currentIndex()).toString())

            xaxis = self.view_tables[table].getColumn(xcol)
            yaxis = self.view_tables[table].getColumn(ycol)

            xerror = None
            yerror = None
            if self.xerrorCheck.checkState() == Qt.Checked:
                xerrcol = self.XError.itemData(self.XError.currentIndex()).toInt()[0]
                xerrcol = int(xerrcol)
                xerror = self.view_tables[table].getColumn(xerrcol)
            if self.yerrorCheck.checkState() == Qt.Checked:
                yerrcol = self.YError.itemData(self.YError.currentIndex()).toInt()[0]
                yerrcol = int(yerrcol)
                yerror = self.view_tables[table].getColumn(yerrcol)

            newplot = plot(plabel, xaxis, yaxis, table, xcol, ycol, xerror, yerror)
            self.plots[plabel] = newplot

        else:

            p = QPalette()
            p.setColor(QPalette.Base, QColor(255, 0, 0))
            self.plotLabel.setPalette(p)

    def editPlot(self):
        if len(self.plotList.selectedItems()) == 1:
            plot_selected = self.plotList.selectedItems()[0]

            if plot_selected != None:
                plabel = unicode(plot_selected.text())

            Dlg = editPlotDlg(self.plots[plabel], self)
            if Dlg.exec_():
                self.plots[plabel].label = Dlg.plotLabel.text()
                self.plots[plabel].color = Dlg.colorList.currentText()
                self.plots[plabel].style = Dlg.styleList.currentText()
                return

    def delPlot(self):
        if len(self.plotList.selectedItems()) == 1:
            plot_selected = self.plotList.selectedItems()[0]

            if plot_selected != None:
                plabel = plot_selected.text()
                del self.plots[unicode(plabel)]
                item = self.plotList.findItems(plabel, Qt.MatchExactly)

                if len(item) == 1:
                    r = self.plotList.row(item[0])
                    i = self.plotList.takeItem(r)
                    del i
            self.plot()

    def plot(self):
        self.axes.clear()

        for key, value in self.plots.iteritems():
            x = value.xAxis
            y = value.yAxis
            l = unicode(value.label)
            style = value.style
            color = value.color
            xerror = value.xError
            yerror = value.yError
            # After replot, the original data are recovered
            self.plots[l].currentX = x
            self.plots[l].currentY = y
            if xerror != None or yerror != None:
                line = self.axes.errorbar(x, y, xerr=xerror, yerr=yerror, fmt=None, label=l, color="b", ecolor="r")
            # else:

            line = self.axes.plot(x, y, marker="o", linestyle="none", label=l, picker=5)[0]

            if style != None:
                line.set_linestyle(STYLE[unicode(style)])
            if color != None:
                line.set_color(COLOR[unicode(color)])

        # Build the legend
        if self.legendFrame.isEnabled():
            pos = None
            if self.xposSpin.isEnabled():
                xpos = self.xposSpin.value()
                ypos = self.yposSpin.value()
                pos = (xpos, ypos)
            else:
                pos = unicode(self.positionCombo.currentText())

            title = self.titleLine.text()
            ncol = self.ncolSpin.value()
            textsize = self.textSizeSpin.value()
            if self.fancyBox.checkState() == Qt.Unchecked:
                fancy = False
            else:
                fancy = True

            if self.shadowBox.checkState() == Qt.Unchecked:
                shadow = False
            else:
                shadow = True

            self.axes.legend(loc=pos, title=title, ncol=ncol, fancybox=fancy, shadow=shadow, prop={"size": textsize})

        # Zoom out a bit to make bigger the plot screen (white part), in order the user can add points at the end or beginning of the plot
        xlim = self.axes.get_xlim()
        ylim = self.axes.get_ylim()
        offsetX = (xlim[1] - xlim[0]) / 10
        offsetY = (ylim[1] - ylim[0]) / 10
        self.axes.set_xlim(xlim[0] - offsetX, xlim[1] + offsetX)
        self.axes.set_ylim(ylim[0] - offsetY, ylim[1] + offsetY)

        self.canvas.draw()

    def updateTableList(self):
        print "updatetablelist"

    def updateColumnsList(self, index):
        item = self.tableList.itemData(index, Qt.UserRole).toString()
        if item:

            currentTable = unicode(item)
            columnTitles = self.view_tables[currentTable].getColumnTitles()
            numericColumns = self.view_tables[currentTable].getNumericColumns()
            self.XAxis.clear()
            self.YAxis.clear()
            self.XError.clear()
            self.YError.clear()
            for index, title in enumerate(columnTitles):
                if numericColumns[index]:
                    self.XAxis.addItem(title, QVariant(index))
                    self.YAxis.addItem(title, QVariant(index))
                    self.XError.addItem(title, QVariant(index))
                    self.YError.addItem(title, QVariant(index))

    def enableLegend(self, state):
        if state == Qt.Checked:
            self.legendFrame.setEnabled(True)
        else:
            self.legendFrame.setEnabled(False)

    def enableModifyPlots(self, state):
        if state == Qt.Checked:
            self.modifyPlotsFrame.setEnabled(True)
        else:
            self.modifyPlotsFrame.setEnabled(False)

    def enableXErrorBars(self, state):
        if state == Qt.Checked:
            self.XError.setEnabled(True)
        else:
            self.XError.setEnabled(False)

    def enableYErrorBars(self, state):
        if state == Qt.Checked:
            self.YError.setEnabled(True)
        else:
            self.YError.setEnabled(False)

    def updateTables(self):

        for key, value in self.plots.iteritems():
            tablename = value.tablename
            self.view_tables[tablename].updateColumn(value.xcol, value.currentX)
            self.view_tables[tablename].updateColumn(value.ycol, value.currentY)
Ejemplo n.º 40
0
class MyFrame(wx.Frame):
    def __init__(self, *args, **kwds):
        # begin wxGlade: MyFrame.__init__
        kwds["style"] = wx.DEFAULT_FRAME_STYLE
        wx.Frame.__init__(self, *args, **kwds)

        # Menu Bar
        self.frame_1_menubar = wx.MenuBar()
        wxglade_tmp_menu = wx.Menu()
        wxglade_tmp_menu.Append(10, "load 4D file", "", wx.ITEM_NORMAL)
        wxglade_tmp_menu.Append(11, "load python data", "", wx.ITEM_NORMAL)
        self.frame_1_menubar.Append(wxglade_tmp_menu, "File")
        wxglade_tmp_menu = wx.Menu()
        wxglade_tmp_menu.Append(22, "Channel Properties", "", wx.ITEM_NORMAL)
        self.frame_1_menubar.Append(wxglade_tmp_menu, "Edit")
        wxglade_tmp_menu = wx.Menu()
        wxglade_tmp_menu.Append(30, "tft", "", wx.ITEM_NORMAL)
        self.frame_1_menubar.Append(wxglade_tmp_menu, "Execute")
        self.SetMenuBar(self.frame_1_menubar)
        # Menu Bar end
        self.frame_1_statusbar = self.CreateStatusBar(1, 0)

        # Tool Bar
        self.frame_1_toolbar = wx.ToolBar(
            self, -1, style=wx.TB_HORIZONTAL | wx.TB_DOCKABLE | wx.TB_TEXT | wx.TB_NOICONS
        )
        self.SetToolBar(self.frame_1_toolbar)
        self.frame_1_toolbar.AddLabelTool(3, "PlotData", wx.NullBitmap, wx.NullBitmap, wx.ITEM_NORMAL, "", "")
        self.frame_1_toolbar.AddLabelTool(1, "RePlot", wx.NullBitmap, wx.NullBitmap, wx.ITEM_NORMAL, "", "")
        self.frame_1_toolbar.AddLabelTool(2, "ReDraw", wx.NullBitmap, wx.NullBitmap, wx.ITEM_NORMAL, "", "")
        self.frame_1_toolbar.AddSeparator()
        self.frame_1_toolbar.AddLabelTool(4, "TFT", wx.NullBitmap, wx.NullBitmap, wx.ITEM_CHECK, "", "")
        # Tool Bar end
        self.viewer = wx.ScrolledWindow(self, -1, style=wx.TAB_TRAVERSAL)
        self.combo_box_1 = wx.ComboBox(
            self,
            -1,
            choices=[".1sec", ".5sec", "1sec", "5sec", "10sec", "20sec", "30sec", "60sec"],
            style=wx.CB_DROPDOWN | wx.CB_DROPDOWN | wx.CB_READONLY,
        )
        self.bitmap_button_2 = wx.BitmapButton(self, -1, wx.Bitmap("left.png", wx.BITMAP_TYPE_ANY))
        self.bitmap_button_1 = wx.BitmapButton(self, -1, wx.Bitmap("right.png", wx.BITMAP_TYPE_ANY))
        self.bitmap_button_3 = wx.BitmapButton(self, -1, wx.Bitmap("up.png", wx.BITMAP_TYPE_ANY))
        self.bitmap_button_4 = wx.BitmapButton(self, -1, wx.Bitmap("down.png", wx.BITMAP_TYPE_ANY))
        self.slider_3 = wx.Slider(
            self, -1, 0, 0, 1, style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS | wx.SL_SELRANGE
        )
        self.slider_1 = wx.Slider(self, -1, 1, 0, 10, style=wx.SL_VERTICAL | wx.SL_AUTOTICKS | wx.SL_LABELS)
        self.button_1 = wx.ToggleButton(self, -1, "AutoScale")
        self.slider_2 = wx.Slider(self, -1, 1, 1, 100, style=wx.SL_VERTICAL | wx.SL_AUTOTICKS | wx.SL_LABELS)

        self.__set_properties()
        self.__do_layout()

        self.Bind(wx.EVT_MENU, self.loaddatagui, id=10)
        self.Bind(wx.EVT_MENU, self.loadpythondata, id=11)
        self.Bind(wx.EVT_MENU, self.chproperties, id=22)
        self.Bind(wx.EVT_MENU, self.tftwindow, id=30)
        self.Bind(wx.EVT_TOOL, self.plotloadeddata, id=3)
        self.Bind(wx.EVT_TOOL, self.go, id=1)
        self.Bind(wx.EVT_TOOL, self.redraw, id=2)
        self.Bind(wx.EVT_TOOL, self.tftframe, id=4)
        self.Bind(wx.EVT_COMBOBOX, self.settimewin, self.combo_box_1)
        self.Bind(wx.EVT_BUTTON, self.panleft, self.bitmap_button_2)
        self.Bind(wx.EVT_BUTTON, self.panright, self.bitmap_button_1)
        self.Bind(wx.EVT_BUTTON, self.panup, self.bitmap_button_3)
        self.Bind(wx.EVT_BUTTON, self.pandown, self.bitmap_button_4)
        self.Bind(wx.EVT_COMMAND_SCROLL_THUMBTRACK, self.settimepnt, self.slider_3)
        self.Bind(wx.EVT_COMMAND_SCROLL_PAGEDOWN, self.settimepnt, self.slider_3)
        self.Bind(wx.EVT_COMMAND_SCROLL_PAGEUP, self.settimepnt, self.slider_3)
        self.Bind(wx.EVT_COMMAND_SCROLL_LINEDOWN, self.settimepnt, self.slider_3)
        self.Bind(wx.EVT_COMMAND_SCROLL_LINEUP, self.settimepnt, self.slider_3)
        self.Bind(wx.EVT_COMMAND_SCROLL_THUMBRELEASE, self.settimepnt, self.slider_3)
        self.Bind(wx.EVT_COMMAND_SCROLL_PAGEDOWN, self.numchans, self.slider_1)
        self.Bind(wx.EVT_COMMAND_SCROLL_ENDSCROLL, self.numchans, self.slider_1)
        self.Bind(wx.EVT_COMMAND_SCROLL_PAGEUP, self.numchans, self.slider_1)
        self.Bind(wx.EVT_COMMAND_SCROLL_THUMBRELEASE, self.numchans, self.slider_1)
        self.Bind(wx.EVT_TOGGLEBUTTON, self.autoscalehandler, self.button_1)
        self.Bind(wx.EVT_COMMAND_SCROLL_THUMBTRACK, self.amplitudescale, self.slider_2)
        self.Bind(wx.EVT_COMMAND_SCROLL_PAGEDOWN, self.amplitudescale, self.slider_2)
        self.Bind(wx.EVT_COMMAND_SCROLL_PAGEUP, self.amplitudescale, self.slider_2)
        self.Bind(wx.EVT_COMMAND_SCROLL_LINEDOWN, self.amplitudescale, self.slider_2)
        self.Bind(wx.EVT_COMMAND_SCROLL_BOTTOM, self.amplitudescale, self.slider_2)
        self.Bind(wx.EVT_COMMAND_SCROLL_LINEUP, self.amplitudescale, self.slider_2)
        self.Bind(wx.EVT_COMMAND_SCROLL_THUMBRELEASE, self.amplitudescale, self.slider_2)
        # end wxGlade

        event = 0
        self.dpi = 50

        left, width = 0.05, 0.8
        rect1 = [left, 0.2, width, 0.79]
        rect2 = [left, 0.1, width, 0.2]

        self.fig = Figure(None, dpi=self.dpi)
        # self.fig = Figure((8.0, 20.0), dpi=self.dpi)
        self.canvas = FigCanvas(self.viewer, -1, self.fig)
        self.redraw(event)
        # ~ self.axes = self.fig.add_subplot(111,axisbg='#ababab')
        # ~ self.axes2 = self.fig.add_subplot(111,axisbg='#ababab', sharex=self.axes)

        self.axes = self.fig.add_axes(rect1, axisbg="#ababab")
        #####self.axes.yaxis.set_visible(False)
        self.axes.set_frame_on(False)
        self.axes2 = self.fig.add_axes(rect2, axisbg="#ababab")  # , sharex=self.axes)

        self.axes2.axis("on")

        # ax2 = fig.add_axes(rect2, axisbg=axescolor, sharex=ax1)
        # self.axes.axis('off')

        event = 0
        # self.loaddata(event)
        #        fn = '/home/danc/python/data/E0053/E0053_EEGSSPLR.pym'
        #        self.trials = 10
        #        #fn = '/home/danc/python/data/E0052/sim.pym'
        #        self.origdata, self.timeaxes, self.chlabels, self.srate = retrievepythondata(fn)
        siz = self.GetSize()
        self.canvas.SetClientSize((siz[0] - 1, siz[1] - 100))
        self.zoomin(event)
        self.pan(event)

        # ~
        # ~ from pylab import *
        # ~ try:
        # ~ import Image
        # ~ except ImportError, exc:
        # ~ raise SystemExit("PIL must be installed to run this example")
        # ~
        # lena = Image.open('AlbinoPython.jpg')
        # ~ dpi = rcParams['figure.dpi']
        # ~ figsize = lena.size[0]/dpi, lena.size[1]/dpi
        # ~
        # ~ figure(figsize=figsize)
        # ~ ax = axes([0,0,1,1], frameon=False)
        # ~ ax.set_axis_off()
        # self.axes.imshow(lena, origin='lower')

    # ~
    # show()

    def __set_properties(self):
        # begin wxGlade: MyFrame.__set_properties
        self.SetTitle("PyPlotter")
        self.SetSize((600, 600))
        self.frame_1_statusbar.SetStatusWidths([-1])
        # statusbar fields
        frame_1_statusbar_fields = ["PyPlotter"]
        for i in range(len(frame_1_statusbar_fields)):
            self.frame_1_statusbar.SetStatusText(frame_1_statusbar_fields[i], i)
        self.frame_1_toolbar.SetToolPacking(5)
        self.frame_1_toolbar.SetToolSeparation(5)
        self.frame_1_toolbar.Realize()
        self.viewer.SetBackgroundColour(wx.Colour(143, 143, 188))
        self.viewer.SetScrollRate(10, 10)
        self.combo_box_1.SetMinSize((87, 27))
        self.combo_box_1.SetSelection(0)
        self.bitmap_button_2.SetMinSize((55, 55))
        self.bitmap_button_2.Enable(False)
        self.bitmap_button_1.SetMinSize((55, 55))
        self.bitmap_button_1.Enable(False)
        self.bitmap_button_3.SetMinSize((55, 55))
        self.bitmap_button_3.Enable(False)
        self.bitmap_button_4.SetMinSize((55, 55))
        self.bitmap_button_4.Enable(False)
        self.slider_3.SetBackgroundColour(wx.Colour(216, 216, 191))
        self.slider_3.SetToolTipString("Time Slider (sec)")
        self.slider_3.SetFocus()
        self.slider_1.SetToolTipString("Number of channels")
        self.slider_1.Enable(False)
        self.button_1.SetBackgroundColour(wx.Colour(143, 143, 188))
        self.button_1.SetValue(1)
        self.slider_2.SetToolTipString("Amplitude")
        self.slider_2.Enable(False)
        # end wxGlade

    def __do_layout(self):
        # begin wxGlade: MyFrame.__do_layout
        sizer_1 = wx.BoxSizer(wx.VERTICAL)
        sizer_2 = wx.BoxSizer(wx.HORIZONTAL)
        sizer_1.Add(self.viewer, 3, wx.ALL | wx.EXPAND | wx.ALIGN_BOTTOM, 0)
        sizer_2.Add(self.combo_box_1, 0, wx.ALIGN_CENTER_HORIZONTAL | wx.ALIGN_CENTER_VERTICAL, 0)
        sizer_2.Add(self.bitmap_button_2, 0, wx.ALIGN_CENTER_VERTICAL, 0)
        sizer_2.Add(self.bitmap_button_1, 0, wx.ALIGN_CENTER_VERTICAL, 0)
        sizer_2.Add(self.bitmap_button_3, 0, wx.ALIGN_CENTER_VERTICAL, 0)
        sizer_2.Add(self.bitmap_button_4, 0, wx.ALIGN_CENTER_VERTICAL, 0)
        sizer_2.Add(self.slider_3, 6, wx.ALL | wx.EXPAND | wx.ALIGN_CENTER_HORIZONTAL | wx.ALIGN_CENTER_VERTICAL, 0)
        sizer_2.Add(self.slider_1, 1, wx.ALL | wx.EXPAND, 0)
        sizer_2.Add(self.button_1, 0, 0, 0)
        sizer_2.Add(self.slider_2, 0, wx.ALL | wx.EXPAND | wx.ALIGN_CENTER_VERTICAL, 0)
        sizer_1.Add(sizer_2, 0, wx.EXPAND | wx.ALIGN_BOTTOM, 0)
        self.SetSizer(sizer_1)
        self.Layout()
        self.Centre()
        # end wxGlade
        # ~ self.dpi = 100
        # ~ self.fig = Figure((8.0, 20.0), dpi=self.dpi)
        # ~ self.canvas = FigCanvas(self.viewer, -1, self.fig)
        # ~ sizer_1.Add(self.frame_1_toolbar, 5, wx.LEFT | wx.EXPAND)
        # ~ self.redraw(event)
        # ~ self.axes = self.fig.add_subplot(111,axisbg='#ababab')

    def passdata(data):  # , samplerate, channellabels):
        import sys

        print sys.argv
        import getopt

        opts, extraparams = getopt.getopt(sys.argv[1:])
        # ~ data = np.random.randn(10000,10)
        # ~ self.origdata = data
        # ~ self.sr = np.float32(100.0)
        # ~ sp = 1/frame.sr
        # ~ self.timeaxes = np.arange(0,sp*1000, sp)
        # ~ self.chlabels = np.arange(0,244)

    def go(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `go' "
        self.draw_figure(event)
        # self.redraw(event)

    def clear(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `clear' not implemented"
        print self.GetSize()
        print self.fig.get_size_inches()
        x = self.canvas.GetClientSize()
        print x
        # self.canvas.SetClientSize((x[0]/2,x[1]/2))
        x = self.GetSize()
        self.canvas.SetClientSize((x[0] / 0.8, x[1] / 0.9))
        # self.canvas.Destroy()
        # self.fig.clear()

    def redraw(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `redraw' not implemented"
        siz = self.GetSize()
        self.canvas.SetClientSize((siz[0] - 1, siz[1] - 100))
        self.canvas.Refresh(eraseBackground=True)

    def amplitudescale(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `amwxplitudescale' not implemented"
        print self.slider_2.GetValue()
        dmin = self.origdata.min()
        dmax = self.origdata.max()
        print dmin, dmax
        self.step = abs(dmin + dmax / 10) / self.slider_2.GetValue()
        print "step2", self.step
        # self.go(event)
        # ~ self.numchans(event)
        # return step/self.slider_2.GetValue()
        # self.settimewin(event)

    def numchans(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `numchans' not implemented"
        self.slider_1.SetRange(0, np.size(self.origdata, 1))
        self.data = self.origdata[:, 0 : self.slider_1.GetValue()]
        self.label2plot = self.chlabels[0 : self.slider_1.GetValue()]
        # self.settimewin(event)
        try:
            self.go(event)
        except AttributeError:
            print "error"

    def settimewin(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `settimewin' not implemented"
        # ~ self.startsec = self.slider_3.GetValue() * int(self.combo_box_1.GetValue().strip('sec'))
        # ~ self.endsec = self.startsec + int(self.combo_box_1.GetValue().strip('sec'))
        self.startsec = self.slider_3.GetValue()
        self.endsec = self.startsec + np.float32(self.combo_box_1.GetValue().strip("sec"))
        print "pp", self.startsec, self.endsec
        self.indstart = np.argmin(abs(self.startsec - self.timeaxes))
        self.indend = np.argmin(abs(self.endsec + -self.timeaxes))
        print "ind", self.indstart, self.indend
        self.canvas.Update()
        self.go(event)
        # self.go(event)

    def settimepnt(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `settimepnt' not implemented"
        self.settimewin(event)
        # ~ self.timepnt = self.slider_3.GetValue()
        # ~ self.indstart = argmin(abs(self.timepnt - self.timeaxes))
        # ~ self.indend = argmin(abs(self.timepnt+ - self.timeaxes))
        # ~
        # ~ print self.indstart,

    def setupwidgets(self):
        self.slider_1.Enable(enable=True)
        self.slider_2.Enable(enable=True)
        self.bitmap_button_1.Enable(True)
        self.bitmap_button_2.Enable(True)
        self.bitmap_button_3.Enable(True)
        self.bitmap_button_4.Enable(True)
        self.combo_box_1.Enable(True)
        x = []
        for n in range(0, np.size(self.chlabels)):
            x.append(str(n))
        # self.combo_box_1.SetItems(unicode(self.chlabels.tolist()))
        # self.combo_box_1.SetItems(x)
        print self.timeaxes[0], self.timeaxes[-1]
        self.slider_3.SetRange(self.timeaxes[0], self.timeaxes[-1])

    def panleft(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `panleft' not implemented"
        event.Skip()

    def panright(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `panright' not implemented"
        event.Skip()

    def panup(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `panup' not implemented"

        try:
            self.e = self.s
            self.s = self.e - self.slider_1.GetValue()
        except AttributeError:
            print "setting"
            self.e = self.slider_1.GetValue()
            self.s = self.e + self.e
        if self.s < 0:
            print "at start of channel set"
            self.s = 0
            self.e = self.s + self.slider_1.GetValue()

        self.data = self.origdata[:, self.s : self.e]
        print "du", self.s, self.e
        self.label2plot = self.chlabels[self.s : self.e]
        self.go(event)

    def pandown(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `pandown' not implemented"
        try:
            self.s = self.e
            self.e = self.s + self.slider_1.GetValue()
        except AttributeError:
            print "setting"
            self.s = self.slider_1.GetValue()
            self.e = self.s + self.s
        if self.e > np.size(self.origdata, 1):
            print "at end of channel set"
            self.e = np.size(self.origdata, 1)
            self.s = self.e - self.slider_1.GetValue()

        self.data = self.origdata[:, self.s : self.e]
        print "dd", self.s, self.e
        self.label2plot = self.chlabels[self.s : self.e]
        self.go(event)

    def loaddatagui(self, event):  # wxGlade: MyFrame.<event_handler>

        self.datapath = "/home/danc/python/data/0611/0611piez/e,rfhp1.0Hz,COH"
        import os

        print "Event handler `loaddatagui' not implemented"
        dlg = wx.FileDialog(self, "Select an Data file", os.getcwd(), "", "*", wx.OPEN)
        if dlg.ShowModal() == wx.ID_OK:
            self.datapath = dlg.GetPath()
            dlg.Destroy()

        # ~ dlg = wx.MessageDialog(self, 'First you need to load MRI data file', 'MRI file error', wx.OK|wx.ICON_INFORMATION)
        # ~ dlg.ShowModal()
        # ~ dlg.Destroy()

        # from gui import pysel
        # pysel.start()
        # p#rint 'done', pysel.fnlist
        self.origdata, self.timeaxes, self.chlabels, self.srate = retrievepdf(self.datapath)  # self.datapath)
        self.data = self.origdata
        self.chproperties(event)
        self.plotloadeddata(event)
        # chpropertywin.loadchannels(event)

    def plotloadeddata(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `plotdata' "
        # print pysel.fnlist
        # self.origdata, self.timeaxes, self.chlabels = getdata()
        self.data = self.origdata
        self.numchans(event)
        self.amplitudescale(event)
        self.settimewin(event)
        # print 'labels', label2plot
        self.setupwidgets()
        self.go(event)
        # self.numchans(event)
        print "step", self.step

    def loadpythondata(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `loadpythondata' not implemented"
        import os

        print "Event handler `loaddatagui' not implemented"
        dlg = wx.FileDialog(
            self,
            "Select a python file(s)",
            os.getcwd(),
            "",
            wildcard="Data File (*.pym)|*.pym|Dipole Report(*.drf)|*.drf",
        )
        if dlg.ShowModal() == wx.ID_OK:
            self.datapath = dlg.GetPath()
            dlg.Destroy()

        self.origdata, self.timeaxes, self.chlabels = retrievepythondata(self.datapath)  # self.datapath)

    def autoscalehandler(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `autoscalehandler' not implemented"
        if self.button_1.GetValue() == 0:
            print "manual scaling"
            scalewin.Show()

        else:
            print "auto scaling"
            scalewin.Hide()

    def tftinit(self, event):  # wxGlade: MyFrame.<event_handler>
        # if self.frame_1_toolbar.GetToolState(4) == True: #depressed
        from meg import timef

        t = timef.initialize()
        self.redraw(event)
        print self.tftch, self.trials, self.srate
        # print 's', np.shape(self.origdata[self.indstart:self.indend,19]),np.shape(self.data[self.indstart:self.indend,19])
        print "inds", self.indstart, "inde", self.indend
        # dif = np.size(self.data[self.indstart:self.indend:19],0); print 'dif', dif, self.indend, self.indstart
        dif = self.indend - self.indstart
        print "dif", dif, self.indend, self.indstart
        print "shape of data", np.shape(self.origdata[self.indstart : self.indend, self.tftch])
        t.calc(
            data=self.origdata[self.indstart : self.indend, self.tftch],
            trials=self.trials,
            srate=self.srate,
            frames=dif / self.trials,
            freqrange=[3.0, 100],
            cycles=[2, 0.5],
        )
        # self.axes.plot(self.data[self.indstart:self.indend:,i]+inc, color=[0,0,0])
        # self.axes2.imshow(abs(t.tmpallallepochs))#, aspect=6,extent=(int(t.timevals[0]), int(t.timevals[-1]), int(t.freqrange[1]), int(t.freqrange[0])));colorbar();show()
        # self.axes2.imshow(abs(t.tmpallallepochs))#,aspect = 1, extent=(int(t.timevals[0]), int(t.timevals[-1]), int(t.freqrange[1]), int(t.freqrange[0])))
        self.axes2.imshow(
            abs(t.tmpallallepochs),
            extent=(int(t.timevals[0]), int(t.timevals[-1]), int(t.freqrange[1]), int(t.freqrange[0])),
        )
        print "tftshape", np.shape(t.tmpallallepochs)
        # self.redraw(event)
        # self.axes2.update()
        self.canvas.draw()

        print "Event handler `tftinit' not implemented"
        event.Skip()

    def tftframe(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `tftframe' not implemented"
        tftframe.Show()

    def tftwindow(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `tftwindow' not implemented"
        tftframe.Show()

    def chproperties(self, event):  # wxGlade: MyFrame.<event_handler>
        print "Event handler `chproperties' not implemented"
        chpropertywin.Show()
        chpropertywin.list_box_channels.SetItems(frame.chlabels)
        for i in frame.chlabels:
            index = chpropertywin.list_ctrl_1.InsertStringItem(sys.maxint, str(i))
            chpropertywin.list_ctrl_1.SetStringItem(index, 1, i)
            # self.list_ctrl_1.SetItem(i)
            # self.list_ctrl_1.SetItem(frame.chlabels)
            chpropertywin.chcolorlist = np.tile((0, 0, 0), (len(frame.chlabels), 1))  # set all channels to black

    # end of class MyFrame
    # -----------------------------------------------------
    def pan(self, event):
        import time

        def on_press(event):
            self.clicktime = time.time()
            #            self.releasetime = False
            self.holdcheck = "yes"
            # on_move(event)

        #            try:
        #                if self.clicktime
        def on_release(event):
            self.releasetime = time.time()
            self.holdcheck = "no"
            print "you released", event.button, event.xdata, event.ydata
            self.canvas.draw()

        def on_move(event):
            try:
                if self.holdcheck == "yes":
                    print "your pointer is at", event.xdata, event.ydata
                    x1, x2 = self.axes.get_xlim()
                    y1, y2 = self.axes.get_ylim()
                    print x1, x2, y1, y2, "::::", x2 - x1, y2 - y1
                    self.axes.set_xlim(event.xdata + (x2 - x1) / 2, event.xdata - (x2 - x1) / 2)
                    self.axes.set_ylim(event.ydata + (y2 - y1) / 2, event.ydata - (y2 - y1) / 2)

            except AttributeError:
                pass
                # print 'no click yet'

        cid = self.fig.canvas.mpl_connect("button_press_event", on_press)
        rid = self.fig.canvas.mpl_connect("button_release_event", on_release)
        mousemove = self.fig.canvas.mpl_connect("motion_notify_event", on_move)

    def zoomin(self, event):
        import time

        def on_press(event):

            try:
                self.lasttime = self.curtime
            except AttributeError:
                self.lasttime = 0

            self.curtime = time.time()
            if np.round(self.curtime, 2) != np.round(self.lasttime, 2):
                tdiff = self.curtime - self.lasttime
                print "time diff", tdiff
                print "you pressed", event.button, event.xdata, event.ydata
                if tdiff < 0.25:
                    tdiff = 0
                    x1, x2 = self.axes.get_xlim()
                    y1, y2 = self.axes.get_ylim()
                    print x1, x2, y1, y2
                    # ~ self.axes.set_xlim(x1/2,x2/2)
                    # ~ self.axes.set_ylim(y1/2,y2/2)
                    # self.axes.set_xlim(event.xdata-(x2-x1)/2, event.xdata+(x2-x1)/2)
                    if event.button == 1:  # zoom in
                        self.axes.set_xlim(event.xdata - (x2 - x1) / 4, event.xdata + (x2 - x1) / 4)
                        self.axes.set_ylim(event.ydata - (y2 - y1) / 4, event.ydata + (y2 - y1) / 4)
                    if event.button == 3:  # zoom out
                        self.axes.set_xlim(event.xdata - (x2 - x1) * 1, event.xdata + (x2 - x1) * 1)
                        self.axes.set_ylim(event.ydata - (y2 - y1) * 1, event.ydata + (y2 - y1) * 1)

                    # self.axes.set_xlim(np.mean(event.xdata,x1), np.mean(event.xdata,x2))
                    # self.axes.set_ylim(np.mean(event.ydata,y1), np.mean(event.ydata,y2))

                    self.canvas.draw()

        cid = self.fig.canvas.mpl_connect("button_press_event", on_press)

    def axeslabels(self, data):  # , chlabels, timelabels):
        chlabels = np.arange(np.size(data, 1))
        timelabels = ["100", "200", "300"]
        # ~ if np.size(chlabels) != np.size(data,1):
        # ~ print np.size(chlabels), np.size(data,1)
        # ~ print 'channel label length differs from data'
        # ~ #return

        return chlabels, timelabels

    def draw_figure(self, event):

        # ~ try:
        # ~ count = count + 1
        # ~ except UnboundLocalError:
        # ~ count = 0
        # ~ if not hasattr(self, 'subplot'):
        # ~ self.axes = self.fig.add_subplot(111,axisbg='#ababab')
        # ~ self._resizeflag = False
        # ~ #self.axes.bar(left=10,height=100,width=100,align='center',alpha=0.44,picker=5)
        # ~ self.dpi = 50
        # ~ self.fig = Figure( None, dpi=self.dpi )
        # ~ self.canvas = FigCanvas(self.viewer, -1, self.fig)
        self.redraw(event)
        # ~ self.axes = self.fig.add_subplot(111,axisbg='#ababab')
        # self.canvas.Update()
        # self.canvas.UpdateRegion()
        # self.canvas.Refresh()
        self.axes.clear()
        # ~ self.canvas.Close()
        # self.canvas.gui_repaint()
        self.canvas.Show()
        # ~ print 'count',count

        # self.axes.axis('off')
        self.axes.grid("on")
        inc = 0
        print "drawshapestate", np.shape(self.origdata), np.shape(self.data), np.shape(
            self.timeaxes
        ), self.indstart, self.indend

        for i in range(0, np.size(self.data, 1))[::-1]:
            # self.axes.plot(self.origdata[self.indstart:self.indend:,i]+inc)
            colur = chpropertywin.chcolorlist[i] / 256.0

            # ~ if i < 2:
            # ~ colur = (0,0,1)
            # ~ else:
            # ~ colur = (0,0,0)
            self.axes.plot(self.data[self.indstart : self.indend :, i] + inc, color=colur)
            if self.frame_1_toolbar.GetToolState(4) == True:  # depressed
                # print 's', np.shape(self.origdata[self.indstart:self.indend,19]),np.shape(self.data[self.indstart:self.indend,19])
                # t.calc(data=self.origdata[self.indstart:self.indend,19], trials=10, srate=290.64,frames=dif/10, freqrange=[5.0,70], cycles=[2, .5])
                pass
                # self.tftinit(event)
                # self.axes2.imshow(abs(t.tmpallallepochs),aspect = 7, extent=(int(t.timevals[0]), int(t.timevals[-1]), int(t.freqrange[1]), int(t.freqrange[0])))
            # self.axes.plot(self.data[:,i]+inc)
            self.axes.text(0, inc, self.label2plot[i], color=[1, 0, 0])
            inc = self.step / 2 + inc
        # self.axes.update()
        # self.axes.set_xlim((self.indstart,self.indend))
        # self.canvas.Update()

    def _onIdle(self, evt):
        if self._resizeflag:
            self._resizeflag = False
            self._SetSize()
            print "idle"

    def _onSize(self, event):
        self._resizeflag = True

    def _SetSize(self):
        pixels = tuple(self.GetClientSize())
        self.SetSize(pixels)
        self.canvas.SetSize(pixels)
        self.fig.set_size_inches(float(pixels[0]) / self.fig.get_dpi(), float(pixels[1]) / self.fig.get_dpi())
Ejemplo n.º 41
0
class TreeFigure:
    def __init__(self, root, relwidth=0.5, leafpad=1.5, name=None,
                 support=70.0, scaled=True, mark_named=True,
                 leaf_fontsize=10, branch_fontsize=10,
                 branch_width=1, branch_color="black",
                 highlight_support=True,
                 branchlabels=True, leaflabels=True, decorators=[],
                 xoff=0, yoff=0,
                 xlim=None, ylim=None,
                 height=None, width=None):
        self.root = root
        self.relwidth = relwidth
        self.leafpad = leafpad
        self.name = name
        self.support = support
        self.scaled = scaled
        self.mark_named = mark_named
        self.leaf_fontsize = leaf_fontsize
        self.branch_fontsize = branch_fontsize
        self.branch_width = branch_width
        self.branch_color = branch_color
        self.highlight_support = highlight_support
        self.branchlabels = branchlabels
        self.leaflabels = leaflabels
        self.decorators = decorators
        self.xoff = xoff
        self.yoff = yoff

        nleaves = len(root.leaves())
        self.dpi = 72.0
        h = height or (nleaves*self.leaf_fontsize*self.leafpad)/self.dpi
        self.height = h
        self.width = width or self.height*self.relwidth
        ## p = min(self.width, self.height)*0.1
        ## self.height += p
        ## self.width += p
        self.figure = Figure(figsize=(self.width, self.height), dpi=self.dpi)
        self.canvas = FigureCanvas(self.figure)
        self.axes = self.figure.add_axes(
            tree.TreePlot(self.figure, 1,1,1,
                          support=self.support,
                          scaled=self.scaled,
                          mark_named=self.mark_named,
                          leaf_fontsize=self.leaf_fontsize,
                          branch_fontsize=self.branch_fontsize,
                          branch_width=self.branch_width,
                          branch_color=self.branch_color,
                          highlight_support=self.highlight_support,
                          branchlabels=self.branchlabels,
                          leaflabels=self.leaflabels,
                          interactive=False,
                          decorators=self.decorators,
                          xoff=self.xoff, yoff=self.yoff,
                          name=self.name).plot_tree(self.root)
            )
        self.axes.spines["top"].set_visible(False)
        self.axes.spines["left"].set_visible(False)
        self.axes.spines["right"].set_visible(False)
        self.axes.spines["bottom"].set_smart_bounds(True)
        self.axes.xaxis.set_ticks_position("bottom")

        for v in self.axes.node2label.values():
            v.set_visible(True)

        ## for k, v in self.decorators:
        ##     func, args, kwargs = v
        ##     func(self.axes, *args, **kwargs)

        self.canvas.draw()
        ## self.axes.home()
        ## adjust_limits(self.axes)
        self.axes.set_position([0.05,0.05,0.95,0.95])

    @property
    def detail(self):
        return self.axes
        
    def savefig(self, fname, format="pdf"):
        self.figure.savefig(fname, format = format)

    def set_relative_width(self, relwidth):
        w, h = self.figure.get_size_inches()
        self.figure.set_figwidth(h*relwidth)

    def autoheight(self):
        "adjust figure height to show all leaf labels"
        nleaves = len(self.root.leaves())
        h = (nleaves*self.leaf_fontsize*self.leafpad)/self.dpi
        self.height = h
        self.figure.set_size_inches(self.width, self.height)
        self.axes.set_ylim(-2, nleaves+2)

    def home(self):
        self.axes.home()
        
        
    def render_multipage(self, outfile, pagesize = [8.5, 11.0], 
                         dims = None, border = 0.393701, landscape = False):
        """
        Create a multi-page PDF document where the figure is cut into
        multiple pages. Used for printing large figures.
        
        
        Args:
            outfile (string): The path to the output file.
            pagesize (list): Two floats. Page size of each individual page
              in inches. Defaults to 8.5 x 11.0.
            dims (list): Two floats. The dimensions of the final figure in 
              inches. Defaults to the original size of the figure.
            border (float): The amount of overlap (in inches) between each page 
              to make taping them together easier. Defaults to 0.393701 (1 cm)
            landscape (bool): Whether or not each page will be in landscape
              orientation. Defaults to false.
        """
        pgwidth, pgheight = pagesize if not landscape \
                            else (pagesize[1], pagesize[0])
        #print "drawing width, height:", drawing.width/inch, drawing.height/inch
        if dims:
            self.width = dims[0]
            self.height = dims[1]
        else:
            self.width, self.height = self.figure.get_size_inches()
        if self.width > pgwidth - 2*border:
            scalefact = min(
                [(self.width-((self.width/pgwidth-1)*border*2))/self.width, 
                 (self.height-((self.height/pgheight-1)*border*2))/self.height])
            #self.figure.set_size_inches(scalefact*self.width, scalefact*self.height)
            #self.width = scalefact*self.width; self.height = scalefact*self.height
        else:
            scalefact = 1.0
        
        self.width *= scalefact # In inches
        self.height *= scalefact
        self.figure.set_size_inches([self.width, self.height])
      
        #border *= scalefact
        dwidth = self.width * 72.0 # In pixels (72 DPI)
        dheight = self.height * 72.0

        output = PdfFileWriter()
        outfile = file(outfile, "wb")

        buf = StringIO()
        self.savefig(buf)
        pgwidth = pgwidth*72
        pgheight = pgheight*72
        
        upper = border
        lower = 0
        right = pgwidth
        left = 0
        
        pgnum = 0
        vpgnum = 0
        hpgnum = 0
        
        border = border*72 # Converting to pixels in 72 DPI
        
        while upper < dheight:
            #if vpgnum == 0:
            #    vdelta = 0.0
            #else:
            #    vdelta = 2*border*vpgnum
            buf.seek(0)
            tmp = PdfFileReader(buf)
            page = tmp.getPage(0)
            box = page.mediaBox
            upper += pgheight-border
            lower = upper-pgheight
            #uly = float(box.getUpperLeft_y())
            #ulx = float(box.getUpperLeft_x())
            #upper = uly+border+vdelta-vpgnum*pgheight
            #lower = uly+border+delta-(pgnum+1)*pgheight
            #lower = upper-pgheight
            box.setUpperRight((right, upper))
            box.setUpperLeft((left, upper))
            box.setLowerRight((right, lower))
            box.setLowerLeft((left, lower))
            output.addPage(page)
            pgnum += 1
            vpgnum += 1
            if (upper >= dheight) & (right < dwidth):
                lower = 0
                upper = border
                right += pgwidth-border
                left = right-pgwidth
                vpgnum = 0

        output.write(outfile)
        return pgnum, scalefact