def enable_widgets(self):
        if not self.enabled:
            self.class_graph_1 = Graph(self.class_graph_1,
                                       'class',
                                       interval_lines=False,
                                       label='Classes #1')
            self.class_graph_2 = Graph(self.class_graph_2,
                                       'class',
                                       interval_lines=False,
                                       label='Classes #2')
            self.class_graph_3 = Graph(self.class_graph_3,
                                       'class',
                                       interval_lines=False,
                                       label='Classes #3')
            self.attribute_graph = Graph(self.attribute_graph,
                                         'attribute',
                                         interval_lines=False)

            # self.deep_rep_button.setEnabled(True)
            self.network_comboBox.setEnabled(True)

            self.enabled = True

        self.class_graph_1.setup()
        self.class_graph_2.setup()
        self.class_graph_3.setup()
        self.attribute_graph.setup()

        self.reload()
        self.enable_annotate_button()
        self.enable_load_button()

        self.deep_rep_checkBox.setEnabled(False)
        self.deep_rep_checkBox.setChecked(False)
        self.deep_rep_files = []
    def enable_widgets(self):
        """"""

        if not self.was_enabled_once:
            self.class_graph = Graph(self.class_graph,
                                     'class',
                                     interval_lines=False)
            self.was_enabled_once = True

        self.split_at_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples + 1))
        self.move_start_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples + 1))
        self.move_end_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples + 1))

        self.class_graph.setup()
        self.class_graph.reload_classes(g.windows.windows)

        self.reload()
Esempio n. 3
0
    def enable_widgets(self):
        self.before_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples))
        self.after_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples))

        self.class_graph = Graph(plot_widget=self.class_graph,
                                 graph_type="class",
                                 label="Classes",
                                 interval_lines=True)
        self.class_graph.setup()

        self.distance_graph = Graph(self.distance_graph,
                                    'data',
                                    label="score",
                                    interval_lines=True,
                                    unit="",
                                    AutoSIPrefix=False,
                                    y_range=(0, 1))
        self.distance_graph.setup()

        self.distance_histogram = Graph(self.distance_histogram,
                                        "histogram",
                                        label="histogram",
                                        play_line=True)
        self.distance_histogram.setup()

        self.reload()
    def show_graphs(self, dataset, y_axis):
        graphs = []
        dlg = PlotDialog(None, 9)
        dlg.setWindowTitle("Graph")
        plots = dlg.graph_widgets()

        class_graph = Graph(plots[0], "class", interval_lines=False)
        class_graph.setup()
        class_graph.reload_classes(g.windows.windows)
        graphs.append(class_graph)

        for i in range(0, len(g.classes)):
            heatmap_data = dataset.make_heatmap(i, True)
            plots[i +
                  1].setTitle(f'<font size="6"><b>{g.classes[i]}</b></font>')
            # plots[i + 1].setYRange(0, 1)
            # legend = plot.addLegend(offset=(-10, 15), labelTextSize='20pt')
            plots[i + 1].getAxis('left').setLabel(y_axis)
            plots[i + 1].plot(heatmap_data)
            dlg.showMaximized()
            graphs.append(dlg)
        return graphs
