Ejemplo n.º 1
0
    update_cols_list()

    plot_coord_btn.configure(state="normal")
    plot_intensity_hist_btn.configure(state="normal")
    plot_intensity_scat_btn.configure(state="normal")

    plot_coord_btn.bind("<Button-1>", coord_plotter)
    plot_intensity_hist_btn.bind("<Button-1>", intensity_hist_plotter)
    plot_intensity_scat_btn.bind("<Button-1>", intensity_scatter_plotter)


load_thor_btn.bind("<Button-1>", open_thor)

#THE FIGURE AND THE CANVAS IN WHICH IT WILL BE DISPLAYED
fig = Figure(figsize=(8, 5), dpi=80)
ax = fig.subplots()

canvas_frame = Frame(root)
canvas = FigureCanvasTkAgg(figure=fig, master=canvas_frame)
canvas.show()
canvas.get_tk_widget().pack(side=BOTTOM, fill=BOTH, expand=True)
toolbar = NavigationToolbar2TkAgg(canvas, canvas_frame)
toolbar.update()
canvas._tkcanvas.pack(side=TOP, fill=BOTH, expand=True)
canvas_frame.grid(row=1, column=2)


#PLOT COORDINATES FROM THOR CATALOGUE
def coord_plotter(event):
    coord_plot_window = CoordPlot(root)
    coord_plot_window.coord_scatter_win.wait_window()
Ejemplo n.º 2
0
def quickLook_function(station, year, doy, snr_type, f, e1, e2, minH, maxH,
                       reqAmp, pele, satsel, PkNoise, fortran):
    """
    inputs:
    station name (4 char), year, day of year
    snr_type is the file extension (i.e. 99, 66 etc)
    f is frequency (1, 2, 5), etc
    e1 and e2 are the elevation angle limits in degrees for the LSP
    minH and maxH are the allowed LSP limits in meters
    reqAmp is LSP amplitude significance criterion
    pele is the elevation angle limits for the polynomial removal.  units: degrees
    KL 20may10 pk2noise value is now sent from main function, which can be set online
    KL 20aug07 added fortran boolean
    """
    # make sure environment variables exist
    g.check_environ_variables()

    webapp = False
    # orbit directories
    ann = g.make_nav_dirs(year)
    # titles in 4 quadrants - for webApp
    titles = ['Northwest', 'Southwest', 'Northeast', 'Southeast']
    # define where the axes are located
    bx = [0, 1, 0, 1]
    by = [0, 0, 1, 1]
    bz = [1, 3, 2, 4]

    # various defaults - ones the user doesn't change in this quick Look code
    delTmax = 70
    polyV = 4  # polynomial order for the direct signal
    desiredP = 0.01  # 1 cm precision
    ediff = 2  # this is a QC value, eliminates small arcs
    #four_in_one = True # put the plots together
    minNumPts = 20
    #noise region for LSP QC. these are meters
    NReg = [minH, maxH]
    print('noise region', NReg)
    # for quickLook, we use the four geographic quadrants - these are azimuth angles in degrees
    azval = [270, 360, 180, 270, 0, 90, 90, 180]
    naz = int(len(azval) / 2)  # number of azimuth pairs
    pltname = 'temp.png'  # default plot
    requireAmp = reqAmp[0]
    screenstats = True

    # to avoid having to do all the indenting over again
    # this allows snr file to live in main directory
    # not sure that that is all that useful as I never let that happen
    obsfile = g.define_quick_filename(station, year, doy, snr_type)
    if os.path.isfile(obsfile):
        print('>>>> The snr file exists ', obsfile)
    else:
        if True:
            print('look for the SNR file elsewhere')
            obsfile, obsfileCmp, snre = g.define_and_xz_snr(
                station, year, doy, snr_type)
            if snre:
                print('file exists on disk')
            else:
                print('>>>> The SNR the file does not exist ', obsfile)
                print('I will try to pick up a RINEX file ')
                print('and translate it for you. This will be GPS only.')
                print(
                    'For now I will check all the official archives for you.')
                rate = 'low'
                dec_rate = 0
                archive = 'all'
                rinex.conv2snr(year, doy, station, int(snr_type), 'nav', rate,
                               dec_rate, archive, fortran)
                if os.path.isfile(obsfile):
                    print('the SNR file now exists')
                else:
                    print(
                        'the RINEX file did not exist, had no SNR data, or failed to convert, so exiting.'
                    )
    allGood, sat, ele, azi, t, edot, s1, s2, s5, s6, s7, s8, snrE = read_snr_simple(
        obsfile)
    if allGood == 1:
        amax = 0
        minEdataset = np.min(ele)
        print('min elevation angle for this dataset ', minEdataset)
        if minEdataset > (e1 + 0.5):
            print('It looks like the receiver had an elevation mask')
            e1 = minEdataset
        if webapp:
            fig = Figure(figsize=(10, 6), dpi=120)
            axes = fig.subplots(2, 2)
        else:
            plt.figure()
        for a in range(naz):
            if not webapp:
                plt.subplot(2, 2, bz[a])
                plt.title(titles[a])
            az1 = azval[(a * 2)]
            az2 = azval[(a * 2 + 1)]
            # this means no satellite list was given, so get them all
            if satsel == None:
                satlist = g.find_satlist(f, snrE)
            else:
                satlist = [satsel]

            for satNu in satlist:
                x, y, Nv, cf, UTCtime, avgAzim, avgEdot, Edot2, delT = g.window_data(
                    s1, s2, s5, s6, s7, s8, sat, ele, azi, t, edot, f, az1,
                    az2, e1, e2, satNu, polyV, pele, screenstats)
                if Nv > minNumPts:
                    maxF, maxAmp, eminObs, emaxObs, riseSet, px, pz = g.strip_compute(
                        x, y, cf, maxH, desiredP, polyV, minH)
                    nij = pz[(px > NReg[0]) & (px < NReg[1])]
                    Noise = 0
                    iAzim = int(avgAzim)
                    if (len(nij) > 0):
                        Noise = np.mean(nij)
                    else:
                        Noise = 1
                        iAzim = 0  # made up numbers
                    if (delT < delTmax) & (eminObs < (e1 + ediff)) & (
                            emaxObs >
                        (e2 - ediff)) & (maxAmp > requireAmp) & (maxAmp / Noise
                                                                 > PkNoise):
                        T = g.nicerTime(UTCtime)
                        print(
                            'SUCCESS Azimuth {0:3.0f} RH {1:6.3f} m, Sat {2:3.0f} Freq {3:3.0f} Amp {4:4.1f} PkNoise {5:3.1f} UTC {6:5s} '
                            .format(avgAzim, maxF, satNu, f, maxAmp,
                                    maxAmp / Noise, T))
                        if not webapp:
                            plt.plot(px, pz, linewidth=1.5)
                        else:
                            axes[bx[a], by[a]].plot(px, pz, linewidth=2)
                            axes[bx[a], by[a]].set_title(titles[a])
                    else:
                        if not webapp:
                            plt.plot(px, pz, 'gray', linewidth=0.5)

            # i do not know how to add a grid using these version of matplotlib
            tt = 'GNSS-IR results: ' + station.upper() + ' Freq:' + str(
                f) + ' ' + str(year) + '/' + str(doy)
            aaa, bbb = plt.ylim()
            amax = max(amax, bbb)  # do not know how to implement this ...
            if (a == 3) or (a == 1):
                plt.xlabel('reflector height (m)')
        plt.suptitle(tt, fontsize=12)
        if webapp:
            fig.savefig('temp.png', format="png")
        else:
            plt.show()
    else:
        print(
            'some kind of problem with SNR file, so I am exiting the code politely.'
        )
Ejemplo n.º 3
0
def plot():
    """Show plot of symbol price over time"""

    # User reached route via POST (as by submitting a form via POST)
    if request.method == "POST":
        # Ensure symbol was submitted
        if not request.form.get("symbol"):
            return apology("must provide symbol", 403)

        # Ensure symbol valid
        query = lookup(request.form.get("symbol"))
        symbol = query["symbol"]
        price = query["price"]
        # Ensure symbol exists
        if not quote:
            return apology("invalid symbol")

        # Select symbol_id
        symbol_id = db.execute("SELECT id FROM symbols WHERE symbol = :symbol",
                               symbol=symbol)[0]["id"]

        # Select transactions
        data = db.execute(
            "SELECT price FROM transactions WHERE symbol_id = :symbol_id",
            symbol_id=symbol_id)

        # Create price and transacted lists
        prices = []
        for transaction in data:
            prices.append(transaction.get("price"))
        prices.append(price)

        data = db.execute(
            "SELECT transacted FROM transactions WHERE symbol_id = :symbol_id",
            symbol_id=symbol_id)
        transacted = []
        for transaction in data:
            transacted.append(transaction.get("transacted"))
        transacted.append(datetime.now().isoformat())

        # Generate figure
        fig = Figure()
        ax = fig.subplots()
        ax.plot(transacted, prices)

        # format the coords message box
        formatter = ticker.FormatStrFormatter('$%1.2f')
        ax.yaxis.set_major_formatter(formatter)
        for tick in ax.yaxis.get_major_ticks():
            tick.label1.set_color('green')

        # rotates and right aligns the x labels, and moves the bottom of the
        # axes up to make room for them
        fig.autofmt_xdate()

        # create more room for x-axis
        fig.tight_layout()

        # Save figure as temporary buffer
        buf = BytesIO()
        fig.savefig(buf, format="png")

        # Embed results in html output
        plot = base64.b64encode(buf.getbuffer()).decode("ascii")
        return render_template("plotted.html",
                               title=symbol,
                               plot=f"data:image/png;base64,{plot}")

    # User reached route via GET (as by clicking a link or via redirect)
    else:
        return render_template("/plot.html")
Ejemplo n.º 4
0
from tkinter import Tk
from tkinter.constants import *

import numpy
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

from matplotlib.figure import Figure
from matplotlib import pyplot as plt, animation

X = numpy.zeros((24, 32))
root = Tk()
plt.ion()
fig = Figure(figsize=(4, 4), dpi=100)
ax2 = fig.subplots()
im = ax2.imshow(X,
                cmap="nipy_spectral",
                clim=(20, 35),
                interpolation=None,
                animated=True)
canvas = FigureCanvasTkAgg(fig, master=root)
canvas.get_tk_widget().grid(row=0, column=1, sticky=NSEW)
root.mainloop()
Ejemplo n.º 5
0
class WindowClient(Frame):
    def __init__(self,
                 master=None,
                 user="",
                 socket_to_server=None,
                 write_obj=None):
        Frame.__init__(self, master)
        self.master = master
        self.username = user
        self.password = None
        self.socket_to_server = socket_to_server
        self.my_writer_obj = write_obj
        print(self.socket_to_server)
        #self.makeConnnectionWithServer()

        self.init_window()

        # Queue to store messages from server
        self.queue = Queue()

        # Thread to read messages from server
        receiver_thread = MessageHandler("Message-handler", self.my_writer_obj,
                                         self.queue)
        receiver_thread.start()

        self.monitor()

    # Creation of init_window
    def init_window(self):
        # Tkinter UI declarations

        self.master.title(f"client: {self.username}")

        Label(self, text="Berichten van moderator:",
              fg='black').grid(row=6, column=1, padx=(5, 5))
        self.mod_messages = Listbox(self, height=3)
        self.mod_messages.grid(row=7,
                               column=1,
                               sticky=N + S + E + W,
                               padx=(5, 5),
                               pady=(0, 5))

        self.server_reply_header = StringVar(self)
        self.server_reply_header.set("Placeholderdebolder")

        Label(self, text=self.server_reply_header,
              fg='black').grid(row=0, column=1, padx=(5, 5))
        self.server_reply_box = Text(self,
                                     width=20,
                                     height=20,
                                     font=("Times New Roman", 15),
                                     fg='black',
                                     bg='grey',
                                     bd=4)
        self.server_reply_box.grid(row=0,
                                   rowspan=8,
                                   column=2,
                                   sticky=E + W,
                                   padx=(5, 5),
                                   pady=(5, 5))
        self.server_reply_box.configure(state='disabled')

        self.artist_name = Entry(self, width=20)
        self.artist_counrty = Entry(self, width=20)
        self.artist_genre = Entry(self, width=20)

        self.artist_name.grid(row=0,
                              column=0,
                              sticky=E + W,
                              padx=(5, 5),
                              pady=(5, 5))
        self.artist_counrty.grid(row=1,
                                 column=0,
                                 sticky=E + W,
                                 padx=(5, 5),
                                 pady=(5, 0))
        self.artist_genre.grid(row=2,
                               column=0,
                               sticky=E + W,
                               padx=(5, 5),
                               pady=(5, 0))

        plt.tight_layout()

        self.fig = Figure(figsize=(3, 3))
        self.a = self.fig.subplots()

        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        #self.canvas.grid(row=0, rowspan=6, column=3, sticky=E + W, padx=(5, 5), pady =(5,5))
        self.canvas.get_tk_widget().grid(row=0,
                                         rowspan=8,
                                         column=3,
                                         sticky=E + W,
                                         padx=(5, 5),
                                         pady=(5, 5))

        Button(self,
               command=lambda: self.sendmessage("Q0", self.artist_name.get()),
               text="Search artist info by name",
               width=30).grid(row=0,
                              column=1,
                              sticky=E + W,
                              padx=(5, 5),
                              pady=(5, 5))
        Button(
            self,
            command=lambda: self.sendmessage("Q1", self.artist_counrty.get()),
            text="Search artist by country",
            width=30).grid(row=1,
                           column=1,
                           sticky=E + W,
                           padx=(5, 5),
                           pady=(5, 5))
        Button(self,
               command=lambda: self.sendmessage("Q2", self.artist_genre.get()),
               text="Search top artists by genre",
               width=30).grid(row=2,
                              column=1,
                              sticky=E + W,
                              padx=(5, 5),
                              pady=(5, 5))
        Button(self,
               command=lambda: self.sendmessage("Q3", "no param"),
               text="Give a histogram of most popular genres",
               width=60).grid(row=4,
                              column=0,
                              columnspan=2,
                              sticky=E + W,
                              padx=(5, 5),
                              pady=(5, 5))
        Button(self,
               command=lambda: self.sendmessage("Q4", "no param"),
               text="Give a histogram of most popular countries",
               width=10).grid(row=5,
                              column=0,
                              columnspan=2,
                              sticky=E + W,
                              padx=(5, 5),
                              pady=(5, 5))
        Button(self,
               command=lambda: self.close_connection,
               text="close connection",
               width=20).grid(row=6,
                              column=0,
                              sticky=E + S + W,
                              pady=(20, 0),
                              padx=(10, 0))

        Grid.rowconfigure(self, 8, weight=1)
        Grid.columnconfigure(self, 3, weight=1)

        self.pack(fill=BOTH, expand=1)

    def __del__(self):
        self.close_connection()

    # When client stops, let the server know
    def close_connection(self):
        try:
            logging.info("Close connection with server...")
            self.my_writer_obj.write("CLOSE\n")
            self.my_writer_obj.flush()
            self.socket_to_server.close()
        except Exception as ex:
            print("testerror")

        self.master.destroy()

    # Send a message to the server with query number & param
    def sendmessage(self, query_number, query_param):
        try:
            logging.info(f"Sending {query_number} with param {query_param}")
            self.my_writer_obj.write(
                f"{self.username};{query_number};{query_param}\n")
            self.my_writer_obj.flush()

        except Exception as ex:
            print(ex)

    # Function to create the client itself
    def create_client(username, socket_server, write_obj):
        root = Tk()
        root.geometry("1000x500")
        app = WindowClient(root, username, socket_server, write_obj)
        root.mainloop()

    # Write new data to the display list
    def write_to_list(self, data, header):
        self.server_reply_header.set(header)
        self.server_reply_box.configure(state='normal')
        self.server_reply_box.delete('1.0', END)
        for index, text in enumerate(data):
            self.server_reply_box.insert(INSERT, f'{index+1}. {text}\n')
        self.server_reply_box.configure(state='disabled')

    # Plot a new graph if user requested
    def plot_data(self, data, header):

        self.a.clear()

        my_df = pd.DataFrame(data)
        sns.barplot(x=0, y=1, data=my_df, ax=self.a)

        self.a.set(yticks=[], title=header)
        self.a.set_xticklabels(self.a.get_xticklabels(),
                               rotation=40,
                               ha="right")
        self.canvas.figure = self.fig
        self.canvas.draw_idle()

    # Monitor the Queue every 100ms to see if server has sent a new message
    def monitor(self):

        if self.queue.empty():
            pass
        else:
            message = self.queue.get_nowait()

            if message[0] != 'CLOSE':

                # If the message came from the moderator, print it to the log window
                if message[0] == 'moderator':
                    print(self.mod_messages)
                    self.mod_messages.insert(END, f"> Moderator: {message[1]}")
                    logging.info(f"> Moderator: {message[1]}")

                # If the message is for the list window
                elif message[0] == 'text_response':

                    pickle_data = jsonpickle.decode(message[1])
                    clean_data = json.loads(pickle_data)
                    self.write_to_list(clean_data, message[2])
                    logging.info(f"Answer server: {clean_data}")

                # If the message is a graph send it to the plotter
                elif message[0] == 'plot_response':

                    pickle_data = jsonpickle.decode(message[1])
                    self.plot_data(pickle_data, message[2])

        self.after(100, self.monitor)
