class CreationWindow(QWidget):

    # Layer buttons vector
    __layerVector = []

    # Deleting mode
    __deleteMode = False

    # Connecting mode
    __connectingMode = False

    # Layer parameters box
    __parametersShown = False

    # Last layer that opened the parameters window
    __currentLayer = None

    # Last layer that opened connecting mode
    __connectingLayer = None

    # Lines
    __lines = []

    # Done signal
    done = Signal()

    # Dataset
    dataset = None

    @Slot()
    def closeParameters(self):
        # Closing parameters window
        if self.__parametersShown:
            self.__parametersShown = False
            self.__layerParametersGroupBox.hide()
            self.__layerOutputLineEdit.setText(str(""))
            self.__currentLayer = None

    @Slot()
    def acceptParameters(self):
        # Setting new layer parameters
        if self.__layerOutputLineEdit.text():
            self.__currentLayer.setOutput(
                int(self.__layerOutputLineEdit.text()))
        self.__currentLayer.setActivation(self.__activationMenu.currentIndex())
        if self.__kernelRowsLineEdit.text():
            self.__currentLayer.kernelRows = int(
                self.__kernelRowsLineEdit.text())
        if self.__kernelColumnLineEdit.text():
            self.__currentLayer.kernelColumns = int(
                self.__kernelColumnLineEdit.text())

    @Slot()
    def layerAction(self):
        # Deleting the layer if delete mode is on
        if self.__deleteMode:
            address = id(self.sender())
            self.sender().deleteLater()
            line1 = -1
            line2 = -1
            # Removing layer from current layers list
            for i in range(len(self.__layerVector)):
                if id(self.__layerVector[i]) == address:
                    del self.__layerVector[i]
                    break
            # Removing layer from lines list
            for i in range(len(self.__lines)):
                if id(self.__lines[i][0].parentWidget()) == address:
                    self.__lines[i][1].setConnected(False)
                    line1 = i
                elif id(self.__lines[i][1].parentWidget()) == address:
                    self.__lines[i][0].setConnected(False)
                    line2 = i
            if line1 != -1:
                del self.__lines[line1]
            if line2 != -1:
                del self.__lines[line2]
        # Showing parameters window if delete mode is off
        elif self.sender().text()[:2] != "In" and self.sender().text(
        )[:2] != "Ou" and self.sender().text()[:2] != "Fl":
            self.__currentLayer = self.sender()
            if self.sender().text()[:13] == "Convolutional":
                self.__kernelRowsGroupBox.show()
                self.__kernelColumnGroupBox.show()
                self.__activationGroupBox.show()
                self.__layerOutputGroupBox.hide()
            else:
                self.__layerOutputGroupBox.show()
                self.__activationGroupBox.show()
                self.__kernelRowsGroupBox.hide()
                self.__kernelColumnGroupBox.hide()
            self.__layerOutputLineEdit.setText(str(self.sender().output()))
            self.__activationMenu.setCurrentIndex(self.sender().activation())
            if not self.__parametersShown:
                self.__layerParametersGroupBox.show()
                self.__layerParametersGroupBox.raise_()
                self.__parametersShown = True

    @Slot()
    def reset(self):
        # Deleting all layers in the list and clearing it
        if self.__layerVector:
            for i in range(len(self.__layerVector)):
                self.__layerVector[i].deleteLater()
            del self.__layerVector[:]
        # Clearing lines list
        del self.__lines[:]
        # Hiding the parameters window
        self.__layerParametersGroupBox.hide()

    @Slot()
    def confirm(self):
        # Checking and confirming model
        self.modelList = []
        input = False
        tmp = 0
        # Checking network name
        if not self.networkName.text():
            ret = QMessageBox.warning(self, "Network name",
                                      "Enter a name for the neural network",
                                      QMessageBox.Ok)
            return
        # If no layers are connected
        if not self.__lines:
            ret = QMessageBox.warning(self, "Invalid model",
                                      "No layer connected", QMessageBox.Ok)
            return
        # If one or more layers are not connected
        elif len(self.__lines) != len(self.__layerVector) - 1:
            ret = QMessageBox.warning(self, "Invalid model",
                                      "Some layers are unconnected",
                                      QMessageBox.Ok)
            return
        else:
            # Finding the input layer
            for line in self.__lines:
                if line[0].parentWidget().text()[:5] == "Input":
                    self.modelList.append(line[0].parentWidget())
                    self.modelList.append(line[1].parentWidget())
                    tmp = id(line[1].parentWidget())
                    input = True
            # If there is no input
            if not input:
                ret = QMessageBox.warning(self, "Invalid model",
                                          "No input layer found",
                                          QMessageBox.Ok)
                return
            # Adding all the layers to a list
            while len(self.modelList) != len(self.__lines) + 1:
                for line in self.__lines:
                    if id(line[0].parentWidget()) == tmp:
                        self.modelList.append(line[1].parentWidget())
                        tmp = id(line[1].parentWidget())

            # If there is no output
            if self.modelList[-1].text()[:6] != "Output":
                ret = QMessageBox.warning(self, "Invalid model",
                                          "No output layer found",
                                          QMessageBox.Ok)
                return
            del self.modelList[-1]
            # Checking layers
            c = 0
            for x in self.modelList:
                if isinstance(self.dataset, pd.DataFrame):
                    if x.text()[:2] == "Co":
                        ret = QMessageBox.warning(
                            self, "Convolutional layer",
                            "Cannot put a convolutional layer. Dataset is alphanumerical",
                            QMessageBox.Ok)
                        return
                    elif x.text()[:2] == "Fl":
                        ret = QMessageBox.warning(
                            self, "Flatten layer",
                            "Cannot put a flatten layer. Dataset is alphanumerical",
                            QMessageBox.Ok)
                        return
                if x.text()[:2] == "Co" and self.modelList[
                        c - 1].text()[:2] == "De":
                    ret = QMessageBox.warning(
                        self, "Dense layer",
                        "A dense layer cannot be followed by a convolutional layer",
                        QMessageBox.Ok)
                    return
                if x.text(
                )[:2] == "Fl" and self.modelList[c - 1].text()[:2] != "Co":
                    ret = QMessageBox.warning(
                        self, "Flatten layer",
                        "A flatten layer must always be preceded by a convolutional layer",
                        QMessageBox.Ok)
                    return
                if x.text()[:2] == "Co" and (
                        self.modelList[c + 1].text()[:2] != "Fl"
                        and self.modelList[c + 1].text()[:2] != "Co"):
                    ret = QMessageBox.warning(
                        self, "Convolutional layer",
                        "A convolutional layer must always be followed by a flatten or convolutional layer",
                        QMessageBox.Ok)
                    return
                if x.text()[:2] == "De" and not x.output():
                    ret = QMessageBox.warning(
                        self, "Layer output number",
                        "Please select an output number for each layer",
                        QMessageBox.Ok)
                    return
                if x.text()[:2] == "Co" and (not x.kernelColumns
                                             or not x.kernelRows):
                    ret = QMessageBox.warning(
                        self, "Kernel size",
                        "Please select a kernel size for convolutional layers",
                        QMessageBox.Ok)
                    return
                c = c + 1

            # Emit signal to switch the window
            self.done.emit()

    @Slot()
    def connectLayer(self):
        # Activating connecting mode (line following cursor)
        if not self.__connectingMode:
            if not self.sender().isConnected():
                self.__connectingMode = True
                self.__connectingLayer = self.sender()
        else:
            parent_1 = id(self.__connectingLayer.parentWidget())
            parent_2 = id(self.sender().parentWidget())
            connectingLayerType = self.__connectingLayer.objectName()
            senderLayerType = self.sender().objectName()
            # Checking if the connection works
            # Can't connect an already connected layer
            # parent_1 != parent_2 : Can't connect the input and output of the same layer
            # Can only connect an Input with an Output or an Output with an Input
            if not self.sender().isConnected() and (parent_1 != parent_2) and (
                (connectingLayerType == "input"
                 and senderLayerType == "output") or
                (connectingLayerType == "output"
                 and senderLayerType == "input")):
                # Connecting the two buttons
                if connectingLayerType == "input":
                    self.__lines.append(
                        [self.sender(), self.__connectingLayer])
                else:
                    self.__lines.append(
                        [self.__connectingLayer,
                         self.sender()])
                self.__connectingLayer.setConnected(True)
                self.sender().setConnected(True)
                # Deactivating connecting mode and reseting last layer clicked
                self.__connectingMode = False
                self.__connectingLayer = None

    @Slot()
    def createLayer(self):
        # Creating the layer
        layer = LayerButton(self.sender().text(), self)
        layer.setFont(QFont("BebasNeue", 20, QFont.Bold))
        layer.setGeometry(400, 220, 180, 80)
        layer.setObjectName("layer")
        # Layer icon settings
        layer.setIcon(QIcon(self.sender().icon()))
        layer.setIconSize(QSize(30, 30))
        # Hiding the input/output button on an Input/Output Layer
        if layer.text()[:5] == "Input":
            layer.inputButton.hide()
        elif layer.text()[:6] == "Output":
            layer.outputButton.hide()
        # Setting the cursor accordingly
        if not self.__deleteMode:
            layer.setCursor(Qt.PointingHandCursor)
        else:
            layer.setCursor(Qt.CrossCursor)
        # Connecting the layer to its slots
        layer.clicked.connect(self.layerAction)
        layer.inputButton.clicked.connect(self.connectLayer)
        layer.outputButton.clicked.connect(self.connectLayer)
        # Showing the layer
        layer.show()
        # Adding the layer to the list of current layers
        self.__layerVector.append(layer)
        # Raising the parameters groupbox on top of the screen
        self.__layerParametersGroupBox.raise_()

    # Initialiazing the window
    def __init__(self, dataset, *args, **kwargs):
        super(CreationWindow, self).__init__(*args, **kwargs)

        # Initializing dataset
        self.dataset = dataset

        # Loading Fonts
        QFontDatabase.addApplicationFont("fonts/BebasNeue-Light.ttf")

        # Accepting drag & drops
        self.setAcceptDrops(True)

        # Window Settings
        self.setFixedSize(1280, 720)
        self.setWindowTitle("Neural network creation")
        background = QPixmap("images/grid")
        palette = QPalette()
        palette.setBrush(QPalette.Background, background)
        self.setPalette(palette)
        self.setAttribute(Qt.WA_StyledBackground, True)
        self.setAutoFillBackground(True)

        # Enabling mouse tracking
        self.setMouseTracking(True)

        # Creating graphics scene
        #self.__scene = QGraphicsScene(300, 120, 980, 600, self)
        #self.__scene.addLine(400, 500, 600, 650, QPen(Qt.yellow, 10))
        #self.__liveFeedView = QGraphicsView()
        #self.__liveFeedView.setScene(self.__scene)

        # Stylesheet Settings
        styleFile = QFile("stylesheets/creation.qss")
        styleFile.open(QFile.ReadOnly)
        style = str(styleFile.readAll())
        self.setStyleSheet(style)

        # Network name line edit
        self.networkName = QLineEdit(self)
        self.networkName.setPlaceholderText("Enter neural network name")

        # Netwok name label
        self.__networkNameLabel = QLabel("Neural network name : ", self)
        self.__networkNameLabel.setFont(QFont("BebasNeue", 20, QFont.Bold))
        self.__networkNameLabel.setAlignment(Qt.AlignTop | Qt.AlignLeft)
        self.__networkNameLabel.setBuddy(self.networkName)

        # Accept/Reset buttons
        self.__topRightButtons = []
        for x in range(2):
            self.__topRightButtons.append(x)
            self.__topRightButtons[x] = QPushButton("Bouton ici", self)
            self.__topRightButtons[x].setCursor(Qt.PointingHandCursor)
            self.__topRightButtons[x].setIconSize(QSize(35, 35))
            self.__topRightButtons[x].setFont(
                QFont("BebasNeue", 10, QFont.Bold))

        # Customising accept/reset buttons
        self.__topRightButtons[0].setText("Reset")
        self.__topRightButtons[0].setIcon(QIcon("images/reset_icon"))
        self.__topRightButtons[1].setText("Confirm")
        self.__topRightButtons[1].setIcon(QIcon("images/check_icon"))

        # Connecting accept/reset buttons
        self.__topRightButtons[0].clicked.connect(self.reset)
        self.__topRightButtons[1].clicked.connect(self.confirm)

        # Go back button
        self.goBackButton = QPushButton("Back", self)
        self.goBackButton.setObjectName("retour")

        # Connecting go back button
        self.goBackButton.clicked.connect(self.reset)

        # Customising go back button
        self.goBackButton.setCursor(Qt.PointingHandCursor)
        self.goBackButton.setIcon(QIcon("images/goback_icon"))
        self.goBackButton.setIconSize(QSize(30, 30))
        self.goBackButton.setFont(QFont("BebasNeue", 20, QFont.Bold))

        # Layer selection buttons
        self.__layerButtons = []
        for x in range(5):
            self.__layerButtons.append(x)
            self.__layerButtons[x] = QPushButton(self)
            self.__layerButtons[x].setCursor(Qt.PointingHandCursor)
            self.__layerButtons[x].setFont(QFont("BebasNeue", 10, QFont.Bold))
            self.__layerButtons[x].clicked.connect(self.createLayer)

        # Layer buttons names
        self.__layerButtons[0].setText("Input layer")
        self.__layerButtons[1].setText("Output layer")
        self.__layerButtons[2].setText("Dense layer")
        self.__layerButtons[3].setText("Flatten layer")
        self.__layerButtons[4].setText("Convolutional layer")

        # Layer buttons icons
        for x in range(5):
            icon = "images/layer_icon_"
            self.__layerButtons[x].setIcon(QIcon(icon + str(x)))
            self.__layerButtons[x].setIconSize(QSize(45, 45))

        # Top buttons layout settings
        self.__buttonLayout = QHBoxLayout(self)
        self.__buttonGroupBox = QGroupBox(self)
        self.__buttonGroupBox.setGeometry(780, -15, 500, 120)
        self.__buttonLayout.addWidget(self.__topRightButtons[0])
        self.__buttonLayout.addWidget(self.__topRightButtons[1])
        self.__buttonGroupBox.setLayout(self.__buttonLayout)

        # Network name form layout settings
        self.__networkNameLayout = QFormLayout(self)
        self.__networkNameGroupBox = QGroupBox(self)
        self.__networkNameGroupBox.setGeometry(300, -15, 480, 120)
        self.__networkNameLayout.addWidget(self.__networkNameLabel)
        self.__networkNameLayout.addWidget(self.networkName)
        self.__networkNameGroupBox.setLayout(self.__networkNameLayout)

        # Layer buttons layout settings
        self.__layerButtonLayout = QVBoxLayout(self)
        self.__layerButtonGroupBox = QGroupBox("Layer selection", self)
        self.__layerButtonGroupBox.setGeometry(0, -15, 300, 735)
        for x in range(5):
            self.__layerButtonLayout.addWidget(self.__layerButtons[x])
        self.__layerButtonLayout.addWidget(self.goBackButton)
        self.__layerButtonGroupBox.setLayout(self.__layerButtonLayout)

        # Parameters window settings
        # Layer output label
        self.__layerOutputLabel = QLabel("Output number", self)
        self.__layerOutputLabel.setFont(QFont("BebasNeue", 20, QFont.Bold))
        # Layer output line edit
        self.__layerOutputLineEdit = QLineEdit(self)
        self.__layerOutputLineEdit.setValidator(QIntValidator(0, 1000, self))
        # Layer output form settings
        self.__layerOutputLayout = QFormLayout(self)
        self.__layerOutputGroupBox = QGroupBox(self)
        self.__layerOutputLayout.addWidget(self.__layerOutputLabel)
        self.__layerOutputLayout.addWidget(self.__layerOutputLineEdit)
        self.__layerOutputGroupBox.setLayout(self.__layerOutputLayout)
        # Activation function label
        self.__activationLabel = QLabel("Activation function", self)
        self.__activationLabel.setFont(QFont("BebasNeue", 20, QFont.Bold))
        # Activation function menu
        self.__activationMenu = QComboBox(self)
        self.__activationMenu.addItems(
            ['Sigmoid', 'Tanh', 'Rectified Linear Unit', 'Softmax'])
        # Activation function form settings
        self.__activationLayout = QFormLayout(self)
        self.__activationGroupBox = QGroupBox(self)
        self.__activationLayout.addWidget(self.__activationLabel)
        self.__activationLayout.addWidget(self.__activationMenu)
        self.__activationGroupBox.setLayout(self.__activationLayout)
        # Close window button
        self.__closeButton = QPushButton(self)
        self.__closeButton.setObjectName("close")
        self.__closeButton.setCursor(Qt.PointingHandCursor)
        self.__closeButton.setIcon(QIcon("images/close_icon"))
        self.__closeButton.setIconSize(QSize(35, 35))
        self.__closeButton.clicked.connect(self.closeParameters)
        # Accept changes button
        self.__acceptButton = QPushButton(self)
        self.__acceptButton.setObjectName("accept")
        self.__acceptButton.setCursor(Qt.PointingHandCursor)
        self.__acceptButton.setIcon(QIcon("images/accept_icon"))
        self.__acceptButton.setIconSize(QSize(35, 35))
        self.__acceptButton.clicked.connect(self.acceptParameters)
        # Close/Accept buttons layout
        self.__bottomButtonsLayout = QHBoxLayout(self)
        self.__bottomButtonsGroupBox = QGroupBox(self)
        self.__bottomButtonsLayout.addWidget(self.__closeButton)
        self.__bottomButtonsLayout.addWidget(self.__acceptButton)
        self.__bottomButtonsGroupBox.setLayout(self.__bottomButtonsLayout)
        # Kernel rows label
        self.__kernelRowsLabel = QLabel("Kernel rows", self)
        self.__kernelRowsLabel.setFont(QFont("BebasNeue", 20, QFont.Bold))
        # Kernel rows line edit
        self.__kernelRowsLineEdit = QLineEdit(self)
        self.__kernelRowsLineEdit.setValidator(QIntValidator(0, 1000, self))
        # Kernel rows form layout
        self.__kernelRowsLayout = QFormLayout(self)
        self.__kernelRowsLayout.addWidget(self.__kernelRowsLabel)
        self.__kernelRowsLayout.addWidget(self.__kernelRowsLineEdit)
        # Kernel rows group box
        self.__kernelRowsGroupBox = QGroupBox(self)
        self.__kernelRowsGroupBox.setLayout(self.__kernelRowsLayout)
        self.__kernelRowsGroupBox.hide()
        # Kernel columns label
        self.__kernelColumnLabel = QLabel("Kernel columns", self)
        self.__kernelColumnLabel.setFont(QFont("BebasNeue", 20, QFont.Bold))
        # Kernel columns line edit
        self.__kernelColumnLineEdit = QLineEdit(self)
        self.__kernelColumnLineEdit.setValidator(QIntValidator(0, 1000, self))
        # Kernel columns form layout
        self.__kernelColumnLayout = QFormLayout(self)
        self.__kernelColumnLayout.addWidget(self.__kernelColumnLabel)
        self.__kernelColumnLayout.addWidget(self.__kernelColumnLineEdit)
        # Kernel columns group box
        self.__kernelColumnGroupBox = QGroupBox(self)
        self.__kernelColumnGroupBox.setLayout(self.__kernelColumnLayout)
        self.__kernelColumnGroupBox.hide()
        # Layer parameters group box
        self.__layerParametersGroupBox = QGroupBox(self)
        self.__layerParametersGroupBox.setObjectName("parameters")
        self.__layerParametersGroupBox.setGeometry(960, 88, 320, 550)
        self.__layerParametersLayout = QVBoxLayout(self)
        self.__layerParametersLayout.addWidget(self.__layerOutputGroupBox)
        self.__layerParametersLayout.addWidget(self.__activationGroupBox)
        self.__layerParametersLayout.addWidget(self.__kernelRowsGroupBox)
        self.__layerParametersLayout.addWidget(self.__kernelColumnGroupBox)
        self.__layerParametersLayout.addWidget(self.__bottomButtonsGroupBox)
        self.__layerParametersGroupBox.setLayout(self.__layerParametersLayout)
        self.__layerParametersGroupBox.hide()
        self.__layerParametersGroupBox.raise_()

    # Overloading dragEnterEvent method
    def dragEnterEvent(self, e):
        e.accept()

    # Overloading dropEvent method
    def dropEvent(self, e):
        # Getting the event source
        button = QDropEvent.source(e)
        # Calculating the new coordinates
        new_x = e.pos().x() - button.getCursorX()
        new_y = e.pos().y() - button.getCursorY()
        # Moving the button if it is still in frame
        if new_x > 350 and new_y > 120 and new_x < 1100 and new_y < 640:
            position = QPoint(new_x, new_y)
            button.move(position)
            e.setDropAction(Qt.MoveAction)
            e.accept()

    # Overloading keyPressEvent to activate/deactivate deleting mode
    def keyPressEvent(self, ev):
        # Activating delete mode when pressing delete key
        if not self.__deleteMode and ev.key() == Qt.Key_Delete:
            self.setCursor(Qt.CrossCursor)
            self.__deleteMode = True
            for i in range(len(self.__layerVector)):
                self.__layerVector[i].setCursor(Qt.CrossCursor)
        # Deactivating delete mode when pressing escape key
        elif self.__deleteMode and ev.key() == Qt.Key_Escape:
            self.setCursor(Qt.ArrowCursor)
            self.__deleteMode = False
            for i in range(len(self.__layerVector)):
                self.__layerVector[i].setCursor(Qt.PointingHandCursor)
        # Deactivating connecting mode when pressing escape key
        if self.__connectingMode and ev.key() == Qt.Key_Escape:
            self.__connectingMode = False

    # Overloading mouse press event to unfocus QLineEdit
    def mousePressEvent(self, event):
        focused_widget = QApplication.focusWidget()
        if isinstance(focused_widget, QLineEdit):
            focused_widget.clearFocus()
        QWidget.mousePressEvent(self, event)
        if self.__deleteMode:
            i = 0
            for line in self.__lines:
                # Getting the line equation
                slope = (line[0].y() + line[0].parentWidget().y() -
                         line[1].y() - line[1].parentWidget().y()
                         ) / float(line[0].x() + line[0].parentWidget().x() -
                                   line[1].x() - line[1].parentWidget().x())
                point_slope = line[0].y() + line[0].parentWidget().y(
                ) + 20 - slope * (line[0].x() + line[0].parentWidget().x() +
                                  20)
                # Checking if the coordinates of the mouse click are on the line
                if event.pos().y() < int(
                        slope * event.pos().x() + point_slope +
                        3) and event.pos().y() > int(slope * event.pos().x() +
                                                     point_slope - 3):
                    # Delete the line here
                    self.__lines[i][0].setConnected(False)
                    self.__lines[i][1].setConnected(False)
                    self.__lines.remove(line)
                i = i + 1

    # Overloading mouse move event to change the cursor if its out of frame
    def mouseMoveEvent(self, ev):
        if self.__deleteMode:
            if ev.pos().x() < 325 or ev.pos().y() < 120:
                self.setCursor(Qt.ArrowCursor)
            else:
                self.setCursor(Qt.CrossCursor)
        if self.__connectingMode:
            if ev.pos().x() < 340 or ev.pos().y() < 130:
                self.__connectingMode = False

    # Drawing lines
    def paintEvent(self, e):
        painter = QPainter(self)
        painter.setPen(QPen(QColor(145, 18, 9), 10))
        if self.__connectingMode:
            layerPoint = QPoint(
                self.__connectingLayer.x() +
                self.__connectingLayer.parentWidget().x() + 20,
                self.__connectingLayer.y() +
                self.__connectingLayer.parentWidget().y() + 20)
            painter.drawLine(layerPoint, self.mapFromGlobal(QCursor.pos()))
            self.update()
        if self.__lines:
            for line in self.__lines:
                point1 = QPoint(line[0].x() + line[0].parentWidget().x() + 20,
                                line[0].y() + line[0].parentWidget().y() + 20)
                point2 = QPoint(line[1].x() + line[1].parentWidget().x() + 20,
                                line[1].y() + line[1].parentWidget().y() + 20)
                self.update()
                painter.drawLine(point1, point2)
