예제 #1
0
    def slider(self, N):
        barpos = plt.axes([0.18, 0.01, 0.70, 0.03], facecolor="gray")
        slider = Slider(barpos, '', 0, len(self.lista_hora) - N, valinit=0)
        slider.on_changed(self.bar)

        rax = plt.axes([0.01, 0.30, 0.12, 0.65])
        check = CheckButtons(
            rax,
            ('C. Fund.', 'Arm. 1', 'Arm. 2', 'Arm. 3', 'Arm. 4', 'Arm. 5',
             'Arm. 6', 'Arm. 7', 'Arm. 8', 'Arm. 9', 'Arm. 10', 'Arm. 11'),
            (True, True, True, True, True, True, True, True, True, True, True,
             True))
        check.get_status()

        check.on_clicked(self.func)
        self.bar(0)
        plt.show()
예제 #2
0
class button_manager:
    ''' Handles some missing features of matplotlib check buttons
    on init:
        creates button, links to button_click routine,
        calls call_on_click with active index and firsttime=True
    on click:
        maintains single button on state, calls call_on_click
    '''

    #@output.capture()  # debug
    def __init__(self, fig, dim, labels, init, call_on_click):
        '''
        dim: (list)     [leftbottom_x,bottom_y,width,height]
        labels: (list)  for example ['1','2','3','4','5','6']
        init: (list)    for example [True, False, False, False, False, False]
        '''
        self.fig = fig
        self.ax = plt.axes(dim)  #lx,by,w,h
        self.init_state = init
        self.call_on_click = call_on_click
        self.button = CheckButtons(self.ax, labels, init)
        self.button.on_clicked(self.button_click)
        self.status = self.button.get_status()
        self.call_on_click(self.status.index(True), firsttime=True)

    #@output.capture()  # debug
    def reinit(self):
        self.status = self.init_state
        self.button.set_active(self.status.index(
            True))  #turn off old, will trigger update and set to status

    #@output.capture()  # debug
    def button_click(self, event):
        ''' maintains one-on state. If on-button is clicked, will process correctly '''
        #new_status = self.button.get_status()
        #new = [self.status[i] ^ new_status[i] for i in range(len(self.status))]
        #newidx = new.index(True)
        self.button.eventson = False
        self.button.set_active(
            self.status.index(True))  #turn off old or reenable if same
        self.button.eventson = True
        self.status = self.button.get_status()
        self.call_on_click(self.status.index(True))
예제 #3
0
    def genCombineResalePrices(yearEnter, monthNum, p_leaseEnter, flatType):
        import numpy as np
        from datetime import datetime
        from matplotlib.widgets import CheckButtons
        global townSelect, towndata
        ## To get data within the Period from parameter pass of p_yearEnter p_monthEnter to current date
        dateStart = yearEnter + monthNum[-2:]
        firstStartDate = yearEnter + "-" + monthNum[-2:]
        dateEnd = datetime.now().strftime('%Y%m')
        dataPeriod = data[data['month'] == firstStartDate]
        for i in range(int(dateStart) + 1, int(dateEnd) + 1):
            strDate = str(i)
            last2digit = int(strDate[-2:])
            if last2digit < 13:
                iStrDate = str(i)
                iDate = iStrDate[0:4] + "-" + iStrDate[-2:]
                idata = data[data['month'] == iDate]
                dataPeriod = np.append(dataPeriod, idata)
        ## To get data within the Flat lease period (Flat Lease > parameter p_leaseEnter)
        leaseStart = int(p_leaseEnter) + 1
        leaseData = dataPeriod[dataPeriod['remaining_lease'] > leaseStart]
        flatTypeData = leaseData[leaseData['flat_type'] == flatType]
        towndata = flatTypeData

        ## To get data within the Selected Checkbox Town Line
        line_labels = [str(line.get_label()) for line in checkBoxLines]
        line_visibility = [line.get_visible() for line in checkBoxLines]
        check = CheckButtons(rax, line_labels, line_visibility)
        checkButtonThemes(check)

        status = check.get_status()

        #index = line_labels.index(label)
        #checkBoxLines[index].set_visible(not checkBoxLines[index].get_visible())
        statusIndex = np.arange(0, len(status))
        resale_prices = flatTypeData['resale_price']
        town_resale_prices = np.zeros(len(townList), object)
        for t in townIndex:
            town_resale_prices[t] = resale_prices[towndata['town'] ==
                                                  townList[t]]
        townSelect = np.zeros(len(townList), object)
        townName = []
        resale_prices_combined = []
        labelIndex = -1
        for s in statusIndex:
            if status[s]:
                labelIndex += 1
                townSelect[labelIndex] = towndata[towndata['town'] ==
                                                  townList[s]]
                #townSelect = np.append(townSelect,sdata)
                townName.append(line_labels[s])
                resale_prices_combined.append(town_resale_prices[s])
        return resale_prices_combined, townName
    def genCombineResalePrices(yearEnter, quarterNum, flatType):
        import numpy as np
        from datetime import datetime
        from matplotlib.widgets import CheckButtons
        global townSelect, towndata
        ## To get data within the Period from parameter pass of p_yearEnter p_quarterEnter to current date
        startYear = int(yearEnter)
        quarterNum = '2'
        dateEnd = datetime.now().strftime('%Y')
        endYear = int(dateEnd)
        dataPeriod = data[data['quarter'] == '2007-02']
        tdata = data[data['quarter'] == '2007-03']
        dataPeriod = np.append(dataPeriod, tdata)
        tdata = data[data['quarter'] == '2007-04']
        dataPeriod = np.append(dataPeriod, tdata)
        for i in range(startYear + 1, endYear + 1):
            for j in range(1, 5):
                if not (i == startYear and j == 2):
                    iDate = str(i) + "-Q" + str(j)
                    idata = data[data['quarter'] == iDate]
                    dataPeriod = np.append(dataPeriod, idata)
        towndata = dataPeriod[dataPeriod['flat_type'] == flatType]

        ## To get data within the Selected Checkbox Town Line
        line_labels = [str(line.get_label()) for line in checkBoxLines]
        line_visibility = [line.get_visible() for line in checkBoxLines]
        check = CheckButtons(rax, line_labels, line_visibility)
        checkButtonThemes(check)
        status = check.get_status()
        prices = towndata['price']
        quarter = towndata['quarter']
        town_prices = np.zeros(len(townList), object)
        town_quarter = np.zeros(len(townList), object)
        for t in townIndex:
            town_prices[t] = prices[towndata['town'] == townList[t]]
            town_quarter[t] = quarter[towndata['town'] == townList[t]]
        return town_prices, town_quarter
예제 #5
0
class CurrentView(object):
    def __new__(cls, path='../numpydata/'):
        # print(path)
        return super(CurrentView, cls).__new__(cls)

    def __init__(self, path='../numpydata/'):
        # self.channel = 2 # default channel

        self.visible_channels = [0, 1, 2, 3]
        self.channels = [0, 1, 2, 3]

        self.num = -1  # default view is fast
        self.fig, self.ax = plt.subplots()
        self.ax.set_title(str(self.visible_channels))
        plt.subplots_adjust(left=0.25, right=0.9, top=0.9, bottom=0.1)
        self.toolbar = self.fig.canvas.manager.toolbar
        # self.line, = self.ax.plot([], [])

        self.rax = plt.axes([0.05, 0.4, 0.1, 0.15])
        self.check = CheckButtons(self.rax, self.channels, [1, 1, 1, 1])
        self.check.on_clicked(self.func)

        self.lines = [
            self.ax.plot([], [])[0],
            self.ax.plot([], [])[0],
            self.ax.plot([], [])[0],
            self.ax.plot([], [])[0]
        ]

        # ch = np.load('../numpydata/fastview.npy')[self.channel]
        self.fast = np.load('../numpydata/fastview.npy')

        self.limits = {
            channel: (np.max(self.fast[channel]), np.min(self.fast[channel]))
            for channel in self.channels
        }
        self.lims = (np.min([
            np.min(self.fast[ch]) for ch in self.visible_channels
        ]), np.max([np.max(self.fast[ch]) for ch in self.visible_channels]))
        # print(self.limits)

        for ch in self.visible_channels:
            self.lines[ch].set_data(np.linspace(0, 1, len(self.fast[ch])),
                                    self.fast[ch])
            self.ax.set_xlim(0, 1)
            # self.ax.set_ylim(self.limits[ch])
            self.ax.set_ylim(self.lims)

        self.ax.set_xlabel('-1')
        self.btn = self.fig.canvas.mpl_connect('key_press_event',
                                               self.on_press)
        self.clk = self.fig.canvas.mpl_connect('button_press_event',
                                               self.on_click)
        self.fig.canvas.draw()

    def show(self):
        plt.show()

    def on_modify_visible_channels(self):
        pass
        self.ax.set_title(str(self.visible_channels))
        if self.num == -1:
            self.lims = (np.min([
                np.min(self.fast[ch]) for ch in self.visible_channels
            ]), np.max([np.max(self.fast[ch])
                        for ch in self.visible_channels]))

            for ch in self.channels:
                if ch not in self.visible_channels:
                    self.lines[ch].set_visible(False)
                else:
                    self.lines[ch].set_visible(True)

            # for ch in self.visible_channels:
            # self.lines[ch].set_visible(not self.lines[ch].get_visible())
            # self.lines[ch].set_data(np.linspace(0, 1, len(self.fast[ch])), self.fast[ch])
            # self.ax.set_xlim(0, 1)
            # self.ax.set_ylim(self.limits[ch])
            # self.ax.set_ylim(self.lims)

            self.ax.set_ylim(self.lims)
            self.fig.canvas.draw()

    def on_press(self, event):
        # print('press', event.key)
        sys.stdout.flush()
        if event.key == 'right':
            if self.num + 1 <= 99:
                self.num += 1
                self.load_data()
        elif event.key == 'left':
            if self.num - 1 >= 0:
                self.num -= 1
                self.load_data()

        elif event.key == 'd':
            self.num = -1
            self.load_data()

        # elif event.key in '0123':
        # pass
        # tempnum = int(event.key)
        # if tempnum in self.visible_channels:
        # 	if len(self.visible_channels)>=2:
        # 		self.visible_channels.remove(tempnum)
        # else:
        # 	self.visible_channels.append(tempnum)

        # self.visible_channels=sorted(self.visible_channels)
        # self.on_modify_visible_channels()
        # self.ax.set_title(self.ax.get_title()+event.key)
        # self.fig.canvas.draw()

        # elif event.key == 'enter':
        # 	try:
        # 		tempnum = int(self.ax.get_title())
        # 	except:
        # 		pass
        # 	else:
        # 		self.ax.set_title('')
        # 		if (tempnum >= 0) & (tempnum < 100):
        # 			self.num = tempnum
        # 		# self.load_data()
        else:
            pass

    def on_click(self, event):
        if event.inaxes != self.rax:
            if (event.dblclick) & (self.num == -1):
                self.num = int(round(event.xdata * 100)) - 1
                self.load_data()

            print(
                '%s click: button = %d, x = %d, y = %d, xdata = %f, ydata = %f'
                % ('double' if event.dblclick else 'single', event.button,
                   event.x, event.y, event.xdata, event.ydata))

    def func(self, label):
        # index=int(label)
        bools = self.check.get_status()
        # plt.draw()
        # tempnum = index
        # if tempnum in self.visible_channels:
        # 	if len(self.visible_channels)>=2:
        # 		self.visible_channels.remove(tempnum)
        # else:
        # 	self.visible_channels.append(tempnum)
        indexs = []
        for ch in self.channels:
            if bools[ch] == True:
                indexs.append(ch)
        self.visible_channels = np.array(self.channels)[indexs]
        self.on_modify_visible_channels()
        # self.lines[ch].set_visible(False)
        # else:
        # self.lines[ch].set_visible(True)

    def load_data(self):

        if self.ax.get_xlabel() != str(self.num):
            if self.num == -1:
                for ch in self.visible_channels:
                    self.lines[ch].set_data(
                        np.linspace(0, 1, len(self.fast[ch])), self.fast[ch])
                    self.ax.set_xlim(0, 1)
                    # self.ax.set_ylim(self.limits[ch])
                    self.ax.set_ylim(self.lims)
            else:
                data = np.load('../numpydata/data' + str(self.num) + '.npy')
                for ch in self.visible_channels:
                    self.lines[ch].set_data(
                        np.linspace(self.num / 100, (self.num + 1) / 100,
                                    len(data[ch])), data[ch])
                    self.ax.set_xlim(self.num / 100, (self.num + 1) / 100)

            self.ax.set_ylim(self.lims)
            self.ax.set_xlabel(self.num)
            self.fig.canvas.draw()

        if (self.ax.get_xlabel() == str(self.num)) & (str(self.num) == '-1'):
            self.ax.set_xlim(0, 1)
            self.ax.set_ylim(self.lims)
            self.fig.canvas.draw()
class WaveGraphBase(ABC):
    def __init__(self, name, granularity, x_range, x_offset, y_range, y_offset,
                 time_factor, line_thickness, waves, slider_data,
                 checkbox_data):
        self.granularity = granularity
        self.x_range = x_range
        self.x_offset = x_offset
        self.y_range = y_range
        self.y_offset = y_offset
        self.waves = waves
        self.slider_data = slider_data
        self.checkbox_data = checkbox_data
        self.time_factor = time_factor
        self.line_thickness = line_thickness

        self.fig = plt.figure(figsize=(16, 9))
        self.fig.canvas.set_window_title(name)

        self.main_grid = gridspec.GridSpec(2, 1)
        self.graph_cell = plt.subplot(self.main_grid[0, :])
        self.graph_cell.set(xlim=(-self.x_range - self.x_offset,
                                  self.x_range - self.x_offset),
                            ylim=(-self.y_range - self.y_offset,
                                  self.y_range - self.y_offset))

        self.x_data = np.linspace(-3 * self.x_range - 3 * self.x_offset,
                                  3 * self.x_range - 3 * self.x_offset,
                                  self.granularity)
        self.y_data = [[]] * len(self.waves)

        self.lines = [
            plt.plot([], [], linewidth=5)[0] for _ in range(len(self.waves))
        ]
        self.patches = self.lines

        self.control_cell = self.main_grid[1, :]
        self.control_grid = gridspec.GridSpecFromSubplotSpec(
            1, 7, self.control_cell)

        self.checkbox_cell = self.control_grid[0, 0]
        self.checkbox_grid = gridspec.GridSpecFromSubplotSpec(
            1, 1, self.checkbox_cell)
        self.checkboxes = []
        self.checkboxAx = plt.subplot(self.checkbox_grid[0, 0:1])
        self.checkbox = CheckButtons(
            self.checkboxAx, tuple(x["name"] for x in self.checkbox_data),
            tuple(x["init"] for x in self.checkbox_data))
        self.checkbox.on_clicked(self.update)
        self.checkboxes_ticked = self.checkbox.get_status()

        self.slider_cell = self.control_grid[0, 2:6]
        self.slider_grid = gridspec.GridSpecFromSubplotSpec(
            len(self.slider_data), 1, self.slider_cell)
        self.sliders = []
        for i in range(0, len(self.slider_data)):
            self.sliderAx = plt.subplot(self.slider_grid[i, 0])
            self.slider = Slider(self.sliderAx,
                                 self.slider_data[i]["name"],
                                 self.slider_data[i]["min"],
                                 self.slider_data[i]["max"],
                                 valinit=self.slider_data[i]["init"],
                                 valstep=self.slider_data[i]["step"])
            self.sliders.append(self.slider)
        for slider in self.sliders:
            slider.on_changed(self.update)

    def init(self):
        for line in self.lines:
            line.set_data([], [])
        return self.patches

    def start(self):
        self.animation = animation.FuncAnimation(self.fig,
                                                 self.animate,
                                                 init_func=self.init,
                                                 frames=999999,
                                                 repeat=True,
                                                 interval=20,
                                                 blit=True)
        plt.show()

    @abstractmethod
    def update(self, event=None):
        """Register sliders and checkboxes"""

    @abstractmethod
    def animate(self, i):
        """Animation function"""
def genLineChartByTown(p_yearEnter, p_quarterEnter, p_flatType):
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    from matplotlib.widgets import RadioButtons, CheckButtons
    from datetime import datetime
    from matplotlib import cm

    global txt, flatType, checkBoxLines
    txt = None
    yearEnter = p_yearEnter
    quarterEnter = p_quarterEnter
    quarterNum = quarterEnter
    flatType = p_flatType

    def mainTheme():
        global axcolor
        axcolor = 'lightgoldenrodyellow'
        title = "HDB Resale Flat Price Line Chart"
        ax.set_title(title, fontsize=30)
        ax.set_ylabel('Resale Price', fontsize=25)

    def get_status():
        return [checkBoxLines[index].get_visible() for index in checkBoxLines]

    def genCombineResalePrices(yearEnter, quarterNum, flatType):
        import numpy as np
        from datetime import datetime
        from matplotlib.widgets import CheckButtons
        global townSelect, towndata
        ## To get data within the Period from parameter pass of p_yearEnter p_quarterEnter to current date
        startYear = int(yearEnter)
        quarterNum = '2'
        dateEnd = datetime.now().strftime('%Y')
        endYear = int(dateEnd)
        dataPeriod = data[data['quarter'] == '2007-02']
        tdata = data[data['quarter'] == '2007-03']
        dataPeriod = np.append(dataPeriod, tdata)
        tdata = data[data['quarter'] == '2007-04']
        dataPeriod = np.append(dataPeriod, tdata)
        for i in range(startYear + 1, endYear + 1):
            for j in range(1, 5):
                if not (i == startYear and j == 2):
                    iDate = str(i) + "-Q" + str(j)
                    idata = data[data['quarter'] == iDate]
                    dataPeriod = np.append(dataPeriod, idata)
        towndata = dataPeriod[dataPeriod['flat_type'] == flatType]

        ## To get data within the Selected Checkbox Town Line
        line_labels = [str(line.get_label()) for line in checkBoxLines]
        line_visibility = [line.get_visible() for line in checkBoxLines]
        check = CheckButtons(rax, line_labels, line_visibility)
        checkButtonThemes(check)
        status = check.get_status()
        prices = towndata['price']
        quarter = towndata['quarter']
        town_prices = np.zeros(len(townList), object)
        town_quarter = np.zeros(len(townList), object)
        for t in townIndex:
            town_prices[t] = prices[towndata['town'] == townList[t]]
            town_quarter[t] = quarter[towndata['town'] == townList[t]]
        return town_prices, town_quarter

    def updateCheckBox(line_labels, line_visibility, check):
        rax.clear()
        line_labels = [str(line.get_label()) for line in checkBoxLines]
        line_visibility = [line.get_visible() for line in checkBoxLines]
        check = CheckButtons(rax, line_labels, line_visibility)
        checkButtonThemes(check)

    def func(label):
        global index, towndata, checkBoxLines
        status = check.get_status()
        index = line_labels.index(label)
        checkBoxLines[index].set_visible(
            not checkBoxLines[index].get_visible())
        updateCheckBox(line_labels, line_visibility, check)
        towndata = towndata[towndata['flat_type'] == flatType]
        prices = towndata['price']
        quarter = towndata['quarter']
        town_prices = np.zeros(len(townList), object)
        town_quarter = np.zeros(len(townList), object)
        for t in townIndex:
            town_prices[t] = prices[towndata['town'] == townList[t]]
            town_quarter[t] = quarter[towndata['town'] == townList[t]]
        mainTheme()
        print(status)
        plt.draw()

    def checkButtonThemes(check):
        for r in check.rectangles:
            r.set_alpha(0.8)
            r.set_width(0.025)
            r.set_edgecolor("k")

##############            Main Process Start Here			###############

    title = "HDB Resale Flat Price Line Chart"

    data = np.genfromtxt(
        'data/median-resale-prices-for-registered-applications-by-town-and-flat-type.csv',
        skip_header=1,
        dtype=[('quarter', 'U7'), ('town', 'U30'), ('flat_type', 'U20'),
               ('price', 'i8')],
        delimiter=",",
        missing_values=['na', '-'],
        filling_values=[0])

    null_rows = np.isnan(data['price'])
    #nonnull_prices = data[null_rows==False]
    townList = list(set(data['town']))
    townList.sort()
    townIndex = np.arange(0, len(townList))
    colorList = cm.hsv(townIndex / float(max(townIndex + 10)))
    yearquarterList = list(set(data['quarter']))
    yearquarterList.sort()
    town_prices = np.zeros(len(townList), object)

    startYear = int(yearEnter)
    quarterNum = '2'
    dateEnd = datetime.now().strftime('%Y')
    endYear = int(dateEnd)
    dataPeriod = data[data['quarter'] == '2007-02']
    tdata = data[data['quarter'] == '2007-03']
    dataPeriod = np.append(dataPeriod, tdata)
    tdata = data[data['quarter'] == '2007-04']
    dataPeriod = np.append(dataPeriod, tdata)
    for i in range(startYear + 1, endYear + 1):
        for j in range(1, 5):
            if not (i == startYear and j == 2):
                iDate = str(i) + "-Q" + str(j)
                idata = data[data['quarter'] == iDate]
                dataPeriod = np.append(dataPeriod, idata)
    towndata = dataPeriod[dataPeriod['flat_type'] == flatType]

    prices = towndata['price']
    quarter = towndata['quarter']
    town_prices = np.zeros(len(townList), object)
    town_quarter = np.zeros(len(townList), object)
    for t in townIndex:
        town_prices[t] = prices[towndata['town'] == townList[t]]
        town_quarter[t] = quarter[towndata['town'] == townList[t]]

    fig, ax = plt.subplots(figsize=(19, 15))
    plt.subplots_adjust(left=0.25, bottom=0.25)
    mainTheme()

    checkBoxLines = np.zeros(len(townList), object)
    print(town_prices[0])
    for l in townIndex:
        checkBoxLines[l], = ax.plot(town_quarter[l],
                                    town_prices[l],
                                    visible=False,
                                    c=colorList[l],
                                    label=townList[l])
        ax.set_xticklabels(town_quarter[l], rotation=90)

    patches = [
        mpatches.Patch(color=color, label=label)
        for label, color in zip(townList, colorList)
    ]
    #fig.legend(patches, townList, loc='top right', frameon=True)
    fig.legend(patches, townList, loc='top right')

    #Ploting checkbox button
    # Make checkbuttons with all plotted lines with correct visibility
    rax = plt.axes([0.05, 0.25, 0.13, 0.7], facecolor=axcolor)

    line_labels = [str(line.get_label()) for line in checkBoxLines]
    line_visibility = [line.get_visible() for line in checkBoxLines]
    check = CheckButtons(rax, line_labels, line_visibility)
    checkButtonThemes(check)

    check.on_clicked(func)
    status = check.get_status()
    print(status)

    radioax = plt.axes([0.05, 0.08, 0.11, 0.12], facecolor=axcolor)
    radio = RadioButtons(
        radioax,
        ('1-room', '2-room', '3-room', '4-room', '5-room', 'Executive'),
        active=3)

    def flatTypefunc(label):
        global flatType
        quarterNum = "2"
        flatType = label
        town_prices, town_quarter = genCombineResalePrices(
            yearEnter, quarterNum, flatType)
        ax.clear()
        mainTheme()
        status = check.get_status()
        for l in townIndex:
            if status[l]:
                checkBoxLines[l], = ax.plot(town_quarter[l],
                                            town_prices[l],
                                            visible=True,
                                            c=colorList[l],
                                            label=townList[l])
                ax.set_xticklabels(town_quarter[l], rotation=90)
            else:
                checkBoxLines[l], = ax.plot(town_quarter[l],
                                            town_prices[l],
                                            visible=False,
                                            c=colorList[l],
                                            label=townList[l])
                ax.set_xticklabels(town_quarter[l], rotation=90)
        fig.canvas.draw_idle()

    radio.on_clicked(flatTypefunc)
    #ax.legend(checkBoxLines)
    #ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    town_prices, town_quarter = genCombineResalePrices(yearEnter, quarterNum,
                                                       flatType)
    plt.show()