Ejemplo n.º 6
0
class EEGFrame(wx.Frame):
    """GUI Frame in which data is plotted. Plots a subplot for every channel.
    Relies on a Timer to retrieve data at a specified interval. Data to be
    displayed is retrieved from a provided DataSource.

    Parameters:
    -----------
        data_source - object that implements the viewer DataSource interface.
        device_info - metadata about the data.
        seconds - how many seconds worth of data to display.
        downsample_factor - how much to compress the data. A factor of 1
            displays the raw data.
        refresh - time in milliseconds; how often to refresh the plots
    """
    def __init__(self,
                 data_source,
                 device_info: DeviceInfo,
                 seconds: int = 5,
                 downsample_factor: int = 2,
                 refresh: int = 500,
                 y_scale=100):
        wx.Frame.__init__(self, None, -1, 'EEG Viewer', size=(800, 550))

        self.data_source = data_source

        self.refresh_rate = refresh
        self.samples_per_second = device_info.fs
        self.records_per_refresh = int(
            (self.refresh_rate / 1000) * self.samples_per_second)

        self.channels = device_info.channels
        self.removed_channels = ['TRG', 'timestamp']
        self.data_indices = self.init_data_indices()

        self.seconds = seconds
        self.downsample_factor = downsample_factor
        self.filter = downsample_filter(downsample_factor, device_info.fs)

        self.autoscale = True
        self.y_min = -y_scale
        self.y_max = y_scale

        self.buffer = self.init_buffer()

        # figure size is in inches.
        self.figure = Figure(figsize=(12, 9),
                             dpi=80,
                             tight_layout={'pad': 0.0})
        # space between axis label and tick labels
        self.yaxis_label_space = 60
        self.yaxis_label_fontsize = 14
        # fixed width font so we can adjust spacing predictably
        self.yaxis_tick_font = 'DejaVu Sans Mono'
        self.yaxis_tick_fontsize = 10

        self.axes = self.init_axes()

        self.canvas = FigureCanvas(self, -1, self.figure)

        self.CreateStatusBar()

        # Toolbar
        self.toolbar = wx.BoxSizer(wx.VERTICAL)

        self.start_stop_btn = wx.Button(self, -1, "Start")

        self.timer = wx.Timer(self)
        self.Bind(wx.EVT_TIMER, self.update_view, self.timer)
        self.Bind(wx.EVT_BUTTON, self.toggle_stream, self.start_stop_btn)

        # Filtered checkbox
        self.sigpro_checkbox = wx.CheckBox(self, label="Filtered")
        self.sigpro_checkbox.SetValue(False)
        self.Bind(wx.EVT_CHECKBOX, self.toggle_filtering_handler,
                  self.sigpro_checkbox)

        # Autoscale checkbox
        self.autoscale_checkbox = wx.CheckBox(self, label="Autoscale")
        self.autoscale_checkbox.SetValue(self.autoscale)
        self.Bind(wx.EVT_CHECKBOX, self.toggle_autoscale_handler,
                  self.autoscale_checkbox)

        # Number of seconds text box
        self.seconds_choices = [2, 5, 10]
        if self.seconds not in self.seconds_choices:
            self.seconds_choices.append(self.seconds)
            self.seconds_choices.sort()
        opts = [str(x) + " seconds" for x in self.seconds_choices]
        self.seconds_input = wx.Choice(self, choices=opts)
        cur_sec_selection = self.seconds_choices.index(self.seconds)
        self.seconds_input.SetSelection(cur_sec_selection)
        self.Bind(wx.EVT_CHOICE, self.seconds_handler, self.seconds_input)

        controls = wx.BoxSizer(wx.HORIZONTAL)
        controls.Add(self.start_stop_btn, 1, wx.ALIGN_CENTER, 0)
        controls.Add(self.sigpro_checkbox, 1, wx.ALIGN_CENTER, 0)
        controls.Add(self.autoscale_checkbox, 1, wx.ALIGN_CENTER, 0)
        # TODO: pull right; currently doesn't do that
        controls.Add(self.seconds_input, 0, wx.ALIGN_RIGHT, 0)

        self.toolbar.Add(controls, 1, wx.ALIGN_CENTER, 0)
        self.init_channel_buttons()

        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.EXPAND)
        sizer.Add(self.toolbar, 0, wx.ALIGN_BOTTOM | wx.ALIGN_CENTER)
        self.SetSizer(sizer)
        self.SetAutoLayout(1)
        self.Fit()

        self.init_data()
        self.started = False
        self.start()

    def init_data_indices(self):
        """List of indices of all channels which will be displayed. By default
        filters out TRG channels and any channels marked as
        removed_channels."""

        return [
            i for i in range(len(self.channels))
            if self.channels[i] not in self.removed_channels
            and 'TRG' not in self.channels[i]
        ]

    def init_buffer(self):
        """Initialize the underlying RingBuffer by pre-allocating empty
        values. Buffer size is determined by the sample frequency and the
        number of seconds to display."""

        buf_size = int(self.samples_per_second * self.seconds)
        empty_val = [0.0 for _x in self.channels]
        buf = RingBuffer(buf_size, pre_allocated=True, empty_value=empty_val)
        return buf

    def init_axes(self):
        """Sets configuration for axes. Creates a subplot for every data
        channel and configures the appropriate labels and tick marks."""

        axes = self.figure.subplots(len(self.data_indices), 1, sharex=True)
        for i, channel in enumerate(self.data_indices):
            ch_name = self.channels[channel]
            axes[i].set_frame_on(False)
            axes[i].set_ylabel(ch_name,
                               rotation=0,
                               labelpad=self.yaxis_label_space,
                               fontsize=self.yaxis_label_fontsize)
            # x-axis shows seconds in 0.5 sec increments
            tick_names = np.arange(0, self.seconds, 0.5)
            ticks = [(self.samples_per_second * sec) / self.downsample_factor
                     for sec in tick_names]
            axes[i].xaxis.set_major_locator(ticker.FixedLocator(ticks))
            axes[i].xaxis.set_major_formatter(
                ticker.FixedFormatter(tick_names))
            axes[i].tick_params(axis='y',
                                which='major',
                                labelsize=self.yaxis_tick_fontsize)
            for tick in axes[i].get_yticklabels():
                tick.set_fontname(self.yaxis_tick_font)
            axes[i].grid()
        return axes

    def reset_axes(self):
        """Clear the data in the GUI."""
        self.figure.clear()
        self.axes = self.init_axes()

    def init_channel_buttons(self):
        """Add buttons for toggling the channels."""

        channel_box = wx.BoxSizer(wx.HORIZONTAL)
        for channel_index in self.data_indices:
            channel = self.channels[channel_index]
            chkbox = wx.CheckBox(self, label=channel, id=channel_index)
            chkbox.SetValue(channel not in self.removed_channels)

            self.Bind(wx.EVT_CHECKBOX, self.toggle_channel, chkbox)
            channel_box.Add(chkbox, 0, wx.ALIGN_CENTER, 0)

        self.toolbar.Add(channel_box, 1, wx.ALIGN_LEFT, 0)

    def toggle_channel(self, event):
        """Remove the provided channel from the display"""
        channel_index = event.GetEventObject().GetId()
        channel = self.channels[channel_index]

        previously_running = self.started
        if self.started:
            self.stop()

        if channel in self.removed_channels:
            self.removed_channels.remove(channel)
        else:
            self.removed_channels.append(channel)
        self.data_indices = self.init_data_indices()
        self.reset_axes()
        self.init_data()
        self.canvas.draw()
        if previously_running:
            self.start()

    def start(self):
        """Start streaming data in the viewer."""
        # update buffer with latest data on (re)start.
        self.update_buffer(fast_forward=True)
        self.timer.Start(self.refresh_rate)
        self.started = True
        self.start_stop_btn.SetLabel("Pause")

    def stop(self):
        """Stop/Pause the viewer."""
        self.timer.Stop()
        self.started = False
        self.start_stop_btn.SetLabel("Start")

    def toggle_stream(self, _event):
        """Toggle data streaming"""
        if self.started:
            self.stop()
        else:
            self.start()

    def toggle_filtering_handler(self, event):
        """Event handler for toggling data filtering."""
        self.with_refresh(self.toggle_filtering)

    def toggle_filtering(self):
        """Toggles data filtering."""
        if self.sigpro_checkbox.GetValue():
            self.filter = stream_filter(self.downsample_factor,
                                        self.samples_per_second)
        else:
            self.filter = downsample_filter(self.downsample_factor,
                                            self.samples_per_second)

    def toggle_autoscale_handler(self, event):
        """Event handler for toggling autoscale"""
        self.with_refresh(self.toggle_autoscale)

    def toggle_autoscale(self):
        """Sets autoscale to checkbox value"""
        self.autoscale = self.autoscale_checkbox.GetValue()

    def seconds_handler(self, event):
        """Event handler for changing seconds"""
        self.with_refresh(self.update_seconds)

    def update_seconds(self):
        """Set the number of seconds worth of data to display from the
        pulldown list."""
        self.seconds = self.seconds_choices[self.seconds_input.GetSelection()]
        self.buffer = self.init_buffer()

    def with_refresh(self, fn):
        """Performs the given action and refreshes the display."""
        previously_running = self.started
        if self.started:
            self.stop()
        fn()
        # re-initialize
        self.reset_axes()
        self.init_data()
        if previously_running:
            self.start()

    def current_data(self):
        """Returns the data as an np array with a row of floats for each
        displayed channel."""

        # array of 'object'; TRG data may be strings
        data = np.array(self.buffer.data)

        # select only data columns and convert to float
        return np.array(data[:, self.data_indices],
                        dtype='float64').transpose()

    def cursor_x(self):
        """Current cursor position (x-axis), accounting for downsampling."""
        return self.buffer.cur // self.downsample_factor

    def init_data(self):
        """Initialize the data."""
        channel_data = self.filter(self.current_data())

        for i, _channel in enumerate(self.data_indices):
            data = channel_data[i].tolist()
            self.axes[i].plot(data, linewidth=0.8)
            # plot cursor
            self.axes[i].axvline(self.cursor_x(), color='r')

    def update_buffer(self, fast_forward=False):
        """Update the buffer with latest data from the datasource and return
        the data. If the datasource does not have the requested number of
        samples, viewer streaming is stopped."""
        try:
            records = self.data_source.next_n(self.records_per_refresh,
                                              fast_forward=fast_forward)
            for row in records:
                self.buffer.append(row)
        except StopIteration:
            self.stop()
            # close the Wx.Frame to shutdown the viewer application
            self.Close()
        except BaseException:
            self.stop()
        return self.buffer.get()

    def update_view(self, _evt):
        """Called by the timer on refresh. Updates the buffer with the latest
        data and refreshes the plots. This is called on every tick."""
        self.update_buffer()
        channel_data = self.filter(self.current_data())

        # plot each channel
        for i, _channel in enumerate(self.data_indices):
            data = channel_data[i].tolist()
            self.axes[i].lines[0].set_ydata(data)
            # cursor line
            self.axes[i].lines[1].set_xdata(self.cursor_x())
            if self.autoscale:
                data_min = min(data)
                data_max = max(data)
                self.axes[i].set_ybound(lower=data_min, upper=data_max)

                # For ylabels to be aligned consistently, labelpad is
                # re-calculated on every draw.
                ch_name = self.channels[_channel]
                tick_labels = self.axes[i].get_yticks()
                # Min tick value does not display so index is 1, not 0.
                pad = self.adjust_padding(int(tick_labels[1]),
                                          int(tick_labels[-1]))
                self.axes[i].set_ylabel(ch_name,
                                        rotation=0,
                                        labelpad=pad,
                                        fontsize=14)
            else:
                # lower=min(data), upper=max(data))
                self.axes[i].set_ybound(lower=self.y_min, upper=self.y_max)

        self.canvas.draw()

    def adjust_padding(self, data_min: int, data_max: int) -> int:
        """Attempts to keep the channel labels in the same position by adjusting
        the padding between the yaxis label and the yticks."""
        digits_min = len(str(data_min))
        digits_max = len(str(data_max))
        chars = max(digits_min, digits_max)

        # assume at least 2 digits to start.
        baseline_chars = 2
        # Value determined by trial and error; this may change if the tick font
        # or font size is adjusted.
        ytick_digit_width = 7
        return self.yaxis_label_space - \
            ((chars - baseline_chars) * ytick_digit_width)
Ejemplo n.º 7
0
    def setup_plot(cls,
                   width=16,
                   height=4,
                   ncols=1,
                   nrows=1,
                   interactive=None,
                   link_dataframes=None,
                   cursor_delta=None,
                   **kwargs):
        """
        Common helper for setting up a matplotlib plot

        :param width: Width of the plot (inches)
        :type width: int or float

        :param height: Height of each subplot (inches)
        :type height: int or float

        :param ncols: Number of plots on a single row
        :type ncols: int

        :param nrows: Number of plots in a single column
        :type nrows: int

        :param link_dataframes: Link the provided dataframes to the axes using
            :func:`lisa.notebook.axis_link_dataframes`
        :type link_dataframes: list(pandas.DataFrame) or None

        :param cursor_delta: Add two vertical lines set with left and right
            clicks, and show the time delta between them in a widget.
        :type cursor_delta: bool or None

        :param interactive: If ``True``, use the pyplot API of matplotlib,
            which integrates well with notebooks. However, it can lead to
            memory leaks in scripts generating lots of plots, in which case it
            is better to use the non-interactive API. Defaults to ``True`` when
            running under IPython or Jupyter notebook, `False`` otherwise.
        :type interactive: bool

        :Keywords arguments: Extra arguments to pass to
          :obj:`matplotlib.figure.Figure.subplots`

        :returns: tuple(matplotlib.figure.Figure, matplotlib.axes.Axes (or an
          array of, if ``nrows`` > 1))
        """

        running_ipython = is_running_ipython()
        if interactive is None:
            interactive = running_ipython

        if tuple(map(int, matplotlib.__version__.split('.'))) <= (3, 0, 3):
            warnings.warn(
                'This version of matplotlib does not allow saving figures from axis created using Figure(), forcing interactive=True'
            )
            interactive = True

        if interactive:
            figure, axes = plt.subplots(ncols=ncols,
                                        nrows=nrows,
                                        figsize=(width, height * nrows),
                                        **kwargs)
        else:
            figure = Figure(figsize=(width, height * nrows))
            axes = figure.subplots(ncols=ncols, nrows=nrows, **kwargs)

        if isinstance(axes, Iterable):
            ax_list = axes
        else:
            ax_list = [axes]

        use_widgets = interactive and running_ipython

        if link_dataframes:
            if not use_widgets:
                cls.get_logger().error(
                    'Dataframes can only be linked to axes in interactive widget plots'
                )
            else:
                for axis in ax_list:
                    axis_link_dataframes(axis, link_dataframes)

        if cursor_delta or cursor_delta is None and use_widgets:
            if not use_widgets and cursor_delta is not None:
                cls.get_logger().error(
                    'Cursor delta can only be used in interactive widget plots'
                )
            else:
                for axis in ax_list:
                    axis_cursor_delta(axis)

        # Needed for multirow plots to not overlap with each other
        figure.set_tight_layout(dict(h_pad=3.5))
        return figure, axes
Ejemplo n.º 8
0
def track_analytics(track_id):
    """ Route including usage of graphs created with matplotlib 
    # Main idea:
    # - Don't create static image files to prevent mass build-up of .jpg, .png files in directory
    # - Instead use BytesIO to hold binary stream of data in memory
    # - Follow by extracting bytes from object into plot file and embed in HTML directly
    """
    track_features = helpers.get_features(session["authorization_header"],
                                          track_id)
    track_info = helpers.get_track(session["authorization_header"], track_id)
    required_data = [
        "acousticness", "danceability", "energy", "instrumentalness",
        "liveness", "speechiness", "valence"
    ]
    feature_data = {}
    for feature, data in track_features.items():
        if feature in required_data:
            feature_data[feature] = data
    # Sort data to match in polar axis bar chart
    labels = []
    label_data = []
    for key, value in feature_data.items():
        labels.append(key)
        label_data.append(value)
    # Generate figure without using pyplot (to prevent main thread error)
    fig = Figure()
    # Define number of bars (on polar axis)
    N = len(required_data)
    # Define x coordinates of the bars
    theta = np.linspace(0.0, 2 * np.pi, N, endpoint=False)
    # Create an array object satisfying the specified requirements
    height = np.array(label_data)
    # Define colors of bar chart
    colors = [
        "#1DB954", "#39C269", "#56CB7F", "#72D394", "#8EDCAA", "#AAE5BF",
        "#C7EDD4"
    ]
    # Add subplot to figure object
    ax = fig.add_subplot(111, projection='polar')
    ax.set_facecolor("#757575")
    # Set figure title and design
    title = "Audio Features"
    ax.set_title(title,
                 fontfamily='sans-serif',
                 fontsize='xx-large',
                 fontweight='bold',
                 pad=15.0,
                 color='#ffffff')
    # Edit radius labels
    ax.set_rlim(0.0, 1)
    ax.set_rlabel_position(+15)
    ax.tick_params(axis='both', colors='#ffffff', pad=22, labelsize='large')
    # Add bar chart
    ax.bar(x=theta,
           height=height,
           width=0.8,
           bottom=0.0,
           alpha=0.7,
           tick_label=labels,
           color=colors,
           edgecolor="black")
    # Ensure labels don't go out of Figure view
    fig.tight_layout()
    # Save to temporary buffer
    buf = BytesIO()
    fig.savefig(buf, format='png', facecolor='#757575')
    plot_url_features = base64.b64encode(buf.getbuffer()).decode()
    # b64encode(): Encodes the bytes-like object using Base64 and returns the encoded bytes
    # getvalue(): Returns bytes containing the entire contents of the buffer
    # decode(): Decodes the contents of the binary input file and write the resulting binary data to the output file
    # Image can be shown with following syntax: data:[<mime type>][;charset=<charset>][;base64],<encoded data>
    # eg. <img src="data:image/png;base64,{}">'.format(plot_url)

    # Visualize track analysis data
    track_analysis = helpers.get_analysis(session["authorization_header"],
                                          track_id)
    segments = []
    loudness = []
    for segment in track_analysis["segments"]:
        segments.append(segment["start"])
        loudness.append(segment["loudness_start"])

    # Make a new figure to display track loudness
    fig2 = Figure(figsize=(15, 4))
    # Create plots
    ax1 = fig2.subplots()
    ax1.plot(segments, loudness, color="#1DB954", linewidth=1.1)
    ax1.set_facecolor("#757575")
    # Configure x- and y-limit for the graph
    ax1.set_xlim(left=0.0, right=track_features["duration_ms"] / 1000)
    ax1.set_ylim(bottom=-60.0, top=0.0)
    # Configure ticks and labels
    ax1.set_xticks(
        np.linspace(0.0, track_features["duration_ms"] / 1000, endpoint=False))
    ax1.set_yticklabels([])
    ax1.tick_params(direction="inout", labelrotation=45.0, pad=-0.5)
    ax1.set(xlabel="Time (s)", ylabel="Loudness", title="Track loudness")
    # Draw grid lines
    ax1.grid()
    # Ensure xlabel doesn't go out of Figure view
    fig2.tight_layout()
    # Save to temporary buffer
    buf2 = BytesIO()
    fig2.savefig(buf2, format='png')
    plot_url_loudness = base64.b64encode(buf2.getbuffer()).decode()

    return render_template("track-analytics.html",
                           track_features=plot_url_features,
                           track_features_data=feature_data,
                           tempo=int(track_features["tempo"]),
                           key=track_features["key"],
                           track_loudness=plot_url_loudness,
                           track_info=track_info)