Exemplo n.º 2
0
class TrainingWindow(QWidget):

    # Dataset
    __dataset = None

    # Netwok model
    __modelList = []

    # Start training
    @Slot()
    def startTraining(self):

        # Get split value
        split = 1 - float(
            (self.__datasetSplitComboBox.currentIndex() + 1) / 10.0)

        # Get split method
        if self.__datasetSplitRandom.isChecked():
            ((x_train, y_train),
             (x_test,
              y_test)) = self.network.random_split(self.__dataset, split)
        elif self.__datasetSplitRegular.isChecked():
            ((x_train, y_train),
             (x_test,
              y_test)) = self.network.regular_split(self.__dataset, split)

        # Get epochs number
        if not self.__epochsLineEdit.text():
            ret = QMessageBox.warning(self, "Epochs number",
                                      "Please enter the number of epochs",
                                      QMessageBox.Ok)
            return
        else:
            epochs = int(self.__epochsLineEdit.text())
            self.__xAxis.setRange(0, epochs)

        # Get learning rate value
        if not self.__learningRateLineEdit.text():
            ret = QMessageBox.warning(self, "Learning rate",
                                      "Please select a learning rate",
                                      QMessageBox.Ok)
            return
        else:
            learning_rate = float(self.__learningRateLineEdit.text().replace(
                ",", "."))
            if not learning_rate:
                ret = QMessageBox.warning(
                    self, "Learning rate",
                    "The learning rate cannot be equal to zero",
                    QMessageBox.Ok)
                return

        # Get learning rate mode
        if self.__learningRateCheckBox.isChecked():
            mode = 2
        else:
            mode = 1

        # Save before training
        ret = QMessageBox.question(
            self, "Network save",
            "Would you like to save the network before the training starts?",
            QMessageBox.Yes | QMessageBox.No)
        if ret == QMessageBox.Yes:
            save_matrix_neural_network(self.network, self.networkName)
            manual_save_model_neural_network(self.network, self.networkName)
            QMessageBox.information(self, "Network save",
                                    "Network successfully saved !",
                                    QMessageBox.Ok)

        # Clearing the graph
        self.__series.clear()

        # Starting training
        length = len(x_train)
        for i in range(epochs):
            err = 0
            training_accuracy = 0
            l_rate = self.network.Learning_rate_schedule(
                mode, epochs, epochs - i + 1, learning_rate)
            for j in range(length):
                outputs = x_train[j]
                for layer in self.network.layers:
                    outputs = layer.forward_propagation(outputs)
                err += self.network.loss(y_train[j], outputs)
                training_accuracy = training_accuracy + self.network.verification_of_prediction(
                    x_train, y_train, j)
                error = self.network.loss_prime(y_train[j], outputs)
                for layer in reversed(self.network.layers):
                    error = layer.backward_propagation(error, l_rate)
            err = err / length
            training_accuracy = training_accuracy / float(length)
            self.__epochNumberLabel.setText("Epoch : " + str(i + 1) + "/" +
                                            str(epochs))
            self.__trainingAccuracyLabel.setText("Taux de precision : " +
                                                 str(training_accuracy * 100) +
                                                 "%")
            # Appending values to the chart
            self.__series.append(i, training_accuracy * 100)
            self.__chartView.repaint()
            # Auto saving network
            save_matrix_neural_network(self.network,
                                       self.networkName + "_auto")
            manual_save_model_neural_network(self.network,
                                             self.networkName + "_auto")

        # Saving trained network
        ret = QMessageBox.question(
            self, "Network save",
            "Would you like to save the trained network? ",
            QMessageBox.Yes | QMessageBox.No)
        if ret == QMessageBox.Yes:
            save_matrix_neural_network(self.network, self.networkName)
            manual_save_model_neural_network(self.network, self.networkName)
            QMessageBox.information(self, "Network save",
                                    "Network successfully saved !",
                                    QMessageBox.Ok)

        # Evaluate network and show confusion matrix
        (self.test_accuracy,
         self.matrix) = self.network.evaluate(x_test, y_test)
        self.__confusionMatrixButton.show()

    # Showing the confusion matrix
    @Slot()
    def showStats(self):
        # Creating matrix window
        self.matrixWindow = QMainWindow()
        self.matrixWindow.setFixedSize(640, 480)

        key_list = list(self.__classes.keys())
        val_list = list(self.__classes.values())

        # Creating matrix table
        self.matrixTable = QTableWidget(
            len(self.matrix) + 1,
            len(self.matrix[0]) + 1)
        for i in range(len(self.matrix)):
            self.matrixTable.setItem(
                i + 1, 0, QTableWidgetItem(str(key_list[val_list.index(i)])))
            self.matrixTable.setItem(
                0, i + 1, QTableWidgetItem(str(key_list[val_list.index(i)])))

        for i in range(len(self.matrix)):
            for j in range(len(self.matrix[0])):
                self.matrixTable.setItem(
                    i + 1, j + 1, QTableWidgetItem(str(self.matrix[i][j])))

        # Printing test accuracy
        self.matrixLabel = QLabel(
            "Test accuracy : " + str(self.test_accuracy * 100) + "%", self)
        self.matrixLabel.setFont(QFont("BebasNeue", 20, QFont.Bold))

        # Matrix window layout
        self.matrixLayout = QVBoxLayout()
        self.matrixLayout.addWidget(self.matrixTable)
        self.matrixLayout.addWidget(self.matrixLabel)

        # Matrix window groupbox
        self.matrixGroupBox = QGroupBox(self.matrixWindow)
        self.matrixGroupBox.setLayout(self.matrixLayout)

        # Showing the matrix window
        self.matrixWindow.setCentralWidget(self.matrixGroupBox)
        self.matrixWindow.show()

    def __init__(self, ds, classes, model, created, *args, **kwargs):
        super(TrainingWindow, self).__init__(*args, **kwargs)

        if created:
            # Initialising network
            self.network = Network()
            self.network.use(mean_squared_error, mean_squared_error_prime)
        else:
            self.network = model[0]

        # Getting inputs and outputs
        self.__dataset = ds
        self.__classes = classes
        #fill_missing_values(self.__dataset)
        #min_max_normalize_dataset(self.__dataset)
        ((x_train, y_train),
         (x_test, y_test)) = self.network.regular_split(self.__dataset, 0.5)

        # Getting inputs
        if len(x_train.shape) == 2:
            inputs = x_train.shape[1]
        else:
            inputs = x_train.shape[1:]
            first = inputs[0]
            second = inputs[1]
            third = inputs[2]

        # Getting expected outputs
        expected_output = y_train.shape[1]

        # Getting network name
        self.networkName = model[1]

        if created:
            # Getting model list
            self.__modelList = model[0]
            self.__modelList[0].setOutput(inputs)

            for i in range(1, len(self.__modelList)):
                # Getting the layer name
                name = self.__modelList[i].text(
                )[:len(self.__modelList[i].text()) - 6]
                activation = None
                activ_prime = None
                # Getting the activation function
                if self.__modelList[i].activation() == 0:
                    activation = sigmoid
                    activ_prime = sigmoid_prime
                elif self.__modelList[i].activation() == 1:
                    activation = tanh
                    activ_prime = tanh_prime
                elif self.__modelList[i].activation() == 2:
                    activation = rectified_linear_unit
                    activ_prime = rectified_linear_unit_prime
                elif self.__modelList[i].activation() == 3:
                    activation = softmax
                    activ_prime = softmax_prime
                # Adding layer to the network
                if name == "Dense":
                    if self.__modelList[i - 1].text()[:2] == "Fl":
                        self.network.add(
                            FullyConnectedLayer(first * second * third,
                                                self.__modelList[i].output()))
                        self.network.add(
                            ActivationLayer(activation, activ_prime))
                    else:
                        self.network.add(
                            FullyConnectedLayer(
                                self.__modelList[i - 1].output(),
                                self.__modelList[i].output()))
                        self.network.add(
                            ActivationLayer(activation, activ_prime))
                elif name == "Flatten":
                    self.network.add(FlattenLayer())
                elif name == "Convolutional":
                    self.network.add(
                        ConvLayer((first, second, third),
                                  (self.__modelList[i].kernelRows,
                                   self.__modelList[i].kernelColumns), 1))
                    self.network.add(ActivationLayer(activation, activ_prime))
                    first = first - self.__modelList[i].kernelRows + 1
                    second = second - self.__modelList[i].kernelColumns + 1

            self.network.add(
                FullyConnectedLayer(
                    self.__modelList[len(self.__modelList) - 1].output(),
                    expected_output))
            self.network.add(ActivationLayer(sigmoid, sigmoid_prime))

        # Loading Fonts
        QFontDatabase.addApplicationFont("fonts/BebasNeue-Light.ttf")

        # Window Settings
        self.setFixedSize(1280, 720)
        self.setWindowTitle("Training window")
        #background = QPixmap("images/menu")
        #palette = QPalette()
        #palette.setBrush(QPalette.Background, background)
        #self.setAttribute(Qt.WA_StyledBackground, True)
        #self.setPalette(palette)
        self.setAutoFillBackground(True)

        # Stylesheet Settings
        styleFile = QFile("stylesheets/training.qss")
        styleFile.open(QFile.ReadOnly)
        style = str(styleFile.readAll())
        self.setStyleSheet(style)

        # Title Settings
        self.title = QLabel("Training", self)
        self.title.setFont(QFont("BebasNeue", 30, QFont.Bold))
        self.title.setAlignment(Qt.AlignCenter)
        self.title.setGeometry(600, 10, 300, 120)

        # Epochs line edit settings
        self.__epochsLineEdit = QLineEdit(self)
        self.__epochsLineEdit.setValidator(QIntValidator(0, 100000, self))

        # Epochs label settings
        self.__epochsLabel = QLabel("Epoch number", self)
        self.__epochsLabel.setFont(QFont("BebasNeue", 20, QFont.Bold))

        # Learning rate line edit settings
        self.__learningRateLineEdit = QLineEdit(self)
        self.__learningRateLineEdit.setValidator(
            QDoubleValidator(0.0, 1.0, 3, self))

        # Learning rate label settings
        self.__learningRateLabel = QLabel("Learning rate", self)
        self.__learningRateLabel.setFont(QFont("BebasNeue", 20, QFont.Bold))

        # Learning rate checkboxsettings (auto or not)
        self.__learningRateCheckBox = QCheckBox("Auto adjustment", self)
        self.__learningRateCheckBox.setFont(QFont("BebasNeue", 15, QFont.Bold))

        # Dataset split settings label
        self.__datasetSplitLabel = QLabel("Dataset split percentage", self)
        self.__datasetSplitLabel.setFont((QFont("BebasNeue", 20, QFont.Bold)))

        # Dataset split mode buttons
        self.__datasetSplitRegular = QRadioButton("Regular split")
        self.__datasetSplitRandom = QRadioButton("Random split")

        # Dataset split mode buttons groupbox
        self.__datasetSplitModeButtonsLayout = QHBoxLayout(self)
        self.__datasetSplitModeButtonsGroupBox = QGroupBox(self)
        self.__datasetSplitModeButtonsGroupBox.setObjectName("setting")
        self.__datasetSplitModeButtonsLayout.addWidget(
            self.__datasetSplitRegular)
        self.__datasetSplitModeButtonsLayout.addWidget(
            self.__datasetSplitRandom)
        self.__datasetSplitModeButtonsGroupBox.setLayout(
            self.__datasetSplitModeButtonsLayout)
        self.__datasetSplitRegular.setChecked(True)

        # Dataset split combo box settings
        self.__datasetSplitComboBox = QComboBox(self)
        self.__datasetSplitComboBox.addItems(
            ['90% - 10%', '80% - 20%', '70% - 30%', '60% - 40%'])

        # Dataset split form layout settings
        self.__datasetSplitLayout = QFormLayout(self)
        self.__datasetSplitGroupBox = QGroupBox(self)
        self.__datasetSplitGroupBox.setObjectName("setting")
        self.__datasetSplitLayout.addWidget(self.__datasetSplitLabel)
        self.__datasetSplitLayout.addWidget(self.__datasetSplitComboBox)
        self.__datasetSplitGroupBox.setLayout(self.__datasetSplitLayout)

        # Epochs form layout settings
        self.__epochsFormLayout = QFormLayout(self)
        self.__epochsGroupBox = QGroupBox(self)
        self.__epochsGroupBox.setObjectName("setting")
        self.__epochsFormLayout.addWidget(self.__epochsLabel)
        self.__epochsFormLayout.addWidget(self.__epochsLineEdit)
        self.__epochsGroupBox.setLayout(self.__epochsFormLayout)

        # Learning rate form layout settings
        self.__learningRateFormLayout = QFormLayout(self)
        self.__learningRateGroupBox = QGroupBox(self)
        self.__learningRateGroupBox.setObjectName("setting")
        self.__learningRateFormLayout.addWidget(self.__learningRateLabel)
        self.__learningRateFormLayout.addWidget(self.__learningRateCheckBox)
        self.__learningRateFormLayout.addWidget(self.__learningRateLineEdit)
        self.__learningRateGroupBox.setLayout(self.__learningRateFormLayout)

        # Epochs number label
        self.__epochNumberLabel = QLabel("Epoch : ", self)
        self.__epochNumberLabel.setFont((QFont("BebasNeue", 15, QFont.Bold)))

        # Training accuracy label
        self.__trainingAccuracyLabel = QLabel("Accuracy : ", self)
        self.__trainingAccuracyLabel.setFont((QFont("BebasNeue", 15,
                                                    QFont.Bold)))

        # Training stats layout
        self.__trainingStatsLayout = QVBoxLayout(self)
        self.__trainingStatsGroupBox = QGroupBox(self)
        self.__trainingStatsLayout.addWidget(self.__epochNumberLabel)
        self.__trainingStatsLayout.addWidget(self.__trainingAccuracyLabel)
        self.__trainingStatsGroupBox.setLayout(self.__trainingStatsLayout)
        self.__trainingStatsGroupBox.setGeometry(1000, -30, 300, 150)

        # Training button settings
        self.__trainingButton = QPushButton("Start", self)
        self.__trainingButton.setCursor(Qt.PointingHandCursor)
        self.__trainingButton.setFont((QFont("BebasNeue", 30, QFont.Bold)))
        self.__trainingButton.clicked.connect(self.startTraining)

        # Go back button
        self.goBackButton = QPushButton("Back", self)
        self.goBackButton.setObjectName("retour")

        # Customising go back button
        self.goBackButton.setCursor(Qt.PointingHandCursor)
        self.goBackButton.setIcon(QIcon("images/goback_icon"))
        self.goBackButton.setIconSize(QSize(30, 30))
        self.goBackButton.setFont(QFont("BebasNeue", 20, QFont.Bold))

        # Confusion matrix button
        self.__confusionMatrixButton = QPushButton("Show confusion matrix",
                                                   self)
        self.__confusionMatrixButton.setCursor(Qt.PointingHandCursor)
        self.__confusionMatrixButton.setFont((QFont("BebasNeue", 17,
                                                    QFont.Bold)))
        self.__confusionMatrixButton.clicked.connect(self.showStats)
        self.__confusionMatrixButton.setGeometry(420, 20, 250, 80)
        self.__confusionMatrixButton.hide()

        # Parameters group box settings
        self.__parametersGroupBox = QGroupBox("Training parameters", self)
        self.__parametersGroupBox.setObjectName("parameters")
        self.__parametersLayout = QVBoxLayout(self)
        self.__parametersLayout.addWidget(self.__epochsGroupBox)
        self.__parametersLayout.addWidget(self.__datasetSplitGroupBox)
        self.__parametersLayout.addWidget(
            self.__datasetSplitModeButtonsGroupBox)
        self.__parametersLayout.addWidget(self.__learningRateGroupBox)
        self.__parametersLayout.addWidget(self.__trainingButton)
        self.__parametersLayout.addWidget(self.goBackButton)
        self.__parametersGroupBox.setLayout(self.__parametersLayout)
        self.__parametersGroupBox.setGeometry(0, 0, 400, 720)

        # Chart axis settings
        self.__xAxis = QtCharts.QValueAxis()
        self.__xAxis.setRange(0, 5)

        self.__yAxis = QtCharts.QValueAxis()
        self.__yAxis.setRange(0, 100)

        # Chart settings
        self.__series = QtCharts.QLineSeries()
        self.__chart = QtCharts.QChart()
        self.__chart.addAxis(self.__xAxis, Qt.AlignBottom)
        self.__chart.addAxis(self.__yAxis, Qt.AlignLeft)
        self.__chart.addSeries(self.__series)
        self.__series.attachAxis(self.__xAxis)
        self.__series.attachAxis(self.__yAxis)
        self.__chart.setTitle("Accuracy")
        self.__chartView = QtCharts.QChartView(self.__chart)
        self.__chartView.setRenderHint(QPainter.Antialiasing)

        # Chart layout settings
        self.__chartLayout = QVBoxLayout(self)
        self.__chartGroupBox = QGroupBox(self)
        self.__chartGroupBox.setObjectName("chart")
        self.__chartLayout.addWidget(self.__chartView)
        self.__chartGroupBox.setLayout(self.__chartLayout)
        self.__chartGroupBox.setGeometry(390, 100, 900, 600)

        # Update timer settings
        #self.__timer = QTimer(self)
        #self.__timer.timeout.connect(self.autoSave)
        #self.__timer.start(1000)


#app = QApplication(sys.argv)
#window = TrainingWindow()
#window.show()
#app.exec_()