예제 #8
0
class Confusion_Visualizer():
    def __init__(self, train_confusion_path, valid_confusion_path, class_names=None):

        #   Load data
        self.confusion_train = self._load_data(train_confusion_path)
        self.tfpn_train = confusions_to_TFPN(self.confusion_train)
        self.confusion_valid = self._load_data(valid_confusion_path)
        self.tfpn_valid = confusions_to_TFPN(self.confusion_valid)

        #   Check whether dimensions match
        if self.confusion_train.shape[-1] != self.confusion_valid.shape[-1]:
            min_length = min(self.confusion_train.shape[-1], self.confusion_valid.shape[-1])
            print('Training and Validation data not match (of shape {} and {} respectively), truncated as {}'.
                  format(self.confusion_train.shape, self.confusion_valid.shape, min_length))
            self.confusion_train = self.confusion_train[:, :, 0:min_length]
            self.confusion_valid = self.confusion_valid[:, :, 0:min_length]
        self.series_length = self.confusion_train.shape[-1]
        self.class_amount = self.confusion_train.shape[0]

        #   Obtain class labels
        if class_names is None:
            self.class_names = [str(i) for i in range(self.class_amount)]
        elif isinstance(class_names, list) or isinstance(class_names, tuple):
            if len(class_names) == self.class_amount:
                self.class_names = class_names
            else:
                print('Class names ({}) mismatch with confusion ({})'.format(len(class_names), self.class_amount))
                raise ValueError
        else:
            print('Class names unsupported (please use list, e.g. [''Benign'',''Tumor''])')
            raise ValueError

        #   Process data to get performance
        self.plot_data = {}
        self.performance_train = self._process_data(self.confusion_train)
        self.performance_valid = self._process_data(self.confusion_valid)

        #   Set initial value for control panel (used in framework)
        self.ctrl_cursor = 0
        self.ctrl_class = 0
        self.ctrl_smooth = 1
        self.ctrl_trajectory_length = 1
        self.ctrl_trajectory_mode = 'FP-FN'
        self.ctrl_hold_range = False
        self.ctrl_same_ratio = True

        #   Draw framework
        self._draw_framework()

        #   Draw control panel
        self._draw_control_panel()

        #   Draw initial figure
        self._figure_initialize()

    def _draw_framework(self):

        #   Used for drawing confusion matrix (heatmap)
        self.heatmap_fig = plt.figure(1, figsize=(8, 4))
        plt.title('Heatmap')
        self.heatmap_train = plt.subplot(1, 2, 1)
        self.heatmap_train.set_title('Train')
        self.heatmap_train.set_xticks(np.arange(self.class_amount))
        self.heatmap_train.set_yticks(np.arange(self.class_amount))
        self.heatmap_train.set_xticklabels(self.class_names)
        self.heatmap_train.set_yticklabels(self.class_names)
        self.heatmap_train.set_ylabel('Ground Truth')
        self.heatmap_train.set_xlabel('Prediction')
        self.heatmap_valid = plt.subplot(1, 2, 2)
        self.heatmap_valid.set_title('Validation')
        self.heatmap_valid.set_xticks(np.arange(self.class_amount))
        self.heatmap_valid.set_yticks(np.arange(self.class_amount))
        self.heatmap_valid.set_xticklabels(self.class_names)
        self.heatmap_valid.set_yticklabels(self.class_names)
        self.heatmap_valid.set_ylabel('Ground Truth')
        self.heatmap_valid.set_xlabel('Prediction')

        #   Used for drawing performance
        self.performance_fig = plt.figure(2, figsize=(10, 6))
        plt.title('Performance')
        self.performance_global_p_r = plt.subplot(2, 2, 1)
        self.performance_global_p_r.set_title('Global')
        self.performance_class_p_r = plt.subplot(2, 2, 2)
        self.performance_class_p_r.set_title('Class: ' + self.class_names[int(self.ctrl_class)])
        self.performance_global_a_f = plt.subplot(2, 2, 3)
        self.performance_class_f = plt.subplot(2, 2, 4)

        #   Used for drawing FP/FN, etc.
        self.trace_fig = plt.figure(3, figsize=(6, 6))
        self.trace_plot = plt.subplot(1, 1, 1)
        self.trace_plot.set_xlabel('FP')
        self.trace_plot.set_ylabel('FN')
        self.trace_plot.set_title('Class: ' + self.class_names[int(self.ctrl_class)])

        #   Used for drawing control panel
        self.controller = plt.figure(4, figsize=(4, 4))

    def _draw_control_panel(self):

        ax = self.controller.add_axes([0.2, 0.90, 0.5, 0.05])
        self.slider_cursor = DiscreteSlider(label='cursor', valmin=0, valmax=self.series_length - 1,
                                            ax=ax, increment=1, valinit=self.ctrl_cursor)
        self.slider_cursor.on_changed(self.controller_change_slider_cursor)

        ax = self.controller.add_axes([0.2, 0.75, 0.5, 0.05])
        self.slider_class = DiscreteSlider(label='class', valmin=0, valmax=self.class_amount - 1,
                                           ax=ax, increment=1, valinit=self.ctrl_class)
        self.slider_class.on_changed(self.controller_change_slider_class)

        ax = self.controller.add_axes([0.2, 0.60, 0.5, 0.05])
        self.slider_smooth = DiscreteSlider(label='smooth', valmin=1, valmax=self.series_length,
                                            ax=ax, increment=1, valinit=self.ctrl_smooth)
        self.slider_smooth.on_changed(self.controller_change_slider_smooth)

        ax = self.controller.add_axes([0.2, 0.45, 0.5, 0.05])
        self.slider_trajectory = DiscreteSlider(label='trajectory', valmin=1, valmax=self.series_length,
                                                ax=ax, increment=1, valinit=self.ctrl_trajectory_length)
        self.slider_trajectory.on_changed(self.controller_change_slider_trajectory)

        ax = self.controller.add_axes([0.6, 0.1, 0.3, 0.25])
        self.radiobuttons_norm = RadioButtons(ax, ('FP-FN', 'FPR-TPR', 'Recl-Prec', 'Spec-Sens'), active=0)
        self.radiobuttons_norm.on_clicked(self.controller_change_radiobuttons_norm)

        ax = self.controller.add_axes([0.3, 0.1, 0.3, 0.25])
        self.checkbutton_hold = CheckButtons(ax, labels=['Hold on Range', 'Same Ratio'], actives=[False, True])
        self.checkbutton_hold.on_clicked(self.controller_change_checkbutton_hold)

    def _load_data(self, obj):
        if os.path.isfile(obj):
            #   read directly given a single file
            [conf_tensor] = load_variables(obj)
            return conf_tensor
        elif os.path.isdir(obj):
            #   read 0.pkl ~ *.pkl given a directory path
            file_list = get_file_list(obj)
            file_amount = len(file_list)
            conf_list = []
            for i in range(file_amount):
                [conf] = load_variables(os.path.join(obj, str(i) + '.pkl'))
                conf_list.append(conf)
            return np.concatenate(np.expand_dims(conf_list, axis=-1), axis=-1)
        elif isinstance(obj, np.ndarray):
            #   return directly given np array
            return obj
        else:
            print('Invalid confusion source: {}'.format(obj))
            raise ValueError

    #   Confusions: c x c x N
    def _process_data(self, confusions):
        performance = {}
        performance['accuracy'] = []
        performance['precision'] = []
        performance['recall'] = []
        performance['f1'] = []
        for i in range(confusions.shape[-1]):
            current_result = confusion_to_all(confusions[:, :, i])
            for metric in ['accuracy', 'precision', 'recall', 'f1']:
                performance[metric].append(current_result[metric])
        for metric in ['accuracy', 'precision', 'recall', 'f1']:
            performance[metric] = np.concatenate(np.expand_dims(performance[metric], axis=-1), axis=-1)
        #   If some performance is nan, it will replaced by 0
        for metric in ['precision', 'recall', 'f1']:
            global_performance = np.mean(performance[metric], axis=0)
            global_performance[np.isnan(global_performance)] = 0
            performance['global_' + metric] = global_performance
        for metric in ['accuracy', 'precision', 'recall', 'f1']:
            performance[metric][np.isnan(performance[metric])] = 0
        performance['global_accuracy'] = performance['accuracy']

        return performance

    #   Figure initialization
    def _figure_initialize(self):

        self.element_heatmap_train_text_pool = []
        self.element_heatmap_valid_text_pool = []
        self._update_figure_heatmap()

        self.element_performance_global_precision_train = None
        self.element_performance_global_recall_train = None
        self.element_performance_global_precision_valid = None
        self.element_performance_global_recall_valid = None
        self.element_performance_global_accuracy_train = None
        self.element_performance_global_f1_train = None
        self.element_performance_global_accuracy_valid = None
        self.element_performance_global_f1_valid = None
        self._update_figure_global_performance()

        self.element_performance_class_precision_train = None
        self.element_performance_class_recall_train = None
        self.element_performance_class_precision_valid = None
        self.element_performance_class_recall_valid = None
        self.element_performance_class_f1_train = None
        self.element_performance_class_f1_valid = None
        self._update_figure_class_performance()

        self.element_trajectory_train = None
        self.element_trajectory_valid = None
        self._update_figure_trajectory()

        plt.show()

    #   Event for slider cursor changed
    def controller_change_slider_cursor(self, event):
        new_value = int(self.slider_cursor.val)
        if self.ctrl_cursor != new_value:
            self.ctrl_cursor = new_value
            self._update_figure_heatmap()
            self._update_figure_trajectory()

    #   Event for slider class changed
    def controller_change_slider_class(self, event):
        new_value = int(self.slider_class.val)
        if self.ctrl_class != new_value:
            self.ctrl_class = new_value
            self._update_figure_class_performance()
            self._update_figure_trajectory()

    #   Event for slider smooth changed
    def controller_change_slider_smooth(self, event):
        new_value = int(self.slider_smooth.val)
        if self.ctrl_smooth != new_value:
            self.ctrl_smooth = new_value
            self._update_figure_global_performance()
            self._update_figure_class_performance()

    #   Event for slider trajectory (length) changed
    def controller_change_slider_trajectory(self, event):
        new_value = int(self.slider_trajectory.val)
        if self.ctrl_trajectory_length != new_value:
            self.ctrl_trajectory_length = new_value
            self._update_figure_trajectory()

    #   Event for radiobuttons changed (FP/FN normalization method)
    def controller_change_radiobuttons_norm(self, event):
        new_value = self.radiobuttons_norm.value_selected
        if self.ctrl_trajectory_mode != new_value:
            self.ctrl_trajectory_mode = new_value
            self._update_figure_trajectory()

    #   Event for checkbox changed (hold FP/FN figure's range)
    def controller_change_checkbutton_hold(self, event):
        new_value = self.checkbutton_hold.get_status()
        self.ctrl_hold_range = new_value[0]
        self.ctrl_same_ratio = new_value[1]

    def _update_data_heatmap(self):
        ctrl_cursor = self.ctrl_cursor
        self.plot_data['heatmap_train'] = self.confusion_train[:, :, ctrl_cursor] / np.sum(
            self.confusion_train[:, :, ctrl_cursor], axis=1, keepdims=True)
        self.plot_data['heatmap_valid'] = self.confusion_valid[:, :, ctrl_cursor] / np.sum(
            self.confusion_valid[:, :, ctrl_cursor], axis=1, keepdims=True)

    def _update_data_global_performance(self):
        ctrl_smooth = self.ctrl_smooth
        for metric in ['precision', 'recall', 'f1', 'accuracy']:
            self.plot_data['performance_global_train_' + metric] = mean_filter_padded(
                self.performance_train['global_' + metric], ctrl_smooth)
            self.plot_data['performance_global_valid_' + metric] = mean_filter_padded(
                self.performance_valid['global_' + metric], ctrl_smooth)

    def _update_data_class_performance(self):
        ctrl_class = self.ctrl_class
        ctrl_smooth = self.ctrl_smooth
        for metric in ['precision', 'recall', 'f1']:
            self.plot_data['performance_class_train_' + metric] = mean_filter_padded(
                self.performance_train[metric][ctrl_class, :], ctrl_smooth)
            self.plot_data['performance_class_valid_' + metric] = mean_filter_padded(
                self.performance_valid[metric][ctrl_class, :], ctrl_smooth)

    def _update_data_trajectory(self):
        ctrl_cursor = self.ctrl_cursor
        ctrl_class = self.ctrl_class
        ctrl_trajectory = self.ctrl_trajectory_length
        ctrl_trajectory_mode = self.ctrl_trajectory_mode

        #   Desired range: from cursor to (cursor+trajectory)
        lb = max(ctrl_cursor, 0)
        ub = min(ctrl_cursor + ctrl_trajectory, self.series_length)

        train_tp = self.tfpn_train[ctrl_class, 0, lb:ub]
        train_tn = self.tfpn_train[ctrl_class, 1, lb:ub]
        train_fp = self.tfpn_train[ctrl_class, 2, lb:ub]
        train_fn = self.tfpn_train[ctrl_class, 3, lb:ub]

        valid_tp = self.tfpn_valid[ctrl_class, 0, lb:ub]
        valid_tn = self.tfpn_valid[ctrl_class, 1, lb:ub]
        valid_fp = self.tfpn_valid[ctrl_class, 2, lb:ub]
        valid_fn = self.tfpn_valid[ctrl_class, 3, lb:ub]

        #   Measurements refer to: http://www.davidsbatista.net/blog/2018/08/19/NLP_Metrics/
        if ctrl_trajectory_mode == 'FP-FN':
            self.plot_data['trajectory_train_x'] = train_fp
            self.plot_data['trajectory_train_y'] = train_fn
            self.plot_data['trajectory_valid_x'] = valid_fp
            self.plot_data['trajectory_valid_y'] = valid_fn
        elif ctrl_trajectory_mode == 'FPR-TPR':
            self.plot_data['trajectory_train_x'] = train_fp / (train_tn + train_fp)
            self.plot_data['trajectory_train_y'] = train_tp / (train_tp + train_fn)
            self.plot_data['trajectory_valid_x'] = valid_fp / (valid_tn + valid_fp)
            self.plot_data['trajectory_valid_y'] = valid_tp / (valid_tp + valid_fn)
        elif ctrl_trajectory_mode == 'Recl-Prec':
            self.plot_data['trajectory_train_x'] = train_tp / (train_tp + train_fn)
            self.plot_data['trajectory_train_y'] = train_tp / (train_tp + train_fp)
            self.plot_data['trajectory_valid_x'] = valid_tp / (valid_tp + valid_fn)
            self.plot_data['trajectory_valid_y'] = valid_tp / (valid_tp + valid_fp)
        elif ctrl_trajectory_mode == 'Spec-Sens':
            self.plot_data['trajectory_train_x'] = train_tn / (train_tn + train_fp)
            self.plot_data['trajectory_train_y'] = train_tp / (train_tp + train_fn)
            self.plot_data['trajectory_valid_x'] = valid_tn / (valid_tn + valid_fp)
            self.plot_data['trajectory_valid_y'] = valid_tp / (valid_tp + valid_fn)
        else:
            raise ValueError
        pass

    def _update_figure_heatmap(self):

        #   Fetch data
        self._update_data_heatmap()
        heatmap_data_train = self.plot_data['heatmap_train']
        heatmap_data_valid = self.plot_data['heatmap_valid']

        #   Clean former text
        for item in self.element_heatmap_train_text_pool:
            item.remove()
        for item in self.element_heatmap_valid_text_pool:
            item.remove()
        self.element_heatmap_train_text_pool = []
        self.element_heatmap_valid_text_pool = []

        #   Write new text & Draw
        for i in range(self.class_amount):
            for j in range(self.class_amount):
                obj = self.heatmap_train.text(j, i, '{:.2f}'.format(heatmap_data_train[i, j]),
                                              ha="center", va="center", color="w",
                                              fontsize=8)
                self.element_heatmap_train_text_pool.append(obj)
        self.heatmap_train.imshow(heatmap_data_train, cmap='jet')

        for i in range(self.class_amount):
            for j in range(self.class_amount):
                obj = self.heatmap_valid.text(j, i, '{:.2f}'.format(heatmap_data_valid[i, j]),
                                              ha="center", va="center", color="w",
                                              fontsize=8)
                self.element_heatmap_valid_text_pool.append(obj)
        self.heatmap_valid.imshow(heatmap_data_valid, cmap='jet')

        self.heatmap_fig.tight_layout()

        self.heatmap_fig.canvas.draw_idle()

    def _update_figure_global_performance(self):
        time_steps = np.arange(self.series_length)
        self._update_data_global_performance()

        if self.element_performance_global_precision_train is None:
            self.element_performance_global_precision_train, = \
                self.performance_global_p_r.plot(time_steps,
                                                 self.plot_data['performance_global_train_precision'],
                                                 color='red', linestyle='-')
        else:
            self.element_performance_global_precision_train.set_ydata(
                self.plot_data['performance_global_train_precision'])

        if self.element_performance_global_recall_train is None:
            self.element_performance_global_recall_train, = \
                self.performance_global_p_r.plot(time_steps,
                                                 self.plot_data['performance_global_train_recall'],
                                                 color='blue', linestyle='-')
        else:
            self.element_performance_global_recall_train.set_ydata(self.plot_data['performance_global_train_recall'])

        if self.element_performance_global_precision_valid is None:
            self.element_performance_global_precision_valid, = \
                self.performance_global_p_r.plot(time_steps,
                                                 self.plot_data['performance_global_valid_precision'],
                                                 color='red', linestyle='--')
        else:
            self.element_performance_global_precision_valid.set_ydata(
                self.plot_data['performance_global_valid_precision'])

        if self.element_performance_global_recall_valid is None:
            self.element_performance_global_recall_valid, = \
                self.performance_global_p_r.plot(time_steps,
                                                 self.plot_data['performance_global_valid_recall'],
                                                 color='blue', linestyle='--')
            #   Legend, only added at first drawing
            self.performance_global_p_r.legend(
                ['pre_train', 'rec_train', 'pre_valid', 'rec_valid'])
        else:
            self.element_performance_global_recall_valid.set_ydata(self.plot_data['performance_global_valid_recall'])

        if self.element_performance_global_accuracy_train is None:
            self.element_performance_global_accuracy_train, = \
                self.performance_global_a_f.plot(time_steps,
                                                 self.plot_data['performance_global_train_accuracy'],
                                                 color='red', linestyle='-')
        else:
            self.element_performance_global_accuracy_train.set_ydata(
                self.plot_data['performance_global_train_accuracy'])

        if self.element_performance_global_f1_train is None:
            self.element_performance_global_f1_train, = \
                self.performance_global_a_f.plot(time_steps,
                                                 self.plot_data['performance_global_train_f1'],
                                                 color='blue', linestyle='-')
        else:
            self.element_performance_global_f1_train.set_ydata(self.plot_data['performance_global_train_f1'])

        if self.element_performance_global_accuracy_valid is None:
            self.element_performance_global_accuracy_valid, = \
                self.performance_global_a_f.plot(time_steps,
                                                 self.plot_data['performance_global_valid_accuracy'],
                                                 color='red', linestyle='--')
        else:
            self.element_performance_global_accuracy_valid.set_ydata(
                self.plot_data['performance_global_valid_accuracy'])

        if self.element_performance_global_f1_valid is None:
            self.element_performance_global_f1_valid, = \
                self.performance_global_a_f.plot(time_steps,
                                                 self.plot_data['performance_global_valid_f1'],
                                                 color='blue', linestyle='--')
            #   Legend, only added at first drawing
            self.performance_global_a_f.legend(
                ['acc_train', 'f1_train', 'acc_valid', 'f1_valid'])
        else:
            self.element_performance_global_f1_valid.set_ydata(self.plot_data['performance_global_valid_f1'])

        self.performance_fig.canvas.draw_idle()

    def _update_figure_class_performance(self):

        time_steps = np.arange(self.series_length)
        self._update_data_class_performance()

        self.performance_class_p_r.set_title('Class: ' + self.class_names[int(self.ctrl_class)])

        if self.element_performance_class_precision_train is None:
            self.element_performance_class_precision_train, = \
                self.performance_class_p_r.plot(time_steps,
                                                self.plot_data['performance_class_train_precision'],
                                                color='red', linestyle='-')
        else:
            self.element_performance_class_precision_train.set_ydata(
                self.plot_data['performance_class_train_precision'])

        if self.element_performance_class_recall_train is None:
            self.element_performance_class_recall_train, = \
                self.performance_class_p_r.plot(time_steps,
                                                self.plot_data['performance_class_train_recall'],
                                                color='blue', linestyle='-')
        else:
            self.element_performance_class_recall_train.set_ydata(self.plot_data['performance_class_train_recall'])

        if self.element_performance_class_precision_valid is None:
            self.element_performance_class_precision_valid, = \
                self.performance_class_p_r.plot(time_steps,
                                                self.plot_data['performance_class_valid_precision'],
                                                color='red', linestyle='--')
        else:
            self.element_performance_class_precision_valid.set_ydata(
                self.plot_data['performance_class_valid_precision'])

        if self.element_performance_class_recall_valid is None:
            self.element_performance_class_recall_valid, = \
                self.performance_class_p_r.plot(time_steps,
                                                self.plot_data['performance_class_valid_recall'],
                                                color='blue', linestyle='--')
            #   Legend, only added at first drawing
            self.performance_class_p_r.legend(
                ['pre_train', 'rec_train', 'pre_valid', 'rec_valid'])
        else:
            self.element_performance_class_recall_valid.set_ydata(self.plot_data['performance_class_valid_recall'])

        if self.element_performance_class_f1_train is None:
            self.element_performance_class_f1_train, = \
                self.performance_class_f.plot(time_steps,
                                              self.plot_data['performance_class_train_f1'],
                                              color='blue', linestyle='-')
        else:
            self.element_performance_class_f1_train.set_ydata(self.plot_data['performance_class_train_f1'])

        if self.element_performance_class_f1_valid is None:
            self.element_performance_class_f1_valid, = \
                self.performance_class_f.plot(time_steps,
                                              self.plot_data['performance_class_valid_f1'],
                                              color='blue', linestyle='--')
            #   Legend, only added at first drawing
            self.performance_class_f.legend(
                ['f1_train', 'f1_valid'])
        else:
            self.element_performance_class_f1_valid.set_ydata(self.plot_data['performance_class_valid_f1'])

        self.performance_fig.canvas.draw_idle()

    def _update_figure_trajectory(self):
        self._update_data_trajectory()

        self.trace_plot.set_title('Class: ' + self.class_names[int(self.ctrl_class)])

        data_length = len(self.plot_data['trajectory_train_x'])
        interpolation_values = (np.arange(data_length) + 1) / data_length
        blue_color = color_interpolation(np.array((1, 1, 1)), np.array((0, 0, 1)), interpolation_values)
        red_color = color_interpolation(np.array((1, 1, 1)), np.array((1, 0, 0)), interpolation_values)

        #   To make the pure blue and pure red showed on legend, we reverse the array
        if blue_color.shape[0] > 1:
            blue_color = blue_color[::-1, :]
        if red_color.shape[0] > 1:
            red_color = red_color[::-1, :]

        if self.element_trajectory_train is not None:
            self.element_trajectory_train.remove()
        self.element_trajectory_train = self.trace_plot.scatter(self.plot_data['trajectory_train_x'][::-1],
                                                                self.plot_data['trajectory_train_y'][::-1],
                                                                c=blue_color)

        if self.element_trajectory_valid is not None:
            self.element_trajectory_valid.remove()
        self.element_trajectory_valid = self.trace_plot.scatter(self.plot_data['trajectory_valid_x'][::-1],
                                                                self.plot_data['trajectory_valid_y'][::-1],
                                                                c=red_color)

        #   Adjust axis range
        if not self.ctrl_hold_range:
            max_range_x = np.max(
                np.concatenate((self.plot_data['trajectory_train_x'], self.plot_data['trajectory_valid_x'])))
            max_range_y = np.max(
                np.concatenate((self.plot_data['trajectory_train_y'], self.plot_data['trajectory_valid_y'])))
            max_range = np.max([max_range_x, max_range_y])

            if self.ctrl_same_ratio:
                self.trace_plot.set_aspect('equal')
                self.trace_plot.set_xlim((0, 1.1 * max_range))
                self.trace_plot.set_ylim((0, 1.1 * max_range))
            else:
                self.trace_plot.set_aspect('auto')
                self.trace_plot.set_xlim((0, 1.1 * max_range_x))
                self.trace_plot.set_ylim((0, 1.1 * max_range_y))

        #   Updata axis label
        if self.ctrl_trajectory_mode == 'FP-FN':
            self.trace_plot.set_xlabel('FP')
            self.trace_plot.set_ylabel('FN')
        elif self.ctrl_trajectory_mode == 'FPR-TPR':
            self.trace_plot.set_xlabel('FPR')
            self.trace_plot.set_ylabel('TPR')
        elif self.ctrl_trajectory_mode == 'Recl-Prec':
            self.trace_plot.set_xlabel('Recall')
            self.trace_plot.set_ylabel('Precision')
        elif self.ctrl_trajectory_mode == 'Spec-Sens':
            self.trace_plot.set_xlabel('Specificity')
            self.trace_plot.set_ylabel('Sensitivity ')
        else:
            raise ValueError

        self.trace_plot.legend(['Train', 'Valid'])

        self.trace_fig.canvas.draw_idle()
예제 #9
0
class CorrViewer(DataViewer):
    """Plots raw correlation data. You need to hold reference to this object, 
    otherwise it will not work in interactive mode.

    Parameters
    ----------
    semilogx : bool
        Whether plot data with semilogx or not.
    shape : tuple of ints, optional
        Original frame shape. For non-rectangular you must provide this so
        to define k step.
    size : int, optional
        If specified, perform log_averaging of data with provided size parameter.
        If not given, no averaging is performed.
    norm : int, optional
        Normalization constant used in normalization
    scale : bool, optional
        Scale constant used in normalization.
    mask : ndarray, optional
        A boolean array indicating which data elements were computed.
    """
    background = None
    variance = None

    def __init__(self,
                 semilogx=True,
                 shape=None,
                 size=None,
                 norm=None,
                 scale=False,
                 mask=None):
        self.norm = norm
        self.scale = scale
        self.semilogx = semilogx
        self.shape = shape
        self.size = size
        self.computed_mask = mask
        if mask is not None:
            self.kisize, self.kjsize = mask.shape

    def set_norm(self, value):
        """Sets norm parameter"""
        method = _method_from_data(self.data)
        self.norm = _default_norm_from_data(self.data, method, value)

    def _init_fig(self):
        super()._init_fig()

        self.set_norm(self.norm)

        #self.rax = plt.axes([0.48, 0.55, 0.15, 0.3])
        self.cax = plt.axes([0.44, 0.72, 0.2, 0.15])

        self.active = [
            bool(self.norm & NORM_STRUCTURED),
            bool(self.norm & NORM_SUBTRACTED),
            bool((self.norm & NORM_WEIGHTED == NORM_WEIGHTED)),
            bool((self.norm & NORM_COMPENSATED) == NORM_COMPENSATED)
        ]

        self.check = CheckButtons(
            self.cax, ("structured", "subtracted", "weighted", "compensated"),
            self.active)

        #self.radio = RadioButtons(self.rax,("norm 0","norm 1","norm 2","norm 3","norm 4", "norm 5", "norm 6", "norm 7"), active = self.norm, activecolor = "gray")

        def update(label):
            index = ["structured", "subtracted", "weighted",
                     "compensated"].index(label)
            status = self.check.get_status()

            norm = NORM_STRUCTURED if status[0] == True else NORM_STANDARD

            if status[1]:
                norm = norm | NORM_SUBTRACTED
            if status[2]:
                norm = norm | NORM_WEIGHTED
            if status[3]:
                norm = norm | NORM_COMPENSATED
            try:
                self.set_norm(norm)
            except ValueError:
                self.check.set_active(index)
            self.set_mask(int(round(self.kindex.val)), self.angleindex.val,
                          self.sectorindex.val, self.kstep)
            self.plot()

        self.check.on_clicked(update)


#        def update(val):
#            try:
#                self.set_norm(int(self.radio.value_selected[-1]))
#            except ValueError:
#                self.radio.set_active(self.norm)
#            self.set_mask(int(round(self.kindex.val)),self.angleindex.val,self.sectorindex.val, self.kstep)
#            self.plot()
#
#        self.radio.on_clicked(update)

    def set_data(self, data, background=None, variance=None):
        """Sets correlation data.
        
        Parameters
        ----------
        data : tuple
            A data tuple (as computed by ccorr, cdiff, adiff, acorr functions)
        background : tuple or ndarray
            Background data for normalization. For adiff, acorr functions this
            is ndarray, for cdiff,ccorr, it is a tuple of ndarrays.
        variance : tuple or ndarray
            Variance data for normalization. For adiff, acorr functions this
            is ndarray, for cdiff,ccorr, it is a tuple of ndarrays.
        """
        self.data = data
        self.background = background
        self.variance = variance
        self.kshape = data[0].shape[:-1]
        if self.computed_mask is None:
            self.kisize, self.kjsize = self.kshape

    def _get_avg_data(self):
        data = normalize(self.data,
                         self.background,
                         self.variance,
                         norm=self.norm,
                         scale=self.scale,
                         mask=self.mask)

        if self.size is not None:
            t, data = log_average(data, self.size)
        else:
            t = np.arange(data.shape[-1])

        return t, np.nanmean(data, axis=-2)
fig, ax = plt.subplots()
p1, = ax.plot(x, y1, color='red', label='red')
p2, = ax.plot(x, y2, color='blue', label='blue', visible=False)
p3, = ax.plot(x, y3, color='green', label='green', visible=False)
lines = [p1, p2, p3]

plt.axis([-2.5, 12, 0, 40])
plt.subplots_adjust(left=0.25, bottom=0.1, right=0.95, top=0.95)

labels = ['red', 'blue', 'green']
actives = [True, False, False]  # set the checkbox as checked on displaying
axCheckButton = plt.axes([0.03, 0.4, 0.15,
                          0.15])  # left, bottom, width, height
chxbox = CheckButtons(axCheckButton, labels, actives)


def set_visible(label):
    index = labels.index(label)
    lines[index].set_visible(not lines[index].get_visible())
    plt.draw()


cid = chxbox.on_clicked(set_visible)
# chxbox.disconnect(cid)

print(chxbox.get_status())

# chxbox.set_active(1)

plt.show()
예제 #11
0
class HeartbeatIntervalFinder(object):
    """
    Cursor for editing heartbeat signals for use of verification
    """
    def __init__(self, files,
                       folder_name = "",
                       dosage = 0,
                       file_number = 1,
                       area_around_echo_size = 240,
                       use_intervals = False,
                       preloaded_signal = False,
                       save_signal = False):

        super(HeartbeatIntervalFinder, self).__init__()
        # Load Arguments
        self.files            = files
        self.folder_name      = folder_name
        self.dosage           = dosage
        self.file_number      = file_number
        self.file_name        = files[folder_name][dosage][file_number]["file_name"]

        self.use_intervals            = use_intervals
        self.area_around_echo_size    = area_around_echo_size
        self.preloaded_signal         = preloaded_signal
        self.update_point             = None 
        self.new_bounds               = False

        # Save signals
        if save_signal:
            self.save_signals()

        # Load Signals and 2D Echo Time
        self.load_signals()

        # Determine Orginal bounds
        self.initialize_bounds()

        # Plot signals
        self.plot_signals()       

    def clip_signals(self):
        max_time = max(self.time)
        min_time = min(self.time)

        if (max_time - min_time) < self.area_around_echo_size:
            interval = range(np.searchsorted(self.time, min_time), np.searchsorted(self.time, max_time))
                            
        elif self.echo_time - self.area_around_echo_size/2 < min_time:
            interval = range(np.searchsorted(self.time, min_time), int(np.searchsorted(self.time, min_time + self.area_around_echo_size)))

        elif self.echo_time + self.area_around_echo_size/2 > max_time:
            interval = range(int(np.searchsorted(self.time, (max_time - self.area_around_echo_size))), np.searchsorted(self.time, max_time))

        else:
            interval = range(int(np.searchsorted(self.time, self.echo_time - (self.area_around_echo_size/2))), int(np.searchsorted(self.time, self.echo_time + (self.area_around_echo_size/2))))

        self.interval_near_echo = interval
        self.time   = self.time[interval]
        self.signal = self.signal[interval]
        self.seis   = self.seis[interval]
        self.phono  = self.phono[interval]

        self.signal = hb.bandpass_filter(time   = self.time, 
                                        signal  = self.signal,
                                        freqmin = 59, 
                                        freqmax = 61)

        self.seis = hb.bandpass_filter(time     = self.time, 
                                        signal  = self.seis,
                                        freqmin = 59, 
                                        freqmax = 61)

        self.phono = hb.bandpass_filter(time    = self.time, 
                                        signal  = self.phono,
                                        freqmin = 59, 
                                        freqmax = 61)

        self.signal = hb.lowpass_filter(time = self.time, 
                                        signal = self.signal,
                                        cutoff_freq = 10)

        self.seis = hb.lowpass_filter(time = self.time, 
                                    signal = self.seis,
                                    cutoff_freq = 50)

    def initialize_bounds(self):
        if self.use_intervals:
            self.lower_bound = self.files[self.folder_name][self.dosage][self.file_number]["intervals"][1][0]
            self.upper_bound = self.files[self.folder_name][self.dosage][self.file_number]["intervals"][1][1]

        else:
            max_time = max(self.time)
            min_time = min(self.time)
            inital_interval_size = 20
            if (max_time - min_time) < inital_interval_size:
                self.lower_bound = min_time
                self.upper_bound = max_time

            elif (self.echo_time - (inital_interval_size/2)) < min_time:
                self.lower_bound = min_time
                self.upper_bound = min_time + inital_interval_size

            elif (self.echo_time + (inital_interval_size/2)) > max_time:
                self.lower_bound = max_time - inital_interval_size
                self.upper_bound = max_time
                
            else:
                self.lower_bound = self.echo_time - (inital_interval_size/2)
                self.upper_bound = self.echo_time + (inital_interval_size/2)

    def plot_signals(self):
        # Create figure
        self.fig, self.ax = plt.subplots()

        self.ax.get_yaxis().set_visible(False)
        self.ax.set_xlabel("Time [s]")
        
        # Plot ECG, Phono and Seismo
        self.signal_line, = self.ax.plot(self.time, self.signal, label = "ECG", linewidth = 0.5, c = "b")
        self.seis_line, = self.ax.plot(self.time, self.seis,   label = "Seismo", linewidth = 0.5, c = "r")
        self.phono_line, = self.ax.plot(self.time, self.phono,  label = "Phono", linewidth = 0.5, c = "g")
        
        sig_min = min(self.signal)
        sig_max = max(self.signal)
        self.ax.set_xlim(self.time[0] - 0.1*(self.time[-1] - self.time[0]), self.time[-1] + 0.1*(self.time[-1] - self.time[0]))
        self.ax.set_ylim(sig_min - 0.1*(sig_max - sig_min), sig_max + 0.1*(sig_max - sig_min))

        # Echo Line
        signal_max = max(self.signal)
        signal_min = min(self.signal)
        self.echo_line = self.ax.axvline(self.echo_time,
                                             ymin = signal_min - abs(signal_max - signal_min),
                                             ymax = signal_max + abs(signal_max - signal_min),
                                             label = "2D Echo Time", c = "k", linewidth = 2)
        plt.legend(loc = "upper right")

        # Set endpoints
        self.bound_span   = self.ax.axvspan(self.lower_bound,
                                            self.upper_bound,
                                            facecolor='g', alpha=0.25)

        self.bound_text_height = -0.1
        self.lower_bound_text = self.ax.text(self.lower_bound, self.bound_text_height, transform = self.ax.get_xaxis_transform(),
                                            s = "Lower Bound\n" + str(self.lower_bound), fontsize=12, horizontalalignment = 'center')
        self.upper_bound_text = self.ax.text(self.upper_bound, self.bound_text_height, transform = self.ax.get_xaxis_transform(),
                                            s = "Upper Bound\n" + str(self.upper_bound), fontsize=12, horizontalalignment = 'center')
                                            

        # Initalize axes and data points
        self.x = self.time
        self.y = self.signal

        # Cross hairs
        self.lx = self.ax.axhline(color='k', linewidth=0.2)  # the horiz line
        self.ly = self.ax.axvline(color='k', linewidth=0.2)  # the vert line

        # Add data
        left_shift = 0.45
        start = 0.96
        space = 0.04
        self.folder_text = self.ax.text(0.01, start, transform = self.ax.transAxes,
                                            s = "Folder: " + self.folder_name, fontsize=12, horizontalalignment = 'left')
        self.dosage_text = self.ax.text(0.61 - left_shift, 1.1 - space, transform = self.ax.transAxes,
                    s = "Dosage: " + str(self.dosage), fontsize=12, horizontalalignment = 'left')
        self.file_name_text = self.ax.text(0.01, start - space, transform = self.ax.transAxes,
                    s = "File: " + self.files[self.folder_name][self.dosage][self.file_number]["file_name"], fontsize=12, horizontalalignment = 'left')
        


        # Add index buttons
        ax_prev = plt.axes([0.575 - left_shift, 0.9, 0.1, 0.075])
        self.bprev = Button(ax_prev, 'Previous')
        self.bprev.on_clicked(self.prev)
        
        ax_next = plt.axes([0.8 - left_shift, 0.9, 0.1, 0.075])
        self.b_next = Button(ax_next, 'Next')
        self.b_next.on_clicked(self.next)

        self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_move)

        # Add Save Button
        ax_save = plt.axes([0.8, 0.9, 0.1, 0.075])
        self.b_save = Button(ax_save, 'Save')
        self.b_save.on_clicked(self.save_intervals)
        
        # Add Line hide buttons
        self.ax.text(1.015, 0.97, transform = self.ax.transAxes,
                    s = "Hide Signal:", fontsize=12, horizontalalignment = 'left')
        ax_switch_signals = plt.axes([0.91, 0.7, 0.07, 0.15])
        self.b_switch_signals = CheckButtons(ax_switch_signals, ('ECG', 'Seismo', 'Phono'))

        self.b_switch_signals.on_clicked(self.switch_signal)

        # Add Sliders
        self.signal_amp_slider = Slider(plt.axes([0.91, 0.15, 0.01, 0.475]),
                                        label = "ECG\nA",
                                        valmin = 0.01,
                                        valmax = 10, 
                                        valinit = 1,
                                        orientation = 'vertical')
        self.signal_amp_slider.on_changed(self.switch_signal)
        self.signal_amp_slider.valtext.set_visible(False)

        self.seis_height_slider = Slider(plt.axes([0.93, 0.15, 0.01, 0.475]),
                                        label = "   Seis\nH",
                                        valmin = 1.5 * min(self.signal),
                                        valmax = 1.5 * max(self.signal), 
                                        valinit = 0,
                                        orientation = 'vertical')
        self.seis_height_slider.on_changed(self.switch_signal)
        self.seis_height_slider.valtext.set_visible(False)

        self.seis_amp_slider = Slider(plt.axes([0.94, 0.15, 0.01, 0.475]),
                                        label = "\nA",
                                        valmin = 0.01,
                                        valmax = 10, 
                                        valinit = 1,
                                        orientation = 'vertical')
        self.seis_amp_slider.on_changed(self.switch_signal)
        self.seis_amp_slider.valtext.set_visible(False)

        self.phono_height_slider = Slider(plt.axes([0.96, 0.15, 0.01, 0.475]),
                                        label = "    Phono\nH",
                                        valmin = 1.5 * min(self.signal),
                                        valmax = 1.5 * max(self.signal), 
                                        valinit = 0,
                                        orientation = 'vertical')
        self.phono_height_slider.on_changed(self.switch_signal)
        self.phono_height_slider.valtext.set_visible(False)

        self.phono_amp_slider = Slider(plt.axes([0.97, 0.15, 0.01, 0.475]),
                                        label = "A",
                                        valmin = .01,
                                        valmax = 10, 
                                        valinit = 1,
                                        orientation = 'vertical')
        self.phono_amp_slider.on_changed(self.switch_signal)
        self.phono_amp_slider.valtext.set_visible(False)

        # Add Clicking Actions
        self.fig.canvas.mpl_connect('button_press_event',   self.on_click)
        self.fig.canvas.mpl_connect('button_release_event', self.off_click)

        # Maximize frame
        mng = plt.get_current_fig_manager()
        mng.full_screen_toggle()

        plt.show()        

    def switch_signal(self, label):

        self.x = self.time
        self.y = [np.mean(self.signal)] * len(self.time)

        label = self.b_switch_signals.get_status()
        self.signal_line.set_data(self.time, self.signal_amp_slider.val * self.signal)
        self.seis_line.set_data(self.time, (self.seis_amp_slider.val * self.seis) + self.seis_height_slider.val)
        self.phono_line.set_data(self.time, (self.phono_amp_slider.val * self.phono) + self.phono_height_slider.val)


        if label[0]: # ECG
            self.signal_line.set_linewidth(0)
            
        else: 
            self.signal_line.set_linewidth(0.5)
            

        if label[1]: # Seismo
            self.seis_line.set_linewidth(0)

        else:
            self.seis_line.set_linewidth(0.5)

        if label[2]: # Phono
            self.phono_line.set_linewidth(0)

        else:
            self.phono_line.set_linewidth(0.5)

        # Update green shaded region
        self.bound_span.remove()
        self.bound_span = self.ax.axvspan(self.lower_bound, self.upper_bound, facecolor='g', alpha=0.25)

        self.lower_bound_text.set_position((self.lower_bound, self.bound_text_height))
        self.lower_bound_text.set_text("Lower Bound\n" + str(self.lower_bound))

        self.upper_bound_text.set_position((self.upper_bound, self.bound_text_height))
        self.upper_bound_text.set_text("Upper Bound\n" + str(self.upper_bound))

        # Update x and y lim
        if self.new_bounds:
            sig_min = min(self.signal)
            sig_max = max(self.signal)
            self.ax.set_xlim(self.time[0] - 0.1*(self.time[-1] - self.time[0]), self.time[-1] + 0.1*(self.time[-1] - self.time[0]))
            self.ax.set_ylim(sig_min - 0.1*(sig_max - sig_min), sig_max + 0.1*(sig_max - sig_min))

            self.new_bounds = False

        self.fig.canvas.draw()
            
    def on_click(self, event):
        threshold = 2
        self.update_point = None
        self.new_bounds = False

        # Make sure a click happened inside the subplot
        if (event.xdata is not None) and (str(type(event.inaxes)) == "<class 'matplotlib.axes._subplots.AxesSubplot'>"):

            if abs(self.lower_bound - event.xdata) < threshold:
                self.lx.set_color('g')
                self.ly.set_color('g')

                self.lx.set_linewidth(1)
                self.ly.set_linewidth(1)
                
                self.fig.canvas.draw()
                
                self.update_point = "lower_bound"

            elif abs(self.upper_bound - event.xdata) < threshold:
                self.lx.set_color('g')
                self.ly.set_color('g')

                self.lx.set_linewidth(1)
                self.ly.set_linewidth(1)
                
                self.fig.canvas.draw()
                
                self.update_point = "upper_bound"

    def off_click(self, event):
        self.lx.set_color('k')
        self.ly.set_color('k')

        self.lx.set_linewidth(0.2)
        self.ly.set_linewidth(0.2)

        if event.xdata is not None:
            if self.update_point == "lower_bound":

                self.lower_bound = max(self.time[0], round(event.xdata, 1))

            if self.update_point == "upper_bound":
                
                self.upper_bound = min(self.time[-1], round(event.xdata, 1))

            lower = min(self.lower_bound, self.upper_bound)
            upper = max(self.lower_bound, self.upper_bound)
            self.lower_bound = lower
            self.upper_bound = upper

            # Update on signal
            self.switch_signal(self.b_switch_signals.get_status())

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
    
    def mouse_move(self, event):
        # If nothing happened do nothing
        
        if not event.inaxes:
            return

        # Update x data point
        x = event.xdata

        # Lock to closest x coordinate on signal
        indx = min(np.searchsorted(self.x, x), len(self.x) - 1)
        x = self.x[indx]
        y = self.y[indx]

        # Update the crosshairs
        self.lx.set_ydata(y)
        self.ly.set_xdata(x)

        # Draw everything
        self.ax.figure.canvas.draw()

    def next(self, event):
        # Update Dosage
        if self.use_intervals:
            self.save_intervals(None) 
        self.dosage += 10
        if self.dosage > 40:
            self.dosage = 0

        # Update plot and Save
        self.new_bounds = True
        self.update_plot() 

    def prev(self, event):
        # Update Dosage
        if self.use_intervals:
            self.save_intervals(None) 
        self.dosage -= 10
        if self.dosage < 0:
            self.dosage = 40

        # Update plot and Save
        self.new_bounds = True
        self.update_plot()

    def save_intervals(self, event):
        # Save bounds
        save_filename = "data/Derived/intervals/Interval_Dict_" + self.folder_name
        self.files[self.folder_name][self.dosage][self.file_number]["intervals"][1][0] = self.lower_bound
        self.files[self.folder_name][self.dosage][self.file_number]["intervals"][1][1] = self.upper_bound

        # Save Bounds
        with open(save_filename + '.pkl', 'wb') as output:
            pickle.dump(self.files, output, pickle.HIGHEST_PROTOCOL)
        print("Saved Intervals")
        self.use_intervals = True

    def save_signals(self):

        print("Saving Signals...")
        for dosage in self.files[self.folder_name]:
            # Load data
            self.time, self.signal, self.seis, _, self.phono, _ = hb.load_file_data(files = self.files, 
                                                                                    folder_name = self.folder_name, 
                                                                                    dosage = dosage,
                                                                                    file_number = self.file_number)
            
            # Load Echo Time
            self.echo_time = self.files[self.folder_name][dosage][self.file_number]["echo_time"]

            # Clip signal size about echo time
            self.clip_signals()

            # Save File Name
            save_file_name = self.folder_name + "_d" + str(dosage)
            assert not os.path.isfile('data/Derived/signals/time_'  + save_file_name + '.csv'), "Saved signals already exist, please delete before saving new signals."

            # Save Signals
            np.savetxt('data/Derived/signals/time_'  + save_file_name + '.csv', self.time, delimiter=',')
            np.savetxt('data/Derived/signals/signal_'+ save_file_name + '.csv', self.signal, delimiter=',')
            np.savetxt('data/Derived/signals/seis_'  + save_file_name + '.csv', self.seis, delimiter=',')
            np.savetxt('data/Derived/signals/phono_' + save_file_name + '.csv', self.phono, delimiter=',')

            print("\tDosage " + str(dosage) + " done")

        print("...Done Saving Signals")

        # Use these saved files from now on
        self.preloaded_signal = True
        
    def load_signals(self):
        # Load Echo Time
        self.echo_time = self.files[self.folder_name][self.dosage][self.file_number]["echo_time"]

        if self.preloaded_signal:
            
            save_file_name = self.folder_name + "_d" + str(self.dosage)
            self.time   = np.loadtxt('data/Derived/signals/time_' + save_file_name + '.csv', delimiter=',')
            self.signal = np.loadtxt('data/Derived/signals/signal_' + save_file_name + '.csv', delimiter=',')
            self.seis   = np.loadtxt('data/Derived/signals/seis_' + save_file_name + '.csv', delimiter=',')
            self.phono  = np.loadtxt('data/Derived/signals/phono_' + save_file_name + '.csv', delimiter=',')

        else:
            self.time, self.signal, self.seis, _, self.phono, _ = hb.load_file_data(files = self.files, 
                                                                                folder_name = self.folder_name, 
                                                                                dosage = self.dosage,
                                                                                file_number = self.file_number)

            # Clip Signals
            self.clip_signals()

        
    def update_plot(self):
        # Display Loading Screen
        self.dosage_text.set_text("Loading: " + str(self.dosage) )
        print("Loading: "  + str(self.dosage) )
        self.fig.canvas.draw()

        # Update index
        self.folder_text.set_text("Folder: " + self.folder_name)
        self.dosage_text.set_text("Dosage: " + str(self.dosage))
        self.file_name_text.set_text("File: " + self.files[self.folder_name][self.dosage][self.file_number]["file_name"])
        
        # Load composite signals
        self.load_signals()

        # Update Echo Line
        self.echo_line.set_xdata(self.echo_time)

        # Find new orginal bounds
        self.initialize_bounds()

        # Update lines
        self.switch_signal(self.b_switch_signals.get_status())
        print("done")
