class AnalysisWindow(QMainWindow):
    def __init__(self, parent, controller):
        # create window
        QMainWindow.__init__(self)

        self.setAttribute(Qt.WA_DeleteOnClose)
        self.setWindowTitle("Tracking Analysis")
        self.setGeometry(100, 200, 10, 10)

        # set controller
        self.controller = controller

        # create main widget & layout
        self.main_widget = QWidget(self)
        self.main_widget.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding)
        self.main_layout = QGridLayout(self.main_widget)

        # create left widget & layout
        self.left_widget = QWidget(self)
        self.main_layout.addWidget(self.left_widget, 0, 0)

        self.left_layout = QVBoxLayout(self.left_widget)
        self.left_layout.setAlignment(Qt.AlignTop)

        # create list of tracking items
        self.tracking_list_items = []
        self.tracking_list = QListWidget(self)
        self.tracking_list.currentRowChanged.connect(self.controller.switch_tracking_file)
        self.left_layout.addWidget(self.tracking_list)

        # create tracking list buttons
        self.tracking_list_buttons = QHBoxLayout(self)
        self.left_layout.addLayout(self.tracking_list_buttons)

        self.add_tracking_button = QPushButton('+')
        self.add_tracking_button.clicked.connect(self.controller.select_and_open_tracking_files)
        self.add_tracking_button.setToolTip("Add tracking file.")
        self.tracking_list_buttons.addWidget(self.add_tracking_button)

        self.remove_tracking_button = QPushButton('-')
        self.remove_tracking_button.clicked.connect(self.controller.remove_tracking_file)
        self.remove_tracking_button.setToolTip("Remove selected tracking file.")
        self.tracking_list_buttons.addWidget(self.remove_tracking_button)

        self.prev_tracking_button = QPushButton('<')
        self.prev_tracking_button.clicked.connect(self.controller.prev_tracking_file)
        self.prev_tracking_button.setToolTip("Switch to previous tracking file.")
        self.tracking_list_buttons.addWidget(self.prev_tracking_button)

        self.next_tracking_button = QPushButton('>')
        self.next_tracking_button.clicked.connect(self.controller.next_tracking_file)
        self.next_tracking_button.setToolTip("Switch to next tracking file.")
        self.tracking_list_buttons.addWidget(self.next_tracking_button)

        # create right widget & layout
        self.right_widget = QWidget(self)
        self.right_widget.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding)
        self.main_layout.addWidget(self.right_widget, 0, 1)
        
        self.right_layout = QVBoxLayout(self.right_widget)
        self.right_layout.setAlignment(Qt.AlignTop)
        self.right_layout.setSpacing(5)

        # create button layout for main widget
        plot_horiz_layout = QHBoxLayout()
        self.right_layout.addLayout(plot_horiz_layout)

        # add param labels & textboxes
        plot_type_label = QLabel()
        plot_type_label.setText("Plot:")
        plot_horiz_layout.addWidget(plot_type_label)
        plot_horiz_layout.addStretch(1)

        # create tab widget for plot type
        self.plot_tabs_widget = QTabBar()
        self.plot_tabs_widget.setDrawBase(False)
        self.plot_tabs_widget.setExpanding(False)
        self.plot_tabs_widget.currentChanged.connect(self.controller.change_plot_type)
        plot_horiz_layout.addWidget(self.plot_tabs_widget)

        # create button layout for main widget
        crop_horiz_layout = QHBoxLayout()
        self.right_layout.addLayout(crop_horiz_layout)

        # add param labels & textboxes
        crop_type_label = QLabel()
        crop_type_label.setText("Crop #:")
        crop_horiz_layout.addWidget(crop_type_label)
        crop_horiz_layout.addStretch(1)

        # create tab widget for crop number
        self.crop_tabs_widget = QTabBar()
        self.crop_tabs_widget.setDrawBase(False)
        self.crop_tabs_widget.setExpanding(False)
        self.crop_tabs_widget.currentChanged.connect(self.controller.change_crop)
        crop_horiz_layout.addWidget(self.crop_tabs_widget)

        # create button layout for main widget
        button_layout = QHBoxLayout()
        button_layout.addStretch(1)
        self.right_layout.addLayout(button_layout)

        # add buttons
        self.show_tracking_params_button = QPushButton('Tracking Parameters', self)
        self.show_tracking_params_button.setMinimumHeight(30)
        self.show_tracking_params_button.clicked.connect(self.controller.show_tracking_params)
        button_layout.addWidget(self.show_tracking_params_button)

        # create stacked widget & layout
        self.stacked_widget = QStackedWidget(self)
        self.stacked_widget.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding)
        self.right_layout.addWidget(self.stacked_widget)

        self.create_tail_tracking_widget(self.stacked_widget)
        self.create_body_tracking_widget(self.stacked_widget)

        # self.right_layout = QVBoxLayout(self.right_widget)
        # self.right_layout.setAlignment(Qt.AlignTop)
        # self.right_layout.setSpacing(5)

        # self.main_widget.setFocus()
        self.setCentralWidget(self.main_widget)

        # set window titlebar buttons
        self.setWindowFlags(Qt.CustomizeWindowHint | Qt.WindowCloseButtonHint | Qt.WindowMinimizeButtonHint | Qt.WindowMaximizeButtonHint)

        self.show()

    def update_plot(self, array, plot_type, extra_tracking=None, keep_xlim=True):
        print("Updating plot")

        if plot_type == "tail":
            self.stacked_widget.setCurrentIndex(0)
            self.plot_window.plot_tail_angle_array(array, extra_tracking=extra_tracking, keep_xlim=keep_xlim)
        elif plot_type == "body":
            self.stacked_widget.setCurrentIndex(1)
            self.plot_window.plot_heading_angle_array(array, keep_xlim=keep_xlim)
        else:
            pass

    def switch_tracking_item(self, row_number):
        print("Switching tracking item")

        tracking_params = self.controller.tracking_params[row_number]

        self.change_selected_tracking_row(row_number)

        self.plot_tabs_widget.blockSignals(True)

        # add plot tabs
        for i in range(self.plot_tabs_widget.count()-1, -1, -1):
            self.plot_tabs_widget.removeTab(i)

        if tracking_params['type'] == "freeswimming":
            if tracking_params['track_tail']:
                self.plot_tabs_widget.addTab("Tail")

            self.plot_tabs_widget.addTab("Body")

            if tracking_params['track_eyes']:
                self.plot_tabs_widget.addTab("Eyes")
        else:
            self.plot_tabs_widget.addTab("Tail")

        self.plot_tabs_widget.blockSignals(False)

        self.crop_tabs_widget.blockSignals(True)
        for i in range(self.crop_tabs_widget.count()-1, -1, -1):
            self.crop_tabs_widget.removeTab(i)

        # add crop tabs
        n_crops = len(tracking_params['crop_params'])

        for i in range(n_crops):
            self.crop_tabs_widget.addTab("{}".format(i+1))
        self.crop_tabs_widget.blockSignals(False)

    def add_tracking_item(self, item_name):
        print("Adding tracking item")
        self.tracking_list_items.append(QListWidgetItem(item_name, self.tracking_list))

        # self.update_plot()

    def change_selected_tracking_row(self, row_number):
        self.tracking_list.blockSignals(True)
        self.tracking_list.setCurrentRow(row_number)
        self.tracking_list.blockSignals(False)

    def create_tail_tracking_widget(self, parent_widget):
        # create tail tab widget & layout
        tail_tab_widget = QWidget()
        tail_tab_layout = QVBoxLayout(tail_tab_widget)

        # create button layout for tail tab
        bottom_tail_button_layout = QVBoxLayout()
        # bottom_tail_button_layout.setSpacing(5)
        bottom_tail_button_layout.addStretch(1)
        tail_tab_layout.addLayout(bottom_tail_button_layout)

        # add buttons
        track_bouts_button = QPushButton('Track Bouts', self)
        track_bouts_button.setMinimumHeight(30)
        track_bouts_button.setMaximumWidth(100)
        track_bouts_button.clicked.connect(lambda:self.controller.track_bouts())
        bottom_tail_button_layout.addWidget(track_bouts_button)

        track_freqs_button = QPushButton('Track Freq', self)
        track_freqs_button.setMinimumHeight(30)
        track_freqs_button.setMaximumWidth(100)
        track_freqs_button.clicked.connect(lambda:self.controller.track_freqs())
        bottom_tail_button_layout.addWidget(track_freqs_button)

        # add checkbox for switching plots
        self.smoothed_deriv_checkbox = QCheckBox("Show smoothed derivative")
        self.smoothed_deriv_checkbox.toggled.connect(lambda:self.show_smoothed_deriv(self.smoothed_deriv_checkbox))
        bottom_tail_button_layout.addWidget(self.smoothed_deriv_checkbox)

        # add param labels & textboxes
        smoothing_window_label = QLabel()
        smoothing_window_label.setText("Smoothing window:")
        bottom_tail_button_layout.addWidget(smoothing_window_label)

        self.smoothing_window_param_box = QLineEdit(self)
        self.smoothing_window_param_box.setMinimumHeight(20)
        self.smoothing_window_param_box.setMaximumWidth(40)
        self.smoothing_window_param_box.setText(str(self.controller.smoothing_window_width))
        bottom_tail_button_layout.addWidget(self.smoothing_window_param_box)

        threshold_label = QLabel()
        threshold_label.setText("Threshold:")
        bottom_tail_button_layout.addWidget(threshold_label)

        self.threshold_param_box = QLineEdit(self)
        self.threshold_param_box.setMinimumHeight(20)
        self.threshold_param_box.setMaximumWidth(40)
        self.threshold_param_box.setText(str(self.controller.threshold))
        bottom_tail_button_layout.addWidget(self.threshold_param_box)

        min_width_label = QLabel()
        min_width_label.setText("Min width:")
        bottom_tail_button_layout.addWidget(min_width_label)

        self.min_width_param_box = QLineEdit(self)
        self.min_width_param_box.setMinimumHeight(20)
        self.min_width_param_box.setMaximumWidth(40)
        self.min_width_param_box.setText(str(self.controller.min_width))
        bottom_tail_button_layout.addWidget(self.min_width_param_box)

        parent_widget.addWidget(tail_tab_widget)

    def create_body_tracking_widget(self, parent_widget):
        # create head tab widget & layout
        head_tab_widget = QWidget()
        head_tab_layout = QVBoxLayout(head_tab_widget)

        # create button layout for head tab
        bottom_head_button_layout = QHBoxLayout()
        head_tab_layout.addLayout(bottom_head_button_layout)

        # add buttons
        track_position_button = QPushButton('Track Pos', self)
        track_position_button.setMinimumHeight(30)
        track_position_button.setMaximumWidth(100)
        track_position_button.clicked.connect(lambda:self.track_position())
        bottom_head_button_layout.addWidget(track_position_button)

        # add checkbox for switching plots
        speed_checkbox = QCheckBox("Show speed")
        speed_checkbox.toggled.connect(lambda:self.show_speed(self.speed_checkbox))
        bottom_head_button_layout.addWidget(speed_checkbox)

        parent_widget.addWidget(head_tab_widget)

    def create_crops(self, parent_layout):
        crop_tabs_widget = QTabWidget()
        crop_tabs_widget.currentChanged.connect(self.change_crop)
        crop_tabs_widget.setElideMode(Qt.ElideLeft)
        crop_tabs_layout = QVBoxLayout(crop_tabs_widget)
        parent_layout.addWidget(crop_tabs_widget)

        self.crop_tabs_widgets.append(crop_tabs_widget)
        self.crop_tabs_layouts.append(crop_tabs_layout)

        n_crops = len(self.controller.tracking_params[self.controller.curr_tracking_num]['crop_params'])

        for k in range(n_crops):
            self.create_crop()

    def clear_crops(self):
        self.crop_tab_layouts  = [[]]
        self.crop_tab_widgets  = [[]]
        self.plot_tab_layouts  = [{'tail': [],
                                  'eyes': [],
                                  'body': []}]
        self.plot_tab_widgets  = [{'tail': [],
                                  'eyes': [],
                                  'body': []}]
        self.plot_tabs_widgets = [[]]
        self.plot_tabs_layouts = [[]]

        self.head_angle_arrays = []
        self.tail_angle_arrays = []

        for c in range(self.n_crops-1, -1, -1):
            # remove tab
            self.crop_tabs_widget.removeTab(c)

        self.n_crops = 0
        self.current_crop = -1

    def show_smoothed_deriv(self, checkbox):
        if self.smoothed_abs_deriv_abs_angle_array != None:
            if checkbox.isChecked():
                self.tail_canvas.plot_tail_angle_array(self.smoothed_abs_deriv_abs_angle_array, self.bouts, keep_limits=True)
            else:
                self.tail_canvas.plot_tail_angle_array(self.tail_end_angle_array[self.current_crop], self.bouts, self.peak_maxes_y, self.peak_maxes_x, self.peak_mins_y, self.peak_mins_x, self.freqs, keep_limits=True)

    def show_speed(self, checkbox):
        if self.speed_array != None:
            if checkbox.isChecked():
                self.head_canvas.plot_head_array(self.speed_array, keep_limits=True)
            else:
                self.head_canvas.plot_head_array(self.head_angle_array, keep_limits=True)

    def load_data(self, data_path=None):
        if data_path == None:
            # ask the user to select a directory
            self.path = str(QFileDialog.getExistingDirectory(self, 'Open folder'))
        else:
            self.path = data_path

        # load saved tracking data
        (self.tail_coords_array, self.spline_coords_array,
         self.heading_angle_array, self.body_position_array,
         self.eye_coords_array, self.params) = an.open_saved_data(self.path)

        if self.params != None:
            # calculate tail angles
            if self.params['type'] == "freeswimming" and self.params['track_tail']:
                self.tail_angle_array = an.get_freeswimming_tail_angles(self.tail_coords_array, self.heading_angle_array, self.body_position_array)
            elif self.params['type'] == "headfixed":
                self.tail_angle_array = an.get_headfixed_tail_angles(self.tail_coords_array, self.params['tail_direction'])

            # get array of average angle of the last few points of the tail
            # self.tail_end_angle_array = np.mean(self.tail_angle_array[:, :, -3:], axis=-1)
            # self.tail_end_angle_array = self.tail_angle_array[:, :, -1]
            self.tail_end_angle_array = an.get_tail_end_angles(self.tail_angle_array, num_to_average=3)
            
            # clear crops
            self.clear_crops()

            # get number of saved crops
            n_crops_total = len(self.params['crop_params'])

            for k in range(n_crops_total):
                # create a crop
                self.create_crop()

                # plot heading angle
                if self.heading_angle_array is not None:
                    self.plot_canvases[k].plot_heading_angle_array(self.heading_angle_array[k])

                # plot tail angle
                if self.tail_angle_array is not None:
                    self.tail_canvases[k].plot_tail_angle_array(self.tail_end_angle_array[k])

    # def track_bouts(self):
    #     if self.tail_angle_array != None:
    #         # get params
    #         self.smoothing_window_width = int(self.smoothing_window_param_box.text())
    #         self.threshold = float(self.threshold_param_box.text())
    #         self.min_width = int(self.min_width_param_box.text())

    #         # get smoothed derivative
    #         abs_angle_array = np.abs(self.tail_end_angle_array[self.current_crop])
    #         deriv_abs_angle_array = np.gradient(abs_angle_array)
    #         abs_deriv_abs_angle_array = np.abs(deriv_abs_angle_array)
    #         normpdf = scipy.stats.norm.pdf(range(-int(self.smoothing_window_width/2),int(self.smoothing_window_width/2)),0,3)
    #         self.smoothed_abs_deriv_abs_angle_array =  np.convolve(abs_deriv_abs_angle_array,  normpdf/np.sum(normpdf),mode='valid')

    #         # calculate bout periods
    #         self.bouts = an.contiguous_regions(self.smoothed_abs_deriv_abs_angle_array > self.threshold)

    #         # remove bouts that don't have the minimum bout length
    #         for i in range(self.bouts.shape[0]-1, -1, -1):
    #             if self.bouts[i, 1] - self.bouts[i, 0] < self.min_width:
    #                 self.bouts = np.delete(self.bouts, (i), 0)

    #         # update plot
    #         self.smoothed_deriv_checkbox.setChecked(False)
    #         self.tail_canvas.plot_tail_angle_array(self.tail_end_angle_array[self.current_crop], self.bouts, keep_limits=True)

    # def track_freqs(self):
    #     if self.bouts != None:
    #         # initiate bout maxima & minima coord lists
    #         self.peak_maxes_y = []
    #         self.peak_maxes_x = []
    #         self.peak_mins_y = []
    #         self.peak_mins_x = []

    #         # initiate instantaneous frequency array
    #         self.freqs = np.zeros(self.tail_angle_array.shape[0])

    #         for i in range(self.bouts.shape[0]):
    #             # get local maxima & minima
    #             peak_max, peak_min = peakdetect.peakdet(self.tail_end_angle_array[self.current_crop][self.bouts[i, 0]:self.bouts[i, 1]], 0.02)

    #             # change local coordinates (relative to the start of the bout) to global coordinates
    #             peak_max[:, 0] += self.bouts[i, 0]
    #             peak_min[:, 0] += self.bouts[i, 0]

    #             # add to the bout maxima & minima coord lists
    #             self.peak_maxes_y += list(peak_max[:, 1])
    #             self.peak_maxes_x += list(peak_max[:, 0])
    #             self.peak_mins_y += list(peak_min[:, 1])
    #             self.peak_mins_x += list(peak_min[:, 0])

    #         # calculate instantaneous frequencies
    #         for i in range(len(self.peak_maxes_x)-1):
    #             self.freqs[self.peak_maxes_x[i]:self.peak_maxes_x[i+1]] = 1.0/(self.peak_maxes_x[i+1] - self.peak_maxes_x[i])

    #         # update plot
    #         self.smoothed_deriv_checkbox.setChecked(False)
    #         self.tail_canvas.plot_tail_angle_array(self.tail_end_angle_array[self.current_crop], self.bouts, self.peak_maxes_y, self.peak_maxes_x, self.peak_mins_y, self.peak_mins_x, self.freqs, keep_limits=True)

    def track_position(self):
        if self.head_angle_array != None:
            # get params
            self.smoothing_window_width = int(self.smoothing_window_param_box.text())

            abs_angle_array = np.abs(self.tail_angle_array)
            deriv_abs_angle_array = np.gradient(abs_angle_array)
            abs_deriv_abs_angle_array = np.abs(deriv_abs_angle_array)
            self.smoothed_abs_deriv_abs_angle_array = np.convolve(abs_deriv_abs_angle_array, np.ones((self.smoothing_window_width,))/self.smoothing_window_width, mode='valid')

            positions_y, positions_x, self.speed_array = an.get_position_history(self.path, plot=False)

    def fileQuit(self):
        self.close()

    def closeEvent(self, event):
        self.controller.close_all()

    def resizeEvent(self, re):
        QMainWindow.resizeEvent(self, re)