Ejemplo n.º 1
0
def main():
	trains, chs = read_trains()

	n_net = NeuralNetwork()

	if os.path.exists(TRAINED_FPATH):
		n_net.load_trained_file(TRAINED_FPATH)
	else:
		n_net.set_trains(trains)
		train(n_net, DESIRED_ACCURACY)
		n_net.save_trained_file(TRAINED_FPATH)

	test_inputs(n_net, chs)
Ejemplo n.º 2
0
class MainWnd(QWidget):
    def __init__(self, parent=None):
        super(MainWnd, self).__init__(parent)

        self.grid = Grid(self)
        self.grid.move(10, 10)

        self.next_g = QPushButton("&Next", self)
        self.next_g.clicked.connect(self.on_next)
        self.next_g.move(self.grid.width() + SPACING * 2, SPACING)
        self.next_g.resize(PANEL_ITEM_W, PANEL_ITEM_H)
        self.next_g.show()

        self.prev_g = QPushButton("&Previous", self)
        self.prev_g.clicked.connect(self.on_prev)
        self.prev_g.move(self.grid.width() + SPACING * 2, SPACING * 2 + PANEL_ITEM_H)
        self.prev_g.resize(PANEL_ITEM_W, PANEL_ITEM_H)
        self.prev_g.show()

        self.clear_g = QPushButton("&Clear", self)
        self.clear_g.clicked.connect(self.on_clear)
        self.clear_g.move(self.grid.width() + SPACING * 2, SPACING * 3 + PANEL_ITEM_H * 2)
        self.clear_g.resize(PANEL_ITEM_W, PANEL_ITEM_H)
        self.clear_g.show()

        self.train_g = QPushButton("&Train", self)
        self.train_g.clicked.connect(self.on_train)
        self.train_g.move(self.grid.width() + SPACING * 2, SPACING * 4 + PANEL_ITEM_H * 3)
        self.train_g.resize(PANEL_ITEM_W, PANEL_ITEM_H)
        self.train_g.show()

        self.out_g = QLabel(self)
        self.out_g.setAlignment(Qt.AlignCenter)
        self.out_g.setFont(QFont(None, 30))
        self.out_g.move(self.grid.width() + SPACING * 2, SPACING * 5 + PANEL_ITEM_H * 4)
        self.out_g.resize(PANEL_ITEM_W, PANEL_ITEM_H)
        self.out_g.show()

        try:
            self.trains, self.chs = read_trains()
        except IOError:
            self.trains, self.chs = [], []

        self.n_net = NeuralNetwork()
        try:
            self.n_net.load_trained_file(TRAINED_FPATH)
        except IOError:
            pass

        try:
            self.tests = read_tests()
            random.shuffle(self.tests)
        except IOError:
            self.tests = []
        self.test_idx = -1
        self.cur_test = None
        if self.tests:
            self.set_test_idx(0)

        self.train_done.connect(self.on_train_done)

        self.setWindowTitle("Optical Character Recognition")
        self.resize(self.grid.width() + SPACING * 3 + PANEL_ITEM_W, self.grid.height() + SPACING * 2)
        self.show()

    def set_test_idx(self, idx):
        if idx < 0 or idx >= len(self.tests):
            raise IndexError, "out of range"

        self.cur_test = self.tests[idx]
        self.test_idx = idx

        self.grid.load_vec(self.cur_test[1])
        self.test_grid()

    def on_next(self):
        try:
            self.set_test_idx(self.test_idx + 1)
        except IndexError:
            QMessageBox.warning(self, None, "No more tests.")

    def on_prev(self):
        try:
            self.set_test_idx(self.test_idx - 1)
        except IndexError:
            QMessageBox.warning(self, None, "No more tests.")

    def on_clear(self):
        self.grid.clear()
        self.test_grid()

    def on_train(self):
        QMessageBox.warning(self, None, "Training sequence started. May take several minitues to finish.")

        def proc():
            self.n_net.set_trains(self.trains)
            train(self.n_net, DESIRED_ACCURACY)
            self.n_net.save_trained_file(TRAINED_FPATH)

            self.train_done.emit()

        thread.start_new_thread(proc, ())

    @Slot()
    def on_train_done(self):
        QMessageBox.warning(self, None, "Training done.")

    train_done = Signal()

    def test_grid(self):
        try:
            ch = test_input(self.n_net, self.grid.get_vec(), self.chs)
        except ValueError:
            ch = "?"

        self.out_g.setText(ch)