class AutomaticAnnotationController(Controller):
    def __init__(self, gui):
        super(AutomaticAnnotationController, self).__init__(gui)

        self.window_step = g.settings['segmentationWindowStride']

        self.selected_network = 0  # TODO: Save last selected in settings
        self.current_window = -1
        self.deep_rep_files = []
        self.setup_widgets()

    def setup_widgets(self):
        self.load_tab(f'..{sep}ui{sep}automatic_annotation_mode.ui',
                      "Automatic Annotation")
        # ComboBoxes
        self.network_comboBox = self.widget.findChild(QtWidgets.QComboBox,
                                                      "aa_network_comboBox")
        self.network_comboBox.currentIndexChanged.connect(self.select_network)
        for k in sorted(g.networks.keys()):
            self.network_comboBox.addItem(g.networks[k]['name'])

        self.post_processing_comboBox = self.widget.findChild(
            QtWidgets.QComboBox, "aa_post_processing_comboBox")

        # Buttons
        self.annotate_button = self.widget.findChild(QtWidgets.QPushButton,
                                                     "aa_annotate_button")
        self.annotate_button.clicked.connect(lambda _: self.gui.pause())
        self.annotate_button.clicked.connect(lambda _: self.annotate())

        self.load_predictions_button = self.widget.findChild(
            QtWidgets.QPushButton, "aa_load_prediction_button")
        self.load_predictions_button.clicked.connect(
            lambda _: self.load_predictions())

        # Graphs
        self.class_graph_1 = self.widget.findChild(pg.PlotWidget,
                                                   'aa_classGraph')
        self.class_graph_2 = self.widget.findChild(pg.PlotWidget,
                                                   'aa_classGraph_2')
        self.class_graph_3 = self.widget.findChild(pg.PlotWidget,
                                                   'aa_classGraph_3')
        self.attribute_graph = self.widget.findChild(pg.PlotWidget,
                                                     'aa_attributeGraph')

        # Status window
        self.status_window = self.widget.findChild(QtWidgets.QTextEdit,
                                                   'aa_statusWindow')
        self.add_status_message(
            "This mode is for using a Neural Network to annotate Data.")

        # deep rep functions
        self.deep_rep_checkBox = self.widget.findChild(QtWidgets.QCheckBox,
                                                       "aa_deep_rep_checkBox")

        self.deep_rep_button = self.widget.findChild(
            QtWidgets.QPushButton, "aa_deep_rep_browse_button")
        self.deep_rep_button.clicked.connect(
            lambda _: self.browse_deep_rep_files())

    def enable_widgets(self):
        if not self.enabled:
            self.class_graph_1 = Graph(self.class_graph_1,
                                       'class',
                                       interval_lines=False,
                                       label='Classes #1')
            self.class_graph_2 = Graph(self.class_graph_2,
                                       'class',
                                       interval_lines=False,
                                       label='Classes #2')
            self.class_graph_3 = Graph(self.class_graph_3,
                                       'class',
                                       interval_lines=False,
                                       label='Classes #3')
            self.attribute_graph = Graph(self.attribute_graph,
                                         'attribute',
                                         interval_lines=False)

            # self.deep_rep_button.setEnabled(True)
            self.network_comboBox.setEnabled(True)

            self.enabled = True

        self.class_graph_1.setup()
        self.class_graph_2.setup()
        self.class_graph_3.setup()
        self.attribute_graph.setup()

        self.reload()
        self.enable_annotate_button()
        self.enable_load_button()

        self.deep_rep_checkBox.setEnabled(False)
        self.deep_rep_checkBox.setChecked(False)
        self.deep_rep_files = []

    def enable_annotate_button(self):
        if self.selected_network > 0 \
                and self.enabled \
                and (self.fixed_window_mode_enabled is None
                     or self.fixed_window_mode_enabled == "none"):
            self.annotate_button.setEnabled(True)
        else:
            self.annotate_button.setEnabled(False)
            # self.annotate_folder_button.setEnabled(False)

    def enable_load_button(self):
        if self.selected_network > 0 \
                and self.enabled \
                and self.fixed_window_mode_enabled in [None, "none"]:
            directory = g.settings['saveFinishedPath']
            annotator_id = g.networks[self.selected_network]['annotator_id']

            files_present = True
            for pred_id in range(3):
                file_name = f"{g.windows.file_name.split('.')[0]}_A{annotator_id:0>2}_N{pred_id:0>2}.txt"
                path = directory + os.sep + file_name
                if not os.path.exists(path):
                    files_present = False

            if files_present:
                self.load_predictions_button.setEnabled(True)
            else:
                self.load_predictions_button.setEnabled(False)
        else:
            self.load_predictions_button.setEnabled(False)

    def reload(self):
        frame = self.gui.get_current_frame()
        graphs = [self.class_graph_1, self.class_graph_2, self.class_graph_3]
        for graph in graphs:
            graph.update_frame_lines(play=frame)

        if g.windows is not None \
                and g.windows.windows_1 is not None \
                and len(g.windows.windows_1) > 0:

            windows = [
                g.windows.windows_1, g.windows.windows_2, g.windows.windows_3
            ]
            for graph, window in zip(graphs, windows):
                graph.reload_classes(window)

            self.select_window_by_frame(frame)
            self.selectWindow(self.current_window)
            self.highlight_class_bar(self.current_window)

    def select_network(self, index):
        """Saves the selected network and tries to activate annotation if one was selected"""
        self.selected_network = index
        if index > 0:
            attributes = g.networks[index]['attributes']
        else:
            attributes = False
        self.deep_rep_button.setEnabled(attributes)
        if not attributes:
            self.deep_rep_checkBox.setChecked(False)
            self.deep_rep_checkBox.setEnabled(False)
        elif self.deep_rep_files:
            self.deep_rep_checkBox.setEnabled(True)

        self.enable_annotate_button()
        self.enable_load_button()

    def annotate(self):
        self.annotate_start_time = time.time()

        self.progress = ProgressDialog(self.gui, "annotating", 6)

        self.annotator = Annotator(self.gui, self.selected_network,
                                   self.deep_rep_checkBox.isChecked(),
                                   self.deep_rep_files)
        self.annotator.progress.connect(self.progress.set_step)
        self.annotator.progress_add.connect(self.progress.advance_step)
        self.annotator.nextstep.connect(self.progress.new_step)
        self.annotator.cancel.connect(lambda _: self.cancel_annotation())
        self.annotator.done.connect(lambda _: self.finish_annotation())

        self.attribute_graph.update_attributes(None)
        # for i in range(len(g.attributes)):
        #    self.gui.graphics_controller.attr_bars[i].setOpts(y1=0)

        self.progress.show()
        self.annotator.start()

    def finish_annotation(self):
        self.reload()
        # print("windows_1: ", g.windows.windows_1)
        # print("windows_2: ", g.windows.windows_2)
        # print("windows_3: ", g.windows.windows_3)
        self.time_annotation()

        del self.annotator

    def cancel_annotation(self):
        self.progress.close()
        # self.time_annotation()

    def time_annotation(self):
        annotate_end_time = time.time()
        time_elapsed = int(annotate_end_time - self.annotate_start_time)
        seconds = time_elapsed % 60
        minutes = (time_elapsed // 60) % 60
        hours = time_elapsed // 3600
        # print(time_elapsed)
        self.add_status_message("The annotation took {}:{}:{}".format(
            hours, minutes, seconds))

    def load_predictions(self):
        g.windows.load_predictions(
            g.settings['saveFinishedPath'],
            g.networks[self.selected_network]['annotator_id'])

        directory = g.settings['saveFinishedPath']
        annotator_id = g.networks[self.selected_network]['annotator_id']
        g.retrieval = RetrievalData.load_retrieval(directory, annotator_id)

        self.reload()
        # print("windows_1: ", g.windows.windows_1)
        # print("windows_2: ", g.windows.windows_2)
        # print("windows_3: ", g.windows.windows_3)

    def new_frame(self, frame):

        classgraphs = [
            self.class_graph_1, self.class_graph_2, self.class_graph_3
        ]

        for graph in classgraphs:
            graph.update_frame_lines(play=frame)

        if g.windows is not None \
                and g.windows.windows_1 is not None \
                and len(g.windows.windows_1) > 0:

            window_index = self.class_window_index(frame)
            if self.current_window != window_index:
                self.current_window = window_index
                self.selectWindow(self.current_window)
                self.highlight_class_bar(self.current_window)

    def class_window_index(self, frame):
        if frame is None:
            frame = self.gui.get_current_frame()
        for i, window in enumerate(g.windows.windows_1):
            if window[0] <= frame < window[1]:
                return i
        return None

    def select_window_by_frame(self, frame=None):
        """Selects the Window around based on the current Frame shown
        
        """
        if frame is None:
            frame = self.gui.get_current_frame()
        window_index = self.class_window_index(frame)
        if window_index is None:
            window_index = -1
        # if the old and new index is the same do nothing.
        if self.current_window != window_index:
            self.current_window = window_index
            self.selectWindow(window_index)
        else:
            self.current_window = window_index

    def selectWindow(self, window_index: int):
        """Selects the window at window_index"""

        if window_index >= 0:
            self.current_window = window_index

            # needs to update shown attributes and start-, end-lines for top3 graphs
            # start end and attributes are the same in each prediction

            # classgraphs = [self.class_graph_1, self.class_graph_2, self.class_graph_3]
            # for graph in classgraphs:
            # graph.update_frame_lines(start, end)

            _, _, _, attributes = g.windows.windows_1[self.current_window]
            self.attribute_graph.update_attributes(attributes)

    def highlight_class_bar(self, bar_index):

        normal_color = 0.5  # gray
        error_color = 200, 100, 100  # gray-ish red
        selected_color = 'y'  # yellow
        selected_error_color = 255, 200, 50  # orange

        num_windows = len(g.windows.windows_1)

        colors = []
        for i in range(num_windows):
            if g.windows.windows_1[i][3][-1] == 0:
                colors.append(normal_color)
            else:
                colors.append(error_color)

        if bar_index is not None:
            if g.windows.windows_1[bar_index][3][-1] == 0:
                colors[bar_index] = selected_color
            else:
                colors[bar_index] = selected_error_color

        self.class_graph_1.color_class_bars(colors)
        self.class_graph_2.color_class_bars(colors)
        self.class_graph_3.color_class_bars(colors)

    def fixed_windows_mode(self, mode: str):
        # Controller.fixed_windows_mode(self, enable)
        self.fixed_window_mode_enabled = mode

        self.enable_annotate_button()
        self.enable_load_button()

    def get_start_frame(self) -> int:
        """returns the start of the current window"""
        if g.windows.windows_1 is not None and len(g.windows.windows_1) > 0:
            return g.windows.windows_1[self.current_window][0] + 1
        return self.gui.get_current_frame()

    def browse_deep_rep_files(self):
        current_file_name = g.windows.file_name
        name_parts = current_file_name.split('_')
        subject_id = [s for s in name_parts if 'S' in s][0]
        # print(subject_id)
        paths = QtWidgets.QFileDialog.getOpenFileNames(
            parent=self.gui,
            caption=
            'Please choose annotated files from the same Subject as the current file.',
            directory=g.settings['saveFinishedPath'],
            filter=f'CSV Files (*{subject_id}*norm_data.csv)',
            initialFilter='')[0]
        # print(paths)

        self.deep_rep_files = paths

        if paths:
            file_names = [os.path.split(path)[1][:-14] for path in paths]
            # print(file_names)
            msg = "Selected files for Deep Representation learning:"
            for file in file_names:
                msg += f"\n- {file}"
            self.add_status_message(msg)

            self.deep_rep_checkBox.setEnabled(True)
            self.deep_rep_checkBox.setChecked(True)
        else:
            self.add_status_message(
                "No files selected for Deep Representation learning")
            self.deep_rep_checkBox.setEnabled(False)
            self.deep_rep_checkBox.setChecked(False)
Esempio n. 6
0
class RetrievalController(Controller):
    def __init__(self, gui):
        super(RetrievalController, self).__init__(gui)

        self.retrieved_list = []
        self.not_filtered_range = (0, 1)  # Used for the distance histogram

        self.setup_widgets()

    def setup_widgets(self):
        self.load_tab(f'..{sep}ui{sep}retrieval_mode.ui', "Retrieval")

        # ----Retrieval preparation----
        self.query_comboBox: QtWidgets.QComboBox = self.widget.findChild(
            QtWidgets.QComboBox, "qbcComboBox")
        self.query_comboBox.addItems(g.classes)
        self.query_comboBox.currentIndexChanged.connect(self.change_query)

        self.metric_comboBox = self.widget.findChild(QtWidgets.QComboBox,
                                                     "distanceComboBox")
        self.metric_comboBox.addItems(["cosine", "bce"])
        # self.metric_comboBox.currentTextChanged.connect()

        self.retrieve_button = self.widget.findChild(QtWidgets.QPushButton,
                                                     "retrievePushButton")
        self.retrieve_button.clicked.connect(lambda _: self.retrieve_list())

        self.none_button = self.widget.findChild(QtWidgets.QPushButton,
                                                 "nonePushButton")
        self.none_button.clicked.connect(lambda _: self.set_to_none())

        # ----Retrieval video settings----
        # TODO: add these settings to the settings file to save/load them
        self.loop_checkBox = self.widget.findChild(QtWidgets.QCheckBox,
                                                   "loopCheckBox")
        self.before_lineEdit = self.widget.findChild(QtWidgets.QLineEdit,
                                                     "beforeLineEdit")
        self.after_lineEdit = self.widget.findChild(QtWidgets.QLineEdit,
                                                    "afterLineEdit")

        # ----Attribute buttons----
        self.attribute_buttons = [
            QtWidgets.QCheckBox(text) for text in g.attributes
        ]
        layout2 = self.widget.findChild(QtWidgets.QGroupBox,
                                        "attributesGroupBox").layout()

        for button in self.attribute_buttons:
            button.setEnabled(False)
            button.toggled.connect(
                lambda _: self.move_buttons(layout2, self.attribute_buttons))
            button.clicked.connect(lambda _: self.change_attributes())
        self.move_buttons(layout2, self.attribute_buttons)

        # ----Retrieval Buttons----
        self.accept_button = self.widget.findChild(QtWidgets.QPushButton,
                                                   "acceptPushButton")
        self.accept_button.clicked.connect(lambda _: self.accept_suggestion())

        self.reject_button = self.widget.findChild(QtWidgets.QPushButton,
                                                   "rejectPushButton")
        self.reject_button.clicked.connect(lambda _: self.reject_suggestion())

        self.reject_all_button = self.widget.findChild(QtWidgets.QPushButton,
                                                       "rejectAllPushButton")
        self.reject_all_button.clicked.connect(
            lambda _: self.reject_all_suggestions())

        # ----Classgraph-----------
        self.class_graph = self.widget.findChild(pg.PlotWidget, 'classGraph')
        self.distance_graph = self.widget.findChild(pg.PlotWidget,
                                                    'distanceGraph')
        self.distance_histogram = self.widget.findChild(
            pg.PlotWidget, 'distanceHistogram')

        # ----Status windows-------
        self.status_window = self.widget.findChild(QtWidgets.QTextEdit,
                                                   'statusWindow')
        self.add_status_message("New retrieval mode. This is a WIP")

    def enable_widgets(self):
        self.before_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples))
        self.after_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples))

        self.class_graph = Graph(plot_widget=self.class_graph,
                                 graph_type="class",
                                 label="Classes",
                                 interval_lines=True)
        self.class_graph.setup()

        self.distance_graph = Graph(self.distance_graph,
                                    'data',
                                    label="score",
                                    interval_lines=True,
                                    unit="",
                                    AutoSIPrefix=False,
                                    y_range=(0, 1))
        self.distance_graph.setup()

        self.distance_histogram = Graph(self.distance_histogram,
                                        "histogram",
                                        label="histogram",
                                        play_line=True)
        self.distance_histogram.setup()

        self.reload()

    def reload(self):
        self.class_graph.reload_classes(g.windows.windows)
        self.class_graph.update_frame_lines(-1000, -1000,
                                            self.gui.get_current_frame())
        self.distance_graph.update_frame_lines(-1000, -1000,
                                               self.gui.get_current_frame())
        self.distance_graph.update_plot(None)
        self.distance_histogram.update_histogram(None, None)

        if g.retrieval is not None \
                and (self.fixed_window_mode_enabled is None
                     or self.fixed_window_mode_enabled in ["none", "retrieval"]):

            self.none_button.setEnabled(True)

            if self.fixed_window_mode_enabled == "retrieval":
                self.retrieve_button.setEnabled(True)
                self.metric_comboBox.setEnabled(True)
            else:
                self.retrieve_button.setEnabled(False)
                self.metric_comboBox.setEnabled(False)

            if self.retrieved_list:
                self.query_comboBox.setEnabled(True)
                self.loop_checkBox.setEnabled(True)
                self.before_lineEdit.setEnabled(True)
                self.after_lineEdit.setEnabled(True)

                self.accept_button.setEnabled(True)
                self.reject_button.setEnabled(True)
                self.reject_all_button.setEnabled(True)

                # ---- Attribute Buttons ----
                window = g.windows.windows_1[self.get_annotation_index()]
                for button, checked in zip(self.attribute_buttons, window[3]):
                    button.setChecked(checked)
                    button.setEnabled(True)

                # ---- Class Graph ----
                suggestion = self.retrieved_list[0]
                index = suggestion["index"]
                s, e = suggestion["range"]

                self.class_graph.update_frame_lines(s, e)
                self.highlight_class_bar(index)

                # ---- Distance Graph ----
                distances = np.zeros((g.data.number_samples, ))
                for suggestion in self.retrieved_list:
                    s_, e_ = suggestion["range"]
                    distances[s_:e_] = suggestion["value"]
                self.distance_graph.update_plot(distances)

                self.distance_graph.update_frame_lines(s, e)

                # ---- Distance Histogram ----
                distances = [item["value"] for item in self.retrieved_list]
                min_x = min(distances)
                max_x = max(distances)
                y_values, x_values = np.histogram(distances,
                                                  bins=1000,
                                                  range=(0, 1))

                discard_min = sum(
                    [1 for x_value in x_values if x_value < min_x]) - 1
                discard_max = sum(
                    [1 for x_value in x_values if x_value > max_x]) - 1
                if discard_max == 0:
                    x_values = x_values[discard_min:]
                    y_values = y_values[discard_min:]
                else:
                    x_values = x_values[discard_min:-discard_max]
                    y_values = y_values[discard_min:-discard_max]

                self.distance_histogram.update_histogram(
                    x_values, y_values, self.not_filtered_range)
                self.distance_histogram.update_frame_lines(
                    play=self.retrieved_list[0]["value"])
            else:

                # if one retrieved list becomes empty other may still have windows left
                # self.query_comboBox.setEnabled(False)
                self.loop_checkBox.setEnabled(False)
                self.before_lineEdit.setEnabled(False)
                self.after_lineEdit.setEnabled(False)

                self.accept_button.setEnabled(False)
                self.reject_button.setEnabled(False)
                self.reject_all_button.setEnabled(False)

                for button in self.attribute_buttons:
                    button.setEnabled(False)
        else:
            self.metric_comboBox.setEnabled(False)
            self.none_button.setEnabled(False)
            self.retrieve_button.setEnabled(False)

            self.query_comboBox.setEnabled(False)
            self.loop_checkBox.setEnabled(False)
            self.before_lineEdit.setEnabled(False)
            self.after_lineEdit.setEnabled(False)

            self.accept_button.setEnabled(False)
            self.reject_button.setEnabled(False)
            self.reject_all_button.setEnabled(False)

            for button in self.attribute_buttons:
                button.setEnabled(False)

    # def setEnabled(self, enable: bool):
    #    print("overwrite setEnabled(self,enable:bool) in", type(self))

    def new_frame(self, frame):
        if self.retrieved_list and self.loop_checkBox.isChecked():
            start, end = self.retrieved_list[0]["range"]
            start = max(0, start - int(self.before_lineEdit.text()))
            end = min(g.data.number_samples - 1,
                      end + int(self.after_lineEdit.text()))
            if not (start <= frame < end):
                frame = start
                self.gui.playback_controller.set_start_frame()
        self.distance_graph.update_frame_lines(play=frame)
        self.class_graph.update_frame_lines(play=frame)

    def get_start_frame(self):
        if self.retrieved_list:
            s, e = self.retrieved_list[0]["range"]
            if self.gui.playback_controller.speed > 0:
                return max(0, s - int(self.before_lineEdit.text()))
            else:
                return min(g.data.number_samples - 1,
                           e + int(self.after_lineEdit.text())) - 1
        else:
            return 0

    def move_buttons(self, layout: QtWidgets.QGridLayout, buttons: list):
        """Moves all the buttons in a layout

        Checked radio/checkbox buttons get moved to the left
        Unchecked buttons get moved to the right

        Arguments:
        ----------
        layout : QGridLayout
            the layout on which the buttons should be
        buttons : list
            a list of QRadioButtons or QCheckBox buttons, that should be moved in the layout
        """

        for i, button in enumerate(buttons):
            if button.isChecked():
                layout.addWidget(button, i + 1, 0)
            else:
                layout.addWidget(button, i + 1, 2)

    def change_attributes(self):
        """Looks which Attribute buttons are checked and saves that to the current window"""

        attributes = []
        for button in self.attribute_buttons:
            if button.isChecked():
                attributes.append(1)
            else:
                attributes.append(0)
        g.windows.change_window(self.retrieved_list[0]["index"],
                                attributes=attributes,
                                save=True)
        self.class_graph.reload_classes(g.windows.windows)
        self.highlight_class_bar(self.retrieved_list[0]["index"])

    def highlight_class_bar(self, bar_index):
        colors = Controller.highlight_class_bar(self, bar_index)
        self.class_graph.color_class_bars(colors)

    def set_to_none(self):

        message = \
            (f"This action will enable fixed-window-size mode.\n"
             f"Any labeling done up to this point will be discarded "
             f"and all windows will be reset to {g.classes[-1]}.\n"
             f"Some features in other modes will be disabled until fixed-window-size mode is disabled again.\n"
             f"In this mode you can choose to accept or reject suggestions for a query class.")

        revision_mode_warning = QtWidgets.QMessageBox.question(
            self.gui, 'Start fixed-window-size mode?', message,
            QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.Cancel,
            QtWidgets.QMessageBox.Cancel)

        if revision_mode_warning == QtWidgets.QMessageBox.Yes:
            error_attr = [0 for _ in range(len(g.attributes))]
            error_attr[-1] = 1
            intervals = [
                g.retrieval.__range__(i) for i in range(len(g.retrieval))
            ]
            g.windows.windows = [(s, e, len(g.classes) - 1, error_attr)
                                 for (s, e) in intervals]
            g.windows.make_backup()
            self.gui.fixed_windows_mode("retrieval")

    def disable_fixed_window_mode(self):
        message = (
            "This will disable fixed-window-size mode.\n"
            "Make sure you are finished everything you need to do in this mode.\n"
            "Next time you activate revision mode your unsaved progress will be lost"
        )

        revision_mode_warning = QtWidgets.QMessageBox.question(
            self.gui, 'Stop revision mode?', message,
            QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.Cancel,
            QtWidgets.QMessageBox.Cancel)

        if revision_mode_warning == QtWidgets.QMessageBox.Yes:
            self.gui.fixed_windows_mode("none")

    def fixed_windows_mode(self, mode: str):
        self.fixed_window_mode_enabled = mode

        if mode == "retrieval":
            self.none_button.clicked.disconnect()
            self.none_button.clicked.connect(
                lambda _: self.disable_fixed_window_mode())
            self.none_button.setText("Stop retrieval mode")
        elif mode is None or mode == "none":
            self.none_button.clicked.disconnect()
            self.none_button.clicked.connect(lambda _: self.set_to_none())
            self.none_button.setText("Set all to None")
            self.none_button.setEnabled(True)

            self.retrieved_list = []
        else:
            self.none_button.setEnabled(False)

        self.reload()

    def retrieve_list(self):
        distance = self.metric_comboBox.currentText()

        g.retrieval.predict_classes_from_attributes(distance)
        g.retrieval.predict_attribute_reps(distance)
        g.retrieval.reset_filter()
        self.change_query(self.query_comboBox.currentIndex())
        self.reload()

    def change_query(self, class_index: int):
        self.retrieved_list = g.retrieval.retrieve_list(class_index)
        values = sorted([item["value"] for item in self.retrieved_list])
        self.not_filtered_range = (values[0], values[-1])
        self.retrieved_list = g.retrieval.filter_not_none_class(
            self.retrieved_list, class_index)
        self.reload()

    def accept_suggestion(self):
        _, _, _, a = g.windows.windows_1[self.get_annotation_index(0)]
        index = self.retrieved_list[0]["index"]
        s, e = self.retrieved_list[0]["range"]
        c = self.query_comboBox.currentIndex()
        g.windows.windows[index] = (s, e, c, a)
        g.retrieval.remove_suggestion(self.retrieved_list[0], None)
        self.change_attributes()  # Change attributes as seen on gui
        self.retrieved_list.pop(0)
        self.retrieved_list = g.retrieval.prioritize_neighbors(
            self.retrieved_list, index)
        self.reload()

    def reject_suggestion(self):
        class_index = self.query_comboBox.currentIndex()
        g.retrieval.remove_suggestion(self.retrieved_list[0], class_index)
        self.retrieved_list.pop(0)
        self.reload()

    def reject_all_suggestions(self):
        class_index = self.query_comboBox.currentIndex()
        g.retrieval.remove_suggestion(None, class_index)
        self.retrieved_list = []
        self.reload()

    def get_annotation_index(self, retrieval_index=0):
        """Calculates the networks window index based on the retrieval window index

        Since the network merges its prediction windows when the top3 predictions match it has fewer windows
        than the retrieval mode which keeps the windows at the networks output size.
        Therefore this method has to be used to convert retrieval index to the automatic annotation index.

        """

        s, e = self.retrieved_list[retrieval_index]["range"]
        m = (s + e) / 2
        for i in range(len(g.windows.windows_1)):
            window = g.windows.windows_1[i]
            if window[0] <= m < window[1]:
                return i
