def main(): trains, chs = read_trains() n_net = NeuralNetwork() if os.path.exists(TRAINED_FPATH): n_net.load_trained_file(TRAINED_FPATH) else: n_net.set_trains(trains) train(n_net, DESIRED_ACCURACY) n_net.save_trained_file(TRAINED_FPATH) test_inputs(n_net, chs)
class MainWnd(QWidget): def __init__(self, parent=None): super(MainWnd, self).__init__(parent) self.grid = Grid(self) self.grid.move(10, 10) self.next_g = QPushButton("&Next", self) self.next_g.clicked.connect(self.on_next) self.next_g.move(self.grid.width() + SPACING * 2, SPACING) self.next_g.resize(PANEL_ITEM_W, PANEL_ITEM_H) self.next_g.show() self.prev_g = QPushButton("&Previous", self) self.prev_g.clicked.connect(self.on_prev) self.prev_g.move(self.grid.width() + SPACING * 2, SPACING * 2 + PANEL_ITEM_H) self.prev_g.resize(PANEL_ITEM_W, PANEL_ITEM_H) self.prev_g.show() self.clear_g = QPushButton("&Clear", self) self.clear_g.clicked.connect(self.on_clear) self.clear_g.move(self.grid.width() + SPACING * 2, SPACING * 3 + PANEL_ITEM_H * 2) self.clear_g.resize(PANEL_ITEM_W, PANEL_ITEM_H) self.clear_g.show() self.train_g = QPushButton("&Train", self) self.train_g.clicked.connect(self.on_train) self.train_g.move(self.grid.width() + SPACING * 2, SPACING * 4 + PANEL_ITEM_H * 3) self.train_g.resize(PANEL_ITEM_W, PANEL_ITEM_H) self.train_g.show() self.out_g = QLabel(self) self.out_g.setAlignment(Qt.AlignCenter) self.out_g.setFont(QFont(None, 30)) self.out_g.move(self.grid.width() + SPACING * 2, SPACING * 5 + PANEL_ITEM_H * 4) self.out_g.resize(PANEL_ITEM_W, PANEL_ITEM_H) self.out_g.show() try: self.trains, self.chs = read_trains() except IOError: self.trains, self.chs = [], [] self.n_net = NeuralNetwork() try: self.n_net.load_trained_file(TRAINED_FPATH) except IOError: pass try: self.tests = read_tests() random.shuffle(self.tests) except IOError: self.tests = [] self.test_idx = -1 self.cur_test = None if self.tests: self.set_test_idx(0) self.train_done.connect(self.on_train_done) self.setWindowTitle("Optical Character Recognition") self.resize(self.grid.width() + SPACING * 3 + PANEL_ITEM_W, self.grid.height() + SPACING * 2) self.show() def set_test_idx(self, idx): if idx < 0 or idx >= len(self.tests): raise IndexError, "out of range" self.cur_test = self.tests[idx] self.test_idx = idx self.grid.load_vec(self.cur_test[1]) self.test_grid() def on_next(self): try: self.set_test_idx(self.test_idx + 1) except IndexError: QMessageBox.warning(self, None, "No more tests.") def on_prev(self): try: self.set_test_idx(self.test_idx - 1) except IndexError: QMessageBox.warning(self, None, "No more tests.") def on_clear(self): self.grid.clear() self.test_grid() def on_train(self): QMessageBox.warning(self, None, "Training sequence started. May take several minitues to finish.") def proc(): self.n_net.set_trains(self.trains) train(self.n_net, DESIRED_ACCURACY) self.n_net.save_trained_file(TRAINED_FPATH) self.train_done.emit() thread.start_new_thread(proc, ()) @Slot() def on_train_done(self): QMessageBox.warning(self, None, "Training done.") train_done = Signal() def test_grid(self): try: ch = test_input(self.n_net, self.grid.get_vec(), self.chs) except ValueError: ch = "?" self.out_g.setText(ch)