Ejemplo n.º 9
0
class ProfilesFrame(wx.Frame):
    """The main frame of the application"""

    title = "Reference profile viewer"

    def __init__(self, profiles):
        wx.Frame.__init__(self, None, -1, self.title)

        self.profiles = profiles
        self.data = self.profiles.get_profiles(experiment=0, block=0)

        self.create_menu()
        self.create_status_bar()
        self.create_main_panel()
        self.draw_figure()

    def create_menu(self):
        self.menubar = wx.MenuBar()

        menu_file = wx.Menu()
        m_expt = menu_file.Append(-1, "&Save plot\tCtrl-S",
                                  "Save plot to file")
        self.Bind(wx.EVT_MENU, self.on_save_plot, m_expt)
        menu_file.AppendSeparator()
        m_exit = menu_file.Append(-1, "E&xit\tCtrl-X", "Exit")
        self.Bind(wx.EVT_MENU, self.on_exit, m_exit)

        menu_help = wx.Menu()
        m_about = menu_help.Append(-1, "&About\tF1", "About the program")
        self.Bind(wx.EVT_MENU, self.on_about, m_about)

        self.menubar.Append(menu_file, "&File")
        self.menubar.Append(menu_help, "&Help")
        self.SetMenuBar(self.menubar)

    def create_main_panel(self):
        """Creates the main panel with all the controls on it:
        * mpl canvas
        * mpl navigation toolbar
        * Control panel for interaction
        """
        self.panel = wx.Panel(self)

        # Create the mpl Figure and FigCanvas objects.
        # 7x5 inches, 100 dots-per-inch
        self.dpi = 100
        self.fig = Figure((7.0, 5.0), dpi=self.dpi)
        self.canvas = FigCanvas(self.panel, -1, self.fig)

        self.set_axes()

        self.expt_selection_label = wx.StaticText(self.panel, -1,
                                                  "Experiment ID: ")
        self.expt_selection = wx.SpinCtrl(
            self.panel, -1, "0", max=self.profiles.get_n_experiments() - 1)
        self.Bind(wx.EVT_SPINCTRL, self.on_spin_expt, self.expt_selection)

        self.block_selection_label = wx.StaticText(self.panel, -1, "Block: ")
        self.block_selection = wx.SpinCtrl(
            self.panel,
            -1,
            "0",
            max=self.profiles.get_n_blocks(experiment=0) - 1)
        self.Bind(wx.EVT_SPINCTRL, self.on_spin_block, self.block_selection)

        self.mask_checkbox = wx.CheckBox(self.panel, -1, "Mask")
        self.mask_checkbox.SetValue(True)
        self.Bind(wx.EVT_CHECKBOX, self.redraw_on_event, self.mask_checkbox)

        self.cmap_choice_label = wx.StaticText(self.panel, -1, "Colourmap: ")
        self.cmap_choice = wx.Choice(
            self.panel, -1, choices=["viridis", "viridis_r", "gray", "gray_r"])
        self.cmap_choice.SetSelection(0)
        self.Bind(wx.EVT_CHOICE, self.redraw_on_event, self.cmap_choice)

        # Create the navigation toolbar, tied to the canvas
        self.toolbar = NavigationToolbar(self.canvas)

        # Layout with box sizers
        self.vbox = wx.BoxSizer(wx.VERTICAL)
        self.vbox.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
        self.vbox.Add(self.toolbar, 0, wx.EXPAND)
        self.vbox.AddSpacer(10)

        self.hbox = wx.BoxSizer(wx.HORIZONTAL)
        flags = wx.ALIGN_LEFT | wx.ALL | wx.ALIGN_CENTER_VERTICAL
        self.hbox.Add(self.expt_selection_label, 0, flag=flags)
        self.hbox.Add(self.expt_selection, 0, border=3, flag=flags)
        self.hbox.Add(self.block_selection_label, 0, flag=flags)
        self.hbox.Add(self.block_selection, 0, border=3, flag=flags)
        self.hbox.Add(self.mask_checkbox, 0, border=3, flag=flags)
        self.hbox.Add(self.cmap_choice_label, 0, flag=flags)
        self.hbox.Add(self.cmap_choice, 0, border=3, flag=flags)

        self.vbox.Add(self.hbox, 0, flag=wx.ALIGN_LEFT | wx.TOP)

        self.panel.SetSizer(self.vbox)
        self.vbox.Fit(self)

    def set_axes(self):
        subplots = [e["subplot"] for e in self.data]
        r, c = list(zip(*subplots))
        nrows = max(r) + 1
        ncols = max(c) + 1
        self.axes = self.fig.subplots(nrows, ncols, sharex=True, sharey=True)

    def create_status_bar(self):
        self.statusbar = self.CreateStatusBar()

    def draw_figure(self):
        """Redraws the figure"""

        final_row_index = self.axes.shape[0] - 1
        for profile in self.data:
            subplot = profile["subplot"]
            ax = self.axes[subplot]
            ax.clear()

            # For now, let's just sum down the first axis. Profiles are stored
            # in ez, ey, ex order, where ex is orthogonal to s1 and s0, ey is
            # orthogonal to s1 and ex, and ez is the axis that is dependent on
            # the direction through the Ewald sphere
            vals2D = profile["data"].sum(axis=0)
            cmap = copy.copy(
                matplotlib.cm.get_cmap(
                    self.cmap_choice.GetString(
                        self.cmap_choice.GetSelection())))

            # If any X, Y position is masked down the summed Z stack then mask
            # it in the final image.
            if self.mask_checkbox.IsChecked():
                mask2D = (profile["mask"] - 1).sum(axis=0)
                mask2D = mask2D != 0
                vals2D[mask2D] = numpy.nan
            cmap.set_bad(color="red")
            ax.imshow(vals2D, cmap=cmap)

            if subplot[0] == final_row_index:
                ax.set_xlabel(f"X (px): {profile['coord'][0]:.1f}")
            if subplot[1] == 0:
                ax.set_ylabel(f"Y (px): {profile['coord'][1]:.1f}")

        self.fig.suptitle(f"Block Z (im): {profile['coord'][2]:.1f}")
        self.canvas.draw()

    def on_spin_expt(self, event):
        self.expt_selection.Disable()
        exp_id = self.expt_selection.GetValue()
        self.data = self.profiles.get_profiles(experiment=exp_id, block=0)
        self.block_selection.SetValue(0)
        self.block_selection.SetMax(self.profiles.get_n_blocks(exp_id) - 1)
        self.draw_figure()
        self.expt_selection.Enable()

    def on_spin_block(self, event):
        self.block_selection.Disable()
        self.data = self.profiles.get_profiles(
            experiment=self.expt_selection.GetValue(),
            block=self.block_selection.GetValue(),
        )
        self.draw_figure()
        self.block_selection.Enable()

    def redraw_on_event(self, event):
        self.draw_figure()

    def on_save_plot(self, event):
        file_choices = "PNG (*.png)|*.png"

        dlg = wx.FileDialog(
            self,
            message="Save plot as...",
            defaultDir=os.getcwd(),
            defaultFile="reference_profiles.png",
            wildcard=file_choices,
            style=wx.SAVE,
        )

        if dlg.ShowModal() == wx.ID_OK:
            path = dlg.GetPath()
            self.canvas.print_figure(path, dpi=self.dpi)
            self.flash_status_message("Saved to %s" % path)

    def on_exit(self, event):
        self.Destroy()

    def on_about(self, event):
        msg = """A reference profile viewer for DIALS:

         * The reference profiles for the selected block of
           images (Z) are displayed along with their X, Y
           pixel positions
         * The display sums slices of the Kabsch space shoebox
           down the e3 direction, hence the view is the
           integration of the profile as it passes through
           the Ewald sphere
         * If any of the slices contains a masked pixel, the
           equivalent pixel in the summed image is also
           showed as masked
         * Save the plot to a file using the File menu
        """
        dlg = wx.MessageDialog(self, msg, "About", wx.OK)
        dlg.ShowModal()
        dlg.Destroy()

    def flash_status_message(self, msg, flash_len_ms=1500):
        self.statusbar.SetStatusText(msg)
        self.timeroff = wx.Timer(self)
        self.Bind(wx.EVT_TIMER, self.on_flash_status_off, self.timeroff)
        self.timeroff.Start(flash_len_ms, oneShot=True)

    def on_flash_status_off(self, event):
        self.statusbar.SetStatusText("")
Ejemplo n.º 10
0
def model(filename, features, dep_var):
    global data
    global test_size

    fig_data = []

    features = ast.literal_eval(features)
    features = [n.strip() for n in features]

    y = data[dep_var]

    sns.set()
    if request.method == "POST":

        if "train_test_split" in request.form and re.match(
                TWO_DIGIT_NUM, request.form["train_test_split"]):
            test_size = 1 - (int(request.form["train_test_split"]) / 100)

            model_fig, metrics_dict, regressor = train(features, dep_var,
                                                       test_size)

            return render_template(
                "plot.html",
                model_fig=model_fig,
                metrics=metrics_dict,
                filename=os.path.splitext(filename)[0],
            )
        elif "test_features" in request.form and re.match(
                NUM_LIST, request.form["test_features"]):
            test_features_list = [
                float(feature)
                for feature in request.form["test_features"].split(",")
            ]
            test_features = np.asarray(test_features_list)
            model_fig, metrics_dict, regressor = train(features, dep_var,
                                                       test_size)
            pred = regressor.predict(test_features.reshape(1, -1))

            return render_template(
                "plot.html",
                model_fig=model_fig,
                metrics=metrics_dict,
                filename=os.path.splitext(filename)[0],
                pred=pred,
            )

        else:
            flash("Invalid Input!")
            return redirect(request.url)

    for feature in features:
        X = data[feature]
        fig = Figure()
        ax = fig.subplots()
        ax.scatter(X, y)
        ax.set_title(f"{feature} vs {dep_var}")
        ax.set_xlabel(feature)
        ax.set_ylabel(dep_var)

        buf = BytesIO()
        fig.savefig(buf, format="png")
        fig_data.append(base64.b64encode(buf.getbuffer()).decode("ascii"))

    return render_template("plot.html",
                           fig_data=fig_data,
                           filename=os.path.splitext(filename)[0])
Ejemplo n.º 11
0
 def __create_raw_canvas(self):
     sns.set(font_scale=1)
     f = Figure()
     self._ax = f.subplots(nrows=6, ncols=2, sharex='all')
     self._canvas = FigureCanvasTkAgg(f, self)
     self._canvas.get_tk_widget().place(x=0, y=0, width=800, height=500)
Ejemplo n.º 12
0
def display_quantiles_flask(prediction,
                            target_ts=None,
                            bench_mark_prediction=None,
                            bench_mark_prediction_name=None,
                            start=None):
    """
    Show predictions for input time series against comparison values in a Flask application
    (avoids a memory leak that may occur using pyplot as in th `display_quantiles` function)
    This function is conceived to be used by a Flask application, and it will not make use of pandas
    :param prediction: time series prediction prediction produced by a DeepAR model
    :param target_ts: prediction target time series
    :param bench_mark_prediction: benchmark model prediction
    :param bench_mark_prediction_name: benchmark model name to be shown in legend
    :param start: plot start date
    :return: a <img> HTML5 element containing the plot
    """

    fig = Figure()
    ax = fig.subplots()
    if start is not None:
        # retrieving x-ticks
        if isinstance(start, str) and start != "":
            start_date = datetime.strptime(start, "%Y-%m-%d %H:%M:%S").date()
        elif isinstance(start, datetime):
            start_date = start.date()
        elif isinstance(start, datetime.date):
            starte_date = start
        else:
            print("Enter only string or date as start values")
        x_ticks = [
            start_date + x * timedelta(days=1)
            for x in range(len(prediction['0.5']))
        ]
        ax.set_xticklabels(
            ["{}/{}".format(x_tick.day, x_tick.month) for x_tick in x_ticks])
    if target_ts is not None:
        ax.plot(target_ts[-len(x_ticks):], label='real Adjusted Close')

    # get the quantile values at 10 and 90%
    p10 = np.array(prediction['0.1'], dtype=float)
    p50 = np.array(prediction['0.5'], dtype=float)
    p90 = np.array(prediction['0.9'], dtype=float)

    # fill the 80% confidence interval
    ax.fill_between(range(0, len(p10)),
                    p10,
                    p90,
                    color='y',
                    alpha=0.5,
                    label='80% confidence interval')

    # plot the median prediction line
    ax.plot(p50, label='prediction median')

    # plot benchmark data
    if bench_mark_prediction is not None:
        ax.plot(bench_mark_prediction,
                label=bench_mark_prediction_name,
                color='r')

    # adding legend
    ax.legend()

    # Save it to a temporary buffer.
    buf = BytesIO()
    fig.savefig(buf, format="png")

    # Embed the result in the html output.
    data = base64.b64encode(buf.getbuffer()).decode("ascii")
    data_str = f"<img src='data:image/png;base64,{data}'/>"
    return data_str
Ejemplo n.º 13
0
class Analysis_Panel(wx.Panel):
    def __init__(self, parent, user_data):
        """ Initialize everything here """
        super(Analysis_Panel, self).__init__(parent, style=wx.BORDER_DOUBLE)
        self.SetBackgroundColour(wx.Colour("White"))
        self.x_vals, self.y_vals = None, None
        # self.speech_x_vals, self.speech_y_vals = None, None
        self.user_data = user_data
        self.figure = Figure(figsize=(1, 1))
        # DPI = self.figure.get_dpi()
        # print("DPI:", DPI)
        # DefaultSize = self.figure.get_size_inches()
        # print("Default size in Inches", DefaultSize)
        # print("Which should result in a %i x %i Image" % (DPI * DefaultSize[0], DPI * DefaultSize[1]))
        self.axes, self.axes_intensity, self.axes_pitch = self.figure.subplots(
            3, sharex='col')
        self.canvas = FigureCanvas(self, -1, self.figure)
        self.sizer = wx.BoxSizer(wx.VERTICAL)
        self.figure.subplots_adjust(top=0.98, bottom=0.15)  #adjusted correctly
        self.sld = wx.Slider(self,
                             value=0,
                             minValue=0,
                             maxValue=100,
                             size=(545, -1),
                             style=wx.SL_AUTOTICKS | wx.SL_HORIZONTAL
                             | wx.SL_SELRANGE)
        self.sizer.Add(self.sld, 0, wx.LEFT, 77)
        self.sizer.Add(self.canvas, 1, wx.EXPAND)

        self.SetSizer(self.sizer)
        self.sizer.Fit(self)

        # self.InitUI()
        pub.subscribe(self.changePitch, "PITCH_CHANGE")
        pub.subscribe(self.changeLevel, "LEVEL_CHANGE")
        pub.subscribe(self.changeSlider, "SLIDER_CHANGE")
        pub.subscribe(self.playAnimation, "PLAY")
        # pub.subscribe(self.changeLevelRight, "LEVEL_RIGHT_CHANGE")

    def changeSlider(self, value):
        self.sld.SetValue(min(self.sld.GetValue() + 1, 0))
        for i in range(100):
            self.sld.SetValue(min(self.sld.GetValue() + 1, 100))
            time.sleep(0.01)

    def changePitch(self, value):
        self.axes_pitch.set_ylim(0, value)
        self.canvas.draw()
        self.canvas.Refresh()

    def changeLevel(self, value):
        self.axes_intensity.set_ylim(20, value)
        self.canvas.draw()
        self.canvas.Refresh()

    def draw_speech_data(self):
        self.axes.clear()
        self.axes.set_ylim(-1, 1)
        self.axes.set_xlim(self.start_time, self.end_time)
        self.axes.set_ylabel('Sig.(norm)',
                             fontname="Arial",
                             fontsize=10,
                             labelpad=13)
        self.axes.tick_params(axis='both', which='major', labelsize=8)
        self.axes.plot(self.speech_x_vals,
                       self.speech_y_vals,
                       "b",
                       linewidth=0.5)
        self.l, self.v = self.axes.plot(self.speech_x_vals[0],
                                        self.speech_y_vals[0],
                                        self.speech_x_vals[-1],
                                        self.speech_y_vals[-1],
                                        linewidth=0.5,
                                        color='red')
        self.axes.grid(True, color='lightgray')
        self.canvas.draw()
        self.canvas.Refresh()

    def set_data(self, xvals, yvals, start, end):
        self.speech_x_vals = xvals
        self.speech_y_vals = yvals
        self.start_time = start
        self.end_time = end
        self.draw_speech_data()

    def get_data(self):
        return self.speech_x_vals, self.speech_y_vals

    def draw_pitch_data(self):
        self.axes_pitch.clear()
        self.axes_pitch.set_ylim(0, 500)
        self.axes_pitch.set_xlim(self.start_time, self.end_time)
        self.axes_pitch.grid(True, color='lightgray')
        self.axes_pitch.set_ylabel('Pitch (norm)',
                                   fontname="Arial",
                                   fontsize=10,
                                   labelpad=9)
        self.axes_pitch.tick_params(axis='both', which='major', labelsize=8)
        self.axes_pitch.set_xlabel('Time (s)', fontname="Arial", fontsize=10)
        # self.axes_pitch.get_xaxis().set_visible(False)
        self.axes_pitch.plot(self.x_vals, self.y_vals, "g", linewidth=0.5)
        self.canvas.draw()
        self.canvas.Refresh()

    def set_pitch_data(self, xvals, yvals, start, end):
        self.x_vals = xvals
        self.y_vals = yvals
        self.start_time = start
        self.end_time = end
        self.draw_pitch_data()

    def draw_intensity_data(self):
        self.axes_intensity.clear()
        self.axes_intensity.set_ylim(20, 80)
        self.axes_intensity.set_xlim(self.start_time, self.end_time)
        self.axes_intensity.grid(True, color='lightgray')
        self.axes_intensity.set_ylabel('Level (norm)',
                                       fontname="Arial",
                                       fontsize=10,
                                       labelpad=15)
        self.axes_intensity.tick_params(axis='both',
                                        which='major',
                                        labelsize=8)
        self.axes_intensity.plot(self.x_vals, self.y_vals, "b", linewidth=0.5)
        self.axes_intensity.xaxis.set_major_locator(ticker.LinearLocator(6))
        self.canvas.draw()
        self.canvas.Refresh()

    def set_intensity_data(self, xvals, yvals, start, end):
        self.x_vals = xvals
        self.y_vals = yvals
        self.start_time = start
        self.end_time = end
        self.draw_intensity_data()

    def update_line(self, num, line):
        i = self.speech_x_vals[num]
        print("val of i")
        line.set_data([i, i], [self.speech_y_vals[0], self.speech_y_vals[-1]])
        return line

    def playAnimation(self, value):
        self.line_anim = animation.FuncAnimation(fig=self.figure,
                                                 func=self.update_line,
                                                 frames=len(
                                                     self.speech_x_vals),
                                                 init_func=None,
                                                 fargs=(self.l, ),
                                                 repeat=None)
        print("Animationplayed")