예제 #12
0
파일: MF_Vis.py 프로젝트: aymanalz/gw_utils
class Vis(object):
    def __init__(self, fn):
        mf = flopy.modflow.Modflow.load(fn, model_ws=os.path.dirname(fn), load_only=['DIS', 'BAS6', 'UPW'])
        mf_files = gw_utils.get_mf_files(fn)
        hds_fn = mf_files['hds'][1]
        hob_out_fn = mf_files['HOB'][1]
        self.fn_hobOut = hob_out_fn
        try:  # text file
            import flopy.utils.formattedfile as ff
            hds = ff.FormattedHeadFile(hds_fn, precision='single')
        except:  # binary
            hds = flopy.utils.HeadFile(hds_fn)
            pass

        self.mf = mf
        self.hds = hds
        self.layout()

    def layout(self):

        # show main figure
        fig, ax = plt.subplots()
        self.fig = fig
        self.ax = ax
        self.all_times = self.hds.get_times()

        self.curr_time_index = 0
        self.curr_layer = 0
        self.togle_TS = 0
        self.curr_label = 'Grid Elev'

        mng = plt.get_current_fig_manager()
        mng.window.showMaximized()  # only windows

        # add  buttons
        ax_Extra = plt.axes([0.5, 0.02, 0.05, 0.05])
        self.Extrabn = Button(ax_Extra, 'Extra')
        self.Extrabn.on_clicked(self.strat_extra_frame)

        ax_timeseries = plt.axes([0.55, 0.02, 0.05, 0.05])
        self.Tsbn = Button(ax_timeseries, 'Hydrograph')
        self.Tsbn.on_clicked(self.PotTS)

        axBackward = plt.axes([0.6, 0.02, 0.05, 0.05])
        self.Bkbn = Button(axBackward, '<<')
        self.Bkbn.on_clicked(self.plot_backward)

        axForward = plt.axes([0.65, 0.02, 0.05, 0.05])
        self.Forbn = Button(axForward, '>>')
        self.Forbn.on_clicked(self.plot_forward)

        axUP= plt.axes([0.7, 0.02, 0.05, 0.05])
        self.upbn = Button(axUP, 'Up')
        self.upbn.on_clicked(self.changeLayerUp)

        axDN = plt.axes([0.75, 0.02, 0.05, 0.05])
        self.Dnbn = Button(axDN, 'Down')
        self.Dnbn.on_clicked(self.chaneLayerDn)

        #axcolor = 'lightgoldenrodyellow'
        plt.text(0.01, 0.875, "MF Package", fontsize=14, transform=plt.gcf().transFigure)
        rax = plt.axes([0.01, 0.65, 0.09, 0.22])
        radio = RadioButtons(rax, ('Grid Elev', 'Heads', 'Drawdown', 'flow direction', 'Grid Thickness', 'IBOUND', 'STRT', 'HK', 'VK', 'SS', 'SY', 'GWDepth'))
        radio.on_clicked(self.data_mode)

        plt.text(0.01, 0.16, "Point Layer", fontsize=14, transform=plt.gcf().transFigure)
        cax = plt.axes([0.01, 0.001, 0.09, 0.15])
        self.CheckedPointlabels = []
        self.checkPoints = CheckButtons(cax, ['HOBS', 'WELLS', 'GAGES'], actives = (False, False, False))
        self.checkPoints.on_clicked(self.plot_points)


        self.LayerTxt = plt.text(0.8, 0.87, "Layer {}".format(self.curr_layer), fontsize=14,
                                 transform=plt.gcf().transFigure)

        self.TimeTxt = plt.text(0.2, 0.87, "Totim {}".format(self.all_times[self.curr_time_index]), fontsize=14,
                                 transform=plt.gcf().transFigure)

        plt.show()

    def strat_extra_frame(self, event):
        self.plot_thickness_trans()


    def plot_thickness_trans(self):
        thk = self.mf.dis.thickness.array
        ib3d = self.mf.bas6.ibound.array
        ib3d[ib3d != 0] = 1
        hk = self.mf.upw.hk.array


        Total_thickness = (thk * ib3d).sum(axis = 0)
        trans = (thk * ib3d*hk).sum(axis = 0)
        ib2 = ib3d.sum(axis=0)

        ib2 = ib2 / ib2

        fig1, ax1 = plt.subplots()
        s1 = ax1.imshow(np.log10(trans*ib2), cmap = 'jet')
        ax1.set_title("Log Total Transmissivity")
        fig1.colorbar(s1, ax = ax1)
        fig1.canvas.draw()

        fig2, ax2 = plt.subplots()
        s2 = ax2.imshow(Total_thickness * ib2, cmap = 'plasma')
        ax2.set_title("Log Total Thickness")
        fig2.colorbar(s2, ax=ax2)
        fig2.canvas.draw()

        plt.show()
        xx = 1
        pass

    def plot_points(self, event):
        from gw_utils import hob_output_to_df, in_hob_to_df # ugly
        fn = os.path.join(self.mf.model_ws, self.mf.namefile)

        self.hobin = in_hob_to_df(mfname = fn)
        self.houout = hob_output_to_df(fn, mf=self.mf)

        status = self.checkPoints.get_status()
        for i, label in enumerate(self.checkPoints.labels):
            if status[i]:
                self.CheckedPointlabels.append(label.get_text())

        for label in self.CheckedPointlabels:
            if label == 'HOBS':
                self.axHob = self.ax.plot(self.hobin['col'], self.hobin['row'], marker= '.',
                                          markeredgecolor = 'b', picker = 5, linestyle = 'None')
                self.fig.canvas.mpl_connect('pick_event', self.on_pick_hob)

        self.fig.canvas.draw()
        xx = 1

    def on_pick_hob(self, event):
        # make code to select the colosets
        try:
            self.AxSelection.remove()
        except:
            pass

        ind = event.ind
        if len(ind) > 1:
            datax, datay = event.artist.get_data()
            datax, datay = [datax[i] for i in ind], [datay[i] for i in ind]
            msx, msy = event.mouseevent.xdata, event.mouseevent.ydata
            dist = np.sqrt((np.array(datax) - msx) ** 2 + (np.array(datay) - msy) ** 2)
            close_i = np.argmin(dist)
            ind = [ind[np.argmin(dist)]]
        x= datax[close_i]
        y = datay[close_i]

        self.AxSelection = self.ax.scatter(x, y,
                                           marker='*', c='r',
                                           s=200)

        name = self.hobin.iloc[ind]['name'].values[0]
        hrow =self.hobin.iloc[ind]['row'].values[0]
        hcol = self.hobin.iloc[ind]['col'].values[0]

        curr_hob_in = self.hobin[(self.hobin['row'] == hrow) & (self.hobin['col'] == hcol)]
        names = curr_hob_in['name'].unique()

        curr_obs_df = self.houout[self.houout['OBSERVATION NAME'].isin(names)]
        curr_obs_df = curr_obs_df.drop_duplicates(subset='OBSERVATION NAME', keep='last')
        curr_hob_in = curr_hob_in.drop_duplicates(subset='name', keep='first')

        self.plot_head_ts(x, y)

        tim = curr_hob_in['totim'].values
        obs_values =curr_obs_df['OBSERVED VALUE'].values
        sim_values = curr_obs_df['SIMULATED EQUIVALENT'].values

        maskSim = 0
        sim_values[sim_values==maskSim] = np.NaN

        self.axhydro.scatter(tim, obs_values, c = 'r', label='OBS')
        self.axhydro.scatter(tim, sim_values, c= 'b', label='SIM')
        self.axhydro.legend()
        title = self.axhydro.get_title()
        if "." in names[0]:
            name = names[0].split(".")[0]
        else:
            name = names[0]
        title = title + "\n" + name
        kkeys = curr_hob_in['mlay'].values[0].keys()
        vals = curr_hob_in['mlay'].values[0].values()

        kval = zip(kkeys, vals)
        title = title + "\n"
        for kv in list(kval):
            title = title + " {} : {}, ".format(kv[0], kv[1])
        self.axhydro.set_title(title)

        event.canvas.draw()
        event.canvas.flush_events()

    def get_arr(self):
        if self.curr_label in ['Heads', 'flow direction']:
            totim = self.all_times[self.curr_time_index]
            self.maxDataLayer = self.mf.nlay
            arr = self.hds.get_data(totim=totim)

        elif self.curr_label == 'Grid Elev':
            self.maxDataLayer = self.mf.nlay
            arr = np.zeros((self.mf.nlay+1, self.mf.nrow, self.mf.ncol))
            arr[0,:,:] = self.mf.dis.top.array
            arr[1:, :, :] = self.mf.dis.botm.array

        elif self.curr_label =='Drawdown':
            totim = self.all_times[self.curr_time_index]
            totim0 = self.all_times[0]
            self.maxDataLayer = self.mf.nlay
            arr = self.hds.get_data(totim=totim)-self.hds.get_data(totim=totim0)

        elif self.curr_label =='HK':
            arr = self.mf.upw.hk.array

        elif self.curr_label =='VK':
            arr = self.mf.upw.vka.array

        elif self.curr_label =='IBOUND':
            arr = self.mf.bas6.ibound.array

        elif self.curr_label == 'STRT':
            arr = self.mf.bas6.strt.array

        elif self.curr_label == 'SS':
            arr = self.mf.upw.ss.array

        elif self.curr_label == 'SY':
            arr = self.mf.upw.sy.array

        elif self.curr_label == 'GWDepth':
            totim = self.all_times[self.curr_time_index]
            self.maxDataLayer = self.mf.nlay
            arr = self.hds.get_data(totim=totim)
            arr = arr.copy()
            ttop = self.mf.dis.top.array.copy()
            for k in range(self.mf.nlay):
                arr[k, :,:] = ttop - arr[k,:,:]


        return arr

    def data_mode(self, label):
        self.curr_label = label


    def PotTS(self, event):
        if self.togle_TS == 0:
            self.togle_TS = 1
            self.Tsbn.color = 'red'
            cid1 = self.fig.canvas.mpl_connect('button_press_event', lambda event: self.click_hydrograph(event))
            print("red")

        else:
            self.togle_TS = 0
            self.Tsbn.color = 'green'
            cid2 = self.fig.canvas.mpl_connect('close_event', lambda event: self.close_hydro(event))

    def close_hydro(self, event):
        pass

    def plot_head_ts(self, x, y):
        Ly = self.mf.dis.delc.array.sum()
        Lx = self.mf.dis.delr.array.sum()
        x = Lx * x / self.mf.ncol
        y = Ly - Ly * y / self.mf.nrow

        rr_cc1 = plotutil.findrowcolumn((x, y), self.mf.modelgrid.xyedges[0],
                                        self.mf.modelgrid.xyedges[1])
        self.fighydro, self.axhydro = plt.subplots()

        elevs = []
        pre_elev = 0
        for k in range(self.mf.nlay):
            rr_cc = (k, rr_cc1[0], rr_cc1[1])
            ib = self.mf.bas6.ibound.array[k, rr_cc1[0], rr_cc1[1]]
            # get elevation
            if k == 0:
                elev = self.mf.dis.top.array[rr_cc1[0], rr_cc1[1]]
                pre_elev = elev
                elevs.append(elev)

            if ib != 0:
                elev = self.mf.dis.botm.array[k, rr_cc1[0], rr_cc1[1]]
                pre_elev = elev
                elevs.append(elev)
            else:
                elevs.append(pre_elev)

            ts = self.hds.get_ts(rr_cc)
            ts[:, 1][ts[:, 1] == self.mf.bas6.hnoflo] = np.NaN

            lb = "Head {}".format(k + 1)
            self.axhydro.plot(ts[:, 0], ts[:, 1], label=lb)
        # plot layers
        if 0:
            for ielev, elev in enumerate(elevs):
                ts2 = ts.copy()
                ts2[:, 1] = elev
                ax.plot(ts2[:, 0], ts2[:, 1], label="Layer {}".format(ielev))
        plt.legend()
        plt.title("Row = {}, Col= {}".format(rr_cc1[0], rr_cc1[1]))
        plt.show()

    def click_hydrograph(self,event):
            if self.togle_TS==0:
                return
            try:
                self.axHeadLocation.remove()
            except:
                pass
            self.axHeadLocation = self.ax.scatter(event.xdata, event.ydata, color='r')
            event.canvas.draw()
            event.canvas.flush_events()

            ##
            if 0:
                Ly = self.mf.dis.delc.array.sum()
                Lx = self.mf.dis.delr.array.sum()
                x = Lx * event.xdata/self.mf.ncol
                y = Ly - Ly * event.ydata / self.mf.nrow

            self.plot_head_ts(event.xdata,event.ydata)


    def chaneLayerDn(self, event):
        self.curr_layer = self.curr_layer + 1
        if self.curr_layer > self.mf.nlay-1 :
            self.curr_layer = self.curr_layer - 1
            return False
        self.update_fig()
        print(self.curr_layer)

    def changeLayerUp(self, event):
        self.curr_layer = self.curr_layer - 1
        if self.curr_layer < 0:
            self.curr_layer = self.curr_layer + 1
            return False
        self.update_fig()
        print(self.curr_layer)

    def update_fig(self):
        arr = self.get_arr()
        #self.fig.suptitle("Totim = {} ".format(totim))
        if self.curr_label == 'flow direction':
            self._plot_dir(arr)
        else:
            self._plot_arr(arr)

    def _plot_dir(self, arr):
        self.ax.clear()
        arr = arr[self.curr_layer, :,:]
        ibound = self.mf.bas6.ibound.array[self.curr_layer, :,:]
        arr[ibound == 0] = np.NaN
        # *******************
        # second subplot
        x = np.arange(0, self.mf.ncol, 1)
        y = np.arange(0, self.mf.nrow, 1)
        X, Y = np.meshgrid(x, y)
        dx, dy = np.gradient(arr)
        dx = dx
        dy = dy
        n = -2
        color = np.sqrt(((dx - n) / 2)**2 + ((dy - n) / 2)**2)
        f = lambda x: np.sign(x) * np.log10(1 + np.abs(x))
        im = self.ax.quiver(X, Y, f(dx), f(dy), color, scale=0.5, units="xy")
        #im =  self.ax.streamplot(X, Y, dx, dy)

        self.ax.contour(arr, colors = 'k')
        self.ax.invert_yaxis()
        self.LayerTxt.remove()
        self.LayerTxt = plt.text(0.8, 0.87, "Layer {}".format(self.curr_layer+1), fontsize=14,
                                 transform=plt.gcf().transFigure, color='r')
        self.TimeTxt.remove()
        self.TimeTxt = plt.text(0.2, 0.87, "Totim {}".format(self.all_times[self.curr_time_index]), fontsize=14,
                                 transform=plt.gcf().transFigure)
        self.fig.canvas.draw()


        try:
            self.cax.remove()
        except:
            pass
        self.cax = self.fig.colorbar(im, fraction = 0.02, pad = 0.01, ax= self.ax, orientation =	'vertical' )

    def _plot_arr(self, arr):
        self.ax.clear()
        arr = arr[self.curr_layer, :,:]
        ibound = self.mf.bas6.ibound.array[self.curr_layer, :,:]
        arr[ibound == 0] = np.NaN
        im = self.ax.imshow(arr)


        self.ax.contour(arr, colors = 'k')

        self.LayerTxt.remove()
        self.LayerTxt = plt.text(0.8, 0.87, "Layer {}".format(self.curr_layer+1), fontsize=14,
                                 transform=plt.gcf().transFigure, color='r')
        self.TimeTxt.remove()
        self.TimeTxt = plt.text(0.2, 0.87, "Totim {}".format(self.all_times[self.curr_time_index]), fontsize=14,
                                 transform=plt.gcf().transFigure)
        self.fig.canvas.draw()

        from mpl_toolkits.axes_grid1 import make_axes_locatable
        divider = make_axes_locatable(self.ax)
        try:
            self.cax.remove()
            del(divider)
        except:
            pass
        self.cax = self.fig.colorbar(im, fraction = 0.02, pad = 0.01, ax= self.ax, orientation =	'vertical' )

    def plot_forward(self, event):
        self.curr_time_index = self.curr_time_index + 1
        if self.curr_time_index> len(self.all_times)-1:
            self.curr_time_index = self.curr_time_index  - 1
            return False
        self.update_fig()

    def plot_backward(self, event):
        self.curr_time_index = self.curr_time_index - 1
        if self.curr_time_index < 0:
            self.curr_time_index = self.curr_time_index - 1
            return False

        self.update_fig()
예제 #13
0
class Oscillo:
    def __init__(
        self,
        channel_list,
        sampling=5e3,
        volt_range=10.0,
        trigger_level=None,
        trigsource=None,
        backend="daqmx",
    ):
        if backend not in Backends:
            raise ValueError(f"Backend {backend:} is not available.")
        self.backend = backend
        self.channel_list = channel_list
        self.sampling = sampling
        self.volt_range = volt_range
        self.ignore_voltrange_submit = False
        self.trigger_level = trigger_level
        self.trigger_source = trigsource
        self.N = 1024
        plt.ion()
        self.fig = plt.figure()
        # left, bottom, width, height
        self.ax = self.fig.add_axes([0.1, 0.1, 0.7, 0.8])
        self.running = False
        self.last_trigged = 0
        self.ask_pause_acqui = False
        self.paused = False

        self.freq = None
        self.Pxx = None
        self.N_spectra = 0
        self.fig_spectrum = None
        self.hanning = True
        self.ac = True
        self.power_spectrum = True
        self.spectrum_unit = 1.0
        self.spectrum_unit_str = "V^2/Hz"
        self.system = None
        self.saved_spectra = list()

        self.fig_stats = None

        # Configure widgets
        left, bottom = 0.825, 0.825
        width, height = 0.15, 0.075
        vpad = 0.02

        ax_sampling = self.fig.add_axes([left, bottom, width, height])
        bottom -= height * 1.25 + 2 * vpad
        self.textbox_sampling = TextBox(ax_sampling,
                                        label="",
                                        initial=f"{sampling:.1e}")
        self.textbox_sampling.on_submit(self.ask_sampling_change)
        _ = ax_sampling.text(0, 1.25, "Sampling")

        ax_enable_trigger = self.fig.add_axes([left, bottom, width, height])
        bottom -= height * 1.25 + 2 * vpad
        self.textbox_triggersource = TextBox(ax_enable_trigger,
                                             label="",
                                             initial="None")
        self.textbox_triggersource.on_submit(self.ask_trigger_change)
        _ = ax_enable_trigger.text(0, 1.25, "Trigger")

        ax_triggerlevel = self.fig.add_axes([left, bottom, width, height])
        bottom -= height * 1.25 + 2 * vpad
        if trigger_level is not None:
            initial_value = f"{trigger_level:.2f}"
        else:
            initial_value = "1.0"
        self.textbox_triggerlevel = TextBox(ax_triggerlevel,
                                            label="",
                                            initial=initial_value)
        self.textbox_triggerlevel.on_submit(self.ask_trigger_change)
        _ = ax_triggerlevel.text(0, 1.25, "Level")

        ax_winsize = self.fig.add_axes([left, bottom, width, height])
        bottom -= height * 1.25 + 2 * vpad
        self.textbox_winsize = TextBox(ax_winsize,
                                       label="",
                                       initial=f"{self.N:d}")
        self.textbox_winsize.on_submit(self.ask_winsize_change)
        _ = ax_winsize.text(0, 1.25, "Win size")

        ax_voltrange = self.fig.add_axes([left, bottom, width, height])
        bottom -= height * 1.25 + 2 * vpad
        self.textbox_voltrange = TextBox(ax_voltrange,
                                         label="",
                                         initial=f"{self.volt_range:.1f}")
        self.textbox_voltrange.on_submit(self.ask_voltrange_change)
        _ = ax_voltrange.text(0, 1.25, "Range")

        ax_start_stats = self.fig.add_axes([left, bottom, width, height])
        bottom -= height + vpad
        self.btn_start_stats = Button(ax_start_stats, label="Stats")
        self.btn_start_stats.on_clicked(self.start_stats)

        ax_start_spectrum = self.fig.add_axes([left, bottom, width, height])
        bottom -= height + vpad
        self.btn_start_spectrum = Button(ax_start_spectrum, label="FFT")
        self.btn_start_spectrum.on_clicked(self.start_spectrum)

    def clean_spectrum(self, *args):
        self.freq = None
        self.Pxx = None
        self.N_spectra = 0

    def start_stats(self, event):
        if self.fig_stats is None:
            self.fig_stats = dict()
            self.box_mean = dict()
            self.box_std = dict()
            self.box_min = dict()
            self.box_max = dict()
            self.box_freq = dict()
        nbox = 5
        height = 1.0 / (nbox + 1)
        padding = height / 4
        for chan in self.channel_list:
            if chan not in self.fig_stats:
                self.fig_stats[chan] = plt.figure(figsize=(2, 4))
                self.fig_stats[chan].canvas.set_window_title(chan)

                ax_mean = self.fig_stats[chan].add_axes(
                    [0.25, 9 * height / 2, 0.7, height - padding])
                self.box_mean[chan] = TextBox(ax_mean,
                                              label="Mean",
                                              initial="")

                ax_std = self.fig_stats[chan].add_axes(
                    [0.25, 7 * height / 2, 0.7, height - padding])
                self.box_std[chan] = TextBox(ax_std, label="Std", initial="")

                ax_min = self.fig_stats[chan].add_axes(
                    [0.25, 5 * height / 2, 0.7, height - padding])
                self.box_min[chan] = TextBox(ax_min, label="Min", initial="")

                ax_max = self.fig_stats[chan].add_axes(
                    [0.25, 3 * height / 2, 0.7, height - padding])
                self.box_max[chan] = TextBox(ax_max, label="Max", initial="")

                ax_freq = self.fig_stats[chan].add_axes(
                    [0.25, height / 2, 0.7, height - padding])
                self.box_freq[chan] = TextBox(ax_freq,
                                              label="Freq",
                                              initial="")

    def start_spectrum(self, *args, **kwargs):
        if self.fig_spectrum is None:
            self.fig_spectrum = plt.figure()
            self.ax_spectrum = self.fig_spectrum.add_axes([0.1, 0.1, 0.7, 0.8])

            # Widgets
            ax_hanning = self.fig_spectrum.add_axes([0.825, 0.75, 0.15, 0.15])
            self.checkbox_hanning = CheckButtons(ax_hanning, ["Hanning", "AC"],
                                                 [self.hanning, self.ac])
            self.checkbox_hanning.on_clicked(self.ask_hanning_change)

            ax_spectrum_unit = self.fig_spectrum.add_axes(
                [0.825, 0.51, 0.15, 0.25])
            self.radio_units = RadioButtons(
                ax_spectrum_unit,
                ["V^2/Hz", "V/sq(Hz)", "mV/sq(Hz)", "µV/sq(Hz)", "nV/sq(Hz)"],
            )
            self.radio_units.on_clicked(self.ask_spectrum_units_change)

            ax_restart = self.fig_spectrum.add_axes([0.825, 0.35, 0.15, 0.075])
            self.btn_restart = Button(ax_restart, label="Restart")
            self.btn_restart.on_clicked(self.clean_spectrum)

            ax_save_hold = self.fig_spectrum.add_axes(
                [0.825, 0.25, 0.15, 0.075])
            self.btn_save_hold = Button(ax_save_hold, label="Hold&Save")
            self.btn_save_hold.on_clicked(self.save_hold)

        self.clean_spectrum()

    def save_hold(self, event):
        if self.N_spectra > 0:
            dt = datetime.now()
            filename = (f"pymanip_oscillo_{dt.year:}-{dt.month}-{dt.day}"
                        f"_{dt.hour}-{dt.minute}-{dt.second}.hdf5")
            bb = self.freq > 0
            with h5py.File(filename) as f:
                f.attrs["ts"] = dt.timestamp()
                f.attrs["N_spectra"] = self.N_spectra
                f.attrs["sampling"] = self.sampling
                f.attrs["volt_range"] = self.volt_range
                f.attrs["N"] = self.N
                f.create_dataset("freq", data=self.freq[bb])
                f.create_dataset("Pxx", data=self.Pxx[bb] / self.N_spectra)
            self.saved_spectra.append({
                "freq": self.freq[bb],
                "Pxx": self.Pxx[bb] / self.N_spectra
            })

    def ask_spectrum_units_change(self, event):
        power_spectrum_dict = {
            "V^2/Hz": True,
            "V/sq(Hz)": False,
            "mV/sq(Hz)": False,
            "µV/sq(Hz)": False,
            "nV/sq(Hz)": False,
        }
        spectrum_unit_dict = {
            "V^2/Hz": 1.0,
            "V/sq(Hz)": 1.0,
            "mV/sq(Hz)": 1e3,
            "µV/sq(Hz)": 1e6,
            "nV/sq(Hz)": 1e9,
        }
        if event in power_spectrum_dict:
            self.spectrum_unit_str = event
            self.power_spectrum = power_spectrum_dict[event]
            self.spectrum_unit = spectrum_unit_dict[event]

    async def winsize_change(self, new_N):
        await self.pause_acqui()
        old_N = self.N
        self.N = new_N
        self.clean_spectrum()
        self.figure_t_axis()
        try:
            self.system.configure_clock(self.sampling, self.N)
        except Exception as e:
            print(e)
            self.N = old_N
        await self.restart_acqui()

    def ask_winsize_change(self, label):
        try:
            new_N = int(label)
        except ValueError:
            print("winsize must be an integer")
            return
        loop = asyncio.get_event_loop()
        asyncio.ensure_future(self.winsize_change(new_N), loop=loop)

    def ask_trigger_change(self, label):
        changed = False
        trigger_source = self.textbox_triggersource.text
        possible_trigger_source = self.system.possible_trigger_channels()
        if trigger_source not in possible_trigger_source:
            print(f"{trigger_source:} is not an acceptable trigger source")
            print("Possible values:", possible_trigger_source)
            trigger_source = None
            self.textbox_triggersource.set_val("None")
        if trigger_source == "None":
            trigger_source = None
        if self.trigger_source != trigger_source:
            self.trigger_source = trigger_source
            changed = True
        if self.trigger_source is not None:
            try:
                trigger_level = float(self.textbox_triggerlevel.text)
            except ValueError:
                trigger_level = self.trigger_level
            val_str = f"{trigger_level:.2f}"
            self.textbox_triggerlevel.set_val(val_str)
            if trigger_level != self.trigger_level:
                self.trigger_level = trigger_level
                changed = True
        if changed:
            loop = asyncio.get_event_loop()
            asyncio.ensure_future(self.trigger_change(), loop=loop)

    async def pause_acqui(self):
        self.ask_pause_acqui = True
        await self.system.stop()
        while not self.paused:
            await asyncio.sleep(0.5)

    async def restart_acqui(self):
        self.ask_pause_acqui = False
        while self.paused:
            await asyncio.sleep(0.5)

    async def trigger_change(self):
        print("Awaiting pause")
        await self.pause_acqui()
        print("Acquisition paused")
        if self.trigger_level is not None:
            trigger_source = self.trigger_source
            trigger_level = self.trigger_level
            self.system.configure_trigger(trigger_source, trigger_level)
        else:
            self.system.configure_trigger(None)
        self.clean_spectrum()
        await self.restart_acqui()

    async def sampling_change(self):
        await self.pause_acqui()
        try:
            self.system.configure_clock(self.sampling, self.N)
        except Exception:
            print("Invalid sampling frequency")
            self.ask_sampling_change(self.system.samp_clk_max_rate)
            return
        if self.system.sample_rate != self.sampling:
            self.ask_sampling_change(self.system.sample_rate)
            return
        self.figure_t_axis()
        self.clean_spectrum()
        await self.restart_acqui()

    def ask_sampling_change(self, sampling):
        try:
            self.sampling = float(sampling)
            changed = True
        except ValueError:
            print("Wrong value:", sampling)
            changed = False
        self.textbox_sampling.set_val(f"{self.sampling:.1e}")
        if changed:
            loop = asyncio.get_event_loop()
            asyncio.ensure_future(self.sampling_change(), loop=loop)

    def ask_hanning_change(self, label):
        self.hanning, self.ac = self.checkbox_hanning.get_status()
        self.clean_spectrum()

    def figure_t_axis(self):
        self.t = np.arange(self.N) / self.sampling
        if self.t[-1] < 1:
            self.t *= 1000
            self.unit = "[ms]"
        else:
            self.unit = "[s]"

    async def run_gui(self):
        while self.running:
            if time.monotonic() - self.last_trigged > self.N / self.sampling:
                self.ax.set_title("Waiting for trigger")
            self.fig.canvas.start_event_loop(0.5)
            await asyncio.sleep(0.05)
            if not plt.fignum_exists(self.fig.number):
                self.running = False
            if self.fig_spectrum and not plt.fignum_exists(
                    self.fig_spectrum.number):
                self.fig_spectrum = None
                self.freq = None
                self.Pxx = None
                self.N_spectra = 0
            if self.fig_stats is not None:
                for chan in list(self.fig_stats.keys()):
                    if not plt.fignum_exists(self.fig_stats[chan].number):
                        self.fig_stats.pop(chan)
                        self.box_mean.pop(chan)
                        self.box_std.pop(chan)
                        self.box_min.pop(chan)
                        self.box_max.pop(chan)
                        self.box_freq.pop(chan)
                    if not self.fig_stats:
                        self.fig_stats = None

    def ask_voltrange_change(self, new_range):
        if not self.ignore_voltrange_submit:
            try:
                new_range = float(new_range)
            except ValueError:
                print("Volt range must be a float")
                return
            loop = asyncio.get_event_loop()
            asyncio.ensure_future(self.voltrange_change(new_range), loop=loop)

    async def voltrange_change(self, new_range):
        await self.pause_acqui()
        self.volt_range = new_range
        self.clean_spectrum()
        self.create_system()
        await self.restart_acqui()
        actual_range = self.system.actual_ranges[0]
        print("actual_range =", actual_range)
        self.ignore_voltrange_submit = True
        self.textbox_voltrange.set_val(f"{actual_range:.1f}")
        self.ignore_voltrange_submit = False

    def create_system(self):
        if self.system is not None:
            self.system.close()
        self.system = Backends[self.backend]()
        self.ai_channels = list()
        for chan in self.channel_list:
            ai_chan = self.system.add_channel(chan,
                                              terminal_config=TC.Diff,
                                              voltage_range=self.volt_range)
            self.ai_channels.append(ai_chan)
        self.system.configure_clock(self.sampling, self.N)
        self.textbox_sampling.set_val(f"{self.system.sample_rate:.1e}")
        if self.trigger_level is not None:
            trigger_source = self.channel_list[self.trigger_source]
            self.system.configure_trigger(trigger_source,
                                          trigger_level=self.trigger_level)
        else:
            self.system.configure_trigger(None)
        self.figure_t_axis()

    async def run_acqui(self):
        self.create_system()
        try:
            while self.running:
                while self.ask_pause_acqui and self.running:
                    self.paused = True
                    await asyncio.sleep(0.5)
                self.paused = False
                if not self.running:
                    break
                self.system.start()
                data = await self.system.read()
                await self.system.stop()
                if data is None:
                    continue
                self.last_trigged = self.system.last_read
                self.ax.cla()
                if len(self.channel_list) == 1:
                    self.ax.plot(self.t, data, "-")
                elif len(self.channel_list) > 1:
                    for d in data:
                        self.ax.plot(self.t, d, "-")
                if self.trigger_source not in (None, "Ext"):
                    self.ax.plot([self.t[0], self.t[-1]],
                                 [self.trigger_level] * 2, "g--")
                self.ax.set_xlim([self.t[0], self.t[-1]])
                self.ax.set_title("Trigged!")
                self.ax.set_xlabel("t " + self.unit)
                if self.fig_spectrum:
                    self.ax_spectrum.cla()
                    if self.saved_spectra:
                        for spectra in self.saved_spectra:
                            self.ax_spectrum.loglog(spectra["freq"],
                                                    spectra["Pxx"], "-")
                    if self.N_spectra == 0:
                        self.freq = np.fft.fftfreq(self.N, 1.0 / self.sampling)
                        bb = self.freq > 0
                        norm = math.pi * math.sqrt(self.N / self.sampling)
                        if self.hanning:
                            window = np.hanning(self.N)
                        else:
                            window = np.ones((self.N, ))
                        if len(self.channel_list) == 1:
                            if self.ac:
                                m = np.mean(data)
                            else:
                                m = 0.0
                            self.Pxx = (np.abs(
                                np.fft.fft((data - m) * window) / norm)**2)
                        else:
                            if self.ac:
                                ms = [np.mean(d) for d in data]
                            else:
                                ms = [0.0 for d in data]
                            self.Pxx = [
                                np.abs(np.fft.fft((d - m) * window) / norm)**2
                                for d, m in zip(data, ms)
                            ]
                        self.N_spectra = 1
                    else:
                        if len(self.channel_list) == 1:
                            if self.ac:
                                m = np.mean(data)
                            else:
                                m = 0.0
                            self.Pxx += (np.abs(
                                np.fft.fft((data - m) * window) / norm)**2)
                        else:
                            if self.ac:
                                ms = [np.mean(d) for d in data]
                            else:
                                ms = [0.0 for d in data]
                            for p, d, m in zip(self.Pxx, data, ms):
                                p += np.abs(
                                    np.fft.fft((d - m) * window) / norm)**2
                        self.N_spectra += 1
                    if self.power_spectrum:

                        def process_spec(s):
                            return self.spectrum_unit * s

                    else:

                        def process_spec(s):
                            return self.spectrum_unit * np.sqrt(s)

                    if len(self.channel_list) == 1:
                        self.ax_spectrum.loglog(
                            self.freq[bb],
                            process_spec(self.Pxx[bb] / self.N_spectra),
                            "-",
                        )
                    else:
                        for p in self.Pxx:
                            self.ax_spectrum.loglog(
                                self.freq[bb],
                                process_spec(p[bb] / self.N_spectra), "-")
                    self.ax_spectrum.set_xlabel("f [Hz]")
                    self.ax_spectrum.set_ylabel(self.spectrum_unit_str)
                    self.ax_spectrum.set_title(f"N = {self.N_spectra:d}")
                if self.fig_stats:
                    if len(self.channel_list) == 1:
                        list_data = [data]
                    else:
                        list_data = data
                    for chan, d in zip(self.channel_list, list_data):
                        if chan in self.fig_stats:
                            self.box_mean[chan].set_val("{:.5f}".format(
                                np.mean(d)))
                            self.box_std[chan].set_val("{:.5f}".format(
                                np.std(d)))
                            self.box_min[chan].set_val("{:.5f}".format(
                                np.min(d)))
                            self.box_max[chan].set_val("{:.5f}".format(
                                np.max(d)))
                            ff = np.fft.fftfreq(self.N, 1.0 / self.sampling)
                            pp = np.abs(np.fft.fft(d - np.mean(d)))
                            ii = np.argmax(pp)
                            self.box_freq[chan].set_val("{:.5f}".format(
                                ff[ii]))

        finally:
            self.system.close()

    def ask_exit(self, *args, **kwargs):
        self.running = False

    def run(self):
        loop = asyncio.get_event_loop()
        self.running = True
        if sys.platform == "win32":
            signal.signal(signal.SIGINT, self.ask_exit)
        else:
            for signame in ("SIGINT", "SIGTERM"):
                loop.add_signal_handler(getattr(signal, signame),
                                        self.ask_exit)

        loop.run_until_complete(
            asyncio.gather(self.run_gui(), self.run_acqui()))