class Label_Correction_Controller(Controller):
    def __init__(self, gui):
        super(Label_Correction_Controller, self).__init__(gui)

        self.was_enabled_once = False

        # self.windows = []
        self.current_window = -1

        self.setup_widgets()

    def setup_widgets(self):
        self.load_tab(f'..{sep}ui{sep}label_correction_mode.ui',
                      "Label Correction")
        # ----Labels----
        self.current_window_label = self.widget.findChild(
            QtWidgets.QLabel, "lc_current_window_label")

        # ----Scrollbars----
        self.scrollBar = self.widget.findChild(QtWidgets.QScrollBar,
                                               "lc_scrollBar")
        self.scrollBar.valueChanged.connect(self.selectWindow)

        # ----LineEdits----
        # self. = self.widget.get_widget(QtWidgets.QLineEdit,"")
        self.split_at_lineEdit = self.widget.findChild(QtWidgets.QLineEdit,
                                                       "lc_split_at_lineEdit")
        self.move_start_lineEdit = self.widget.findChild(
            QtWidgets.QLineEdit, "lc_move_start_lineEdit")
        self.move_end_lineEdit = self.widget.findChild(QtWidgets.QLineEdit,
                                                       "lc_move_end_lineEdit")

        self.start_lineEdit = self.widget.findChild(QtWidgets.QLineEdit,
                                                    "lc_start_lineEdit")
        self.end_lineEdit = self.widget.findChild(QtWidgets.QLineEdit,
                                                  "lc_end_lineEdit")

        # ----Buttons----
        self.merge_previous_button = self.widget.findChild(
            QtWidgets.QPushButton, "lc_merge_previous_button")
        self.merge_previous_button.clicked.connect(
            lambda _: self.merge_previous())
        self.merge_next_button = self.widget.findChild(QtWidgets.QPushButton,
                                                       "lc_merge_next_button")
        self.merge_next_button.clicked.connect(lambda _: self.merge_next())
        self.merge_all_button = self.widget.findChild(QtWidgets.QPushButton,
                                                      "lc_merge_all_button")
        self.merge_all_button.clicked.connect(
            lambda _: self.merge_all_adjacent())

        self.split_at_button = self.widget.findChild(QtWidgets.QPushButton,
                                                     "lc_split_at_button")
        self.split_at_button.clicked.connect(lambda _: self.split())
        self.move_start_button = self.widget.findChild(QtWidgets.QPushButton,
                                                       "lc_move_start_button")
        self.move_start_button.clicked.connect(lambda _: self.move_start())
        self.move_end_button = self.widget.findChild(QtWidgets.QPushButton,
                                                     "lc_move_end_button")
        self.move_end_button.clicked.connect(lambda _: self.move_end())

        self.set_to_frame_split_button = self.widget.findChild(
            QtWidgets.QPushButton, "lc_set_frame_split_button")
        self.set_to_frame_split_button.clicked.connect(
            lambda _: self.split_at_lineEdit.setText(
                str(self.gui.get_current_frame() + 1)))
        self.set_to_frame_start_button = self.widget.findChild(
            QtWidgets.QPushButton, "lc_set_frame_start_button")
        self.set_to_frame_start_button.clicked.connect(
            lambda _: self.move_start_lineEdit.setText(
                str(self.gui.get_current_frame() + 1)))
        self.set_to_frame_end_button = self.widget.findChild(
            QtWidgets.QPushButton, "lc_set_frame_end_button")
        self.set_to_frame_end_button.clicked.connect(
            lambda _: self.move_end_lineEdit.setText(
                str(self.gui.get_current_frame() + 1)))
        self.set_to_start_button = self.widget.findChild(
            QtWidgets.QPushButton, "lc_set_start_button")
        self.set_to_start_button.clicked.connect(
            lambda _: self.move_start_lineEdit.setText(
                str(g.windows.windows[self.current_window][0] + 1)))
        self.set_to_end_button = self.widget.findChild(QtWidgets.QPushButton,
                                                       "lc_set_end_button")
        self.set_to_end_button.clicked.connect(
            lambda _: self.move_end_lineEdit.setText(
                str(g.windows.windows[self.current_window][1] + 1)))

        self.window_by_frame_button = self.widget.findChild(
            QtWidgets.QPushButton, "lc_window__by_frame_button")
        self.window_by_frame_button.clicked.connect(
            lambda _: self.select_window_by_frame())

        # ----Class buttons----
        self.classButtons = [
            QtWidgets.QRadioButton(text) for text in g.classes
        ]
        self.class_layout = self.widget.findChild(QtWidgets.QGroupBox,
                                                  "classesGroupBox").layout()

        for button in self.classButtons:
            button.setEnabled(False)
            button.toggled.connect(lambda _: self.move_buttons(
                self.class_layout, self.classButtons))
            button.clicked.connect(lambda _: self.changeClass())
        self.move_buttons(self.class_layout, self.classButtons)

        # ----Attribute buttons----
        self.attributeButtons = [
            QtWidgets.QCheckBox(text) for text in g.attributes
        ]
        layout2 = self.widget.findChild(QtWidgets.QGroupBox,
                                        "attributesGroupBox").layout()

        for button in self.attributeButtons:
            button.setEnabled(False)
            button.toggled.connect(
                lambda _: self.move_buttons(layout2, self.attributeButtons))
            button.clicked.connect(lambda _: self.changeAttributes())
        self.move_buttons(layout2, self.attributeButtons)

        # ----Classgraph-----------
        self.class_graph = self.widget.findChild(pg.PlotWidget,
                                                 'lc_classGraph')

        # ----Status windows-------
        self.status_window = self.widget.findChild(QtWidgets.QTextEdit,
                                                   'lc_statusWindow')
        self.add_status_message("Here you can correct wrong Labels.")

    def enable_widgets(self):
        """"""

        if not self.was_enabled_once:
            self.class_graph = Graph(self.class_graph,
                                     'class',
                                     interval_lines=False)
            self.was_enabled_once = True

        self.split_at_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples + 1))
        self.move_start_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples + 1))
        self.move_end_lineEdit.setValidator(
            QtGui.QIntValidator(0, g.data.number_samples + 1))

        self.class_graph.setup()
        self.class_graph.reload_classes(g.windows.windows)

        self.reload()

    def reload(self):
        """reloads all window information
        
        called when switching to label correction mode
        """
        # print("reloading LCC")
        self.class_graph.reload_classes(g.windows.windows)

        self.update_frame_lines(self.gui.get_current_frame())

        if g.windows is not None and len(g.windows.windows) > 0:
            self.set_enabled(True)
            self.select_window_by_frame()
            self.selectWindow(self.current_window)

        else:
            self.set_enabled(False)

    def set_enabled(self, enable: bool):
        """Turns the Widgets of Label Correction Mode on or off based on the enable parameter
        
        Arguments:
        ----------
        enable : bool
            If True and widgets were disabled, the widgets get enabled.
            If False and widgets were enabled, the widgets get disabled.
            Otherwise does nothing.
        ----------
        
        """
        # print("lcm.set_enabled:",\
        #      "\n\t self.enabled:",self.enabled,\
        #      "\n\t enable:",enable,\
        #      "\n\t revision:",self.revision_mode_enabled)
        if self.revision_mode_enabled:
            self.split_at_lineEdit.setEnabled(False)
            self.move_start_lineEdit.setEnabled(False)
            self.move_end_lineEdit.setEnabled(False)

            self.merge_previous_button.setEnabled(False)
            self.merge_next_button.setEnabled(False)
            self.merge_all_button.setEnabled(False)

            self.split_at_button.setEnabled(False)
            self.move_start_button.setEnabled(False)
            self.move_end_button.setEnabled(False)

            self.set_to_frame_split_button.setEnabled(False)
            self.set_to_frame_start_button.setEnabled(False)
            self.set_to_frame_end_button.setEnabled(False)
            self.set_to_start_button.setEnabled(False)
            self.set_to_end_button.setEnabled(False)
        else:
            self.split_at_lineEdit.setEnabled(enable)
            self.move_start_lineEdit.setEnabled(enable)
            self.move_end_lineEdit.setEnabled(enable)

            self.merge_previous_button.setEnabled(enable)
            self.merge_next_button.setEnabled(enable)
            self.merge_all_button.setEnabled(enable)

            self.split_at_button.setEnabled(enable)
            self.move_start_button.setEnabled(enable)
            self.move_end_button.setEnabled(enable)

            self.set_to_frame_split_button.setEnabled(enable)
            self.set_to_frame_start_button.setEnabled(enable)
            self.set_to_frame_end_button.setEnabled(enable)
            self.set_to_start_button.setEnabled(enable)
            self.set_to_end_button.setEnabled(enable)

        if not (self.enabled == enable):
            # Only reason why it might be disabled is that there were no windows
            # Therefore setting the current window to 0 as this mode is enabled
            # as soon as there is at least one window
            self.enabled = enable
            for button in self.classButtons:
                button.setEnabled(enable)
            for button in self.attributeButtons:
                button.setEnabled(enable)

            self.window_by_frame_button.setEnabled(enable)

            self.scrollBar.setEnabled(enable)

    def selectWindow(self, window_index: int):
        """Selects the window at window_index"""
        if window_index >= 0:
            self.current_window = window_index
        else:
            self.current_window = len(g.windows.windows) + window_index

        self.scrollBar.setRange(0, len(g.windows.windows) - 1)
        self.scrollBar.setValue(self.current_window)

        window = g.windows.windows[self.current_window]
        self.current_window_label.setText("Current Window: " +
                                          str(self.current_window + 1) + "/" +
                                          str(len(g.windows.windows)))
        self.start_lineEdit.setText(str(window[0] + 1))
        self.end_lineEdit.setText(str(window[1] + 1))

        self.classButtons[window[2]].setChecked(True)
        for button, checked in zip(self.attributeButtons, window[3]):
            button.setChecked(checked)

        if self.revision_mode_enabled:
            # print(window_index, len(g.windows.windows_1))
            top_buttons = [
                g.windows.windows_1[window_index][2],
                g.windows.windows_2[window_index][2],
                g.windows.windows_3[window_index][2]
            ]
            for i, name in enumerate(g.classes):
                if i == top_buttons[0]:
                    self.classButtons[i].setText(name + " (#1)")
                elif i == top_buttons[1]:
                    self.classButtons[i].setText(name + " (#2)")
                elif i == top_buttons[2]:
                    self.classButtons[i].setText(name + " (#3)")
                else:
                    self.classButtons[i].setText(name)

        self.highlight_class_bar(window_index)

    def highlight_class_bar(self, bar_index):
        colors = Controller.highlight_class_bar(self, bar_index)

        self.class_graph.color_class_bars(colors)

    def new_frame(self, frame):
        self.update_frame_lines(frame)
        window_index = self.class_window_index(frame)
        if self.enabled and (self.current_window != window_index):
            self.current_window = window_index
            self.highlight_class_bar(window_index)
            self.selectWindow(window_index)

    def update_frame_lines(self, play=None):
        self.class_graph.update_frame_lines(play=play)

    def select_window_by_frame(self, frame=None):
        """Selects the Window around based on the current Frame shown
        
        """
        if frame is None:
            frame = self.gui.get_current_frame()
        window_index = self.class_window_index(frame)
        if window_index is None:
            window_index = -1
        # if the old and new index is the same do nothing.
        if self.current_window != window_index:
            self.current_window = window_index
            # self.reload()
            self.selectWindow(window_index)
        else:
            self.current_window = window_index

    def mergeable(self, window_index_a: int, window_index_b: int) -> bool:
        """Checks whether two windows can be merged
        
        window_index_a should be smaller than window_index_b
        """
        if (window_index_a + 1
                == window_index_b) and (window_index_a >= 0) and (
                    window_index_b < len(g.windows.windows)):
            window_a = g.windows.windows[window_index_a]
            window_b = g.windows.windows[window_index_b]
            if window_a[2] == window_b[2]:
                a_and_b = [a == b for a, b in zip(window_a[3], window_b[3])]
                return reduce(lambda a, b: a and b, a_and_b)
        return False

    def merge(self,
              window_index_a: int,
              window_index_b: int,
              check_mergeable=True,
              reload=True):
        """Tries to merge two windows"""
        if not check_mergeable or self.mergeable(window_index_a,
                                                 window_index_b):
            window_b = g.windows.windows[window_index_b]
            g.windows.change_window(window_index_a,
                                    end=window_b[1],
                                    save=False)
            g.windows.delete_window(window_index_b, save=True)
            if self.current_window == len(g.windows.windows):
                self.current_window -= 1

            if reload:
                self.reload()

    def merge_all_adjacent(self):
        """Tries to merge all mergeable adjacent windows"""
        for i in range(len(g.windows.windows)):
            while self.mergeable(i, i + 1):
                self.merge(i, i + 1, False, False)
        self.reload()

    def merge_previous(self):
        """Tries to merge the current window with the previous"""

        if self.current_window == 0:
            self.add_status_message(
                "Can't merge the first window with a previous window.")
        else:
            self.merge(self.current_window - 1, self.current_window)

    def merge_next(self):
        """Tries to merge the current window with the next"""

        if self.current_window == len(g.windows.windows) - 1:
            self.add_status_message(
                "Can't merge the last window with a following window.")
        else:
            self.merge(self.current_window, self.current_window + 1)

    def split(self):
        """Splits the current window into two windows at a specified frame"""
        split_point = self.split_at_lineEdit.text()
        if split_point != '':
            split_point = int(self.split_at_lineEdit.text()) - 1
            window = g.windows.windows[self.current_window]
            if window[0] + 25 < split_point < window[1] - 25:
                g.windows.insert_window(self.current_window, window[0],
                                        split_point, window[2], window[3],
                                        False)
                g.windows.change_window(self.current_window + 1,
                                        start=split_point,
                                        save=True)
                # self.gui.reloadClasses()
                self.reload()
            else:
                self.add_status_message(
                    "The splitting point should be inside the current window")

    def move_start(self):
        """Moves the start frame of the current window to a specified frame
        
        Moves the end of the previous window too.
        """
        start_new = self.move_start_lineEdit.text()
        if start_new != '':
            if self.current_window > 0:
                window_previous = g.windows.windows[self.current_window - 1]
                window_current = g.windows.windows[self.current_window]
                start_new = int(self.move_start_lineEdit.text()) - 1
                if window_previous[0] + 50 < start_new:
                    if start_new < window_current[1] - 50:
                        g.windows.change_window(self.current_window - 1,
                                                end=start_new,
                                                save=False)
                        g.windows.change_window(self.current_window,
                                                start=start_new,
                                                save=True)
                        # self.gui.reloadClasses()
                        self.reload()
                    else:
                        self.add_status_message(
                            "A window can't start after it ended.")
                else:
                    self.add_status_message(
                        "A window can't start before a previous window.")
            else:
                self.add_status_message(
                    "You can't move the start point of the first window.")

    def move_end(self):
        """Moves the end frame of the current window to a specified frame
        
        Moves the start of the next window too.
        """
        end_new = self.move_end_lineEdit.text()
        if end_new != '':

            window_current = g.windows.windows[self.current_window]
            end_new = int(self.move_end_lineEdit.text())

            if window_current[0] + 50 < end_new:
                if self.current_window < len(g.windows.windows) - 1:
                    window_next = g.windows.windows[self.current_window + 1]
                    if end_new < window_next[1] - 50:
                        g.windows.change_window(self.current_window,
                                                end=end_new,
                                                save=False)
                        g.windows.change_window(self.current_window + 1,
                                                start=end_new,
                                                save=True)
                        # self.gui.reloadClasses()
                        self.reload()
                    else:
                        self.add_status_message(
                            "A window can't end after a following window ends."
                        )
                else:
                    if end_new <= g.data.number_samples:
                        g.windows.change_window(self.current_window,
                                                end=end_new,
                                                save=True)
                        # self.gui.reloadClasses()
                        self.reload()
                    else:
                        self.add_status_message(
                            "A window can't end after the end of the data.")
            else:
                self.add_status_message(
                    "A window can't end before if started.")

    def changeClass(self):
        for i, button in enumerate(self.classButtons):
            if button.isChecked():
                g.windows.change_window(self.current_window,
                                        class_index=i,
                                        save=True)
        # self.reload()
        self.class_graph.reload_classes(g.windows.windows)
        self.highlight_class_bar(self.current_window)

    def changeAttributes(self):
        """Looks which Attribute buttons are checked and saves that to the current window"""

        attributes = []
        for button in self.attributeButtons:
            if button.isChecked():
                attributes.append(1)
            else:
                attributes.append(0)
        g.windows.change_window(self.current_window,
                                attributes=attributes,
                                save=True)
        # self.reload()
        self.class_graph.reload_classes(g.windows.windows)
        self.highlight_class_bar(self.current_window)

    def move_buttons(self, layout: QtWidgets.QGridLayout, buttons: list):
        """Moves all the buttons in a layout
        
        Checked radio/checkbox buttons get moved to the left
        Unchecked buttons get moved to the right
        
        Arguments:
        ----------
        layout : QGridLayout
            the layout on which the buttons should be
        buttons : list
            a list of QRadioButtons or QCheckBox buttons, that should be moved in the layout
        """

        for i, button in enumerate(buttons):
            if button.isChecked():
                layout.addWidget(button, i + 1, 0)
            else:
                layout.addWidget(button, i + 1, 2)

    def get_start_frame(self) -> int:
        """returns the start of the current window"""
        if len(g.windows.windows) > 0:
            return g.windows.windows[self.current_window][0] + 1
        return 1

    def revision_mode(self, enable: bool):
        self.revision_mode_enabled = enable

        if not enable:
            for i, name in enumerate(g.classes):
                self.classButtons[i].setText(name)

        self.reload()