Ejemplo n.º 14
0
def sample_wf(  # noqa: C901
    wf,
    sampler,
    steps,
    writer=None,
    write_figures=False,
    log_dict=None,
    blocks=None,
    *,
    block_size=10,
    equilibrate=True,
):
    r"""Sample a wave function and accumulate expectation values.

    This is a low-level interface, see :func:`~deepqmc.evaluate` for a high-level
    interface. This iterator iteratively draws samples from the sampler, detects
    when equilibrium is reached, and starts calculating and accumulating
    local energies to get an estimate of the energy. Diagnostics is written into
    the Tensorboard writer, and every full block, the step index, the current
    estimate of the energy, and the sampled electron coordinates are yielded.

    Args:
        wf (:class:`~deepqmc.wf.WaveFunction`): wave function model to be sampled
        sampler (iterator): yields batches of electron coordinate samples
        steps (iterator): yields step indexes
        writer (:class:`torch.utils.tensorboard.writer.SummaryWriter`):
            Tensorboard writer
        log_dict (dict-like): step data will be stored in this dictionary if given
        blocks (list): used as storage of blocks. If not given, the iterator
            uses a local storage.
        block_size (int): size of a block (a sequence of samples)
        equilibrate (bool or int): if false, local energies are calculated and
            accumulated from the first sampling step, if true equilibrium is
            detected automatically, if integer argument, specifies number of
            equilibration steps
    """
    blocks = blocks if blocks is not None else []
    calculating_energy = not equilibrate
    buffer = []
    energy = None
    for step, (rs, log_psis, _, info) in zip(steps, sampler):
        if step == 0:
            dist_means = rs.new_zeros(5 * block_size)
            if not equilibrate:
                yield 0, 'eq'
        dist_means[:-1] = dist_means[1:].clone()
        dist_means[-1] = pairwise_self_distance(rs).mean()
        if not calculating_energy:
            if type(equilibrate) is int:
                if step >= equilibrate:
                    calculating_energy = True
            elif dist_means[0] > 0:
                if dist_means[:block_size].std(
                ) < dist_means[-block_size:].std():
                    calculating_energy = True
            if calculating_energy:
                yield step, 'eq'
        if calculating_energy:
            Es_loc = local_energy(rs, wf, keep_graph=False)[0]
            buffer.append(Es_loc)
            if log_dict is not None:
                log_dict['coords'] = rs.cpu().numpy()
                log_dict['E_loc'] = Es_loc.cpu().numpy()
                log_dict['log_psis'] = log_psis.cpu().numpy()
            if 'sample_plugin' in PLUGINS:
                PLUGINS['sample_plugin'](wf, rs, log_dict)
            if len(buffer) == block_size:
                buffer = torch.stack(buffer)
                block = unp.uarray(
                    buffer.mean(dim=0).cpu(),
                    buffer.std(dim=0).cpu() / np.sqrt(len(buffer)),
                )
                blocks.append(block)
                buffer = []
            if not buffer:
                blocks_arr = unp.nominal_values(np.stack(blocks, -1))
                err = blocks_arr.mean(-1).std() / np.sqrt(len(blocks_arr))
                energy = ufloat(blocks_arr.mean(), err)
        if writer:
            if calculating_energy:
                writer.add_scalar('E_loc/mean',
                                  Es_loc.mean() - energy_offset, step)
                writer.add_scalar('E_loc/var', Es_loc.var(), step)
                writer.add_scalar('E_loc/min', Es_loc.min(), step)
                writer.add_scalar('E_loc/max', Es_loc.max(), step)
                if not buffer:
                    writer.add_scalar('E/value',
                                      energy.nominal_value - energy_offset,
                                      step)
                    writer.add_scalar('E/error', energy.std_dev, step)
            if write_figures:
                from matplotlib.figure import Figure

                fig = Figure(dpi=300)
                ax = fig.subplots()
                ax.hist(log_psis.cpu(), bins=100)
                writer.add_figure('log_psi', fig, step)
                fig = Figure(dpi=300)
                ax = fig.subplots()
                ax.hist(info['age'], bins=100)
                writer.add_figure('age', fig, step)
                if calculating_energy:
                    fig = Figure(dpi=300)
                    ax = fig.subplots()
                    ax.hist(Es_loc.cpu(), bins=100)
                    writer.add_figure('E_loc', fig, step)
                    if not buffer:
                        fig = Figure(dpi=300)
                        ax = fig.subplots()
                        ax.hist(blocks_arr.flatten(), bins=100)
                        writer.add_figure('E_block', fig, step)
        if calculating_energy:
            yield step, energy
Ejemplo n.º 15
0
class PairPanel(wx.Panel):
    """
    A panel displays the pair plots for any given column

    Args:
        df --> pandas dataframe: passed internally for plotting

    Returns: None
    """
    def __init__(self, parent, df=None):
        wx.Panel.__init__(self, parent)

        self.df = df
        self.available_columns = list(self.df.columns)
        self.hue_columns = self._get_hue_column()

        self.figure = Figure()
        self.axes = self.figure.add_subplot(111)
        self.canvas = FigureCanvas(self, -1, self.figure)

        self.toolbar = NavigationToolbar(self.canvas)

        self.text_hue = wx.StaticText(self, label="Hue:")
        self.dropdown_menu = wx.ComboBox(self,
                                         choices=self.hue_columns,
                                         style=wx.CB_READONLY)
        self.Bind(wx.EVT_COMBOBOX, self.column_selected)

        toolbar_sizer = wx.BoxSizer(wx.HORIZONTAL)
        toolbar_sizer.Add(self.text_hue, 0, wx.ALL | wx.ALIGN_CENTER, 5)
        toolbar_sizer.Add(self.dropdown_menu, 0, wx.ALL | wx.ALIGN_CENTER, 5)
        toolbar_sizer.Add(self.toolbar, 0, wx.ALL, 5)

        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
        sizer.Add(toolbar_sizer)
        self.SetSizer(sizer)
        self.Fit()

        pub.subscribe(self.update_available_column, "UPDATE_DISPLAYED_COLUMNS")

    def column_selected(self, event):
        """
        Function responses to select column from dropdown menu.
        """

        selected_column_id = self.dropdown_menu.GetCurrentSelection()
        selcted_column = self.hue_columns[selected_column_id]

        self.draw_pair(selcted_column)

    def draw_pair(self, column_name):
        """
        Function that draws plot in the panel.

        Note: 
            Seaborn pairplot return a series of subplots within one figure,
            therefore it is really difficult to plot it directly in the 
            existing figure. 
            Instead, we mimic how it is plotted and add corresponding 
            number of matplotlib subplots and plot the pairplot inside the 
            matplotlib subplots

        Args:
            column_name --> string: the name of the column that needs to 
                be drawn
        Returns: None
        """

        # Reset plot, clean the axes
        self.figure.clf()

        legend_labels = self.df[column_name].unique()
        legend_title = column_name

        df = prepare_data(self.df[self.available_columns])

        if str(self.df[column_name].dtype) == "object":
            # Update hue column for categorical data
            column_name += "_code"

        pub.sendMessage("LOG_MESSAGE", log_message="\nReady to plot...")

        try:
            # Produce pairpolot using seaborn
            pair_plot = sns.pairplot(
                df,
                hue=column_name,
                palette="deep",
                size=1.2,
                diag_kind="kde",
                diag_kws=dict(shade=True),
                plot_kws=dict(s=10),
            )

            # Get the number of rows and columns from the seaborn pairplot
            pp_rows = len(pair_plot.axes)
            pp_cols = len(pair_plot.axes[0])

            # Update axes to the corresponding number of subplots from pairplot
            self.axes = self.figure.subplots(pp_rows, pp_cols)

            # Get the label and plotting order
            xlabels, ylabels = [], []
            for ax in pair_plot.axes[-1, :]:
                xlabel = ax.xaxis.get_label_text()
                xlabels.append(xlabel)
            for ax in pair_plot.axes[:, 0]:
                ylabel = ax.yaxis.get_label_text()
                ylabels.append(ylabel)

            # Setup hue for plots
            hue_values = df[column_name].unique()
            palette = sns.color_palette("muted")  # get seaborn default colors
            legend_color = palette.as_hex()

            # Mimic how seaborn produce the pairplot using matplotlib subplots
            for i in range(len(xlabels)):
                for j in range(len(ylabels)):
                    if i != j:
                        # Non-diagnal locations, scatter plot
                        for num, value in enumerate(hue_values):
                            sns.regplot(
                                x=df[xlabels[i]][df[column_name] == value],
                                y=df[ylabels[j]][df[column_name] == value],
                                data=df,
                                scatter=True,
                                fit_reg=False,
                                ax=self.axes[j, i],
                                scatter_kws={
                                    's': 10,  # Set dot size
                                    'facecolor':
                                    legend_color[num],  # Set dot color
                                })
                    else:
                        # Diagnal locations, distribution plot
                        for num, value in enumerate(hue_values):
                            sns.kdeplot(
                                df[xlabels[i]][df[column_name] == value],
                                ax=self.axes[j, i],
                                color=legend_color[num],
                                legend=False,
                                shade=True,
                            )

                    # Set plot labels, only set the outter plots to avoid
                    # label overlapping
                    if i == 0:
                        if j == len(xlabels) - 1:
                            # Case for bottom left corner
                            self.axes[j, i].set_xlabel(xlabels[i], fontsize=8)
                        else:
                            self.axes[j, i].set_xlabel("")
                            self.axes[j, i].set_xticklabels(
                                [])  # Turn off tick labels
                        self.axes[j, i].set_ylabel(ylabels[j], fontsize=8)
                    elif j == len(xlabels) - 1:
                        self.axes[j, i].set_xlabel(xlabels[i], fontsize=8)
                        self.axes[j, i].set_ylabel("")
                        self.axes[j, i].set_yticklabels(
                            [])  # Turn off tick labels
                    else:
                        # Hide labels
                        self.axes[j, i].set_xlabel("")
                        self.axes[j, i].set_ylabel("")

                        # Turn off tick labels
                        self.axes[j, i].set_xticklabels([])
                        self.axes[j, i].set_yticklabels([])

            end_message = "Pair plots finished"
            pub.sendMessage("LOG_MESSAGE", log_message=end_message)

            handles, _ = self.axes[0, 0].get_legend_handles_labels()
            self.figure.legend(
                handles,
                labels=legend_labels,
                title=legend_title,
                loc='center right',
            )

            self.figure.subplots_adjust(
                left=0.03,  # the left side of the subplots of the figure
                bottom=0.08,  # the bottom of the subplots of the figure
                right=0.93,  # the right side of the subplots of the figure
                top=0.97,  # the top of the subplots of the figure
                wspace=
                0.12,  # the amount of width reserved for space between subplots
                hspace=
                0.12,  # the amount of height reserved for space between subplots
            )

        except ValueError as e:
            # log Error
            _log_message = "\nPair plots failed due to error:\n--> {}".format(
                e)
            pub.sendMessage("LOG_MESSAGE", log_message=_log_message)

        self.canvas.draw()
        self.Refresh()

    def _get_hue_column(self):
        """
        This internal function limits the available columns for hue selection.
        It filters out those columns with too many dinstinct values.
        
        Currently it is set for the number 6, which is the number of distinct 
        colors for the default seaborn color palette.

        Args: None

        Returns: 
            hue_columns --> list: a list of column headers
        """

        hue_columns = []
        for column in self.available_columns:
            if self.df[column].nunique() <= 6:
                # Restrict hue selection based on distinct values in column
                hue_columns.append(column)

        return hue_columns

    def update_available_column(self, available_columns):
        """
        Update datafram used for plotting.

        Args:
            available_columns --> list: a list of available column headers
            
        Returns: None
        """

        self.available_columns = available_columns

        # Update hue column selection
        self.hue_columns = self._get_hue_column()

        # Update dropdown menu
        self.dropdown_menu.Clear()
        for column in self.hue_columns:
            self.dropdown_menu.Append(column)
Ejemplo n.º 16
0
def graficar_dists(dists, valores=None, rango=None, título=None, archivo=None):
    """
    Esta función genera un gráfico de una o más distribuciones y valores.

    :param dists: Una lista de las distribuciones para graficar.
    :type dists: list[str, VarCalib] | str | VarCalib

    :param valores: Una matriz numpy de valores para generar un histograma (opcional)
    :type valores: np.ndarray

    :param rango: Un rango de valores para resaltar en el gráfico (opcional).
    :type rango: tuple

    :param título: El título del gráfico, si hay.
    :type título: str

    :param archivo: Dónde hay que guardar el dibujo. Si no se especifica, se presentará el gráfico al usuario en una
      nueva ventana (y el programa esperará que la usadora cierra la ventana antes de seguir con su ejecución).
    :type archivo: str

    """

    if type(dists) is not list:
        dists = [dists]

    n = 100000

    fig = Figura()
    TelaFigura(fig)

    # Poner cada distribución en el gráfico
    for dist in dists:

        if isinstance(dist, VarCalib):
            ejes = fig.subplots(1, 2)

            dist.dibujar(ejes=ejes)

            # Si se especificó un título, ponerlo
            if título is not None:
                fig.suptitle(título)

        else:

            if isinstance(dist, str):
                dist = VarSciPy.de_texto(texto=dist)

            if isinstance(dist, VarSciPy):
                x = np.linspace(dist.percentiles(0.01), dist.percentiles(0.99),
                                n)
                y = dist.fdp(x)
            else:
                raise TypeError(
                    'El tipo de distribución "%s" no se reconoce como distribución aceptada.'
                    % type(dist))

            ejes = fig.add_subplot(111)

            # Dibujar la distribución
            ejes.plot(x, y, 'b-', lw=2, alpha=0.6)

            # Resaltar un rango, si necesario
            if rango is not None:
                if rango[1] < rango[0]:
                    rango = (rango[1], rango[0])
                ejes.fill_between(x[(rango[0] <= x) & (x <= rango[1])],
                                  0,
                                  y[(rango[0] <= x) & (x <= rango[1])],
                                  color='blue',
                                  alpha=0.2)

            # Si hay valores, hacer un histrograma
            if valores is not None:
                valores = valores.astype(float)
                ejes.hist(valores,
                          normed=True,
                          color='green',
                          histtype='stepfilled',
                          alpha=0.2)

            # Si se especificó un título, ponerlo
            if título is not None:
                ejes.set_title(título)

    # Guardar el gráfico
    if archivo[-4:] != '.png':
        archivo = os.path.join(archivo, título + '.png')

    valid_archivo(archivo)

    fig.savefig(archivo)
Ejemplo n.º 17
0
def do_it(cfg):
    DPI = 96

    fig = Figure()
    FigureCanvasAgg(fig)
    ax: Axes = fig.subplots(
        1,
        1,
        subplot_kw=dict(xticks=[], yticks=[]),
        gridspec_kw=dict(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0),
    )
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_axis_off()
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)

    fig.set_dpi(DPI)
    fig.set_size_inches(cfg.dim / DPI, cfg.dim / DPI)

    # params
    tau = 2 * np.pi
    NPOINTS = 1000
    XMAX = 1
    FREQ = 1.3

    def lerp(x, y, a: float):
        return x * (1 - a) + y * a

    def sintau(x):
        return np.sin(tau * x)

    def costau(x):
        return np.cos(tau * x)

    gauss = lambda xs, W: np.exp(-(xs / W)**2)
    cos_win = lambda xs: costau(xs / 4)

    def win(xs):
        assert xs[0] == -1
        assert xs[-1] == 1

        W = 0.6
        e = 1
        return gauss(xs, W) * cos_win(xs)**e

    # plot
    xs = np.linspace(-XMAX, XMAX, NPOINTS)

    def sinusoid(dx, freq=1, yscale=1):
        # sintau
        # x: compress=freq, shift=dx
        # y: mul=yscale
        return lambda xs: sintau((xs - dx) * freq) * yscale

    def plot_sinusoid(dx, freq, yscale, alpha, color=None):
        func = sinusoid(dx, freq, yscale)
        ax.plot(xs,
                func(xs) * win(xs),
                alpha=alpha,
                color=color,
                linewidth=cfg.line_width)

    top = "narrow"
    blue = "narrow"

    top_blue = top == blue
    if top_blue:
        i = cfg.nline - 1
        di = -1
    else:
        i = 0
        di = 1

    freqs = np.geomspace(0.2, 1, cfg.nline)
    if top == "wide":
        freqs = freqs[::-1]

    e = 0

    for freq in freqs:
        plot_sinusoid(0, freq=freq, yscale=freq**e, alpha=1, color=cmap(i))
        i += di

    fig.savefig(f"{cfg.dim}.png", transparent=True)
Ejemplo n.º 18
0
class ImageToTensorBoard(ToTensorBoard):
    def __init__(
        self,
        log_dir: str,
        plotting_function: Callable[
            ["matplotlib.figure.Figure", "matplotlib.figure.Axes"], "matplotlib.figure.Figure"
        ],
        name: Optional[str] = None,
        *,
        fig_kw: Optional[Dict[str, Any]] = None,
        subplots_kw: Optional[Dict[str, Any]] = None,
    ):
        """
        :param log_dir: directory in which to store the tensorboard files.
            Can be nested: for example, './logs/my_run/'.
        :param plotting_function: function performing the plotting.
        :param name: name used in TensorBoard.
        :params fig_kw: keyword arguments to be passed to Figure constructor, e.g. `figsize`.
        :params subplots_kw: keyword arguments to be passed to figure.subplots constructor, e.g.
            `nrows`, `ncols`, `sharex`, `sharey`. By default the default values
            from matplotlib.pyplot are used.
        """
        super().__init__(log_dir)
        self.plotting_function = plotting_function
        self.name = name
        self.fig_kw = fig_kw or {}
        self.subplots_kw = subplots_kw or {}

        try:
            from matplotlib.figure import Figure
        except ImportError:
            raise RuntimeError("ImageToTensorBoard requires the matplotlib package to be installed")

        self.fig = Figure(**self.fig_kw)
        if self.subplots_kw != {}:
            self.axes = self.fig.subplots(**self.subplots_kw)
        else:
            self.axes = self.fig.add_subplot(111)

    def _clear_axes(self):
        if isinstance(self.axes, np.ndarray):
            for ax in self.axes.flatten():
                ax.clear()
        else:
            self.axes.clear()

    def run(self, **unused_kwargs):
        from matplotlib.backends.backend_agg import FigureCanvasAgg

        self._clear_axes()
        self.plotting_function(self.fig, self.axes)
        canvas = FigureCanvasAgg(self.fig)
        canvas.draw()

        # get PNG data from the figure
        png_buffer = BytesIO()
        canvas.print_png(png_buffer)
        png_encoded = png_buffer.getvalue()
        png_buffer.close()

        image_tensor = tf.io.decode_png(png_encoded)[None]

        # Write to TensorBoard
        tf.summary.image(self.name, image_tensor, step=self.current_step)