예제 #14
0
class NetworkVisualisation(object):
    def __init__(self,
                 units,
                 data_points,
                 min_range,
                 max_range,
                 quality,
                 dataset,
                 saves_path=None,
                 seed=1):
        np.random.seed(seed)

        self.precision = quality
        self.min_range = min_range
        self.max_range = max_range
        self.dataset_type = dataset
        if dataset is Dataset.CIRCLE:
            self.dataset = get_circle_dataset(points=data_points,
                                              min_range=min_range,
                                              max_range=max_range,
                                              radius=0.8)
        elif dataset is Dataset.SPIRAL:
            self.dataset = get_spiral_dataset(data_points, classes=2)
        else:
            raise Exception("Invalid dataset type")
        self.data_points = self.dataset[:, :-1]
        self.data_labels = self.dataset[:, -1]
        self.data_space, self.dim_data = setup_data_space(
            min_range, max_range, quality)

        # Network Creation
        if saves_path and check_saved_network(units, dataset, saves_path):
            self.network = load_network(units, dataset, saves_path)
        else:
            if dataset is Dataset.CIRCLE:
                self.network = train_network_sigmoid(self.dataset,
                                                     units=units,
                                                     learning_rate=5e-3,
                                                     window_size=1000)
            elif dataset is Dataset.SPIRAL:
                self.network = train_network_softmax(self.dataset,
                                                     units=units,
                                                     learning_rate=1,
                                                     window_size=1000)
            else:
                raise Exception("Invalid dataset type")
            save_network(self.network, dataset, saves_path)
        self.default_network = dict(
            zip(self.network.keys(),
                [layer.copy() for layer in self.network.values()]))

        # GUI Visualisation
        self.perceptron1 = 0
        self.is_relu = True
        self.perceptron2 = 0
        self.connection = 0
        self.is_pre_add = False
        self.is_sig = True
        self.all_p1_enabled = set(range(self.network["W1"].shape[1]))
        self.ignore_update = False

        fig = plt.figure(figsize=(13, 6.5))
        self.plot_network(fig)
        self.plot_controls(fig)

    def plot_network(self, fig):
        _, out2, out3, out4 = forward(self.data_space,
                                      self.network,
                                      self.dataset_type,
                                      precision=self.precision)
        outer_points = self.data_points[self.data_labels == 1]
        inner_points = self.data_points[self.data_labels == 0]

        self.layer1_plot = Plot(fig, (4, 4), (0, 0), (1, 3), out2[:, :, 0],
                                self.min_range, self.max_range)
        self.layer1_plot.ax.scatter(outer_points[:, 0],
                                    outer_points[:, 1],
                                    s=3,
                                    c="g",
                                    alpha=0.5)
        self.layer1_plot.ax.scatter(inner_points[:, 0],
                                    inner_points[:, 1],
                                    s=3,
                                    c="r",
                                    alpha=0.5)

        self.layer1_3d_plot = Plot3D(fig, (4, 4), (1, 0), (1, 3),
                                     self.precision, out2)

        self.layer2_plot = Plot(fig, (4, 4), (2, 0), (1, 3), out4[:, :, 0],
                                self.min_range, self.max_range)
        self.layer2_plot.ax.scatter(outer_points[:, 0],
                                    outer_points[:, 1],
                                    s=3,
                                    c="g",
                                    alpha=0.5)
        self.layer2_plot.ax.scatter(inner_points[:, 0],
                                    inner_points[:, 1],
                                    s=3,
                                    c="r",
                                    alpha=0.5)

        self.layer2_3d_plot = Plot3D(fig, (4, 4), (3, 0), (1, 3),
                                     self.precision, out4)

    def plot_controls(self, fig):
        step_size = 0.01
        padding = 5

        # Plot 1 controls
        w1x_min = self.network["W1"][0].min()
        w1x_max = self.network["W1"][0].max()
        w1x_diff = (w1x_max - w1x_min) / 2 + padding

        w1y_min = self.network["W1"][1].min()
        w1y_max = self.network["W1"][1].max()
        w1y_diff = (w1y_max - w1y_min) / 2 + padding

        w1b_min = self.network["b1"].min()
        w1b_max = self.network["b1"].max()
        w1b_diff = (w1b_max - w1b_min) / 2 + padding

        p1x_ax = plot_to_grid(fig, (2, 16), (0, 12), (1, 1))
        self.p1x_slid = Slider(p1x_ax,
                               'P1 x',
                               valmin=w1x_min - w1x_diff,
                               valmax=w1x_max + w1x_diff,
                               valinit=self.network["W1"][0, 0],
                               valstep=step_size)
        self.p1x_slid.on_changed(self.p1x_changed)

        p1y_ax = plot_to_grid(fig, (2, 16), (0, 13), (1, 1))
        self.p1y_slid = Slider(p1y_ax,
                               'P1 y',
                               valmin=w1y_min - w1y_diff,
                               valmax=w1y_max + w1y_diff,
                               valinit=self.network["W1"][1, 0],
                               valstep=step_size)
        self.p1y_slid.on_changed(self.p1y_changed)

        p1b_ax = plot_to_grid(fig, (24, 16), (0, 14), (7, 1))
        self.p1b_slid = Slider(p1b_ax,
                               'P1 b',
                               valmin=w1b_min - w1b_diff,
                               valmax=w1b_max + w1b_diff,
                               valinit=self.network["b1"][0, 0],
                               valstep=step_size)
        self.p1b_slid.on_changed(self.p1b_changed)

        p1_ax = plot_to_grid(fig, (24, 16), (0, 15), (7, 1))
        self.p1_slid = Slider(p1_ax,
                              'P1',
                              valmin=0,
                              valmax=self.network["W1"].shape[1] - 1,
                              valinit=self.perceptron1,
                              valstep=1)
        self.p1_slid.on_changed(self.p1_changed)

        p1_opt_ax = plot_to_grid(fig, (24, 16), (8, 14), (3, 2))
        self.p1_opt_buttons = CheckButtons(p1_opt_ax, ["ReLU?", "Enabled?"],
                                           [self.is_relu, True])
        self.p1_opt_buttons.on_clicked(self.p1_options_update)

        # Plot 2 Controls
        w2_min = self.network["W2"].min()
        w2_max = self.network["W2"].max()
        w2_diff = (w2_max - w2_min) / 2 + padding

        w2b_abs = np.abs(self.network["b2"][0, 0]) + padding
        w2b_min = self.network["b2"][0, 0] - w2b_abs
        w2b_max = self.network["b2"][0, 0] + w2b_abs

        p2_weight_val_ax = plot_to_grid(fig, (2, 16), (1, 12), (1, 1))
        self.p2_dim_val_slid = Slider(p2_weight_val_ax,
                                      'p2 w',
                                      valmin=w2_min - w2_diff,
                                      valmax=w2_max + w2_diff,
                                      valinit=self.network["W2"][0, 0],
                                      valstep=step_size)
        self.p2_dim_val_slid.on_changed(self.p2_weight_changed)

        p2_connection_dim_ax = plot_to_grid(fig, (2, 16), (1, 13), (1, 1))
        self.p2_connection_dim_slid = Slider(
            p2_connection_dim_ax,
            'p2 c',
            valmin=0,
            valmax=self.network["W2"].shape[0] - 1,
            valinit=0,
            valstep=1)
        self.p2_connection_dim_slid.on_changed(self.p2_connection_dim_changed)

        p2b_ax = plot_to_grid(fig, (24, 16), (13, 14), (7, 1))
        self.p2b_slid = Slider(p2b_ax,
                               'p2 b',
                               valmin=w2b_min,
                               valmax=w2b_max,
                               valinit=self.network["b2"][0, 0],
                               valstep=step_size)
        self.p2b_slid.on_changed(self.p2b_changed)

        p2_opt_ax = plot_to_grid(fig, (24, 16), (21, 14), (4, 2))
        self.p2_opt_buttons = CheckButtons(p2_opt_ax,
                                           ["Pre-add?", "Transform?"],
                                           [self.is_pre_add, self.is_sig])
        self.p2_opt_buttons.on_clicked(self.p2_options_update)

    def p1_changed(self, val):
        self.perceptron1 = int(val)
        self.ignore_update = True
        self.update_widgets()
        self.ignore_update = False

        self.update_just_plot1()

    def p1x_changed(self, val):
        self.network["W1"][0, self.perceptron1] = val
        self.update_visuals()

    def p1y_changed(self, val):
        self.network["W1"][1, self.perceptron1] = val
        self.update_visuals()

    def p1b_changed(self, val):
        self.network["b1"][0, self.perceptron1] = val
        self.update_visuals()

    def p1_options_update(self, label):
        if label == "ReLU?":
            self.is_relu = not self.is_relu
            self.update_just_plot1()
        elif label == "Enabled?":
            is_enabled = self.p1_opt_buttons.get_status()[1]
            if is_enabled and self.perceptron1 not in self.all_p1_enabled:
                self.all_p1_enabled.add(self.perceptron1)
            elif not is_enabled and self.perceptron1 in self.all_p1_enabled:
                layer1_out = sorted(list(self.all_p1_enabled)).index(
                    self.perceptron1)
                self.layer1_3d_plot.remove_plot(layer1_out)
                self.all_p1_enabled.remove(self.perceptron1)

            self.update_visuals()

    def p2_weight_changed(self, val):
        self.network["W2"][self.connection, 0] = val

        self.update_just_plot2()

    def p2_connection_dim_changed(self, val):
        self.connection = int(val)
        self.ignore_update = True
        self.p2_dim_val_slid.set_val(self.network["W2"][self.connection, 0])
        self.p2_dim_val_slid.vline.set_xdata(
            self.default_network["W2"][self.connection, 0])
        self.ignore_update = False

    def p2b_changed(self, val):
        self.network["b2"][0, 0] = val
        self.update_just_plot2()

    def p2_options_update(self, label):
        if label == "Transform?":
            self.is_sig = not self.is_sig
        elif label == "Pre-add?":
            self.is_pre_add = not self.is_pre_add

        self.update_just_plot2()

    def show(self):
        plt.show()

    def update_plot1(self, out1, out2):
        if self.perceptron1 in self.all_p1_enabled:
            self.layer1_plot.set_visible(True)
            layer1_out = sorted(list(self.all_p1_enabled)).index(
                self.perceptron1)
            if not self.is_relu:
                layer1_data = out1[:, :, layer1_out]
            else:
                layer1_data = out2[:, :, layer1_out]
            self.layer1_plot.update(layer1_data)
        else:
            self.layer1_plot.set_visible(False)

    def update_3d_plot1(self, out1, out2):
        if self.perceptron1 in self.all_p1_enabled:
            if not self.is_relu:
                self.layer1_3d_plot.update_all(out1)
            else:
                self.layer1_3d_plot.update_all(out2)

    def update_plot2(self, out2, out3, out4):
        if self.is_pre_add:
            layer2_data = scale_out2(out2, self.network["W2"],
                                     self.all_p1_enabled, self.perceptron2,
                                     self.precision)
            layer2_data = np.sum(layer2_data, axis=2)
        elif not self.is_sig:
            layer2_data = out3[:, :, 0]
        else:
            layer2_data = out4[:, :, 0]
        self.layer2_plot.update(layer2_data)

    def update_3d_plot2(self, out2, out3, out4):
        if self.is_pre_add:
            layer2_data = scale_out2(out2, self.network["W2"],
                                     self.all_p1_enabled, self.perceptron2,
                                     self.precision)
        elif not self.is_sig:
            layer2_data = out3
        else:
            layer2_data = out4
        self.layer2_3d_plot.update_all(layer2_data)

    def update_visuals(self):
        if not self.ignore_update:
            out1, out2, out3, out4 = forward(self.data_space, self.network,
                                             self.dataset_type,
                                             self.all_p1_enabled,
                                             self.precision)
            self.update_plot1_visuals(out1, out2)
            self.update_plot2_visuals(out2, out3, out4)
            plt.draw()

    def update_just_plot1(self):
        if not self.ignore_update:
            out1, out2, out3, out4 = forward(self.data_space, self.network,
                                             self.dataset_type,
                                             self.all_p1_enabled,
                                             self.precision)
            self.update_plot1_visuals(out1, out2)
            plt.draw()

    def update_plot1_visuals(self, out1, out2):
        self.update_plot1(out1, out2)
        self.update_3d_plot1(out1, out2)

    def update_just_plot2(self):
        if not self.ignore_update:
            out1, out2, out3, out4 = forward(self.data_space, self.network,
                                             self.dataset_type,
                                             self.all_p1_enabled,
                                             self.precision)
            self.update_plot2_visuals(out2, out3, out4)
            plt.draw()

    def update_plot2_visuals(self, out2, out3, out4):
        self.update_plot2(out2, out3, out4)
        self.update_3d_plot2(out2, out3, out4)

    def update_widgets(self):
        self.p1b_slid.set_val(self.network["b1"][0, self.perceptron1])
        self.p1x_slid.set_val(self.network["W1"][0, self.perceptron1])
        self.p1y_slid.set_val(self.network["W1"][1, self.perceptron1])

        self.p1b_slid.vline.set_xdata(
            self.default_network["b1"][0, self.perceptron1])
        self.p1x_slid.vline.set_xdata(
            self.default_network["W1"][0, self.perceptron1])
        self.p1y_slid.vline.set_xdata(
            self.default_network["W1"][1, self.perceptron1])

        if (self.perceptron1 in self.all_p1_enabled and not self.p1_opt_buttons.get_status()[1]) or \
           (self.perceptron1 not in self.all_p1_enabled and self.p1_opt_buttons.get_status()[1]):
            self.p1_opt_buttons.set_active(1)
예제 #15
0
def genResaleFlatPriceGraph(p_yearEnter, p_monthEnter, p_leaseEnter,
                            p_flatType, p_labels):
    import numpy as np
    import matplotlib.pyplot as plt
    #from datetime import datetime
    from matplotlib.widgets import Slider, Button, RadioButtons, CheckButtons
    from datetime import datetime, timedelta

    global txt, flatType, checkBoxLines
    txt = None
    yearEnter = p_yearEnter
    monthEnter = p_monthEnter
    monthNum = "0" + monthEnter
    labels = p_labels
    flatType = p_flatType

    init_sLease = 50
    init_sYear = 2018
    date11halfMonthAgo = datetime.today() - timedelta(days=350)
    if init_sYear > 2000:
        init_sYear = date11halfMonthAgo.strftime('%Y')
    #print(init_sYear)

    def mainTheme():
        global axcolor
        axcolor = 'lightgoldenrodyellow'
        title = "HDB Resale Flat Price"
        titlelen = len(title)
        ax.set_title(title, fontsize=30)
        ax.set_ylabel('Resale Price', fontsize=25)

    def onclick(event):
        global txt
        nearestPrice = round(event.ydata, -3)
        #mouse click on the graph and y-axis indicate the resale price and range within 5K (upper and lower 5K)
        if nearestPrice > 0:
            minPrice = nearestPrice - 5000
            maxPrice = nearestPrice + 5000
            townDet = townSelect[int(round(event.xdata, 0) - 1)]
            pointSel = townDet[(townDet['resale_price'] >= minPrice)
                               & (townDet['resale_price'] < maxPrice)]
            if str(pointSel) != "[]":
                txt = ax.text(event.xdata,
                              event.ydata,
                              str(pointSel),
                              horizontalalignment='center',
                              verticalalignment='center',
                              bbox=dict(facecolor='red', alpha=0.5))
        fig.canvas.draw()

    def offclick(event):
        nearestPrice = round(event.ydata, -3)
        if nearestPrice > 0:
            txt.remove()
        fig.canvas.draw()

    def get_status():
        return [checkBoxLines[index].get_visible() for index in checkBoxLines]

    def genCombineResalePrices(yearEnter, monthNum, p_leaseEnter, flatType):
        import numpy as np
        from datetime import datetime
        from matplotlib.widgets import CheckButtons
        global townSelect, towndata
        ## To get data within the Period from parameter pass of p_yearEnter p_monthEnter to current date
        dateStart = yearEnter + monthNum[-2:]
        firstStartDate = yearEnter + "-" + monthNum[-2:]
        dateEnd = datetime.now().strftime('%Y%m')
        dataPeriod = data[data['month'] == firstStartDate]
        for i in range(int(dateStart) + 1, int(dateEnd) + 1):
            strDate = str(i)
            last2digit = int(strDate[-2:])
            if last2digit < 13:
                iStrDate = str(i)
                iDate = iStrDate[0:4] + "-" + iStrDate[-2:]
                idata = data[data['month'] == iDate]
                dataPeriod = np.append(dataPeriod, idata)
        ## To get data within the Flat lease period (Flat Lease > parameter p_leaseEnter)
        leaseStart = int(p_leaseEnter) + 1
        leaseData = dataPeriod[dataPeriod['remaining_lease'] > leaseStart]
        flatTypeData = leaseData[leaseData['flat_type'] == flatType]
        towndata = flatTypeData

        ## To get data within the Selected Checkbox Town Line
        line_labels = [str(line.get_label()) for line in checkBoxLines]
        line_visibility = [line.get_visible() for line in checkBoxLines]
        check = CheckButtons(rax, line_labels, line_visibility)
        checkButtonThemes(check)

        status = check.get_status()

        #index = line_labels.index(label)
        #checkBoxLines[index].set_visible(not checkBoxLines[index].get_visible())
        statusIndex = np.arange(0, len(status))
        resale_prices = flatTypeData['resale_price']
        town_resale_prices = np.zeros(len(townList), object)
        for t in townIndex:
            town_resale_prices[t] = resale_prices[towndata['town'] ==
                                                  townList[t]]
        townSelect = np.zeros(len(townList), object)
        townName = []
        resale_prices_combined = []
        labelIndex = -1
        for s in statusIndex:
            if status[s]:
                labelIndex += 1
                townSelect[labelIndex] = towndata[towndata['town'] ==
                                                  townList[s]]
                #townSelect = np.append(townSelect,sdata)
                townName.append(line_labels[s])
                resale_prices_combined.append(town_resale_prices[s])
        return resale_prices_combined, townName

    def updateCheckBox(line_labels, line_visibility, check):
        rax.clear()
        line_labels = [str(line.get_label()) for line in checkBoxLines]
        line_visibility = [line.get_visible() for line in checkBoxLines]
        check = CheckButtons(rax, line_labels, line_visibility)
        checkButtonThemes(check)

    def func(label):
        global index, towndata, checkBoxLines
        status = check.get_status()
        index = line_labels.index(label)
        checkBoxLines[index].set_visible(
            not checkBoxLines[index].get_visible())
        updateCheckBox(line_labels, line_visibility, check)
        statusIndex = np.arange(0, len(status))
        resale_prices = towndata['resale_price']
        town_resale_prices = np.zeros(len(townList), object)
        for t in townIndex:
            town_resale_prices[t] = resale_prices[towndata['town'] ==
                                                  townList[t]]
        townName = []
        resale_prices_combined = []
        labelIndex = -1
        for s in statusIndex:
            if status[s]:
                labelIndex += 1
                townSelect[labelIndex] = towndata[towndata['town'] ==
                                                  townList[s]]
                townName.append(line_labels[s])
                resale_prices_combined.append(town_resale_prices[s])
        #ax.boxplot.clear()
        ax.clear()
        mainTheme()
        if str(resale_prices_combined) != '[]':
            bp_dict = ax.boxplot(resale_prices_combined,
                                 labels=townName,
                                 patch_artist=True)
            boxplotTheme(bp_dict)
        print(status)
        plt.draw()

    def update(val):
        uYear = str(sYear.val)
        uLease = sLease.val
        monthNum = "01"
        resale_prices_combined, labels = genCombineResalePrices(
            uYear[0:4], monthNum, uLease, flatType)
        #ax.boxplot.clear()
        ax.clear()
        mainTheme()
        if str(resale_prices_combined) != '[]':
            bp_dict = ax.boxplot(resale_prices_combined,
                                 labels=labels,
                                 patch_artist=True)
            boxplotTheme(bp_dict)
        fig.canvas.draw_idle()

    def checkButtonThemes(check):
        for r in check.rectangles:
            r.set_alpha(0.8)
            r.set_width(0.025)
            r.set_edgecolor("k")

##############            Main Process Start Here			###############

    title = "HDB Resale Flat Price"
    titlelen = len(title)
    #print("{:*^{titlelen}}".format(title, titlelen=titlelen+6))
    #print()
    data = np.genfromtxt('data/resale-flat-prices.csv',
                         skip_header=1,
                         dtype=[('month', 'U7'), ('town', 'U30'),
                                ('flat_type', 'U20'), ('block', 'U4'),
                                ('street_name', 'U50'),
                                ('storey_range', 'U15'),
                                ('floor_area_sqm', 'U3'),
                                ('flat_model', 'U25'),
                                ('lease_commence_date', 'U4'),
                                ('remaining_lease', 'i2'),
                                ('resale_price', 'f8')],
                         delimiter=",",
                         missing_values=['na', '-'],
                         filling_values=[0])
    null_rows = np.isnan(data['resale_price'])
    nonnull_resale_prices = data[null_rows == False]
    townList = list(set(data['town']))
    townList.sort()
    townIndex = np.arange(0, len(townList))

    yearmonthList = list(set(data['month']))

    minsYear = int(yearmonthList[0][0:4])
    maxsYear = int(yearmonthList[0][0:4])
    for yearmonthRec in yearmonthList:
        if int(yearmonthRec[0:4]) < minsYear:
            minsYear = int(yearmonthRec[0:4])
        if int(yearmonthRec[0:4]) > maxsYear:
            maxsYear = int(yearmonthRec[0:4])

    #print(minYear)
    #print(maxYear)

    date11halfMonthAgo = datetime.today() - timedelta(days=350)
    init_sYear = date11halfMonthAgo.strftime('%Y')

    #print("Filtered data: " + str(nonnull_resale_prices.shape))

    fig, ax = plt.subplots(figsize=(19, 15))
    plt.subplots_adjust(left=0.25, bottom=0.25)
    mainTheme()

    #Create a dummy ax plot line to capture the status of visible or non visible town
    #so later can be use in boxplot
    checkBoxLines = np.zeros(len(townList), object)
    for l in townIndex:
        checkBoxLines[l], = ax.plot(0,
                                    0,
                                    visible=False,
                                    lw=2,
                                    label=townList[l])
    #Ploting checkbox button
    # Make checkbuttons with all plotted lines with correct visibility
    rax = plt.axes([0.05, 0.2, 0.13, 0.7], facecolor=axcolor)

    line_labels = [str(line.get_label()) for line in checkBoxLines]
    line_visibility = [line.get_visible() for line in checkBoxLines]
    check = CheckButtons(rax, line_labels, line_visibility)
    checkButtonThemes(check)

    check.on_clicked(func)
    status = check.get_status()
    print(status)

    axYear = plt.axes([0.5, 0.15, 0.4, 0.03], facecolor=axcolor)
    axLease = plt.axes([0.5, 0.1, 0.4, 0.03], facecolor=axcolor)

    sYear = Slider(axYear,
                   'Year From',
                   minsYear,
                   maxsYear,
                   valinit=int(init_sYear),
                   valstep=1)
    sLease = Slider(axLease,
                    'Min Remain Lease Year',
                    1,
                    99,
                    valinit=int(init_sLease),
                    valstep=1)

    radioax = plt.axes([0.25, 0.08, 0.11, 0.12], facecolor=axcolor)
    radio = RadioButtons(radioax, ('1 ROOM', '2 ROOM', '3 ROOM', '4 ROOM',
                                   '5 ROOM', 'EXECUTIVE', 'MULTI-GENE'),
                         active=3)

    sYear.on_changed(update)
    sLease.on_changed(update)

    def flatTypefunc(label):
        global flatType
        uYear = str(sYear.val)
        uLease = sLease.val
        monthNum = "01"
        flatType = label
        resale_prices_combined, labels = genCombineResalePrices(
            uYear[0:4], monthNum, uLease, flatType)
        ax.clear()
        mainTheme()
        if str(resale_prices_combined) != '[]':
            bp_dict = ax.boxplot(resale_prices_combined,
                                 labels=labels,
                                 patch_artist=True)
            boxplotTheme(bp_dict)
        fig.canvas.draw_idle()

    radio.on_clicked(flatTypefunc)

    fig.canvas.mpl_connect('button_press_event', onclick)
    fig.canvas.mpl_connect('button_release_event', offclick)

    resale_prices_combined, labels = genCombineResalePrices(
        yearEnter, monthNum, p_leaseEnter, flatType)

    if str(resale_prices_combined) != '[]':
        bp_dict = ax.boxplot(resale_prices_combined,
                             labels=labels,
                             patch_artist=True)
        #labels = ['']
        #resale_prices_combined = ['0']
        boxplotTheme(bp_dict)

    def boxplotTheme(bp_dict):
        #global bp_dict
        ## change outline color, fill color and linewidth of the boxes
        for box in bp_dict['boxes']:
            # change outline color
            box.set(color='#7570b3', linewidth=2)
            # change fill color
            box.set(facecolor='#1b9e77')

        ## change color and linewidth of the whiskers
        for whisker in bp_dict['whiskers']:
            whisker.set(color='#7570b3', linewidth=2)

        ## change color and linewidth of the caps
        for cap in bp_dict['caps']:
            cap.set(color='#7570b3', linewidth=2)

        ## change color and linewidth of the medians
        for median in bp_dict['medians']:
            median.set(color='#b2df8a', linewidth=2)

        ## change the style of fliers and their fill
        for flier in bp_dict['fliers']:
            flier.set(marker='D', color='#e7298a', alpha=0.5)

        print(bp_dict.keys())

        for line in bp_dict['medians']:
            # get position data for median line
            x, y = line.get_xydata()[1]  # top of median line
            # overlay median resale_price
            #ax[0].plt.text(x, y, '%.1f' % y,
            ax.text(x,
                    y,
                    '%.1f' % y,
                    horizontalalignment='center',
                    fontsize=15)  # draw above, centered

        fliers = []
        for line in bp_dict['fliers']:
            ndarray = line.get_xydata()
            if (len(ndarray) > 0):
                max_flier = ndarray[:, 1].max()
                max_flier_index = ndarray[:, 1].argmax()
                x = ndarray[max_flier_index, 0]
                print("Flier: " + str(x) + "," + str(max_flier))
                ax.text(x,
                        max_flier,
                        '%.1f' % max_flier,
                        horizontalalignment='center',
                        fontsize=15,
                        color='green')

    plt.show()
