class AnalysisWindow(QMainWindow): def __init__(self, parent, controller): # create window QMainWindow.__init__(self) self.setAttribute(Qt.WA_DeleteOnClose) self.setWindowTitle("Tracking Analysis") self.setGeometry(100, 200, 10, 10) # set controller self.controller = controller # create main widget & layout self.main_widget = QWidget(self) self.main_widget.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding) self.main_layout = QGridLayout(self.main_widget) # create left widget & layout self.left_widget = QWidget(self) self.main_layout.addWidget(self.left_widget, 0, 0) self.left_layout = QVBoxLayout(self.left_widget) self.left_layout.setAlignment(Qt.AlignTop) # create list of tracking items self.tracking_list_items = [] self.tracking_list = QListWidget(self) self.tracking_list.currentRowChanged.connect(self.controller.switch_tracking_file) self.left_layout.addWidget(self.tracking_list) # create tracking list buttons self.tracking_list_buttons = QHBoxLayout(self) self.left_layout.addLayout(self.tracking_list_buttons) self.add_tracking_button = QPushButton('+') self.add_tracking_button.clicked.connect(self.controller.select_and_open_tracking_files) self.add_tracking_button.setToolTip("Add tracking file.") self.tracking_list_buttons.addWidget(self.add_tracking_button) self.remove_tracking_button = QPushButton('-') self.remove_tracking_button.clicked.connect(self.controller.remove_tracking_file) self.remove_tracking_button.setToolTip("Remove selected tracking file.") self.tracking_list_buttons.addWidget(self.remove_tracking_button) self.prev_tracking_button = QPushButton('<') self.prev_tracking_button.clicked.connect(self.controller.prev_tracking_file) self.prev_tracking_button.setToolTip("Switch to previous tracking file.") self.tracking_list_buttons.addWidget(self.prev_tracking_button) self.next_tracking_button = QPushButton('>') self.next_tracking_button.clicked.connect(self.controller.next_tracking_file) self.next_tracking_button.setToolTip("Switch to next tracking file.") self.tracking_list_buttons.addWidget(self.next_tracking_button) # create right widget & layout self.right_widget = QWidget(self) self.right_widget.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding) self.main_layout.addWidget(self.right_widget, 0, 1) self.right_layout = QVBoxLayout(self.right_widget) self.right_layout.setAlignment(Qt.AlignTop) self.right_layout.setSpacing(5) # create button layout for main widget plot_horiz_layout = QHBoxLayout() self.right_layout.addLayout(plot_horiz_layout) # add param labels & textboxes plot_type_label = QLabel() plot_type_label.setText("Plot:") plot_horiz_layout.addWidget(plot_type_label) plot_horiz_layout.addStretch(1) # create tab widget for plot type self.plot_tabs_widget = QTabBar() self.plot_tabs_widget.setDrawBase(False) self.plot_tabs_widget.setExpanding(False) self.plot_tabs_widget.currentChanged.connect(self.controller.change_plot_type) plot_horiz_layout.addWidget(self.plot_tabs_widget) # create button layout for main widget crop_horiz_layout = QHBoxLayout() self.right_layout.addLayout(crop_horiz_layout) # add param labels & textboxes crop_type_label = QLabel() crop_type_label.setText("Crop #:") crop_horiz_layout.addWidget(crop_type_label) crop_horiz_layout.addStretch(1) # create tab widget for crop number self.crop_tabs_widget = QTabBar() self.crop_tabs_widget.setDrawBase(False) self.crop_tabs_widget.setExpanding(False) self.crop_tabs_widget.currentChanged.connect(self.controller.change_crop) crop_horiz_layout.addWidget(self.crop_tabs_widget) # create button layout for main widget button_layout = QHBoxLayout() button_layout.addStretch(1) self.right_layout.addLayout(button_layout) # add buttons self.show_tracking_params_button = QPushButton('Tracking Parameters', self) self.show_tracking_params_button.setMinimumHeight(30) self.show_tracking_params_button.clicked.connect(self.controller.show_tracking_params) button_layout.addWidget(self.show_tracking_params_button) # create stacked widget & layout self.stacked_widget = QStackedWidget(self) self.stacked_widget.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding) self.right_layout.addWidget(self.stacked_widget) self.create_tail_tracking_widget(self.stacked_widget) self.create_body_tracking_widget(self.stacked_widget) # self.right_layout = QVBoxLayout(self.right_widget) # self.right_layout.setAlignment(Qt.AlignTop) # self.right_layout.setSpacing(5) # self.main_widget.setFocus() self.setCentralWidget(self.main_widget) # set window titlebar buttons self.setWindowFlags(Qt.CustomizeWindowHint | Qt.WindowCloseButtonHint | Qt.WindowMinimizeButtonHint | Qt.WindowMaximizeButtonHint) self.show() def update_plot(self, array, plot_type, extra_tracking=None, keep_xlim=True): print("Updating plot") if plot_type == "tail": self.stacked_widget.setCurrentIndex(0) self.plot_window.plot_tail_angle_array(array, extra_tracking=extra_tracking, keep_xlim=keep_xlim) elif plot_type == "body": self.stacked_widget.setCurrentIndex(1) self.plot_window.plot_heading_angle_array(array, keep_xlim=keep_xlim) else: pass def switch_tracking_item(self, row_number): print("Switching tracking item") tracking_params = self.controller.tracking_params[row_number] self.change_selected_tracking_row(row_number) self.plot_tabs_widget.blockSignals(True) # add plot tabs for i in range(self.plot_tabs_widget.count()-1, -1, -1): self.plot_tabs_widget.removeTab(i) if tracking_params['type'] == "freeswimming": if tracking_params['track_tail']: self.plot_tabs_widget.addTab("Tail") self.plot_tabs_widget.addTab("Body") if tracking_params['track_eyes']: self.plot_tabs_widget.addTab("Eyes") else: self.plot_tabs_widget.addTab("Tail") self.plot_tabs_widget.blockSignals(False) self.crop_tabs_widget.blockSignals(True) for i in range(self.crop_tabs_widget.count()-1, -1, -1): self.crop_tabs_widget.removeTab(i) # add crop tabs n_crops = len(tracking_params['crop_params']) for i in range(n_crops): self.crop_tabs_widget.addTab("{}".format(i+1)) self.crop_tabs_widget.blockSignals(False) def add_tracking_item(self, item_name): print("Adding tracking item") self.tracking_list_items.append(QListWidgetItem(item_name, self.tracking_list)) # self.update_plot() def change_selected_tracking_row(self, row_number): self.tracking_list.blockSignals(True) self.tracking_list.setCurrentRow(row_number) self.tracking_list.blockSignals(False) def create_tail_tracking_widget(self, parent_widget): # create tail tab widget & layout tail_tab_widget = QWidget() tail_tab_layout = QVBoxLayout(tail_tab_widget) # create button layout for tail tab bottom_tail_button_layout = QVBoxLayout() # bottom_tail_button_layout.setSpacing(5) bottom_tail_button_layout.addStretch(1) tail_tab_layout.addLayout(bottom_tail_button_layout) # add buttons track_bouts_button = QPushButton('Track Bouts', self) track_bouts_button.setMinimumHeight(30) track_bouts_button.setMaximumWidth(100) track_bouts_button.clicked.connect(lambda:self.controller.track_bouts()) bottom_tail_button_layout.addWidget(track_bouts_button) track_freqs_button = QPushButton('Track Freq', self) track_freqs_button.setMinimumHeight(30) track_freqs_button.setMaximumWidth(100) track_freqs_button.clicked.connect(lambda:self.controller.track_freqs()) bottom_tail_button_layout.addWidget(track_freqs_button) # add checkbox for switching plots self.smoothed_deriv_checkbox = QCheckBox("Show smoothed derivative") self.smoothed_deriv_checkbox.toggled.connect(lambda:self.show_smoothed_deriv(self.smoothed_deriv_checkbox)) bottom_tail_button_layout.addWidget(self.smoothed_deriv_checkbox) # add param labels & textboxes smoothing_window_label = QLabel() smoothing_window_label.setText("Smoothing window:") bottom_tail_button_layout.addWidget(smoothing_window_label) self.smoothing_window_param_box = QLineEdit(self) self.smoothing_window_param_box.setMinimumHeight(20) self.smoothing_window_param_box.setMaximumWidth(40) self.smoothing_window_param_box.setText(str(self.controller.smoothing_window_width)) bottom_tail_button_layout.addWidget(self.smoothing_window_param_box) threshold_label = QLabel() threshold_label.setText("Threshold:") bottom_tail_button_layout.addWidget(threshold_label) self.threshold_param_box = QLineEdit(self) self.threshold_param_box.setMinimumHeight(20) self.threshold_param_box.setMaximumWidth(40) self.threshold_param_box.setText(str(self.controller.threshold)) bottom_tail_button_layout.addWidget(self.threshold_param_box) min_width_label = QLabel() min_width_label.setText("Min width:") bottom_tail_button_layout.addWidget(min_width_label) self.min_width_param_box = QLineEdit(self) self.min_width_param_box.setMinimumHeight(20) self.min_width_param_box.setMaximumWidth(40) self.min_width_param_box.setText(str(self.controller.min_width)) bottom_tail_button_layout.addWidget(self.min_width_param_box) parent_widget.addWidget(tail_tab_widget) def create_body_tracking_widget(self, parent_widget): # create head tab widget & layout head_tab_widget = QWidget() head_tab_layout = QVBoxLayout(head_tab_widget) # create button layout for head tab bottom_head_button_layout = QHBoxLayout() head_tab_layout.addLayout(bottom_head_button_layout) # add buttons track_position_button = QPushButton('Track Pos', self) track_position_button.setMinimumHeight(30) track_position_button.setMaximumWidth(100) track_position_button.clicked.connect(lambda:self.track_position()) bottom_head_button_layout.addWidget(track_position_button) # add checkbox for switching plots speed_checkbox = QCheckBox("Show speed") speed_checkbox.toggled.connect(lambda:self.show_speed(self.speed_checkbox)) bottom_head_button_layout.addWidget(speed_checkbox) parent_widget.addWidget(head_tab_widget) def create_crops(self, parent_layout): crop_tabs_widget = QTabWidget() crop_tabs_widget.currentChanged.connect(self.change_crop) crop_tabs_widget.setElideMode(Qt.ElideLeft) crop_tabs_layout = QVBoxLayout(crop_tabs_widget) parent_layout.addWidget(crop_tabs_widget) self.crop_tabs_widgets.append(crop_tabs_widget) self.crop_tabs_layouts.append(crop_tabs_layout) n_crops = len(self.controller.tracking_params[self.controller.curr_tracking_num]['crop_params']) for k in range(n_crops): self.create_crop() def clear_crops(self): self.crop_tab_layouts = [[]] self.crop_tab_widgets = [[]] self.plot_tab_layouts = [{'tail': [], 'eyes': [], 'body': []}] self.plot_tab_widgets = [{'tail': [], 'eyes': [], 'body': []}] self.plot_tabs_widgets = [[]] self.plot_tabs_layouts = [[]] self.head_angle_arrays = [] self.tail_angle_arrays = [] for c in range(self.n_crops-1, -1, -1): # remove tab self.crop_tabs_widget.removeTab(c) self.n_crops = 0 self.current_crop = -1 def show_smoothed_deriv(self, checkbox): if self.smoothed_abs_deriv_abs_angle_array != None: if checkbox.isChecked(): self.tail_canvas.plot_tail_angle_array(self.smoothed_abs_deriv_abs_angle_array, self.bouts, keep_limits=True) else: self.tail_canvas.plot_tail_angle_array(self.tail_end_angle_array[self.current_crop], self.bouts, self.peak_maxes_y, self.peak_maxes_x, self.peak_mins_y, self.peak_mins_x, self.freqs, keep_limits=True) def show_speed(self, checkbox): if self.speed_array != None: if checkbox.isChecked(): self.head_canvas.plot_head_array(self.speed_array, keep_limits=True) else: self.head_canvas.plot_head_array(self.head_angle_array, keep_limits=True) def load_data(self, data_path=None): if data_path == None: # ask the user to select a directory self.path = str(QFileDialog.getExistingDirectory(self, 'Open folder')) else: self.path = data_path # load saved tracking data (self.tail_coords_array, self.spline_coords_array, self.heading_angle_array, self.body_position_array, self.eye_coords_array, self.params) = an.open_saved_data(self.path) if self.params != None: # calculate tail angles if self.params['type'] == "freeswimming" and self.params['track_tail']: self.tail_angle_array = an.get_freeswimming_tail_angles(self.tail_coords_array, self.heading_angle_array, self.body_position_array) elif self.params['type'] == "headfixed": self.tail_angle_array = an.get_headfixed_tail_angles(self.tail_coords_array, self.params['tail_direction']) # get array of average angle of the last few points of the tail # self.tail_end_angle_array = np.mean(self.tail_angle_array[:, :, -3:], axis=-1) # self.tail_end_angle_array = self.tail_angle_array[:, :, -1] self.tail_end_angle_array = an.get_tail_end_angles(self.tail_angle_array, num_to_average=3) # clear crops self.clear_crops() # get number of saved crops n_crops_total = len(self.params['crop_params']) for k in range(n_crops_total): # create a crop self.create_crop() # plot heading angle if self.heading_angle_array is not None: self.plot_canvases[k].plot_heading_angle_array(self.heading_angle_array[k]) # plot tail angle if self.tail_angle_array is not None: self.tail_canvases[k].plot_tail_angle_array(self.tail_end_angle_array[k]) # def track_bouts(self): # if self.tail_angle_array != None: # # get params # self.smoothing_window_width = int(self.smoothing_window_param_box.text()) # self.threshold = float(self.threshold_param_box.text()) # self.min_width = int(self.min_width_param_box.text()) # # get smoothed derivative # abs_angle_array = np.abs(self.tail_end_angle_array[self.current_crop]) # deriv_abs_angle_array = np.gradient(abs_angle_array) # abs_deriv_abs_angle_array = np.abs(deriv_abs_angle_array) # normpdf = scipy.stats.norm.pdf(range(-int(self.smoothing_window_width/2),int(self.smoothing_window_width/2)),0,3) # self.smoothed_abs_deriv_abs_angle_array = np.convolve(abs_deriv_abs_angle_array, normpdf/np.sum(normpdf),mode='valid') # # calculate bout periods # self.bouts = an.contiguous_regions(self.smoothed_abs_deriv_abs_angle_array > self.threshold) # # remove bouts that don't have the minimum bout length # for i in range(self.bouts.shape[0]-1, -1, -1): # if self.bouts[i, 1] - self.bouts[i, 0] < self.min_width: # self.bouts = np.delete(self.bouts, (i), 0) # # update plot # self.smoothed_deriv_checkbox.setChecked(False) # self.tail_canvas.plot_tail_angle_array(self.tail_end_angle_array[self.current_crop], self.bouts, keep_limits=True) # def track_freqs(self): # if self.bouts != None: # # initiate bout maxima & minima coord lists # self.peak_maxes_y = [] # self.peak_maxes_x = [] # self.peak_mins_y = [] # self.peak_mins_x = [] # # initiate instantaneous frequency array # self.freqs = np.zeros(self.tail_angle_array.shape[0]) # for i in range(self.bouts.shape[0]): # # get local maxima & minima # peak_max, peak_min = peakdetect.peakdet(self.tail_end_angle_array[self.current_crop][self.bouts[i, 0]:self.bouts[i, 1]], 0.02) # # change local coordinates (relative to the start of the bout) to global coordinates # peak_max[:, 0] += self.bouts[i, 0] # peak_min[:, 0] += self.bouts[i, 0] # # add to the bout maxima & minima coord lists # self.peak_maxes_y += list(peak_max[:, 1]) # self.peak_maxes_x += list(peak_max[:, 0]) # self.peak_mins_y += list(peak_min[:, 1]) # self.peak_mins_x += list(peak_min[:, 0]) # # calculate instantaneous frequencies # for i in range(len(self.peak_maxes_x)-1): # self.freqs[self.peak_maxes_x[i]:self.peak_maxes_x[i+1]] = 1.0/(self.peak_maxes_x[i+1] - self.peak_maxes_x[i]) # # update plot # self.smoothed_deriv_checkbox.setChecked(False) # self.tail_canvas.plot_tail_angle_array(self.tail_end_angle_array[self.current_crop], self.bouts, self.peak_maxes_y, self.peak_maxes_x, self.peak_mins_y, self.peak_mins_x, self.freqs, keep_limits=True) def track_position(self): if self.head_angle_array != None: # get params self.smoothing_window_width = int(self.smoothing_window_param_box.text()) abs_angle_array = np.abs(self.tail_angle_array) deriv_abs_angle_array = np.gradient(abs_angle_array) abs_deriv_abs_angle_array = np.abs(deriv_abs_angle_array) self.smoothed_abs_deriv_abs_angle_array = np.convolve(abs_deriv_abs_angle_array, np.ones((self.smoothing_window_width,))/self.smoothing_window_width, mode='valid') positions_y, positions_x, self.speed_array = an.get_position_history(self.path, plot=False) def fileQuit(self): self.close() def closeEvent(self, event): self.controller.close_all() def resizeEvent(self, re): QMainWindow.resizeEvent(self, re)