Ejemplo n.º 19
0
 def dna_logo(self, save=False, show=False, count=True, ax=None):
     freq = np.array(self._freq(count))
     en = 2.88539008 / max(sum(self.count), 1.5)
     info = (np.log(5) / np.log(2) - en - self.entropy(count=count))
     if np.min(info) < 0:
         info -= np.min(info)
     height = (freq.T * info).T
     order = ['A', 'G', 'C', 'T', '-']
     height = [
         sorted(list(zip(order, i)), key=lambda x: x[1]) for i in height
     ]
     fp = FontProperties(family="monospace", weight="bold")
     LETTERS = {
         "T": TextPath((-0.395, 0), "T", size=1.35, prop=fp),
         "G": TextPath((-0.395, 0), "G", size=1.35, prop=fp),
         "A": TextPath((-0.395, 0), "A", size=1.35, prop=fp),
         "C": TextPath((-0.395, 0), "C", size=1.35, prop=fp),
         "-": TextPath((-0.395, 0), "I", size=1.35, prop=fp)
     }
     COLOR_SCHEME = {
         'G': 'orange',
         'A': 'red',
         'C': 'blue',
         'T': 'darkgreen',
         '-': 'black'
     }
     if ax:
         ax = ax
         drawx = False
         fig = None
     else:
         fig = Figure(figsize=(7, 2))
         ax = fig.subplots()
         drawx = True
     for x, scores in enumerate(height):
         y = 0
         for letter, score in scores:
             text = LETTERS[letter]
             t = mpl.transforms.Affine2D().scale(1.2, score) + \
                 mpl.transforms.Affine2D().translate(x+1, y) + ax.transData
             p = PathPatch(text, lw=0, fc=COLOR_SCHEME[letter], transform=t)
             ax.add_artist(p)
             y += score
     x_tick_label = ([
         str(k + 1) + '\n' + i[-1][0] for k, i in enumerate(height)
     ] if drawx else [str(k + 1) for k, i in enumerate(height)])
     ax.set_title('{} Total count: {}'.format(str(self.name),
                                              sum(self.count)),
                  fontsize=6)
     ax.set_xticks(range(1, len(height) + 1), )
     ax.set_xticklabels(x_tick_label)
     ax.set_xlim((0, len(height) + 1))
     ax.set_ylim((0, 2.33))
     ax.tick_params(axis='both', which='both', labelsize=6)
     if save:
         fig.set_tight_layout(True)
         save = save if isinstance(save,
                                   str) else str(self.name) + '_logo.svg'
         fig.savefig(save, format='svg')
     if show:
         plt.tight_layout()
         plt.show()
     return fig
Ejemplo n.º 20
0
def inference():
    testId = request.args.get('testId', None)
    if testId is None:
        return "ERROR: Invalid testId Parameter", 400

    dataset = MNIST(config.DATA_DIR, train=False)

    test_id = int(testId)
    if test_id < 0 or test_id >= dataset.data.shape[0]:
        return "ERROR: Out of Range testId Parameter", 400

    test_img = image_to_base64(get_pil_image_from_dataset(dataset, test_id))

    device = torch.device('cpu')
    model = MNISTNet()
    model.load_state_dict(
        torch.load(join(config.MODEL_DIR, 'mnist_cnn.pt'),
                   map_location=device))
    model.eval()

    normalized_data = MNIST(config.DATA_DIR,
                            train=False,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307, ), (0.3081, ))
                            ]))
    with torch.no_grad():
        data = normalized_data.data[test_id, :, :]
        data = data.type(torch.float32)
        data = data.unsqueeze(0).unsqueeze(0)
        target = normalized_data.targets[test_id]

        data, target = data.to(device), target.to(device)

        output = model(data)
        pred = output.argmax(
            dim=1, keepdim=True)  # get the index of the max log-probability

    output_val = np.exp(output.numpy()).squeeze()
    pred_idx = int(pred.numpy()[0][0])
    pred_score = output_val[pred_idx]
    gt_label = dataset.classes[int(dataset.test_labels[test_id])]
    pred_label = dataset.classes[pred_idx]

    fig = Figure(figsize=(10, 10), dpi=100)
    canvas = FigureCanvas(fig)
    ax = fig.subplots()
    x = np.arange(0, 10, 1)
    ax.set_xticks(x)
    # Set ticks labels for x-axis
    ax.set_xticklabels(dataset.classes, rotation='vertical', fontsize=18)
    # ax = fig.add_axes([.1, .1, .8, .8])
    ax.plot(output_val)
    # ax.title('Predictions')

    inf_img = figure_to_base64(fig)

    return render_template('inference.html',
                           test_id=test_id,
                           test_img=test_img,
                           inf_img=inf_img,
                           gt_label=gt_label,
                           pred_label=pred_label)
Ejemplo n.º 21
0
class Plot:
    def __init__(self, controller):
        self.controller = controller
        self.lock = Lock()
        self.update_plot = False
        self.label = {'x': 'time in seconds', 'y': ''}
        self.title = ''
        self.ani = None
        self.sensor_config = None
        self.incoming_data = {"start": 0.0, "end": 0.0, "data": 0}

        self.fig = Figure(figsize=(12, 8), dpi=100)
        self.ax1 = self.fig.add_subplot(111)
        self.ax2 = False

    def set_animation(self, state):
        if self.ani == None:
            self.ani = animation.FuncAnimation(
                self.fig,
                self.plot_values,
                interval=self.controller.get_update_rate())
        self.running = state
        if state:
            self.update_plot = True
            self.ani.event_source.start()

    def init_plot(self, type):
        self.stats = ['', '', '', '']
        self.incoming_data["start"] = time.time()
        self.incoming_data["data"] = 0
        self.voltage = 0
        self.current = 0
        self.power = 0
        self.pf = 0
        self.energy = 0
        self.noiseLevel = 0
        self.dimLevel = 0
        self.values = [[], [], [], [], [], []]
        self.start_time = 0
        self.max_time = 0
        self.plot_type = self.controller.gui_settings['plot_type']

        self.compress_start = 0
        self.data_retention = int(
            self.controller.gui_settings['retention'].get())
        if self.data_retention < 5:
            self.data_retention = 5
            self.controller.gui_settings['retention'].set('5')
        # keep 10% or max. 5 seconds uncompressed
        self.data_retention_keep = self.data_retention * 0.1
        if self.data_retention_keep > 5:
            self.data_retention_keep = 5
        # min. 10ms max 1s
        self.data_retention_compress = float(
            self.controller.gui_settings['compress'].get())
        if self.data_retention_compress != 0:
            if self.data_retention_compress < 0.01:
                self.data_retention_compress = 0.01
                self.controller.gui_settings['compress'].set('0.01')
            elif self.data_retention_compress > 1:
                self.data_retention_compress = 1
                self.controller.gui_settings['compress'].set('1')

        self.convert_units = self.controller.gui_settings['convert_units'].get(
        )
        self.label['y'] = "µs"
        if self.convert_units == 1:
            if self.plot_type == 'I':
                self.label['y'] = "A"
            elif self.plot_type == 'U':
                self.label['y'] = "V"
            elif self.plot_type == 'P':
                self.label['y'] = "W"

        self.fig.clear()
        if self.controller.gui_settings['noise'].get() == 1:
            (self.ax1, self.ax2) = self.fig.subplots(2, 1)
        else:
            self.ax1 = self.fig.subplots(1, 1)
            self.ax2 = False

    def plot_values(self, i):
        try:
            running = self.ani.running
        except:
            return
        if not running:
            self.ani.event_source.stop()
        elif self.update_plot:
            if not self.lock.acquire(False):
                print("plot_values() could not aquire lock")
                return

            try:
                max_time = self.max_time
                x_left = max_time - self.data_retention
                x_right = max_time
                values = [
                    self.values[0].copy(), self.values[1].copy(),
                    self.values[2].copy(), self.values[3].copy(),
                    self.values[4].copy(), self.values[5].copy()
                ]

            finally:
                self.lock.release()

            # print(str(len(values[0])) + ' ' + str(len(values[1])) + ' ' + str(len(values[2])) + ' ' + str(len(values[3])) + ' ' + str(len(values[4])) + ' ' + str(len(values[5])))

            self.update_plot = False
            self.ax1.clear()

            label_x = ''
            if self.sensor_config != None:
                fmt = " - Imin. {0:}A Imax {1:}A Rshunt {2:} - Calibration U/I/P {3:} {4:} {5:}"
                label_x = fmt.format(self.sensor_config["Imin"],
                                     self.sensor_config["Imax"],
                                     self.sensor_config["Rs"],
                                     *self.sensor_config["UIPc"])

            self.ax1.set_title('HLW8012 - ' + self.title)
            self.ax1.set_xlabel(self.label['x'] + label_x)
            self.ax1.set_ylabel(self.label['y'])
            self.ax1.set_xlim(left=x_left, right=x_right)

            if len(values[0]) == 0:
                return

            fmt = "{0:.2f}V {1:.4f}A {2:.2f}W pf {3:.2f} {4:.3f}kWh"
            if self.noiseLevel != 0:
                fmt = fmt + " noise {5:.3f}"
            fmt = fmt + " data {6:d} ({7:.2f}/s)"
            if self.dimLevel != -1:
                fmt = fmt + " level {8:.1f}%"
            try:
                dpps = self.incoming_data["data"] / (
                    self.incoming_data["end"] - self.incoming_data["start"])
                self.fig.suptitle(fmt.format(self.voltage, self.current,
                                             self.power, self.pf, self.energy,
                                             self.noiseLevel / 1000.0,
                                             len(values[0]), dpps,
                                             self.dimLevel * 100.0),
                                  fontsize=16)
            except:
                pass

            try:
                stats = []
                n = max(values[3]) / 10
                digits = 2
                while n < 1 and digits < 4:
                    n = n * 10
                    digits = digits + 1
                fmt = '.' + str(digits) + 'f'
                for i in range(1, 5):
                    tmp = ' min/max ' + format(min(
                        values[i]), fmt) + '/' + format(max(
                            values[i]), fmt) + ' ' + u"\u2300" + ' ' + format(
                                mean(values[i]), fmt)
                    stats.append(tmp)
                self.stats = stats
            except:
                pass

            y_max = 0
            if self.controller.get_data_state(0):
                self.ax1.plot(values[0],
                              values[1],
                              'g',
                              label='sensor' + self.stats[0],
                              linewidth=0.1)
                y_max = max(y_max, max(values[1]))
            if self.controller.get_data_state(1):
                self.ax1.plot(values[0],
                              values[2],
                              'b',
                              label='avg' + self.stats[1])
                y_max = max(y_max, max(values[2]))
            if self.controller.get_data_state(2):
                self.ax1.plot(values[0],
                              values[3],
                              'r',
                              label='integral' + self.stats[2])
                y_max = max(y_max, max(values[3]))
            if self.controller.get_data_state(3):
                self.ax1.plot(values[0],
                              values[4],
                              'c',
                              label='display' + self.stats[3])
                y_max = max(y_max, max(values[4]))

            i_min = 0
            if self.sensor_config != None and self.plot_type == 'I':
                if self.convert_units == 1:
                    i_min = self.sensor_config["Imin"]
                    self.ax1.hlines(y=self.sensor_config["Imax"],
                                    xmin=x_left,
                                    xmax=x_right,
                                    linestyle='dashed')
                else:
                    i_min = self.sensor_config["Ipmax"]
                    self.ax1.hlines(y=self.sensor_config["Ipmin"],
                                    xmin=x_left,
                                    xmax=x_right,
                                    linestyle='dashed')

                y_max = max(y_max, i_min)
                self.ax1.hlines(y=i_min,
                                xmin=x_left,
                                xmax=x_right,
                                linestyle='dashed')

            y_min = y_max * 0.98 * self.controller.get_y_range() / 100.0
            y_max = y_max * 1.02
            self.ax1.set_ylim(top=y_max, bottom=y_min)
            self.ax1.legend(loc='lower left')

            if self.ax2 != False:
                self.ax2.clear()
                self.ax2.set_xlim(left=x_left, right=x_right)
                self.ax2.hlines(y=40,
                                xmin=x_left,
                                xmax=x_right,
                                linestyle='dashed')
                self.ax2.plot(values[0], values[5], 'r', label='noise')

    def get_time(self, value):
        return (value - self.start_time) / 1000000.0

    def clean_old_data(self):
        if len(self.values[0]) > 0:
            min_time = self.values[0][-1] - self.data_retention
            for i in range(len(self.values[0]) - 1, 1, -1):
                if self.values[0][i] < min_time:
                    if i >= self.compress_start:
                        self.compress_start = 0
                    for n in range(0, len(self.values)):
                        del self.values[n][0:i]
                    break

    def data_handler(self, header, data):

        self.incoming_data["end"] = time.time()
        self.incoming_data["data"] = self.incoming_data["data"] + len(data)

        if not self.lock.acquire(True):
            print("data_handler() could not aquire lock")
            return
        try:

            if chr(header[2]) == self.plot_type:
                (packet_id, output_mode, data_type, self.voltage, self.current,
                 self.power, self.energy, self.pf, self.noiseLevel,
                 self.dimLevel) = header
                if self.plot_type == 'I':
                    display_value = self.current
                elif self.plot_type == 'U':
                    display_value = self.voltage
                elif self.plot_type == 'P':
                    display_value = self.power

                self.clean_old_data()

                if self.data_retention_compress != 0 and len(
                        self.values[0]) > 100:
                    keep_time = self.values[0][-1] - self.data_retention_keep
                    while True:
                        compress_time = self.values[0][
                            self.compress_start] + self.data_retention_compress
                        if compress_time > keep_time:
                            break
                        n = 0
                        for t in self.values[0]:
                            if compress_time < t:
                                # print("compress " + str(self.compress_start) + " " + str(n) + " " + str(len(self.values[0])) + " t " + str(compress_time) + " " + str(t))
                                for i in range(0, len(self.values)):
                                    minVal = min(
                                        self.values[i][self.compress_start:n])
                                    maxVal = max(
                                        self.values[i][self.compress_start:n])
                                    self.values[i][
                                        self.compress_start] = minVal
                                    self.values[i][self.compress_start +
                                                   1] = maxVal
                                    del self.values[i][self.compress_start +
                                                       2:n]
                                self.compress_start = self.compress_start + 2
                                compress_time = self.values[0][
                                    self.
                                    compress_start] + self.data_retention_compress
                                break
                            n = n + 1
                        if compress_time == self.values[0][
                                self.
                                compress_start] + self.data_retention_compress:
                            break

                if self.start_time == 0:
                    self.start_time = data[0]

                # copy data
                for pos in range(0, len(data), 4):
                    self.values[0].append(self.get_time(data[pos]))
                    self.values[1].append(data[pos + 1])
                    self.values[2].append(data[pos + 2])
                    self.values[3].append(data[pos + 3])
                    self.values[4].append(display_value)
                    self.values[5].append(self.noiseLevel / 1000.0)

                self.max_time = self.values[0][-1]
                self.update_plot = True

        finally:
            self.lock.release()