예제 #16
0
class PeakParameterPlot(object):
    def __init__(self, peaks, font_size=5, observed_init=1.,
                 oe_d_init=1., oe_h_init=1., oe_v_init=1., oe_l_init=1.,
                 fdr_d_init=.1, fdr_h_init=.1, fdr_v_init=.1, fdr_l_init=.1,
                 mappability_d_init=0., mappability_h_init=0.,
                 mappability_v_init=0., mappability_l_init=0.,
                 oe_slider_range=(0, 5), oe_slider_step=0.1,
                 fdr_slider_range=(0, .1), fdr_slider_step=0.001,
                 mappability_slider_range=(0, 1.), mappability_slider_step=0.05,
                 observed_range=(0, 50), observed_step=1,
                 **kwargs):
        self.peaks = peaks
        self.font_size = font_size
        self.hic_args = kwargs
        
        self.observed_init = observed_init

        self.oe_init = {
            'd': oe_d_init,
            'v': oe_v_init,
            'h': oe_h_init,
            'l': oe_l_init,
        }

        self.fdr_init = {
            'd': fdr_d_init,
            'v': fdr_v_init,
            'h': fdr_h_init,
            'l': fdr_l_init,
        }

        self.mappability_init = {
            'd': mappability_d_init,
            'v': mappability_v_init,
            'h': mappability_h_init,
            'l': mappability_l_init,
        }

        self.fdr_range = fdr_slider_range
        self.fdr_step = fdr_slider_step
        self.oe_range = oe_slider_range
        self.oe_step = oe_slider_step
        self.mappability_range = mappability_slider_range
        self.mappability_step = mappability_slider_step
        self.observed_range = observed_range
        self.observed_step = observed_step

        self.observed_cutoff = self.observed_init
        self.oe_cutoffs = self.oe_init.copy()
        self.fdr_cutoffs = self.fdr_init.copy()
        self.mappability_cutoffs = self.mappability_init.copy()

        self.ax_fdr_sliders = {}
        self.fdr_sliders = {}
        self.oe_sliders = {}
        self.ax_oe_sliders = {}
        self.ax_mappability_sliders = {}
        self.mappability_sliders = {}
        self.observed_slider = None

        self.observed_filter = None
        self.fdr_filter = None
        self.mappability_filter = None
        self.oe_filter = None

        self.filtered_plots = []
        self.hic_plots = []

        self.button = None
        self.region_pairs = []

        self.fig = None

    def plot(self, *regions):
        plt.rcParams.update({'font.size': self.font_size})

        # process region pairs
        self.region_pairs = []
        for pair in regions:
            if isinstance(pair, string_types):
                r = as_region(pair)
                pair = (r, r)
            elif isinstance(pair, GenomicRegion):
                pair = (pair, pair)

            try:
                r1, r2, vmax = pair
            except ValueError:
                r1, r2 = pair
                if isinstance(r2, float) or isinstance(r2, int):
                    vmax = r2
                    r2 = r1
                else:
                    vmax = None
            r1 = as_region(r1)
            r2 = as_region(r2)
            self.region_pairs.append((r1, r2, vmax))

        #
        # necessary plots: hic, oe_d, fdr_d, uncorrected,
        #
        gs = grd.GridSpec(len(self.region_pairs) + 4, 4,
                          height_ratios=[10] * len(self.region_pairs) + [1, 1, 1, 3],
                          wspace=0.3, hspace=0.5)

        self.fig = plt.figure(figsize=(10, len(self.region_pairs) * 2 + 2), dpi=150)

        # sliders
        inner_observed_gs = grd.GridSpecFromSubplotSpec(3, 1,
                                                        subplot_spec=gs[len(self.region_pairs) + 3, 0],
                                                        wspace=0.0, hspace=0.0)

        ax_observed_slider = plt.subplot(inner_observed_gs[0, 0])
        self.observed_slider = Slider(ax_observed_slider, 'uncorrected',
                                      self.observed_range[0], self.observed_range[1],
                                      valinit=self.observed_init, valstep=self.observed_step)

        self.ax_oe_sliders = dict()
        self.oe_sliders = dict()
        self.ax_fdr_sliders = dict()
        self.fdr_sliders = dict()
        self.ax_mappability_sliders = dict()
        self.mappability_sliders = dict()
        for i, neighborhood in enumerate(['d', 'h', 'v', 'l']):
            # O/E
            self.ax_oe_sliders[neighborhood] = plt.subplot(gs[len(self.region_pairs) + 0, i])
            self.oe_sliders[neighborhood] = Slider(self.ax_oe_sliders[neighborhood],
                                                   'O/E {}'.format(neighborhood.upper()),
                                                   self.oe_range[0], self.oe_range[1],
                                                   valinit=self.oe_init[neighborhood],
                                                   valstep=self.oe_step)
            # FDR
            self.ax_fdr_sliders[neighborhood] = plt.subplot(gs[len(self.region_pairs) + 1, i])
            self.fdr_sliders[neighborhood] = Slider(self.ax_fdr_sliders[neighborhood],
                                                    'FDR {}'.format(neighborhood.upper()),
                                                    self.fdr_range[0], self.fdr_range[1],
                                                    valinit=self.fdr_init[neighborhood],
                                                    valstep=self.fdr_step)
            # Mappability
            self.ax_mappability_sliders[neighborhood] = plt.subplot(gs[len(self.region_pairs) + 2, i])
            self.mappability_sliders[neighborhood] = Slider(self.ax_mappability_sliders[neighborhood],
                                                            'Map {}'.format(neighborhood.upper()),
                                                            self.mappability_range[0],
                                                            self.mappability_range[1],
                                                            valinit=self.mappability_init[neighborhood],
                                                            valstep=self.mappability_step)

        # check button
        inner_button_gs = grd.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[len(self.region_pairs) + 3, 1],
                                                      wspace=0.0, hspace=0.0)
        ax_button = plt.subplot(inner_button_gs[0, 0])
        self.button = CheckButtons(ax_button, ['Show loops'], [False])

        # filters
        self.observed_filter = ObservedPeakFilter(cutoff=self.observed_init)
        self.oe_filter = EnrichmentPeakFilter(enrichment_d_cutoff=self.oe_init['d'],
                                              enrichment_h_cutoff=self.oe_init['h'],
                                              enrichment_v_cutoff=self.oe_init['v'],
                                              enrichment_ll_cutoff=self.oe_init['l'])
        self.fdr_filter = FdrPeakFilter(fdr_d_cutoff=self.fdr_init['d'],
                                        fdr_ll_cutoff=self.fdr_init['h'],
                                        fdr_h_cutoff=self.fdr_init['v'],
                                        fdr_v_cutoff=self.fdr_init['l'])
        self.mappability_filter = MappabilityPeakFilter(mappability_d_cutoff=self.mappability_init['d'],
                                                        mappability_h_cutoff=self.mappability_init['h'],
                                                        mappability_v_cutoff=self.mappability_init['v'],
                                                        mappability_ll_cutoff=self.mappability_init['l'])

        self.hic_plots = []
        self.filtered_plots = []
        for i, (r1, r2, vmax) in enumerate(self.region_pairs):
            ax_hic = plt.subplot(gs[i, 0])
            ax_filtered = plt.subplot(gs[i, 1])
            ax_oe = plt.subplot(gs[i, 2])
            ax_fdr = plt.subplot(gs[i, 3])

            hic_plot = EdgeHicPlot(self.peaks,
                                   ax=ax_hic, show_colorbar=False, adjust_range=False,
                                   unmappable_color='white', vmax=vmax,
                                   highlight_edges=True, highlight_limit=200,
                                   draw_tick_legend=False,
                                   **self.hic_args)
            filtered_plot = EdgeHicPlot(self.peaks,
                                        ax=ax_filtered, show_colorbar=False, adjust_range=False,
                                        unmappable_color='white', vmax=vmax,
                                        draw_tick_legend=False,
                                        **self.hic_args)
            oe_plot = EdgeHicPlot(self.peaks, plot_field='oe_d', default_value=0, norm='lin',
                                  ax=ax_oe, show_colorbar=False, colormap='white_red', adjust_range=False,
                                  vmin=1, vmax=4, log2=False, unmappable_color='white',
                                  draw_tick_legend=False,)
            fdr_plot = EdgeHicPlot(self.peaks, plot_field='fdr_d', default_value=1, norm='lin',
                                   ax=ax_fdr, show_colorbar=False, adjust_range=False,
                                   vmin=0, vmax=0.05, colormap='Greys_r', unmappable_color='white',
                                   draw_tick_legend=False)

            filtered_plot.hic_buffer.add_filter(self.observed_filter)
            filtered_plot.hic_buffer.add_filter(self.oe_filter)
            filtered_plot.hic_buffer.add_filter(self.fdr_filter)
            filtered_plot.hic_buffer.add_filter(self.mappability_filter)

            self.hic_plots.append(hic_plot)
            self.filtered_plots.append(filtered_plot)

            hic_plot.plot((r1, r2))
            filtered_plot.plot((r1, r2))
            oe_plot.plot((r1, r2))
            fdr_plot.plot((r1, r2))

            ax_filtered.set_yticklabels([])
            ax_oe.set_yticklabels([])
            ax_fdr.set_yticklabels([])

        for neighborhood in ['d', 'h', 'v', 'l']:
            self.oe_sliders[neighborhood].on_changed(self.update_oe_filter)
            self.fdr_sliders[neighborhood].on_changed(self.update_fdr_filter)
            self.mappability_sliders[neighborhood].on_changed(self.update_mappability_filter)

        self.observed_slider.on_changed(self.update_observed_filter)

        self.button.on_clicked(self.refresh_plots)
        self.refresh_plots()
        return self.fig

    def refresh_plots(self, event=None):
        for i, (r1, r2, vmax) in enumerate(self.region_pairs):
            self.filtered_plots[i].refresh((r1, r2))

            if self.button.get_status()[0]:
                self.hic_plots[i].update_highlights(self.filtered_plots[i].hic_buffer._last_matrix)
            else:
                self.hic_plots[i].update_highlights()

        self.fig.canvas.draw()

    def update_observed_filter(self, event):
        self.observed_cutoff = self.observed_slider.val
        self.observed_filter.cutoff = self.observed_cutoff
        logger.info("Observed cutoff set to {}".format(self.observed_cutoff))
        self.refresh_plots()

    def update_oe_filter(self, event):
        self.oe_cutoffs = {
            'd': self.oe_sliders['d'].val,
            'h': self.oe_sliders['h'].val,
            'v': self.oe_sliders['v'].val,
            'l': self.oe_sliders['l'].val
        }

        self.oe_filter.enrichment_d_cutoff = self.oe_cutoffs['d']
        self.oe_filter.enrichment_h_cutoff = self.oe_cutoffs['h']
        self.oe_filter.enrichment_v_cutoff = self.oe_cutoffs['v']
        self.oe_filter.enrichment_ll_cutoff = self.oe_cutoffs['l']
        logger.info("O/E cutoffs set to: {}".format(self.oe_cutoffs))
        self.refresh_plots()

    def update_fdr_filter(self, event):
        self.fdr_cutoffs = {
            'd': self.fdr_sliders['d'].val,
            'h': self.fdr_sliders['h'].val,
            'v': self.fdr_sliders['v'].val,
            'l': self.fdr_sliders['l'].val
        }
        self.fdr_filter.fdr_d_cutoff = self.fdr_cutoffs['d']
        self.fdr_filter.fdr_h_cutoff = self.fdr_cutoffs['h']
        self.fdr_filter.fdr_v_cutoff = self.fdr_cutoffs['v']
        self.fdr_filter.fdr_ll_cutoff = self.fdr_cutoffs['l']

        logger.info("FDR cutoffs set to: {}".format(self.fdr_cutoffs))
        self.refresh_plots()

    def update_mappability_filter(self, event):
        self.mappability_cutoffs = {
            'd': self.mappability_sliders['d'].val,
            'h': self.mappability_sliders['h'].val,
            'v': self.mappability_sliders['v'].val,
            'l': self.mappability_sliders['l'].val
        }
        self.mappability_filter.mappability_d_cutoff = self.mappability_cutoffs['d']
        self.mappability_filter.mappability_h_cutoff = self.mappability_cutoffs['h']
        self.mappability_filter.mappability_v_cutoff = self.mappability_cutoffs['v']
        self.mappability_filter.mappability_ll_cutoff = self.mappability_cutoffs['l']

        logger.info("Mappability cutoffs set to: {}".format(self.mappability_cutoffs))
        self.refresh_plots()
예제 #17
0
class plt_one_addpt_onclick:
    """ class to run one interactive plot """
    def __init__(self, x, y, w, b, logistic=True):
        self.logistic = logistic
        pos = y == 1
        neg = y == 0

        fig, ax = plt.subplots(1, 1, figsize=(8, 4))
        fig.canvas.toolbar_visible = False
        fig.canvas.header_visible = False
        fig.canvas.footer_visible = False

        plt.subplots_adjust(bottom=0.25)
        ax.scatter(x[pos],
                   y[pos],
                   marker='x',
                   s=80,
                   c='red',
                   label="malignant")
        ax.scatter(x[neg],
                   y[neg],
                   marker='o',
                   s=100,
                   label="benign",
                   facecolors='none',
                   edgecolors=dlblue,
                   lw=3)
        ax.set_ylim(-0.05, 1.1)
        xlim = ax.get_xlim()
        ax.set_xlim(xlim[0], xlim[1] * 2)
        ax.set_ylabel('y')
        ax.set_xlabel('Tumor Size')
        self.alegend = ax.legend(loc='lower right')
        if self.logistic:
            ax.set_title("Example of Logistic Regression on Categorical Data")
        else:
            ax.set_title("Example of Linear Regression on Categorical Data")

        ax.text(0.65,
                0.8,
                "[Click to add data points]",
                size=10,
                transform=ax.transAxes)

        axcalc = plt.axes([0.1, 0.05, 0.38, 0.075])  #l,b,w,h
        axthresh = plt.axes([0.5, 0.05, 0.38, 0.075])  #l,b,w,h
        self.tlist = []

        self.fig = fig
        self.ax = [ax, axcalc, axthresh]
        self.x = x
        self.y = y
        self.w = copy.deepcopy(w)
        self.b = b
        f_wb = np.matmul(self.x.reshape(-1, 1), self.w) + self.b
        if self.logistic:
            self.aline = self.ax[0].plot(self.x, sigmoid(f_wb), color=dlblue)
            self.bline = self.ax[0].plot(self.x, f_wb, color=dlorange, lw=1)
        else:
            self.aline = self.ax[0].plot(self.x, sigmoid(f_wb), color=dlblue)

        self.cid = fig.canvas.mpl_connect('button_press_event', self.add_data)
        if self.logistic:
            self.bcalc = Button(axcalc,
                                'Run Logistic Regression (click)',
                                color=dlblue)
            self.bcalc.on_clicked(self.calc_logistic)
        else:
            self.bcalc = Button(axcalc,
                                'Run Linear Regression (click)',
                                color=dlblue)
            self.bcalc.on_clicked(self.calc_linear)
        self.bthresh = CheckButtons(
            axthresh, ('Toggle 0.5 threshold (after regression)', ))
        self.bthresh.on_clicked(self.thresh)
        self.resize_sq(self.bthresh)

#   @output.capture()  # debug

    def add_data(self, event):
        #self.ax[0].text(0.1,0.1, f"in onclick")
        if event.inaxes == self.ax[0]:
            x_coord = event.xdata
            y_coord = event.ydata

            if y_coord > 0.5:
                self.ax[0].scatter(x_coord, 1, marker='x', s=80, c='red')
                self.y = np.append(self.y, 1)
            else:
                self.ax[0].scatter(x_coord,
                                   0,
                                   marker='o',
                                   s=100,
                                   facecolors='none',
                                   edgecolors=dlblue,
                                   lw=3)
                self.y = np.append(self.y, 0)
            self.x = np.append(self.x, x_coord)
        self.fig.canvas.draw()

