def slider(self, N): barpos = plt.axes([0.18, 0.01, 0.70, 0.03], facecolor="gray") slider = Slider(barpos, '', 0, len(self.lista_hora) - N, valinit=0) slider.on_changed(self.bar) rax = plt.axes([0.01, 0.30, 0.12, 0.65]) check = CheckButtons( rax, ('C. Fund.', 'Arm. 1', 'Arm. 2', 'Arm. 3', 'Arm. 4', 'Arm. 5', 'Arm. 6', 'Arm. 7', 'Arm. 8', 'Arm. 9', 'Arm. 10', 'Arm. 11'), (True, True, True, True, True, True, True, True, True, True, True, True)) check.get_status() check.on_clicked(self.func) self.bar(0) plt.show()
class button_manager: ''' Handles some missing features of matplotlib check buttons on init: creates button, links to button_click routine, calls call_on_click with active index and firsttime=True on click: maintains single button on state, calls call_on_click ''' #@output.capture() # debug def __init__(self, fig, dim, labels, init, call_on_click): ''' dim: (list) [leftbottom_x,bottom_y,width,height] labels: (list) for example ['1','2','3','4','5','6'] init: (list) for example [True, False, False, False, False, False] ''' self.fig = fig self.ax = plt.axes(dim) #lx,by,w,h self.init_state = init self.call_on_click = call_on_click self.button = CheckButtons(self.ax, labels, init) self.button.on_clicked(self.button_click) self.status = self.button.get_status() self.call_on_click(self.status.index(True), firsttime=True) #@output.capture() # debug def reinit(self): self.status = self.init_state self.button.set_active(self.status.index( True)) #turn off old, will trigger update and set to status #@output.capture() # debug def button_click(self, event): ''' maintains one-on state. If on-button is clicked, will process correctly ''' #new_status = self.button.get_status() #new = [self.status[i] ^ new_status[i] for i in range(len(self.status))] #newidx = new.index(True) self.button.eventson = False self.button.set_active( self.status.index(True)) #turn off old or reenable if same self.button.eventson = True self.status = self.button.get_status() self.call_on_click(self.status.index(True))
def genCombineResalePrices(yearEnter, monthNum, p_leaseEnter, flatType): import numpy as np from datetime import datetime from matplotlib.widgets import CheckButtons global townSelect, towndata ## To get data within the Period from parameter pass of p_yearEnter p_monthEnter to current date dateStart = yearEnter + monthNum[-2:] firstStartDate = yearEnter + "-" + monthNum[-2:] dateEnd = datetime.now().strftime('%Y%m') dataPeriod = data[data['month'] == firstStartDate] for i in range(int(dateStart) + 1, int(dateEnd) + 1): strDate = str(i) last2digit = int(strDate[-2:]) if last2digit < 13: iStrDate = str(i) iDate = iStrDate[0:4] + "-" + iStrDate[-2:] idata = data[data['month'] == iDate] dataPeriod = np.append(dataPeriod, idata) ## To get data within the Flat lease period (Flat Lease > parameter p_leaseEnter) leaseStart = int(p_leaseEnter) + 1 leaseData = dataPeriod[dataPeriod['remaining_lease'] > leaseStart] flatTypeData = leaseData[leaseData['flat_type'] == flatType] towndata = flatTypeData ## To get data within the Selected Checkbox Town Line line_labels = [str(line.get_label()) for line in checkBoxLines] line_visibility = [line.get_visible() for line in checkBoxLines] check = CheckButtons(rax, line_labels, line_visibility) checkButtonThemes(check) status = check.get_status() #index = line_labels.index(label) #checkBoxLines[index].set_visible(not checkBoxLines[index].get_visible()) statusIndex = np.arange(0, len(status)) resale_prices = flatTypeData['resale_price'] town_resale_prices = np.zeros(len(townList), object) for t in townIndex: town_resale_prices[t] = resale_prices[towndata['town'] == townList[t]] townSelect = np.zeros(len(townList), object) townName = [] resale_prices_combined = [] labelIndex = -1 for s in statusIndex: if status[s]: labelIndex += 1 townSelect[labelIndex] = towndata[towndata['town'] == townList[s]] #townSelect = np.append(townSelect,sdata) townName.append(line_labels[s]) resale_prices_combined.append(town_resale_prices[s]) return resale_prices_combined, townName
def genCombineResalePrices(yearEnter, quarterNum, flatType): import numpy as np from datetime import datetime from matplotlib.widgets import CheckButtons global townSelect, towndata ## To get data within the Period from parameter pass of p_yearEnter p_quarterEnter to current date startYear = int(yearEnter) quarterNum = '2' dateEnd = datetime.now().strftime('%Y') endYear = int(dateEnd) dataPeriod = data[data['quarter'] == '2007-02'] tdata = data[data['quarter'] == '2007-03'] dataPeriod = np.append(dataPeriod, tdata) tdata = data[data['quarter'] == '2007-04'] dataPeriod = np.append(dataPeriod, tdata) for i in range(startYear + 1, endYear + 1): for j in range(1, 5): if not (i == startYear and j == 2): iDate = str(i) + "-Q" + str(j) idata = data[data['quarter'] == iDate] dataPeriod = np.append(dataPeriod, idata) towndata = dataPeriod[dataPeriod['flat_type'] == flatType] ## To get data within the Selected Checkbox Town Line line_labels = [str(line.get_label()) for line in checkBoxLines] line_visibility = [line.get_visible() for line in checkBoxLines] check = CheckButtons(rax, line_labels, line_visibility) checkButtonThemes(check) status = check.get_status() prices = towndata['price'] quarter = towndata['quarter'] town_prices = np.zeros(len(townList), object) town_quarter = np.zeros(len(townList), object) for t in townIndex: town_prices[t] = prices[towndata['town'] == townList[t]] town_quarter[t] = quarter[towndata['town'] == townList[t]] return town_prices, town_quarter
class CurrentView(object): def __new__(cls, path='../numpydata/'): # print(path) return super(CurrentView, cls).__new__(cls) def __init__(self, path='../numpydata/'): # self.channel = 2 # default channel self.visible_channels = [0, 1, 2, 3] self.channels = [0, 1, 2, 3] self.num = -1 # default view is fast self.fig, self.ax = plt.subplots() self.ax.set_title(str(self.visible_channels)) plt.subplots_adjust(left=0.25, right=0.9, top=0.9, bottom=0.1) self.toolbar = self.fig.canvas.manager.toolbar # self.line, = self.ax.plot([], []) self.rax = plt.axes([0.05, 0.4, 0.1, 0.15]) self.check = CheckButtons(self.rax, self.channels, [1, 1, 1, 1]) self.check.on_clicked(self.func) self.lines = [ self.ax.plot([], [])[0], self.ax.plot([], [])[0], self.ax.plot([], [])[0], self.ax.plot([], [])[0] ] # ch = np.load('../numpydata/fastview.npy')[self.channel] self.fast = np.load('../numpydata/fastview.npy') self.limits = { channel: (np.max(self.fast[channel]), np.min(self.fast[channel])) for channel in self.channels } self.lims = (np.min([ np.min(self.fast[ch]) for ch in self.visible_channels ]), np.max([np.max(self.fast[ch]) for ch in self.visible_channels])) # print(self.limits) for ch in self.visible_channels: self.lines[ch].set_data(np.linspace(0, 1, len(self.fast[ch])), self.fast[ch]) self.ax.set_xlim(0, 1) # self.ax.set_ylim(self.limits[ch]) self.ax.set_ylim(self.lims) self.ax.set_xlabel('-1') self.btn = self.fig.canvas.mpl_connect('key_press_event', self.on_press) self.clk = self.fig.canvas.mpl_connect('button_press_event', self.on_click) self.fig.canvas.draw() def show(self): plt.show() def on_modify_visible_channels(self): pass self.ax.set_title(str(self.visible_channels)) if self.num == -1: self.lims = (np.min([ np.min(self.fast[ch]) for ch in self.visible_channels ]), np.max([np.max(self.fast[ch]) for ch in self.visible_channels])) for ch in self.channels: if ch not in self.visible_channels: self.lines[ch].set_visible(False) else: self.lines[ch].set_visible(True) # for ch in self.visible_channels: # self.lines[ch].set_visible(not self.lines[ch].get_visible()) # self.lines[ch].set_data(np.linspace(0, 1, len(self.fast[ch])), self.fast[ch]) # self.ax.set_xlim(0, 1) # self.ax.set_ylim(self.limits[ch]) # self.ax.set_ylim(self.lims) self.ax.set_ylim(self.lims) self.fig.canvas.draw() def on_press(self, event): # print('press', event.key) sys.stdout.flush() if event.key == 'right': if self.num + 1 <= 99: self.num += 1 self.load_data() elif event.key == 'left': if self.num - 1 >= 0: self.num -= 1 self.load_data() elif event.key == 'd': self.num = -1 self.load_data() # elif event.key in '0123': # pass # tempnum = int(event.key) # if tempnum in self.visible_channels: # if len(self.visible_channels)>=2: # self.visible_channels.remove(tempnum) # else: # self.visible_channels.append(tempnum) # self.visible_channels=sorted(self.visible_channels) # self.on_modify_visible_channels() # self.ax.set_title(self.ax.get_title()+event.key) # self.fig.canvas.draw() # elif event.key == 'enter': # try: # tempnum = int(self.ax.get_title()) # except: # pass # else: # self.ax.set_title('') # if (tempnum >= 0) & (tempnum < 100): # self.num = tempnum # # self.load_data() else: pass def on_click(self, event): if event.inaxes != self.rax: if (event.dblclick) & (self.num == -1): self.num = int(round(event.xdata * 100)) - 1 self.load_data() print( '%s click: button = %d, x = %d, y = %d, xdata = %f, ydata = %f' % ('double' if event.dblclick else 'single', event.button, event.x, event.y, event.xdata, event.ydata)) def func(self, label): # index=int(label) bools = self.check.get_status() # plt.draw() # tempnum = index # if tempnum in self.visible_channels: # if len(self.visible_channels)>=2: # self.visible_channels.remove(tempnum) # else: # self.visible_channels.append(tempnum) indexs = [] for ch in self.channels: if bools[ch] == True: indexs.append(ch) self.visible_channels = np.array(self.channels)[indexs] self.on_modify_visible_channels() # self.lines[ch].set_visible(False) # else: # self.lines[ch].set_visible(True) def load_data(self): if self.ax.get_xlabel() != str(self.num): if self.num == -1: for ch in self.visible_channels: self.lines[ch].set_data( np.linspace(0, 1, len(self.fast[ch])), self.fast[ch]) self.ax.set_xlim(0, 1) # self.ax.set_ylim(self.limits[ch]) self.ax.set_ylim(self.lims) else: data = np.load('../numpydata/data' + str(self.num) + '.npy') for ch in self.visible_channels: self.lines[ch].set_data( np.linspace(self.num / 100, (self.num + 1) / 100, len(data[ch])), data[ch]) self.ax.set_xlim(self.num / 100, (self.num + 1) / 100) self.ax.set_ylim(self.lims) self.ax.set_xlabel(self.num) self.fig.canvas.draw() if (self.ax.get_xlabel() == str(self.num)) & (str(self.num) == '-1'): self.ax.set_xlim(0, 1) self.ax.set_ylim(self.lims) self.fig.canvas.draw()
class WaveGraphBase(ABC): def __init__(self, name, granularity, x_range, x_offset, y_range, y_offset, time_factor, line_thickness, waves, slider_data, checkbox_data): self.granularity = granularity self.x_range = x_range self.x_offset = x_offset self.y_range = y_range self.y_offset = y_offset self.waves = waves self.slider_data = slider_data self.checkbox_data = checkbox_data self.time_factor = time_factor self.line_thickness = line_thickness self.fig = plt.figure(figsize=(16, 9)) self.fig.canvas.set_window_title(name) self.main_grid = gridspec.GridSpec(2, 1) self.graph_cell = plt.subplot(self.main_grid[0, :]) self.graph_cell.set(xlim=(-self.x_range - self.x_offset, self.x_range - self.x_offset), ylim=(-self.y_range - self.y_offset, self.y_range - self.y_offset)) self.x_data = np.linspace(-3 * self.x_range - 3 * self.x_offset, 3 * self.x_range - 3 * self.x_offset, self.granularity) self.y_data = [[]] * len(self.waves) self.lines = [ plt.plot([], [], linewidth=5)[0] for _ in range(len(self.waves)) ] self.patches = self.lines self.control_cell = self.main_grid[1, :] self.control_grid = gridspec.GridSpecFromSubplotSpec( 1, 7, self.control_cell) self.checkbox_cell = self.control_grid[0, 0] self.checkbox_grid = gridspec.GridSpecFromSubplotSpec( 1, 1, self.checkbox_cell) self.checkboxes = [] self.checkboxAx = plt.subplot(self.checkbox_grid[0, 0:1]) self.checkbox = CheckButtons( self.checkboxAx, tuple(x["name"] for x in self.checkbox_data), tuple(x["init"] for x in self.checkbox_data)) self.checkbox.on_clicked(self.update) self.checkboxes_ticked = self.checkbox.get_status() self.slider_cell = self.control_grid[0, 2:6] self.slider_grid = gridspec.GridSpecFromSubplotSpec( len(self.slider_data), 1, self.slider_cell) self.sliders = [] for i in range(0, len(self.slider_data)): self.sliderAx = plt.subplot(self.slider_grid[i, 0]) self.slider = Slider(self.sliderAx, self.slider_data[i]["name"], self.slider_data[i]["min"], self.slider_data[i]["max"], valinit=self.slider_data[i]["init"], valstep=self.slider_data[i]["step"]) self.sliders.append(self.slider) for slider in self.sliders: slider.on_changed(self.update) def init(self): for line in self.lines: line.set_data([], []) return self.patches def start(self): self.animation = animation.FuncAnimation(self.fig, self.animate, init_func=self.init, frames=999999, repeat=True, interval=20, blit=True) plt.show() @abstractmethod def update(self, event=None): """Register sliders and checkboxes""" @abstractmethod def animate(self, i): """Animation function"""
def genLineChartByTown(p_yearEnter, p_quarterEnter, p_flatType): import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.widgets import RadioButtons, CheckButtons from datetime import datetime from matplotlib import cm global txt, flatType, checkBoxLines txt = None yearEnter = p_yearEnter quarterEnter = p_quarterEnter quarterNum = quarterEnter flatType = p_flatType def mainTheme(): global axcolor axcolor = 'lightgoldenrodyellow' title = "HDB Resale Flat Price Line Chart" ax.set_title(title, fontsize=30) ax.set_ylabel('Resale Price', fontsize=25) def get_status(): return [checkBoxLines[index].get_visible() for index in checkBoxLines] def genCombineResalePrices(yearEnter, quarterNum, flatType): import numpy as np from datetime import datetime from matplotlib.widgets import CheckButtons global townSelect, towndata ## To get data within the Period from parameter pass of p_yearEnter p_quarterEnter to current date startYear = int(yearEnter) quarterNum = '2' dateEnd = datetime.now().strftime('%Y') endYear = int(dateEnd) dataPeriod = data[data['quarter'] == '2007-02'] tdata = data[data['quarter'] == '2007-03'] dataPeriod = np.append(dataPeriod, tdata) tdata = data[data['quarter'] == '2007-04'] dataPeriod = np.append(dataPeriod, tdata) for i in range(startYear + 1, endYear + 1): for j in range(1, 5): if not (i == startYear and j == 2): iDate = str(i) + "-Q" + str(j) idata = data[data['quarter'] == iDate] dataPeriod = np.append(dataPeriod, idata) towndata = dataPeriod[dataPeriod['flat_type'] == flatType] ## To get data within the Selected Checkbox Town Line line_labels = [str(line.get_label()) for line in checkBoxLines] line_visibility = [line.get_visible() for line in checkBoxLines] check = CheckButtons(rax, line_labels, line_visibility) checkButtonThemes(check) status = check.get_status() prices = towndata['price'] quarter = towndata['quarter'] town_prices = np.zeros(len(townList), object) town_quarter = np.zeros(len(townList), object) for t in townIndex: town_prices[t] = prices[towndata['town'] == townList[t]] town_quarter[t] = quarter[towndata['town'] == townList[t]] return town_prices, town_quarter def updateCheckBox(line_labels, line_visibility, check): rax.clear() line_labels = [str(line.get_label()) for line in checkBoxLines] line_visibility = [line.get_visible() for line in checkBoxLines] check = CheckButtons(rax, line_labels, line_visibility) checkButtonThemes(check) def func(label): global index, towndata, checkBoxLines status = check.get_status() index = line_labels.index(label) checkBoxLines[index].set_visible( not checkBoxLines[index].get_visible()) updateCheckBox(line_labels, line_visibility, check) towndata = towndata[towndata['flat_type'] == flatType] prices = towndata['price'] quarter = towndata['quarter'] town_prices = np.zeros(len(townList), object) town_quarter = np.zeros(len(townList), object) for t in townIndex: town_prices[t] = prices[towndata['town'] == townList[t]] town_quarter[t] = quarter[towndata['town'] == townList[t]] mainTheme() print(status) plt.draw() def checkButtonThemes(check): for r in check.rectangles: r.set_alpha(0.8) r.set_width(0.025) r.set_edgecolor("k") ############## Main Process Start Here ############### title = "HDB Resale Flat Price Line Chart" data = np.genfromtxt( 'data/median-resale-prices-for-registered-applications-by-town-and-flat-type.csv', skip_header=1, dtype=[('quarter', 'U7'), ('town', 'U30'), ('flat_type', 'U20'), ('price', 'i8')], delimiter=",", missing_values=['na', '-'], filling_values=[0]) null_rows = np.isnan(data['price']) #nonnull_prices = data[null_rows==False] townList = list(set(data['town'])) townList.sort() townIndex = np.arange(0, len(townList)) colorList = cm.hsv(townIndex / float(max(townIndex + 10))) yearquarterList = list(set(data['quarter'])) yearquarterList.sort() town_prices = np.zeros(len(townList), object) startYear = int(yearEnter) quarterNum = '2' dateEnd = datetime.now().strftime('%Y') endYear = int(dateEnd) dataPeriod = data[data['quarter'] == '2007-02'] tdata = data[data['quarter'] == '2007-03'] dataPeriod = np.append(dataPeriod, tdata) tdata = data[data['quarter'] == '2007-04'] dataPeriod = np.append(dataPeriod, tdata) for i in range(startYear + 1, endYear + 1): for j in range(1, 5): if not (i == startYear and j == 2): iDate = str(i) + "-Q" + str(j) idata = data[data['quarter'] == iDate] dataPeriod = np.append(dataPeriod, idata) towndata = dataPeriod[dataPeriod['flat_type'] == flatType] prices = towndata['price'] quarter = towndata['quarter'] town_prices = np.zeros(len(townList), object) town_quarter = np.zeros(len(townList), object) for t in townIndex: town_prices[t] = prices[towndata['town'] == townList[t]] town_quarter[t] = quarter[towndata['town'] == townList[t]] fig, ax = plt.subplots(figsize=(19, 15)) plt.subplots_adjust(left=0.25, bottom=0.25) mainTheme() checkBoxLines = np.zeros(len(townList), object) print(town_prices[0]) for l in townIndex: checkBoxLines[l], = ax.plot(town_quarter[l], town_prices[l], visible=False, c=colorList[l], label=townList[l]) ax.set_xticklabels(town_quarter[l], rotation=90) patches = [ mpatches.Patch(color=color, label=label) for label, color in zip(townList, colorList) ] #fig.legend(patches, townList, loc='top right', frameon=True) fig.legend(patches, townList, loc='top right') #Ploting checkbox button # Make checkbuttons with all plotted lines with correct visibility rax = plt.axes([0.05, 0.25, 0.13, 0.7], facecolor=axcolor) line_labels = [str(line.get_label()) for line in checkBoxLines] line_visibility = [line.get_visible() for line in checkBoxLines] check = CheckButtons(rax, line_labels, line_visibility) checkButtonThemes(check) check.on_clicked(func) status = check.get_status() print(status) radioax = plt.axes([0.05, 0.08, 0.11, 0.12], facecolor=axcolor) radio = RadioButtons( radioax, ('1-room', '2-room', '3-room', '4-room', '5-room', 'Executive'), active=3) def flatTypefunc(label): global flatType quarterNum = "2" flatType = label town_prices, town_quarter = genCombineResalePrices( yearEnter, quarterNum, flatType) ax.clear() mainTheme() status = check.get_status() for l in townIndex: if status[l]: checkBoxLines[l], = ax.plot(town_quarter[l], town_prices[l], visible=True, c=colorList[l], label=townList[l]) ax.set_xticklabels(town_quarter[l], rotation=90) else: checkBoxLines[l], = ax.plot(town_quarter[l], town_prices[l], visible=False, c=colorList[l], label=townList[l]) ax.set_xticklabels(town_quarter[l], rotation=90) fig.canvas.draw_idle() radio.on_clicked(flatTypefunc) #ax.legend(checkBoxLines) #ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) town_prices, town_quarter = genCombineResalePrices(yearEnter, quarterNum, flatType) plt.show()
class Confusion_Visualizer(): def __init__(self, train_confusion_path, valid_confusion_path, class_names=None): # Load data self.confusion_train = self._load_data(train_confusion_path) self.tfpn_train = confusions_to_TFPN(self.confusion_train) self.confusion_valid = self._load_data(valid_confusion_path) self.tfpn_valid = confusions_to_TFPN(self.confusion_valid) # Check whether dimensions match if self.confusion_train.shape[-1] != self.confusion_valid.shape[-1]: min_length = min(self.confusion_train.shape[-1], self.confusion_valid.shape[-1]) print('Training and Validation data not match (of shape {} and {} respectively), truncated as {}'. format(self.confusion_train.shape, self.confusion_valid.shape, min_length)) self.confusion_train = self.confusion_train[:, :, 0:min_length] self.confusion_valid = self.confusion_valid[:, :, 0:min_length] self.series_length = self.confusion_train.shape[-1] self.class_amount = self.confusion_train.shape[0] # Obtain class labels if class_names is None: self.class_names = [str(i) for i in range(self.class_amount)] elif isinstance(class_names, list) or isinstance(class_names, tuple): if len(class_names) == self.class_amount: self.class_names = class_names else: print('Class names ({}) mismatch with confusion ({})'.format(len(class_names), self.class_amount)) raise ValueError else: print('Class names unsupported (please use list, e.g. [''Benign'',''Tumor''])') raise ValueError # Process data to get performance self.plot_data = {} self.performance_train = self._process_data(self.confusion_train) self.performance_valid = self._process_data(self.confusion_valid) # Set initial value for control panel (used in framework) self.ctrl_cursor = 0 self.ctrl_class = 0 self.ctrl_smooth = 1 self.ctrl_trajectory_length = 1 self.ctrl_trajectory_mode = 'FP-FN' self.ctrl_hold_range = False self.ctrl_same_ratio = True # Draw framework self._draw_framework() # Draw control panel self._draw_control_panel() # Draw initial figure self._figure_initialize() def _draw_framework(self): # Used for drawing confusion matrix (heatmap) self.heatmap_fig = plt.figure(1, figsize=(8, 4)) plt.title('Heatmap') self.heatmap_train = plt.subplot(1, 2, 1) self.heatmap_train.set_title('Train') self.heatmap_train.set_xticks(np.arange(self.class_amount)) self.heatmap_train.set_yticks(np.arange(self.class_amount)) self.heatmap_train.set_xticklabels(self.class_names) self.heatmap_train.set_yticklabels(self.class_names) self.heatmap_train.set_ylabel('Ground Truth') self.heatmap_train.set_xlabel('Prediction') self.heatmap_valid = plt.subplot(1, 2, 2) self.heatmap_valid.set_title('Validation') self.heatmap_valid.set_xticks(np.arange(self.class_amount)) self.heatmap_valid.set_yticks(np.arange(self.class_amount)) self.heatmap_valid.set_xticklabels(self.class_names) self.heatmap_valid.set_yticklabels(self.class_names) self.heatmap_valid.set_ylabel('Ground Truth') self.heatmap_valid.set_xlabel('Prediction') # Used for drawing performance self.performance_fig = plt.figure(2, figsize=(10, 6)) plt.title('Performance') self.performance_global_p_r = plt.subplot(2, 2, 1) self.performance_global_p_r.set_title('Global') self.performance_class_p_r = plt.subplot(2, 2, 2) self.performance_class_p_r.set_title('Class: ' + self.class_names[int(self.ctrl_class)]) self.performance_global_a_f = plt.subplot(2, 2, 3) self.performance_class_f = plt.subplot(2, 2, 4) # Used for drawing FP/FN, etc. self.trace_fig = plt.figure(3, figsize=(6, 6)) self.trace_plot = plt.subplot(1, 1, 1) self.trace_plot.set_xlabel('FP') self.trace_plot.set_ylabel('FN') self.trace_plot.set_title('Class: ' + self.class_names[int(self.ctrl_class)]) # Used for drawing control panel self.controller = plt.figure(4, figsize=(4, 4)) def _draw_control_panel(self): ax = self.controller.add_axes([0.2, 0.90, 0.5, 0.05]) self.slider_cursor = DiscreteSlider(label='cursor', valmin=0, valmax=self.series_length - 1, ax=ax, increment=1, valinit=self.ctrl_cursor) self.slider_cursor.on_changed(self.controller_change_slider_cursor) ax = self.controller.add_axes([0.2, 0.75, 0.5, 0.05]) self.slider_class = DiscreteSlider(label='class', valmin=0, valmax=self.class_amount - 1, ax=ax, increment=1, valinit=self.ctrl_class) self.slider_class.on_changed(self.controller_change_slider_class) ax = self.controller.add_axes([0.2, 0.60, 0.5, 0.05]) self.slider_smooth = DiscreteSlider(label='smooth', valmin=1, valmax=self.series_length, ax=ax, increment=1, valinit=self.ctrl_smooth) self.slider_smooth.on_changed(self.controller_change_slider_smooth) ax = self.controller.add_axes([0.2, 0.45, 0.5, 0.05]) self.slider_trajectory = DiscreteSlider(label='trajectory', valmin=1, valmax=self.series_length, ax=ax, increment=1, valinit=self.ctrl_trajectory_length) self.slider_trajectory.on_changed(self.controller_change_slider_trajectory) ax = self.controller.add_axes([0.6, 0.1, 0.3, 0.25]) self.radiobuttons_norm = RadioButtons(ax, ('FP-FN', 'FPR-TPR', 'Recl-Prec', 'Spec-Sens'), active=0) self.radiobuttons_norm.on_clicked(self.controller_change_radiobuttons_norm) ax = self.controller.add_axes([0.3, 0.1, 0.3, 0.25]) self.checkbutton_hold = CheckButtons(ax, labels=['Hold on Range', 'Same Ratio'], actives=[False, True]) self.checkbutton_hold.on_clicked(self.controller_change_checkbutton_hold) def _load_data(self, obj): if os.path.isfile(obj): # read directly given a single file [conf_tensor] = load_variables(obj) return conf_tensor elif os.path.isdir(obj): # read 0.pkl ~ *.pkl given a directory path file_list = get_file_list(obj) file_amount = len(file_list) conf_list = [] for i in range(file_amount): [conf] = load_variables(os.path.join(obj, str(i) + '.pkl')) conf_list.append(conf) return np.concatenate(np.expand_dims(conf_list, axis=-1), axis=-1) elif isinstance(obj, np.ndarray): # return directly given np array return obj else: print('Invalid confusion source: {}'.format(obj)) raise ValueError # Confusions: c x c x N def _process_data(self, confusions): performance = {} performance['accuracy'] = [] performance['precision'] = [] performance['recall'] = [] performance['f1'] = [] for i in range(confusions.shape[-1]): current_result = confusion_to_all(confusions[:, :, i]) for metric in ['accuracy', 'precision', 'recall', 'f1']: performance[metric].append(current_result[metric]) for metric in ['accuracy', 'precision', 'recall', 'f1']: performance[metric] = np.concatenate(np.expand_dims(performance[metric], axis=-1), axis=-1) # If some performance is nan, it will replaced by 0 for metric in ['precision', 'recall', 'f1']: global_performance = np.mean(performance[metric], axis=0) global_performance[np.isnan(global_performance)] = 0 performance['global_' + metric] = global_performance for metric in ['accuracy', 'precision', 'recall', 'f1']: performance[metric][np.isnan(performance[metric])] = 0 performance['global_accuracy'] = performance['accuracy'] return performance # Figure initialization def _figure_initialize(self): self.element_heatmap_train_text_pool = [] self.element_heatmap_valid_text_pool = [] self._update_figure_heatmap() self.element_performance_global_precision_train = None self.element_performance_global_recall_train = None self.element_performance_global_precision_valid = None self.element_performance_global_recall_valid = None self.element_performance_global_accuracy_train = None self.element_performance_global_f1_train = None self.element_performance_global_accuracy_valid = None self.element_performance_global_f1_valid = None self._update_figure_global_performance() self.element_performance_class_precision_train = None self.element_performance_class_recall_train = None self.element_performance_class_precision_valid = None self.element_performance_class_recall_valid = None self.element_performance_class_f1_train = None self.element_performance_class_f1_valid = None self._update_figure_class_performance() self.element_trajectory_train = None self.element_trajectory_valid = None self._update_figure_trajectory() plt.show() # Event for slider cursor changed def controller_change_slider_cursor(self, event): new_value = int(self.slider_cursor.val) if self.ctrl_cursor != new_value: self.ctrl_cursor = new_value self._update_figure_heatmap() self._update_figure_trajectory() # Event for slider class changed def controller_change_slider_class(self, event): new_value = int(self.slider_class.val) if self.ctrl_class != new_value: self.ctrl_class = new_value self._update_figure_class_performance() self._update_figure_trajectory() # Event for slider smooth changed def controller_change_slider_smooth(self, event): new_value = int(self.slider_smooth.val) if self.ctrl_smooth != new_value: self.ctrl_smooth = new_value self._update_figure_global_performance() self._update_figure_class_performance() # Event for slider trajectory (length) changed def controller_change_slider_trajectory(self, event): new_value = int(self.slider_trajectory.val) if self.ctrl_trajectory_length != new_value: self.ctrl_trajectory_length = new_value self._update_figure_trajectory() # Event for radiobuttons changed (FP/FN normalization method) def controller_change_radiobuttons_norm(self, event): new_value = self.radiobuttons_norm.value_selected if self.ctrl_trajectory_mode != new_value: self.ctrl_trajectory_mode = new_value self._update_figure_trajectory() # Event for checkbox changed (hold FP/FN figure's range) def controller_change_checkbutton_hold(self, event): new_value = self.checkbutton_hold.get_status() self.ctrl_hold_range = new_value[0] self.ctrl_same_ratio = new_value[1] def _update_data_heatmap(self): ctrl_cursor = self.ctrl_cursor self.plot_data['heatmap_train'] = self.confusion_train[:, :, ctrl_cursor] / np.sum( self.confusion_train[:, :, ctrl_cursor], axis=1, keepdims=True) self.plot_data['heatmap_valid'] = self.confusion_valid[:, :, ctrl_cursor] / np.sum( self.confusion_valid[:, :, ctrl_cursor], axis=1, keepdims=True) def _update_data_global_performance(self): ctrl_smooth = self.ctrl_smooth for metric in ['precision', 'recall', 'f1', 'accuracy']: self.plot_data['performance_global_train_' + metric] = mean_filter_padded( self.performance_train['global_' + metric], ctrl_smooth) self.plot_data['performance_global_valid_' + metric] = mean_filter_padded( self.performance_valid['global_' + metric], ctrl_smooth) def _update_data_class_performance(self): ctrl_class = self.ctrl_class ctrl_smooth = self.ctrl_smooth for metric in ['precision', 'recall', 'f1']: self.plot_data['performance_class_train_' + metric] = mean_filter_padded( self.performance_train[metric][ctrl_class, :], ctrl_smooth) self.plot_data['performance_class_valid_' + metric] = mean_filter_padded( self.performance_valid[metric][ctrl_class, :], ctrl_smooth) def _update_data_trajectory(self): ctrl_cursor = self.ctrl_cursor ctrl_class = self.ctrl_class ctrl_trajectory = self.ctrl_trajectory_length ctrl_trajectory_mode = self.ctrl_trajectory_mode # Desired range: from cursor to (cursor+trajectory) lb = max(ctrl_cursor, 0) ub = min(ctrl_cursor + ctrl_trajectory, self.series_length) train_tp = self.tfpn_train[ctrl_class, 0, lb:ub] train_tn = self.tfpn_train[ctrl_class, 1, lb:ub] train_fp = self.tfpn_train[ctrl_class, 2, lb:ub] train_fn = self.tfpn_train[ctrl_class, 3, lb:ub] valid_tp = self.tfpn_valid[ctrl_class, 0, lb:ub] valid_tn = self.tfpn_valid[ctrl_class, 1, lb:ub] valid_fp = self.tfpn_valid[ctrl_class, 2, lb:ub] valid_fn = self.tfpn_valid[ctrl_class, 3, lb:ub] # Measurements refer to: http://www.davidsbatista.net/blog/2018/08/19/NLP_Metrics/ if ctrl_trajectory_mode == 'FP-FN': self.plot_data['trajectory_train_x'] = train_fp self.plot_data['trajectory_train_y'] = train_fn self.plot_data['trajectory_valid_x'] = valid_fp self.plot_data['trajectory_valid_y'] = valid_fn elif ctrl_trajectory_mode == 'FPR-TPR': self.plot_data['trajectory_train_x'] = train_fp / (train_tn + train_fp) self.plot_data['trajectory_train_y'] = train_tp / (train_tp + train_fn) self.plot_data['trajectory_valid_x'] = valid_fp / (valid_tn + valid_fp) self.plot_data['trajectory_valid_y'] = valid_tp / (valid_tp + valid_fn) elif ctrl_trajectory_mode == 'Recl-Prec': self.plot_data['trajectory_train_x'] = train_tp / (train_tp + train_fn) self.plot_data['trajectory_train_y'] = train_tp / (train_tp + train_fp) self.plot_data['trajectory_valid_x'] = valid_tp / (valid_tp + valid_fn) self.plot_data['trajectory_valid_y'] = valid_tp / (valid_tp + valid_fp) elif ctrl_trajectory_mode == 'Spec-Sens': self.plot_data['trajectory_train_x'] = train_tn / (train_tn + train_fp) self.plot_data['trajectory_train_y'] = train_tp / (train_tp + train_fn) self.plot_data['trajectory_valid_x'] = valid_tn / (valid_tn + valid_fp) self.plot_data['trajectory_valid_y'] = valid_tp / (valid_tp + valid_fn) else: raise ValueError pass def _update_figure_heatmap(self): # Fetch data self._update_data_heatmap() heatmap_data_train = self.plot_data['heatmap_train'] heatmap_data_valid = self.plot_data['heatmap_valid'] # Clean former text for item in self.element_heatmap_train_text_pool: item.remove() for item in self.element_heatmap_valid_text_pool: item.remove() self.element_heatmap_train_text_pool = [] self.element_heatmap_valid_text_pool = [] # Write new text & Draw for i in range(self.class_amount): for j in range(self.class_amount): obj = self.heatmap_train.text(j, i, '{:.2f}'.format(heatmap_data_train[i, j]), ha="center", va="center", color="w", fontsize=8) self.element_heatmap_train_text_pool.append(obj) self.heatmap_train.imshow(heatmap_data_train, cmap='jet') for i in range(self.class_amount): for j in range(self.class_amount): obj = self.heatmap_valid.text(j, i, '{:.2f}'.format(heatmap_data_valid[i, j]), ha="center", va="center", color="w", fontsize=8) self.element_heatmap_valid_text_pool.append(obj) self.heatmap_valid.imshow(heatmap_data_valid, cmap='jet') self.heatmap_fig.tight_layout() self.heatmap_fig.canvas.draw_idle() def _update_figure_global_performance(self): time_steps = np.arange(self.series_length) self._update_data_global_performance() if self.element_performance_global_precision_train is None: self.element_performance_global_precision_train, = \ self.performance_global_p_r.plot(time_steps, self.plot_data['performance_global_train_precision'], color='red', linestyle='-') else: self.element_performance_global_precision_train.set_ydata( self.plot_data['performance_global_train_precision']) if self.element_performance_global_recall_train is None: self.element_performance_global_recall_train, = \ self.performance_global_p_r.plot(time_steps, self.plot_data['performance_global_train_recall'], color='blue', linestyle='-') else: self.element_performance_global_recall_train.set_ydata(self.plot_data['performance_global_train_recall']) if self.element_performance_global_precision_valid is None: self.element_performance_global_precision_valid, = \ self.performance_global_p_r.plot(time_steps, self.plot_data['performance_global_valid_precision'], color='red', linestyle='--') else: self.element_performance_global_precision_valid.set_ydata( self.plot_data['performance_global_valid_precision']) if self.element_performance_global_recall_valid is None: self.element_performance_global_recall_valid, = \ self.performance_global_p_r.plot(time_steps, self.plot_data['performance_global_valid_recall'], color='blue', linestyle='--') # Legend, only added at first drawing self.performance_global_p_r.legend( ['pre_train', 'rec_train', 'pre_valid', 'rec_valid']) else: self.element_performance_global_recall_valid.set_ydata(self.plot_data['performance_global_valid_recall']) if self.element_performance_global_accuracy_train is None: self.element_performance_global_accuracy_train, = \ self.performance_global_a_f.plot(time_steps, self.plot_data['performance_global_train_accuracy'], color='red', linestyle='-') else: self.element_performance_global_accuracy_train.set_ydata( self.plot_data['performance_global_train_accuracy']) if self.element_performance_global_f1_train is None: self.element_performance_global_f1_train, = \ self.performance_global_a_f.plot(time_steps, self.plot_data['performance_global_train_f1'], color='blue', linestyle='-') else: self.element_performance_global_f1_train.set_ydata(self.plot_data['performance_global_train_f1']) if self.element_performance_global_accuracy_valid is None: self.element_performance_global_accuracy_valid, = \ self.performance_global_a_f.plot(time_steps, self.plot_data['performance_global_valid_accuracy'], color='red', linestyle='--') else: self.element_performance_global_accuracy_valid.set_ydata( self.plot_data['performance_global_valid_accuracy']) if self.element_performance_global_f1_valid is None: self.element_performance_global_f1_valid, = \ self.performance_global_a_f.plot(time_steps, self.plot_data['performance_global_valid_f1'], color='blue', linestyle='--') # Legend, only added at first drawing self.performance_global_a_f.legend( ['acc_train', 'f1_train', 'acc_valid', 'f1_valid']) else: self.element_performance_global_f1_valid.set_ydata(self.plot_data['performance_global_valid_f1']) self.performance_fig.canvas.draw_idle() def _update_figure_class_performance(self): time_steps = np.arange(self.series_length) self._update_data_class_performance() self.performance_class_p_r.set_title('Class: ' + self.class_names[int(self.ctrl_class)]) if self.element_performance_class_precision_train is None: self.element_performance_class_precision_train, = \ self.performance_class_p_r.plot(time_steps, self.plot_data['performance_class_train_precision'], color='red', linestyle='-') else: self.element_performance_class_precision_train.set_ydata( self.plot_data['performance_class_train_precision']) if self.element_performance_class_recall_train is None: self.element_performance_class_recall_train, = \ self.performance_class_p_r.plot(time_steps, self.plot_data['performance_class_train_recall'], color='blue', linestyle='-') else: self.element_performance_class_recall_train.set_ydata(self.plot_data['performance_class_train_recall']) if self.element_performance_class_precision_valid is None: self.element_performance_class_precision_valid, = \ self.performance_class_p_r.plot(time_steps, self.plot_data['performance_class_valid_precision'], color='red', linestyle='--') else: self.element_performance_class_precision_valid.set_ydata( self.plot_data['performance_class_valid_precision']) if self.element_performance_class_recall_valid is None: self.element_performance_class_recall_valid, = \ self.performance_class_p_r.plot(time_steps, self.plot_data['performance_class_valid_recall'], color='blue', linestyle='--') # Legend, only added at first drawing self.performance_class_p_r.legend( ['pre_train', 'rec_train', 'pre_valid', 'rec_valid']) else: self.element_performance_class_recall_valid.set_ydata(self.plot_data['performance_class_valid_recall']) if self.element_performance_class_f1_train is None: self.element_performance_class_f1_train, = \ self.performance_class_f.plot(time_steps, self.plot_data['performance_class_train_f1'], color='blue', linestyle='-') else: self.element_performance_class_f1_train.set_ydata(self.plot_data['performance_class_train_f1']) if self.element_performance_class_f1_valid is None: self.element_performance_class_f1_valid, = \ self.performance_class_f.plot(time_steps, self.plot_data['performance_class_valid_f1'], color='blue', linestyle='--') # Legend, only added at first drawing self.performance_class_f.legend( ['f1_train', 'f1_valid']) else: self.element_performance_class_f1_valid.set_ydata(self.plot_data['performance_class_valid_f1']) self.performance_fig.canvas.draw_idle() def _update_figure_trajectory(self): self._update_data_trajectory() self.trace_plot.set_title('Class: ' + self.class_names[int(self.ctrl_class)]) data_length = len(self.plot_data['trajectory_train_x']) interpolation_values = (np.arange(data_length) + 1) / data_length blue_color = color_interpolation(np.array((1, 1, 1)), np.array((0, 0, 1)), interpolation_values) red_color = color_interpolation(np.array((1, 1, 1)), np.array((1, 0, 0)), interpolation_values) # To make the pure blue and pure red showed on legend, we reverse the array if blue_color.shape[0] > 1: blue_color = blue_color[::-1, :] if red_color.shape[0] > 1: red_color = red_color[::-1, :] if self.element_trajectory_train is not None: self.element_trajectory_train.remove() self.element_trajectory_train = self.trace_plot.scatter(self.plot_data['trajectory_train_x'][::-1], self.plot_data['trajectory_train_y'][::-1], c=blue_color) if self.element_trajectory_valid is not None: self.element_trajectory_valid.remove() self.element_trajectory_valid = self.trace_plot.scatter(self.plot_data['trajectory_valid_x'][::-1], self.plot_data['trajectory_valid_y'][::-1], c=red_color) # Adjust axis range if not self.ctrl_hold_range: max_range_x = np.max( np.concatenate((self.plot_data['trajectory_train_x'], self.plot_data['trajectory_valid_x']))) max_range_y = np.max( np.concatenate((self.plot_data['trajectory_train_y'], self.plot_data['trajectory_valid_y']))) max_range = np.max([max_range_x, max_range_y]) if self.ctrl_same_ratio: self.trace_plot.set_aspect('equal') self.trace_plot.set_xlim((0, 1.1 * max_range)) self.trace_plot.set_ylim((0, 1.1 * max_range)) else: self.trace_plot.set_aspect('auto') self.trace_plot.set_xlim((0, 1.1 * max_range_x)) self.trace_plot.set_ylim((0, 1.1 * max_range_y)) # Updata axis label if self.ctrl_trajectory_mode == 'FP-FN': self.trace_plot.set_xlabel('FP') self.trace_plot.set_ylabel('FN') elif self.ctrl_trajectory_mode == 'FPR-TPR': self.trace_plot.set_xlabel('FPR') self.trace_plot.set_ylabel('TPR') elif self.ctrl_trajectory_mode == 'Recl-Prec': self.trace_plot.set_xlabel('Recall') self.trace_plot.set_ylabel('Precision') elif self.ctrl_trajectory_mode == 'Spec-Sens': self.trace_plot.set_xlabel('Specificity') self.trace_plot.set_ylabel('Sensitivity ') else: raise ValueError self.trace_plot.legend(['Train', 'Valid']) self.trace_fig.canvas.draw_idle()
class CorrViewer(DataViewer): """Plots raw correlation data. You need to hold reference to this object, otherwise it will not work in interactive mode. Parameters ---------- semilogx : bool Whether plot data with semilogx or not. shape : tuple of ints, optional Original frame shape. For non-rectangular you must provide this so to define k step. size : int, optional If specified, perform log_averaging of data with provided size parameter. If not given, no averaging is performed. norm : int, optional Normalization constant used in normalization scale : bool, optional Scale constant used in normalization. mask : ndarray, optional A boolean array indicating which data elements were computed. """ background = None variance = None def __init__(self, semilogx=True, shape=None, size=None, norm=None, scale=False, mask=None): self.norm = norm self.scale = scale self.semilogx = semilogx self.shape = shape self.size = size self.computed_mask = mask if mask is not None: self.kisize, self.kjsize = mask.shape def set_norm(self, value): """Sets norm parameter""" method = _method_from_data(self.data) self.norm = _default_norm_from_data(self.data, method, value) def _init_fig(self): super()._init_fig() self.set_norm(self.norm) #self.rax = plt.axes([0.48, 0.55, 0.15, 0.3]) self.cax = plt.axes([0.44, 0.72, 0.2, 0.15]) self.active = [ bool(self.norm & NORM_STRUCTURED), bool(self.norm & NORM_SUBTRACTED), bool((self.norm & NORM_WEIGHTED == NORM_WEIGHTED)), bool((self.norm & NORM_COMPENSATED) == NORM_COMPENSATED) ] self.check = CheckButtons( self.cax, ("structured", "subtracted", "weighted", "compensated"), self.active) #self.radio = RadioButtons(self.rax,("norm 0","norm 1","norm 2","norm 3","norm 4", "norm 5", "norm 6", "norm 7"), active = self.norm, activecolor = "gray") def update(label): index = ["structured", "subtracted", "weighted", "compensated"].index(label) status = self.check.get_status() norm = NORM_STRUCTURED if status[0] == True else NORM_STANDARD if status[1]: norm = norm | NORM_SUBTRACTED if status[2]: norm = norm | NORM_WEIGHTED if status[3]: norm = norm | NORM_COMPENSATED try: self.set_norm(norm) except ValueError: self.check.set_active(index) self.set_mask(int(round(self.kindex.val)), self.angleindex.val, self.sectorindex.val, self.kstep) self.plot() self.check.on_clicked(update) # def update(val): # try: # self.set_norm(int(self.radio.value_selected[-1])) # except ValueError: # self.radio.set_active(self.norm) # self.set_mask(int(round(self.kindex.val)),self.angleindex.val,self.sectorindex.val, self.kstep) # self.plot() # # self.radio.on_clicked(update) def set_data(self, data, background=None, variance=None): """Sets correlation data. Parameters ---------- data : tuple A data tuple (as computed by ccorr, cdiff, adiff, acorr functions) background : tuple or ndarray Background data for normalization. For adiff, acorr functions this is ndarray, for cdiff,ccorr, it is a tuple of ndarrays. variance : tuple or ndarray Variance data for normalization. For adiff, acorr functions this is ndarray, for cdiff,ccorr, it is a tuple of ndarrays. """ self.data = data self.background = background self.variance = variance self.kshape = data[0].shape[:-1] if self.computed_mask is None: self.kisize, self.kjsize = self.kshape def _get_avg_data(self): data = normalize(self.data, self.background, self.variance, norm=self.norm, scale=self.scale, mask=self.mask) if self.size is not None: t, data = log_average(data, self.size) else: t = np.arange(data.shape[-1]) return t, np.nanmean(data, axis=-2)
fig, ax = plt.subplots() p1, = ax.plot(x, y1, color='red', label='red') p2, = ax.plot(x, y2, color='blue', label='blue', visible=False) p3, = ax.plot(x, y3, color='green', label='green', visible=False) lines = [p1, p2, p3] plt.axis([-2.5, 12, 0, 40]) plt.subplots_adjust(left=0.25, bottom=0.1, right=0.95, top=0.95) labels = ['red', 'blue', 'green'] actives = [True, False, False] # set the checkbox as checked on displaying axCheckButton = plt.axes([0.03, 0.4, 0.15, 0.15]) # left, bottom, width, height chxbox = CheckButtons(axCheckButton, labels, actives) def set_visible(label): index = labels.index(label) lines[index].set_visible(not lines[index].get_visible()) plt.draw() cid = chxbox.on_clicked(set_visible) # chxbox.disconnect(cid) print(chxbox.get_status()) # chxbox.set_active(1) plt.show()
class HeartbeatIntervalFinder(object): """ Cursor for editing heartbeat signals for use of verification """ def __init__(self, files, folder_name = "", dosage = 0, file_number = 1, area_around_echo_size = 240, use_intervals = False, preloaded_signal = False, save_signal = False): super(HeartbeatIntervalFinder, self).__init__() # Load Arguments self.files = files self.folder_name = folder_name self.dosage = dosage self.file_number = file_number self.file_name = files[folder_name][dosage][file_number]["file_name"] self.use_intervals = use_intervals self.area_around_echo_size = area_around_echo_size self.preloaded_signal = preloaded_signal self.update_point = None self.new_bounds = False # Save signals if save_signal: self.save_signals() # Load Signals and 2D Echo Time self.load_signals() # Determine Orginal bounds self.initialize_bounds() # Plot signals self.plot_signals() def clip_signals(self): max_time = max(self.time) min_time = min(self.time) if (max_time - min_time) < self.area_around_echo_size: interval = range(np.searchsorted(self.time, min_time), np.searchsorted(self.time, max_time)) elif self.echo_time - self.area_around_echo_size/2 < min_time: interval = range(np.searchsorted(self.time, min_time), int(np.searchsorted(self.time, min_time + self.area_around_echo_size))) elif self.echo_time + self.area_around_echo_size/2 > max_time: interval = range(int(np.searchsorted(self.time, (max_time - self.area_around_echo_size))), np.searchsorted(self.time, max_time)) else: interval = range(int(np.searchsorted(self.time, self.echo_time - (self.area_around_echo_size/2))), int(np.searchsorted(self.time, self.echo_time + (self.area_around_echo_size/2)))) self.interval_near_echo = interval self.time = self.time[interval] self.signal = self.signal[interval] self.seis = self.seis[interval] self.phono = self.phono[interval] self.signal = hb.bandpass_filter(time = self.time, signal = self.signal, freqmin = 59, freqmax = 61) self.seis = hb.bandpass_filter(time = self.time, signal = self.seis, freqmin = 59, freqmax = 61) self.phono = hb.bandpass_filter(time = self.time, signal = self.phono, freqmin = 59, freqmax = 61) self.signal = hb.lowpass_filter(time = self.time, signal = self.signal, cutoff_freq = 10) self.seis = hb.lowpass_filter(time = self.time, signal = self.seis, cutoff_freq = 50) def initialize_bounds(self): if self.use_intervals: self.lower_bound = self.files[self.folder_name][self.dosage][self.file_number]["intervals"][1][0] self.upper_bound = self.files[self.folder_name][self.dosage][self.file_number]["intervals"][1][1] else: max_time = max(self.time) min_time = min(self.time) inital_interval_size = 20 if (max_time - min_time) < inital_interval_size: self.lower_bound = min_time self.upper_bound = max_time elif (self.echo_time - (inital_interval_size/2)) < min_time: self.lower_bound = min_time self.upper_bound = min_time + inital_interval_size elif (self.echo_time + (inital_interval_size/2)) > max_time: self.lower_bound = max_time - inital_interval_size self.upper_bound = max_time else: self.lower_bound = self.echo_time - (inital_interval_size/2) self.upper_bound = self.echo_time + (inital_interval_size/2) def plot_signals(self): # Create figure self.fig, self.ax = plt.subplots() self.ax.get_yaxis().set_visible(False) self.ax.set_xlabel("Time [s]") # Plot ECG, Phono and Seismo self.signal_line, = self.ax.plot(self.time, self.signal, label = "ECG", linewidth = 0.5, c = "b") self.seis_line, = self.ax.plot(self.time, self.seis, label = "Seismo", linewidth = 0.5, c = "r") self.phono_line, = self.ax.plot(self.time, self.phono, label = "Phono", linewidth = 0.5, c = "g") sig_min = min(self.signal) sig_max = max(self.signal) self.ax.set_xlim(self.time[0] - 0.1*(self.time[-1] - self.time[0]), self.time[-1] + 0.1*(self.time[-1] - self.time[0])) self.ax.set_ylim(sig_min - 0.1*(sig_max - sig_min), sig_max + 0.1*(sig_max - sig_min)) # Echo Line signal_max = max(self.signal) signal_min = min(self.signal) self.echo_line = self.ax.axvline(self.echo_time, ymin = signal_min - abs(signal_max - signal_min), ymax = signal_max + abs(signal_max - signal_min), label = "2D Echo Time", c = "k", linewidth = 2) plt.legend(loc = "upper right") # Set endpoints self.bound_span = self.ax.axvspan(self.lower_bound, self.upper_bound, facecolor='g', alpha=0.25) self.bound_text_height = -0.1 self.lower_bound_text = self.ax.text(self.lower_bound, self.bound_text_height, transform = self.ax.get_xaxis_transform(), s = "Lower Bound\n" + str(self.lower_bound), fontsize=12, horizontalalignment = 'center') self.upper_bound_text = self.ax.text(self.upper_bound, self.bound_text_height, transform = self.ax.get_xaxis_transform(), s = "Upper Bound\n" + str(self.upper_bound), fontsize=12, horizontalalignment = 'center') # Initalize axes and data points self.x = self.time self.y = self.signal # Cross hairs self.lx = self.ax.axhline(color='k', linewidth=0.2) # the horiz line self.ly = self.ax.axvline(color='k', linewidth=0.2) # the vert line # Add data left_shift = 0.45 start = 0.96 space = 0.04 self.folder_text = self.ax.text(0.01, start, transform = self.ax.transAxes, s = "Folder: " + self.folder_name, fontsize=12, horizontalalignment = 'left') self.dosage_text = self.ax.text(0.61 - left_shift, 1.1 - space, transform = self.ax.transAxes, s = "Dosage: " + str(self.dosage), fontsize=12, horizontalalignment = 'left') self.file_name_text = self.ax.text(0.01, start - space, transform = self.ax.transAxes, s = "File: " + self.files[self.folder_name][self.dosage][self.file_number]["file_name"], fontsize=12, horizontalalignment = 'left') # Add index buttons ax_prev = plt.axes([0.575 - left_shift, 0.9, 0.1, 0.075]) self.bprev = Button(ax_prev, 'Previous') self.bprev.on_clicked(self.prev) ax_next = plt.axes([0.8 - left_shift, 0.9, 0.1, 0.075]) self.b_next = Button(ax_next, 'Next') self.b_next.on_clicked(self.next) self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_move) # Add Save Button ax_save = plt.axes([0.8, 0.9, 0.1, 0.075]) self.b_save = Button(ax_save, 'Save') self.b_save.on_clicked(self.save_intervals) # Add Line hide buttons self.ax.text(1.015, 0.97, transform = self.ax.transAxes, s = "Hide Signal:", fontsize=12, horizontalalignment = 'left') ax_switch_signals = plt.axes([0.91, 0.7, 0.07, 0.15]) self.b_switch_signals = CheckButtons(ax_switch_signals, ('ECG', 'Seismo', 'Phono')) self.b_switch_signals.on_clicked(self.switch_signal) # Add Sliders self.signal_amp_slider = Slider(plt.axes([0.91, 0.15, 0.01, 0.475]), label = "ECG\nA", valmin = 0.01, valmax = 10, valinit = 1, orientation = 'vertical') self.signal_amp_slider.on_changed(self.switch_signal) self.signal_amp_slider.valtext.set_visible(False) self.seis_height_slider = Slider(plt.axes([0.93, 0.15, 0.01, 0.475]), label = " Seis\nH", valmin = 1.5 * min(self.signal), valmax = 1.5 * max(self.signal), valinit = 0, orientation = 'vertical') self.seis_height_slider.on_changed(self.switch_signal) self.seis_height_slider.valtext.set_visible(False) self.seis_amp_slider = Slider(plt.axes([0.94, 0.15, 0.01, 0.475]), label = "\nA", valmin = 0.01, valmax = 10, valinit = 1, orientation = 'vertical') self.seis_amp_slider.on_changed(self.switch_signal) self.seis_amp_slider.valtext.set_visible(False) self.phono_height_slider = Slider(plt.axes([0.96, 0.15, 0.01, 0.475]), label = " Phono\nH", valmin = 1.5 * min(self.signal), valmax = 1.5 * max(self.signal), valinit = 0, orientation = 'vertical') self.phono_height_slider.on_changed(self.switch_signal) self.phono_height_slider.valtext.set_visible(False) self.phono_amp_slider = Slider(plt.axes([0.97, 0.15, 0.01, 0.475]), label = "A", valmin = .01, valmax = 10, valinit = 1, orientation = 'vertical') self.phono_amp_slider.on_changed(self.switch_signal) self.phono_amp_slider.valtext.set_visible(False) # Add Clicking Actions self.fig.canvas.mpl_connect('button_press_event', self.on_click) self.fig.canvas.mpl_connect('button_release_event', self.off_click) # Maximize frame mng = plt.get_current_fig_manager() mng.full_screen_toggle() plt.show() def switch_signal(self, label): self.x = self.time self.y = [np.mean(self.signal)] * len(self.time) label = self.b_switch_signals.get_status() self.signal_line.set_data(self.time, self.signal_amp_slider.val * self.signal) self.seis_line.set_data(self.time, (self.seis_amp_slider.val * self.seis) + self.seis_height_slider.val) self.phono_line.set_data(self.time, (self.phono_amp_slider.val * self.phono) + self.phono_height_slider.val) if label[0]: # ECG self.signal_line.set_linewidth(0) else: self.signal_line.set_linewidth(0.5) if label[1]: # Seismo self.seis_line.set_linewidth(0) else: self.seis_line.set_linewidth(0.5) if label[2]: # Phono self.phono_line.set_linewidth(0) else: self.phono_line.set_linewidth(0.5) # Update green shaded region self.bound_span.remove() self.bound_span = self.ax.axvspan(self.lower_bound, self.upper_bound, facecolor='g', alpha=0.25) self.lower_bound_text.set_position((self.lower_bound, self.bound_text_height)) self.lower_bound_text.set_text("Lower Bound\n" + str(self.lower_bound)) self.upper_bound_text.set_position((self.upper_bound, self.bound_text_height)) self.upper_bound_text.set_text("Upper Bound\n" + str(self.upper_bound)) # Update x and y lim if self.new_bounds: sig_min = min(self.signal) sig_max = max(self.signal) self.ax.set_xlim(self.time[0] - 0.1*(self.time[-1] - self.time[0]), self.time[-1] + 0.1*(self.time[-1] - self.time[0])) self.ax.set_ylim(sig_min - 0.1*(sig_max - sig_min), sig_max + 0.1*(sig_max - sig_min)) self.new_bounds = False self.fig.canvas.draw() def on_click(self, event): threshold = 2 self.update_point = None self.new_bounds = False # Make sure a click happened inside the subplot if (event.xdata is not None) and (str(type(event.inaxes)) == "<class 'matplotlib.axes._subplots.AxesSubplot'>"): if abs(self.lower_bound - event.xdata) < threshold: self.lx.set_color('g') self.ly.set_color('g') self.lx.set_linewidth(1) self.ly.set_linewidth(1) self.fig.canvas.draw() self.update_point = "lower_bound" elif abs(self.upper_bound - event.xdata) < threshold: self.lx.set_color('g') self.ly.set_color('g') self.lx.set_linewidth(1) self.ly.set_linewidth(1) self.fig.canvas.draw() self.update_point = "upper_bound" def off_click(self, event): self.lx.set_color('k') self.ly.set_color('k') self.lx.set_linewidth(0.2) self.ly.set_linewidth(0.2) if event.xdata is not None: if self.update_point == "lower_bound": self.lower_bound = max(self.time[0], round(event.xdata, 1)) if self.update_point == "upper_bound": self.upper_bound = min(self.time[-1], round(event.xdata, 1)) lower = min(self.lower_bound, self.upper_bound) upper = max(self.lower_bound, self.upper_bound) self.lower_bound = lower self.upper_bound = upper # Update on signal self.switch_signal(self.b_switch_signals.get_status()) self.fig.canvas.draw() self.fig.canvas.flush_events() def mouse_move(self, event): # If nothing happened do nothing if not event.inaxes: return # Update x data point x = event.xdata # Lock to closest x coordinate on signal indx = min(np.searchsorted(self.x, x), len(self.x) - 1) x = self.x[indx] y = self.y[indx] # Update the crosshairs self.lx.set_ydata(y) self.ly.set_xdata(x) # Draw everything self.ax.figure.canvas.draw() def next(self, event): # Update Dosage if self.use_intervals: self.save_intervals(None) self.dosage += 10 if self.dosage > 40: self.dosage = 0 # Update plot and Save self.new_bounds = True self.update_plot() def prev(self, event): # Update Dosage if self.use_intervals: self.save_intervals(None) self.dosage -= 10 if self.dosage < 0: self.dosage = 40 # Update plot and Save self.new_bounds = True self.update_plot() def save_intervals(self, event): # Save bounds save_filename = "data/Derived/intervals/Interval_Dict_" + self.folder_name self.files[self.folder_name][self.dosage][self.file_number]["intervals"][1][0] = self.lower_bound self.files[self.folder_name][self.dosage][self.file_number]["intervals"][1][1] = self.upper_bound # Save Bounds with open(save_filename + '.pkl', 'wb') as output: pickle.dump(self.files, output, pickle.HIGHEST_PROTOCOL) print("Saved Intervals") self.use_intervals = True def save_signals(self): print("Saving Signals...") for dosage in self.files[self.folder_name]: # Load data self.time, self.signal, self.seis, _, self.phono, _ = hb.load_file_data(files = self.files, folder_name = self.folder_name, dosage = dosage, file_number = self.file_number) # Load Echo Time self.echo_time = self.files[self.folder_name][dosage][self.file_number]["echo_time"] # Clip signal size about echo time self.clip_signals() # Save File Name save_file_name = self.folder_name + "_d" + str(dosage) assert not os.path.isfile('data/Derived/signals/time_' + save_file_name + '.csv'), "Saved signals already exist, please delete before saving new signals." # Save Signals np.savetxt('data/Derived/signals/time_' + save_file_name + '.csv', self.time, delimiter=',') np.savetxt('data/Derived/signals/signal_'+ save_file_name + '.csv', self.signal, delimiter=',') np.savetxt('data/Derived/signals/seis_' + save_file_name + '.csv', self.seis, delimiter=',') np.savetxt('data/Derived/signals/phono_' + save_file_name + '.csv', self.phono, delimiter=',') print("\tDosage " + str(dosage) + " done") print("...Done Saving Signals") # Use these saved files from now on self.preloaded_signal = True def load_signals(self): # Load Echo Time self.echo_time = self.files[self.folder_name][self.dosage][self.file_number]["echo_time"] if self.preloaded_signal: save_file_name = self.folder_name + "_d" + str(self.dosage) self.time = np.loadtxt('data/Derived/signals/time_' + save_file_name + '.csv', delimiter=',') self.signal = np.loadtxt('data/Derived/signals/signal_' + save_file_name + '.csv', delimiter=',') self.seis = np.loadtxt('data/Derived/signals/seis_' + save_file_name + '.csv', delimiter=',') self.phono = np.loadtxt('data/Derived/signals/phono_' + save_file_name + '.csv', delimiter=',') else: self.time, self.signal, self.seis, _, self.phono, _ = hb.load_file_data(files = self.files, folder_name = self.folder_name, dosage = self.dosage, file_number = self.file_number) # Clip Signals self.clip_signals() def update_plot(self): # Display Loading Screen self.dosage_text.set_text("Loading: " + str(self.dosage) ) print("Loading: " + str(self.dosage) ) self.fig.canvas.draw() # Update index self.folder_text.set_text("Folder: " + self.folder_name) self.dosage_text.set_text("Dosage: " + str(self.dosage)) self.file_name_text.set_text("File: " + self.files[self.folder_name][self.dosage][self.file_number]["file_name"]) # Load composite signals self.load_signals() # Update Echo Line self.echo_line.set_xdata(self.echo_time) # Find new orginal bounds self.initialize_bounds() # Update lines self.switch_signal(self.b_switch_signals.get_status()) print("done")
class Vis(object): def __init__(self, fn): mf = flopy.modflow.Modflow.load(fn, model_ws=os.path.dirname(fn), load_only=['DIS', 'BAS6', 'UPW']) mf_files = gw_utils.get_mf_files(fn) hds_fn = mf_files['hds'][1] hob_out_fn = mf_files['HOB'][1] self.fn_hobOut = hob_out_fn try: # text file import flopy.utils.formattedfile as ff hds = ff.FormattedHeadFile(hds_fn, precision='single') except: # binary hds = flopy.utils.HeadFile(hds_fn) pass self.mf = mf self.hds = hds self.layout() def layout(self): # show main figure fig, ax = plt.subplots() self.fig = fig self.ax = ax self.all_times = self.hds.get_times() self.curr_time_index = 0 self.curr_layer = 0 self.togle_TS = 0 self.curr_label = 'Grid Elev' mng = plt.get_current_fig_manager() mng.window.showMaximized() # only windows # add buttons ax_Extra = plt.axes([0.5, 0.02, 0.05, 0.05]) self.Extrabn = Button(ax_Extra, 'Extra') self.Extrabn.on_clicked(self.strat_extra_frame) ax_timeseries = plt.axes([0.55, 0.02, 0.05, 0.05]) self.Tsbn = Button(ax_timeseries, 'Hydrograph') self.Tsbn.on_clicked(self.PotTS) axBackward = plt.axes([0.6, 0.02, 0.05, 0.05]) self.Bkbn = Button(axBackward, '<<') self.Bkbn.on_clicked(self.plot_backward) axForward = plt.axes([0.65, 0.02, 0.05, 0.05]) self.Forbn = Button(axForward, '>>') self.Forbn.on_clicked(self.plot_forward) axUP= plt.axes([0.7, 0.02, 0.05, 0.05]) self.upbn = Button(axUP, 'Up') self.upbn.on_clicked(self.changeLayerUp) axDN = plt.axes([0.75, 0.02, 0.05, 0.05]) self.Dnbn = Button(axDN, 'Down') self.Dnbn.on_clicked(self.chaneLayerDn) #axcolor = 'lightgoldenrodyellow' plt.text(0.01, 0.875, "MF Package", fontsize=14, transform=plt.gcf().transFigure) rax = plt.axes([0.01, 0.65, 0.09, 0.22]) radio = RadioButtons(rax, ('Grid Elev', 'Heads', 'Drawdown', 'flow direction', 'Grid Thickness', 'IBOUND', 'STRT', 'HK', 'VK', 'SS', 'SY', 'GWDepth')) radio.on_clicked(self.data_mode) plt.text(0.01, 0.16, "Point Layer", fontsize=14, transform=plt.gcf().transFigure) cax = plt.axes([0.01, 0.001, 0.09, 0.15]) self.CheckedPointlabels = [] self.checkPoints = CheckButtons(cax, ['HOBS', 'WELLS', 'GAGES'], actives = (False, False, False)) self.checkPoints.on_clicked(self.plot_points) self.LayerTxt = plt.text(0.8, 0.87, "Layer {}".format(self.curr_layer), fontsize=14, transform=plt.gcf().transFigure) self.TimeTxt = plt.text(0.2, 0.87, "Totim {}".format(self.all_times[self.curr_time_index]), fontsize=14, transform=plt.gcf().transFigure) plt.show() def strat_extra_frame(self, event): self.plot_thickness_trans() def plot_thickness_trans(self): thk = self.mf.dis.thickness.array ib3d = self.mf.bas6.ibound.array ib3d[ib3d != 0] = 1 hk = self.mf.upw.hk.array Total_thickness = (thk * ib3d).sum(axis = 0) trans = (thk * ib3d*hk).sum(axis = 0) ib2 = ib3d.sum(axis=0) ib2 = ib2 / ib2 fig1, ax1 = plt.subplots() s1 = ax1.imshow(np.log10(trans*ib2), cmap = 'jet') ax1.set_title("Log Total Transmissivity") fig1.colorbar(s1, ax = ax1) fig1.canvas.draw() fig2, ax2 = plt.subplots() s2 = ax2.imshow(Total_thickness * ib2, cmap = 'plasma') ax2.set_title("Log Total Thickness") fig2.colorbar(s2, ax=ax2) fig2.canvas.draw() plt.show() xx = 1 pass def plot_points(self, event): from gw_utils import hob_output_to_df, in_hob_to_df # ugly fn = os.path.join(self.mf.model_ws, self.mf.namefile) self.hobin = in_hob_to_df(mfname = fn) self.houout = hob_output_to_df(fn, mf=self.mf) status = self.checkPoints.get_status() for i, label in enumerate(self.checkPoints.labels): if status[i]: self.CheckedPointlabels.append(label.get_text()) for label in self.CheckedPointlabels: if label == 'HOBS': self.axHob = self.ax.plot(self.hobin['col'], self.hobin['row'], marker= '.', markeredgecolor = 'b', picker = 5, linestyle = 'None') self.fig.canvas.mpl_connect('pick_event', self.on_pick_hob) self.fig.canvas.draw() xx = 1 def on_pick_hob(self, event): # make code to select the colosets try: self.AxSelection.remove() except: pass ind = event.ind if len(ind) > 1: datax, datay = event.artist.get_data() datax, datay = [datax[i] for i in ind], [datay[i] for i in ind] msx, msy = event.mouseevent.xdata, event.mouseevent.ydata dist = np.sqrt((np.array(datax) - msx) ** 2 + (np.array(datay) - msy) ** 2) close_i = np.argmin(dist) ind = [ind[np.argmin(dist)]] x= datax[close_i] y = datay[close_i] self.AxSelection = self.ax.scatter(x, y, marker='*', c='r', s=200) name = self.hobin.iloc[ind]['name'].values[0] hrow =self.hobin.iloc[ind]['row'].values[0] hcol = self.hobin.iloc[ind]['col'].values[0] curr_hob_in = self.hobin[(self.hobin['row'] == hrow) & (self.hobin['col'] == hcol)] names = curr_hob_in['name'].unique() curr_obs_df = self.houout[self.houout['OBSERVATION NAME'].isin(names)] curr_obs_df = curr_obs_df.drop_duplicates(subset='OBSERVATION NAME', keep='last') curr_hob_in = curr_hob_in.drop_duplicates(subset='name', keep='first') self.plot_head_ts(x, y) tim = curr_hob_in['totim'].values obs_values =curr_obs_df['OBSERVED VALUE'].values sim_values = curr_obs_df['SIMULATED EQUIVALENT'].values maskSim = 0 sim_values[sim_values==maskSim] = np.NaN self.axhydro.scatter(tim, obs_values, c = 'r', label='OBS') self.axhydro.scatter(tim, sim_values, c= 'b', label='SIM') self.axhydro.legend() title = self.axhydro.get_title() if "." in names[0]: name = names[0].split(".")[0] else: name = names[0] title = title + "\n" + name kkeys = curr_hob_in['mlay'].values[0].keys() vals = curr_hob_in['mlay'].values[0].values() kval = zip(kkeys, vals) title = title + "\n" for kv in list(kval): title = title + " {} : {}, ".format(kv[0], kv[1]) self.axhydro.set_title(title) event.canvas.draw() event.canvas.flush_events() def get_arr(self): if self.curr_label in ['Heads', 'flow direction']: totim = self.all_times[self.curr_time_index] self.maxDataLayer = self.mf.nlay arr = self.hds.get_data(totim=totim) elif self.curr_label == 'Grid Elev': self.maxDataLayer = self.mf.nlay arr = np.zeros((self.mf.nlay+1, self.mf.nrow, self.mf.ncol)) arr[0,:,:] = self.mf.dis.top.array arr[1:, :, :] = self.mf.dis.botm.array elif self.curr_label =='Drawdown': totim = self.all_times[self.curr_time_index] totim0 = self.all_times[0] self.maxDataLayer = self.mf.nlay arr = self.hds.get_data(totim=totim)-self.hds.get_data(totim=totim0) elif self.curr_label =='HK': arr = self.mf.upw.hk.array elif self.curr_label =='VK': arr = self.mf.upw.vka.array elif self.curr_label =='IBOUND': arr = self.mf.bas6.ibound.array elif self.curr_label == 'STRT': arr = self.mf.bas6.strt.array elif self.curr_label == 'SS': arr = self.mf.upw.ss.array elif self.curr_label == 'SY': arr = self.mf.upw.sy.array elif self.curr_label == 'GWDepth': totim = self.all_times[self.curr_time_index] self.maxDataLayer = self.mf.nlay arr = self.hds.get_data(totim=totim) arr = arr.copy() ttop = self.mf.dis.top.array.copy() for k in range(self.mf.nlay): arr[k, :,:] = ttop - arr[k,:,:] return arr def data_mode(self, label): self.curr_label = label def PotTS(self, event): if self.togle_TS == 0: self.togle_TS = 1 self.Tsbn.color = 'red' cid1 = self.fig.canvas.mpl_connect('button_press_event', lambda event: self.click_hydrograph(event)) print("red") else: self.togle_TS = 0 self.Tsbn.color = 'green' cid2 = self.fig.canvas.mpl_connect('close_event', lambda event: self.close_hydro(event)) def close_hydro(self, event): pass def plot_head_ts(self, x, y): Ly = self.mf.dis.delc.array.sum() Lx = self.mf.dis.delr.array.sum() x = Lx * x / self.mf.ncol y = Ly - Ly * y / self.mf.nrow rr_cc1 = plotutil.findrowcolumn((x, y), self.mf.modelgrid.xyedges[0], self.mf.modelgrid.xyedges[1]) self.fighydro, self.axhydro = plt.subplots() elevs = [] pre_elev = 0 for k in range(self.mf.nlay): rr_cc = (k, rr_cc1[0], rr_cc1[1]) ib = self.mf.bas6.ibound.array[k, rr_cc1[0], rr_cc1[1]] # get elevation if k == 0: elev = self.mf.dis.top.array[rr_cc1[0], rr_cc1[1]] pre_elev = elev elevs.append(elev) if ib != 0: elev = self.mf.dis.botm.array[k, rr_cc1[0], rr_cc1[1]] pre_elev = elev elevs.append(elev) else: elevs.append(pre_elev) ts = self.hds.get_ts(rr_cc) ts[:, 1][ts[:, 1] == self.mf.bas6.hnoflo] = np.NaN lb = "Head {}".format(k + 1) self.axhydro.plot(ts[:, 0], ts[:, 1], label=lb) # plot layers if 0: for ielev, elev in enumerate(elevs): ts2 = ts.copy() ts2[:, 1] = elev ax.plot(ts2[:, 0], ts2[:, 1], label="Layer {}".format(ielev)) plt.legend() plt.title("Row = {}, Col= {}".format(rr_cc1[0], rr_cc1[1])) plt.show() def click_hydrograph(self,event): if self.togle_TS==0: return try: self.axHeadLocation.remove() except: pass self.axHeadLocation = self.ax.scatter(event.xdata, event.ydata, color='r') event.canvas.draw() event.canvas.flush_events() ## if 0: Ly = self.mf.dis.delc.array.sum() Lx = self.mf.dis.delr.array.sum() x = Lx * event.xdata/self.mf.ncol y = Ly - Ly * event.ydata / self.mf.nrow self.plot_head_ts(event.xdata,event.ydata) def chaneLayerDn(self, event): self.curr_layer = self.curr_layer + 1 if self.curr_layer > self.mf.nlay-1 : self.curr_layer = self.curr_layer - 1 return False self.update_fig() print(self.curr_layer) def changeLayerUp(self, event): self.curr_layer = self.curr_layer - 1 if self.curr_layer < 0: self.curr_layer = self.curr_layer + 1 return False self.update_fig() print(self.curr_layer) def update_fig(self): arr = self.get_arr() #self.fig.suptitle("Totim = {} ".format(totim)) if self.curr_label == 'flow direction': self._plot_dir(arr) else: self._plot_arr(arr) def _plot_dir(self, arr): self.ax.clear() arr = arr[self.curr_layer, :,:] ibound = self.mf.bas6.ibound.array[self.curr_layer, :,:] arr[ibound == 0] = np.NaN # ******************* # second subplot x = np.arange(0, self.mf.ncol, 1) y = np.arange(0, self.mf.nrow, 1) X, Y = np.meshgrid(x, y) dx, dy = np.gradient(arr) dx = dx dy = dy n = -2 color = np.sqrt(((dx - n) / 2)**2 + ((dy - n) / 2)**2) f = lambda x: np.sign(x) * np.log10(1 + np.abs(x)) im = self.ax.quiver(X, Y, f(dx), f(dy), color, scale=0.5, units="xy") #im = self.ax.streamplot(X, Y, dx, dy) self.ax.contour(arr, colors = 'k') self.ax.invert_yaxis() self.LayerTxt.remove() self.LayerTxt = plt.text(0.8, 0.87, "Layer {}".format(self.curr_layer+1), fontsize=14, transform=plt.gcf().transFigure, color='r') self.TimeTxt.remove() self.TimeTxt = plt.text(0.2, 0.87, "Totim {}".format(self.all_times[self.curr_time_index]), fontsize=14, transform=plt.gcf().transFigure) self.fig.canvas.draw() try: self.cax.remove() except: pass self.cax = self.fig.colorbar(im, fraction = 0.02, pad = 0.01, ax= self.ax, orientation = 'vertical' ) def _plot_arr(self, arr): self.ax.clear() arr = arr[self.curr_layer, :,:] ibound = self.mf.bas6.ibound.array[self.curr_layer, :,:] arr[ibound == 0] = np.NaN im = self.ax.imshow(arr) self.ax.contour(arr, colors = 'k') self.LayerTxt.remove() self.LayerTxt = plt.text(0.8, 0.87, "Layer {}".format(self.curr_layer+1), fontsize=14, transform=plt.gcf().transFigure, color='r') self.TimeTxt.remove() self.TimeTxt = plt.text(0.2, 0.87, "Totim {}".format(self.all_times[self.curr_time_index]), fontsize=14, transform=plt.gcf().transFigure) self.fig.canvas.draw() from mpl_toolkits.axes_grid1 import make_axes_locatable divider = make_axes_locatable(self.ax) try: self.cax.remove() del(divider) except: pass self.cax = self.fig.colorbar(im, fraction = 0.02, pad = 0.01, ax= self.ax, orientation = 'vertical' ) def plot_forward(self, event): self.curr_time_index = self.curr_time_index + 1 if self.curr_time_index> len(self.all_times)-1: self.curr_time_index = self.curr_time_index - 1 return False self.update_fig() def plot_backward(self, event): self.curr_time_index = self.curr_time_index - 1 if self.curr_time_index < 0: self.curr_time_index = self.curr_time_index - 1 return False self.update_fig()
class Oscillo: def __init__( self, channel_list, sampling=5e3, volt_range=10.0, trigger_level=None, trigsource=None, backend="daqmx", ): if backend not in Backends: raise ValueError(f"Backend {backend:} is not available.") self.backend = backend self.channel_list = channel_list self.sampling = sampling self.volt_range = volt_range self.ignore_voltrange_submit = False self.trigger_level = trigger_level self.trigger_source = trigsource self.N = 1024 plt.ion() self.fig = plt.figure() # left, bottom, width, height self.ax = self.fig.add_axes([0.1, 0.1, 0.7, 0.8]) self.running = False self.last_trigged = 0 self.ask_pause_acqui = False self.paused = False self.freq = None self.Pxx = None self.N_spectra = 0 self.fig_spectrum = None self.hanning = True self.ac = True self.power_spectrum = True self.spectrum_unit = 1.0 self.spectrum_unit_str = "V^2/Hz" self.system = None self.saved_spectra = list() self.fig_stats = None # Configure widgets left, bottom = 0.825, 0.825 width, height = 0.15, 0.075 vpad = 0.02 ax_sampling = self.fig.add_axes([left, bottom, width, height]) bottom -= height * 1.25 + 2 * vpad self.textbox_sampling = TextBox(ax_sampling, label="", initial=f"{sampling:.1e}") self.textbox_sampling.on_submit(self.ask_sampling_change) _ = ax_sampling.text(0, 1.25, "Sampling") ax_enable_trigger = self.fig.add_axes([left, bottom, width, height]) bottom -= height * 1.25 + 2 * vpad self.textbox_triggersource = TextBox(ax_enable_trigger, label="", initial="None") self.textbox_triggersource.on_submit(self.ask_trigger_change) _ = ax_enable_trigger.text(0, 1.25, "Trigger") ax_triggerlevel = self.fig.add_axes([left, bottom, width, height]) bottom -= height * 1.25 + 2 * vpad if trigger_level is not None: initial_value = f"{trigger_level:.2f}" else: initial_value = "1.0" self.textbox_triggerlevel = TextBox(ax_triggerlevel, label="", initial=initial_value) self.textbox_triggerlevel.on_submit(self.ask_trigger_change) _ = ax_triggerlevel.text(0, 1.25, "Level") ax_winsize = self.fig.add_axes([left, bottom, width, height]) bottom -= height * 1.25 + 2 * vpad self.textbox_winsize = TextBox(ax_winsize, label="", initial=f"{self.N:d}") self.textbox_winsize.on_submit(self.ask_winsize_change) _ = ax_winsize.text(0, 1.25, "Win size") ax_voltrange = self.fig.add_axes([left, bottom, width, height]) bottom -= height * 1.25 + 2 * vpad self.textbox_voltrange = TextBox(ax_voltrange, label="", initial=f"{self.volt_range:.1f}") self.textbox_voltrange.on_submit(self.ask_voltrange_change) _ = ax_voltrange.text(0, 1.25, "Range") ax_start_stats = self.fig.add_axes([left, bottom, width, height]) bottom -= height + vpad self.btn_start_stats = Button(ax_start_stats, label="Stats") self.btn_start_stats.on_clicked(self.start_stats) ax_start_spectrum = self.fig.add_axes([left, bottom, width, height]) bottom -= height + vpad self.btn_start_spectrum = Button(ax_start_spectrum, label="FFT") self.btn_start_spectrum.on_clicked(self.start_spectrum) def clean_spectrum(self, *args): self.freq = None self.Pxx = None self.N_spectra = 0 def start_stats(self, event): if self.fig_stats is None: self.fig_stats = dict() self.box_mean = dict() self.box_std = dict() self.box_min = dict() self.box_max = dict() self.box_freq = dict() nbox = 5 height = 1.0 / (nbox + 1) padding = height / 4 for chan in self.channel_list: if chan not in self.fig_stats: self.fig_stats[chan] = plt.figure(figsize=(2, 4)) self.fig_stats[chan].canvas.set_window_title(chan) ax_mean = self.fig_stats[chan].add_axes( [0.25, 9 * height / 2, 0.7, height - padding]) self.box_mean[chan] = TextBox(ax_mean, label="Mean", initial="") ax_std = self.fig_stats[chan].add_axes( [0.25, 7 * height / 2, 0.7, height - padding]) self.box_std[chan] = TextBox(ax_std, label="Std", initial="") ax_min = self.fig_stats[chan].add_axes( [0.25, 5 * height / 2, 0.7, height - padding]) self.box_min[chan] = TextBox(ax_min, label="Min", initial="") ax_max = self.fig_stats[chan].add_axes( [0.25, 3 * height / 2, 0.7, height - padding]) self.box_max[chan] = TextBox(ax_max, label="Max", initial="") ax_freq = self.fig_stats[chan].add_axes( [0.25, height / 2, 0.7, height - padding]) self.box_freq[chan] = TextBox(ax_freq, label="Freq", initial="") def start_spectrum(self, *args, **kwargs): if self.fig_spectrum is None: self.fig_spectrum = plt.figure() self.ax_spectrum = self.fig_spectrum.add_axes([0.1, 0.1, 0.7, 0.8]) # Widgets ax_hanning = self.fig_spectrum.add_axes([0.825, 0.75, 0.15, 0.15]) self.checkbox_hanning = CheckButtons(ax_hanning, ["Hanning", "AC"], [self.hanning, self.ac]) self.checkbox_hanning.on_clicked(self.ask_hanning_change) ax_spectrum_unit = self.fig_spectrum.add_axes( [0.825, 0.51, 0.15, 0.25]) self.radio_units = RadioButtons( ax_spectrum_unit, ["V^2/Hz", "V/sq(Hz)", "mV/sq(Hz)", "µV/sq(Hz)", "nV/sq(Hz)"], ) self.radio_units.on_clicked(self.ask_spectrum_units_change) ax_restart = self.fig_spectrum.add_axes([0.825, 0.35, 0.15, 0.075]) self.btn_restart = Button(ax_restart, label="Restart") self.btn_restart.on_clicked(self.clean_spectrum) ax_save_hold = self.fig_spectrum.add_axes( [0.825, 0.25, 0.15, 0.075]) self.btn_save_hold = Button(ax_save_hold, label="Hold&Save") self.btn_save_hold.on_clicked(self.save_hold) self.clean_spectrum() def save_hold(self, event): if self.N_spectra > 0: dt = datetime.now() filename = (f"pymanip_oscillo_{dt.year:}-{dt.month}-{dt.day}" f"_{dt.hour}-{dt.minute}-{dt.second}.hdf5") bb = self.freq > 0 with h5py.File(filename) as f: f.attrs["ts"] = dt.timestamp() f.attrs["N_spectra"] = self.N_spectra f.attrs["sampling"] = self.sampling f.attrs["volt_range"] = self.volt_range f.attrs["N"] = self.N f.create_dataset("freq", data=self.freq[bb]) f.create_dataset("Pxx", data=self.Pxx[bb] / self.N_spectra) self.saved_spectra.append({ "freq": self.freq[bb], "Pxx": self.Pxx[bb] / self.N_spectra }) def ask_spectrum_units_change(self, event): power_spectrum_dict = { "V^2/Hz": True, "V/sq(Hz)": False, "mV/sq(Hz)": False, "µV/sq(Hz)": False, "nV/sq(Hz)": False, } spectrum_unit_dict = { "V^2/Hz": 1.0, "V/sq(Hz)": 1.0, "mV/sq(Hz)": 1e3, "µV/sq(Hz)": 1e6, "nV/sq(Hz)": 1e9, } if event in power_spectrum_dict: self.spectrum_unit_str = event self.power_spectrum = power_spectrum_dict[event] self.spectrum_unit = spectrum_unit_dict[event] async def winsize_change(self, new_N): await self.pause_acqui() old_N = self.N self.N = new_N self.clean_spectrum() self.figure_t_axis() try: self.system.configure_clock(self.sampling, self.N) except Exception as e: print(e) self.N = old_N await self.restart_acqui() def ask_winsize_change(self, label): try: new_N = int(label) except ValueError: print("winsize must be an integer") return loop = asyncio.get_event_loop() asyncio.ensure_future(self.winsize_change(new_N), loop=loop) def ask_trigger_change(self, label): changed = False trigger_source = self.textbox_triggersource.text possible_trigger_source = self.system.possible_trigger_channels() if trigger_source not in possible_trigger_source: print(f"{trigger_source:} is not an acceptable trigger source") print("Possible values:", possible_trigger_source) trigger_source = None self.textbox_triggersource.set_val("None") if trigger_source == "None": trigger_source = None if self.trigger_source != trigger_source: self.trigger_source = trigger_source changed = True if self.trigger_source is not None: try: trigger_level = float(self.textbox_triggerlevel.text) except ValueError: trigger_level = self.trigger_level val_str = f"{trigger_level:.2f}" self.textbox_triggerlevel.set_val(val_str) if trigger_level != self.trigger_level: self.trigger_level = trigger_level changed = True if changed: loop = asyncio.get_event_loop() asyncio.ensure_future(self.trigger_change(), loop=loop) async def pause_acqui(self): self.ask_pause_acqui = True await self.system.stop() while not self.paused: await asyncio.sleep(0.5) async def restart_acqui(self): self.ask_pause_acqui = False while self.paused: await asyncio.sleep(0.5) async def trigger_change(self): print("Awaiting pause") await self.pause_acqui() print("Acquisition paused") if self.trigger_level is not None: trigger_source = self.trigger_source trigger_level = self.trigger_level self.system.configure_trigger(trigger_source, trigger_level) else: self.system.configure_trigger(None) self.clean_spectrum() await self.restart_acqui() async def sampling_change(self): await self.pause_acqui() try: self.system.configure_clock(self.sampling, self.N) except Exception: print("Invalid sampling frequency") self.ask_sampling_change(self.system.samp_clk_max_rate) return if self.system.sample_rate != self.sampling: self.ask_sampling_change(self.system.sample_rate) return self.figure_t_axis() self.clean_spectrum() await self.restart_acqui() def ask_sampling_change(self, sampling): try: self.sampling = float(sampling) changed = True except ValueError: print("Wrong value:", sampling) changed = False self.textbox_sampling.set_val(f"{self.sampling:.1e}") if changed: loop = asyncio.get_event_loop() asyncio.ensure_future(self.sampling_change(), loop=loop) def ask_hanning_change(self, label): self.hanning, self.ac = self.checkbox_hanning.get_status() self.clean_spectrum() def figure_t_axis(self): self.t = np.arange(self.N) / self.sampling if self.t[-1] < 1: self.t *= 1000 self.unit = "[ms]" else: self.unit = "[s]" async def run_gui(self): while self.running: if time.monotonic() - self.last_trigged > self.N / self.sampling: self.ax.set_title("Waiting for trigger") self.fig.canvas.start_event_loop(0.5) await asyncio.sleep(0.05) if not plt.fignum_exists(self.fig.number): self.running = False if self.fig_spectrum and not plt.fignum_exists( self.fig_spectrum.number): self.fig_spectrum = None self.freq = None self.Pxx = None self.N_spectra = 0 if self.fig_stats is not None: for chan in list(self.fig_stats.keys()): if not plt.fignum_exists(self.fig_stats[chan].number): self.fig_stats.pop(chan) self.box_mean.pop(chan) self.box_std.pop(chan) self.box_min.pop(chan) self.box_max.pop(chan) self.box_freq.pop(chan) if not self.fig_stats: self.fig_stats = None def ask_voltrange_change(self, new_range): if not self.ignore_voltrange_submit: try: new_range = float(new_range) except ValueError: print("Volt range must be a float") return loop = asyncio.get_event_loop() asyncio.ensure_future(self.voltrange_change(new_range), loop=loop) async def voltrange_change(self, new_range): await self.pause_acqui() self.volt_range = new_range self.clean_spectrum() self.create_system() await self.restart_acqui() actual_range = self.system.actual_ranges[0] print("actual_range =", actual_range) self.ignore_voltrange_submit = True self.textbox_voltrange.set_val(f"{actual_range:.1f}") self.ignore_voltrange_submit = False def create_system(self): if self.system is not None: self.system.close() self.system = Backends[self.backend]() self.ai_channels = list() for chan in self.channel_list: ai_chan = self.system.add_channel(chan, terminal_config=TC.Diff, voltage_range=self.volt_range) self.ai_channels.append(ai_chan) self.system.configure_clock(self.sampling, self.N) self.textbox_sampling.set_val(f"{self.system.sample_rate:.1e}") if self.trigger_level is not None: trigger_source = self.channel_list[self.trigger_source] self.system.configure_trigger(trigger_source, trigger_level=self.trigger_level) else: self.system.configure_trigger(None) self.figure_t_axis() async def run_acqui(self): self.create_system() try: while self.running: while self.ask_pause_acqui and self.running: self.paused = True await asyncio.sleep(0.5) self.paused = False if not self.running: break self.system.start() data = await self.system.read() await self.system.stop() if data is None: continue self.last_trigged = self.system.last_read self.ax.cla() if len(self.channel_list) == 1: self.ax.plot(self.t, data, "-") elif len(self.channel_list) > 1: for d in data: self.ax.plot(self.t, d, "-") if self.trigger_source not in (None, "Ext"): self.ax.plot([self.t[0], self.t[-1]], [self.trigger_level] * 2, "g--") self.ax.set_xlim([self.t[0], self.t[-1]]) self.ax.set_title("Trigged!") self.ax.set_xlabel("t " + self.unit) if self.fig_spectrum: self.ax_spectrum.cla() if self.saved_spectra: for spectra in self.saved_spectra: self.ax_spectrum.loglog(spectra["freq"], spectra["Pxx"], "-") if self.N_spectra == 0: self.freq = np.fft.fftfreq(self.N, 1.0 / self.sampling) bb = self.freq > 0 norm = math.pi * math.sqrt(self.N / self.sampling) if self.hanning: window = np.hanning(self.N) else: window = np.ones((self.N, )) if len(self.channel_list) == 1: if self.ac: m = np.mean(data) else: m = 0.0 self.Pxx = (np.abs( np.fft.fft((data - m) * window) / norm)**2) else: if self.ac: ms = [np.mean(d) for d in data] else: ms = [0.0 for d in data] self.Pxx = [ np.abs(np.fft.fft((d - m) * window) / norm)**2 for d, m in zip(data, ms) ] self.N_spectra = 1 else: if len(self.channel_list) == 1: if self.ac: m = np.mean(data) else: m = 0.0 self.Pxx += (np.abs( np.fft.fft((data - m) * window) / norm)**2) else: if self.ac: ms = [np.mean(d) for d in data] else: ms = [0.0 for d in data] for p, d, m in zip(self.Pxx, data, ms): p += np.abs( np.fft.fft((d - m) * window) / norm)**2 self.N_spectra += 1 if self.power_spectrum: def process_spec(s): return self.spectrum_unit * s else: def process_spec(s): return self.spectrum_unit * np.sqrt(s) if len(self.channel_list) == 1: self.ax_spectrum.loglog( self.freq[bb], process_spec(self.Pxx[bb] / self.N_spectra), "-", ) else: for p in self.Pxx: self.ax_spectrum.loglog( self.freq[bb], process_spec(p[bb] / self.N_spectra), "-") self.ax_spectrum.set_xlabel("f [Hz]") self.ax_spectrum.set_ylabel(self.spectrum_unit_str) self.ax_spectrum.set_title(f"N = {self.N_spectra:d}") if self.fig_stats: if len(self.channel_list) == 1: list_data = [data] else: list_data = data for chan, d in zip(self.channel_list, list_data): if chan in self.fig_stats: self.box_mean[chan].set_val("{:.5f}".format( np.mean(d))) self.box_std[chan].set_val("{:.5f}".format( np.std(d))) self.box_min[chan].set_val("{:.5f}".format( np.min(d))) self.box_max[chan].set_val("{:.5f}".format( np.max(d))) ff = np.fft.fftfreq(self.N, 1.0 / self.sampling) pp = np.abs(np.fft.fft(d - np.mean(d))) ii = np.argmax(pp) self.box_freq[chan].set_val("{:.5f}".format( ff[ii])) finally: self.system.close() def ask_exit(self, *args, **kwargs): self.running = False def run(self): loop = asyncio.get_event_loop() self.running = True if sys.platform == "win32": signal.signal(signal.SIGINT, self.ask_exit) else: for signame in ("SIGINT", "SIGTERM"): loop.add_signal_handler(getattr(signal, signame), self.ask_exit) loop.run_until_complete( asyncio.gather(self.run_gui(), self.run_acqui()))
class NetworkVisualisation(object): def __init__(self, units, data_points, min_range, max_range, quality, dataset, saves_path=None, seed=1): np.random.seed(seed) self.precision = quality self.min_range = min_range self.max_range = max_range self.dataset_type = dataset if dataset is Dataset.CIRCLE: self.dataset = get_circle_dataset(points=data_points, min_range=min_range, max_range=max_range, radius=0.8) elif dataset is Dataset.SPIRAL: self.dataset = get_spiral_dataset(data_points, classes=2) else: raise Exception("Invalid dataset type") self.data_points = self.dataset[:, :-1] self.data_labels = self.dataset[:, -1] self.data_space, self.dim_data = setup_data_space( min_range, max_range, quality) # Network Creation if saves_path and check_saved_network(units, dataset, saves_path): self.network = load_network(units, dataset, saves_path) else: if dataset is Dataset.CIRCLE: self.network = train_network_sigmoid(self.dataset, units=units, learning_rate=5e-3, window_size=1000) elif dataset is Dataset.SPIRAL: self.network = train_network_softmax(self.dataset, units=units, learning_rate=1, window_size=1000) else: raise Exception("Invalid dataset type") save_network(self.network, dataset, saves_path) self.default_network = dict( zip(self.network.keys(), [layer.copy() for layer in self.network.values()])) # GUI Visualisation self.perceptron1 = 0 self.is_relu = True self.perceptron2 = 0 self.connection = 0 self.is_pre_add = False self.is_sig = True self.all_p1_enabled = set(range(self.network["W1"].shape[1])) self.ignore_update = False fig = plt.figure(figsize=(13, 6.5)) self.plot_network(fig) self.plot_controls(fig) def plot_network(self, fig): _, out2, out3, out4 = forward(self.data_space, self.network, self.dataset_type, precision=self.precision) outer_points = self.data_points[self.data_labels == 1] inner_points = self.data_points[self.data_labels == 0] self.layer1_plot = Plot(fig, (4, 4), (0, 0), (1, 3), out2[:, :, 0], self.min_range, self.max_range) self.layer1_plot.ax.scatter(outer_points[:, 0], outer_points[:, 1], s=3, c="g", alpha=0.5) self.layer1_plot.ax.scatter(inner_points[:, 0], inner_points[:, 1], s=3, c="r", alpha=0.5) self.layer1_3d_plot = Plot3D(fig, (4, 4), (1, 0), (1, 3), self.precision, out2) self.layer2_plot = Plot(fig, (4, 4), (2, 0), (1, 3), out4[:, :, 0], self.min_range, self.max_range) self.layer2_plot.ax.scatter(outer_points[:, 0], outer_points[:, 1], s=3, c="g", alpha=0.5) self.layer2_plot.ax.scatter(inner_points[:, 0], inner_points[:, 1], s=3, c="r", alpha=0.5) self.layer2_3d_plot = Plot3D(fig, (4, 4), (3, 0), (1, 3), self.precision, out4) def plot_controls(self, fig): step_size = 0.01 padding = 5 # Plot 1 controls w1x_min = self.network["W1"][0].min() w1x_max = self.network["W1"][0].max() w1x_diff = (w1x_max - w1x_min) / 2 + padding w1y_min = self.network["W1"][1].min() w1y_max = self.network["W1"][1].max() w1y_diff = (w1y_max - w1y_min) / 2 + padding w1b_min = self.network["b1"].min() w1b_max = self.network["b1"].max() w1b_diff = (w1b_max - w1b_min) / 2 + padding p1x_ax = plot_to_grid(fig, (2, 16), (0, 12), (1, 1)) self.p1x_slid = Slider(p1x_ax, 'P1 x', valmin=w1x_min - w1x_diff, valmax=w1x_max + w1x_diff, valinit=self.network["W1"][0, 0], valstep=step_size) self.p1x_slid.on_changed(self.p1x_changed) p1y_ax = plot_to_grid(fig, (2, 16), (0, 13), (1, 1)) self.p1y_slid = Slider(p1y_ax, 'P1 y', valmin=w1y_min - w1y_diff, valmax=w1y_max + w1y_diff, valinit=self.network["W1"][1, 0], valstep=step_size) self.p1y_slid.on_changed(self.p1y_changed) p1b_ax = plot_to_grid(fig, (24, 16), (0, 14), (7, 1)) self.p1b_slid = Slider(p1b_ax, 'P1 b', valmin=w1b_min - w1b_diff, valmax=w1b_max + w1b_diff, valinit=self.network["b1"][0, 0], valstep=step_size) self.p1b_slid.on_changed(self.p1b_changed) p1_ax = plot_to_grid(fig, (24, 16), (0, 15), (7, 1)) self.p1_slid = Slider(p1_ax, 'P1', valmin=0, valmax=self.network["W1"].shape[1] - 1, valinit=self.perceptron1, valstep=1) self.p1_slid.on_changed(self.p1_changed) p1_opt_ax = plot_to_grid(fig, (24, 16), (8, 14), (3, 2)) self.p1_opt_buttons = CheckButtons(p1_opt_ax, ["ReLU?", "Enabled?"], [self.is_relu, True]) self.p1_opt_buttons.on_clicked(self.p1_options_update) # Plot 2 Controls w2_min = self.network["W2"].min() w2_max = self.network["W2"].max() w2_diff = (w2_max - w2_min) / 2 + padding w2b_abs = np.abs(self.network["b2"][0, 0]) + padding w2b_min = self.network["b2"][0, 0] - w2b_abs w2b_max = self.network["b2"][0, 0] + w2b_abs p2_weight_val_ax = plot_to_grid(fig, (2, 16), (1, 12), (1, 1)) self.p2_dim_val_slid = Slider(p2_weight_val_ax, 'p2 w', valmin=w2_min - w2_diff, valmax=w2_max + w2_diff, valinit=self.network["W2"][0, 0], valstep=step_size) self.p2_dim_val_slid.on_changed(self.p2_weight_changed) p2_connection_dim_ax = plot_to_grid(fig, (2, 16), (1, 13), (1, 1)) self.p2_connection_dim_slid = Slider( p2_connection_dim_ax, 'p2 c', valmin=0, valmax=self.network["W2"].shape[0] - 1, valinit=0, valstep=1) self.p2_connection_dim_slid.on_changed(self.p2_connection_dim_changed) p2b_ax = plot_to_grid(fig, (24, 16), (13, 14), (7, 1)) self.p2b_slid = Slider(p2b_ax, 'p2 b', valmin=w2b_min, valmax=w2b_max, valinit=self.network["b2"][0, 0], valstep=step_size) self.p2b_slid.on_changed(self.p2b_changed) p2_opt_ax = plot_to_grid(fig, (24, 16), (21, 14), (4, 2)) self.p2_opt_buttons = CheckButtons(p2_opt_ax, ["Pre-add?", "Transform?"], [self.is_pre_add, self.is_sig]) self.p2_opt_buttons.on_clicked(self.p2_options_update) def p1_changed(self, val): self.perceptron1 = int(val) self.ignore_update = True self.update_widgets() self.ignore_update = False self.update_just_plot1() def p1x_changed(self, val): self.network["W1"][0, self.perceptron1] = val self.update_visuals() def p1y_changed(self, val): self.network["W1"][1, self.perceptron1] = val self.update_visuals() def p1b_changed(self, val): self.network["b1"][0, self.perceptron1] = val self.update_visuals() def p1_options_update(self, label): if label == "ReLU?": self.is_relu = not self.is_relu self.update_just_plot1() elif label == "Enabled?": is_enabled = self.p1_opt_buttons.get_status()[1] if is_enabled and self.perceptron1 not in self.all_p1_enabled: self.all_p1_enabled.add(self.perceptron1) elif not is_enabled and self.perceptron1 in self.all_p1_enabled: layer1_out = sorted(list(self.all_p1_enabled)).index( self.perceptron1) self.layer1_3d_plot.remove_plot(layer1_out) self.all_p1_enabled.remove(self.perceptron1) self.update_visuals() def p2_weight_changed(self, val): self.network["W2"][self.connection, 0] = val self.update_just_plot2() def p2_connection_dim_changed(self, val): self.connection = int(val) self.ignore_update = True self.p2_dim_val_slid.set_val(self.network["W2"][self.connection, 0]) self.p2_dim_val_slid.vline.set_xdata( self.default_network["W2"][self.connection, 0]) self.ignore_update = False def p2b_changed(self, val): self.network["b2"][0, 0] = val self.update_just_plot2() def p2_options_update(self, label): if label == "Transform?": self.is_sig = not self.is_sig elif label == "Pre-add?": self.is_pre_add = not self.is_pre_add self.update_just_plot2() def show(self): plt.show() def update_plot1(self, out1, out2): if self.perceptron1 in self.all_p1_enabled: self.layer1_plot.set_visible(True) layer1_out = sorted(list(self.all_p1_enabled)).index( self.perceptron1) if not self.is_relu: layer1_data = out1[:, :, layer1_out] else: layer1_data = out2[:, :, layer1_out] self.layer1_plot.update(layer1_data) else: self.layer1_plot.set_visible(False) def update_3d_plot1(self, out1, out2): if self.perceptron1 in self.all_p1_enabled: if not self.is_relu: self.layer1_3d_plot.update_all(out1) else: self.layer1_3d_plot.update_all(out2) def update_plot2(self, out2, out3, out4): if self.is_pre_add: layer2_data = scale_out2(out2, self.network["W2"], self.all_p1_enabled, self.perceptron2, self.precision) layer2_data = np.sum(layer2_data, axis=2) elif not self.is_sig: layer2_data = out3[:, :, 0] else: layer2_data = out4[:, :, 0] self.layer2_plot.update(layer2_data) def update_3d_plot2(self, out2, out3, out4): if self.is_pre_add: layer2_data = scale_out2(out2, self.network["W2"], self.all_p1_enabled, self.perceptron2, self.precision) elif not self.is_sig: layer2_data = out3 else: layer2_data = out4 self.layer2_3d_plot.update_all(layer2_data) def update_visuals(self): if not self.ignore_update: out1, out2, out3, out4 = forward(self.data_space, self.network, self.dataset_type, self.all_p1_enabled, self.precision) self.update_plot1_visuals(out1, out2) self.update_plot2_visuals(out2, out3, out4) plt.draw() def update_just_plot1(self): if not self.ignore_update: out1, out2, out3, out4 = forward(self.data_space, self.network, self.dataset_type, self.all_p1_enabled, self.precision) self.update_plot1_visuals(out1, out2) plt.draw() def update_plot1_visuals(self, out1, out2): self.update_plot1(out1, out2) self.update_3d_plot1(out1, out2) def update_just_plot2(self): if not self.ignore_update: out1, out2, out3, out4 = forward(self.data_space, self.network, self.dataset_type, self.all_p1_enabled, self.precision) self.update_plot2_visuals(out2, out3, out4) plt.draw() def update_plot2_visuals(self, out2, out3, out4): self.update_plot2(out2, out3, out4) self.update_3d_plot2(out2, out3, out4) def update_widgets(self): self.p1b_slid.set_val(self.network["b1"][0, self.perceptron1]) self.p1x_slid.set_val(self.network["W1"][0, self.perceptron1]) self.p1y_slid.set_val(self.network["W1"][1, self.perceptron1]) self.p1b_slid.vline.set_xdata( self.default_network["b1"][0, self.perceptron1]) self.p1x_slid.vline.set_xdata( self.default_network["W1"][0, self.perceptron1]) self.p1y_slid.vline.set_xdata( self.default_network["W1"][1, self.perceptron1]) if (self.perceptron1 in self.all_p1_enabled and not self.p1_opt_buttons.get_status()[1]) or \ (self.perceptron1 not in self.all_p1_enabled and self.p1_opt_buttons.get_status()[1]): self.p1_opt_buttons.set_active(1)
def genResaleFlatPriceGraph(p_yearEnter, p_monthEnter, p_leaseEnter, p_flatType, p_labels): import numpy as np import matplotlib.pyplot as plt #from datetime import datetime from matplotlib.widgets import Slider, Button, RadioButtons, CheckButtons from datetime import datetime, timedelta global txt, flatType, checkBoxLines txt = None yearEnter = p_yearEnter monthEnter = p_monthEnter monthNum = "0" + monthEnter labels = p_labels flatType = p_flatType init_sLease = 50 init_sYear = 2018 date11halfMonthAgo = datetime.today() - timedelta(days=350) if init_sYear > 2000: init_sYear = date11halfMonthAgo.strftime('%Y') #print(init_sYear) def mainTheme(): global axcolor axcolor = 'lightgoldenrodyellow' title = "HDB Resale Flat Price" titlelen = len(title) ax.set_title(title, fontsize=30) ax.set_ylabel('Resale Price', fontsize=25) def onclick(event): global txt nearestPrice = round(event.ydata, -3) #mouse click on the graph and y-axis indicate the resale price and range within 5K (upper and lower 5K) if nearestPrice > 0: minPrice = nearestPrice - 5000 maxPrice = nearestPrice + 5000 townDet = townSelect[int(round(event.xdata, 0) - 1)] pointSel = townDet[(townDet['resale_price'] >= minPrice) & (townDet['resale_price'] < maxPrice)] if str(pointSel) != "[]": txt = ax.text(event.xdata, event.ydata, str(pointSel), horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor='red', alpha=0.5)) fig.canvas.draw() def offclick(event): nearestPrice = round(event.ydata, -3) if nearestPrice > 0: txt.remove() fig.canvas.draw() def get_status(): return [checkBoxLines[index].get_visible() for index in checkBoxLines] def genCombineResalePrices(yearEnter, monthNum, p_leaseEnter, flatType): import numpy as np from datetime import datetime from matplotlib.widgets import CheckButtons global townSelect, towndata ## To get data within the Period from parameter pass of p_yearEnter p_monthEnter to current date dateStart = yearEnter + monthNum[-2:] firstStartDate = yearEnter + "-" + monthNum[-2:] dateEnd = datetime.now().strftime('%Y%m') dataPeriod = data[data['month'] == firstStartDate] for i in range(int(dateStart) + 1, int(dateEnd) + 1): strDate = str(i) last2digit = int(strDate[-2:]) if last2digit < 13: iStrDate = str(i) iDate = iStrDate[0:4] + "-" + iStrDate[-2:] idata = data[data['month'] == iDate] dataPeriod = np.append(dataPeriod, idata) ## To get data within the Flat lease period (Flat Lease > parameter p_leaseEnter) leaseStart = int(p_leaseEnter) + 1 leaseData = dataPeriod[dataPeriod['remaining_lease'] > leaseStart] flatTypeData = leaseData[leaseData['flat_type'] == flatType] towndata = flatTypeData ## To get data within the Selected Checkbox Town Line line_labels = [str(line.get_label()) for line in checkBoxLines] line_visibility = [line.get_visible() for line in checkBoxLines] check = CheckButtons(rax, line_labels, line_visibility) checkButtonThemes(check) status = check.get_status() #index = line_labels.index(label) #checkBoxLines[index].set_visible(not checkBoxLines[index].get_visible()) statusIndex = np.arange(0, len(status)) resale_prices = flatTypeData['resale_price'] town_resale_prices = np.zeros(len(townList), object) for t in townIndex: town_resale_prices[t] = resale_prices[towndata['town'] == townList[t]] townSelect = np.zeros(len(townList), object) townName = [] resale_prices_combined = [] labelIndex = -1 for s in statusIndex: if status[s]: labelIndex += 1 townSelect[labelIndex] = towndata[towndata['town'] == townList[s]] #townSelect = np.append(townSelect,sdata) townName.append(line_labels[s]) resale_prices_combined.append(town_resale_prices[s]) return resale_prices_combined, townName def updateCheckBox(line_labels, line_visibility, check): rax.clear() line_labels = [str(line.get_label()) for line in checkBoxLines] line_visibility = [line.get_visible() for line in checkBoxLines] check = CheckButtons(rax, line_labels, line_visibility) checkButtonThemes(check) def func(label): global index, towndata, checkBoxLines status = check.get_status() index = line_labels.index(label) checkBoxLines[index].set_visible( not checkBoxLines[index].get_visible()) updateCheckBox(line_labels, line_visibility, check) statusIndex = np.arange(0, len(status)) resale_prices = towndata['resale_price'] town_resale_prices = np.zeros(len(townList), object) for t in townIndex: town_resale_prices[t] = resale_prices[towndata['town'] == townList[t]] townName = [] resale_prices_combined = [] labelIndex = -1 for s in statusIndex: if status[s]: labelIndex += 1 townSelect[labelIndex] = towndata[towndata['town'] == townList[s]] townName.append(line_labels[s]) resale_prices_combined.append(town_resale_prices[s]) #ax.boxplot.clear() ax.clear() mainTheme() if str(resale_prices_combined) != '[]': bp_dict = ax.boxplot(resale_prices_combined, labels=townName, patch_artist=True) boxplotTheme(bp_dict) print(status) plt.draw() def update(val): uYear = str(sYear.val) uLease = sLease.val monthNum = "01" resale_prices_combined, labels = genCombineResalePrices( uYear[0:4], monthNum, uLease, flatType) #ax.boxplot.clear() ax.clear() mainTheme() if str(resale_prices_combined) != '[]': bp_dict = ax.boxplot(resale_prices_combined, labels=labels, patch_artist=True) boxplotTheme(bp_dict) fig.canvas.draw_idle() def checkButtonThemes(check): for r in check.rectangles: r.set_alpha(0.8) r.set_width(0.025) r.set_edgecolor("k") ############## Main Process Start Here ############### title = "HDB Resale Flat Price" titlelen = len(title) #print("{:*^{titlelen}}".format(title, titlelen=titlelen+6)) #print() data = np.genfromtxt('data/resale-flat-prices.csv', skip_header=1, dtype=[('month', 'U7'), ('town', 'U30'), ('flat_type', 'U20'), ('block', 'U4'), ('street_name', 'U50'), ('storey_range', 'U15'), ('floor_area_sqm', 'U3'), ('flat_model', 'U25'), ('lease_commence_date', 'U4'), ('remaining_lease', 'i2'), ('resale_price', 'f8')], delimiter=",", missing_values=['na', '-'], filling_values=[0]) null_rows = np.isnan(data['resale_price']) nonnull_resale_prices = data[null_rows == False] townList = list(set(data['town'])) townList.sort() townIndex = np.arange(0, len(townList)) yearmonthList = list(set(data['month'])) minsYear = int(yearmonthList[0][0:4]) maxsYear = int(yearmonthList[0][0:4]) for yearmonthRec in yearmonthList: if int(yearmonthRec[0:4]) < minsYear: minsYear = int(yearmonthRec[0:4]) if int(yearmonthRec[0:4]) > maxsYear: maxsYear = int(yearmonthRec[0:4]) #print(minYear) #print(maxYear) date11halfMonthAgo = datetime.today() - timedelta(days=350) init_sYear = date11halfMonthAgo.strftime('%Y') #print("Filtered data: " + str(nonnull_resale_prices.shape)) fig, ax = plt.subplots(figsize=(19, 15)) plt.subplots_adjust(left=0.25, bottom=0.25) mainTheme() #Create a dummy ax plot line to capture the status of visible or non visible town #so later can be use in boxplot checkBoxLines = np.zeros(len(townList), object) for l in townIndex: checkBoxLines[l], = ax.plot(0, 0, visible=False, lw=2, label=townList[l]) #Ploting checkbox button # Make checkbuttons with all plotted lines with correct visibility rax = plt.axes([0.05, 0.2, 0.13, 0.7], facecolor=axcolor) line_labels = [str(line.get_label()) for line in checkBoxLines] line_visibility = [line.get_visible() for line in checkBoxLines] check = CheckButtons(rax, line_labels, line_visibility) checkButtonThemes(check) check.on_clicked(func) status = check.get_status() print(status) axYear = plt.axes([0.5, 0.15, 0.4, 0.03], facecolor=axcolor) axLease = plt.axes([0.5, 0.1, 0.4, 0.03], facecolor=axcolor) sYear = Slider(axYear, 'Year From', minsYear, maxsYear, valinit=int(init_sYear), valstep=1) sLease = Slider(axLease, 'Min Remain Lease Year', 1, 99, valinit=int(init_sLease), valstep=1) radioax = plt.axes([0.25, 0.08, 0.11, 0.12], facecolor=axcolor) radio = RadioButtons(radioax, ('1 ROOM', '2 ROOM', '3 ROOM', '4 ROOM', '5 ROOM', 'EXECUTIVE', 'MULTI-GENE'), active=3) sYear.on_changed(update) sLease.on_changed(update) def flatTypefunc(label): global flatType uYear = str(sYear.val) uLease = sLease.val monthNum = "01" flatType = label resale_prices_combined, labels = genCombineResalePrices( uYear[0:4], monthNum, uLease, flatType) ax.clear() mainTheme() if str(resale_prices_combined) != '[]': bp_dict = ax.boxplot(resale_prices_combined, labels=labels, patch_artist=True) boxplotTheme(bp_dict) fig.canvas.draw_idle() radio.on_clicked(flatTypefunc) fig.canvas.mpl_connect('button_press_event', onclick) fig.canvas.mpl_connect('button_release_event', offclick) resale_prices_combined, labels = genCombineResalePrices( yearEnter, monthNum, p_leaseEnter, flatType) if str(resale_prices_combined) != '[]': bp_dict = ax.boxplot(resale_prices_combined, labels=labels, patch_artist=True) #labels = [''] #resale_prices_combined = ['0'] boxplotTheme(bp_dict) def boxplotTheme(bp_dict): #global bp_dict ## change outline color, fill color and linewidth of the boxes for box in bp_dict['boxes']: # change outline color box.set(color='#7570b3', linewidth=2) # change fill color box.set(facecolor='#1b9e77') ## change color and linewidth of the whiskers for whisker in bp_dict['whiskers']: whisker.set(color='#7570b3', linewidth=2) ## change color and linewidth of the caps for cap in bp_dict['caps']: cap.set(color='#7570b3', linewidth=2) ## change color and linewidth of the medians for median in bp_dict['medians']: median.set(color='#b2df8a', linewidth=2) ## change the style of fliers and their fill for flier in bp_dict['fliers']: flier.set(marker='D', color='#e7298a', alpha=0.5) print(bp_dict.keys()) for line in bp_dict['medians']: # get position data for median line x, y = line.get_xydata()[1] # top of median line # overlay median resale_price #ax[0].plt.text(x, y, '%.1f' % y, ax.text(x, y, '%.1f' % y, horizontalalignment='center', fontsize=15) # draw above, centered fliers = [] for line in bp_dict['fliers']: ndarray = line.get_xydata() if (len(ndarray) > 0): max_flier = ndarray[:, 1].max() max_flier_index = ndarray[:, 1].argmax() x = ndarray[max_flier_index, 0] print("Flier: " + str(x) + "," + str(max_flier)) ax.text(x, max_flier, '%.1f' % max_flier, horizontalalignment='center', fontsize=15, color='green') plt.show()
class PeakParameterPlot(object): def __init__(self, peaks, font_size=5, observed_init=1., oe_d_init=1., oe_h_init=1., oe_v_init=1., oe_l_init=1., fdr_d_init=.1, fdr_h_init=.1, fdr_v_init=.1, fdr_l_init=.1, mappability_d_init=0., mappability_h_init=0., mappability_v_init=0., mappability_l_init=0., oe_slider_range=(0, 5), oe_slider_step=0.1, fdr_slider_range=(0, .1), fdr_slider_step=0.001, mappability_slider_range=(0, 1.), mappability_slider_step=0.05, observed_range=(0, 50), observed_step=1, **kwargs): self.peaks = peaks self.font_size = font_size self.hic_args = kwargs self.observed_init = observed_init self.oe_init = { 'd': oe_d_init, 'v': oe_v_init, 'h': oe_h_init, 'l': oe_l_init, } self.fdr_init = { 'd': fdr_d_init, 'v': fdr_v_init, 'h': fdr_h_init, 'l': fdr_l_init, } self.mappability_init = { 'd': mappability_d_init, 'v': mappability_v_init, 'h': mappability_h_init, 'l': mappability_l_init, } self.fdr_range = fdr_slider_range self.fdr_step = fdr_slider_step self.oe_range = oe_slider_range self.oe_step = oe_slider_step self.mappability_range = mappability_slider_range self.mappability_step = mappability_slider_step self.observed_range = observed_range self.observed_step = observed_step self.observed_cutoff = self.observed_init self.oe_cutoffs = self.oe_init.copy() self.fdr_cutoffs = self.fdr_init.copy() self.mappability_cutoffs = self.mappability_init.copy() self.ax_fdr_sliders = {} self.fdr_sliders = {} self.oe_sliders = {} self.ax_oe_sliders = {} self.ax_mappability_sliders = {} self.mappability_sliders = {} self.observed_slider = None self.observed_filter = None self.fdr_filter = None self.mappability_filter = None self.oe_filter = None self.filtered_plots = [] self.hic_plots = [] self.button = None self.region_pairs = [] self.fig = None def plot(self, *regions): plt.rcParams.update({'font.size': self.font_size}) # process region pairs self.region_pairs = [] for pair in regions: if isinstance(pair, string_types): r = as_region(pair) pair = (r, r) elif isinstance(pair, GenomicRegion): pair = (pair, pair) try: r1, r2, vmax = pair except ValueError: r1, r2 = pair if isinstance(r2, float) or isinstance(r2, int): vmax = r2 r2 = r1 else: vmax = None r1 = as_region(r1) r2 = as_region(r2) self.region_pairs.append((r1, r2, vmax)) # # necessary plots: hic, oe_d, fdr_d, uncorrected, # gs = grd.GridSpec(len(self.region_pairs) + 4, 4, height_ratios=[10] * len(self.region_pairs) + [1, 1, 1, 3], wspace=0.3, hspace=0.5) self.fig = plt.figure(figsize=(10, len(self.region_pairs) * 2 + 2), dpi=150) # sliders inner_observed_gs = grd.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[len(self.region_pairs) + 3, 0], wspace=0.0, hspace=0.0) ax_observed_slider = plt.subplot(inner_observed_gs[0, 0]) self.observed_slider = Slider(ax_observed_slider, 'uncorrected', self.observed_range[0], self.observed_range[1], valinit=self.observed_init, valstep=self.observed_step) self.ax_oe_sliders = dict() self.oe_sliders = dict() self.ax_fdr_sliders = dict() self.fdr_sliders = dict() self.ax_mappability_sliders = dict() self.mappability_sliders = dict() for i, neighborhood in enumerate(['d', 'h', 'v', 'l']): # O/E self.ax_oe_sliders[neighborhood] = plt.subplot(gs[len(self.region_pairs) + 0, i]) self.oe_sliders[neighborhood] = Slider(self.ax_oe_sliders[neighborhood], 'O/E {}'.format(neighborhood.upper()), self.oe_range[0], self.oe_range[1], valinit=self.oe_init[neighborhood], valstep=self.oe_step) # FDR self.ax_fdr_sliders[neighborhood] = plt.subplot(gs[len(self.region_pairs) + 1, i]) self.fdr_sliders[neighborhood] = Slider(self.ax_fdr_sliders[neighborhood], 'FDR {}'.format(neighborhood.upper()), self.fdr_range[0], self.fdr_range[1], valinit=self.fdr_init[neighborhood], valstep=self.fdr_step) # Mappability self.ax_mappability_sliders[neighborhood] = plt.subplot(gs[len(self.region_pairs) + 2, i]) self.mappability_sliders[neighborhood] = Slider(self.ax_mappability_sliders[neighborhood], 'Map {}'.format(neighborhood.upper()), self.mappability_range[0], self.mappability_range[1], valinit=self.mappability_init[neighborhood], valstep=self.mappability_step) # check button inner_button_gs = grd.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[len(self.region_pairs) + 3, 1], wspace=0.0, hspace=0.0) ax_button = plt.subplot(inner_button_gs[0, 0]) self.button = CheckButtons(ax_button, ['Show loops'], [False]) # filters self.observed_filter = ObservedPeakFilter(cutoff=self.observed_init) self.oe_filter = EnrichmentPeakFilter(enrichment_d_cutoff=self.oe_init['d'], enrichment_h_cutoff=self.oe_init['h'], enrichment_v_cutoff=self.oe_init['v'], enrichment_ll_cutoff=self.oe_init['l']) self.fdr_filter = FdrPeakFilter(fdr_d_cutoff=self.fdr_init['d'], fdr_ll_cutoff=self.fdr_init['h'], fdr_h_cutoff=self.fdr_init['v'], fdr_v_cutoff=self.fdr_init['l']) self.mappability_filter = MappabilityPeakFilter(mappability_d_cutoff=self.mappability_init['d'], mappability_h_cutoff=self.mappability_init['h'], mappability_v_cutoff=self.mappability_init['v'], mappability_ll_cutoff=self.mappability_init['l']) self.hic_plots = [] self.filtered_plots = [] for i, (r1, r2, vmax) in enumerate(self.region_pairs): ax_hic = plt.subplot(gs[i, 0]) ax_filtered = plt.subplot(gs[i, 1]) ax_oe = plt.subplot(gs[i, 2]) ax_fdr = plt.subplot(gs[i, 3]) hic_plot = EdgeHicPlot(self.peaks, ax=ax_hic, show_colorbar=False, adjust_range=False, unmappable_color='white', vmax=vmax, highlight_edges=True, highlight_limit=200, draw_tick_legend=False, **self.hic_args) filtered_plot = EdgeHicPlot(self.peaks, ax=ax_filtered, show_colorbar=False, adjust_range=False, unmappable_color='white', vmax=vmax, draw_tick_legend=False, **self.hic_args) oe_plot = EdgeHicPlot(self.peaks, plot_field='oe_d', default_value=0, norm='lin', ax=ax_oe, show_colorbar=False, colormap='white_red', adjust_range=False, vmin=1, vmax=4, log2=False, unmappable_color='white', draw_tick_legend=False,) fdr_plot = EdgeHicPlot(self.peaks, plot_field='fdr_d', default_value=1, norm='lin', ax=ax_fdr, show_colorbar=False, adjust_range=False, vmin=0, vmax=0.05, colormap='Greys_r', unmappable_color='white', draw_tick_legend=False) filtered_plot.hic_buffer.add_filter(self.observed_filter) filtered_plot.hic_buffer.add_filter(self.oe_filter) filtered_plot.hic_buffer.add_filter(self.fdr_filter) filtered_plot.hic_buffer.add_filter(self.mappability_filter) self.hic_plots.append(hic_plot) self.filtered_plots.append(filtered_plot) hic_plot.plot((r1, r2)) filtered_plot.plot((r1, r2)) oe_plot.plot((r1, r2)) fdr_plot.plot((r1, r2)) ax_filtered.set_yticklabels([]) ax_oe.set_yticklabels([]) ax_fdr.set_yticklabels([]) for neighborhood in ['d', 'h', 'v', 'l']: self.oe_sliders[neighborhood].on_changed(self.update_oe_filter) self.fdr_sliders[neighborhood].on_changed(self.update_fdr_filter) self.mappability_sliders[neighborhood].on_changed(self.update_mappability_filter) self.observed_slider.on_changed(self.update_observed_filter) self.button.on_clicked(self.refresh_plots) self.refresh_plots() return self.fig def refresh_plots(self, event=None): for i, (r1, r2, vmax) in enumerate(self.region_pairs): self.filtered_plots[i].refresh((r1, r2)) if self.button.get_status()[0]: self.hic_plots[i].update_highlights(self.filtered_plots[i].hic_buffer._last_matrix) else: self.hic_plots[i].update_highlights() self.fig.canvas.draw() def update_observed_filter(self, event): self.observed_cutoff = self.observed_slider.val self.observed_filter.cutoff = self.observed_cutoff logger.info("Observed cutoff set to {}".format(self.observed_cutoff)) self.refresh_plots() def update_oe_filter(self, event): self.oe_cutoffs = { 'd': self.oe_sliders['d'].val, 'h': self.oe_sliders['h'].val, 'v': self.oe_sliders['v'].val, 'l': self.oe_sliders['l'].val } self.oe_filter.enrichment_d_cutoff = self.oe_cutoffs['d'] self.oe_filter.enrichment_h_cutoff = self.oe_cutoffs['h'] self.oe_filter.enrichment_v_cutoff = self.oe_cutoffs['v'] self.oe_filter.enrichment_ll_cutoff = self.oe_cutoffs['l'] logger.info("O/E cutoffs set to: {}".format(self.oe_cutoffs)) self.refresh_plots() def update_fdr_filter(self, event): self.fdr_cutoffs = { 'd': self.fdr_sliders['d'].val, 'h': self.fdr_sliders['h'].val, 'v': self.fdr_sliders['v'].val, 'l': self.fdr_sliders['l'].val } self.fdr_filter.fdr_d_cutoff = self.fdr_cutoffs['d'] self.fdr_filter.fdr_h_cutoff = self.fdr_cutoffs['h'] self.fdr_filter.fdr_v_cutoff = self.fdr_cutoffs['v'] self.fdr_filter.fdr_ll_cutoff = self.fdr_cutoffs['l'] logger.info("FDR cutoffs set to: {}".format(self.fdr_cutoffs)) self.refresh_plots() def update_mappability_filter(self, event): self.mappability_cutoffs = { 'd': self.mappability_sliders['d'].val, 'h': self.mappability_sliders['h'].val, 'v': self.mappability_sliders['v'].val, 'l': self.mappability_sliders['l'].val } self.mappability_filter.mappability_d_cutoff = self.mappability_cutoffs['d'] self.mappability_filter.mappability_h_cutoff = self.mappability_cutoffs['h'] self.mappability_filter.mappability_v_cutoff = self.mappability_cutoffs['v'] self.mappability_filter.mappability_ll_cutoff = self.mappability_cutoffs['l'] logger.info("Mappability cutoffs set to: {}".format(self.mappability_cutoffs)) self.refresh_plots()
class plt_one_addpt_onclick: """ class to run one interactive plot """ def __init__(self, x, y, w, b, logistic=True): self.logistic = logistic pos = y == 1 neg = y == 0 fig, ax = plt.subplots(1, 1, figsize=(8, 4)) fig.canvas.toolbar_visible = False fig.canvas.header_visible = False fig.canvas.footer_visible = False plt.subplots_adjust(bottom=0.25) ax.scatter(x[pos], y[pos], marker='x', s=80, c='red', label="malignant") ax.scatter(x[neg], y[neg], marker='o', s=100, label="benign", facecolors='none', edgecolors=dlblue, lw=3) ax.set_ylim(-0.05, 1.1) xlim = ax.get_xlim() ax.set_xlim(xlim[0], xlim[1] * 2) ax.set_ylabel('y') ax.set_xlabel('Tumor Size') self.alegend = ax.legend(loc='lower right') if self.logistic: ax.set_title("Example of Logistic Regression on Categorical Data") else: ax.set_title("Example of Linear Regression on Categorical Data") ax.text(0.65, 0.8, "[Click to add data points]", size=10, transform=ax.transAxes) axcalc = plt.axes([0.1, 0.05, 0.38, 0.075]) #l,b,w,h axthresh = plt.axes([0.5, 0.05, 0.38, 0.075]) #l,b,w,h self.tlist = [] self.fig = fig self.ax = [ax, axcalc, axthresh] self.x = x self.y = y self.w = copy.deepcopy(w) self.b = b f_wb = np.matmul(self.x.reshape(-1, 1), self.w) + self.b if self.logistic: self.aline = self.ax[0].plot(self.x, sigmoid(f_wb), color=dlblue) self.bline = self.ax[0].plot(self.x, f_wb, color=dlorange, lw=1) else: self.aline = self.ax[0].plot(self.x, sigmoid(f_wb), color=dlblue) self.cid = fig.canvas.mpl_connect('button_press_event', self.add_data) if self.logistic: self.bcalc = Button(axcalc, 'Run Logistic Regression (click)', color=dlblue) self.bcalc.on_clicked(self.calc_logistic) else: self.bcalc = Button(axcalc, 'Run Linear Regression (click)', color=dlblue) self.bcalc.on_clicked(self.calc_linear) self.bthresh = CheckButtons( axthresh, ('Toggle 0.5 threshold (after regression)', )) self.bthresh.on_clicked(self.thresh) self.resize_sq(self.bthresh) # @output.capture() # debug def add_data(self, event): #self.ax[0].text(0.1,0.1, f"in onclick") if event.inaxes == self.ax[0]: x_coord = event.xdata y_coord = event.ydata if y_coord > 0.5: self.ax[0].scatter(x_coord, 1, marker='x', s=80, c='red') self.y = np.append(self.y, 1) else: self.ax[0].scatter(x_coord, 0, marker='o', s=100, facecolors='none', edgecolors=dlblue, lw=3) self.y = np.append(self.y, 0) self.x = np.append(self.x, x_coord) self.fig.canvas.draw() # @output.capture() # debug def calc_linear(self, event): if self.bthresh.get_status()[0]: self.remove_thresh() for it in [1, 1, 1, 1, 1, 2, 4, 8, 16, 32, 64, 128, 256]: self.w, self.b, _ = gradient_descent(self.x.reshape(-1, 1), self.y.reshape(-1, 1), self.w.reshape(-1, 1), self.b, 0.01, it, logistic=False, lambda_=0, verbose=False) self.aline[0].remove() self.alegend.remove() y_hat = np.matmul(self.x.reshape(-1, 1), self.w) + self.b self.aline = self.ax[0].plot( self.x, y_hat, color=dlblue, label=f"y = {np.squeeze(self.w):0.2f}x+({self.b:0.2f})") self.alegend = self.ax[0].legend(loc='lower right') time.sleep(0.3) self.fig.canvas.draw() if self.bthresh.get_status()[0]: self.draw_thresh() self.fig.canvas.draw() def calc_logistic(self, event): if self.bthresh.get_status()[0]: self.remove_thresh() for it in [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: self.w, self.b, _ = gradient_descent(self.x.reshape(-1, 1), self.y.reshape(-1, 1), self.w.reshape(-1, 1), self.b, 0.1, it, logistic=True, lambda_=0, verbose=False) self.aline[0].remove() self.bline[0].remove() self.alegend.remove() xlim = self.ax[0].get_xlim() x_hat = np.linspace(*xlim, 30) y_hat = sigmoid(np.matmul(x_hat.reshape(-1, 1), self.w) + self.b) self.aline = self.ax[0].plot(x_hat, y_hat, color=dlblue, label="y = sigmoid(z)") f_wb = np.matmul(x_hat.reshape(-1, 1), self.w) + self.b self.bline = self.ax[0].plot( x_hat, f_wb, color=dlorange, lw=1, label=f"z = {np.squeeze(self.w):0.2f}x+({self.b:0.2f})") self.alegend = self.ax[0].legend(loc='lower right') time.sleep(0.3) self.fig.canvas.draw() if self.bthresh.get_status()[0]: self.draw_thresh() self.fig.canvas.draw() def thresh(self, event): if self.bthresh.get_status()[0]: #plt.figtext(0,0, f"in thresh {self.bthresh.get_status()}") self.draw_thresh() else: #plt.figtext(0,0.3, f"in thresh {self.bthresh.get_status()}") self.remove_thresh() def draw_thresh(self): ws = np.squeeze(self.w) xp5 = -self.b / ws if self.logistic else (0.5 - self.b) / ws ylim = self.ax[0].get_ylim() xlim = self.ax[0].get_xlim() a = self.ax[0].fill_between([xlim[0], xp5], [ylim[1], ylim[1]], alpha=0.2, color=dlblue) b = self.ax[0].fill_between([xp5, xlim[1]], [ylim[1], ylim[1]], alpha=0.2, color=dldarkred) c = self.ax[0].annotate("Malignant", xy=[xp5, 0.5], xycoords='data', xytext=[30, 5], textcoords='offset points') d = FancyArrowPatch( posA=(xp5, 0.5), posB=(xp5 + 1.5, 0.5), color=dldarkred, arrowstyle='simple, head_width=5, head_length=10, tail_width=0.0', ) self.ax[0].add_artist(d) e = self.ax[0].annotate("Benign", xy=[xp5, 0.5], xycoords='data', xytext=[-70, 5], textcoords='offset points', ha='left') f = FancyArrowPatch( posA=(xp5, 0.5), posB=(xp5 - 1.5, 0.5), color=dlblue, arrowstyle='simple, head_width=5, head_length=10, tail_width=0.0', ) self.ax[0].add_artist(f) self.tlist = [a, b, c, d, e, f] self.fig.canvas.draw() def remove_thresh(self): #plt.figtext(0.5,0.0, f"rem thresh {self.bthresh.get_status()}") for artist in self.tlist: artist.remove() self.fig.canvas.draw() def resize_sq(self, bcid): """ resizes the check box """ #future reference #print(f"width : {bcid.rectangles[0].get_width()}") #print(f"height : {bcid.rectangles[0].get_height()}") #print(f"xy : {bcid.rectangles[0].get_xy()}") #print(f"bb : {bcid.rectangles[0].get_bbox()}") #print(f"points : {bcid.rectangles[0].get_bbox().get_points()}") #[[xmin,ymin],[xmax,ymax]] h = bcid.rectangles[0].get_height() bcid.rectangles[0].set_height(3 * h) ymax = bcid.rectangles[0].get_bbox().y1 ymin = bcid.rectangles[0].get_bbox().y0 bcid.lines[0][0].set_ydata([ymax, ymin]) bcid.lines[0][1].set_ydata([ymin, ymax])
class TypogaDraw: title = "Typoga Scores" randFilePath = "scores/highScore_random.txt" leadFilePath = "scores/highScore_phrases.txt" progFilePath = "scores/highScore_programming.txt" def __init__(self): # Set Window Title and Size plt.figure(num=self.title, figsize=(3, 6)) def checkForScoreFiles(self): """ Check if any highscore files exist. """ if not os.path.exists(self.randFilePath) and \ not os.path.exists(self.leadFilePath) and \ not os.path.exists(self.progFilePath): print("You should first play the game.\n Run ./typoga.sh") exit(0) self.checkButGO = None self.checkButGO = None self.pos = 1 def checkBoxesInit(self): """ Initialize check boxes and click events. """ # Check which score files exist. self.checkForScoreFiles() # Initialize CheckBox Menu for Game Options raxGO = plt.axes([0.02, 0.5, 1, 0.5], frameon=False) self.checkButGO = CheckButtons(raxGO, GameOptions, (False, False, False)) # Set CheckBox handler for game Options self.checkButGO.on_clicked(self.handleUserSelection) # Initialize CheckBox Menu for Score Options raxSO = plt.axes([0.02, 0.05, 1, 0.5], frameon=False) self.checkButSO = CheckButtons( raxSO, SocreOptions, (False, False, False, False, False, False, False, False)) # Set CheckBox handler for game Options self.checkButSO.on_clicked(self.handleUserSelection) def handleUserSelection(self, label): """ This method is called every time a check box gets pressed. """ # Draw Random data in its figure fig = plt.figure(num=GameOptions[0], figsize=(6, 8)) if self.checkButGO.get_status()[0] == True: fig.clear() self.parseFile(self.randFilePath) fig.show() else: plt.close() # Draw Leaders data in its figure fig = plt.figure(num=GameOptions[1], figsize=(6, 8)) if self.checkButGO.get_status()[1] == True: fig.clear() self.parseFile(self.leadFilePath) fig.show() else: plt.close() # Draw Programming data in its figure fig = plt.figure(num=GameOptions[2], figsize=(6, 8)) if self.checkButGO.get_status()[2] == True: fig.clear() self.parseFile(self.progFilePath) fig.show() else: plt.close() def parseFile(self, pfile): """ Method used to parse pfile and plot graphics """ # Check if file exists if not os.path.exists(pfile): print("File %s does not exist" % pfile) return else: hsf = open(pfile, "r") lines = hsf.readlines() lines = lines[3:] # skip header # Reset plot Position self.pos = 1 for i in range(len(self.checkButSO.get_status())): scores = [] # Clear previous scores if self.checkButSO.get_status()[i] == True: # Parse each score line for scoreLine in lines: scoreLine = scoreLine.strip() # remove \n scoreLine = scoreLine.split(":") scores.append(float(scoreLine[i + 1])) # store data # plot scores data in self.plotData(scores, i) def plotData(self, scores, index): """ Plot score data """ if len(scores) == 0: scores.append(0) # Calculate mean value before appending 0 value at the end mean = np.mean(scores) # Append zero to get space to print mean value scores.append(0) x = np.arange(1, len(scores) + 1, 1) # Divide plot based on user selection rows = self.checkButSO.get_status().count(True) plt.subplot(rows, 1, self.pos) self.pos += 1 plt.ylabel(SocreOptions[index]) plt.xticks(x) # Draw bar plot plt.bar(x, scores) # Add text legends to each bar for i in range(len(scores) - 1): # -1 because we appended zero at the end plt.text(i + 0.8, scores[i] / 2, ("%.02f" % scores[i]), fontsize=8, color='black') # Draw orange bar for the highest value max_val = np.argmax(scores) plt.bar(max_val + 1, scores[max_val], color='orange') # TODO remove this line: # plt.text(max_val+1, scores[max_val], ("%.02f" % scores[max_val]), fontsize=10, color='black') # Draw mean line plt.plot([0, len(scores)], [mean, mean], 'black') plt.text(len(scores), mean, ("%.02f" % mean), fontsize=10, color='black') def run(self): self.checkBoxesInit() #show all plt.show()
class LoaderUI(object): def __init__(self): self.full_sensor_data = None self.selected_data = None self.spe_file = None self.selector = None self.pol_angle = None self.success = False dir = os.path.dirname(__file__) filename = os.path.join(dir, 'style', 'custom-wa.mplstyle') plt.style.use(filename) fig = plt.figure() fig.set_size_inches(20, 12, forward=True) fig.canvas.set_window_title('Load Data') grid_shape = (16, 28) # Make open, load, draw buttons self.full_sensor_ax = plt.subplot2grid(grid_shape, (0, 0), colspan=13, rowspan=13) self.selected_ax = plt.subplot2grid(grid_shape, (0, 15), colspan=13, rowspan=4) axopen = plt.subplot2grid(grid_shape, (14, 4), colspan=4) axload = plt.subplot2grid(grid_shape, (14, 20), colspan=4) axfull_lambda = plt.subplot2grid(grid_shape, (11, 16), colspan=4) axpix_min = plt.subplot2grid(grid_shape, (7, 16), colspan=4) axpix_max = plt.subplot2grid(grid_shape, (7, 23), colspan=4) axrefresh = plt.subplot2grid(grid_shape, (11, 23), colspan=4) axpol_angle = plt.subplot2grid(grid_shape, (9, 23), colspan=4) bload = Button(axload, 'Load Selected', color='0.25', hovercolor='0.3') bload.on_clicked(self._load_callback) bopen = Button(axopen, 'Open File', color='0.25', hovercolor='0.3') bopen.on_clicked(self._open_callback) self.chk_full_lambda = CheckButtons(axfull_lambda, ['Full Lambda'], [True]) self.chk_full_lambda.on_clicked(self._full_lambda_callback) self.ypix_min = TextBox(axpix_min, 'Sel. Lower \nbound ', '0', color='0.25', hovercolor='0.3') self.ypix_max = TextBox(axpix_max, 'Sel. Upper \nbound ', '0', color='0.25', hovercolor='0.3') self.refresh_selection = Button(axrefresh, 'Refresh Selection', color='0.25', hovercolor='0.3') self.refresh_selection.on_clicked(self._refresh_selection_callback) self.txt_pol_angle = TextBox(axpol_angle, 'Pol. Angle \n (deg) ', '0', color='0.25', hovercolor='0.3') self._full_lambda_callback(None) plt.show(block=True) def _rect_select_callback(self, eclick, erelease): x1, y1 = eclick.xdata, eclick.ydata x2, y2 = erelease.xdata, erelease.ydata def _span_select_callback(self, ymin, ymax): ymin = int(np.floor(ymin)) ymax = int(np.ceil(ymax)) self.selected_data = self.full_sensor_data[ymin:ymax + 1, :] self.ypix_min.set_val(str(ymin)) self.ypix_max.set_val(str(ymax)) self._image_selected_data(ymin, ymax) def _load_callback(self, event): if self.spe_file is not None and self.selected_data is not None: self.pol_angle = float(self.txt_pol_angle.text) self.success = True plt.close() mpl.rcParams.update(mpl.rcParamsDefault) else: print( 'Invalid loader state: must have loaded file and selected data' ) def _open_callback(self, event): # calls general observation load function root = tk.Tk() root.withdraw() filename = filedialog.askopenfilename() if filename is not '': self.spe_file = spe_loader.load_from_files([filename]) self.full_sensor_data = self.spe_file.data[0][0] self._image_full_sensor_data() else: return def _full_lambda_callback(self, event): if self.selector is not None: self.selector.set_visible(False) chk_status = self.chk_full_lambda.get_status() if chk_status[0]: self.selector = SpanSelector(self.full_sensor_ax, self._span_select_callback, direction='vertical', minspan=5, useblit=True, span_stays=False, rectprops=dict(facecolor='red', alpha=0.2)) else: self.selector = RectangleSelector( self.full_sensor_ax, self._rect_select_callback, drawtype='box', useblit=True, button=[1, 3], # don't use middle button minspanx=1, minspany=0.1, spancoords='data', interactive=True) def _refresh_selection_callback(self, event): ymin = int(self.ypix_min.text) ymax = int(self.ypix_max.text) self._span_select_callback(ymin, ymax) def _image_full_sensor_data(self): self.full_sensor_ax.clear() self._full_lambda_callback(None) self.full_sensor_ax.imshow(self.full_sensor_data, aspect='auto') self.full_sensor_ax.set_title('Source Data') def _image_selected_data(self, ymin, ymax): self.selected_ax.clear() self.selected_ax.imshow(self.selected_data, aspect='auto') title = 'Selected: {0} rows ({1} -> {2})\n' \ '{3} wavelengths ({4:.3f} -> {5:.3f})'.format(self.selected_data.shape[0], ymin, ymax, self.selected_data.shape[1], self.spe_file.wavelength[0], self.spe_file.wavelength[-1]) self.selected_ax.set_title(title)
class Curator: """ matplotlib display of scrolling image data Parameters --------- extractor : extractor extractor object containing a full set of infilled threads and time series Attributes ---------- ind : int thread indexing min : int min of image data (for setting ranges) max : int max of image data (for setting ranges) """ def __init__(self, e, window=100): # get info from extractors self.s = e.spool self.timeseries = e.timeseries self.tf = e.im self.tf.t = 0 self.window = window ## num neurons self.numneurons = len(self.s.threads) self.path = e.root + 'extractor-objects/curate.json' self.ind = 0 try: with open(self.path) as f: self.curate = json.load(f) self.ind = int(self.curate['last']) except: self.curate = {} self.ind = 0 self.curate['0'] = 'seen' # array to contain internal state: whether to display single ROI, ROI in Z, or all ROIs self.pointstate = 0 self.show_settings = 0 self.showmip = 0 ## index for which thread #self.ind = 0 ## index for which time point to display self.t = 0 ### First frame of the first thread self.update_im() ## Display range self.min = np.min(self.im) self.max = np.max(self.im) # just some arbitrary value ## maximum t self.tmax = e.t self.restart() atexit.register(self.log_curate) def restart(self): ## Figure to display self.fig = plt.figure() ## Size of window around ROI in sub image #self.window = window ## grid object for complicated subplot handing self.grid = plt.GridSpec(4, 2, wspace=0.1, hspace=0.2) ### First subplot: whole image with red dot over ROI self.ax1 = plt.subplot(self.grid[:3, 0]) plt.subplots_adjust(bottom=0.4) self.img1 = self.ax1.imshow(self.get_im_display(), cmap='gray', vmin=0, vmax=1) # plotting for multiple points if self.pointstate == 0: pass #self.point1 = plt.scatter() #self.point1 = plt.scatter(self.s.get_positions_t_z(self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:,2], self.s.get_positions_t_z(self.t,self.s.threads[self.ind].get_position_t(self.t)[0])[:,1],c='b', s=10) elif self.pointstate == 1: self.point1 = self.ax1.scatter( self.s.get_positions_t_z( self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:, 2], self.s.get_positions_t_z( self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:, 1], c='b', s=10) #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10) elif self.pointstate == 2: self.point1 = self.ax1.scatter(self.s.get_positions_t(self.t)[:, 2], self.s.get_positions_t(self.t)[:, 1], c='b', s=10) #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10) self.thispoint = self.ax1.scatter( self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1], c='r', s=10) plt.axis('off') # plotting for single point # #plt.axis("off") # ### Second subplot: some window around the ROI plt.subplot(self.grid[:3, 1]) plt.subplots_adjust(bottom=0.4) self.subim, self.offset = subaxis( self.im, self.s.threads[self.ind].get_position_t(self.t), self.window) self.img2 = plt.imshow(self.get_subim_display(), cmap='gray', vmin=0, vmax=1) self.point2 = plt.scatter(self.window / 2 + self.offset[0], self.window / 2 + self.offset[1], c='r', s=40) self.title = self.fig.suptitle( 'Series=' + str(self.ind) + ', Z=' + str(int(self.s.threads[self.ind].get_position_t(self.t)[0]))) plt.axis("off") ### Third subplot: plotting the timeseries self.timeax = plt.subplot(self.grid[3, :]) plt.subplots_adjust(bottom=0.4) self.timeplot, = self.timeax.plot( (self.timeseries[:, self.ind] - np.min(self.timeseries[:, self.ind])) / (np.max(self.timeseries[:, self.ind]) - np.min(self.timeseries[:, self.ind]))) plt.axis("off") ### Axis for scrolling through t self.tr = plt.axes([0.2, 0.15, 0.3, 0.03], facecolor='lightgoldenrodyellow') self.s_tr = Slider(self.tr, 'Timepoint', 0, self.tmax - 1, valinit=0, valstep=1) self.s_tr.on_changed(self.update_t) ### Axis for setting min/max range self.minr = plt.axes([0.2, 0.2, 0.3, 0.03], facecolor='lightgoldenrodyellow') self.sminr = Slider(self.minr, 'R Min', 0, np.max(self.im), valinit=self.min, valstep=1) self.maxr = plt.axes([0.2, 0.25, 0.3, 0.03], facecolor='lightgoldenrodyellow') self.smaxr = Slider(self.maxr, 'R Max', 0, np.max(self.im) * 4, valinit=self.max, valstep=1) self.sminr.on_changed(self.update_mm) self.smaxr.on_changed(self.update_mm) ### Axis for buttons for next/previous time series #where the buttons are, and their locations self.axprev = plt.axes([0.62, 0.20, 0.1, 0.075]) self.axnext = plt.axes([0.75, 0.20, 0.1, 0.075]) self.bnext = Button(self.axnext, 'Next') self.bnext.on_clicked(self.next) self.bprev = Button(self.axprev, 'Previous') self.bprev.on_clicked(self.prev) #### Axis for button for display self.pointsax = plt.axes([0.75, 0.10, 0.1, 0.075]) self.pointsbutton = RadioButtons(self.pointsax, ('Single', 'Same Z', 'All')) self.pointsbutton.set_active(self.pointstate) self.pointsbutton.on_clicked(self.update_pointstate) #### Axis for whether to display MIP on left self.mipax = plt.axes([0.62, 0.10, 0.1, 0.075]) self.mipbutton = RadioButtons(self.mipax, ('Single Z', 'MIP')) self.mipbutton.set_active(self.showmip) self.mipbutton.on_clicked(self.update_mipstate) ### Axis for button to keep self.keepax = plt.axes([0.87, 0.20, 0.075, 0.075]) self.keep_button = CheckButtons(self.keepax, ['Keep', 'Trash'], [False, False]) self.keep_button.on_clicked(self.keep) ### Axis to determine which ones to show self.showax = plt.axes([0.87, 0.10, 0.075, 0.075]) self.showbutton = RadioButtons( self.showax, ('All', 'Unlabelled', 'Kept', 'Trashed')) self.showbutton.set_active(self.show_settings) self.showbutton.on_clicked(self.show) plt.show() ## Attempting to get autosave when instance gets deleted, not working right now TODO def __del__(self): self.log_curate() def update_im(self): #print(self.t) #print(self.ind) #print(self.t,int(self.s.threads[self.ind].get_position_t(self.t)[0])) if self.showmip: self.im = np.max(self.tf.get_t(self.t), axis=0) else: self.im = self.tf.get_tbyf( self.t, int(self.s.threads[self.ind].get_position_t(self.t)[0])) def get_im_display(self): return (self.im - self.min) / (self.max - self.min) def get_subim_display(self): return (self.subim - self.min) / (self.max - self.min) def update_figures(self): self.subim, self.offset = subaxis( self.im, self.s.threads[self.ind].get_position_t(self.t), self.window) self.img1.set_data(self.get_im_display()) if self.pointstate == 0: pass elif self.pointstate == 1: self.point1.set_offsets( np.array([ self.s.get_positions_t_z( self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:, 2], self.s.get_positions_t_z( self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:, 1] ]).T) #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10) elif self.pointstate == 2: self.point1.set_offsets( np.array([ self.s.get_positions_t(self.t)[:, 2], self.s.get_positions_t(self.t)[:, 1] ]).T) self.thispoint.set_offsets([ self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1] ]) plt.axis('off') #plotting for single point # self.img2.set_data(self.get_subim_display()) self.point2.set_offsets([ self.window / 2 + self.offset[0], self.window / 2 + self.offset[1] ]) self.title.set_text( 'Series=' + str(self.ind) + ', Z=' + str(int(self.s.threads[self.ind].get_position_t(self.t)[0]))) plt.draw() def update_timeseries(self): self.timeplot.set_ydata((self.timeseries[:, self.ind] - np.min(self.timeseries[:, self.ind])) / (np.max(self.timeseries[:, self.ind]) - np.min(self.timeseries[:, self.ind]))) plt.draw() def update_t(self, val): # Update index for t self.t = val # update image for t self.update_im() self.update_figures() def update_mm(self, val): self.min = self.sminr.val self.max = self.smaxr.val #self.update_im() self.update_figures() def next(self, event): self.set_index_next() self.update_im() self.update_figures() self.update_timeseries() self.update_buttons() self.update_curate() def prev(self, event): self.set_index_prev() self.update_im() self.update_figures() self.update_timeseries() self.update_buttons() self.update_curate() def log_curate(self): self.curate['last'] = self.ind with open(self.path, 'w') as fp: json.dump(self.curate, fp) def keep(self, event): status = self.keep_button.get_status() if np.sum(status) != 1: for i in range(len(status)): if status[i] != False: self.keep_button.set_active(i) else: if status[0]: self.curate[str(self.ind)] = 'keep' elif status[1]: self.curate[str(self.ind)] = 'trash' else: pass def update_buttons(self): curr = self.keep_button.get_status() #print(curr) future = [False for i in range(len(curr))] if self.curate.get(str(self.ind)) == 'seen': pass elif self.curate.get(str(self.ind)) == 'keep': future[0] = True elif self.curate.get(str(self.ind)) == 'trash': future[1] = True else: pass for i in range(len(curr)): if curr[i] != future[i]: self.keep_button.set_active(i) def show(self, label): d = {'All': 0, 'Unlabelled': 1, 'Kept': 2, 'Trashed': 3} #print(label) self.show_settings = d[label] def set_index_prev(self): if self.show_settings == 0: self.ind -= 1 self.ind = self.ind % self.numneurons elif self.show_settings == 1: self.ind -= 1 counter = 0 while self.curate.get(str(self.ind)) in [ 'keep', 'trash' ] and counter != self.numneurons: self.ind -= 1 self.ind = self.ind % self.numneurons counter += 1 self.ind = self.ind % self.numneurons elif self.show_settings == 2: self.ind -= 1 counter = 0 while self.curate.get(str( self.ind)) not in ['keep'] and counter != self.numneurons: self.ind -= 1 counter += 1 self.ind = self.ind % self.numneurons else: self.ind -= 1 counter = 0 while self.curate.get(str( self.ind)) not in ['trash'] and counter != self.numneurons: self.ind -= 1 counter += 1 self.ind = self.ind % self.numneurons def set_index_next(self): if self.show_settings == 0: self.ind += 1 self.ind = self.ind % self.numneurons elif self.show_settings == 1: self.ind += 1 counter = 0 while self.curate.get(str(self.ind)) in [ 'keep', 'trash' ] and counter != self.numneurons: self.ind += 1 self.ind = self.ind % self.numneurons counter += 1 self.ind = self.ind % self.numneurons elif self.show_settings == 2: self.ind += 1 counter = 0 while self.curate.get(str( self.ind)) not in ['keep'] and counter != self.numneurons: self.ind += 1 counter += 1 self.ind = self.ind % self.numneurons else: self.ind += 1 counter = 0 while self.curate.get(str( self.ind)) not in ['trash'] and counter != self.numneurons: self.ind += 1 counter += 1 self.ind = self.ind % self.numneurons def update_curate(self): if self.curate.get(str(self.ind)) in ['keep', 'seen', 'trash']: pass else: self.curate[str(self.ind)] = 'seen' def update_pointstate(self, label): d = { 'Single': 0, 'Same Z': 1, 'All': 2, } #print(label) self.pointstate = d[label] self.update_point1() self.update_figures() def update_point1(self): self.ax1.clear() self.img1 = self.ax1.imshow(self.get_im_display(), cmap='gray', vmin=0, vmax=1) plt.axis('off') if self.pointstate == 0: self.point1 = None elif self.pointstate == 1: self.point1 = self.ax1.scatter( self.s.get_positions_t_z( self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:, 2], self.s.get_positions_t_z( self.t, self.s.threads[self.ind].get_position_t(self.t)[0])[:, 1], c='b', s=10) #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10) elif self.pointstate == 2: self.point1 = self.ax1.scatter(self.s.get_positions_t(self.t)[:, 2], self.s.get_positions_t(self.t)[:, 1], c='b', s=10) #self.thispoint = plt.scatter(self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1],c='r', s=10) self.thispoint = self.ax1.scatter( self.s.threads[self.ind].get_position_t(self.t)[2], self.s.threads[self.ind].get_position_t(self.t)[1], c='r', s=10) plt.axis('off') #plt.show() def update_mipstate(self, label): d = { 'Single Z': 0, 'MIP': 1, } #print(label) self.showmip = d[label] self.update_im() self.update_figures()
class ScanData: ''' Collect raw data for energy scan experiments and provides methods to average them. Attributes ---------- label : list (str) Labels for graphs. idx : list (str) Scan indexes raw_imp : pandas DataFrame Collect imported raw data. energy : array common energy scale for average of scans. lines : list (2d line objects) Collect lines object of raw scan for plotting. blins : list (2d line objects) Collect lines object of raw scan for plotting. avg_ln : 2d line objects Lines object of average of scans. plab : list (str) Collect the list of labels plotted in graphs for choosing scans. checkbx : CheckButtons obj Widget for choosing plots. aver : array Data average of selected scans from raw_imp. If ScanData is a reference one, aver contain normalized data by reference. dtype : str Identifies the data collected, used for graph labelling: sigma+, sigma- for XMCD CR, CL for XNCD H-, H+ for XNXD LH, LV fot XNLD chsn_scns : list (str) Labels of chosen scans for the analysis. pe_av : float Value of spectra at pre-edge energy. It is obtained from averaging data in an defined energy range centered at pre-edge energy. pe_av_int : float Pre-edge value obtained from linear interpolation considering pre-edge and post-edge energies. bsl : Univariate spline object spline interpolation of ArpLS baseline norm : array Averaged data normalized by value at pre-edge energy. norm_int : array Averaged data normalized by interpolated pre-edge value. ej : float Edge-jump value. ej_norm : float edge-jump value normalized by value at pre-edge energy. ej_int : float edge-jump value computed with interpolated pre-edge value. ej_norm_int : float edge-jump computed and normalized by interpolated pre-edge value. Methods ------- man_aver_e_scans(guiobj, enrg) Manage the choice of scans to be averaged and return the average of selected scans. aver_e_scans(enrg, chsn, guiobj) Performe the average of data scans. check_but(label) When check buttons are checked switch the visibility of corresponding line. averbut(event) When average button is pressed calls aver_e_scans to compute average on selected scans. reset(event) Reset graph to starting conditions. finish_but(self, event) Close the figure on pressing the button Finish. edge_norm(guiobj, enrg, e_edge, e_pe, pe_rng, pe_int) Normalize energy scan data by value at pre-edge energy and compute edge-jump. ''' def __init__(self): ''' Initialize attributes label, idx, and raw_imp. ''' self.label = [] self.idx = [] self.raw_imp = pd.DataFrame() def man_aver_e_scans(self, guiobj, enrg): ''' Manage the choice of scans to be averaged and return the average of selected scans. Parameters ---------- guiobj : GUI object Provides GUI dialogs. enrg : array Energy values at which average is calculated. Return ------ Set class attributes: aver : array Average values of the chosen scans. chsn_scns : list Labels of chosen scans for the analysis (for log purpose). ''' self.energy = enrg self.chsn_scns = [] self.aver = 0 if guiobj.interactive: # Interactive choose of scans fig, ax = plt.subplots(figsize=(10, 6)) fig.subplots_adjust(right=0.75) ax.set_xlabel('E (eV)') ax.set_ylabel(self.dtype) if guiobj.infile_ref: fig.suptitle('Choose reference sample scans') else: fig.suptitle('Choose sample scans') # Initialize list which will contains line obj of scans # lines contain colored lines for choose # blines contain dark lines to be showed with average self.lines = [] self.blines = [] # Populate list with line objs for i in self.idx: e_col = 'E' + i # Show lines and not blines self.lines.append( ax.plot(self.raw_imp[e_col], self.raw_imp[i], label=self.label[self.idx.index(i)])[0]) self.blines.append( ax.plot(self.raw_imp[e_col], self.raw_imp[i], color='dimgrey', visible=False)[0]) # Initialize chsn_scs and average line with all scans and # set it invisible for line in self.lines: if line.get_visible(): self.chsn_scns.append(line.get_label()) self.aver_e_scans() self.avg_ln, = ax.plot(self.energy, self.aver, color='red', lw=2, visible=False) # Create box for checkbutton chax = fig.add_axes([0.755, 0.32, 0.24, 0.55], facecolor='0.95') self.plab = [str(line.get_label()) for line in self.lines] visibility = [line.get_visible() for line in self.lines] self.checkbx = CheckButtons(chax, self.plab, visibility) # Customizations of checkbuttons rxy = [] bxh = 0.05 for r in self.checkbx.rectangles: r.set(height=bxh) r.set(width=bxh) rxy.append(r.get_xy()) for i in range(len(rxy)): self.checkbx.lines[i][0].set_xdata( [rxy[i][0], rxy[i][0] + bxh]) self.checkbx.lines[i][0].set_ydata( [rxy[i][1], rxy[i][1] + bxh]) self.checkbx.lines[i][1].set_xdata( [rxy[i][0] + bxh, rxy[i][0]]) self.checkbx.lines[i][1].set_ydata( [rxy[i][1], rxy[i][1] + bxh]) for l in self.checkbx.labels: l.set(fontsize='medium') l.set_verticalalignment('center') l.set_horizontalalignment('left') self.checkbx.on_clicked(self.check_but) # Create box for average reset and finish buttons averbox = fig.add_axes([0.77, 0.2, 0.08, 0.08]) bnaver = Button(averbox, 'Average') bnaver.on_clicked(self.averbut) rstbox = fig.add_axes([0.89, 0.2, 0.08, 0.08]) bnrst = Button(rstbox, 'Reset') bnrst.on_clicked(self.reset) finbox = fig.add_axes([0.82, 0.07, 0.12, 0.08]) bnfinish = Button(finbox, 'Finish') bnfinish.on_clicked(self.finish_but) ax.legend() plt.show() # If average is not pressed automatically compute average on # selected scans if self.chsn_scns == []: for line in self.lines: if line.get_visible(): self.chsn_scns.append(line.get_label()) self.aver_e_scans() else: # Not-interactive mode: all scans except 'Dummy Scans' are # evaluated for lbl in self.label: # Check it is not a 'Dummy Scan' and append # corresponding scan number in chosen scan list if not ('Dummy' in lbl): self.chsn_scns.append(self.idx[self.label.index(lbl)]) self.aver_e_scans() def check_but(self, label): ''' When check buttons are checked switch the visibility of corresponding line. Also update self.chsn_scns with labels of visible scans. ''' index = self.plab.index(label) self.lines[index].set_visible(not self.lines[index].get_visible()) # Update chsn_scns self.chsn_scns = [] for line in self.lines: if line.get_visible(): self.chsn_scns.append(line.get_label()) plt.draw() def averbut(self, event): ''' When average button is pressed calls aver_e_scans to compute average on selected scans. Update self.chsn_scns and self.aver. ''' # Initialize list of chosed scans self.chsn_scns = [] # Set visible only chosen scans in blines and append to # chsn_scns for i in range(len(self.lines)): if self.lines[i].get_visible(): self.lines[i].set(visible=False) self.chsn_scns.append(self.lines[i].get_label()) self.blines[i].set(visible=True) self.aver_e_scans() # Update average line and make it visible self.avg_ln.set_ydata(self.aver) self.avg_ln.set(visible=True) plt.draw() def aver_e_scans(self): ''' Perform the average of data scans. If interactive mode, data scans and their average are shown together in a plot. Parameters ---------- enrg : array Energy values at which average is calculated. chsn : list (str) Scan-numbers of scan to be averaged. guiobj: GUI object Provides GUI dialogs. Returns ------- array, containing the average of data scans. Notes ----- To compute the average the common energy scale enrg is used. All passed scans are interpolated with a linear spline (k=1 and s=0 in itp.UnivariateSpline) and evaluated along the common energy scale. The interpolated data are eventually averaged. ''' intrp = [] for i in self.idx: e_col = 'E' + i if self.label[self.idx.index(i)] in self.chsn_scns: # chosen data x = self.raw_imp['E' + i][1:] y = self.raw_imp[i][1:] # Compute linear spline interpolation y_int = itp.UnivariateSpline(x, y, k=1, s=0) # Evaluate interpolation of scan data on enrg energy scale # and append to previous interpolations intrp.append(y_int(self.energy)) # Average all inteprolated scans self.aver = np.average(intrp, axis=0) def reset(self, event): ''' Reset graph for schoosing scans to starting conditions. Show all scans and set checked all buttons. ''' # Clear graph self.avg_ln.set(visible=False) stauts = self.checkbx.get_status() for i, stat in enumerate(stauts): if not stat: self.checkbx.set_active(i) # Show all spectra for i in range(len(self.lines)): self.lines[i].set(visible=True) self.blines[i].set(visible=False) plt.draw() def finish_but(self, event): ''' Close the figure on pressing the button Finish. ''' plt.close() def edge_norm(self, guiobj, enrg, e_edge, e_pe, e_poste, pe_rng): ''' Normalize energy scan data by the value at pre-edge energy. Also compute the energy jump defined as the difference between the value at the edge and pre-edge energies respectively. This computations are implemented also considering baseline. If linear baseline is selected edge jump is computed considering the the height of data at edge energy from the stright line passing from pre-edge and post edge data. If asymmetrically reweighted penalized least squares baseline is selected the edge jump is calculated considering as the distance at edge energy between the averaged spectrum and baseline. Parameters ---------- guiobj: GUI object Provides GUI dialogs. enrg : array Energy values of scan. e_edge : float Edge energy value. e_pe : float Pre-edge energy value. pe_rng : int Number of points constituting the semi-width of energy range centered at e_pe. pe_int : float Pre-edge value obtained from linear interpolation based on pre- and post-edge energies. Returns ------- Set class attributes: pe_av : float value at pre-edge energy. norm : array self.aver scan normalized by value at pre-edge energy. norm_int : array Averaged data normalized by interpolated pre-edge value. ej : float edge-jump value. ej_norm : float edge-jump value normalized by value at pre-edge energy. ej_int : float edge-jump value computed with interpolated pre-edge value. ej_norm_int : float edge-jump computed and normalized by interpolated pre-edge value. Notes ----- To reduce noise effects the value of scan at pre-edge energy is obtained computing an average over an energy range of width pe_rng and centered at e_pe pre-edge energy. The value of scan at edge energy is obtained by cubic spline interpolation of data (itp.UnivariateSpline with k=3 and s=0). ''' # Index of the nearest element to pre-edge energy pe_idx = np.argmin((np.abs(enrg - e_pe))) # Left and right extremes of energy range for pre-edge average lpe_idx = int(pe_idx - pe_rng) rpe_idx = int(pe_idx + pe_rng + 1) # Average of values for computation of pre-edge self.pe_av = np.average(self.aver[lpe_idx:rpe_idx:1]) # Cubic spline interpolation of energy scan y_int = itp.UnivariateSpline(enrg, self.aver, k=3, s=0) # value at edge energy from interpolation y_edg = y_int(e_edge) # Edge-jumps computations - no baseline self.ej = y_edg - self.pe_av self.ej_norm = self.ej / self.pe_av # Normalization by pre-edge value self.norm = self.aver / self.pe_av # Edge-jumps computations - consider baseline if guiobj.bsl_int: # ArpLS baseline # Interpolation of pre-edge energy self.pe_av_int = self.bsl(e_edge) else: # Linear baseline # Interpolation of pre-edge energy x = [e_pe, e_poste] y = [y_int(e_pe), y_int(e_poste)] self.pe_av_int = lin_interpolate(x, y, e_edge) # Normalization by pre-edge value self.norm_int = self.aver / self.pe_av_int self.ej_int = y_edg - self.pe_av_int self.ej_norm_int = self.ej_int / self.pe_av_int
class InoHeartbeatVerifier(object): """ GUI for editing heartbeat signal labels for INO case """ def __init__(self, composite_peaks, index=0, folder_name="", dosage="", file_name="", interval_number=""): super(InoHeartbeatVerifier, self).__init__() # Save Composite Peaks self.composite_peaks = composite_peaks self.index = index self.folder_name = folder_name self.dosage = dosage self.file_name = file_name self.interval_number = interval_number self.update_point = None # Plot signals self.plot_signals() def plot_signals(self): # Create figure self.fig, self.ax = plt.subplots() # Load composite signals self.time, self.signal, self.seis, self.phono = self.composite_peaks.composites[ self.index] # Plot ECG, Phono and Seismo self.signal_line, = self.ax.plot(self.signal, linewidth=1, c="b", label="ECG") self.seis_line, = self.ax.plot(self.seis, '--', linewidth=0.5, c='r', label="Seis") self.phono_line, = self.ax.plot(self.phono, '--', linewidth=0.5, c='g', label="Phono") self.ax.set_xlim(0, len(self.signal)) sig_min = min(self.signal) sig_max = max(self.signal) self.ax.set_ylim(sig_min - 0.1 * (sig_max - sig_min), sig_max + 0.1 * (sig_max - sig_min)) plt.legend(loc='upper right') # Q Peak self.q_point = self.ax.scatter( self.composite_peaks.Q.data[self.index], self.signal[self.composite_peaks.Q.data[self.index]], c='#ff7f0e') self.q_text = self.ax.text( self.composite_peaks.Q.data[self.index], self.signal[self.composite_peaks.Q.data[self.index]] + 0.2, "Q", fontsize=9, horizontalalignment='center') # QM Seismo self.qm_seis_point = self.ax.scatter( self.composite_peaks.QM_seis.data[self.index], self.seis[self.composite_peaks.QM_seis.data[self.index]], c='#d62728') self.qm_seis_text = self.ax.text( self.composite_peaks.QM_seis.data[self.index], self.seis[self.composite_peaks.QM_seis.data[self.index]] + 0.2, "QM Seis", fontsize=9, horizontalalignment='center') # QM Phono self.qm_phono_point = self.ax.scatter( self.composite_peaks.QM_phono.data[self.index], self.phono[self.composite_peaks.QM_phono.data[self.index]], c='#8c564b') self.qm_phono_text = self.ax.text( self.composite_peaks.QM_phono.data[self.index], self.phono[self.composite_peaks.QM_phono.data[self.index]] + 0.2, "QM Phono", fontsize=9, horizontalalignment='center') # Initalize axes and data points self.x = range(len(self.signal)) self.y = self.signal # Cross hairs self.lx = self.ax.axhline(color='k', linewidth=0.2) # the horiz line self.ly = self.ax.axvline(color='k', linewidth=0.2) # the vert line # Add data left_shift = 0.45 start = 0.96 space = 0.04 self.ax.text(0.01, start, transform=self.ax.transAxes, s="Folder: " + self.folder_name, fontsize=12, horizontalalignment='left') self.ax.text(0.01, start - space, transform=self.ax.transAxes, s="Dosage: " + str(self.dosage), fontsize=12, horizontalalignment='left') self.ax.text(0.01, start - 2 * space, transform=self.ax.transAxes, s="File: " + self.file_name, fontsize=12, horizontalalignment='left') self.ax.text(0.01, start - 3 * space, transform=self.ax.transAxes, s="File #: " + str(self.interval_number), fontsize=12, horizontalalignment='left') self.i_text = self.ax.text(0.60 - left_shift, 1.1 - space, transform=self.ax.transAxes, s="Composite: " + str(self.index + 1) + "/" + str(len(self.composite_peaks.composites)), fontsize=12, horizontalalignment='left') # Add Intervals start_left = 0.575 shift_left = 0.10 qm_seis = str( round( 1 / (self.time[self.composite_peaks.QM_seis.data[self.index]] - self.time[self.composite_peaks.Q.data[self.index]]), 2)) qm_phono = str( round( 1 / (self.time[self.composite_peaks.QM_phono.data[self.index]] - self.time[self.composite_peaks.Q.data[self.index]]), 2)) self.qm_text = self.ax.text(start_left, 0.91, horizontalalignment='center', transform=self.fig.transFigure, s="1/(E-M)ino\nSeis: " + qm_seis + " Hz" + "\nPhono: " + qm_phono + " Hz") # Add index buttons ax_prev = plt.axes([0.575 - left_shift, 0.9, 0.1, 0.075]) self.bprev = Button(ax_prev, 'Previous') self.bprev.on_clicked(self.prev) ax_next = plt.axes([0.8 - left_shift, 0.9, 0.1, 0.075]) self.b_next = Button(ax_next, 'Next') self.b_next.on_clicked(self.next) self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_move) # Add Save Button ax_save = plt.axes([0.8, 0.9, 0.1, 0.075]) self.b_save = Button(ax_save, 'Save') self.b_save.on_clicked(self.save) # Add Line buttons self.ax.text(-0.13, 0.97, transform=self.ax.transAxes, s="Snap on to:", fontsize=12, horizontalalignment='left') # left, bottom, width, height ax_switch_signals = plt.axes([0.02, 0.7, 0.07, 0.15]) self.b_switch_signals = RadioButtons(ax_switch_signals, ('ECG', 'Seismo', 'Phono')) for c in self.b_switch_signals.circles: c.set_radius(0.05) self.b_switch_signals.on_clicked(self.switch_signal) self.fig.canvas.mpl_connect('button_press_event', self.on_click) self.fig.canvas.mpl_connect('button_release_event', self.off_click) # Add Line hide buttons self.ax.text(1.015, 0.97, transform=self.ax.transAxes, s="Hide Signal:", fontsize=12, horizontalalignment='left') ax_hide_signals = plt.axes([0.91, 0.7, 0.07, 0.15]) self.b_hide_signals = CheckButtons(ax_hide_signals, ('ECG', 'Seismo', 'Phono')) self.b_hide_signals.on_clicked(self.switch_signal) # Add Sliders self.signal_amp_slider = Slider(plt.axes([0.91, 0.15, 0.01, 0.475]), label="ECG\nA", valmin=0.01, valmax=10, valinit=1, orientation='vertical') self.signal_amp_slider.on_changed(self.switch_signal) self.signal_amp_slider.valtext.set_visible(False) self.seis_height_slider = Slider(plt.axes([0.93, 0.15, 0.01, 0.475]), label=" Seis\nH", valmin=1.5 * min(self.signal), valmax=1.5 * max(self.signal), valinit=0, orientation='vertical') self.seis_height_slider.on_changed(self.switch_signal) self.seis_height_slider.valtext.set_visible(False) self.seis_amp_slider = Slider(plt.axes([0.94, 0.15, 0.01, 0.475]), label="\nA", valmin=0.01, valmax=10, valinit=1, orientation='vertical') self.seis_amp_slider.on_changed(self.switch_signal) self.seis_amp_slider.valtext.set_visible(False) self.phono_height_slider = Slider(plt.axes([0.96, 0.15, 0.01, 0.475]), label=" Phono\nH", valmin=1.5 * min(self.signal), valmax=1.5 * max(self.signal), valinit=0, orientation='vertical') self.phono_height_slider.on_changed(self.switch_signal) self.phono_height_slider.valtext.set_visible(False) self.phono_amp_slider = Slider(plt.axes([0.97, 0.15, 0.01, 0.475]), label="A", valmin=.01, valmax=10, valinit=1, orientation='vertical') self.phono_amp_slider.on_changed(self.switch_signal) self.phono_amp_slider.valtext.set_visible(False) # Maximize frame mng = plt.get_current_fig_manager() mng.full_screen_toggle() plt.show() def switch_signal(self, label): # Update Lines self.signal_line.set_data(range(len(self.signal)), self.signal_amp_slider.val * self.signal) self.seis_line.set_data(range(len(self.signal)), (self.seis_amp_slider.val * self.seis) + self.seis_height_slider.val) self.phono_line.set_data(range(len(self.signal)), (self.phono_amp_slider.val * self.phono) + self.phono_height_slider.val) # Q Peaks self.q_point.set_offsets( (self.composite_peaks.Q.data[self.index], self.signal_amp_slider.val * self.signal[self.composite_peaks.Q.data[self.index]])) self.q_text.set_position( (self.composite_peaks.Q.data[self.index], self.signal_amp_slider.val * self.signal[self.composite_peaks.Q.data[self.index]] + 0.2)) # QM Seismo self.qm_seis_point.set_offsets( (self.composite_peaks.QM_seis.data[self.index], self.seis_amp_slider.val * self.seis[self.composite_peaks.QM_seis.data[self.index]] + self.seis_height_slider.val)) self.qm_seis_text.set_position( (self.composite_peaks.QM_seis.data[self.index], self.seis_amp_slider.val * self.seis[self.composite_peaks.QM_seis.data[self.index]] + 0.2 + self.seis_height_slider.val)) # QM Phono self.qm_phono_point.set_offsets( (self.composite_peaks.QM_phono.data[self.index], self.phono_amp_slider.val * self.phono[self.composite_peaks.QM_phono.data[self.index]] + self.phono_height_slider.val)) self.qm_phono_text.set_position( (self.composite_peaks.QM_phono.data[self.index], self.phono_amp_slider.val * self.phono[self.composite_peaks.QM_phono.data[self.index]] + 0.2 + self.phono_height_slider.val)) # Update Data qm_seis = str( round( 1 / (self.time[self.composite_peaks.QM_seis.data[self.index]] - self.time[self.composite_peaks.Q.data[self.index]]), 2)) qm_phono = str( round( 1 / (self.time[self.composite_peaks.QM_phono.data[self.index]] - self.time[self.composite_peaks.Q.data[self.index]]), 2)) self.qm_text.set_text("1/(E-M)ino\nSeis: " + qm_seis + " Hz" + "\nPhono: " + qm_phono + " Hz") # Update Cross-hairs label = self.b_switch_signals.value_selected self.x = range(len(self.signal)) if label == 'ECG': self.y = self.signal_amp_slider.val * self.signal if label == 'Seismo': self.y = self.seis_amp_slider.val * self.seis + self.seis_height_slider.val if label == 'Phono': self.y = self.phono_amp_slider.val * self.phono + self.phono_height_slider.val # Update Hidden signals hide_label = self.b_hide_signals.get_status() if hide_label[0]: # ECG self.signal_line.set_linewidth(0) else: self.signal_line.set_linewidth(0.5) if hide_label[1]: # Seismo self.seis_line.set_linewidth(0) else: self.seis_line.set_linewidth(0.5) if hide_label[2]: # Phono self.phono_line.set_linewidth(0) else: self.phono_line.set_linewidth(0.5) self.fig.canvas.draw() def off_click(self, event): self.lx.set_color('k') self.ly.set_color('k') self.lx.set_linewidth(0.2) self.ly.set_linewidth(0.2) if event.xdata is not None: if self.update_point == "Q": self.q_point.set_offsets( (int(event.xdata), self.signal_amp_slider.val * self.signal[int(event.xdata)])) self.q_text.set_position( (int(event.xdata), self.signal_amp_slider.val * self.signal[int(event.xdata)] + 0.2)) self.composite_peaks.Q.data[self.index] = int(event.xdata) if self.update_point == "QM Seismo": self.qm_seis_point.set_offsets( (int(event.xdata), self.seis_amp_slider.val * self.seis[int(event.xdata)] + self.seis_height_slider.val)) self.qm_seis_text.set_position( (int(event.xdata), self.seis_amp_slider.val * self.seis[int(event.xdata)] + 0.2 + self.seis_height_slider.val)) self.composite_peaks.QM_seis.data[self.index] = int( event.xdata) if self.update_point == "QM Phono": self.qm_phono_point.set_offsets( (int(event.xdata), self.phono_amp_slider.val * self.phono[int(event.xdata)] + self.phono_height_slider.val)) self.qm_phono_text.set_position( (int(event.xdata), self.phono_amp_slider.val * self.phono[int(event.xdata)] + 0.2 + self.phono_height_slider.val)) self.composite_peaks.QM_phono.data[self.index] = int( event.xdata) # Update Data qm_seis = str( round( 1 / (self.time[self.composite_peaks.QM_seis.data[self.index]] - self.time[self.composite_peaks.Q.data[self.index]]), 2)) qm_phono = str( round( 1 / (self.time[self.composite_peaks.QM_phono.data[self.index]] - self.time[self.composite_peaks.Q.data[self.index]]), 2)) self.qm_text.set_text("1/(E-M)ino\nSeis: " + qm_seis + " Hz" + "\nPhono: " + qm_phono + " Hz") self.fig.canvas.draw() self.fig.canvas.flush_events() def on_click(self, event): threshold = 40 current_signal = self.b_switch_signals.value_selected self.update_point = None if event.xdata is not None: if current_signal == 'ECG': if abs(self.composite_peaks.Q.data[self.index] - event.xdata) < threshold: self.lx.set_color('#ff7f0e') self.ly.set_color('#ff7f0e') self.lx.set_linewidth(1) self.ly.set_linewidth(1) self.fig.canvas.draw() self.update_point = "Q" if current_signal == 'Seismo': if abs(self.composite_peaks.QM_seis.data[self.index] - event.xdata) < threshold: self.lx.set_color('#d62728') self.ly.set_color('#d62728') self.lx.set_linewidth(1) self.ly.set_linewidth(1) self.fig.canvas.draw() self.update_point = "QM Seismo" if current_signal == 'Phono': if abs(self.composite_peaks.QM_phono.data[self.index] - event.xdata) < threshold: self.lx.set_color('#8c564b') self.ly.set_color('#8c564b') self.lx.set_linewidth(1) self.ly.set_linewidth(1) self.fig.canvas.draw() self.update_point = "QM Phono" def mouse_move(self, event): # If nothing happened do nothing if not event.inaxes: return # Update x data point x = event.xdata # Lock to closest x coordinate on signal indx = min(np.searchsorted(self.x, x), len(self.x) - 1) x = self.x[indx] y = self.y[indx] # Update the crosshairs self.lx.set_ydata(y) self.ly.set_xdata(x) # Draw everything self.ax.figure.canvas.draw() def next(self, event): self.index += 1 if self.index > len(self.composite_peaks.composites) - 1: self.index = 0 self.update_plot() def prev(self, event): self.index -= 1 if self.index < 0: self.index = len(self.composite_peaks.composites) - 1 self.update_plot() def save(self, event): # Get File Name save_file_name = "composites_" + self.folder_name + "_" + self.file_name + "_d" + str( self.dosage) # Save self.composite_peaks.save("data/Derived/composites/" + save_file_name) print("Saved") def update_plot(self): # Update index self.i_text.set_text("Composite: " + str(self.index + 1) + "/" + str(len(self.composite_peaks.composites))) # Load composite signals self.time, self.signal, self.seis, self.phono = self.composite_peaks.composites[ self.index] # Update cross hairs self.switch_signal(None) # Plot ECG, Phono and Seismo self.signal_line.set_data(range(len(self.signal)), self.signal) self.seis_line.set_data(range(len(self.seis)), self.seis) self.phono_line.set_data(range(len(self.phono)), self.phono) self.ax.set_xlim(0, len(self.signal)) # Q Peaks self.q_point.set_offsets( (self.composite_peaks.Q.data[self.index], self.signal[self.composite_peaks.Q.data[self.index]])) self.q_text.set_position( (self.composite_peaks.Q.data[self.index], self.signal[self.composite_peaks.Q.data[self.index]] + 0.2)) # QM Seismo self.qm_seis_point.set_offsets( (self.composite_peaks.QM_seis.data[self.index], self.seis[self.composite_peaks.QM_seis.data[self.index]])) self.qm_seis_text.set_position( (self.composite_peaks.QM_seis.data[self.index], self.seis[self.composite_peaks.QM_seis.data[self.index]] + 0.2)) # QM Phono self.qm_phono_point.set_offsets( (self.composite_peaks.QM_phono.data[self.index], self.phono[self.composite_peaks.QM_phono.data[self.index]])) self.qm_phono_text.set_position( (self.composite_peaks.QM_phono.data[self.index], self.phono[self.composite_peaks.QM_phono.data[self.index]] + 0.2)) self.fig.canvas.draw()
class CrossMark: def __init__(self, img_name, imgref_name): self.img = cv2.imread(img_name)[:, :, ::-1] self.imgr = cv2.imread(imgref_name)[:, :, ::-1] fig = plt.figure(str('coross mark')) self.sfigs = [] self.sfigs += [plt.subplot(1, 2, 1)] plt.imshow(self.imgr) self.sfigs += [plt.subplot(1, 2, 2)] plt.imshow(self.img) axsave = plt.axes([0.8, 0.05, 0.1, 0.075]) bsave = Button(axsave, 'Save') bsave.on_clicked(self.save) axadd = plt.axes([0.7, 0.05, 0.1, 0.075]) badd = Button(axadd, 'Add') badd.on_clicked(self.add) axdel = plt.axes([0.6, 0.05, 0.1, 0.075]) bdel = Button(axdel, 'Del') bdel.on_clicked(self.delete) axrot = plt.axes([0.5, 0.05, 0.1, 0.075]) brot = Button(axrot, 'Rot') brot.on_clicked(self.rotate_im) axplus = plt.axes([0.4, 0.05, 0.1, 0.075]) bplus = Button(axplus, '+') bplus.on_clicked(self.plus) axmin = plt.axes([0.3, 0.05, 0.1, 0.075]) bmin = Button(axmin, '-') bmin.on_clicked(self.minus) axchk = plt.axes([0.2, 0.90, 0.2, 0.1]) self.chk_colors = CheckButtons(axchk, ('Show_pts', ), actives=[False]) self.chk_colors.on_clicked(self.redraw) cid = fig.canvas.mpl_connect('button_press_event', self.onclick) self.objects_hdls = None self.cross_list = [[], []] self.selected_obj_ind = -1 self.draw_objs() plt.show() def plus(self, event): self.selected_obj_ind += 1 self.selected_obj_ind = min( (self.selected_obj_ind, len(self.cross_list[0]))) print('self.selected_obj_ind', self.selected_obj_ind) self.draw_objs() plt.draw() def minus(self, event): self.selected_obj_ind -= 1 self.selected_obj_ind = max((-1, self.selected_obj_ind)) print('self.selected_obj_ind', self.selected_obj_ind) self.draw_objs() plt.draw() def rotate_im(self, event): self.img = cv2.transpose(cv2.flip(self.img, 1)) self.cross_list[1] = [(y, self.img.shape[0] - x - 1) for x, y in self.cross_list[1]] self.sfigs[1].cla() self.sfigs[1].imshow(self.img) self.draw_objs() plt.draw() def redraw(self, event): self.draw_objs() plt.draw() def draw_objs(self, selected=-1): if self.objects_hdls is not None: for h in self.objects_hdls: #print('----',h) try: h[0].remove() except ValueError: pass self.objects_hdls = [] h = self.objects_hdls if self.chk_colors.get_status()[0]: pts_r = pick_colors_pts() pts = convert_points(np.array(self.cross_list[0]), np.array(self.cross_list[1]), pts_r) h.append(self.sfigs[0].plot(pts_r[:, 0], pts_r[:, 1], '+g')) h.append(self.sfigs[1].plot(pts[:, 0], pts[:, 1], '+g')) self.sfigs[0].set_title('{}/{}'.format(self.selected_obj_ind + 1, len(self.cross_list[0]))) for find, subp in enumerate(self.sfigs): #if len(self.cross_list[find]): # crosses = np.array(self.cross_list[find]) # h.append(subp.plot(crosses[:,0],crosses[:,1],'+b')) for i, cr in enumerate(self.cross_list[find]): h.append( subp.plot(cr[0], cr[1], '+r' if self.selected_obj_ind == i else '+b')) def get_closest(self, x, y, ax_ind): max_ind = -1 max_val = 30 #minmal distance to accept click for ind, cr in enumerate(self.cross_list[ax_ind]): d = abs(cr[0] - x) + abs(cr[1] - y) if d < max_val: max_val = d max_ind = ind return max_ind def delete(self, event): if self.selected_obj_ind != -1: for i in [0, 1]: self.cross_list[i].pop(self.selected_obj_ind) self.selected_obj_ind = -1 self.draw_objs() plt.draw() def save(self, event): plt.figure() plt.subplot(1, 3, 1) plt.title('ref') plt.imshow(self.imgr) pts_r = pick_colors_pts() plt.plot(pts_r[:, 0], pts_r[:, 1], '+') plt.subplot(1, 3, 2) plt.title('im corrected') pts = convert_points(np.array(self.cross_list[0]), np.array(self.cross_list[1]), pts_r) plt.plot(pts[:, 0], pts[:, 1], '+') col_r = pick_colors(self.imgr, pts_r) col_im = pick_colors(self.img, pts) CM = np.transpose(np.linalg.lstsq(col_im, col_r, rcond=None)[0]) with open('color_mat.pkl', 'wb') as fd: pickle.dump(CM, fd) im1_corr = apply_mat(self.img, CM) plt.imshow(im1_corr) plt.subplot(1, 3, 3) plt.title('im1') plt.plot(pts[:, 0], pts[:, 1], '+') plt.imshow(self.img) plt.show() def add(self, event): #import pdb;pdb.set_trace() for i in [0, 1]: self.cross_list[i].append((-1, -1)) self.draw_objs() plt.draw() def onclick(self, event): #print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % # ('double' if event.dblclick else 'single', event.button, # event.x, event.y, event.xdata, event.ydata)) ind = self.sfigs.index( event.inaxes) if event.inaxes in self.sfigs else -1 #print('inaxes',ind,event) x, y = event.xdata, event.ydata if event.button == 3 and ind > -1 and self.selected_obj_ind > -1: self.cross_list[ind][self.selected_obj_ind] = (x, y) self.draw_objs() plt.draw()
def plot_selected(event): print(CheckButtons.get_status(chk_btn)) print('Plotting selected datasets...') chk_btn_status = CheckButtons.get_status(chk_btn) # empty dataframe hyst_data_cut = pd.DataFrame() # column index operators init j = 0 # M k = 1 # H # construct DataFrame containig only selected (checkboxed) datastes for i in xrange(0,len(hyst_data.columns)/2): if chk_btn_status[i] == True: #get dataset to plot x = hyst_data.iloc[:,j] y = hyst_data.iloc[:,k] hyst_data_icut = pd.DataFrame([x,y]).T hyst_data_cut = pd.concat([hyst_data_cut, hyst_data_icut], axis=1) j+=2 k+=2 hyst_labels_cut = hyst_data_cut.columns.values print(hyst_data_cut) print('A total of' + ' ' + str(len(hyst_data_cut.columns)/2) + ' datasets in DataFrame were selected,') print('generated a new DataFrame with column names:') print(hyst_labels_cut) # plot figure for the dataframe fig1, ax1 = plt.subplots(figsize=(9, 9)) fig1.tight_layout(pad=4.0, w_pad=0.5, h_pad=0.5) plt.subplots_adjust(left=0.15, bottom=0.1, wspace=0.0, hspace=0.0) # set widget window title fig1.canvas.set_window_title('hysteresis plotter' + ' ' + version_name) # plot version info in the description plt.figtext(0.90, 0.97, version_name, size=10) # define datapoint markers colors colors = iter(plt.cm.inferno(np.linspace(0.3,0.8,10))) mfcolors = iter(plt.cm.plasma(np.linspace(0.1,1,10))) # collect plot objects here hyst_plots_cut = [] # column index operators init j = 0 # M k = 1 # H # plot data recursively from DataFrame for i in xrange(0,len(hyst_data_cut.columns)/2): # generate dataset label hyst_label_cut = hyst_labels_cut[j][1:] #get dataset to plot x = hyst_data_cut.iloc[:,j] y = hyst_data_cut.iloc[:,k] j+=2 k+=2 hyst_plot_cut = plt.plot(x+i*H_step, y, 'o', color=next(colors), mfc=next(mfcolors), markersize=6, label=hyst_label_cut, visible=True) plt.legend(loc='upper left', frameon=True, fontsize=10, title='Anneal Time') # format plot custom_axis_formater(plot_title, plot_x_label, plot_y_label, xmin, xmax, ymin, ymax, xprec, yprec) plt.show() # export dataframe to xlsx hyst_data_cut.to_excel("hyst_out_cut.xlsx")
class CellAutomatonGameGUI(object): """ GUI to control: 1. N = # of creatures 2. P = infection probability 3. K = quarantine parameter 4. L = generation (iteration) from which the quarantine applies 5. animation's speed 6. pause | play | reset buttons """ def __init__(self, game): self.game = game self.N = N self.P = P self.K = K self.L = L # future members self.fig, self.ax = None, None self.animation = None def __set_widgets(self): """ set all sliders and buttons that are in the gui's responsibility """ hcolor = None axcolor = 'white' slider_x_loc = 0.25 slider_y_loc = 0.2 slider_width = 0.6 slider_hight = 0.021 gap = slider_hight + 0.01 # [left, bottom, width, height] # parameters sliders self.n_slider = Slider(plt.axes( [slider_x_loc, slider_y_loc, slider_width, slider_hight], facecolor=axcolor), 'N', 1.0, int(self.game.get_size() / 2), valinit=self.N, valstep=1.0, valfmt='%0.0f') self.p_slider = Slider(plt.axes( [slider_x_loc, slider_y_loc - gap, slider_width, slider_hight], facecolor=axcolor), 'P', 0.0, 1.0, valinit=self.P) self.k_slider_loc = plt.axes( [slider_x_loc, slider_y_loc - 2 * gap, slider_width, slider_hight], facecolor=axcolor) self.k_slider = Slider(self.k_slider_loc, 'K', 1.0, 8.0, valinit=self.K, valstep=1.0, valfmt='%0.0f') self.k_slider_loc.set_visible(False) self.l_slider_loc = self.fig.add_axes( [slider_x_loc, slider_y_loc - 3 * gap, slider_width, slider_hight], facecolor=axcolor) self.l_slider = Slider(self.l_slider_loc, 'L', 0.0, 1000.0, valinit=0, valstep=20.0, valfmt='%0.0f') self.l_slider_loc.set_visible(False) # quarantine option menu self.right_menu_x_loc = 0.025 self.options_button = CheckButtons( plt.axes( [self.right_menu_x_loc, slider_y_loc - 4 * gap, 0.18, 0.15]), ['apply\nquarantine']) self.get_stat_button = None if ALLOW_SAVE_DATA: self.get_stat_button = Button(plt.axes( [self.right_menu_x_loc, slider_y_loc + 15 * gap, 0.12, 0.04]), 'save data', color=axcolor, hovercolor=hcolor) # control animation's speed y_axis_speed = 0.92 plt.text(0.80, y_axis_speed, 'speed: ', transform=self.fig.transFigure) self.speed_box = plt.text(0.88, y_axis_speed, '', transform=self.fig.transFigure) self.speed_up_button = Button(plt.axes( [0.94, y_axis_speed, 0.02, 0.03]), '+', hovercolor=hcolor) self.speed_down_button = Button(plt.axes( [0.96, y_axis_speed, 0.02, 0.03]), '-', hovercolor=hcolor) # control animation buttons self.play_button = Button(plt.axes([0.8, 0.025, 0.1, 0.04]), 'play', color=axcolor, hovercolor=hcolor) self.pause_button = Button(plt.axes([0.69, 0.025, 0.1, 0.04]), 'pause', color=axcolor, hovercolor=hcolor) self.reset_button = Button(plt.axes([0.56, 0.025, 0.12, 0.04]), 'reset', color=axcolor, hovercolor=hcolor) def __set_p(self, e): """ on-click function: change P slider value """ self.P = self.p_slider.val def __set_n(self, e): """ on-click function: change N slider value """ self.N = int(self.n_slider.val) def __set_k(self, e): """ on-click function: change K slider value """ check = self.options_button.get_status()[0] self.k_slider_loc.set_visible(check) self.k_slider.set_active(check) self.K = int(self.k_slider.val) if check else 0 def __set_l(self, e): """ on-click function: change L slider value """ check = self.options_button.get_status()[0] self.l_slider_loc.set_visible(check) self.l_slider.set_active(check) self.L = int(self.l_slider.val) if check else None def __reset_button_on_click(self, e): """ on-click function: pause and reset the game with current parameters """ self.animation.stop() self.game.build(self.N, self.P, self.K, self.L) def set_all(self): """ create gui's visible elements and attach them to event-functions """ self.fig, self.ax = plt.subplots() plt.subplots_adjust(left=0.25, bottom=0.25) self.__set_widgets() self.p_slider.on_changed(self.__set_p) self.n_slider.on_changed(self.__set_n) self.k_slider.on_changed(self.__set_k) self.l_slider.on_changed(self.__set_l) self.options_button.on_clicked(self.__set_l) self.options_button.on_clicked(self.__set_k) self.reset_button.on_clicked(self.__reset_button_on_click) def start(self): """ create all entities and start the animation """ # calculate game's statistics for each time step gs = GameStatistics(self.game) # follow statistics ShowStatistics(gs, self.fig, x=self.right_menu_x_loc) StatAccumulator(gs, self.get_stat_button) if SHOW_ONLINE_GRAPH: OnlineGraph(gs) self.game.build(self.N, self.P, self.K, self.L) self.animation = CellAnimation(self.pause_button, self.play_button, self.speed_up_button, self.speed_down_button, self.speed_box, self.fig, self.ax) self.animation.start(self.game)
# +---------------------------------------------+ # | Show Figure(determin whether or not)(GUI) | # +---------------------------------------------+ correlate_threshold = 0.01 * 100 z_score_lag_threshold = 2.5 z_score_log_threshold = 2.5 depmax_diff = 0.9 if ((wasserstein_metric * 100).max() >= correlate_threshold or abs(zScoreLag).max() >= z_score_lag_threshold or (abs(zScoreLog).max() >= z_score_log_threshold and objects_depmax_log.max() - objects_depmax_log.min() > depmax_diff)): plt.show() # plt.show() DELET_BOOL_VALUE = check.get_status() # get the status of checkbuttons print(DELET_BOOL_VALUE) # plt.show() # # +-------------------------------+ # # | TEST: PLOT the sequencer data | # # +-------------------------------+ # fig2 = plt.figure(4, figsize=(7.5, 5)) # fig2_ax3 = plt.subplot2grid((1, 4), (0, 2), rowspan=1, colspan=2) # sb.heatmap(wasserstein_metric*100, annot=True, fmt='.1f') # # plt.gca().invert_yaxis() # invert the y axis # ticks_labels = [str(coun+1) for coun in range(num_ray)] # fig2_ax3.xaxis.tick_top() # plt.yticks(range(num_ray), ticks_labels) # plt.xticks(range(num_ray), ticks_labels)
class DefacingInterface(BaseReviewInterface): """Custom interface to rate the quality of defacing in an MRI scan""" def __init__(self, fig, axes, issue_list=cfg.defacing_default_issue_list, next_button_callback=None, quit_button_callback=None, processing_choice_callback=None, map_key_to_callback=None): """Constructor""" super().__init__(fig, axes, next_button_callback, quit_button_callback) self.issue_list = issue_list self.prev_axis = None self.prev_ax_pos = None self.zoomed_in = False self.next_button_callback = next_button_callback self.quit_button_callback = quit_button_callback self.processing_choice_callback = processing_choice_callback if map_key_to_callback is None: self.map_key_to_callback = {} # empty elif isinstance(map_key_to_callback, dict): self.map_key_to_callback = map_key_to_callback else: raise ValueError('map_key_to_callback must be a dict') self.add_checkboxes() self.add_process_options() # include all the non-data axes here (so they wont be zoomed-in) self.unzoomable_axes = [ self.checkbox.ax, self.text_box.ax, self.bt_next.ax, self.bt_quit.ax, self.radio_bt_vis_type ] # this list of artists to be populated later # makes to handy to clean them all self.data_handles = list() def add_checkboxes(self): """ Checkboxes offer the ability to select multiple tags such as Motion, Ghosting Aliasing etc, instead of one from a list of mutual exclusive rating options (such as Good, Bad, Error etc). """ ax_checkbox = plt.axes(cfg.position_checkbox_t1_mri, facecolor=cfg.color_rating_axis) # initially de-activating all check_box_status = [False] * len(self.issue_list) self.checkbox = CheckButtons(ax_checkbox, labels=self.issue_list, actives=check_box_status) self.checkbox.on_clicked(self.save_issues) for txt_lbl in self.checkbox.labels: txt_lbl.set(color=cfg.text_option_color, fontweight='normal') for rect in self.checkbox.rectangles: rect.set_width(cfg.checkbox_rect_width) rect.set_height(cfg.checkbox_rect_height) # lines is a list of n crosses, each cross (x) defined by a tuple of lines for x_line1, x_line2 in self.checkbox.lines: x_line1.set_color(cfg.checkbox_cross_color) x_line2.set_color(cfg.checkbox_cross_color) self._index_pass = cfg.defacing_default_issue_list.index( cfg.defacing_pass_indicator) def add_process_options(self): ax_radio = plt.axes(cfg.position_radio_bt_t1_mri, facecolor=cfg.color_rating_axis) self.radio_bt_vis_type = RadioButtons(ax_radio, cfg.vis_choices_defacing, active=None, activecolor='orange') self.radio_bt_vis_type.on_clicked(self.processing_choice_callback) for txt_lbl in self.radio_bt_vis_type.labels: txt_lbl.set(color=cfg.text_option_color, fontweight='normal') for circ in self.radio_bt_vis_type.circles: circ.set(radius=0.06) def save_issues(self, label): """ Update the rating This function is called whenever set_active() happens on any label, if checkbox.eventson is True. """ if label == cfg.visual_qc_pass_indicator: self.clear_checkboxes(except_pass=True) else: self.clear_pass_only_if_on() self.fig.canvas.draw_idle() def clear_checkboxes(self, except_pass=False): """Clears all checkboxes. if except_pass=True, does not clear checkbox corresponding to cfg.t1_mri_pass_indicator """ cbox_statuses = self.checkbox.get_status() for index, this_cbox_active in enumerate(cbox_statuses): if except_pass and index == self._index_pass: continue # if it was selected already, toggle it. if this_cbox_active: # not calling checkbox.set_active() as it calls the callback # self.save_issues() each time, if eventson is True self._toggle_visibility_checkbox(index) def clear_pass_only_if_on(self): """Clear pass checkbox only""" cbox_statuses = self.checkbox.get_status() if cbox_statuses[self._index_pass]: self._toggle_visibility_checkbox(self._index_pass) def _toggle_visibility_checkbox(self, index): """toggles the visibility of a given checkbox""" l1, l2 = self.checkbox.lines[index] l1.set_visible(not l1.get_visible()) l2.set_visible(not l2.get_visible()) def get_ratings(self): """Returns the final set of checked ratings""" cbox_statuses = self.checkbox.get_status() user_ratings = [ self.checkbox.labels[idx].get_text() for idx, this_cbox_active in enumerate(cbox_statuses) if this_cbox_active ] return user_ratings def allowed_to_advance(self): """ Method to ensure work is done for current iteration, before allowing the user to advance to next subject. Returns False if atleast one of the following conditions are not met: Atleast Checkbox is checked """ if any(self.checkbox.get_status()): allowed = True else: allowed = False return allowed def reset_figure(self): "Resets the figure to prepare it for display of next subject." self.clear_data() self.clear_checkboxes() self.clear_radio_buttons() self.clear_notes_annot() def clear_data(self): """clearing all data/image handles""" if self.data_handles: for artist in self.data_handles: artist.remove() # resetting it self.data_handles = list() def clear_notes_annot(self): """clearing notes and annotations""" self.text_box.set_val(cfg.textbox_initial_text) # text is matplotlib artist self.annot_text.remove() def clear_radio_buttons(self): """Clears the radio button""" # enabling default rating encourages lazy advancing without review # self.radio_bt_rating.set_active(cfg.index_freesurfer_default_rating) for index, label in enumerate(self.radio_bt_vis_type.labels): if label.get_text() == self.radio_bt_vis_type.value_selected: self.radio_bt_vis_type.circles[index].set_facecolor( cfg.color_rating_axis) break self.radio_bt_vis_type.value_selected = None def on_mouse(self, event): """Callback for mouse events.""" if self.prev_axis is not None: if event.inaxes not in self.unzoomable_axes: self.prev_axis.set_position(self.prev_ax_pos) self.prev_axis.set_zorder(0) self.prev_axis.patch.set_alpha(0.5) self.zoomed_in = False # right or double click to zoom in to any axis if (event.button in [3] or event.dblclick) and \ (event.inaxes is not None) and \ event.inaxes not in self.unzoomable_axes: self.prev_ax_pos = event.inaxes.get_position() event.inaxes.set_position(cfg.zoomed_position) event.inaxes.set_zorder(1) # bring forth event.inaxes.set_facecolor('black') # black event.inaxes.patch.set_alpha(1.0) # opaque self.zoomed_in = True self.prev_axis = event.inaxes else: pass self.fig.canvas.draw_idle() def on_keyboard(self, key_in): """Callback to handle keyboard shortcuts to rate and advance.""" # ignore keyboard key_in when mouse within Notes textbox if key_in.inaxes == self.text_box.ax or key_in.key is None: return key_pressed = key_in.key.lower() # print(key_pressed) if key_pressed in ['right', ' ', 'space']: self.next_button_callback() elif key_pressed in ['ctrl+q', 'q+ctrl']: self.quit_button_callback() elif key_pressed in self.map_key_to_callback: # notice parentheses at the end self.map_key_to_callback[key_pressed]() else: if key_pressed in cfg.abbreviation_t1_mri_default_issue_list: checked_label = cfg.abbreviation_t1_mri_default_issue_list[ key_pressed] self.checkbox.set_active( cfg.t1_mri_default_issue_list.index(checked_label)) else: pass self.fig.canvas.draw_idle()
class ImageEditor(BaseEditor): def __init__(self, image, range_of_interest, mask_func, estimate_p_start, estimate_p_target, resize_rate=1.0, travel_threshold=0.0, wave_coef=2.0, step_width=1.0): ''' Args: image_raw (np.ndarray(uint8)): raw image, shape=(H, W, 3) mask_raw (np.ndarray(bool)): mask image, shape=(H, W) mask_func (func): mask_func estimator estimate_p_start (func): basal point estimator estimate_p_target (func): tip point estimator travel_threshold (float, optional): Boundary dist, Defaults to 0. wave_coef (float, optional): wave speed coefficient. Defaults to 2. step_width (float, optional): backward step width. Defaults to 1. ''' self.image_size = tuple(image.shape[:2]) self.image_edit = np.concatenate([ cv2.cvtColor(image, cv2.COLOR_BGR2RGB), np.full((*image.shape[:2], 1), 255, dtype=np.uint8) ], axis=2) self.roi_raw = range_of_interest & mask_func(image) self.roi_edit = self.roi_raw self.mask_func = mask_func self.estimate_p_start = estimate_p_start self.estimate_p_target = estimate_p_target self.extensions = [] self.extensions += self.estimate_p_start.func.extensions self.extensions += self.estimate_p_target.func.extensions self.extensions += self.mask_func.extensions self._resize_rate = resize_rate self._wave_coef = wave_coef self._travel_threshold = travel_threshold self._step_width = step_width self._update_dist_flag = self._update_travel_flag = True def estimate(self): self.p_start = self.estimate_p_start(self.image_edit, self.roi_edit, self.dist_field, resize_rate=self._resize_rate) self.p_start_init = np.array(self.p_start) self.endpoint_num, self.p_target = self.estimate_p_target( self.image_edit, self.roi_edit, self.dist_field, self.travel_field, resize_rate=self._resize_rate) self.p_target_init = np.array(self.p_target) self.is_converged = [False] * self.endpoint_num self.path = [np.array([]) for _ in range(self.endpoint_num)] travel_field = self.travel_field if travel_field is None: return travel_field[self.dist_field < self.travel_threshold] = 1e5 for _, _p_target in enumerate(self.p_target): self.is_converged[_], self.path[_] = estimate_path( travel_field, self.p_start, _p_target, self.step_width) def retrieve(self): return [[_check, _path / self._resize_rate] for _check, _path in zip(self.is_converged, self.path)] def launch(self): if not hasattr(self, "path"): self.estimate() self.fig = Figure(figsize=(15, 6)) self.fig.create_grid((1, 3), wspace=0.0, width_ratios=(2, 1, 2)) self.fig[1].create_grid((4, 3), wspace=0.0, width_ratios=(1, 4.5, 1), hspace=0.05, height_ratios=(5, 3.5, 4, 7)) ax_command = self.fig[1][0, 1] ax_command.create_grid((5, 1), hspace=0.0) self.button_reset = Button(ax_command[0], 'reset') self.button_reset.label.set_fontsize(15) self.button_reset.on_clicked(self.reset) self.button_start = Button(ax_command[1], 'estimate basal point') self.button_start.label.set_fontsize(12) self.button_start.on_clicked(self._estimate_start) self.button_target = Button(ax_command[2], 'estimate tip point(s)') self.button_target.label.set_fontsize(12) self.button_target.on_clicked(self._estimate_target) self.button_path = Button(ax_command[3], 'estimate skeletal curve') self.button_path.label.set_fontsize(10) self.button_path.on_clicked(self._estimate_path) self.button_end = Button(ax_command[4], 'end') self.button_end.label.set_fontsize(15) self.button_end.on_clicked(self.terminate) self.fig._fig.canvas.mpl_connect('close_event', self.terminate) ax_action = self.fig[1][1, 1] ax_action.create_grid((2, 1), hspace=0.0, height_ratios=(5, 2)) self.fig._fig.canvas.mpl_connect('button_press_event', self._on_click_set) self.fig._fig.canvas.mpl_connect('motion_notify_event', self._on_click_set) self._set_interactive_tool(self.fig[0], ax_action[0], ax_action[1]) ax_fmm_param = self.fig[1][2, 1] ax_fmm_param.create_grid((4, 1), hspace=0.0, height_ratios=(2, 1, 1, 1)) self.button_field = CheckButtons(ax_fmm_param[0], ["distance", "travel"]) for _label in self.button_field.labels: _label.set_fontsize(12) self.button_field.on_clicked(self.render) self.slider_wave = Slider(ax_fmm_param[1], r'$\alpha$', 0.0, 5.0, valinit=self._wave_coef, valfmt="%.1f") self.slider_dist = Slider(ax_fmm_param[2], r'$\theta_T$', -5.0, 5.0, valinit=self._travel_threshold, valfmt="%.1f") self.slider_step = Slider(ax_fmm_param[3], r'$\delta$', 0.1, 5.0, valinit=self._step_width, valfmt="%.1f") self.slider_wave.on_changed(self._on_click_slider) self.slider_dist.on_changed(self._on_click_slider) self.sliders = [self.slider_wave, self.slider_dist, self.slider_step] ax_extension = self.fig[1][3, 1] ratios = [1] * len(self.extensions) ratios[0] = 2 self._set_extension(ax_extension, height_ratios=ratios) self.fig[0].add_zoom_func() rect = plt.Rectangle((0, 0), self.image_size[1], self.image_size[0], fc='w', ec='gray', hatch='++', alpha=0.5, zorder=-10) self.fig[0].add_patch(rect) self.im_image_left = self.fig[0].plot_matrix(np.zeros( (*self.image_size, 4)), picker=True, flip_axis=True, colorbar=False) self.fig[0].set_xlim([0, self.image_size[1]]) self.fig[0].set_ylim([self.image_size[0], 0]) self.fig[0].set_aspect("equal", "box") self.fig[0].set_title("extracted area " "(ROI+HSV, left click: write, " "left click + shift key: erase)") self.im_start, = self.fig[0].plot(self.p_start[1], self.p_start[0], "ro", markersize=10) self.im_target = [None] * self.endpoint_num for _, _p_target in enumerate(self.p_target): if _p_target is None: self.im_target[_], = self.fig[0].plot(0, 0, "*", markersize=10) self.im_target[_].set_visible(False) if len(self.path[_]) > 0: self.im_target[_].set_data(*self.path[_][-1][::-1]) self.im_target[_].set_visible(True) else: self.im_target[_], = self.fig[0].plot(_p_target[1], _p_target[0], "*", markersize=10) self.fig[2].add_zoom_func() self.im_image_right = self.fig[2].plot_matrix(self.image_edit[..., :3], picker=True, flip_axis=True, colorbar=False) self.fig[2].set_xlim([0, self.image_size[1]]) self.fig[2].set_ylim([self.image_size[0], 0]) self.fig[2].set_aspect("equal", "box") self.fig[2].get_xaxis().set_visible(False) self.fig[2].get_yaxis().set_visible(False) self.fig[2].set_title("original image") self.im_path_left = [None] * self.endpoint_num self.im_path_right = [None] * self.endpoint_num for _ in range(self.endpoint_num): self.im_path_left[_], = self.fig[0].plot([], [], color="pink", marker="+", picker=True) self.im_path_right[_], = self.fig[2].plot([], [], color="pink", marker="+", picker=True) self.reset() self.fig.show() def reset(self, event=None): self.roi_edit = self.roi_raw & self.mask_func(self.image_edit[..., :3]) for ext in self.extensions: ext.reset(self, event) for sld in self.sliders: sld.reset() self.path = [np.array([]) for _ in range(self.endpoint_num)] self.is_converged = [False] * self.endpoint_num self.p_start = np.array(self.p_start_init) self.p_target = np.array(self.p_target_init) self._update_dist_flag = True self._update_travel_flag = True self.update() def update(self, use_mask_func=False): if use_mask_func: mask = self.mask_func(self.image_edit[..., :3]) self.roi_edit = self.roi_raw & mask self.image_edit[~self.roi_edit, 3] = 0 self.image_edit[self.roi_edit, 3] = 255 self.im_image_left.set_data(self.image_edit[::-1]) self._update_dist_flag = True self._update_travel_flag = True self.render() def render(self, event=None): self.im_start.set_data(self.p_start[1], self.p_start[0]) for _id, _p_target in enumerate(self.p_target): self.im_target[_id].set_data(_p_target[1], _p_target[0]) for _i, _path in enumerate(self.path): if _path.size > 0: self.im_path_left[_i].set_data(_path[:, 1], _path[:, 0]) self.im_path_right[_i].set_data(_path[:, 1], _path[:, 0]) check_button = self.button_field.get_status() if check_button[0] and (self.im_dist is None or self._update_dist_flag): dist_field = np.array(self.dist_field) if dist_field is not None: dist_field[dist_field < 0] = 0 self.im_dist = self.fig[0].plot_matrix(dist_field, contour=True, flip_axis=False, levels=100, colorbar=False) if self.im_dist is not None: for _ in self.im_dist.collections: _.set_visible(check_button[0]) if check_button[1] and (self.im_travel is None or self._update_travel_flag): travel_field = self.travel_field if travel_field is not None: self.im_travel = self.fig[0].plot_matrix(self.travel_field, contour=True, flip_axis=False, levels=100, colorbar=False) if self.im_travel is not None: for _ in self.im_travel.collections: _.set_visible(check_button[1]) self.fig._fig.canvas.draw_idle() def terminate(self, event=None): self.fig.close() @property def wave_coef(self): if hasattr(self, "slider_wave"): return self.slider_wave.val else: return self._wave_coef @property def travel_threshold(self): if hasattr(self, "slider_dist"): return self.slider_dist.val else: return self._travel_threshold @property def step_width(self): if hasattr(self, "slider_step"): return self.slider_step.val else: return self._step_width @property def dist_field(self): if self._update_dist_flag or (not hasattr(self, "_dist_field")): self._dist_field = calc_distance(self.roi_edit) self._update_dist_flag = False if hasattr(self, "im_dist") and (self.im_dist is not None): for _ in self.im_dist.collections: _.remove() self.im_dist = None self._update_travel_flag = True if self._dist_field is None: return None else: return np.array(self._dist_field) @property def travel_field(self): if self._update_travel_flag or (not hasattr(self, "_travel_field")): self._travel_field = calc_travel(self.dist_field, self.p_start, self.wave_coef, self.travel_threshold) self._update_travel_flag = False if hasattr(self, "im_travel") and (self.im_travel is not None): for _ in self.im_travel.collections: _.remove() self.im_travel = None if self._travel_field is None: return None else: return np.array(self._travel_field) def _on_click_slider(self, event): self._update_travel_flag = True self.render() def _on_click_set(self, event): if not (self.area_bbox.x0 < event.x < self.area_bbox.x1): return if not (self.area_bbox.y0 < event.y < self.area_bbox.y1): return if event.key is not None: return if event.button is None: self.drag_id = None return def _position(ax): coords = ax.get_path().vertices return ax.get_transform().transform(coords) p_mouse = np.array([event.x, event.y]) p_mouse_data = np.array([event.ydata, event.xdata]) if event.button == 1 and self.drag_id is None: if np.linalg.norm(_position(self.im_start) - p_mouse) < 10.0: self.drag_id = -1 else: for _id in range(self.endpoint_num): if np.linalg.norm( _position(self.im_target[_id]) - p_mouse) < 10.0: self.drag_id = _id break if self.drag_id == -1: self.p_start = p_mouse_data self._update_travel_flag = True elif self.drag_id is not None: self.p_target[self.drag_id] = p_mouse_data self.im_target[self.drag_id].set_data( self.p_target[self.drag_id][1], self.p_target[self.drag_id][0]) self.render() def _estimate_start(self, event=None): self.p_start = self.estimate_p_start(self.image_edit[..., :3], self.roi_edit, self.dist_field, p_start_now=self.p_start, resize_rate=self._resize_rate) self._update_travel_flag = True self.render() def _estimate_target(self, event=None): self.endpoint_num, self.p_target = self.estimate_p_target( self.image_edit[..., :3], self.roi_edit, self.dist_field, self.travel_field, p_target_now=self.p_target, resize_rate=self._resize_rate) self.render() def _estimate_path(self, event=None): travel_field = self.travel_field if travel_field is None: for _, _p_target in enumerate(self.p_target): self.is_converged[_], self.path[_] = False, np.array([]) else: travel_field[self.dist_field < self.travel_threshold] = 1e5 for _, _p_target in enumerate(self.p_target): self.is_converged[_], self.path[_] = estimate_path( travel_field, self.p_start, _p_target, self.step_width) self.render()
class LyaGUI(SimpleFitGUI): """ This is a GUI to perform model-independent measurements of Ly-alpha line profiles (primarily interesting for emission). It measures the following craracteristics, when present: - Equivalent Width - Red and blue peak velocity and separation - Red and blue peak flux relative to continuum, and their ratio - Red and blue integrated flux, and their ratio - Peak-to-valley flux ratios. - Asymmetry (crookedity?) of red peak """ galaxy_name = '' data = None interp = None errs = None # mask = None smoothed = None num_realizations = 1000 z = 0. summary_dict = {} cenwave = 1215.67 * (1 + z) _current_peak = 'Red' _peaks = np.array(['Red', 'Blue', 'Valley']) _peak_on = np.array([True, True, True]) _colors = {'Red': 'tab:red', 'Blue': 'tab:blue', 'Valley': '0.7'} _extra_plots = {} _velocities = {} # Just for use in plots def __init__(self, summaryfile=None, inwave=None, indata=None, inerrs=None, inmask=None, smooth=None): self.data = indata self.wave = inwave self.errs = inerrs self.mask = inmask if summaryfile: self.open_summary(summaryfile) # Interpolate masked areas self.data[self.mask] = np.nan self.nonnanidx = np.where(~self.mask)[0] self.interp = np.interp(self.wave, self.wave[self.nonnanidx], self.data[self.nonnanidx]) self.interr = np.interp(self.wave, self.wave[self.nonnanidx], self.errs[self.nonnanidx]) if smooth == 'll': lle = KernelReg(self.interp, self.wave, 'c', bw=[10]) mean, marg = lle.fit() del marg self.smoothed = mean elif smooth == 'box': mean = np.convolve(self.data, np.array([1, 1, 1]) / 3) else: self.smoothed = self.data self._build_plot() def open_summary(self, summary_file): with open(summary_file) as f: summary = yaml.load(f) # summary = yaml.full_load(f) self.z = summary['redshift'] self.galaxy_name = summary['galaxyname'] self.data = np.array(summary['transitions']['Ly alpha']['data']) self.errs = np.array(summary['transitions']['Ly alpha']['errs']) self.wave = np.array(summary['transitions']['Ly alpha']['wave']) self.mask = np.array(summary['transitions']['Ly alpha']['mask']) try: self._cfp = summary['transitions']['Ly alpha']['cont_fit_params'] except KeyError: self._cfp = summary['transitions']['Ly alpha'][ 'continuum_fit_params'] self._cont = self.wave * self._cfp['slope'] + self._cfp['intercept'] def _build_plot(self): self.fig = plt.figure(figsize=(8, 4)) self.ax = self.fig.add_axes([0.08, 0.08, 0.8, 0.8]) self.chax = self.fig.add_axes([0.90, 0.55, 0.09, 0.3], frameon=False) self.chax.set_title('Components \npresent', size='x-small') self.rax = self.fig.add_axes([0.90, 0.25, 0.09, 0.2], frameon=False) self.rax.set_title('Current \ncomponent', size='x-small') self.okax = self.fig.add_axes([0.9, 0.05, 0.09, 0.08]) self.ok_button = Button(self.okax, 'Done') self.ok_button.on_clicked(self._ok_clicked) self.reax = self.fig.add_axes([0.9, 0.15, 0.09, 0.08]) self.re_button = Button(self.reax, 'Reset') self.re_button.on_clicked(self._reset_clicked) for ax in self.chax, self.rax: ax.tick_params(length=0, labelleft='off', labelbottom='off') self.ax.plot(self.wave, self.data, 'k-', drawstyle='steps-mid') if self.interp is not None: self.ax.plot(self.wave, self.interp, '-', color='C1', zorder=0) # if self.smoothed is not None: # self.ax.plot(self.wave, self.smoothed, '-', color='C2') self.ax.axhline(0, ls='-', color='k', lw=1) self.ax.axvline(1215.67 * (1 + self.z), ls=':', color='k', lw=.8) self.ax.axhline(1, ls=':', color='k', lw=.8) # Insert check buttons clabels = self._peaks cons = self._peak_on self.check = CheckButtons(self.chax, clabels, cons) self.check.on_clicked(self._check_clicked) rlabels = clabels[np.where(cons)] self.radio = RadioButtons(self.rax, rlabels) # self.radio.on_clicked(self._radio_clicked) self._selector = SpanSelector( self.ax, self._onselect, 'horizontal', minspan=self._wave_diffs.mean() * 5, ) self.ax.set_xlabel('Wavelength') self.ax.set_ylabel('Normalized flux') def _onselect(self, xmin, xmax): self.fit_active(xmin, xmax) def _ok_clicked(self, event): self.save_summary() self.ax.axvline(v_to_wl(self.summary_dict['Red']['vpeak'][2], self.refwave), ls='--', color='0.5') self.ax.axvline(v_to_wl(self._velocities['Red']['v05'], self.refwave), ls='--', color='0.5') self.ax.axvline(v_to_wl(self._velocities['Red']['v50'], self.refwave), ls='--', color='0.5') self.ax.axvline(v_to_wl(self._velocities['Red']['v95'], self.refwave), ls='--', color='0.5') # self.ax.draw() plt.draw() # plt.close(self.ax.figure) def _reset_clicked(self, event): self.summary_dict = {} for a in self._extra_plots.values(): a.remove() print('You presset RESET and are now back to scratch.') def _radio_clicked(self, event): self._selector.rectprops['facecolor'] = self._colors[event] def _check_clicked(self, event): print('Event: ', event) self.rax.remove() self.rax = self.fig.add_axes([0.90, 0.25, 0.09, 0.2], frameon=False) self.rax.set_title('Current \ncomponent', size='x-small') self._peak_on = np.array(self.check.get_status()) rlabels = self._peaks[np.where(self.check.get_status())] self.radio = RadioButtons(self.rax, rlabels) plt.draw() def measure_flux(self, xmin, xmax, iters=1): idx = np.where((self.wave > xmin) & (self.wave < xmax)) wav = self.wave[idx] vel = wl_to_v(wav, self.refwave) dat = self.interp[idx] # - 1 fluxes, vmaxs, vmins, fwhms = [], [], [], [] bhms, rhms, asymmetry, asymGronKo = [], [], [], [] fmax, fmin = [], [] v05, v50, v95 = [], [], [] if self.errs is None: iters = 1 for i in range(iters): perturb = np.array( [np.random.normal(scale=e) for e in self.errs[idx]]) if i > 0: pertdata = dat + perturb else: pertdata = dat # fluxes.append(((pertdata - 1) * self._wave_diffs[idx]).sum()) fluxes.append(((pertdata) * self._wave_diffs[idx]).sum()) vmaxs.append(vel[pertdata.argmax()]) vmins.append(vel[pertdata.argmin()]) cumflux = np.cumsum(pertdata - 1) # cumflux = np.cumsum(pertdata) q05 = vel[np.absolute(cumflux / cumflux.max() - 0.05).argmin()] q50 = vel[np.absolute(cumflux / cumflux.max() - 0.50).argmin()] q95 = vel[np.absolute(cumflux / cumflux.max() - 0.95).argmin()] A = (q95 - q50) / (q50 - q05) # Agk = (q95 - vel[pertdata.argmax()]) / (vel[pertdata.argmax()] - q05) fpeak = cumflux[pertdata.argmax()] ftot = cumflux.max() Agk = (ftot - fpeak) / fpeak asymmetry.append(A) asymGronKo.append(Agk) fwidx = np.where(pertdata - 1 > (pertdata - 1).max() / 2)[0] bhm = vel[fwidx.min()] rhm = vel[fwidx.max()] bhms.append(bhm) rhms.append(rhm) fmax.append((pertdata - 1).max()) fmin.append((pertdata - 1).min()) # fmax.append((pertdata).max()) # fmin.append((pertdata).min()) fwhms.append(rhm - bhm) v05.append(q05) v95.append(q95) v50.append(q50) self._velocities[self.radio.value_selected] = { 'v05': np.median(v05), 'v50': np.median(v50), 'v95': np.median(v95), } return fluxes, vmaxs, vmins, fwhms, bhms, rhms, asymmetry, fmin, fmax, asymGronKo def absolute_flux(self, xmin=None, xmax=None, iters=1): # TODO Write sane defaults for xmin and xmax, should go via velocity. afs = [] afarray = (self.interp - 1) * self._cont aferrar = (self.interr) * self._cont ranges = [self.summary_dict[i]['range'] for i in self._peaks_active] therange = np.array(ranges).flatten() if xmin is None: xmin = therange.min() if xmax is None: xmax = therange.max() if xmin: afarray[self.wave < xmin] = 0 if xmax: afarray[self.wave > xmax] = 0 if self.errs is None: iters = 1 for i in range(iters): if i == 0: pertdata = afarray else: perturb = np.array( [np.random.normal(scale=e) for e in np.absolute(aferrar)]) pertdata = afarray + perturb afs.append((pertdata * self._wave_diffs).sum()) self.summary_dict['AbsFlux'] = np.percentile(afs, [2.5, 16, 50, 84, 97.5]) return self.summary_dict['AbsFlux'] # np.atleast_1d(afs) def equivalent_width(self, xmin=None, xmax=None, iters=1000): ews = [] ewarray, ewerrar = (self.interp - 1), self.interr ranges = [self.summary_dict[i]['range'] for i in self._peaks_active] therange = np.array(ranges).flatten() if xmin is None: xmin = therange.min() if xmax is None: xmax = therange.max() if xmin: ewarray[self.wave < xmin] = 0 if xmax: ewarray[self.wave > xmax] = 0 if self.errs is None: iters = 1 for i in range(iters): if i == 0: pertdata = ewarray else: perturb = np.array( [np.random.normal(scale=e) for e in np.absolute(ewerrar)]) pertdata = ewarray + perturb ews.append((pertdata * self._wave_diffs).sum()) print(np.std(ews)) self.summary_dict['EW_lya'] = np.percentile(ews, [2.5, 16, 50, 84, 97.5]) return self.summary_dict['EW_lya'] def fit_red(self, xmin, xmax): self._selector.rectprops['facecolor'] = 'tab:red' self._extra_plots['redspan'] = self.ax.axvspan( xmin, xmax, color=self._colors['Red'], alpha=.5, zorder=0) flux, vpeak, vmin, fwhm, bhms, rhms, A_qs, fmin, fmax, Agks = \ self.measure_flux(xmin, xmax, iters=1000) self.summary_dict['Red'] = { 'range': (xmin, xmax), 'flux': np.percentile(flux, [2.5, 16, 50, 84, 97.5]), 'fmax': np.percentile(fmax, [2.5, 16, 50, 84, 97.5]), 'vpeak': np.percentile(vpeak, [2.5, 16, 50, 84, 97.5]), 'fwhm': np.percentile(fwhm, [2.5, 16, 50, 84, 97.5]), 'blue_at_half_width': np.percentile(bhms, [2.5, 16, 50, 84, 97.5]), 'red_at_half_width': np.percentile(rhms, [2.5, 16, 50, 84, 97.5]), 'Asymmetry': np.percentile(A_qs, [2.5, 16, 50, 84, 97.5]), 'Asym_Gr_Ko': np.percentile(Agks, [2.5, 16, 50, 84, 97.5]), } print('Fitting red peak') print(xmin, xmax) def fit_blue(self, xmin, xmax): self._extra_plots['bluespan'] = self.ax.axvspan( xmin, xmax, color=self._colors['Blue'], alpha=.5, zorder=0) flux, vpeak, vmin, fwhm, bhms, rhms, A_qs, fmin, fmax, Agks = \ self.measure_flux(xmin, xmax, iters=1000) self.summary_dict['Blue'] = { 'range': (xmin, xmax), 'flux': np.percentile(flux, [2.5, 16, 50, 84, 97.5]), 'fmax': np.percentile(fmax, [2.5, 16, 50, 84, 97.5]), 'vpeak': np.percentile(vpeak, [2.5, 16, 50, 84, 97.5]), 'fwhm': np.percentile(fwhm, [2.5, 16, 50, 84, 97.5]), 'blue_at_half_width': np.percentile(bhms, [2.5, 16, 50, 84, 97.5]), 'red_at_half_width': np.percentile(rhms, [2.5, 16, 50, 84, 97.5]), } print("The following has been added/overwritten" + " in the summary_dict['Blue'].") return self.summary_dict['Blue'] def fit_valley(self, xmin, xmax): flux, vpeak, vmin, fwhm, bhms, rhms, A_qs, fmin, fmax, Agks = \ self.measure_flux(xmin, xmax, iters=1000) self._extra_plots['Valley'] = self.ax.axvspan( xmin, xmax, color=self._colors['Valley'], alpha=5.) self.summary_dict['Valley'] = { 'range': (xmin, xmax), 'minflux': np.percentile(fmin, [2.5, 16, 50, 84, 97.5]), 'vmin': np.percentile(vmin, [2.5, 16, 50, 84, 97.5]) } # print(xmin, xmax) print('Fitting valley') print("The following has been added/overwritten" + " in the summary_dict['Valley'].") return self.summary_dict['Valley'] def fit_active(self, xmin, xmax): if self.radio.value_selected == 'Blue': self.fit_blue(xmin, xmax) elif self.radio.value_selected == 'Red': self.fit_red(xmin, xmax) elif self.radio.value_selected == 'Valley': self.fit_valley(xmin, xmax) fitfuncs = {'Red': fit_red, 'Blue': fit_blue, 'Valley': fit_valley} def save_summary(self): # TODO: Implement something. pass def save_summary_table(self, path="summarytable.ecsv"): d = self.summary_dict # Make sure absolute fliux is measured (TODO make better) if not "AbsFlux" in d.keys(): d["AbsFlux"] = self.absolute_flux(iters=1000) if not "EW_lua" in d.keys(): self.equivalent_width(iters=1000) if "Red" in self._peaks_active: asym = d['Red']['Asymmetry'] asyms = asym[2] - asym[1], asym[2], asym[3] - asym[2] # A_red asgk = d['Red']['Asym_Gr_Ko'] asgks = asgk[2] - asgk[1], asgk[2], asgk[3] - asgk[2] # A_red fwhr = d['Red']['fwhm'] fwhm_reds = fwhr[2] - fwhr[1], fwhr[2], fwhr[3] - fwhr[ 2] # FWHM_red lr = d['Red']['flux'] l_red = lr[2] - lr[1], lr[2], lr[3] - lr[2] # L_red fr = d['Red']['fmax'] f_red = fr[2] - fr[1], fr[2], fr[3] - fr[2] # F_red vr = d['Red']['vpeak'] vpeak_red = vr[2] - vr[1], vr[2], vr[3] - vr[2] # v_red if "Blue" in self._peaks_active: fwhb = d['Blue']['fwhm'] fwhm_blue = fwhb[2] - fwhb[1], fwhb[2], fwhb[3] - fwhb[ 2] # FWHM_blue lb = d['Blue']['flux'] l_blue = lb[2] - lb[1], lb[2], lb[3] - lb[2] # L_red fb = d['Blue']['fmax'] f_blue = fb[2] - fb[1], fb[2], fb[3] - fb[2] # F_red vb = d['Blue']['vpeak'] vpeak_blue = vb[2] - vb[1], vb[2], vb[3] - vb[2] # v_blue if "Valley" in self._peaks_active: vv = d['Valley']['vmin'] v_valley = vv[2] - vv[1], vv[2], vv[3] - vv[2] # Maybe not? v_min fv = d['Valley']['minflux'] f_valley = fv[2] - fv[1], fv[2], fv[3] - fv[2] # F_valley af = d['AbsFlux'] abs_flux = af[2] - af[1], af[2], af[3] - af[2] ewl = d["EW_lya"] EW_lya = ewl[2] - ewl[1], ewl[2], ewl[3] - ewl[2] # Create interim output dictionary outdict = {} if "Red" in self._peaks_active: tmp = { 'fwhm_red': fwhm_reds, 'f_red': f_red, 'l_red': l_red, 'A_red': asyms, 'A_GK': asgks, 'v_red': vpeak_red, } outdict.update(tmp) if "Blue" in self._peaks_active: tmp = { 'fwhm_blue': fwhm_blue, 'f_blue': f_blue, 'l_blue': l_blue, 'v_blue': vpeak_blue, } outdict.update(tmp) if "Valley" in self._peaks_active: tmp = { 'v_valley': v_valley, 'f_valley': f_valley, } outdict.update(tmp) outdict.update({'AbsFlux': abs_flux, 'EW_Lya': EW_lya}) # Now make it a dataframe outframe = pd.DataFrame.from_dict(outdict) outframe.set_index(pd.Index(['Low', 'Median', 'High'], name=self.galaxy_name), inplace=True) outframe = outframe.T # Now make it a Table outtable = Table.from_pandas(outframe.reset_index()) outtable.meta['identifier'] = self.galaxy_name outtable.write(path, format='ascii.ecsv') return outtable @property def _wave_diffs(self): diffs = self.wave[1:] - self.wave[:-1] diffs = np.append(diffs, diffs[-1]) return diffs @property def _peaks_active(self): return self._peaks[self._peak_on] @property def refwave(self): return 1215.67 * (1 + self.z) @property def ref_wl(self): return 1215.67 * (1 + self.z) def __call__(self): self.fig.show()
class EGG_Control_Panel: def __init__(self, Egg): # ============================================================================= # Initialize Egg # ============================================================================= self.Egg = Egg # ============================================================================= # Initialize Figure and Main Axis # ============================================================================= self.figure = plt.figure(num=self.Egg.EGG_figure_Name + " Control Panel", figsize=[5.5, 8], clear=True) self.main_ax = self.figure.add_axes([0, 0, 1, 1], label="main_ax", facecolor='white') self.main_ax.tick_params( axis='x', # changes apply to the x-axis which='both', # both major and minor ticks are affected bottom=False, # ticks along the bottom edge are off top=False, # ticks along the top edge are off left=False, right=False, labelbottom=False) # labels along the bottom edge are off self.main_ax.tick_params( axis='y', # changes apply to the x-axis which='both', # both major and minor ticks are affected left=False, right=False, labelleft=False) # labels along the bottom edge are off self.main_ax.margins(x=0) self.main_ax.set_xlim([0, 1]) self.main_ax.set_ylim([0, 1]) # ============================================================================= # Create Buttons # ============================================================================= Button_width = .2 Button_height = .06 Edge_width = 0.05 Button_spacing = 0.03 Top_gap = 0 Bottom_Gap = .5 Left_gap = 0 Right_gap = 0 Button_count = 0 Button_Box = patches.Rectangle( (Edge_width / 2, Edge_width + Bottom_Gap), 1 - Edge_width, 1 - Bottom_Gap - 1.5 * Edge_width) # Button_Box = patches.Rectangle((0.15, 0.5), .25, .25) self.main_ax.add_patch(Button_Box) max_buttons_per_column = (int)(math.floor( (1 + Button_spacing - 2 * Edge_width - Top_gap - Bottom_Gap) / (Button_height + Button_spacing))) max_buttons_per_row = (int)(math.floor( (1 + Button_spacing - 2 * Edge_width - Left_gap - Right_gap) / (Button_width + Button_spacing))) def Button_params(): column_number = math.floor(Button_count / max_buttons_per_column) row_number = Button_count - max_buttons_per_column * column_number left = Edge_width + Left_gap + ( (Button_width + Button_spacing) * column_number) bottom = 1 - Edge_width - Top_gap - Button_height - ( Button_height + Button_spacing) * row_number if Button_count > max_buttons_per_column * max_buttons_per_row: raise Exception("TOO MANY BUTTONS!!!") return [left, bottom, Button_width, Button_height] self.Reset_Button_ax = self.figure.add_axes(Button_params(), label="Reset_Button_ax", facecolor='white') Button_count += 1 self.Reset_Button = Button( self.Reset_Button_ax, label="Reset", ) self.Reset_Button.on_clicked(self.Egg.Reset_GUI(Reset_Tone=True)) self.Play_Button_ax = self.figure.add_axes(Button_params(), label="Play_Button_ax", facecolor='white') Button_count += 1 self.Play_Button = Button( self.Play_Button_ax, label="Play", ) self.Play_Button.on_clicked(self.Egg.Play) self.Loop_Button_ax = self.figure.add_axes(Button_params(), label="Loop_Button_ax", facecolor='white') Button_count += 1 self.Loop_Button = Button( self.Loop_Button_ax, label="Loop", ) self.Loop_Button.on_clicked(self.Egg.Loop) self.Stop_Button_ax = self.figure.add_axes(Button_params(), label="Stop_Button_ax", facecolor='white') Button_count += 1 self.Stop_Button = Button( self.Stop_Button_ax, label="Stop", ) self.Stop_Button.on_clicked(self.Egg.Stop) # ============================================================================= # Create Misc. Box # ============================================================================= Misc_Box = patches.Rectangle((Edge_width / 2, Edge_width / 2), 1 - Edge_width, Bottom_Gap, color='g') self.main_ax.add_patch(Misc_Box) self.Plot_Toggle_ax = self.figure.add_axes( [Edge_width, Edge_width, .25, .18], label="Plot_Toggle_ax", facecolor='white') self.Plot_Toggle = CheckButtons(ax=self.Plot_Toggle_ax, labels=[ "Plot Wave", "Plot Savgol", "Plot Linear", "Plot Selectors" ], actives=[True, True, True, True]) def Update_Plot_Toggle(val): Toggle_Bools = self.Plot_Toggle.get_status() self.Egg.Plot_Wave = Toggle_Bools[0] self.Egg.Plot_Savgol = Toggle_Bools[1] self.Egg.Plot_Linear = Toggle_Bools[2] self.Egg.Plot_Selectors = Toggle_Bools[3] self.Egg.Update_Canvas()(0) self.Plot_Toggle.on_clicked(Update_Plot_Toggle) self.Savgol_Mode_Select_ax = self.figure.add_axes( [Edge_width, Edge_width + .25, .25, .18], label="Savgol_Mode_Select_ax", facecolor='white') self.Savgol_Mode_Select = RadioButtons(self.Savgol_Mode_Select_ax, labels=[ "wrap", "mirror", "nearest", "constant", ]) def Update_Savgol_Mode_Select(val): active_mode = self.Savgol_Mode_Select.value_selected self.Egg.Update_Savgol_Mode(active_mode) self.Savgol_Mode_Select.on_clicked(Update_Savgol_Mode_Select)