Ejemplo n.º 22
0
class App_ui(QtWidgets.QMainWindow):

    def __init__(self):
        super(App_ui, self).__init__() # Call the inherited classes __init__ method
        uic.loadUi('./UI/App_UI_20201227.ui', self) # Load the .ui file
       
        plt.style.use('ggplot')
        self.fig = Figure()
        self.dataframe = pd.DataFrame() #Initialize an empty data frame
        self.dataframe_keys = (self.dataframe.columns)
        self.fname = ''
        self.file_size = None
        self.file_rows, self.file_cols = None, None

        self.canvas = FigureCanvas(self.fig)
        self.mplvl.addWidget(self.canvas)
        self.ax = self.fig.subplots(nrows=1, ncols=1, squeeze=False)
        self.fig.tight_layout()
        self.canvas.draw()

        self.toolbar = NavigationToolbar(self.canvas, 
                self.mplwindow, coordinates=True)
        self.mplvl.addWidget(self.toolbar)

        self.show() # Show the GUI

        btn_open_file = self.actionOpen_File
        btn_open_file.triggered.connect(self.OpenFileDialog)

        self.cmbx_x_axis_data = self.comboBox_x_axis_data
        self.cmbx_y_axis_data = self.comboBox_y_axis_data

        self.textBox_filePath    = self.textEdit_filePath
        self.textBox_fileSize    = self.textEdit_fileSize
        self.textBox_fileColumns = self.textEdit_fileColumns
        self.textBox_fileRows    = self.textEdit_fileRows

        self.spinBox_xTicksRotation = self.spinBox_x_ticks_rotation
        self.spinBox_xTicksRotation.valueChanged.connect(lambda: self.rotate_axis_ticks("x", self.spinBox_xTicksRotation.value()))
        self.spinBox_yTicksRotation = self.spinBox_y_ticks_rotation
        self.spinBox_yTicksRotation.valueChanged.connect(lambda: self.rotate_axis_ticks("y", self.spinBox_yTicksRotation.value()))

        self.table_data    = self.tableView_dataTable
        
        btn_plot_file = self.actionCreate_Plot
        btn_plot_file.triggered.connect(self.plot_mpl)

    def rotate_axis_ticks(self, plot_axis, rotation):
        self.ax[0,0].tick_params(axis=plot_axis, labelrotation=rotation)
        self.fig.canvas.draw()
        print(plot_axis, rotation)



    def add_columnNames_to_comboBox(self, combo_Box, df):
        self.dataframe_keys = (df.columns)
        combo_Box.addItem('index')
        combo_Box.addItems(self.dataframe_keys)
    
    def remove_columnNames_from_comboBox(self, combo_Box):
        combo_Box.clear()
                
    def OpenFileDialog(self):
        self.remove_columnNames_from_comboBox(self.cmbx_x_axis_data)
        self.remove_columnNames_from_comboBox(self.cmbx_y_axis_data)
        open_file_dialog = QFileDialog.getOpenFileName(self, 'Open file', "", "CSV files (*.csv)")
        if open_file_dialog == ('', ''):
            print(f'>>Info: No file was Selected: {(self.fname)}')
            # self.fname = None
            # self.dataframe = pd.DataFrame()
            
        else:
            self.fname = open_file_dialog[0]
            print(f'>>Info: Selected File: {(self.fname)}')
            self.file_size = str(os.path.getsize(self.fname)/1e6) #in MB, string
            self.dataframe = read_data(self.fname, pd.DataFrame()).data()
            self.file_rows, self.file_cols = self.dataframe.shape
            model = pandasModel(self.dataframe)
            self.table_data.setModel(model)

        self.textBox_filePath.setText(self.fname)
        self.textBox_fileSize.setText(self.file_size)

        self.add_columnNames_to_comboBox(self.cmbx_x_axis_data, self.dataframe)
        self.add_columnNames_to_comboBox(self.cmbx_y_axis_data, self.dataframe)

        self.textBox_fileColumns.setText(str(self.file_cols))
        self.textBox_fileRows.setText(str(self.file_rows))

        

        


    def set_x_data(self):
        data = self.dataframe
        if len(data)!=0:
            if self.cmbx_x_axis_data.currentText()=='index':
                self.x_axis_data = list(data.index.values)
            elif self.cmbx_x_axis_data.currentText()!='index':
                self.x_axis_data = data[self.cmbx_x_axis_data.currentText()]
        else:
            self.x_axis_data=nan


    def set_y_data(self):
        data = self.dataframe
        if len(data)!=0:
            if self.cmbx_y_axis_data.currentText()=='index':
                self.y_axis_data = list(data.index.values)
            elif self.cmbx_y_axis_data.currentText()!='index':
                self.y_axis_data = data[self.cmbx_y_axis_data.currentText()]
        else:
            self.y_axis_data=nan

    def plot_mpl(self):
        data = self.dataframe
        self.set_x_data()
        self.set_y_data()
        self.ax[0,0].cla()
        label_txt = self.cmbx_y_axis_data.currentText()
        self.ax[0,0].plot(self.x_axis_data, self.y_axis_data, 'g', label = label_txt)
        x_axis_label_txt = self.cmbx_x_axis_data.currentText()
        y_axis_label_txt = self.cmbx_y_axis_data.currentText()
        self.ax[0,0].set_xlabel(x_axis_label_txt)
        self.ax[0,0].set_ylabel(y_axis_label_txt)
        self.ax[0,0].legend(loc='best')
        self.ax[0,0].grid(True, which='both')
        # self.ax[0,0].tick_params(axis='x', labelrotation=45)
        self.fig.tight_layout()
        self.canvas.draw()
        print(f'>>Info: New plot was created')
Ejemplo n.º 23
0
    print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
    print(" The button you used were: %s %s" %
          (eclick.button, erelease.button))


def toggle_selector(event):
    print(" Key pressed.")
    if event.key in ["Q", "q"] and toggle_selector.RS.active:
        print(" RectangleSelector deactivated.")
        toggle_selector.RS.set_active(False)
    if event.key in ["A", "a"] and not toggle_selector.RS.active:
        print(" RectangleSelector activated.")
        toggle_selector.RS.set_active(True)


current_ax = fig.subplots()  # make a new plotting range
N = 100000  # If N is large one can see
x = np.linspace(0.0, 10.0, N)  # improvement by use blitting!

current_ax.plot(x, +np.sin(0.2 * np.pi * x), lw=3.5, c="b",
                alpha=0.7)  # plot something
current_ax.plot(x, +np.cos(0.2 * np.pi * x), lw=3.5, c="r", alpha=0.5)
current_ax.plot(x, -np.sin(0.2 * np.pi * x), lw=3.5, c="g", alpha=0.3)

print("\n      click  -->  release")

# drawtype is 'box' or 'line' or 'none'
toggle_selector.RS = RectangleSelector(
    current_ax,
    line_select_callback,
    drawtype="box",
Ejemplo n.º 24
0
    def _build_figures(self):
        """Construct matplotlib Figure objects that will be used for rendering
    and initialize them with their static layout data.
    """
        figs = dict()
        dpi = float(100)

        for name, conf in self.conf.items():

            # Get the figure layout configuration
            fconf = conf["fig"]
            width, height = conf["width"], conf["height"]

            # Create the figure
            fig = Figure(figsize=(width / dpi, height / dpi), dpi=dpi)
            canvas = FigureCanvas(fig)
            # Create subplots; axes is always a 2D array
            axes = fig.subplots(**fconf["subplots"], squeeze=False)
            # Squeeze the axes in a list
            axes = [ax for r in axes for ax in r]

            assert len(axes) == len(fconf["subplots_conf"])

            # Configure all subplots
            axes_dict = collections.OrderedDict()
            for ax, (subplot, subconf) in zip(axes,
                                              fconf["subplots_conf"].items()):

                # Call each configuration method for the specific subplot
                # for fname, kwargs in {**subconf, **fconf["subplots_common"]}.items():
                for methods in [subconf, fconf["subplots_common"]]:
                    for fname, kwargs in methods.items():
                        try:
                            f = getattr(ax, fname)
                        except AttributeError:
                            raise ValueError(
                                "matplotlib.axes.Axes does not have a method '{}' \
                specified for figure '{}', subplot '{}`".format(
                                    fname, name, subplot))
                        f(**kwargs)

                # Get the plotting function
                if "plot" in conf:
                    plot_conf = conf["plot"]
                    try:
                        p = getattr(ax, plot_conf["method"])

                        # plot_fn = lambda ax, kwargs, env, p=p, pkwargs=plot_conf["kwargs"]: p(**kwargs, **pkwargs)
                        def plot_fn(ax,
                                    kwargs,
                                    env,
                                    p=p,
                                    pkwargs=plot_conf["kwargs"]):
                            return p(**kwargs, **pkwargs)
                    except AttributeError:
                        raise ValueError(
                            "matplotlib.axes.Axes does not have plot method '{}' \
              specified for figure '{}'".format(plot_conf["method"], name))
                else:
                    plot_fn = conf["plot_function"]

                # Remember the subplot key, axes and plot method
                axes_dict[subplot] = dict(ax=ax, plot_fn=plot_fn, artist=None)

            # Configure the figure options
            for fname, kwargs in fconf["fig_conf"].items():
                try:
                    f = getattr(fig, fname)
                except AttributeError:
                    raise ValueError(
                        "matplotlib.figure.Figure does not have a method '{}' \
            specified for figure '{}'".format(fname, name))
                f(**kwargs)

            # Add the figure and its subplots to the dict
            figs[name] = dict(fig=fig, axes=axes_dict, image=None)

        return figs
Ejemplo n.º 25
0
class spotWindow(QDialog):
    def __init__(self, input_folder, params, parent=None):
        super(spotWindow, self).__init__(parent)
        self.input_folder = input_folder
        self.params = params
        # load the first image to use for parameter definition and find out the number of channels
        _, cond = os.path.split(input_folder)
        save_folder = os.path.join(input_folder, 'result_segmentation')
        props = utils_postprocessing.load_morpho_params(save_folder, cond)
        props = {key: props[key][0] for key in props}
        mask_file = props['mask_file']
        path_to_mask = os.path.join(input_folder, mask_file)
        self.mask = imread(path_to_mask)[props['slice']].astype(np.float)
        input_file = props['input_file']
        path_to_file = os.path.join(input_folder, input_file)
        self.img = imread(path_to_file).astype(float)
        if len(self.img.shape) == 2:
            self.img = np.expand_dims(self.img, 0)
        self.img = np.array([img[props['slice']] for img in self.img])
        self.n_channels = self.img.shape[0]

        # if params are none, set them to default values
        params_default = [
            0.8, 2, 0, (2, self.img.shape[1] * self.img.shape[2])
        ]
        for i, p in enumerate(self.params):
            # if there is no channel indexing, create one if length 1
            if p == None:
                self.params[i] = [None for i in self.n_channels]
            # for every element in the channel indexing, if it is None, set it to defualt
            for ch in range(len(p)):
                if (p[ch] == None) or (p[ch] == (None, None)):
                    self.params[i][ch] = params_default[i]

        # create window
        self.initUI()
        self.updateParamsAndFigure()

    def initUI(self):
        self.figure = Figure(figsize=(10, 2.5), dpi=100)
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)

        self.figure.clear()
        axs = self.figure.subplots(nrows=1, ncols=4)
        self.figure.subplots_adjust(top=0.95,
                                    right=0.95,
                                    left=0.2,
                                    bottom=0.25)
        for i in [0, 1, 3]:
            axs[i].axis('off')
        axs[2].set_xlabel('Fluo')
        axs[2].ticklabel_format(axis="x", style="sci", scilimits=(2, 2))
        axs[2].set_ylabel('Counts')
        axs[2].ticklabel_format(axis="y", style="sci", scilimits=(0, 2))
        self.canvas.draw()

        self.channel = QSpinBox()
        self.channel.setMaximum(self.n_channels - 1)
        self.channel.valueChanged.connect(self.updateChannel)
        self.channel.setAlignment(Qt.AlignRight)

        self.enhancement = QDoubleSpinBox()
        self.enhancement.setMinimum(0)
        self.enhancement.setMaximum(1)
        self.enhancement.setSingleStep(0.05)
        self.enhancement.setValue(self.params[0][self.channel.value()])
        self.enhancement.setAlignment(Qt.AlignRight)

        self.nClasses = QSpinBox()
        self.nClasses.setMinimum(2)
        self.nClasses.setValue(self.params[1][self.channel.value()])
        self.nClasses.valueChanged.connect(self.updatenThrChoice)
        self.nClasses.setAlignment(Qt.AlignRight)

        self.nThr = QSpinBox()
        self.nThr.setValue(self.params[2][self.channel.value()])
        self.nThr.setAlignment(Qt.AlignRight)

        self.minSize = QSpinBox()
        self.minSize.setMaximum(self.img.shape[1] * self.img.shape[2])
        self.minSize.setValue(self.params[3][self.channel.value()][0])
        self.minSize.setAlignment(Qt.AlignRight)

        self.maxSize = QSpinBox()
        self.maxSize.setMaximum(self.img.shape[1] * self.img.shape[2])
        self.maxSize.setValue(self.img.shape[1] * self.img.shape[2])
        self.maxSize.setAlignment(Qt.AlignRight)

        applyButton = QPushButton('Apply params')
        applyButton.clicked.connect(self.updateParamsAndFigure)

        endButton = QPushButton('UPDATE AND RETURN PARAMS')
        endButton.clicked.connect(self.on_clicked)

        lay = QGridLayout(self)
        lay.addWidget(NavigationToolbar(self.canvas, self), 0, 0, 1, 2)
        lay.addWidget(self.canvas, 1, 0, 1, 2)
        lay.addWidget(QLabel('Current channel'), 2, 0, 1, 1)
        lay.addWidget(self.channel, 2, 1, 1, 1)
        lay.addWidget(QLabel('Enhancement'), 3, 0, 1, 1)
        lay.addWidget(self.enhancement, 3, 1, 1, 1)
        lay.addWidget(QLabel('Expected classes for thresholding'), 4, 0, 1, 1)
        lay.addWidget(self.nClasses, 4, 1, 1, 1)
        lay.addWidget(QLabel('Selected threshold'), 5, 0, 1, 1)
        lay.addWidget(self.nThr, 5, 1, 1, 1)
        lay.addWidget(QLabel('Minimum spot size'), 6, 0, 1, 1)
        lay.addWidget(self.minSize, 6, 1, 1, 1)
        lay.addWidget(QLabel('Maximum spot size'), 7, 0, 1, 1)
        lay.addWidget(self.maxSize, 7, 1, 1, 1)
        lay.addWidget(applyButton, 8, 0, 1, 2)
        lay.addWidget(endButton, 9, 0, 1, 2)

        self.setWindowTitle(self.input_folder)
        QApplication.setStyle('Macintosh')

    def updatenThrChoice(self):
        self.nThr.setMaximum(self.nClasses.value() - 2)

    def updateChannel(self):
        ch = self.channel.value()

        self.enhancement.setValue(self.params[0][ch])
        self.nClasses.setValue(self.params[1][ch])
        self.nThr.setValue(self.params[2][ch])
        self.minSize.setValue(self.params[3][ch][0])
        self.maxSize.setValue(self.params[3][ch][1])

        self.updateParamsAndFigure()

    def updateParamsAndFigure(self):
        from matplotlib import rc
        from matplotlib.backends.backend_pdf import PdfPages
        rc('font', size=8)
        rc('font', family='Arial')
        # rc('font', serif='Times')
        rc('pdf', fonttype=42)
        # rc('text', usetex=True)
        self.nThr.setMaximum(self.nClasses.value() - 2)

        ch = self.channel.value()
        enhancement = self.enhancement.value()
        nclasses = self.nClasses.value()
        nThr = self.nThr.value()
        sizelims = (self.minSize.value(), self.maxSize.value())
        dict_, enhanced, thrs, objects = utils_image.detect_peaks(
            self.img[ch],
            self.mask,
            enhancement=enhancement,
            nclasses=nclasses,
            nThr=nThr,
            sizelims=sizelims)

        ### update the values
        self.params[0][ch] = enhancement
        self.params[1][ch] = nclasses
        self.params[2][ch] = nThr
        self.params[3][ch] = sizelims

        ### update the plot
        self.figure.clear()
        axs = self.figure.subplots(nrows=1, ncols=4)
        self.figure.subplots_adjust(top=0.9, right=1., left=0.,
                                    bottom=0.2)  #,wspace=0.01)#,hspace=0.01)
        for i in [0, 1, 3]:
            axs[i].axis('off')
        axs[2].set_xlabel('Fluo')
        axs[2].ticklabel_format(axis="x", style="sci", scilimits=(2, 2))
        axs[2].set_ylabel('Counts')
        axs[2].ticklabel_format(axis="y", style="sci", scilimits=(0, 2))
        axs[0].set_title('Input image')
        axs[1].set_title('Enhanced image')
        axs[2].set_title('Histogram')
        axs[3].set_title('Segmented spots: %d' % len(dict_['centroid']))
        axs[2].set_yscale('log')

        axs[0].imshow(self.img[ch],
                      cmap='magma',
                      vmin=np.percentile(self.img[ch], 0.3),
                      vmax=np.percentile(self.img[ch], 99.7))
        axs[1].imshow(enhanced,
                      cmap='magma',
                      vmin=np.percentile(enhanced, 0.3),
                      vmax=np.percentile(enhanced, 99.7))
        n, _, _ = axs[2].hist(enhanced[self.mask > 0], bins=100)
        for thr in thrs:
            axs[2].plot([thr, thr], [0, np.max(n)], '-r')
        axs[2].plot([thrs[nThr]], [np.max(n)], '*r', ms=10)
        axs[3].imshow(objects, cmap='gray')
        for coords, area in zip(dict_['centroid'], dict_['area']):
            # draw circle around segmented coins
            circle = mpatches.Circle((coords[1], coords[0]),
                                     radius=np.sqrt(area / np.pi),
                                     fc=(1, 0, 0, 0.5),
                                     ec=(1, 0, 0, 1),
                                     linewidth=2)
            axs[3].add_patch(circle)
            # axs[3].plot(coords[1],coords[0],'+r',ms=5,alpha=.8)
        self.canvas.draw()

    @QtCore.pyqtSlot()
    def on_clicked(self):
        self.accept()
Ejemplo n.º 26
0
def visualize_image_attr_multiple(attr: ndarray,
                                  original_image: Union[None, ndarray],
                                  methods: List[str],
                                  signs: List[str],
                                  titles: Union[None, List[str]] = None,
                                  fig_size: Tuple[int, int] = (8, 6),
                                  use_pyplot: bool = True,
                                  **kwargs: Any):
    r"""
    Visualizes attribution using multiple visualization methods displayed
    in a 1 x k grid, where k is the number of desired visualizations.

    Args:

        attr (numpy.array): Numpy array corresponding to attributions to be
                    visualized. Shape must be in the form (H, W, C), with
                    channels as last dimension. Shape must also match that of
                    the original image if provided.
        original_image (numpy.array, optional):  Numpy array corresponding to
                    original image. Shape must be in the form (H, W, C), with
                    channels as the last dimension. Image can be provided either
                    with values in range 0-1 or 0-255. This is a necessary
                    argument for any visualization method which utilizes
                    the original image.
        methods (list of strings): List of strings of length k, defining method
                        for each visualization. Each method must be a valid
                        string argument for method to visualize_image_attr.
        signs (list of strings): List of strings of length k, defining signs for
                        each visualization. Each sign must be a valid
                        string argument for sign to visualize_image_attr.
        titles (list of strings, optional):  List of strings of length k, providing
                    a title string for each plot. If None is provided, no titles
                    are added to subplots.
                    Default: None
        fig_size (tuple, optional): Size of figure created.
                    Default: (8, 6)
        use_pyplot (boolean, optional): If true, uses pyplot to create and show
                    figure and displays the figure after creating. If False,
                    uses Matplotlib object oriented API and simply returns a
                    figure object without showing.
                    Default: True.
        **kwargs (Any, optional): Any additional arguments which will be passed
                    to every individual visualization. Such arguments include
                    `show_colorbar`, `alpha_overlay`, `cmap`, etc.


    Returns:
        2-element tuple of **figure**, **axis**:
        - **figure** (*matplotlib.pyplot.figure*):
                    Figure object on which visualization
                    is created. If plt_fig_axis argument is given, this is the
                    same figure provided.
        - **axis** (*matplotlib.pyplot.axis*):
                    Axis object on which visualization
                    is created. If plt_fig_axis argument is given, this is the
                    same axis provided.

    Examples::

        >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
        >>> # and returns an Nx10 tensor of class probabilities.
        >>> net = ImageClassifier()
        >>> ig = IntegratedGradients(net)
        >>> # Computes integrated gradients for class 3 for a given image .
        >>> attribution, delta = ig.attribute(orig_image, target=3)
        >>> # Displays original image and heat map visualization of
        >>> # computed attributions side by side.
        >>> _ = visualize_mutliple_image_attr(attribution, orig_image,
        >>>                     ["original_image", "heat_map"], ["all", "positive"])
    """
    assert len(methods) == len(
        signs), "Methods and signs array lengths must match."
    if titles is not None:
        assert len(methods) == len(titles), (
            "If titles list is given, length must "
            "match that of methods list.")
    if use_pyplot:
        plt_fig = plt.figure(figsize=fig_size)
    else:
        plt_fig = Figure(figsize=fig_size)
    plt_axis = plt_fig.subplots(1, len(methods))

    # When visualizing one
    if len(methods) == 1:
        plt_axis = [plt_axis]

    for i in range(len(methods)):
        visualize_image_attr(attr,
                             original_image=original_image,
                             method=methods[i],
                             sign=signs[i],
                             plt_fig_axis=(plt_fig, plt_axis[i]),
                             use_pyplot=False,
                             title=titles[i] if titles else None,
                             **kwargs)
    plt_fig.tight_layout()
    if use_pyplot:
        plt.show()
    return plt_fig, plt_axis