#   @output.capture()  # debug

    def calc_linear(self, event):
        if self.bthresh.get_status()[0]:
            self.remove_thresh()
        for it in [1, 1, 1, 1, 1, 2, 4, 8, 16, 32, 64, 128, 256]:
            self.w, self.b, _ = gradient_descent(self.x.reshape(-1, 1),
                                                 self.y.reshape(-1, 1),
                                                 self.w.reshape(-1, 1),
                                                 self.b,
                                                 0.01,
                                                 it,
                                                 logistic=False,
                                                 lambda_=0,
                                                 verbose=False)
            self.aline[0].remove()
            self.alegend.remove()
            y_hat = np.matmul(self.x.reshape(-1, 1), self.w) + self.b
            self.aline = self.ax[0].plot(
                self.x,
                y_hat,
                color=dlblue,
                label=f"y = {np.squeeze(self.w):0.2f}x+({self.b:0.2f})")
            self.alegend = self.ax[0].legend(loc='lower right')
            time.sleep(0.3)
            self.fig.canvas.draw()
        if self.bthresh.get_status()[0]:
            self.draw_thresh()
            self.fig.canvas.draw()

    def calc_logistic(self, event):
        if self.bthresh.get_status()[0]:
            self.remove_thresh()
        for it in [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
            self.w, self.b, _ = gradient_descent(self.x.reshape(-1, 1),
                                                 self.y.reshape(-1, 1),
                                                 self.w.reshape(-1, 1),
                                                 self.b,
                                                 0.1,
                                                 it,
                                                 logistic=True,
                                                 lambda_=0,
                                                 verbose=False)
            self.aline[0].remove()
            self.bline[0].remove()
            self.alegend.remove()
            xlim = self.ax[0].get_xlim()
            x_hat = np.linspace(*xlim, 30)
            y_hat = sigmoid(np.matmul(x_hat.reshape(-1, 1), self.w) + self.b)
            self.aline = self.ax[0].plot(x_hat,
                                         y_hat,
                                         color=dlblue,
                                         label="y = sigmoid(z)")
            f_wb = np.matmul(x_hat.reshape(-1, 1), self.w) + self.b
            self.bline = self.ax[0].plot(
                x_hat,
                f_wb,
                color=dlorange,
                lw=1,
                label=f"z = {np.squeeze(self.w):0.2f}x+({self.b:0.2f})")
            self.alegend = self.ax[0].legend(loc='lower right')
            time.sleep(0.3)
            self.fig.canvas.draw()
        if self.bthresh.get_status()[0]:
            self.draw_thresh()
            self.fig.canvas.draw()

    def thresh(self, event):
        if self.bthresh.get_status()[0]:
            #plt.figtext(0,0, f"in thresh {self.bthresh.get_status()}")
            self.draw_thresh()
        else:
            #plt.figtext(0,0.3, f"in thresh {self.bthresh.get_status()}")
            self.remove_thresh()

    def draw_thresh(self):
        ws = np.squeeze(self.w)
        xp5 = -self.b / ws if self.logistic else (0.5 - self.b) / ws
        ylim = self.ax[0].get_ylim()
        xlim = self.ax[0].get_xlim()
        a = self.ax[0].fill_between([xlim[0], xp5], [ylim[1], ylim[1]],
                                    alpha=0.2,
                                    color=dlblue)
        b = self.ax[0].fill_between([xp5, xlim[1]], [ylim[1], ylim[1]],
                                    alpha=0.2,
                                    color=dldarkred)
        c = self.ax[0].annotate("Malignant",
                                xy=[xp5, 0.5],
                                xycoords='data',
                                xytext=[30, 5],
                                textcoords='offset points')
        d = FancyArrowPatch(
            posA=(xp5, 0.5),
            posB=(xp5 + 1.5, 0.5),
            color=dldarkred,
            arrowstyle='simple, head_width=5, head_length=10, tail_width=0.0',
        )
        self.ax[0].add_artist(d)

        e = self.ax[0].annotate("Benign",
                                xy=[xp5, 0.5],
                                xycoords='data',
                                xytext=[-70, 5],
                                textcoords='offset points',
                                ha='left')
        f = FancyArrowPatch(
            posA=(xp5, 0.5),
            posB=(xp5 - 1.5, 0.5),
            color=dlblue,
            arrowstyle='simple, head_width=5, head_length=10, tail_width=0.0',
        )
        self.ax[0].add_artist(f)
        self.tlist = [a, b, c, d, e, f]

        self.fig.canvas.draw()

    def remove_thresh(self):
        #plt.figtext(0.5,0.0, f"rem thresh {self.bthresh.get_status()}")
        for artist in self.tlist:
            artist.remove()
        self.fig.canvas.draw()

    def resize_sq(self, bcid):
        """ resizes the check box """
        #future reference
        #print(f"width  : {bcid.rectangles[0].get_width()}")
        #print(f"height : {bcid.rectangles[0].get_height()}")
        #print(f"xy     : {bcid.rectangles[0].get_xy()}")
        #print(f"bb     : {bcid.rectangles[0].get_bbox()}")
        #print(f"points : {bcid.rectangles[0].get_bbox().get_points()}")  #[[xmin,ymin],[xmax,ymax]]

        h = bcid.rectangles[0].get_height()
        bcid.rectangles[0].set_height(3 * h)

        ymax = bcid.rectangles[0].get_bbox().y1
        ymin = bcid.rectangles[0].get_bbox().y0

        bcid.lines[0][0].set_ydata([ymax, ymin])
        bcid.lines[0][1].set_ydata([ymin, ymax])
예제 #18
0
class TypogaDraw:

    title = "Typoga Scores"
    randFilePath = "scores/highScore_random.txt"
    leadFilePath = "scores/highScore_phrases.txt"
    progFilePath = "scores/highScore_programming.txt"

    def __init__(self):
        # Set Window Title and Size
        plt.figure(num=self.title, figsize=(3, 6))

    def checkForScoreFiles(self):
        """
        Check if any highscore files exist.
        """
        if not os.path.exists(self.randFilePath) and \
           not os.path.exists(self.leadFilePath) and \
           not os.path.exists(self.progFilePath):
            print("You should first play the game.\n Run ./typoga.sh")
            exit(0)

        self.checkButGO = None
        self.checkButGO = None

        self.pos = 1

    def checkBoxesInit(self):
        """
        Initialize check boxes and click events.
        """
        # Check which score files exist.
        self.checkForScoreFiles()

        # Initialize CheckBox Menu for Game Options
        raxGO = plt.axes([0.02, 0.5, 1, 0.5], frameon=False)
        self.checkButGO = CheckButtons(raxGO, GameOptions,
                                       (False, False, False))
        # Set CheckBox handler for game Options
        self.checkButGO.on_clicked(self.handleUserSelection)

        # Initialize CheckBox Menu for Score Options
        raxSO = plt.axes([0.02, 0.05, 1, 0.5], frameon=False)
        self.checkButSO = CheckButtons(
            raxSO, SocreOptions,
            (False, False, False, False, False, False, False, False))
        # Set CheckBox handler for game Options
        self.checkButSO.on_clicked(self.handleUserSelection)

    def handleUserSelection(self, label):
        """
        This method is called every time a check box gets pressed. 
        """
        # Draw Random data in its figure
        fig = plt.figure(num=GameOptions[0], figsize=(6, 8))
        if self.checkButGO.get_status()[0] == True:
            fig.clear()
            self.parseFile(self.randFilePath)
            fig.show()
        else:
            plt.close()

        # Draw Leaders data in its figure
        fig = plt.figure(num=GameOptions[1], figsize=(6, 8))
        if self.checkButGO.get_status()[1] == True:
            fig.clear()
            self.parseFile(self.leadFilePath)
            fig.show()
        else:
            plt.close()

        # Draw Programming data in its figure
        fig = plt.figure(num=GameOptions[2], figsize=(6, 8))
        if self.checkButGO.get_status()[2] == True:
            fig.clear()
            self.parseFile(self.progFilePath)
            fig.show()
        else:
            plt.close()

    def parseFile(self, pfile):
        """
        Method used to parse pfile and plot graphics
        """
        # Check if file exists
        if not os.path.exists(pfile):
            print("File %s does not exist" % pfile)
            return
        else:
            hsf = open(pfile, "r")

            lines = hsf.readlines()
            lines = lines[3:]  # skip header

            # Reset plot Position
            self.pos = 1
            for i in range(len(self.checkButSO.get_status())):
                scores = []  # Clear previous scores
                if self.checkButSO.get_status()[i] == True:
                    # Parse each score line
                    for scoreLine in lines:
                        scoreLine = scoreLine.strip()  # remove \n
                        scoreLine = scoreLine.split(":")

                        scores.append(float(scoreLine[i + 1]))  # store data
                    # plot scores data in
                    self.plotData(scores, i)

    def plotData(self, scores, index):
        """
        Plot score data
        """
        if len(scores) == 0:
            scores.append(0)

        # Calculate mean value before appending 0 value at the end
        mean = np.mean(scores)
        # Append zero to get space to print mean value
        scores.append(0)

        x = np.arange(1, len(scores) + 1, 1)

        # Divide plot based on user selection
        rows = self.checkButSO.get_status().count(True)
        plt.subplot(rows, 1, self.pos)
        self.pos += 1

        plt.ylabel(SocreOptions[index])
        plt.xticks(x)

        # Draw bar plot
        plt.bar(x, scores)
        # Add text legends to each bar
        for i in range(len(scores) -
                       1):  # -1 because we appended zero at the end
            plt.text(i + 0.8,
                     scores[i] / 2, ("%.02f" % scores[i]),
                     fontsize=8,
                     color='black')

        # Draw orange bar for the highest value
        max_val = np.argmax(scores)
        plt.bar(max_val + 1, scores[max_val], color='orange')
        # TODO remove this line:
        # plt.text(max_val+1, scores[max_val], ("%.02f" % scores[max_val]), fontsize=10, color='black')

        # Draw mean line
        plt.plot([0, len(scores)], [mean, mean], 'black')
        plt.text(len(scores),
                 mean, ("%.02f" % mean),
                 fontsize=10,
                 color='black')

    def run(self):
        self.checkBoxesInit()
        #show all
        plt.show()
예제 #19
0
class LoaderUI(object):
    def __init__(self):
        self.full_sensor_data = None
        self.selected_data = None
        self.spe_file = None
        self.selector = None
        self.pol_angle = None
        self.success = False

        dir = os.path.dirname(__file__)
        filename = os.path.join(dir, 'style', 'custom-wa.mplstyle')
        plt.style.use(filename)

        fig = plt.figure()
        fig.set_size_inches(20, 12, forward=True)
        fig.canvas.set_window_title('Load Data')

        grid_shape = (16, 28)
        # Make open, load, draw buttons
        self.full_sensor_ax = plt.subplot2grid(grid_shape, (0, 0),
                                               colspan=13,
                                               rowspan=13)
        self.selected_ax = plt.subplot2grid(grid_shape, (0, 15),
                                            colspan=13,
                                            rowspan=4)
        axopen = plt.subplot2grid(grid_shape, (14, 4), colspan=4)
        axload = plt.subplot2grid(grid_shape, (14, 20), colspan=4)
        axfull_lambda = plt.subplot2grid(grid_shape, (11, 16), colspan=4)
        axpix_min = plt.subplot2grid(grid_shape, (7, 16), colspan=4)
        axpix_max = plt.subplot2grid(grid_shape, (7, 23), colspan=4)
        axrefresh = plt.subplot2grid(grid_shape, (11, 23), colspan=4)
        axpol_angle = plt.subplot2grid(grid_shape, (9, 23), colspan=4)

        bload = Button(axload, 'Load Selected', color='0.25', hovercolor='0.3')
        bload.on_clicked(self._load_callback)
        bopen = Button(axopen, 'Open File', color='0.25', hovercolor='0.3')
        bopen.on_clicked(self._open_callback)

        self.chk_full_lambda = CheckButtons(axfull_lambda, ['Full Lambda'],
                                            [True])
        self.chk_full_lambda.on_clicked(self._full_lambda_callback)

        self.ypix_min = TextBox(axpix_min,
                                'Sel. Lower \nbound ',
                                '0',
                                color='0.25',
                                hovercolor='0.3')

        self.ypix_max = TextBox(axpix_max,
                                'Sel. Upper \nbound ',
                                '0',
                                color='0.25',
                                hovercolor='0.3')

        self.refresh_selection = Button(axrefresh,
                                        'Refresh Selection',
                                        color='0.25',
                                        hovercolor='0.3')
        self.refresh_selection.on_clicked(self._refresh_selection_callback)

        self.txt_pol_angle = TextBox(axpol_angle,
                                     'Pol. Angle \n (deg) ',
                                     '0',
                                     color='0.25',
                                     hovercolor='0.3')

        self._full_lambda_callback(None)

        plt.show(block=True)

    def _rect_select_callback(self, eclick, erelease):
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata

    def _span_select_callback(self, ymin, ymax):
        ymin = int(np.floor(ymin))
        ymax = int(np.ceil(ymax))
        self.selected_data = self.full_sensor_data[ymin:ymax + 1, :]
        self.ypix_min.set_val(str(ymin))
        self.ypix_max.set_val(str(ymax))
        self._image_selected_data(ymin, ymax)

    def _load_callback(self, event):
        if self.spe_file is not None and self.selected_data is not None:
            self.pol_angle = float(self.txt_pol_angle.text)
            self.success = True
            plt.close()
            mpl.rcParams.update(mpl.rcParamsDefault)
        else:
            print(
                'Invalid loader state: must have loaded file and selected data'
            )

    def _open_callback(self, event):
        # calls general observation load function
        root = tk.Tk()
        root.withdraw()
        filename = filedialog.askopenfilename()
        if filename is not '':
            self.spe_file = spe_loader.load_from_files([filename])
            self.full_sensor_data = self.spe_file.data[0][0]
            self._image_full_sensor_data()
        else:
            return

    def _full_lambda_callback(self, event):
        if self.selector is not None:
            self.selector.set_visible(False)
        chk_status = self.chk_full_lambda.get_status()
        if chk_status[0]:
            self.selector = SpanSelector(self.full_sensor_ax,
                                         self._span_select_callback,
                                         direction='vertical',
                                         minspan=5,
                                         useblit=True,
                                         span_stays=False,
                                         rectprops=dict(facecolor='red',
                                                        alpha=0.2))
        else:
            self.selector = RectangleSelector(
                self.full_sensor_ax,
                self._rect_select_callback,
                drawtype='box',
                useblit=True,
                button=[1, 3],  # don't use middle button
                minspanx=1,
                minspany=0.1,
                spancoords='data',
                interactive=True)

    def _refresh_selection_callback(self, event):
        ymin = int(self.ypix_min.text)
        ymax = int(self.ypix_max.text)
        self._span_select_callback(ymin, ymax)

    def _image_full_sensor_data(self):
        self.full_sensor_ax.clear()
        self._full_lambda_callback(None)
        self.full_sensor_ax.imshow(self.full_sensor_data, aspect='auto')
        self.full_sensor_ax.set_title('Source Data')

    def _image_selected_data(self, ymin, ymax):
        self.selected_ax.clear()
        self.selected_ax.imshow(self.selected_data, aspect='auto')
        title = 'Selected: {0} rows ({1} -> {2})\n' \
                '{3} wavelengths ({4:.3f} -> {5:.3f})'.format(self.selected_data.shape[0], ymin, ymax,
                                                     self.selected_data.shape[1], self.spe_file.wavelength[0],
                                                     self.spe_file.wavelength[-1])
        self.selected_ax.set_title(title)
예제 #20
0
class Curator:
    """
    matplotlib display of scrolling image data 
    
    Parameters
    ---------
    extractor : extractor
        extractor object containing a full set of infilled threads and time series

    Attributes
    ----------
    ind : int
        thread indexing 

    min : int
        min of image data (for setting ranges)

    max : int
        max of image data (for setting ranges)

    """
    def __init__(self, e, window=100):
        # get info from extractors
        self.s = e.spool
        self.timeseries = e.timeseries
        self.tf = e.im
        self.tf.t = 0
        self.window = window
        ## num neurons
        self.numneurons = len(self.s.threads)

        self.path = e.root + 'extractor-objects/curate.json'
        self.ind = 0
        try:
            with open(self.path) as f:
                self.curate = json.load(f)

            self.ind = int(self.curate['last'])
        except:
            self.curate = {}
            self.ind = 0
            self.curate['0'] = 'seen'

        # array to contain internal state: whether to display single ROI, ROI in Z, or all ROIs
        self.pointstate = 0
        self.show_settings = 0
        self.showmip = 0
        ## index for which thread
        #self.ind = 0

        ## index for which time point to display
        self.t = 0

        ### First frame of the first thread
        self.update_im()

        ## Display range
        self.min = np.min(self.im)
        self.max = np.max(self.im)  # just some arbitrary value

        ## maximum t
        self.tmax = e.t

        self.restart()
        atexit.register(self.log_curate)

    def restart(self):
        ## Figure to display
        self.fig = plt.figure()

        ## Size of window around ROI in sub image
        #self.window = window

        ## grid object for complicated subplot handing
        self.grid = plt.GridSpec(4, 2, wspace=0.1, hspace=0.2)

        ### First subplot: whole image with red dot over ROI
        self.ax1 = plt.subplot(self.grid[:3, 0])
        plt.subplots_adjust(bottom=0.4)
        self.img1 = self.ax1.imshow(self.get_im_display(),
                                    cmap='gray',
                                    vmin=0,
                                    vmax=1)

        # plotting for multiple points

        if self.pointstate == 0:
            pass
            #self.point1 = plt.scatter()
            #self.point1 = plt.scatter(self.s.get_positions_t_z(self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:,2], self.s.get_positions_t_z(self.t,self.s.threads[self.ind].get_position_t(self.t)[0])[:,1],c='b', s=10)
        elif self.pointstate == 1:
            self.point1 = self.ax1.scatter(
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 2],
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 1],
                c='b',
                s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        elif self.pointstate == 2:
            self.point1 = self.ax1.scatter(self.s.get_positions_t(self.t)[:,
                                                                          2],
                                           self.s.get_positions_t(self.t)[:,
                                                                          1],
                                           c='b',
                                           s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        self.thispoint = self.ax1.scatter(
            self.s.threads[self.ind].get_position_t(self.t)[2],
            self.s.threads[self.ind].get_position_t(self.t)[1],
            c='r',
            s=10)
        plt.axis('off')

        # plotting for single point
        #
        #plt.axis("off")
        #

        ### Second subplot: some window around the ROI
        plt.subplot(self.grid[:3, 1])
        plt.subplots_adjust(bottom=0.4)

        self.subim, self.offset = subaxis(
            self.im, self.s.threads[self.ind].get_position_t(self.t),
            self.window)

        self.img2 = plt.imshow(self.get_subim_display(),
                               cmap='gray',
                               vmin=0,
                               vmax=1)
        self.point2 = plt.scatter(self.window / 2 + self.offset[0],
                                  self.window / 2 + self.offset[1],
                                  c='r',
                                  s=40)

        self.title = self.fig.suptitle(
            'Series=' + str(self.ind) + ', Z=' +
            str(int(self.s.threads[self.ind].get_position_t(self.t)[0])))
        plt.axis("off")

        ### Third subplot: plotting the timeseries
        self.timeax = plt.subplot(self.grid[3, :])
        plt.subplots_adjust(bottom=0.4)
        self.timeplot, = self.timeax.plot(
            (self.timeseries[:, self.ind] -
             np.min(self.timeseries[:, self.ind])) /
            (np.max(self.timeseries[:, self.ind]) -
             np.min(self.timeseries[:, self.ind])))
        plt.axis("off")

        ### Axis for scrolling through t
        self.tr = plt.axes([0.2, 0.15, 0.3, 0.03],
                           facecolor='lightgoldenrodyellow')
        self.s_tr = Slider(self.tr,
                           'Timepoint',
                           0,
                           self.tmax - 1,
                           valinit=0,
                           valstep=1)
        self.s_tr.on_changed(self.update_t)

        ### Axis for setting min/max range
        self.minr = plt.axes([0.2, 0.2, 0.3, 0.03],
                             facecolor='lightgoldenrodyellow')
        self.sminr = Slider(self.minr,
                            'R Min',
                            0,
                            np.max(self.im),
                            valinit=self.min,
                            valstep=1)
        self.maxr = plt.axes([0.2, 0.25, 0.3, 0.03],
                             facecolor='lightgoldenrodyellow')
        self.smaxr = Slider(self.maxr,
                            'R Max',
                            0,
                            np.max(self.im) * 4,
                            valinit=self.max,
                            valstep=1)
        self.sminr.on_changed(self.update_mm)
        self.smaxr.on_changed(self.update_mm)

        ### Axis for buttons for next/previous time series
        #where the buttons are, and their locations
        self.axprev = plt.axes([0.62, 0.20, 0.1, 0.075])
        self.axnext = plt.axes([0.75, 0.20, 0.1, 0.075])
        self.bnext = Button(self.axnext, 'Next')
        self.bnext.on_clicked(self.next)
        self.bprev = Button(self.axprev, 'Previous')
        self.bprev.on_clicked(self.prev)

        #### Axis for button for display
        self.pointsax = plt.axes([0.75, 0.10, 0.1, 0.075])
        self.pointsbutton = RadioButtons(self.pointsax,
                                         ('Single', 'Same Z', 'All'))
        self.pointsbutton.set_active(self.pointstate)
        self.pointsbutton.on_clicked(self.update_pointstate)

        #### Axis for whether to display MIP on left
        self.mipax = plt.axes([0.62, 0.10, 0.1, 0.075])
        self.mipbutton = RadioButtons(self.mipax, ('Single Z', 'MIP'))
        self.mipbutton.set_active(self.showmip)
        self.mipbutton.on_clicked(self.update_mipstate)

        ### Axis for button to keep
        self.keepax = plt.axes([0.87, 0.20, 0.075, 0.075])
        self.keep_button = CheckButtons(self.keepax, ['Keep', 'Trash'],
                                        [False, False])
        self.keep_button.on_clicked(self.keep)

        ### Axis to determine which ones to show
        self.showax = plt.axes([0.87, 0.10, 0.075, 0.075])
        self.showbutton = RadioButtons(
            self.showax, ('All', 'Unlabelled', 'Kept', 'Trashed'))
        self.showbutton.set_active(self.show_settings)
        self.showbutton.on_clicked(self.show)

        plt.show()

    ## Attempting to get autosave when instance gets deleted, not working right now TODO
    def __del__(self):
        self.log_curate()

    def update_im(self):
        #print(self.t)
        #print(self.ind)
        #print(self.t,int(self.s.threads[self.ind].get_position_t(self.t)[0]))
        if self.showmip:
            self.im = np.max(self.tf.get_t(self.t), axis=0)
        else:
            self.im = self.tf.get_tbyf(
                self.t,
                int(self.s.threads[self.ind].get_position_t(self.t)[0]))

    def get_im_display(self):

        return (self.im - self.min) / (self.max - self.min)

    def get_subim_display(self):
        return (self.subim - self.min) / (self.max - self.min)

    def update_figures(self):
        self.subim, self.offset = subaxis(
            self.im, self.s.threads[self.ind].get_position_t(self.t),
            self.window)
        self.img1.set_data(self.get_im_display())

        if self.pointstate == 0:
            pass
        elif self.pointstate == 1:
            self.point1.set_offsets(
                np.array([
                    self.s.get_positions_t_z(
                        self.t,
                        self.s.threads[self.ind].get_position_t(self.t)[0])[:,
                                                                            2],
                    self.s.get_positions_t_z(
                        self.t,
                        self.s.threads[self.ind].get_position_t(self.t)[0])[:,
                                                                            1]
                ]).T)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        elif self.pointstate == 2:
            self.point1.set_offsets(
                np.array([
                    self.s.get_positions_t(self.t)[:, 2],
                    self.s.get_positions_t(self.t)[:, 1]
                ]).T)
        self.thispoint.set_offsets([
            self.s.threads[self.ind].get_position_t(self.t)[2],
            self.s.threads[self.ind].get_position_t(self.t)[1]
        ])
        plt.axis('off')
        #plotting for single point
        #

        self.img2.set_data(self.get_subim_display())
        self.point2.set_offsets([
            self.window / 2 + self.offset[0], self.window / 2 + self.offset[1]
        ])
        self.title.set_text(
            'Series=' + str(self.ind) + ', Z=' +
            str(int(self.s.threads[self.ind].get_position_t(self.t)[0])))
        plt.draw()

    def update_timeseries(self):
        self.timeplot.set_ydata((self.timeseries[:, self.ind] -
                                 np.min(self.timeseries[:, self.ind])) /
                                (np.max(self.timeseries[:, self.ind]) -
                                 np.min(self.timeseries[:, self.ind])))
        plt.draw()

    def update_t(self, val):
        # Update index for t
        self.t = val
        # update image for t
        self.update_im()
        self.update_figures()

    def update_mm(self, val):
        self.min = self.sminr.val
        self.max = self.smaxr.val
        #self.update_im()
        self.update_figures()

    def next(self, event):
        self.set_index_next()
        self.update_im()
        self.update_figures()
        self.update_timeseries()
        self.update_buttons()
        self.update_curate()

    def prev(self, event):
        self.set_index_prev()
        self.update_im()
        self.update_figures()
        self.update_timeseries()
        self.update_buttons()
        self.update_curate()

    def log_curate(self):
        self.curate['last'] = self.ind
        with open(self.path, 'w') as fp:
            json.dump(self.curate, fp)

    def keep(self, event):
        status = self.keep_button.get_status()
        if np.sum(status) != 1:
            for i in range(len(status)):
                if status[i] != False:
                    self.keep_button.set_active(i)

        else:
            if status[0]:
                self.curate[str(self.ind)] = 'keep'
            elif status[1]:
                self.curate[str(self.ind)] = 'trash'
            else:
                pass

    def update_buttons(self):

        curr = self.keep_button.get_status()
        #print(curr)
        future = [False for i in range(len(curr))]
        if self.curate.get(str(self.ind)) == 'seen':
            pass
        elif self.curate.get(str(self.ind)) == 'keep':
            future[0] = True
        elif self.curate.get(str(self.ind)) == 'trash':
            future[1] = True
        else:
            pass

        for i in range(len(curr)):
            if curr[i] != future[i]:
                self.keep_button.set_active(i)

    def show(self, label):
        d = {'All': 0, 'Unlabelled': 1, 'Kept': 2, 'Trashed': 3}
        #print(label)
        self.show_settings = d[label]

    def set_index_prev(self):
        if self.show_settings == 0:
            self.ind -= 1
            self.ind = self.ind % self.numneurons

        elif self.show_settings == 1:
            self.ind -= 1
            counter = 0
            while self.curate.get(str(self.ind)) in [
                    'keep', 'trash'
            ] and counter != self.numneurons:
                self.ind -= 1
                self.ind = self.ind % self.numneurons
                counter += 1
            self.ind = self.ind % self.numneurons
        elif self.show_settings == 2:
            self.ind -= 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['keep'] and counter != self.numneurons:
                self.ind -= 1
                counter += 1
            self.ind = self.ind % self.numneurons
        else:
            self.ind -= 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['trash'] and counter != self.numneurons:
                self.ind -= 1
                counter += 1
            self.ind = self.ind % self.numneurons

    def set_index_next(self):
        if self.show_settings == 0:
            self.ind += 1
            self.ind = self.ind % self.numneurons

        elif self.show_settings == 1:
            self.ind += 1
            counter = 0
            while self.curate.get(str(self.ind)) in [
                    'keep', 'trash'
            ] and counter != self.numneurons:
                self.ind += 1
                self.ind = self.ind % self.numneurons
                counter += 1
            self.ind = self.ind % self.numneurons
        elif self.show_settings == 2:
            self.ind += 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['keep'] and counter != self.numneurons:
                self.ind += 1
                counter += 1
            self.ind = self.ind % self.numneurons
        else:
            self.ind += 1
            counter = 0
            while self.curate.get(str(
                    self.ind)) not in ['trash'] and counter != self.numneurons:
                self.ind += 1
                counter += 1
            self.ind = self.ind % self.numneurons

    def update_curate(self):
        if self.curate.get(str(self.ind)) in ['keep', 'seen', 'trash']:
            pass
        else:
            self.curate[str(self.ind)] = 'seen'

    def update_pointstate(self, label):
        d = {
            'Single': 0,
            'Same Z': 1,
            'All': 2,
        }
        #print(label)
        self.pointstate = d[label]
        self.update_point1()
        self.update_figures()

    def update_point1(self):
        self.ax1.clear()
        self.img1 = self.ax1.imshow(self.get_im_display(),
                                    cmap='gray',
                                    vmin=0,
                                    vmax=1)
        plt.axis('off')
        if self.pointstate == 0:
            self.point1 = None
        elif self.pointstate == 1:
            self.point1 = self.ax1.scatter(
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 2],
                self.s.get_positions_t_z(
                    self.t,
                    self.s.threads[self.ind].get_position_t(self.t)[0])[:, 1],
                c='b',
                s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        elif self.pointstate == 2:
            self.point1 = self.ax1.scatter(self.s.get_positions_t(self.t)[:,
                                                                          2],
                                           self.s.get_positions_t(self.t)[:,
                                                                          1],
                                           c='b',
                                           s=10)
            #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10)
        self.thispoint = self.ax1.scatter(
            self.s.threads[self.ind].get_position_t(self.t)[2],
            self.s.threads[self.ind].get_position_t(self.t)[1],
            c='r',
            s=10)
        plt.axis('off')
        #plt.show()

    def update_mipstate(self, label):
        d = {
            'Single Z': 0,
            'MIP': 1,
        }
        #print(label)
        self.showmip = d[label]

        self.update_im()
        self.update_figures()
예제 #21
0
class ScanData:
    '''
    Collect raw data for energy scan experiments and provides methods to
    average them.

    Attributes
    ----------
    label : list (str)
        Labels for graphs.

    idx : list (str)
        Scan indexes

    raw_imp : pandas DataFrame
        Collect imported raw data.

    energy : array
        common energy scale for average of scans.

    lines : list (2d line objects)
        Collect lines object of raw scan for plotting.

    blins : list (2d line objects)
        Collect lines object of raw scan for plotting.

    avg_ln : 2d line objects
        Lines object of average of scans.

    plab : list (str)
        Collect the list of labels plotted in graphs for choosing scans.

    checkbx : CheckButtons obj
        Widget for choosing plots.

    aver : array
        Data average of selected scans from raw_imp.
        If ScanData is a reference one, aver contain normalized data by
        reference.

    dtype : str
        Identifies the data collected, used for graph labelling:
        sigma+, sigma- for XMCD
        CR, CL for XNCD
        H-, H+ for XNXD
        LH, LV fot XNLD

    chsn_scns : list (str)
        Labels of chosen scans for the analysis.

    pe_av : float
        Value of spectra at pre-edge energy. It is obtained from
        averaging data in an defined energy range centered at pre-edge
        energy.

    pe_av_int : float
        Pre-edge value obtained from linear interpolation considering
        pre-edge and post-edge energies.

    bsl : Univariate spline object
        spline interpolation of ArpLS baseline 

    norm : array
        Averaged data normalized by value at pre-edge energy.

    norm_int : array
        Averaged data normalized by interpolated pre-edge value.

    ej : float
        Edge-jump value.

    ej_norm : float
            edge-jump value normalized by value at pre-edge energy.

    ej_int : float
        edge-jump value computed with interpolated pre-edge value.

    ej_norm_int : float
        edge-jump computed and normalized by interpolated pre-edge
        value.

    Methods
    -------
    man_aver_e_scans(guiobj, enrg)
        Manage the choice of scans to be averaged and return the average
        of selected scans.

    aver_e_scans(enrg, chsn, guiobj)
        Performe the average of data scans.

    check_but(label)
        When check buttons are checked switch the visibility of
        corresponding line.

    averbut(event)
        When average button is pressed calls aver_e_scans to compute
        average on selected scans.

    reset(event)
        Reset graph to starting conditions.

    finish_but(self, event)
        Close the figure on pressing the button Finish.

    edge_norm(guiobj, enrg, e_edge, e_pe, pe_rng, pe_int)
        Normalize energy scan data by value at pre-edge energy and
        compute edge-jump.
    '''
    def __init__(self):
        '''
        Initialize attributes label, idx, and raw_imp.
        '''
        self.label = []
        self.idx = []
        self.raw_imp = pd.DataFrame()

    def man_aver_e_scans(self, guiobj, enrg):
        '''
        Manage the choice of scans to be averaged and return the average
        of selected scans.

        Parameters
        ----------
        guiobj : GUI object
            Provides GUI dialogs.

        enrg : array
            Energy values at which average is calculated.

        Return
        ------
        Set class attributes:
        aver : array
            Average values of the chosen scans.

        chsn_scns : list
            Labels of chosen scans for the analysis (for log purpose).
        '''
        self.energy = enrg
        self.chsn_scns = []
        self.aver = 0

        if guiobj.interactive:  # Interactive choose of scans
            fig, ax = plt.subplots(figsize=(10, 6))
            fig.subplots_adjust(right=0.75)

            ax.set_xlabel('E (eV)')
            ax.set_ylabel(self.dtype)

            if guiobj.infile_ref:
                fig.suptitle('Choose reference sample scans')
            else:
                fig.suptitle('Choose sample scans')
            # Initialize list which will contains line obj of scans
            # lines contain colored lines for choose
            # blines contain dark lines to be showed with average
            self.lines = []
            self.blines = []
            # Populate list with line objs
            for i in self.idx:
                e_col = 'E' + i
                # Show lines and not blines
                self.lines.append(
                    ax.plot(self.raw_imp[e_col],
                            self.raw_imp[i],
                            label=self.label[self.idx.index(i)])[0])
                self.blines.append(
                    ax.plot(self.raw_imp[e_col],
                            self.raw_imp[i],
                            color='dimgrey',
                            visible=False)[0])
            # Initialize chsn_scs and average line with all scans and
            # set it invisible
            for line in self.lines:
                if line.get_visible():
                    self.chsn_scns.append(line.get_label())
            self.aver_e_scans()
            self.avg_ln, = ax.plot(self.energy,
                                   self.aver,
                                   color='red',
                                   lw=2,
                                   visible=False)
            # Create box for checkbutton
            chax = fig.add_axes([0.755, 0.32, 0.24, 0.55], facecolor='0.95')
            self.plab = [str(line.get_label()) for line in self.lines]
            visibility = [line.get_visible() for line in self.lines]
            self.checkbx = CheckButtons(chax, self.plab, visibility)
            # Customizations of checkbuttons
            rxy = []
            bxh = 0.05
            for r in self.checkbx.rectangles:
                r.set(height=bxh)
                r.set(width=bxh)
                rxy.append(r.get_xy())
            for i in range(len(rxy)):
                self.checkbx.lines[i][0].set_xdata(
                    [rxy[i][0], rxy[i][0] + bxh])
                self.checkbx.lines[i][0].set_ydata(
                    [rxy[i][1], rxy[i][1] + bxh])
                self.checkbx.lines[i][1].set_xdata(
                    [rxy[i][0] + bxh, rxy[i][0]])
                self.checkbx.lines[i][1].set_ydata(
                    [rxy[i][1], rxy[i][1] + bxh])

            for l in self.checkbx.labels:
                l.set(fontsize='medium')
                l.set_verticalalignment('center')
                l.set_horizontalalignment('left')

            self.checkbx.on_clicked(self.check_but)

            # Create box for average reset and finish buttons
            averbox = fig.add_axes([0.77, 0.2, 0.08, 0.08])
            bnaver = Button(averbox, 'Average')
            bnaver.on_clicked(self.averbut)
            rstbox = fig.add_axes([0.89, 0.2, 0.08, 0.08])
            bnrst = Button(rstbox, 'Reset')
            bnrst.on_clicked(self.reset)
            finbox = fig.add_axes([0.82, 0.07, 0.12, 0.08])
            bnfinish = Button(finbox, 'Finish')
            bnfinish.on_clicked(self.finish_but)

            ax.legend()
            plt.show()

            # If average is not pressed automatically compute average on
            # selected scans
            if self.chsn_scns == []:
                for line in self.lines:
                    if line.get_visible():
                        self.chsn_scns.append(line.get_label())
                self.aver_e_scans()
        else:
            # Not-interactive mode: all scans except 'Dummy Scans' are
            # evaluated
            for lbl in self.label:
                # Check it is not a 'Dummy Scan' and append
                # corresponding scan number in chosen scan list
                if not ('Dummy' in lbl):
                    self.chsn_scns.append(self.idx[self.label.index(lbl)])

            self.aver_e_scans()

    def check_but(self, label):
        '''
        When check buttons are checked switch the visibility of
        corresponding line.
        Also update self.chsn_scns with labels of visible scans.
        '''
        index = self.plab.index(label)
        self.lines[index].set_visible(not self.lines[index].get_visible())
        # Update chsn_scns
        self.chsn_scns = []
        for line in self.lines:
            if line.get_visible():
                self.chsn_scns.append(line.get_label())
        plt.draw()

    def averbut(self, event):
        '''
        When average button is pressed calls aver_e_scans to compute
        average on selected scans.
        Update self.chsn_scns and self.aver.
        '''
        # Initialize list of chosed scans
        self.chsn_scns = []
        # Set visible only chosen scans in blines and append to
        # chsn_scns
        for i in range(len(self.lines)):
            if self.lines[i].get_visible():
                self.lines[i].set(visible=False)
                self.chsn_scns.append(self.lines[i].get_label())
                self.blines[i].set(visible=True)

        self.aver_e_scans()

        # Update average line and make it visible
        self.avg_ln.set_ydata(self.aver)
        self.avg_ln.set(visible=True)

        plt.draw()

    def aver_e_scans(self):
        '''
        Perform the average of data scans. 
        If interactive mode, data scans and their average are shown
        together in a plot. 

        Parameters
        ----------
        enrg : array
            Energy values at which average is calculated.

        chsn : list (str)
            Scan-numbers of scan to be averaged.

        guiobj: GUI object
            Provides GUI dialogs.

        Returns
        -------
        array, containing the average of data scans.

        Notes
        -----
        To compute the average the common energy scale enrg is used.
        All passed scans are interpolated with a linear spline (k=1 and
        s=0 in itp.UnivariateSpline) and evaluated along the common
        energy scale.
        The interpolated data are eventually averaged.
        '''
        intrp = []

        for i in self.idx:
            e_col = 'E' + i
            if self.label[self.idx.index(i)] in self.chsn_scns:
                # chosen data
                x = self.raw_imp['E' + i][1:]
                y = self.raw_imp[i][1:]

                # Compute linear spline interpolation
                y_int = itp.UnivariateSpline(x, y, k=1, s=0)
                # Evaluate interpolation of scan data on enrg energy scale
                # and append to previous interpolations
                intrp.append(y_int(self.energy))

        # Average all inteprolated scans
        self.aver = np.average(intrp, axis=0)

    def reset(self, event):
        '''
        Reset graph for schoosing scans to starting conditions.
        Show all scans and set checked all buttons.
        '''
        # Clear graph
        self.avg_ln.set(visible=False)

        stauts = self.checkbx.get_status()
        for i, stat in enumerate(stauts):
            if not stat:
                self.checkbx.set_active(i)
        # Show all spectra
        for i in range(len(self.lines)):
            self.lines[i].set(visible=True)
            self.blines[i].set(visible=False)

        plt.draw()

    def finish_but(self, event):
        '''
        Close the figure on pressing the button Finish.
        '''
        plt.close()

    def edge_norm(self, guiobj, enrg, e_edge, e_pe, e_poste, pe_rng):
        '''
        Normalize energy scan data by the value at pre-edge energy.
        Also compute the  energy jump defined as the difference between
        the value at the edge and pre-edge energies respectively.

        This computations are implemented also considering baseline.
        If linear baseline is selected edge jump is computed considering
        the the height of data at edge energy from the stright line
        passing from pre-edge and post edge data.
        If asymmetrically reweighted penalized least squares baseline is
        selected the edge jump is calculated considering as the distance
        at edge energy between the averaged spectrum and baseline.
        
        Parameters
        ----------
        guiobj: GUI object
            Provides GUI dialogs.

        enrg : array
            Energy values of scan.

        e_edge : float
            Edge energy value.

        e_pe : float
            Pre-edge energy value.

        pe_rng : int
            Number of points constituting the semi-width of energy range
            centered at e_pe.

        pe_int : float
            Pre-edge value obtained from linear interpolation based on
            pre- and post-edge energies.

        Returns
        -------
        Set class attributes:
        pe_av : float
            value at pre-edge energy.

        norm : array
            self.aver scan normalized by value at pre-edge energy.

        norm_int : array
            Averaged data normalized by interpolated pre-edge value.

        ej : float
            edge-jump value.

        ej_norm : float
            edge-jump value normalized by value at pre-edge energy.

        ej_int : float
            edge-jump value computed with interpolated pre-edge value.

        ej_norm_int : float
            edge-jump computed and normalized by interpolated pre-edge
            value.

        Notes
        -----
        To reduce noise effects the value of scan at pre-edge energy is
        obtained computing an average over an energy range of width
        pe_rng and centered at e_pe pre-edge energy.
        The value of scan at edge energy is obtained by cubic spline
        interpolation of data (itp.UnivariateSpline with k=3 and s=0).
        '''
        # Index of the nearest element to pre-edge energy
        pe_idx = np.argmin((np.abs(enrg - e_pe)))
        # Left and right extremes of energy range for pre-edge average
        lpe_idx = int(pe_idx - pe_rng)
        rpe_idx = int(pe_idx + pe_rng + 1)

        # Average of values for computation of pre-edge
        self.pe_av = np.average(self.aver[lpe_idx:rpe_idx:1])

        # Cubic spline interpolation of energy scan
        y_int = itp.UnivariateSpline(enrg, self.aver, k=3, s=0)
        # value at edge energy from interpolation
        y_edg = y_int(e_edge)

        # Edge-jumps computations - no baseline
        self.ej = y_edg - self.pe_av
        self.ej_norm = self.ej / self.pe_av
        # Normalization by pre-edge value
        self.norm = self.aver / self.pe_av

        # Edge-jumps computations - consider baseline
        if guiobj.bsl_int:
            # ArpLS baseline
            # Interpolation of pre-edge energy
            self.pe_av_int = self.bsl(e_edge)
        else:
            # Linear baseline
            # Interpolation of pre-edge energy
            x = [e_pe, e_poste]
            y = [y_int(e_pe), y_int(e_poste)]
            self.pe_av_int = lin_interpolate(x, y, e_edge)

        # Normalization by pre-edge value
        self.norm_int = self.aver / self.pe_av_int

        self.ej_int = y_edg - self.pe_av_int
        self.ej_norm_int = self.ej_int / self.pe_av_int
예제 #22
0
class InoHeartbeatVerifier(object):
    """
    GUI for editing heartbeat signal labels for INO case
    """
    def __init__(self,
                 composite_peaks,
                 index=0,
                 folder_name="",
                 dosage="",
                 file_name="",
                 interval_number=""):

        super(InoHeartbeatVerifier, self).__init__()
        # Save Composite Peaks
        self.composite_peaks = composite_peaks
        self.index = index

        self.folder_name = folder_name
        self.dosage = dosage
        self.file_name = file_name
        self.interval_number = interval_number
        self.update_point = None

        # Plot signals
        self.plot_signals()

    def plot_signals(self):
        # Create figure
        self.fig, self.ax = plt.subplots()

        # Load composite signals
        self.time, self.signal, self.seis, self.phono = self.composite_peaks.composites[
            self.index]

        # Plot ECG, Phono and Seismo
        self.signal_line, = self.ax.plot(self.signal,
                                         linewidth=1,
                                         c="b",
                                         label="ECG")
        self.seis_line, = self.ax.plot(self.seis,
                                       '--',
                                       linewidth=0.5,
                                       c='r',
                                       label="Seis")
        self.phono_line, = self.ax.plot(self.phono,
                                        '--',
                                        linewidth=0.5,
                                        c='g',
                                        label="Phono")

        self.ax.set_xlim(0, len(self.signal))

        sig_min = min(self.signal)
        sig_max = max(self.signal)

        self.ax.set_ylim(sig_min - 0.1 * (sig_max - sig_min),
                         sig_max + 0.1 * (sig_max - sig_min))
        plt.legend(loc='upper right')

        # Q Peak
        self.q_point = self.ax.scatter(
            self.composite_peaks.Q.data[self.index],
            self.signal[self.composite_peaks.Q.data[self.index]],
            c='#ff7f0e')
        self.q_text = self.ax.text(
            self.composite_peaks.Q.data[self.index],
            self.signal[self.composite_peaks.Q.data[self.index]] + 0.2,
            "Q",
            fontsize=9,
            horizontalalignment='center')

        # QM Seismo
        self.qm_seis_point = self.ax.scatter(
            self.composite_peaks.QM_seis.data[self.index],
            self.seis[self.composite_peaks.QM_seis.data[self.index]],
            c='#d62728')
        self.qm_seis_text = self.ax.text(
            self.composite_peaks.QM_seis.data[self.index],
            self.seis[self.composite_peaks.QM_seis.data[self.index]] + 0.2,
            "QM Seis",
            fontsize=9,
            horizontalalignment='center')

        # QM Phono
        self.qm_phono_point = self.ax.scatter(
            self.composite_peaks.QM_phono.data[self.index],
            self.phono[self.composite_peaks.QM_phono.data[self.index]],
            c='#8c564b')
        self.qm_phono_text = self.ax.text(
            self.composite_peaks.QM_phono.data[self.index],
            self.phono[self.composite_peaks.QM_phono.data[self.index]] + 0.2,
            "QM Phono",
            fontsize=9,
            horizontalalignment='center')

        # Initalize axes and data points
        self.x = range(len(self.signal))
        self.y = self.signal

        # Cross hairs
        self.lx = self.ax.axhline(color='k', linewidth=0.2)  # the horiz line
        self.ly = self.ax.axvline(color='k', linewidth=0.2)  # the vert line

        # Add data
        left_shift = 0.45
        start = 0.96
        space = 0.04
        self.ax.text(0.01,
                     start,
                     transform=self.ax.transAxes,
                     s="Folder: " + self.folder_name,
                     fontsize=12,
                     horizontalalignment='left')
        self.ax.text(0.01,
                     start - space,
                     transform=self.ax.transAxes,
                     s="Dosage: " + str(self.dosage),
                     fontsize=12,
                     horizontalalignment='left')
        self.ax.text(0.01,
                     start - 2 * space,
                     transform=self.ax.transAxes,
                     s="File: " + self.file_name,
                     fontsize=12,
                     horizontalalignment='left')
        self.ax.text(0.01,
                     start - 3 * space,
                     transform=self.ax.transAxes,
                     s="File #: " + str(self.interval_number),
                     fontsize=12,
                     horizontalalignment='left')
        self.i_text = self.ax.text(0.60 - left_shift,
                                   1.1 - space,
                                   transform=self.ax.transAxes,
                                   s="Composite: " + str(self.index + 1) +
                                   "/" +
                                   str(len(self.composite_peaks.composites)),
                                   fontsize=12,
                                   horizontalalignment='left')

        # Add Intervals
        start_left = 0.575
        shift_left = 0.10
        qm_seis = str(
            round(
                1 / (self.time[self.composite_peaks.QM_seis.data[self.index]] -
                     self.time[self.composite_peaks.Q.data[self.index]]), 2))
        qm_phono = str(
            round(
                1 /
                (self.time[self.composite_peaks.QM_phono.data[self.index]] -
                 self.time[self.composite_peaks.Q.data[self.index]]), 2))
        self.qm_text = self.ax.text(start_left,
                                    0.91,
                                    horizontalalignment='center',
                                    transform=self.fig.transFigure,
                                    s="1/(E-M)ino\nSeis: " + qm_seis + " Hz" +
                                    "\nPhono: " + qm_phono + " Hz")

        # Add index buttons
        ax_prev = plt.axes([0.575 - left_shift, 0.9, 0.1, 0.075])
        self.bprev = Button(ax_prev, 'Previous')
        self.bprev.on_clicked(self.prev)

        ax_next = plt.axes([0.8 - left_shift, 0.9, 0.1, 0.075])
        self.b_next = Button(ax_next, 'Next')
        self.b_next.on_clicked(self.next)

        self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_move)

        # Add Save Button
        ax_save = plt.axes([0.8, 0.9, 0.1, 0.075])
        self.b_save = Button(ax_save, 'Save')
        self.b_save.on_clicked(self.save)

        # Add Line buttons
        self.ax.text(-0.13,
                     0.97,
                     transform=self.ax.transAxes,
                     s="Snap on to:",
                     fontsize=12,
                     horizontalalignment='left')
        # left, bottom, width, height
        ax_switch_signals = plt.axes([0.02, 0.7, 0.07, 0.15])
        self.b_switch_signals = RadioButtons(ax_switch_signals,
                                             ('ECG', 'Seismo', 'Phono'))
        for c in self.b_switch_signals.circles:
            c.set_radius(0.05)

        self.b_switch_signals.on_clicked(self.switch_signal)

        self.fig.canvas.mpl_connect('button_press_event', self.on_click)
        self.fig.canvas.mpl_connect('button_release_event', self.off_click)

        # Add Line hide buttons
        self.ax.text(1.015,
                     0.97,
                     transform=self.ax.transAxes,
                     s="Hide Signal:",
                     fontsize=12,
                     horizontalalignment='left')
        ax_hide_signals = plt.axes([0.91, 0.7, 0.07, 0.15])
        self.b_hide_signals = CheckButtons(ax_hide_signals,
                                           ('ECG', 'Seismo', 'Phono'))

        self.b_hide_signals.on_clicked(self.switch_signal)

        # Add Sliders
        self.signal_amp_slider = Slider(plt.axes([0.91, 0.15, 0.01, 0.475]),
                                        label="ECG\nA",
                                        valmin=0.01,
                                        valmax=10,
                                        valinit=1,
                                        orientation='vertical')
        self.signal_amp_slider.on_changed(self.switch_signal)
        self.signal_amp_slider.valtext.set_visible(False)

        self.seis_height_slider = Slider(plt.axes([0.93, 0.15, 0.01, 0.475]),
                                         label="   Seis\nH",
                                         valmin=1.5 * min(self.signal),
                                         valmax=1.5 * max(self.signal),
                                         valinit=0,
                                         orientation='vertical')
        self.seis_height_slider.on_changed(self.switch_signal)
        self.seis_height_slider.valtext.set_visible(False)

        self.seis_amp_slider = Slider(plt.axes([0.94, 0.15, 0.01, 0.475]),
                                      label="\nA",
                                      valmin=0.01,
                                      valmax=10,
                                      valinit=1,
                                      orientation='vertical')
        self.seis_amp_slider.on_changed(self.switch_signal)
        self.seis_amp_slider.valtext.set_visible(False)

        self.phono_height_slider = Slider(plt.axes([0.96, 0.15, 0.01, 0.475]),
                                          label="    Phono\nH",
                                          valmin=1.5 * min(self.signal),
                                          valmax=1.5 * max(self.signal),
                                          valinit=0,
                                          orientation='vertical')
        self.phono_height_slider.on_changed(self.switch_signal)
        self.phono_height_slider.valtext.set_visible(False)

        self.phono_amp_slider = Slider(plt.axes([0.97, 0.15, 0.01, 0.475]),
                                       label="A",
                                       valmin=.01,
                                       valmax=10,
                                       valinit=1,
                                       orientation='vertical')
        self.phono_amp_slider.on_changed(self.switch_signal)
        self.phono_amp_slider.valtext.set_visible(False)

        # Maximize frame
        mng = plt.get_current_fig_manager()
        mng.full_screen_toggle()

        plt.show()

    def switch_signal(self, label):

        # Update Lines
        self.signal_line.set_data(range(len(self.signal)),
                                  self.signal_amp_slider.val * self.signal)
        self.seis_line.set_data(range(len(self.signal)),
                                (self.seis_amp_slider.val * self.seis) +
                                self.seis_height_slider.val)
        self.phono_line.set_data(range(len(self.signal)),
                                 (self.phono_amp_slider.val * self.phono) +
                                 self.phono_height_slider.val)

        # Q Peaks
        self.q_point.set_offsets(
            (self.composite_peaks.Q.data[self.index],
             self.signal_amp_slider.val *
             self.signal[self.composite_peaks.Q.data[self.index]]))
        self.q_text.set_position(
            (self.composite_peaks.Q.data[self.index],
             self.signal_amp_slider.val *
             self.signal[self.composite_peaks.Q.data[self.index]] + 0.2))

        # QM Seismo
        self.qm_seis_point.set_offsets(
            (self.composite_peaks.QM_seis.data[self.index],
             self.seis_amp_slider.val *
             self.seis[self.composite_peaks.QM_seis.data[self.index]] +
             self.seis_height_slider.val))
        self.qm_seis_text.set_position(
            (self.composite_peaks.QM_seis.data[self.index],
             self.seis_amp_slider.val *
             self.seis[self.composite_peaks.QM_seis.data[self.index]] + 0.2 +
             self.seis_height_slider.val))

        # QM Phono
        self.qm_phono_point.set_offsets(
            (self.composite_peaks.QM_phono.data[self.index],
             self.phono_amp_slider.val *
             self.phono[self.composite_peaks.QM_phono.data[self.index]] +
             self.phono_height_slider.val))
        self.qm_phono_text.set_position(
            (self.composite_peaks.QM_phono.data[self.index],
             self.phono_amp_slider.val *
             self.phono[self.composite_peaks.QM_phono.data[self.index]] + 0.2 +
             self.phono_height_slider.val))

        # Update Data
        qm_seis = str(
            round(
                1 / (self.time[self.composite_peaks.QM_seis.data[self.index]] -
                     self.time[self.composite_peaks.Q.data[self.index]]), 2))
        qm_phono = str(
            round(
                1 /
                (self.time[self.composite_peaks.QM_phono.data[self.index]] -
                 self.time[self.composite_peaks.Q.data[self.index]]), 2))
        self.qm_text.set_text("1/(E-M)ino\nSeis: " + qm_seis + " Hz" +
                              "\nPhono: " + qm_phono + " Hz")

        # Update Cross-hairs
        label = self.b_switch_signals.value_selected
        self.x = range(len(self.signal))
        if label == 'ECG':
            self.y = self.signal_amp_slider.val * self.signal

        if label == 'Seismo':
            self.y = self.seis_amp_slider.val * self.seis + self.seis_height_slider.val

        if label == 'Phono':
            self.y = self.phono_amp_slider.val * self.phono + self.phono_height_slider.val

        # Update Hidden signals
        hide_label = self.b_hide_signals.get_status()
        if hide_label[0]:  # ECG
            self.signal_line.set_linewidth(0)

        else:
            self.signal_line.set_linewidth(0.5)

        if hide_label[1]:  # Seismo
            self.seis_line.set_linewidth(0)

        else:
            self.seis_line.set_linewidth(0.5)

        if hide_label[2]:  # Phono
            self.phono_line.set_linewidth(0)

        else:
            self.phono_line.set_linewidth(0.5)

        self.fig.canvas.draw()

    def off_click(self, event):
        self.lx.set_color('k')
        self.ly.set_color('k')

        self.lx.set_linewidth(0.2)
        self.ly.set_linewidth(0.2)

        if event.xdata is not None:
            if self.update_point == "Q":
                self.q_point.set_offsets(
                    (int(event.xdata), self.signal_amp_slider.val *
                     self.signal[int(event.xdata)]))
                self.q_text.set_position(
                    (int(event.xdata),
                     self.signal_amp_slider.val * self.signal[int(event.xdata)]
                     + 0.2))

                self.composite_peaks.Q.data[self.index] = int(event.xdata)

            if self.update_point == "QM Seismo":
                self.qm_seis_point.set_offsets(
                    (int(event.xdata),
                     self.seis_amp_slider.val * self.seis[int(event.xdata)] +
                     self.seis_height_slider.val))
                self.qm_seis_text.set_position(
                    (int(event.xdata),
                     self.seis_amp_slider.val * self.seis[int(event.xdata)] +
                     0.2 + self.seis_height_slider.val))

                self.composite_peaks.QM_seis.data[self.index] = int(
                    event.xdata)

            if self.update_point == "QM Phono":
                self.qm_phono_point.set_offsets(
                    (int(event.xdata),
                     self.phono_amp_slider.val * self.phono[int(event.xdata)] +
                     self.phono_height_slider.val))
                self.qm_phono_text.set_position(
                    (int(event.xdata),
                     self.phono_amp_slider.val * self.phono[int(event.xdata)] +
                     0.2 + self.phono_height_slider.val))

                self.composite_peaks.QM_phono.data[self.index] = int(
                    event.xdata)

            # Update Data
            qm_seis = str(
                round(
                    1 /
                    (self.time[self.composite_peaks.QM_seis.data[self.index]] -
                     self.time[self.composite_peaks.Q.data[self.index]]), 2))
            qm_phono = str(
                round(
                    1 /
                    (self.time[self.composite_peaks.QM_phono.data[self.index]]
                     - self.time[self.composite_peaks.Q.data[self.index]]), 2))
            self.qm_text.set_text("1/(E-M)ino\nSeis: " + qm_seis + " Hz" +
                                  "\nPhono: " + qm_phono + " Hz")

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

    def on_click(self, event):
        threshold = 40
        current_signal = self.b_switch_signals.value_selected
        self.update_point = None

        if event.xdata is not None:
            if current_signal == 'ECG':
                if abs(self.composite_peaks.Q.data[self.index] -
                       event.xdata) < threshold:
                    self.lx.set_color('#ff7f0e')
                    self.ly.set_color('#ff7f0e')

                    self.lx.set_linewidth(1)
                    self.ly.set_linewidth(1)

                    self.fig.canvas.draw()

                    self.update_point = "Q"

            if current_signal == 'Seismo':
                if abs(self.composite_peaks.QM_seis.data[self.index] -
                       event.xdata) < threshold:
                    self.lx.set_color('#d62728')
                    self.ly.set_color('#d62728')

                    self.lx.set_linewidth(1)
                    self.ly.set_linewidth(1)

                    self.fig.canvas.draw()

                    self.update_point = "QM Seismo"

            if current_signal == 'Phono':
                if abs(self.composite_peaks.QM_phono.data[self.index] -
                       event.xdata) < threshold:
                    self.lx.set_color('#8c564b')
                    self.ly.set_color('#8c564b')

                    self.lx.set_linewidth(1)
                    self.ly.set_linewidth(1)

                    self.fig.canvas.draw()

                    self.update_point = "QM Phono"

    def mouse_move(self, event):
        # If nothing happened do nothing
        if not event.inaxes:
            return

        # Update x data point
        x = event.xdata

        # Lock to closest x coordinate on signal
        indx = min(np.searchsorted(self.x, x), len(self.x) - 1)
        x = self.x[indx]
        y = self.y[indx]

        # Update the crosshairs
        self.lx.set_ydata(y)
        self.ly.set_xdata(x)

        # Draw everything
        self.ax.figure.canvas.draw()

    def next(self, event):
        self.index += 1
        if self.index > len(self.composite_peaks.composites) - 1:
            self.index = 0
        self.update_plot()

    def prev(self, event):
        self.index -= 1
        if self.index < 0:
            self.index = len(self.composite_peaks.composites) - 1
        self.update_plot()

    def save(self, event):
        # Get File Name
        save_file_name = "composites_" + self.folder_name + "_" + self.file_name + "_d" + str(
            self.dosage)

        # Save
        self.composite_peaks.save("data/Derived/composites/" + save_file_name)
        print("Saved")

    def update_plot(self):
        # Update index
        self.i_text.set_text("Composite: " + str(self.index + 1) + "/" +
                             str(len(self.composite_peaks.composites)))

        # Load composite signals
        self.time, self.signal, self.seis, self.phono = self.composite_peaks.composites[
            self.index]

        # Update cross hairs
        self.switch_signal(None)

        # Plot ECG, Phono and Seismo
        self.signal_line.set_data(range(len(self.signal)), self.signal)
        self.seis_line.set_data(range(len(self.seis)), self.seis)
        self.phono_line.set_data(range(len(self.phono)), self.phono)
        self.ax.set_xlim(0, len(self.signal))

        # Q Peaks
        self.q_point.set_offsets(
            (self.composite_peaks.Q.data[self.index],
             self.signal[self.composite_peaks.Q.data[self.index]]))
        self.q_text.set_position(
            (self.composite_peaks.Q.data[self.index],
             self.signal[self.composite_peaks.Q.data[self.index]] + 0.2))

        # QM Seismo
        self.qm_seis_point.set_offsets(
            (self.composite_peaks.QM_seis.data[self.index],
             self.seis[self.composite_peaks.QM_seis.data[self.index]]))
        self.qm_seis_text.set_position(
            (self.composite_peaks.QM_seis.data[self.index],
             self.seis[self.composite_peaks.QM_seis.data[self.index]] + 0.2))

        # QM Phono
        self.qm_phono_point.set_offsets(
            (self.composite_peaks.QM_phono.data[self.index],
             self.phono[self.composite_peaks.QM_phono.data[self.index]]))
        self.qm_phono_text.set_position(
            (self.composite_peaks.QM_phono.data[self.index],
             self.phono[self.composite_peaks.QM_phono.data[self.index]] + 0.2))

        self.fig.canvas.draw()