Ejemplo n.º 27
0
def genchart(option, df):
    """ Generating graphs based on user selected options """

    if option == "timeline":
        year_list = df.columns[2:-1]

        # configuring canvas and plot area
        fig = Figure(figsize=(18, 10), dpi=100)
        canvas = FigureCanvasAgg(fig)
        ax = fig.subplots()

        # generating the basic line chart
        chart = ax.plot(df[year_list].sum().index, df[year_list].sum().values,
                        'k-o')

        # improving chart appearance
        ax.set_title("Australia  Historical Migration [From 1945 to 2018]\n", {
            'fontsize': 18,
            "color": "k"
        },
                     loc='center')
        ax.set_ylabel("Number of Migrants\n", fontsize=16, color='k')
        ax.set_xlabel("Year", fontsize=16, color='k')
        ax.set_yticks(np.arange(0, 300000, 20000))
        ax.tick_params(axis='x', labelrotation=90)
        ax.tick_params(axis='both', labelsize=12, colors='k')

        # adding patches for each party-in-power period
        max_mig = df[year_list].sum(axis=0).max()
        min_mig = df[year_list].sum(axis=0).min()

        patch1 = Rectangle((0, min_mig),
                           2,
                           max_mig - min_mig,
                           color='red',
                           alpha=.3)
        patch2 = Rectangle((2, min_mig),
                           24,
                           max_mig - min_mig,
                           color='blue',
                           alpha=.3)
        patch3 = Rectangle((26, min_mig),
                           3,
                           max_mig - min_mig,
                           color='red',
                           alpha=.3)
        patch4 = Rectangle((29, min_mig),
                           8,
                           max_mig - min_mig,
                           color='blue',
                           alpha=.3)
        patch5 = Rectangle((37, min_mig),
                           13,
                           max_mig - min_mig,
                           color='red',
                           alpha=.3)
        patch6 = Rectangle((50, min_mig),
                           11,
                           max_mig - min_mig,
                           color='blue',
                           alpha=.3)
        patch7 = Rectangle((61, min_mig),
                           6,
                           max_mig - min_mig,
                           color='red',
                           alpha=.3)
        patch8 = Rectangle((67, min_mig),
                           5,
                           max_mig - min_mig,
                           color='blue',
                           alpha=.3)
        ax.add_patch(patch1)
        ax.add_patch(patch2)
        ax.add_patch(patch3)
        ax.add_patch(patch4)
        ax.add_patch(patch5)
        ax.add_patch(patch6)
        ax.add_patch(patch7)
        ax.add_patch(patch8)

        # adding the chart legend
        ax.legend([patch1, patch2], ['Labor in Power', 'Liberal in Power'],
                  loc='upper left',
                  fontsize='x-large')

        # adding total number of migrants to the chart
        total = df['sum'].sum()
        ax.annotate("Total Number of Migrants: {:,.0f}".format(total),
                    (45, 15000),
                    fontsize=14,
                    fontweight='bold')

        # Saving the chart in memory
        buf = BytesIO()
        fig.savefig(buf, format="png")
        infograph = base64.b64encode(buf.getbuffer()).decode("ascii")

    if option == "continent":

        # configuring canvas and plot area
        fig = Figure(figsize=(18, 10), dpi=100)
        canvas = FigureCanvasAgg(fig)
        ax = fig.subplots()

        # Generating the pie chart
        chart = ax.pie(
            df.groupby('continent').sum()['sum'].values,
            colors=["#999999", "#666666", "#e6e600", "#008080", "#808080"],
            explode=[.04, .04, 0, 0, .04],
            autopct="%1.1f%%",
            pctdistance=1.1,
            wedgeprops=dict(width=0.3, edgecolor='w'),
            textprops=dict(size=16, weight='bold'))

        # adding title to the chart
        ax.set_title("\nMigration per Continent \nAustralia [1945 - 2018]",
                     fontsize=18,
                     fontweight='bold')

        # adding legend to the chart
        leg_hand = []
        for continent in df.groupby('continent').sum().index:
            leg_hand.append("{}: {:,.0f}".format(
                continent,
                df.groupby('continent').sum().at[continent, 'sum']))
        ax.legend(leg_hand, loc='center', fontsize='xx-large')

        # Saving the chart in memory
        buf = BytesIO()
        fig.savefig(buf, format="png")
        infograph = base64.b64encode(buf.getbuffer()).decode("ascii")

    if option == "region":

        fig = Figure(figsize=(18, 10), dpi=100)
        canvas = FigureCanvasAgg(fig)
        ax = fig.subplots()

        # generate a horizontal barchart
        chart = ax.barh(
            df.groupby('region').sum().sort_values('sum',
                                                   axis=0,
                                                   ascending=False).index,
            df.groupby('region').sum().sort_values('sum',
                                                   axis=0,
                                                   ascending=False)['sum'],
            height=.5,
            color=[
                "#008080", "#e6e600", 'k', 'k', 'k', 'k', 'k', 'k', 'k', 'k',
                'k'
            ],
        )

        # improving the chart appearance
        ax.set_title("Migrants per World Region; Australia [1945 - 2018]",
                     fontsize=18)
        ax.set_xlabel("\nNumber of Migrants", fontsize=16, color="k")
        ax.set_ylabel("World Region", fontsize=16, color='k')
        ax.set_xticks(np.arange(0, 3500000, 250000))
        ax.tick_params(axis='both', labelsize=12, colors='k')

        # adding percent for each bar
        for bar in chart.patches:
            ax.text(1.01 * bar.get_width(), bar.get_y() + .3 * bar.get_height(), \
            '{:.1%}'.format(bar.get_width() / df['sum'].sum()), dict(fontsize=15, fontweight='bold'))

        # Saving the chart in memory
        buf = BytesIO()
        fig.savefig(buf, format="png")
        infograph = base64.b64encode(buf.getbuffer()).decode("ascii")

    if option == "topten":

        fig = Figure(figsize=(18, 10), dpi=100)
        canvas = FigureCanvasAgg(fig)
        ax = fig.subplots()

        chart = ax.barh(
            df['sum'].sort_values(ascending=False).head(10).index,
            df['sum'].sort_values(ascending=False).head(10).values,
            height=.5,
            color=[
                "#008080", "#e6e600", 'k', 'k', 'k', 'k', 'k', 'k', 'k', 'k'
            ],
        )

        # improving chart appearance
        ax.set_title(
            "Top 10 Countries; Migration to Australia [1945 - 2018]\n",
            dict(fontsize=16))
        ax.set_xlabel("\nNumber of Migrants", fontsize=16, color='k')
        ax.set_ylabel("Top 10 Countries", fontsize=16, color='k')
        ax.set_xticks(np.arange(0, 2750000, 250000))
        ax.tick_params(axis='both', labelsize=14, colors='k')

        # adding bar chart percent
        for bar in chart.patches:
            ax.text(1.01 * bar.get_width(), bar.get_y() + 0.3 * bar.get_height(), \
            "{:.1%}".format(bar.get_width() / df['sum'].sum()), dict(fontsize=16, fontweight='bold', color='k'))

        # Saving the chart in memory
        buf = BytesIO()
        fig.savefig(buf, format="png")
        infograph = base64.b64encode(buf.getbuffer()).decode("ascii")

    if option == "Asia" or option == 'Europe' or option == "Africa" or option == "America":

        fig = Figure(figsize=(18, 8), dpi=100)
        canvas = FigureCanvasAgg(fig)
        ax = fig.subplots()

        chart = ax.barh(
            df[df['continent'] == option]["sum"].sort_values(
                ascending=False).head(5).index,
            df[df['continent'] == option]["sum"].sort_values(
                ascending=False).head(5).values,
            height=.3,
            color=["#b30000", "k", 'k', 'k', 'k'],
        )

        # improving chart appearance
        ax.set_title(
            "Top Five Countries in {}; Migration to Australia [1945 - 2018]\n".
            format(option), dict(fontsize=18))
        ax.set_xlabel("\nNumber of Migrants", fontsize=16, color='k')
        ax.set_ylabel("Top Five Countries\n", fontsize=16, color='k')

        MAX = df[df['continent'] == option]["sum"].sort_values(
            ascending=False).head(5).values.max()
        ax.set_xticks(
            np.arange(
                0, MAX + MAX / 10,
                int(str(int(MAX))[0]) * (10**len(str(int(MAX / 10))) / 10)))

        ax.tick_params(axis='both', labelsize=14, colors='k')

        # adding bar chart percent
        for bar in chart.patches:
            ax.text(1.01 * bar.get_width(), bar.get_y() + 0.3 * bar.get_height(), \
            "{:.1%}".format(bar.get_width() / df['sum'].sum()), dict(fontsize=16, fontweight='bold', color='k'))

        # Saving the chart in memory
        buf = BytesIO()
        fig.savefig(buf, format="png")
        infograph = base64.b64encode(buf.getbuffer()).decode("ascii")

    if option == "map":
        """ generating a leaflet map using folium library"""

        # loading geojson overlay
        with open("world.json") as data:
            wjson = json.load(data)

        year_list = df.columns[2:-1]
        # creating dataframe required to used during parsing geojson overlay
        map_df = df[year_list].sum(axis=1).to_frame().reset_index()
        # just for more readable legend on the map
        map_df[0] = map_df[0] / 1000000

        MigMap = folium.Map(location=[30, 10],
                            zoom_start=1.5,
                            tiles='OpenStreetMap')

        folium.Choropleth(
            geo_data=wjson,
            name='choropleth',
            data=map_df,
            columns=['country', 0],
            key_on='feature.properties.name',
            fill_color='YlOrRd',
            fill_opacity=0.8,
            line_opacity=0.2,
            nan_fill_color='white',
            nan_fill_opacity=0.3,
            legend_name='Migration to Australia from 1945 to 2018 (million)',
            bins=[0, 0.15, 0.3, 0.5, 1, 2.5]).add_to(MigMap)

        return MigMap._repr_html_()
        #return MigMap._repr_html_() #it does the same as above line of code

    return f"data:image/png;base64,{infograph}"
Ejemplo n.º 28
0
def visualize_image_attr(
        attr: ndarray,
        original_image: Union[None, ndarray] = None,
        method: str = "heat_map",
        sign: str = "absolute_value",
        plt_fig_axis: Union[None, Tuple[figure, axis]] = None,
        outlier_perc: Union[int, float] = 2,
        cmap: Union[None, str] = None,
        alpha_overlay: float = 0.5,
        show_colorbar: bool = False,
        title: Union[None, str] = None,
        fig_size: Tuple[int, int] = (6, 6),
        use_pyplot: bool = True,
):
    r"""
    Visualizes attribution for a given image by normalizing attribution values
    of the desired sign (positive, negative, absolute value, or all) and displaying
    them using the desired mode in a matplotlib figure.

    Args:

        attr (numpy.array): Numpy array corresponding to attributions to be
                    visualized. Shape must be in the form (H, W, C), with
                    channels as last dimension. Shape must also match that of
                    the original image if provided.
        original_image (numpy.array, optional):  Numpy array corresponding to
                    original image. Shape must be in the form (H, W, C), with
                    channels as the last dimension. Image can be provided either
                    with float values in range 0-1 or int values between 0-255.
                    This is a necessary argument for any visualization method
                    which utilizes the original image.
                    Default: None
        method (string, optional): Chosen method for visualizing attribution.
                    Supported options are:

                    1. `heat_map` - Display heat map of chosen attributions

                    2. `blended_heat_map` - Overlay heat map over greyscale
                       version of original image. Parameter alpha_overlay
                       corresponds to alpha of heat map.

                    3. `original_image` - Only display original image.

                    4. `masked_image` - Mask image (pixel-wise multiply)
                       by normalized attribution values.

                    5. `alpha_scaling` - Sets alpha channel of each pixel
                       to be equal to normalized attribution value.
                    Default: `heat_map`
        sign (string, optional): Chosen sign of attributions to visualize. Supported
                    options are:

                    1. `positive` - Displays only positive pixel attributions.

                    2. `absolute_value` - Displays absolute value of
                       attributions.

                    3. `negative` - Displays only negative pixel attributions.

                    4. `all` - Displays both positive and negative attribution
                       values. This is not supported for `masked_image` or
                       `alpha_scaling` modes, since signed information cannot
                       be represented in these modes.
                    Default: `absolute_value`
        plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis
                    on which to visualize. If None is provided, then a new figure
                    and axis are created.
                    Default: None
        outlier_perc (float or int, optional): Top attribution values which
                    correspond to a total of outlier_perc percentage of the
                    total attribution are set to 1 and scaling is performed
                    using the minimum of these values. For sign=`all`, outliers a
                    nd scale value are computed using absolute value of
                    attributions.
                    Default: 2
        cmap (string, optional): String corresponding to desired colormap for
                    heatmap visualization. This defaults to "Reds" for negative
                    sign, "Blues" for absolute value, "Greens" for positive sign,
                    and a spectrum from red to green for all. Note that this
                    argument is only used for visualizations displaying heatmaps.
                    Default: None
        alpha_overlay (float, optional): Alpha to set for heatmap when using
                    `blended_heat_map` visualization mode, which overlays the
                    heat map over the greyscaled original image.
                    Default: 0.5
        show_colorbar (boolean, optional): Displays colorbar for heatmap below
                    the visualization. If given method does not use a heatmap,
                    then a colormap axis is created and hidden. This is
                    necessary for appropriate alignment when visualizing
                    multiple plots, some with colorbars and some without.
                    Default: False
        title (string, optional): Title string for plot. If None, no title is
                    set.
                    Default: None
        fig_size (tuple, optional): Size of figure created.
                    Default: (6,6)
        use_pyplot (boolean, optional): If true, uses pyplot to create and show
                    figure and displays the figure after creating. If False,
                    uses Matplotlib object oriented API and simply returns a
                    figure object without showing.
                    Default: True.

    Returns:
        2-element tuple of **figure**, **axis**:
        - **figure** (*matplotlib.pyplot.figure*):
                    Figure object on which visualization
                    is created. If plt_fig_axis argument is given, this is the
                    same figure provided.
        - **axis** (*matplotlib.pyplot.axis*):
                    Axis object on which visualization
                    is created. If plt_fig_axis argument is given, this is the
                    same axis provided.

    Examples::

        >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
        >>> # and returns an Nx10 tensor of class probabilities.
        >>> net = ImageClassifier()
        >>> ig = IntegratedGradients(net)
        >>> # Computes integrated gradients for class 3 for a given image .
        >>> attribution, delta = ig.attribute(orig_image, target=3)
        >>> # Displays blended heat map visualization of computed attributions.
        >>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
    """
    # Create plot if figure, axis not provided
    if plt_fig_axis is not None:
        plt_fig, plt_axis = plt_fig_axis
    else:
        if use_pyplot:
            plt_fig, plt_axis = plt.subplots(figsize=fig_size)
        else:
            plt_fig = Figure(figsize=fig_size)
            plt_axis = plt_fig.subplots()

    if original_image is not None:
        if np.max(original_image) <= 1.0:
            original_image = _prepare_image(original_image * 255)
    else:
        assert (
            ImageVisualizationMethod[method] ==
            ImageVisualizationMethod.heat_map
        ), "Original Image must be provided for any visualization other than heatmap."

    # Remove ticks and tick labels from plot.
    plt_axis.xaxis.set_ticks_position("none")
    plt_axis.yaxis.set_ticks_position("none")
    plt_axis.set_yticklabels([])
    plt_axis.set_xticklabels([])
    plt_axis.grid(b=False)

    heat_map = None
    # Show original image
    if ImageVisualizationMethod[
            method] == ImageVisualizationMethod.original_image:
        if len(original_image.shape) > 2 and original_image.shape[2] == 1:
            original_image = np.squeeze(original_image, axis=2)
        plt_axis.imshow(original_image)
    else:
        # Choose appropriate signed attributions and normalize.
        norm_attr = _normalize_image_attr(attr, sign, outlier_perc)

        # Set default colormap and bounds based on sign.
        if VisualizeSign[sign] == VisualizeSign.all:
            default_cmap = LinearSegmentedColormap.from_list(
                "RdWhGn", ["red", "white", "green"])
            vmin, vmax = -1, 1
        elif VisualizeSign[sign] == VisualizeSign.positive:
            default_cmap = "Greens"
            vmin, vmax = 0, 1
        elif VisualizeSign[sign] == VisualizeSign.negative:
            default_cmap = "Reds"
            vmin, vmax = 0, 1
        elif VisualizeSign[sign] == VisualizeSign.absolute_value:
            default_cmap = "Blues"
            vmin, vmax = 0, 1
        else:
            raise AssertionError("Visualize Sign type is not valid.")
        cmap = cmap if cmap is not None else default_cmap

        # Show appropriate image visualization.
        if ImageVisualizationMethod[
                method] == ImageVisualizationMethod.heat_map:
            heat_map = plt_axis.imshow(norm_attr,
                                       cmap=cmap,
                                       vmin=vmin,
                                       vmax=vmax)
        elif (ImageVisualizationMethod[method] ==
              ImageVisualizationMethod.blended_heat_map):
            plt_axis.imshow(np.mean(original_image, axis=2), cmap="gray")
            heat_map = plt_axis.imshow(norm_attr,
                                       cmap=cmap,
                                       vmin=vmin,
                                       vmax=vmax,
                                       alpha=alpha_overlay)
        elif ImageVisualizationMethod[
                method] == ImageVisualizationMethod.masked_image:
            assert VisualizeSign[sign] != VisualizeSign.all, (
                "Cannot display masked image with both positive and negative "
                "attributions, choose a different sign option.")
            plt_axis.imshow(
                _prepare_image(original_image * np.expand_dims(norm_attr, 2)))
        elif ImageVisualizationMethod[
                method] == ImageVisualizationMethod.alpha_scaling:
            assert VisualizeSign[sign] != VisualizeSign.all, (
                "Cannot display alpha scaling with both positive and negative "
                "attributions, choose a different sign option.")
            plt_axis.imshow(
                np.concatenate(
                    [
                        original_image,
                        _prepare_image(np.expand_dims(norm_attr, 2) * 255),
                    ],
                    axis=2,
                ))
        else:
            raise AssertionError("Visualize Method type is not valid.")

    # Add colorbar. If given method is not a heatmap and no colormap is relevant,
    # then a colormap axis is created and hidden. This is necessary for appropriate
    # alignment when visualizing multiple plots, some with heatmaps and some
    # without.
    if show_colorbar:
        axis_separator = make_axes_locatable(plt_axis)
        colorbar_axis = axis_separator.append_axes("bottom",
                                                   size="5%",
                                                   pad=0.1)
        if heat_map:
            plt_fig.colorbar(heat_map,
                             orientation="horizontal",
                             cax=colorbar_axis)
        else:
            colorbar_axis.axis("off")
    if title:
        plt_axis.set_title(title)

    if use_pyplot:
        plt.show()

    return plt_fig, plt_axis