예제 #23
0
class CrossMark:
    def __init__(self, img_name, imgref_name):
        self.img = cv2.imread(img_name)[:, :, ::-1]
        self.imgr = cv2.imread(imgref_name)[:, :, ::-1]

        fig = plt.figure(str('coross mark'))
        self.sfigs = []
        self.sfigs += [plt.subplot(1, 2, 1)]

        plt.imshow(self.imgr)
        self.sfigs += [plt.subplot(1, 2, 2)]
        plt.imshow(self.img)

        axsave = plt.axes([0.8, 0.05, 0.1, 0.075])
        bsave = Button(axsave, 'Save')
        bsave.on_clicked(self.save)

        axadd = plt.axes([0.7, 0.05, 0.1, 0.075])
        badd = Button(axadd, 'Add')
        badd.on_clicked(self.add)

        axdel = plt.axes([0.6, 0.05, 0.1, 0.075])
        bdel = Button(axdel, 'Del')
        bdel.on_clicked(self.delete)

        axrot = plt.axes([0.5, 0.05, 0.1, 0.075])
        brot = Button(axrot, 'Rot')
        brot.on_clicked(self.rotate_im)

        axplus = plt.axes([0.4, 0.05, 0.1, 0.075])
        bplus = Button(axplus, '+')
        bplus.on_clicked(self.plus)

        axmin = plt.axes([0.3, 0.05, 0.1, 0.075])
        bmin = Button(axmin, '-')
        bmin.on_clicked(self.minus)

        axchk = plt.axes([0.2, 0.90, 0.2, 0.1])
        self.chk_colors = CheckButtons(axchk, ('Show_pts', ), actives=[False])
        self.chk_colors.on_clicked(self.redraw)

        cid = fig.canvas.mpl_connect('button_press_event', self.onclick)

        self.objects_hdls = None
        self.cross_list = [[], []]
        self.selected_obj_ind = -1

        self.draw_objs()
        plt.show()

    def plus(self, event):
        self.selected_obj_ind += 1
        self.selected_obj_ind = min(
            (self.selected_obj_ind, len(self.cross_list[0])))
        print('self.selected_obj_ind', self.selected_obj_ind)
        self.draw_objs()
        plt.draw()

    def minus(self, event):
        self.selected_obj_ind -= 1
        self.selected_obj_ind = max((-1, self.selected_obj_ind))
        print('self.selected_obj_ind', self.selected_obj_ind)
        self.draw_objs()
        plt.draw()

    def rotate_im(self, event):
        self.img = cv2.transpose(cv2.flip(self.img, 1))
        self.cross_list[1] = [(y, self.img.shape[0] - x - 1)
                              for x, y in self.cross_list[1]]

        self.sfigs[1].cla()
        self.sfigs[1].imshow(self.img)
        self.draw_objs()
        plt.draw()

    def redraw(self, event):
        self.draw_objs()
        plt.draw()

    def draw_objs(self, selected=-1):
        if self.objects_hdls is not None:
            for h in self.objects_hdls:
                #print('----',h)
                try:
                    h[0].remove()
                except ValueError:
                    pass
        self.objects_hdls = []

        h = self.objects_hdls

        if self.chk_colors.get_status()[0]:
            pts_r = pick_colors_pts()
            pts = convert_points(np.array(self.cross_list[0]),
                                 np.array(self.cross_list[1]), pts_r)
            h.append(self.sfigs[0].plot(pts_r[:, 0], pts_r[:, 1], '+g'))
            h.append(self.sfigs[1].plot(pts[:, 0], pts[:, 1], '+g'))

        self.sfigs[0].set_title('{}/{}'.format(self.selected_obj_ind + 1,
                                               len(self.cross_list[0])))

        for find, subp in enumerate(self.sfigs):
            #if len(self.cross_list[find]):
            #    crosses = np.array(self.cross_list[find])
            #    h.append(subp.plot(crosses[:,0],crosses[:,1],'+b'))
            for i, cr in enumerate(self.cross_list[find]):
                h.append(
                    subp.plot(cr[0], cr[1],
                              '+r' if self.selected_obj_ind == i else '+b'))

    def get_closest(self, x, y, ax_ind):
        max_ind = -1
        max_val = 30  #minmal distance to accept click
        for ind, cr in enumerate(self.cross_list[ax_ind]):
            d = abs(cr[0] - x) + abs(cr[1] - y)
            if d < max_val:
                max_val = d
                max_ind = ind
        return max_ind

    def delete(self, event):
        if self.selected_obj_ind != -1:
            for i in [0, 1]:
                self.cross_list[i].pop(self.selected_obj_ind)
            self.selected_obj_ind = -1
            self.draw_objs()
            plt.draw()

    def save(self, event):
        plt.figure()
        plt.subplot(1, 3, 1)
        plt.title('ref')
        plt.imshow(self.imgr)

        pts_r = pick_colors_pts()
        plt.plot(pts_r[:, 0], pts_r[:, 1], '+')

        plt.subplot(1, 3, 2)
        plt.title('im corrected')
        pts = convert_points(np.array(self.cross_list[0]),
                             np.array(self.cross_list[1]), pts_r)

        plt.plot(pts[:, 0], pts[:, 1], '+')
        col_r = pick_colors(self.imgr, pts_r)
        col_im = pick_colors(self.img, pts)
        CM = np.transpose(np.linalg.lstsq(col_im, col_r, rcond=None)[0])

        with open('color_mat.pkl', 'wb') as fd:
            pickle.dump(CM, fd)
        im1_corr = apply_mat(self.img, CM)
        plt.imshow(im1_corr)

        plt.subplot(1, 3, 3)
        plt.title('im1')
        plt.plot(pts[:, 0], pts[:, 1], '+')
        plt.imshow(self.img)
        plt.show()

    def add(self, event):
        #import pdb;pdb.set_trace()
        for i in [0, 1]:
            self.cross_list[i].append((-1, -1))
        self.draw_objs()
        plt.draw()

    def onclick(self, event):
        #print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
        #          ('double' if event.dblclick else 'single', event.button,
        #                     event.x, event.y, event.xdata, event.ydata))
        ind = self.sfigs.index(
            event.inaxes) if event.inaxes in self.sfigs else -1
        #print('inaxes',ind,event)
        x, y = event.xdata, event.ydata
        if event.button == 3 and ind > -1 and self.selected_obj_ind > -1:
            self.cross_list[ind][self.selected_obj_ind] = (x, y)
            self.draw_objs()
            plt.draw()
예제 #24
0
def plot_selected(event):
	print(CheckButtons.get_status(chk_btn))
	print('Plotting selected datasets...')

	chk_btn_status = CheckButtons.get_status(chk_btn)

	# empty dataframe 
	hyst_data_cut = pd.DataFrame()

	# column index operators init
	j = 0	# M
	k = 1	# H

	# construct DataFrame containig only selected (checkboxed) datastes
	for i in xrange(0,len(hyst_data.columns)/2):
	
		if chk_btn_status[i] == True:

			#get dataset to plot
			x = hyst_data.iloc[:,j]
			y = hyst_data.iloc[:,k]
			hyst_data_icut = pd.DataFrame([x,y]).T
			hyst_data_cut = pd.concat([hyst_data_cut, hyst_data_icut], axis=1)
		j+=2
		k+=2

	hyst_labels_cut = hyst_data_cut.columns.values
	
	print(hyst_data_cut)
	print('A total of' + ' ' + str(len(hyst_data_cut.columns)/2) + ' datasets in DataFrame were selected,')
	print('generated a new DataFrame with column names:')
	print(hyst_labels_cut)

	# plot figure for the dataframe
	fig1, ax1 = plt.subplots(figsize=(9, 9))
	fig1.tight_layout(pad=4.0, w_pad=0.5, h_pad=0.5)
	plt.subplots_adjust(left=0.15, bottom=0.1, wspace=0.0, hspace=0.0)

	# set widget window title
	fig1.canvas.set_window_title('hysteresis plotter' + ' ' + version_name) 
	
	# plot version info in the description
	plt.figtext(0.90, 0.97, version_name, size=10)

	# define datapoint markers colors
	colors = iter(plt.cm.inferno(np.linspace(0.3,0.8,10)))
	mfcolors = iter(plt.cm.plasma(np.linspace(0.1,1,10)))

	# collect plot objects here
	hyst_plots_cut = []

	# column index operators init
	j = 0	# M
	k = 1	# H

	# plot data recursively from DataFrame
	for i in xrange(0,len(hyst_data_cut.columns)/2):
	
		# generate dataset label
		hyst_label_cut = hyst_labels_cut[j][1:]
	
		#get dataset to plot
		x = hyst_data_cut.iloc[:,j]
		y = hyst_data_cut.iloc[:,k]
		j+=2
		k+=2

		hyst_plot_cut = plt.plot(x+i*H_step, y, 'o', color=next(colors), mfc=next(mfcolors), markersize=6, label=hyst_label_cut, visible=True)
		
	plt.legend(loc='upper left', frameon=True, fontsize=10, title='Anneal Time')

	# format plot
	custom_axis_formater(plot_title, plot_x_label, plot_y_label, xmin, xmax, ymin, ymax, xprec, yprec)
	plt.show()

	# export dataframe to xlsx
	hyst_data_cut.to_excel("hyst_out_cut.xlsx")
예제 #25
0
class CellAutomatonGameGUI(object):
    """
    GUI to control:
    1. N = # of creatures
    2. P = infection probability
    3. K = quarantine parameter
    4. L = generation (iteration) from which the quarantine applies
    5. animation's speed
    6. pause | play | reset buttons
    """
    def __init__(self, game):
        self.game = game
        self.N = N
        self.P = P
        self.K = K
        self.L = L

        # future members
        self.fig, self.ax = None, None
        self.animation = None

    def __set_widgets(self):
        """ set all sliders and buttons that are in the gui's responsibility """
        hcolor = None
        axcolor = 'white'
        slider_x_loc = 0.25
        slider_y_loc = 0.2
        slider_width = 0.6
        slider_hight = 0.021
        gap = slider_hight + 0.01

        # [left, bottom, width, height]
        # parameters sliders
        self.n_slider = Slider(plt.axes(
            [slider_x_loc, slider_y_loc, slider_width, slider_hight],
            facecolor=axcolor),
                               'N',
                               1.0,
                               int(self.game.get_size() / 2),
                               valinit=self.N,
                               valstep=1.0,
                               valfmt='%0.0f')
        self.p_slider = Slider(plt.axes(
            [slider_x_loc, slider_y_loc - gap, slider_width, slider_hight],
            facecolor=axcolor),
                               'P',
                               0.0,
                               1.0,
                               valinit=self.P)
        self.k_slider_loc = plt.axes(
            [slider_x_loc, slider_y_loc - 2 * gap, slider_width, slider_hight],
            facecolor=axcolor)
        self.k_slider = Slider(self.k_slider_loc,
                               'K',
                               1.0,
                               8.0,
                               valinit=self.K,
                               valstep=1.0,
                               valfmt='%0.0f')
        self.k_slider_loc.set_visible(False)
        self.l_slider_loc = self.fig.add_axes(
            [slider_x_loc, slider_y_loc - 3 * gap, slider_width, slider_hight],
            facecolor=axcolor)
        self.l_slider = Slider(self.l_slider_loc,
                               'L',
                               0.0,
                               1000.0,
                               valinit=0,
                               valstep=20.0,
                               valfmt='%0.0f')
        self.l_slider_loc.set_visible(False)

        # quarantine option menu
        self.right_menu_x_loc = 0.025
        self.options_button = CheckButtons(
            plt.axes(
                [self.right_menu_x_loc, slider_y_loc - 4 * gap, 0.18, 0.15]),
            ['apply\nquarantine'])
        self.get_stat_button = None
        if ALLOW_SAVE_DATA:
            self.get_stat_button = Button(plt.axes(
                [self.right_menu_x_loc, slider_y_loc + 15 * gap, 0.12, 0.04]),
                                          'save data',
                                          color=axcolor,
                                          hovercolor=hcolor)

        # control animation's speed
        y_axis_speed = 0.92
        plt.text(0.80, y_axis_speed, 'speed: ', transform=self.fig.transFigure)
        self.speed_box = plt.text(0.88,
                                  y_axis_speed,
                                  '',
                                  transform=self.fig.transFigure)
        self.speed_up_button = Button(plt.axes(
            [0.94, y_axis_speed, 0.02, 0.03]),
                                      '+',
                                      hovercolor=hcolor)
        self.speed_down_button = Button(plt.axes(
            [0.96, y_axis_speed, 0.02, 0.03]),
                                        '-',
                                        hovercolor=hcolor)

        # control animation buttons
        self.play_button = Button(plt.axes([0.8, 0.025, 0.1, 0.04]),
                                  'play',
                                  color=axcolor,
                                  hovercolor=hcolor)
        self.pause_button = Button(plt.axes([0.69, 0.025, 0.1, 0.04]),
                                   'pause',
                                   color=axcolor,
                                   hovercolor=hcolor)
        self.reset_button = Button(plt.axes([0.56, 0.025, 0.12, 0.04]),
                                   'reset',
                                   color=axcolor,
                                   hovercolor=hcolor)

    def __set_p(self, e):
        """ on-click function: change P slider value """
        self.P = self.p_slider.val

    def __set_n(self, e):
        """ on-click function: change N slider value """
        self.N = int(self.n_slider.val)

    def __set_k(self, e):
        """ on-click function: change K slider value """
        check = self.options_button.get_status()[0]
        self.k_slider_loc.set_visible(check)
        self.k_slider.set_active(check)
        self.K = int(self.k_slider.val) if check else 0

    def __set_l(self, e):
        """ on-click function: change L slider value """
        check = self.options_button.get_status()[0]
        self.l_slider_loc.set_visible(check)
        self.l_slider.set_active(check)
        self.L = int(self.l_slider.val) if check else None

    def __reset_button_on_click(self, e):
        """ on-click function: pause and reset the game with current parameters """
        self.animation.stop()
        self.game.build(self.N, self.P, self.K, self.L)

    def set_all(self):
        """ create gui's visible elements and attach them to event-functions """
        self.fig, self.ax = plt.subplots()
        plt.subplots_adjust(left=0.25, bottom=0.25)
        self.__set_widgets()

        self.p_slider.on_changed(self.__set_p)
        self.n_slider.on_changed(self.__set_n)
        self.k_slider.on_changed(self.__set_k)
        self.l_slider.on_changed(self.__set_l)

        self.options_button.on_clicked(self.__set_l)
        self.options_button.on_clicked(self.__set_k)
        self.reset_button.on_clicked(self.__reset_button_on_click)

    def start(self):
        """ create all entities and start the animation """
        # calculate game's statistics for each time step
        gs = GameStatistics(self.game)
        # follow statistics
        ShowStatistics(gs, self.fig, x=self.right_menu_x_loc)
        StatAccumulator(gs, self.get_stat_button)
        if SHOW_ONLINE_GRAPH:
            OnlineGraph(gs)

        self.game.build(self.N, self.P, self.K, self.L)
        self.animation = CellAnimation(self.pause_button, self.play_button,
                                       self.speed_up_button,
                                       self.speed_down_button, self.speed_box,
                                       self.fig, self.ax)
        self.animation.start(self.game)
        # +---------------------------------------------+
        # | Show Figure(determin whether or not)(GUI)   |
        # +---------------------------------------------+
        correlate_threshold = 0.01 * 100
        z_score_lag_threshold = 2.5
        z_score_log_threshold = 2.5
        depmax_diff = 0.9
        if ((wasserstein_metric * 100).max() >= correlate_threshold
                or abs(zScoreLag).max() >= z_score_lag_threshold
                or (abs(zScoreLog).max() >= z_score_log_threshold
                    and objects_depmax_log.max() - objects_depmax_log.min() >
                    depmax_diff)):
            plt.show()
        # plt.show()
        DELET_BOOL_VALUE = check.get_status()  # get the status of checkbuttons
        print(DELET_BOOL_VALUE)
        # plt.show()

        # # +-------------------------------+
        # # | TEST: PLOT the sequencer data |
        # # +-------------------------------+
        # fig2 = plt.figure(4, figsize=(7.5, 5))

        # fig2_ax3 = plt.subplot2grid((1, 4), (0, 2), rowspan=1, colspan=2)
        # sb.heatmap(wasserstein_metric*100, annot=True, fmt='.1f')
        # # plt.gca().invert_yaxis()  # invert the y axis
        # ticks_labels = [str(coun+1) for coun in range(num_ray)]
        # fig2_ax3.xaxis.tick_top()
        # plt.yticks(range(num_ray), ticks_labels)
        # plt.xticks(range(num_ray), ticks_labels)
예제 #27
0
class DefacingInterface(BaseReviewInterface):
    """Custom interface to rate the quality of defacing in an MRI scan"""
    def __init__(self,
                 fig,
                 axes,
                 issue_list=cfg.defacing_default_issue_list,
                 next_button_callback=None,
                 quit_button_callback=None,
                 processing_choice_callback=None,
                 map_key_to_callback=None):
        """Constructor"""

        super().__init__(fig, axes, next_button_callback, quit_button_callback)

        self.issue_list = issue_list

        self.prev_axis = None
        self.prev_ax_pos = None
        self.zoomed_in = False
        self.next_button_callback = next_button_callback
        self.quit_button_callback = quit_button_callback
        self.processing_choice_callback = processing_choice_callback
        if map_key_to_callback is None:
            self.map_key_to_callback = {}  # empty
        elif isinstance(map_key_to_callback, dict):
            self.map_key_to_callback = map_key_to_callback
        else:
            raise ValueError('map_key_to_callback must be a dict')

        self.add_checkboxes()
        self.add_process_options()
        # include all the non-data axes here (so they wont be zoomed-in)
        self.unzoomable_axes = [
            self.checkbox.ax, self.text_box.ax, self.bt_next.ax,
            self.bt_quit.ax, self.radio_bt_vis_type
        ]

        # this list of artists to be populated later
        # makes to handy to clean them all
        self.data_handles = list()

    def add_checkboxes(self):
        """
        Checkboxes offer the ability to select multiple tags such as Motion,
        Ghosting Aliasing etc, instead of one from a list of mutual exclusive
        rating options (such as Good, Bad, Error etc).

        """

        ax_checkbox = plt.axes(cfg.position_checkbox_t1_mri,
                               facecolor=cfg.color_rating_axis)
        # initially de-activating all
        check_box_status = [False] * len(self.issue_list)
        self.checkbox = CheckButtons(ax_checkbox,
                                     labels=self.issue_list,
                                     actives=check_box_status)
        self.checkbox.on_clicked(self.save_issues)
        for txt_lbl in self.checkbox.labels:
            txt_lbl.set(color=cfg.text_option_color, fontweight='normal')

        for rect in self.checkbox.rectangles:
            rect.set_width(cfg.checkbox_rect_width)
            rect.set_height(cfg.checkbox_rect_height)

        # lines is a list of n crosses, each cross (x) defined by a tuple of lines
        for x_line1, x_line2 in self.checkbox.lines:
            x_line1.set_color(cfg.checkbox_cross_color)
            x_line2.set_color(cfg.checkbox_cross_color)

        self._index_pass = cfg.defacing_default_issue_list.index(
            cfg.defacing_pass_indicator)

    def add_process_options(self):

        ax_radio = plt.axes(cfg.position_radio_bt_t1_mri,
                            facecolor=cfg.color_rating_axis)
        self.radio_bt_vis_type = RadioButtons(ax_radio,
                                              cfg.vis_choices_defacing,
                                              active=None,
                                              activecolor='orange')
        self.radio_bt_vis_type.on_clicked(self.processing_choice_callback)
        for txt_lbl in self.radio_bt_vis_type.labels:
            txt_lbl.set(color=cfg.text_option_color, fontweight='normal')

        for circ in self.radio_bt_vis_type.circles:
            circ.set(radius=0.06)

    def save_issues(self, label):
        """
        Update the rating

        This function is called whenever set_active() happens on any label,
            if checkbox.eventson is True.

        """

        if label == cfg.visual_qc_pass_indicator:
            self.clear_checkboxes(except_pass=True)
        else:
            self.clear_pass_only_if_on()

        self.fig.canvas.draw_idle()

    def clear_checkboxes(self, except_pass=False):
        """Clears all checkboxes.

        if except_pass=True,
            does not clear checkbox corresponding to cfg.t1_mri_pass_indicator
        """

        cbox_statuses = self.checkbox.get_status()
        for index, this_cbox_active in enumerate(cbox_statuses):
            if except_pass and index == self._index_pass:
                continue
            # if it was selected already, toggle it.
            if this_cbox_active:
                # not calling checkbox.set_active() as it calls the callback
                #   self.save_issues() each time, if eventson is True
                self._toggle_visibility_checkbox(index)

    def clear_pass_only_if_on(self):
        """Clear pass checkbox only"""

        cbox_statuses = self.checkbox.get_status()
        if cbox_statuses[self._index_pass]:
            self._toggle_visibility_checkbox(self._index_pass)

    def _toggle_visibility_checkbox(self, index):
        """toggles the visibility of a given checkbox"""

        l1, l2 = self.checkbox.lines[index]
        l1.set_visible(not l1.get_visible())
        l2.set_visible(not l2.get_visible())

    def get_ratings(self):
        """Returns the final set of checked ratings"""

        cbox_statuses = self.checkbox.get_status()
        user_ratings = [
            self.checkbox.labels[idx].get_text()
            for idx, this_cbox_active in enumerate(cbox_statuses)
            if this_cbox_active
        ]

        return user_ratings

    def allowed_to_advance(self):
        """
        Method to ensure work is done for current iteration,
        before allowing the user to advance to next subject.

        Returns False if atleast one of the following conditions are not met:
            Atleast Checkbox is checked
        """

        if any(self.checkbox.get_status()):
            allowed = True
        else:
            allowed = False

        return allowed

    def reset_figure(self):
        "Resets the figure to prepare it for display of next subject."

        self.clear_data()
        self.clear_checkboxes()
        self.clear_radio_buttons()
        self.clear_notes_annot()

    def clear_data(self):
        """clearing all data/image handles"""

        if self.data_handles:
            for artist in self.data_handles:
                artist.remove()
            # resetting it
            self.data_handles = list()

    def clear_notes_annot(self):
        """clearing notes and annotations"""

        self.text_box.set_val(cfg.textbox_initial_text)
        # text is matplotlib artist
        self.annot_text.remove()

    def clear_radio_buttons(self):
        """Clears the radio button"""

        # enabling default rating encourages lazy advancing without review
        # self.radio_bt_rating.set_active(cfg.index_freesurfer_default_rating)
        for index, label in enumerate(self.radio_bt_vis_type.labels):
            if label.get_text() == self.radio_bt_vis_type.value_selected:
                self.radio_bt_vis_type.circles[index].set_facecolor(
                    cfg.color_rating_axis)
                break
        self.radio_bt_vis_type.value_selected = None

    def on_mouse(self, event):
        """Callback for mouse events."""

        if self.prev_axis is not None:
            if event.inaxes not in self.unzoomable_axes:
                self.prev_axis.set_position(self.prev_ax_pos)
                self.prev_axis.set_zorder(0)
                self.prev_axis.patch.set_alpha(0.5)
                self.zoomed_in = False

        # right or double click to zoom in to any axis
        if (event.button in [3] or event.dblclick) and \
            (event.inaxes is not None) and \
            event.inaxes not in self.unzoomable_axes:
            self.prev_ax_pos = event.inaxes.get_position()
            event.inaxes.set_position(cfg.zoomed_position)
            event.inaxes.set_zorder(1)  # bring forth
            event.inaxes.set_facecolor('black')  # black
            event.inaxes.patch.set_alpha(1.0)  # opaque
            self.zoomed_in = True
            self.prev_axis = event.inaxes
        else:
            pass

        self.fig.canvas.draw_idle()

    def on_keyboard(self, key_in):
        """Callback to handle keyboard shortcuts to rate and advance."""

        # ignore keyboard key_in when mouse within Notes textbox
        if key_in.inaxes == self.text_box.ax or key_in.key is None:
            return

        key_pressed = key_in.key.lower()
        # print(key_pressed)
        if key_pressed in ['right', ' ', 'space']:
            self.next_button_callback()
        elif key_pressed in ['ctrl+q', 'q+ctrl']:
            self.quit_button_callback()
        elif key_pressed in self.map_key_to_callback:
            # notice parentheses at the end
            self.map_key_to_callback[key_pressed]()
        else:
            if key_pressed in cfg.abbreviation_t1_mri_default_issue_list:
                checked_label = cfg.abbreviation_t1_mri_default_issue_list[
                    key_pressed]
                self.checkbox.set_active(
                    cfg.t1_mri_default_issue_list.index(checked_label))
            else:
                pass

        self.fig.canvas.draw_idle()