Ejemplo n.º 29
0
# -*- coding: utf-8 -*-
# https://matplotlib.org/3.2.1/gallery/pie_and_polar_charts/pie_features.html

import sys

from matplotlib.backends.backend_qt5agg import FigureCanvas
from matplotlib.figure import Figure
from PyQt5.QtWidgets import QApplication

app = QApplication(sys.argv)

fig = Figure(figsize=(8, 6))
canvas = FigureCanvas(fig)
canvas.resize(640, 480)
canvas.show()

# Pie chart, where the slices will be ordered and plotted counter-clockwise:
labels = "Frogs", "Hogs", "Dogs", "Logs"
sizes = [15, 30, 45, 10]
explode = (0, 0.1, 0, 0)  # only "explode" the 2nd slice (i.e. 'Hogs')

ax1 = fig.subplots()
ax1.pie(
    sizes, explode=explode, labels=labels, autopct="%1.1f%%", shadow=True, startangle=90
)
ax1.axis("equal")  # Equal aspect ratio ensures that pie is drawn as a circle.

sys.exit(app.exec_())
Ejemplo n.º 30
0
def _default_plot_numpy(x_data,
                        y_data,
                        fig=None,
                        ax=None,
                        theory_func=None,
                        theory_args=(),
                        theory_kw={},
                        theory_x_data=None,
                        theory_y_data=None,
                        subplot_kw={},
                        line_kw={},
                        theory_name='Theory',
                        fit_func=None,
                        plot_type='line'):

    if fig is None:
        fig = Figure()
    if ax is None:
        if len(fig.axes) == 0:
            ax = fig.subplots(subplot_kw=subplot_kw)
        elif len(fig.axes) >= 1:
            ax = fig.axes[0]
    if 'xlabel' in subplot_kw:
        if subplot_kw['xlabel'] == ax.get_xlabel() and \
            subplot_kw['ylabel'] != ax.get_ylabel():
            ax = ax.twinx()
            twinned = True
            if 'xlabel' in subplot_kw:
                ax.set_xlabel(subplot_kw['xlabel'])
            if 'ylabel' in subplot_kw:
                ax.set_ylabel(subplot_kw['ylabel'])
            if 'xscale' in subplot_kw:
                ax.set_xscale(subplot_kw['xscale'])
            if 'yscale' in subplot_kw:
                ax.set_xscale(subplot_kw['yscale'])
            line_kw = dict(color=cmap(1), **line_kw)

    if plot_type == 'line':
        ax.plot(x_data, y_data, **line_kw)
    elif plot_type == 'scatter':
        ax.scatter(x_data, y_data, **line_kw)
    else:
        raise ValueError(
            f'Plot type {plot_type} is unavailable. Only "line" and "scatter" are implemented'
        )

    if fit_func is not None:
        theory_args, pcov = curve_fit(fit_func, x_data, y_data)
        theory_func = fit_func
        print(f'Fit params: {theory_args}')

    if theory_func:
        ax.plot(x_data,
                theory_func(x_data, *theory_args, **theory_kw),
                linestyle='dashed',
                **line_kw)
        if plot_type == 'line':
            ax.legend(['Measured', theory_name])
        else:
            ax.legend([theory_name, 'Measured'])

    if theory_x_data is not None and theory_y_data is not None:
        ax.plot(theory_x_data, theory_y_data, linestyle='dashed', **line_kw)
        if plot_type == 'line':
            ax.legend(['Measured', theory_name])
        else:
            ax.legend([theory_name, 'Measured'])
        if 'xlim' not in subplot_kw:
            xlim_lower = min(x_data) - abs(min(x_data)) * 0.1
            xlim_higher = max(x_data) + abs(max(x_data)) * 0.1
            ax.set_xlim(xlim_lower, xlim_higher)

    prettifyPlot(fig=fig)
    return fig, ax
Ejemplo n.º 31
0
class FourierDemoFrame(wx.Frame):
    def __init__(self, *args, **kwargs):
        wx.Frame.__init__(self, *args, **kwargs)
        panel = wx.Panel(self)

        # create the GUI elements
        self.createCanvas(panel)
        self.createSliders(panel)

        # place them in a sizer for the Layout
        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(self.canvas, 1, wx.EXPAND)
        sizer.Add(self.frequencySliderGroup.sizer, 0,
                  wx.EXPAND | wx.ALIGN_CENTER | wx.ALL, border=5)
        sizer.Add(self.amplitudeSliderGroup.sizer, 0,
                  wx.EXPAND | wx.ALIGN_CENTER | wx.ALL, border=5)
        panel.SetSizer(sizer)

    def createCanvas(self, parent):
        self.lines = []
        self.figure = Figure()
        self.canvas = FigureCanvas(parent, -1, self.figure)
        self.canvas.callbacks.connect('button_press_event', self.mouseDown)
        self.canvas.callbacks.connect('motion_notify_event', self.mouseMotion)
        self.canvas.callbacks.connect('button_release_event', self.mouseUp)
        self.state = ''
        self.mouseInfo = (None, None, None, None)
        self.f0 = Param(2., minimum=0., maximum=6.)
        self.A = Param(1., minimum=0.01, maximum=2.)
        self.createPlots()

        # Not sure I like having two params attached to the same Knob,
        # but that is what we have here... it works but feels kludgy -
        # although maybe it's not too bad since the knob changes both params
        # at the same time (both f0 and A are affected during a drag)
        self.f0.attach(self)
        self.A.attach(self)

    def createSliders(self, panel):
        self.frequencySliderGroup = SliderGroup(
            panel,
            label='Frequency f0:',
            param=self.f0)
        self.amplitudeSliderGroup = SliderGroup(panel, label=' Amplitude a:',
                                                param=self.A)

    def mouseDown(self, evt):
        if self.lines[0].contains(evt)[0]:
            self.state = 'frequency'
        elif self.lines[1].contains(evt)[0]:
            self.state = 'time'
        else:
            self.state = ''
        self.mouseInfo = (evt.xdata, evt.ydata,
                          max(self.f0.value, .1),
                          self.A.value)

    def mouseMotion(self, evt):
        if self.state == '':
            return
        x, y = evt.xdata, evt.ydata
        if x is None:  # outside the axes
            return
        x0, y0, f0Init, AInit = self.mouseInfo
        self.A.set(AInit + (AInit * (y - y0) / y0), self)
        if self.state == 'frequency':
            self.f0.set(f0Init + (f0Init * (x - x0) / x0))
        elif self.state == 'time':
            if (x - x0) / x0 != -1.:
                self.f0.set(1. / (1. / f0Init + (1. / f0Init * (x - x0) / x0)))

    def mouseUp(self, evt):
        self.state = ''

    def createPlots(self):
        # This method creates the subplots, waveforms and labels.
        # Later, when the waveforms or sliders are dragged, only the
        # waveform data will be updated (not here, but below in setKnob).
        self.subplot1, self.subplot2 = self.figure.subplots(2)
        x1, y1, x2, y2 = self.compute(self.f0.value, self.A.value)
        color = (1., 0., 0.)
        self.lines += self.subplot1.plot(x1, y1, color=color, linewidth=2)
        self.lines += self.subplot2.plot(x2, y2, color=color, linewidth=2)
        # Set some plot attributes
        self.subplot1.set_title(
            "Click and drag waveforms to change frequency and amplitude",
            fontsize=12)
        self.subplot1.set_ylabel("Frequency Domain Waveform X(f)", fontsize=8)
        self.subplot1.set_xlabel("frequency f", fontsize=8)
        self.subplot2.set_ylabel("Time Domain Waveform x(t)", fontsize=8)
        self.subplot2.set_xlabel("time t", fontsize=8)
        self.subplot1.set_xlim([-6, 6])
        self.subplot1.set_ylim([0, 1])
        self.subplot2.set_xlim([-2, 2])
        self.subplot2.set_ylim([-2, 2])
        self.subplot1.text(0.05, .95,
                           r'$X(f) = \mathcal{F}\{x(t)\}$',
                           verticalalignment='top',
                           transform=self.subplot1.transAxes)
        self.subplot2.text(0.05, .95,
                           r'$x(t) = a \cdot \cos(2\pi f_0 t) e^{-\pi t^2}$',
                           verticalalignment='top',
                           transform=self.subplot2.transAxes)

    def compute(self, f0, A):
        f = np.arange(-6., 6., 0.02)
        t = np.arange(-2., 2., 0.01)
        x = A * np.cos(2 * np.pi * f0 * t) * np.exp(-np.pi * t ** 2)
        X = A / 2 * \
            (np.exp(-np.pi * (f - f0) ** 2) + np.exp(-np.pi * (f + f0) ** 2))
        return f, X, t, x

    def setKnob(self, value):
        # Note, we ignore value arg here and just go by state of the params
        x1, y1, x2, y2 = self.compute(self.f0.value, self.A.value)
        # update the data of the two waveforms
        self.lines[0].set(xdata=x1, ydata=y1)
        self.lines[1].set(xdata=x2, ydata=y2)
        # make the canvas draw its contents again with the new data
        self.canvas.draw()
Ejemplo n.º 32
0
class VelociraptorGui(tk.Tk):
    def __init__(self, parent):
        tk.Tk.__init__(self, parent)
        self.parent = parent
        self._initialize()
        self.protocol('WM_DELETE_WINDOW', self.quit)
        self.option_add('*tearOff', False)

    def _initialize(self):
        """Setup the GUI widgets.

        +----------------------------------+
        | Add button   Edit button         |
        +----------------------------------+
        |                     |            |
        |                     | Graph view |
        |                     |            |
        |     Rides view      |------------|
        |                     |            |
        |                     | Stats view |
        |                     |            |
        +----------------------------------+

        """
        self.buttonbox()
        self._init_rides_view()
        self._init_graph_view()
        self._init_stats_view()
        self.grid_columnconfigure(0, weight=1)
        self.grid_columnconfigure(1, weight=1)
        self.grid_rowconfigure(1, weight=1)
        self.focus_set()

    def buttonbox(self):
        """Add standard button box."""
        box = ttk.Frame(self)
        add_button = ttk.Button(box, text='Ajouter', command=self.add_ride)
        add_button.pack(side=tk.LEFT)
        edit_button = ttk.Button(box, text='Modifier', command=self.edit_ride)
        edit_button.pack(side=tk.LEFT)
        del_button = ttk.Button(box, text='Effacer', command=self.del_ride)
        del_button.pack(side=tk.LEFT)

        self.year = tk.StringVar()
        self.year_combo = ttk.Combobox(box, textvariable=self.year, width=10)
        self.year_combo.bind('<<ComboboxSelected>>', self.change_year)
        self.year_combo.pack(side=tk.RIGHT)
        self.year_combo.state(['readonly'])

        box.grid(column=0, row=0, columnspan=2, sticky='ew', ipadx=5, ipady=5)

    def _init_rides_view(self):
        colnames = ['id', 'Date', 'Distance (km)', 'Durée (h)',
                    'Vitesse (km/h)', 'Commentaire', 'url']
        self.rides_container = ttk.Frame(self)
        self.rides_container.grid(column=0, row=1, rowspan=2, sticky='ewns')
        self.rides_view = ttk.Treeview(self.rides_container, columns=colnames,
                selectmode='browse')
        # Make rides view resizable.
        self.rides_container.grid_columnconfigure(0, weight=1)
        self.rides_container.grid_rowconfigure(0, weight=1)
        self.rides_view.grid(column=0, row=0, sticky='ewns')

        # Add scrollbars
        vsb = ttk.Scrollbar(self.rides_container, orient='vertical',
                command=self.rides_view.yview)
        hsb = ttk.Scrollbar(self.rides_container, orient='horizontal',
                command=self.rides_view.xview)
        self.rides_view.configure(xscrollcommand=hsb.set,
                yscrollcommand=vsb.set)
        hsb.grid(column=0, row=1, sticky='ew')
        vsb.grid(column=1, row=0, sticky='ns')

        # Adjust columns
        self.rides_view.column('#0', width=0, stretch=False)
        for col in colnames:
            self.rides_view.heading(col, text=col)
            width = tkfont.Font().measure(col) + 10
            self.rides_view.column(col, minwidth=width, width=width)
        id_width = tkfont.Font().measure('9999') + 10
        self.rides_view.column('id', width=id_width, minwidth=id_width,
                anchor=tk.CENTER)
        date_width = 100
        self.rides_view.column('Date', width=date_width, minwidth=date_width,
                anchor=tk.CENTER)
        comment_width = 160
        self.rides_view.column('Commentaire', width=comment_width,
                minwidth=comment_width)
        for col in ['Distance (km)', 'Durée (h)', 'Vitesse (km/h)']:
            self.rides_view.column(col, anchor=tk.E)

        # Bind double click events
        self.rides_view.bind('<Double-1>', self.edit_ride)

        # Populate the view with data
        self.year.set(str(datetime.datetime.now().year))
        self.load_data()
        self.update_rides_view()
        if self.years:
            self.year.set(self.years[0])

    def change_year(self, event):
        self.update_rides_view()
        self.update_graph_view()
        self.update_stats()

    def load_data(self):
        self.rides = bike.read_db_file(year='all')
        self.years = sorted(list(set(ride['timestamp'].year for ride in
            self.rides)), reverse=True)
        self.year_combo['values'] = self.years

    def update_rides_view(self):
        self.viewable_rides = [ride for ride in self.rides if
                ride['timestamp'].year == int(self.year.get())]
        self.rides_view.delete(*self.rides_view.get_children())
        for ride in self.viewable_rides:
            self.rides_view.insert('', 'end', values=format_ride(ride))

    def get_graph_data(self):
        cumsum = list(itertools.accumulate(ride['distance'] for ride in
                        self.viewable_rides))
        dates = [ride['timestamp'] for ride in self.viewable_rides]
        speeds = [ride['distance'] / ride['duration'] for ride in
                self.viewable_rides]
        return cumsum, dates, speeds

    def update_graph_view(self):
        cumsum, dates, speeds = self.get_graph_data()
        self.ax1.clear()
        self.ax2.clear()
        self.ax1.plot(dates, cumsum)
        self.ax2.plot(dates, speeds)
        self.ax1.set_ylabel('distance (km)')
        self.ax2.set_ylabel('vitesse (km/h)')
        self.fig.autofmt_xdate()
        self.graph_view.draw()
        
    def _init_graph_view(self):
        matplotlib.style.use('ggplot')
        self.fig = Figure(figsize=(4, 4), tight_layout=True)
        cumsum, dates, speeds = self.get_graph_data()
        self.ax1, self.ax2 = self.fig.subplots(2, sharex=True)
        self.ax1.plot(dates, cumsum)
        self.ax2.plot(dates, speeds)
        self.ax1.set_ylabel('distance (km)')
        self.ax2.set_ylabel('vitesse (km/h)')
        self.fig.autofmt_xdate()
        self.graph_view = FigureCanvasTkAgg(self.fig, master=self)
        self.graph_view.draw()

        self.graph_view.get_tk_widget().grid(column=1, row=1, sticky='nsew')

    def update_stats(self):
        stats = bike.get_stats(self.viewable_rides)
        stats_text = 'Distance totale : {:.1f} km\n'.format(stats['tot_distance'])
        stats_text += 'Durée totale : {:.1f} h\n'.format(stats['tot_duration'])
        stats_text += 'Distance moyenne : {:.1f} km\n'.format(stats['mean_distance'])
        stats_text += 'Durée moyenne : {:.1f} h\n'.format(stats['mean_duration'])
        stats_text += 'Vitesse moyenne : {:.2f} km/h\n'.format(stats['speed'])
        self.stats_text.set(stats_text)

    def _init_stats_view(self):
        self.stats_text = tk.StringVar()
        self.stats_view = tk.Label(self, textvariable=self.stats_text,
                justify=tk.LEFT, anchor=tk.NW)
        self.stats_view.grid(column=1, row=2)
        self.update_stats()

    def add_ride(self):
        dialog = RideDetailDialog(self, 'Ajouter une randonnée')
        result = dialog.result
        dialog.destroy()
        if result:
            self.load_data()
            self.update_rides_view()
            self.update_graph_view()
            self.update_stats()

    def edit_ride(self, event=None):
        try:
            iid = self.rides_view.selection()[0]
        except IndexError:
            return
        rideid = int(self.rides_view.item(iid, 'values')[0])
        ride = self.rides[rideid]
        dialog = RideDetailDialog(self, 'Modifier une randonnée', ride=ride)
        result = dialog.result
        dialog.destroy()
        if result:
            self.load_data()
            self.update_rides_view()
            self.update_graph_view()
            self.update_stats()

    def del_ride(self):
        try:
            iid = self.rides_view.selection()[0]
        except IndexError:
            return
        rideid = int(self.rides_view.item(iid, 'values')[0])
        self.rides.pop(rideid)
        bike.update_db(self.rides)
        self.load_data()
        self.update_rides_view()
        self.update_graph_view()
        self.update_stats()