class ImageEditor(BaseEditor):
    def __init__(self,
                 image,
                 range_of_interest,
                 mask_func,
                 estimate_p_start,
                 estimate_p_target,
                 resize_rate=1.0,
                 travel_threshold=0.0,
                 wave_coef=2.0,
                 step_width=1.0):
        '''
        Args:
            image_raw (np.ndarray(uint8)): raw image, shape=(H, W, 3)
            mask_raw (np.ndarray(bool)):  mask image, shape=(H, W)
            mask_func (func): mask_func estimator
            estimate_p_start (func): basal point estimator
            estimate_p_target (func): tip point estimator
            travel_threshold (float, optional): Boundary dist, Defaults to 0.
            wave_coef (float, optional): wave speed coefficient. Defaults to 2.
            step_width (float, optional): backward step width. Defaults to 1.
        '''
        self.image_size = tuple(image.shape[:2])
        self.image_edit = np.concatenate([
            cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
            np.full((*image.shape[:2], 1), 255, dtype=np.uint8)
        ],
                                         axis=2)
        self.roi_raw = range_of_interest & mask_func(image)
        self.roi_edit = self.roi_raw

        self.mask_func = mask_func
        self.estimate_p_start = estimate_p_start
        self.estimate_p_target = estimate_p_target
        self.extensions = []
        self.extensions += self.estimate_p_start.func.extensions
        self.extensions += self.estimate_p_target.func.extensions
        self.extensions += self.mask_func.extensions

        self._resize_rate = resize_rate
        self._wave_coef = wave_coef
        self._travel_threshold = travel_threshold
        self._step_width = step_width

        self._update_dist_flag = self._update_travel_flag = True

    def estimate(self):
        self.p_start = self.estimate_p_start(self.image_edit,
                                             self.roi_edit,
                                             self.dist_field,
                                             resize_rate=self._resize_rate)
        self.p_start_init = np.array(self.p_start)
        self.endpoint_num, self.p_target = self.estimate_p_target(
            self.image_edit,
            self.roi_edit,
            self.dist_field,
            self.travel_field,
            resize_rate=self._resize_rate)
        self.p_target_init = np.array(self.p_target)
        self.is_converged = [False] * self.endpoint_num
        self.path = [np.array([]) for _ in range(self.endpoint_num)]
        travel_field = self.travel_field
        if travel_field is None:
            return
        travel_field[self.dist_field < self.travel_threshold] = 1e5
        for _, _p_target in enumerate(self.p_target):
            self.is_converged[_], self.path[_] = estimate_path(
                travel_field, self.p_start, _p_target, self.step_width)

    def retrieve(self):
        return [[_check, _path / self._resize_rate]
                for _check, _path in zip(self.is_converged, self.path)]

    def launch(self):
        if not hasattr(self, "path"):
            self.estimate()
        self.fig = Figure(figsize=(15, 6))
        self.fig.create_grid((1, 3), wspace=0.0, width_ratios=(2, 1, 2))
        self.fig[1].create_grid((4, 3),
                                wspace=0.0,
                                width_ratios=(1, 4.5, 1),
                                hspace=0.05,
                                height_ratios=(5, 3.5, 4, 7))

        ax_command = self.fig[1][0, 1]
        ax_command.create_grid((5, 1), hspace=0.0)
        self.button_reset = Button(ax_command[0], 'reset')
        self.button_reset.label.set_fontsize(15)
        self.button_reset.on_clicked(self.reset)
        self.button_start = Button(ax_command[1], 'estimate basal point')
        self.button_start.label.set_fontsize(12)
        self.button_start.on_clicked(self._estimate_start)
        self.button_target = Button(ax_command[2], 'estimate tip point(s)')
        self.button_target.label.set_fontsize(12)
        self.button_target.on_clicked(self._estimate_target)
        self.button_path = Button(ax_command[3], 'estimate skeletal curve')
        self.button_path.label.set_fontsize(10)
        self.button_path.on_clicked(self._estimate_path)
        self.button_end = Button(ax_command[4], 'end')
        self.button_end.label.set_fontsize(15)
        self.button_end.on_clicked(self.terminate)
        self.fig._fig.canvas.mpl_connect('close_event', self.terminate)

        ax_action = self.fig[1][1, 1]
        ax_action.create_grid((2, 1), hspace=0.0, height_ratios=(5, 2))
        self.fig._fig.canvas.mpl_connect('button_press_event',
                                         self._on_click_set)
        self.fig._fig.canvas.mpl_connect('motion_notify_event',
                                         self._on_click_set)
        self._set_interactive_tool(self.fig[0], ax_action[0], ax_action[1])

        ax_fmm_param = self.fig[1][2, 1]
        ax_fmm_param.create_grid((4, 1),
                                 hspace=0.0,
                                 height_ratios=(2, 1, 1, 1))
        self.button_field = CheckButtons(ax_fmm_param[0],
                                         ["distance", "travel"])
        for _label in self.button_field.labels:
            _label.set_fontsize(12)
        self.button_field.on_clicked(self.render)
        self.slider_wave = Slider(ax_fmm_param[1],
                                  r'$\alpha$',
                                  0.0,
                                  5.0,
                                  valinit=self._wave_coef,
                                  valfmt="%.1f")
        self.slider_dist = Slider(ax_fmm_param[2],
                                  r'$\theta_T$',
                                  -5.0,
                                  5.0,
                                  valinit=self._travel_threshold,
                                  valfmt="%.1f")
        self.slider_step = Slider(ax_fmm_param[3],
                                  r'$\delta$',
                                  0.1,
                                  5.0,
                                  valinit=self._step_width,
                                  valfmt="%.1f")
        self.slider_wave.on_changed(self._on_click_slider)
        self.slider_dist.on_changed(self._on_click_slider)
        self.sliders = [self.slider_wave, self.slider_dist, self.slider_step]

        ax_extension = self.fig[1][3, 1]
        ratios = [1] * len(self.extensions)
        ratios[0] = 2
        self._set_extension(ax_extension, height_ratios=ratios)

        self.fig[0].add_zoom_func()
        rect = plt.Rectangle((0, 0),
                             self.image_size[1],
                             self.image_size[0],
                             fc='w',
                             ec='gray',
                             hatch='++',
                             alpha=0.5,
                             zorder=-10)
        self.fig[0].add_patch(rect)
        self.im_image_left = self.fig[0].plot_matrix(np.zeros(
            (*self.image_size, 4)),
                                                     picker=True,
                                                     flip_axis=True,
                                                     colorbar=False)
        self.fig[0].set_xlim([0, self.image_size[1]])
        self.fig[0].set_ylim([self.image_size[0], 0])
        self.fig[0].set_aspect("equal", "box")
        self.fig[0].set_title("extracted area "
                              "(ROI+HSV, left click: write, "
                              "left click + shift key: erase)")
        self.im_start, = self.fig[0].plot(self.p_start[1],
                                          self.p_start[0],
                                          "ro",
                                          markersize=10)
        self.im_target = [None] * self.endpoint_num
        for _, _p_target in enumerate(self.p_target):
            if _p_target is None:
                self.im_target[_], = self.fig[0].plot(0, 0, "*", markersize=10)
                self.im_target[_].set_visible(False)
                if len(self.path[_]) > 0:
                    self.im_target[_].set_data(*self.path[_][-1][::-1])
                    self.im_target[_].set_visible(True)
            else:
                self.im_target[_], = self.fig[0].plot(_p_target[1],
                                                      _p_target[0],
                                                      "*",
                                                      markersize=10)

        self.fig[2].add_zoom_func()
        self.im_image_right = self.fig[2].plot_matrix(self.image_edit[..., :3],
                                                      picker=True,
                                                      flip_axis=True,
                                                      colorbar=False)
        self.fig[2].set_xlim([0, self.image_size[1]])
        self.fig[2].set_ylim([self.image_size[0], 0])
        self.fig[2].set_aspect("equal", "box")
        self.fig[2].get_xaxis().set_visible(False)
        self.fig[2].get_yaxis().set_visible(False)
        self.fig[2].set_title("original image")

        self.im_path_left = [None] * self.endpoint_num
        self.im_path_right = [None] * self.endpoint_num
        for _ in range(self.endpoint_num):
            self.im_path_left[_], = self.fig[0].plot([], [],
                                                     color="pink",
                                                     marker="+",
                                                     picker=True)
            self.im_path_right[_], = self.fig[2].plot([], [],
                                                      color="pink",
                                                      marker="+",
                                                      picker=True)
        self.reset()
        self.fig.show()

    def reset(self, event=None):
        self.roi_edit = self.roi_raw & self.mask_func(self.image_edit[..., :3])
        for ext in self.extensions:
            ext.reset(self, event)
        for sld in self.sliders:
            sld.reset()
        self.path = [np.array([]) for _ in range(self.endpoint_num)]
        self.is_converged = [False] * self.endpoint_num
        self.p_start = np.array(self.p_start_init)
        self.p_target = np.array(self.p_target_init)
        self._update_dist_flag = True
        self._update_travel_flag = True
        self.update()

    def update(self, use_mask_func=False):
        if use_mask_func:
            mask = self.mask_func(self.image_edit[..., :3])
            self.roi_edit = self.roi_raw & mask
        self.image_edit[~self.roi_edit, 3] = 0
        self.image_edit[self.roi_edit, 3] = 255
        self.im_image_left.set_data(self.image_edit[::-1])
        self._update_dist_flag = True
        self._update_travel_flag = True
        self.render()

    def render(self, event=None):
        self.im_start.set_data(self.p_start[1], self.p_start[0])
        for _id, _p_target in enumerate(self.p_target):
            self.im_target[_id].set_data(_p_target[1], _p_target[0])

        for _i, _path in enumerate(self.path):
            if _path.size > 0:
                self.im_path_left[_i].set_data(_path[:, 1], _path[:, 0])
                self.im_path_right[_i].set_data(_path[:, 1], _path[:, 0])

        check_button = self.button_field.get_status()
        if check_button[0] and (self.im_dist is None
                                or self._update_dist_flag):
            dist_field = np.array(self.dist_field)
            if dist_field is not None:
                dist_field[dist_field < 0] = 0
                self.im_dist = self.fig[0].plot_matrix(dist_field,
                                                       contour=True,
                                                       flip_axis=False,
                                                       levels=100,
                                                       colorbar=False)
        if self.im_dist is not None:
            for _ in self.im_dist.collections:
                _.set_visible(check_button[0])

        if check_button[1] and (self.im_travel is None
                                or self._update_travel_flag):
            travel_field = self.travel_field
            if travel_field is not None:
                self.im_travel = self.fig[0].plot_matrix(self.travel_field,
                                                         contour=True,
                                                         flip_axis=False,
                                                         levels=100,
                                                         colorbar=False)
        if self.im_travel is not None:
            for _ in self.im_travel.collections:
                _.set_visible(check_button[1])
        self.fig._fig.canvas.draw_idle()

    def terminate(self, event=None):
        self.fig.close()

    @property
    def wave_coef(self):
        if hasattr(self, "slider_wave"):
            return self.slider_wave.val
        else:
            return self._wave_coef

    @property
    def travel_threshold(self):
        if hasattr(self, "slider_dist"):
            return self.slider_dist.val
        else:
            return self._travel_threshold

    @property
    def step_width(self):
        if hasattr(self, "slider_step"):
            return self.slider_step.val
        else:
            return self._step_width

    @property
    def dist_field(self):
        if self._update_dist_flag or (not hasattr(self, "_dist_field")):
            self._dist_field = calc_distance(self.roi_edit)
            self._update_dist_flag = False
            if hasattr(self, "im_dist") and (self.im_dist is not None):
                for _ in self.im_dist.collections:
                    _.remove()
            self.im_dist = None
            self._update_travel_flag = True
        if self._dist_field is None:
            return None
        else:
            return np.array(self._dist_field)

    @property
    def travel_field(self):
        if self._update_travel_flag or (not hasattr(self, "_travel_field")):
            self._travel_field = calc_travel(self.dist_field, self.p_start,
                                             self.wave_coef,
                                             self.travel_threshold)
            self._update_travel_flag = False
            if hasattr(self, "im_travel") and (self.im_travel is not None):
                for _ in self.im_travel.collections:
                    _.remove()
            self.im_travel = None
        if self._travel_field is None:
            return None
        else:
            return np.array(self._travel_field)

    def _on_click_slider(self, event):
        self._update_travel_flag = True
        self.render()

    def _on_click_set(self, event):
        if not (self.area_bbox.x0 < event.x < self.area_bbox.x1):
            return
        if not (self.area_bbox.y0 < event.y < self.area_bbox.y1):
            return
        if event.key is not None:
            return
        if event.button is None:
            self.drag_id = None
            return

        def _position(ax):
            coords = ax.get_path().vertices
            return ax.get_transform().transform(coords)

        p_mouse = np.array([event.x, event.y])
        p_mouse_data = np.array([event.ydata, event.xdata])
        if event.button == 1 and self.drag_id is None:
            if np.linalg.norm(_position(self.im_start) - p_mouse) < 10.0:
                self.drag_id = -1
            else:
                for _id in range(self.endpoint_num):
                    if np.linalg.norm(
                            _position(self.im_target[_id]) - p_mouse) < 10.0:
                        self.drag_id = _id
                        break
        if self.drag_id == -1:
            self.p_start = p_mouse_data
            self._update_travel_flag = True
        elif self.drag_id is not None:
            self.p_target[self.drag_id] = p_mouse_data
            self.im_target[self.drag_id].set_data(
                self.p_target[self.drag_id][1], self.p_target[self.drag_id][0])
        self.render()

    def _estimate_start(self, event=None):
        self.p_start = self.estimate_p_start(self.image_edit[..., :3],
                                             self.roi_edit,
                                             self.dist_field,
                                             p_start_now=self.p_start,
                                             resize_rate=self._resize_rate)
        self._update_travel_flag = True
        self.render()

    def _estimate_target(self, event=None):
        self.endpoint_num, self.p_target = self.estimate_p_target(
            self.image_edit[..., :3],
            self.roi_edit,
            self.dist_field,
            self.travel_field,
            p_target_now=self.p_target,
            resize_rate=self._resize_rate)
        self.render()

    def _estimate_path(self, event=None):
        travel_field = self.travel_field
        if travel_field is None:
            for _, _p_target in enumerate(self.p_target):
                self.is_converged[_], self.path[_] = False, np.array([])
        else:
            travel_field[self.dist_field < self.travel_threshold] = 1e5
            for _, _p_target in enumerate(self.p_target):
                self.is_converged[_], self.path[_] = estimate_path(
                    travel_field, self.p_start, _p_target, self.step_width)
        self.render()
예제 #29
0
class LyaGUI(SimpleFitGUI):
    """ This is a GUI to perform model-independent measurements of Ly-alpha line
    profiles (primarily interesting for emission).

    It measures the following craracteristics, when present:
        - Equivalent Width
        - Red and blue peak velocity and separation
        - Red and blue peak flux relative to continuum, and their ratio
        - Red and blue integrated flux, and their ratio
        - Peak-to-valley flux ratios.
        - Asymmetry (crookedity?) of red peak
    """
    galaxy_name = ''
    data = None
    interp = None
    errs = None
    # mask = None
    smoothed = None
    num_realizations = 1000
    z = 0.
    summary_dict = {}
    cenwave = 1215.67 * (1 + z)

    _current_peak = 'Red'
    _peaks = np.array(['Red', 'Blue', 'Valley'])
    _peak_on = np.array([True, True, True])
    _colors = {'Red': 'tab:red', 'Blue': 'tab:blue', 'Valley': '0.7'}
    _extra_plots = {}
    _velocities = {}  # Just for use in plots

    def __init__(self,
                 summaryfile=None,
                 inwave=None,
                 indata=None,
                 inerrs=None,
                 inmask=None,
                 smooth=None):
        self.data = indata
        self.wave = inwave
        self.errs = inerrs
        self.mask = inmask
        if summaryfile:
            self.open_summary(summaryfile)
            # Interpolate masked areas
            self.data[self.mask] = np.nan
            self.nonnanidx = np.where(~self.mask)[0]
            self.interp = np.interp(self.wave, self.wave[self.nonnanidx],
                                    self.data[self.nonnanidx])
            self.interr = np.interp(self.wave, self.wave[self.nonnanidx],
                                    self.errs[self.nonnanidx])
        if smooth == 'll':
            lle = KernelReg(self.interp, self.wave, 'c', bw=[10])
            mean, marg = lle.fit()
            del marg
            self.smoothed = mean
        elif smooth == 'box':
            mean = np.convolve(self.data, np.array([1, 1, 1]) / 3)
        else:
            self.smoothed = self.data
        self._build_plot()

    def open_summary(self, summary_file):
        with open(summary_file) as f:
            summary = yaml.load(f)
            # summary = yaml.full_load(f)
        self.z = summary['redshift']
        self.galaxy_name = summary['galaxyname']
        self.data = np.array(summary['transitions']['Ly alpha']['data'])
        self.errs = np.array(summary['transitions']['Ly alpha']['errs'])
        self.wave = np.array(summary['transitions']['Ly alpha']['wave'])
        self.mask = np.array(summary['transitions']['Ly alpha']['mask'])
        try:
            self._cfp = summary['transitions']['Ly alpha']['cont_fit_params']
        except KeyError:
            self._cfp = summary['transitions']['Ly alpha'][
                'continuum_fit_params']
        self._cont = self.wave * self._cfp['slope'] + self._cfp['intercept']

    def _build_plot(self):
        self.fig = plt.figure(figsize=(8, 4))
        self.ax = self.fig.add_axes([0.08, 0.08, 0.8, 0.8])
        self.chax = self.fig.add_axes([0.90, 0.55, 0.09, 0.3], frameon=False)
        self.chax.set_title('Components \npresent', size='x-small')
        self.rax = self.fig.add_axes([0.90, 0.25, 0.09, 0.2], frameon=False)
        self.rax.set_title('Current \ncomponent', size='x-small')
        self.okax = self.fig.add_axes([0.9, 0.05, 0.09, 0.08])
        self.ok_button = Button(self.okax, 'Done')
        self.ok_button.on_clicked(self._ok_clicked)
        self.reax = self.fig.add_axes([0.9, 0.15, 0.09, 0.08])
        self.re_button = Button(self.reax, 'Reset')
        self.re_button.on_clicked(self._reset_clicked)
        for ax in self.chax, self.rax:
            ax.tick_params(length=0, labelleft='off', labelbottom='off')
        self.ax.plot(self.wave, self.data, 'k-', drawstyle='steps-mid')
        if self.interp is not None:
            self.ax.plot(self.wave, self.interp, '-', color='C1', zorder=0)
        # if self.smoothed is not None:
        #     self.ax.plot(self.wave, self.smoothed, '-', color='C2')
        self.ax.axhline(0, ls='-', color='k', lw=1)
        self.ax.axvline(1215.67 * (1 + self.z), ls=':', color='k', lw=.8)
        self.ax.axhline(1, ls=':', color='k', lw=.8)
        # Insert check buttons
        clabels = self._peaks
        cons = self._peak_on
        self.check = CheckButtons(self.chax, clabels, cons)
        self.check.on_clicked(self._check_clicked)
        rlabels = clabels[np.where(cons)]
        self.radio = RadioButtons(self.rax, rlabels)
        # self.radio.on_clicked(self._radio_clicked)
        self._selector = SpanSelector(
            self.ax,
            self._onselect,
            'horizontal',
            minspan=self._wave_diffs.mean() * 5,
        )
        self.ax.set_xlabel('Wavelength')
        self.ax.set_ylabel('Normalized flux')

    def _onselect(self, xmin, xmax):
        self.fit_active(xmin, xmax)

    def _ok_clicked(self, event):
        self.save_summary()
        self.ax.axvline(v_to_wl(self.summary_dict['Red']['vpeak'][2],
                                self.refwave),
                        ls='--',
                        color='0.5')
        self.ax.axvline(v_to_wl(self._velocities['Red']['v05'], self.refwave),
                        ls='--',
                        color='0.5')
        self.ax.axvline(v_to_wl(self._velocities['Red']['v50'], self.refwave),
                        ls='--',
                        color='0.5')
        self.ax.axvline(v_to_wl(self._velocities['Red']['v95'], self.refwave),
                        ls='--',
                        color='0.5')
        # self.ax.draw()
        plt.draw()
        # plt.close(self.ax.figure)

    def _reset_clicked(self, event):
        self.summary_dict = {}
        for a in self._extra_plots.values():
            a.remove()
        print('You presset RESET and are now back to scratch.')

    def _radio_clicked(self, event):
        self._selector.rectprops['facecolor'] = self._colors[event]

    def _check_clicked(self, event):
        print('Event: ', event)
        self.rax.remove()
        self.rax = self.fig.add_axes([0.90, 0.25, 0.09, 0.2], frameon=False)
        self.rax.set_title('Current \ncomponent', size='x-small')
        self._peak_on = np.array(self.check.get_status())
        rlabels = self._peaks[np.where(self.check.get_status())]
        self.radio = RadioButtons(self.rax, rlabels)
        plt.draw()

    def measure_flux(self, xmin, xmax, iters=1):
        idx = np.where((self.wave > xmin) & (self.wave < xmax))
        wav = self.wave[idx]
        vel = wl_to_v(wav, self.refwave)
        dat = self.interp[idx]  # - 1
        fluxes, vmaxs, vmins, fwhms = [], [], [], []
        bhms, rhms, asymmetry, asymGronKo = [], [], [], []
        fmax, fmin = [], []
        v05, v50, v95 = [], [], []
        if self.errs is None:
            iters = 1
        for i in range(iters):
            perturb = np.array(
                [np.random.normal(scale=e) for e in self.errs[idx]])
            if i > 0:
                pertdata = dat + perturb
            else:
                pertdata = dat
            # fluxes.append(((pertdata - 1) * self._wave_diffs[idx]).sum())
            fluxes.append(((pertdata) * self._wave_diffs[idx]).sum())
            vmaxs.append(vel[pertdata.argmax()])
            vmins.append(vel[pertdata.argmin()])
            cumflux = np.cumsum(pertdata - 1)
            # cumflux = np.cumsum(pertdata)
            q05 = vel[np.absolute(cumflux / cumflux.max() - 0.05).argmin()]
            q50 = vel[np.absolute(cumflux / cumflux.max() - 0.50).argmin()]
            q95 = vel[np.absolute(cumflux / cumflux.max() - 0.95).argmin()]
            A = (q95 - q50) / (q50 - q05)
            # Agk = (q95 - vel[pertdata.argmax()]) / (vel[pertdata.argmax()] - q05)
            fpeak = cumflux[pertdata.argmax()]
            ftot = cumflux.max()
            Agk = (ftot - fpeak) / fpeak
            asymmetry.append(A)
            asymGronKo.append(Agk)
            fwidx = np.where(pertdata - 1 > (pertdata - 1).max() / 2)[0]
            bhm = vel[fwidx.min()]
            rhm = vel[fwidx.max()]
            bhms.append(bhm)
            rhms.append(rhm)
            fmax.append((pertdata - 1).max())
            fmin.append((pertdata - 1).min())
            # fmax.append((pertdata).max())
            # fmin.append((pertdata).min())
            fwhms.append(rhm - bhm)
            v05.append(q05)
            v95.append(q95)
            v50.append(q50)
        self._velocities[self.radio.value_selected] = {
            'v05': np.median(v05),
            'v50': np.median(v50),
            'v95': np.median(v95),
        }
        return fluxes, vmaxs, vmins, fwhms, bhms, rhms, asymmetry, fmin, fmax, asymGronKo

    def absolute_flux(self, xmin=None, xmax=None, iters=1):
        # TODO Write sane defaults for xmin and xmax, should go via velocity.
        afs = []
        afarray = (self.interp - 1) * self._cont
        aferrar = (self.interr) * self._cont
        ranges = [self.summary_dict[i]['range'] for i in self._peaks_active]
        therange = np.array(ranges).flatten()
        if xmin is None:
            xmin = therange.min()
        if xmax is None:
            xmax = therange.max()
        if xmin:
            afarray[self.wave < xmin] = 0
        if xmax:
            afarray[self.wave > xmax] = 0
        if self.errs is None:
            iters = 1
        for i in range(iters):
            if i == 0:
                pertdata = afarray
            else:
                perturb = np.array(
                    [np.random.normal(scale=e) for e in np.absolute(aferrar)])
                pertdata = afarray + perturb
            afs.append((pertdata * self._wave_diffs).sum())
        self.summary_dict['AbsFlux'] = np.percentile(afs,
                                                     [2.5, 16, 50, 84, 97.5])
        return self.summary_dict['AbsFlux']  # np.atleast_1d(afs)

    def equivalent_width(self, xmin=None, xmax=None, iters=1000):
        ews = []
        ewarray, ewerrar = (self.interp - 1), self.interr
        ranges = [self.summary_dict[i]['range'] for i in self._peaks_active]
        therange = np.array(ranges).flatten()
        if xmin is None:
            xmin = therange.min()
        if xmax is None:
            xmax = therange.max()
        if xmin:
            ewarray[self.wave < xmin] = 0
        if xmax:
            ewarray[self.wave > xmax] = 0
        if self.errs is None:
            iters = 1
        for i in range(iters):
            if i == 0:
                pertdata = ewarray
            else:
                perturb = np.array(
                    [np.random.normal(scale=e) for e in np.absolute(ewerrar)])
                pertdata = ewarray + perturb
            ews.append((pertdata * self._wave_diffs).sum())
        print(np.std(ews))
        self.summary_dict['EW_lya'] = np.percentile(ews,
                                                    [2.5, 16, 50, 84, 97.5])
        return self.summary_dict['EW_lya']

    def fit_red(self, xmin, xmax):
        self._selector.rectprops['facecolor'] = 'tab:red'
        self._extra_plots['redspan'] = self.ax.axvspan(
            xmin, xmax, color=self._colors['Red'], alpha=.5, zorder=0)
        flux, vpeak, vmin, fwhm, bhms, rhms, A_qs, fmin, fmax, Agks = \
            self.measure_flux(xmin, xmax, iters=1000)
        self.summary_dict['Red'] = {
            'range': (xmin, xmax),
            'flux': np.percentile(flux, [2.5, 16, 50, 84, 97.5]),
            'fmax': np.percentile(fmax, [2.5, 16, 50, 84, 97.5]),
            'vpeak': np.percentile(vpeak, [2.5, 16, 50, 84, 97.5]),
            'fwhm': np.percentile(fwhm, [2.5, 16, 50, 84, 97.5]),
            'blue_at_half_width': np.percentile(bhms, [2.5, 16, 50, 84, 97.5]),
            'red_at_half_width': np.percentile(rhms, [2.5, 16, 50, 84, 97.5]),
            'Asymmetry': np.percentile(A_qs, [2.5, 16, 50, 84, 97.5]),
            'Asym_Gr_Ko': np.percentile(Agks, [2.5, 16, 50, 84, 97.5]),
        }
        print('Fitting red peak')
        print(xmin, xmax)

    def fit_blue(self, xmin, xmax):
        self._extra_plots['bluespan'] = self.ax.axvspan(
            xmin, xmax, color=self._colors['Blue'], alpha=.5, zorder=0)
        flux, vpeak, vmin, fwhm, bhms, rhms, A_qs, fmin, fmax, Agks = \
            self.measure_flux(xmin, xmax, iters=1000)
        self.summary_dict['Blue'] = {
            'range': (xmin, xmax),
            'flux': np.percentile(flux, [2.5, 16, 50, 84, 97.5]),
            'fmax': np.percentile(fmax, [2.5, 16, 50, 84, 97.5]),
            'vpeak': np.percentile(vpeak, [2.5, 16, 50, 84, 97.5]),
            'fwhm': np.percentile(fwhm, [2.5, 16, 50, 84, 97.5]),
            'blue_at_half_width': np.percentile(bhms, [2.5, 16, 50, 84, 97.5]),
            'red_at_half_width': np.percentile(rhms, [2.5, 16, 50, 84, 97.5]),
        }
        print("The following has been added/overwritten" +
              " in the summary_dict['Blue'].")
        return self.summary_dict['Blue']

    def fit_valley(self, xmin, xmax):
        flux, vpeak, vmin, fwhm, bhms, rhms, A_qs, fmin, fmax, Agks = \
            self.measure_flux(xmin, xmax, iters=1000)
        self._extra_plots['Valley'] = self.ax.axvspan(
            xmin, xmax, color=self._colors['Valley'], alpha=5.)
        self.summary_dict['Valley'] = {
            'range': (xmin, xmax),
            'minflux': np.percentile(fmin, [2.5, 16, 50, 84, 97.5]),
            'vmin': np.percentile(vmin, [2.5, 16, 50, 84, 97.5])
        }
        # print(xmin, xmax)
        print('Fitting valley')
        print("The following has been added/overwritten" +
              " in the summary_dict['Valley'].")
        return self.summary_dict['Valley']

    def fit_active(self, xmin, xmax):
        if self.radio.value_selected == 'Blue':
            self.fit_blue(xmin, xmax)
        elif self.radio.value_selected == 'Red':
            self.fit_red(xmin, xmax)
        elif self.radio.value_selected == 'Valley':
            self.fit_valley(xmin, xmax)

    fitfuncs = {'Red': fit_red, 'Blue': fit_blue, 'Valley': fit_valley}

    def save_summary(self):
        # TODO: Implement something.
        pass

    def save_summary_table(self, path="summarytable.ecsv"):
        d = self.summary_dict
        # Make sure absolute fliux is measured (TODO make better)
        if not "AbsFlux" in d.keys():
            d["AbsFlux"] = self.absolute_flux(iters=1000)
        if not "EW_lua" in d.keys():
            self.equivalent_width(iters=1000)
        if "Red" in self._peaks_active:
            asym = d['Red']['Asymmetry']
            asyms = asym[2] - asym[1], asym[2], asym[3] - asym[2]  # A_red
            asgk = d['Red']['Asym_Gr_Ko']
            asgks = asgk[2] - asgk[1], asgk[2], asgk[3] - asgk[2]  # A_red
            fwhr = d['Red']['fwhm']
            fwhm_reds = fwhr[2] - fwhr[1], fwhr[2], fwhr[3] - fwhr[
                2]  # FWHM_red
            lr = d['Red']['flux']
            l_red = lr[2] - lr[1], lr[2], lr[3] - lr[2]  # L_red
            fr = d['Red']['fmax']
            f_red = fr[2] - fr[1], fr[2], fr[3] - fr[2]  # F_red
            vr = d['Red']['vpeak']
            vpeak_red = vr[2] - vr[1], vr[2], vr[3] - vr[2]  # v_red
        if "Blue" in self._peaks_active:
            fwhb = d['Blue']['fwhm']
            fwhm_blue = fwhb[2] - fwhb[1], fwhb[2], fwhb[3] - fwhb[
                2]  # FWHM_blue
            lb = d['Blue']['flux']
            l_blue = lb[2] - lb[1], lb[2], lb[3] - lb[2]  # L_red
            fb = d['Blue']['fmax']
            f_blue = fb[2] - fb[1], fb[2], fb[3] - fb[2]  # F_red
            vb = d['Blue']['vpeak']
            vpeak_blue = vb[2] - vb[1], vb[2], vb[3] - vb[2]  # v_blue
        if "Valley" in self._peaks_active:
            vv = d['Valley']['vmin']
            v_valley = vv[2] - vv[1], vv[2], vv[3] - vv[2]  # Maybe not? v_min
            fv = d['Valley']['minflux']
            f_valley = fv[2] - fv[1], fv[2], fv[3] - fv[2]  # F_valley
        af = d['AbsFlux']
        abs_flux = af[2] - af[1], af[2], af[3] - af[2]
        ewl = d["EW_lya"]
        EW_lya = ewl[2] - ewl[1], ewl[2], ewl[3] - ewl[2]
        # Create interim output dictionary
        outdict = {}
        if "Red" in self._peaks_active:
            tmp = {
                'fwhm_red': fwhm_reds,
                'f_red': f_red,
                'l_red': l_red,
                'A_red': asyms,
                'A_GK': asgks,
                'v_red': vpeak_red,
            }
            outdict.update(tmp)
        if "Blue" in self._peaks_active:
            tmp = {
                'fwhm_blue': fwhm_blue,
                'f_blue': f_blue,
                'l_blue': l_blue,
                'v_blue': vpeak_blue,
            }
            outdict.update(tmp)

        if "Valley" in self._peaks_active:
            tmp = {
                'v_valley': v_valley,
                'f_valley': f_valley,
            }
            outdict.update(tmp)
        outdict.update({'AbsFlux': abs_flux, 'EW_Lya': EW_lya})
        # Now make it a dataframe
        outframe = pd.DataFrame.from_dict(outdict)
        outframe.set_index(pd.Index(['Low', 'Median', 'High'],
                                    name=self.galaxy_name),
                           inplace=True)
        outframe = outframe.T
        # Now make it a Table
        outtable = Table.from_pandas(outframe.reset_index())
        outtable.meta['identifier'] = self.galaxy_name
        outtable.write(path, format='ascii.ecsv')
        return outtable

    @property
    def _wave_diffs(self):
        diffs = self.wave[1:] - self.wave[:-1]
        diffs = np.append(diffs, diffs[-1])
        return diffs

    @property
    def _peaks_active(self):
        return self._peaks[self._peak_on]

    @property
    def refwave(self):
        return 1215.67 * (1 + self.z)

    @property
    def ref_wl(self):
        return 1215.67 * (1 + self.z)

    def __call__(self):
        self.fig.show()
예제 #30
0
class EGG_Control_Panel:
    def __init__(self, Egg):

        # =============================================================================
        #         Initialize Egg
        # =============================================================================

        self.Egg = Egg

        # =============================================================================
        #         Initialize Figure and Main Axis
        # =============================================================================

        self.figure = plt.figure(num=self.Egg.EGG_figure_Name +
                                 " Control Panel",
                                 figsize=[5.5, 8],
                                 clear=True)
        self.main_ax = self.figure.add_axes([0, 0, 1, 1],
                                            label="main_ax",
                                            facecolor='white')

        self.main_ax.tick_params(
            axis='x',  # changes apply to the x-axis
            which='both',  # both major and minor ticks are affected
            bottom=False,  # ticks along the bottom edge are off
            top=False,  # ticks along the top edge are off
            left=False,
            right=False,
            labelbottom=False)  # labels along the bottom edge are off

        self.main_ax.tick_params(
            axis='y',  # changes apply to the x-axis
            which='both',  # both major and minor ticks are affected
            left=False,
            right=False,
            labelleft=False)  # labels along the bottom edge are off

        self.main_ax.margins(x=0)
        self.main_ax.set_xlim([0, 1])
        self.main_ax.set_ylim([0, 1])

        # =============================================================================
        #       Create Buttons
        # =============================================================================

        Button_width = .2
        Button_height = .06
        Edge_width = 0.05
        Button_spacing = 0.03
        Top_gap = 0
        Bottom_Gap = .5
        Left_gap = 0
        Right_gap = 0

        Button_count = 0

        Button_Box = patches.Rectangle(
            (Edge_width / 2, Edge_width + Bottom_Gap), 1 - Edge_width,
            1 - Bottom_Gap - 1.5 * Edge_width)
        #        Button_Box = patches.Rectangle((0.15, 0.5), .25, .25)
        self.main_ax.add_patch(Button_Box)

        max_buttons_per_column = (int)(math.floor(
            (1 + Button_spacing - 2 * Edge_width - Top_gap - Bottom_Gap) /
            (Button_height + Button_spacing)))

        max_buttons_per_row = (int)(math.floor(
            (1 + Button_spacing - 2 * Edge_width - Left_gap - Right_gap) /
            (Button_width + Button_spacing)))

        def Button_params():
            column_number = math.floor(Button_count / max_buttons_per_column)
            row_number = Button_count - max_buttons_per_column * column_number

            left = Edge_width + Left_gap + (
                (Button_width + Button_spacing) * column_number)
            bottom = 1 - Edge_width - Top_gap - Button_height - (
                Button_height + Button_spacing) * row_number

            if Button_count > max_buttons_per_column * max_buttons_per_row:
                raise Exception("TOO MANY BUTTONS!!!")

            return [left, bottom, Button_width, Button_height]

        self.Reset_Button_ax = self.figure.add_axes(Button_params(),
                                                    label="Reset_Button_ax",
                                                    facecolor='white')
        Button_count += 1

        self.Reset_Button = Button(
            self.Reset_Button_ax,
            label="Reset",
        )
        self.Reset_Button.on_clicked(self.Egg.Reset_GUI(Reset_Tone=True))

        self.Play_Button_ax = self.figure.add_axes(Button_params(),
                                                   label="Play_Button_ax",
                                                   facecolor='white')
        Button_count += 1
        self.Play_Button = Button(
            self.Play_Button_ax,
            label="Play",
        )
        self.Play_Button.on_clicked(self.Egg.Play)

        self.Loop_Button_ax = self.figure.add_axes(Button_params(),
                                                   label="Loop_Button_ax",
                                                   facecolor='white')
        Button_count += 1
        self.Loop_Button = Button(
            self.Loop_Button_ax,
            label="Loop",
        )
        self.Loop_Button.on_clicked(self.Egg.Loop)

        self.Stop_Button_ax = self.figure.add_axes(Button_params(),
                                                   label="Stop_Button_ax",
                                                   facecolor='white')
        Button_count += 1
        self.Stop_Button = Button(
            self.Stop_Button_ax,
            label="Stop",
        )
        self.Stop_Button.on_clicked(self.Egg.Stop)

        # =============================================================================
        #       Create Misc. Box
        # =============================================================================

        Misc_Box = patches.Rectangle((Edge_width / 2, Edge_width / 2),
                                     1 - Edge_width,
                                     Bottom_Gap,
                                     color='g')
        self.main_ax.add_patch(Misc_Box)

        self.Plot_Toggle_ax = self.figure.add_axes(
            [Edge_width, Edge_width, .25, .18],
            label="Plot_Toggle_ax",
            facecolor='white')
        self.Plot_Toggle = CheckButtons(ax=self.Plot_Toggle_ax,
                                        labels=[
                                            "Plot Wave", "Plot Savgol",
                                            "Plot Linear", "Plot Selectors"
                                        ],
                                        actives=[True, True, True, True])

        def Update_Plot_Toggle(val):
            Toggle_Bools = self.Plot_Toggle.get_status()
            self.Egg.Plot_Wave = Toggle_Bools[0]
            self.Egg.Plot_Savgol = Toggle_Bools[1]
            self.Egg.Plot_Linear = Toggle_Bools[2]
            self.Egg.Plot_Selectors = Toggle_Bools[3]
            self.Egg.Update_Canvas()(0)

        self.Plot_Toggle.on_clicked(Update_Plot_Toggle)

        self.Savgol_Mode_Select_ax = self.figure.add_axes(
            [Edge_width, Edge_width + .25, .25, .18],
            label="Savgol_Mode_Select_ax",
            facecolor='white')
        self.Savgol_Mode_Select = RadioButtons(self.Savgol_Mode_Select_ax,
                                               labels=[
                                                   "wrap",
                                                   "mirror",
                                                   "nearest",
                                                   "constant",
                                               ])

        def Update_Savgol_Mode_Select(val):
            active_mode = self.Savgol_Mode_Select.value_selected
            self.Egg.Update_Savgol_Mode(active_mode)

        self.Savgol_Mode_Select.on_clicked(Update_Savgol_Mode_Select)