Ejemplo n.º 1
0
 def predict_mode(self):
     if self.train_mode is True:
         self.train_mode = False
         self.logger_test.push(UiLogger.Item(UiLogger.LEVEL_WARNING, 'switch', '切换到预测模式'))
         self.t2 = 0
     else:
         self.train_mode = True
         self.logger_test.push(UiLogger.Item(UiLogger.LEVEL_WARNING, 'switch', '切换到训练模式'))
         self.t2 = 0
Ejemplo n.º 2
0
    def __init__(self, root=None):
        self.init_top = Tk()

        self.port_left = 'COM4'
        self.port_right = 'COM5'

        self.init_bps = StringVar()
        self.init_bps.set('115200')
        self.init_com_left = StringVar()
        self.init_com_left.set(self.port_left)
        self.init_com_right = StringVar()
        self.init_com_right.set(self.port_right)

        self.init_communication()

        self.bps = 115200
        self.comm = None
        self.n = 512
        self.select = 24
        self.frames = [[0 for i in range(12)] for j in range(self.n)]
        self.raw = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] for j in range(self.n)]

        # 建立网络
        self.model_file = 'mc_actions.h5'

        # 建立网络的过程放在线程2

        # 初始化手柄连接
        self.comm_left = BaseComm(self.init_com_left.get(), self.bps)
        self.comm_right = BaseComm(self.init_com_right.get(), self.bps)

        # 灵敏度
        self.sensitivity = 0.3

        # 先读取几次,读取Init标志
        print('预读取:')
        for i in range(10):
            r1 = self.comm_left.read1epoch()
            r2 = self.comm_right.read1epoch()
            print(r1, r2)

        # 初始化某些变化的值
        # 加上Uilogger
        print('请保持手柄水平放置不动')
        # 初始化遥感的中点值(取平均)       X  Y    X  Y
        self.ave_left, self.ave_right = [0, 0], [0, 0]
        pick = 10
        for i in range(pick):
            data = self.comm_left.read1epoch()
            data_ctrl = data[-4:]
            data_ctrl = list(map(int, data_ctrl))
            self.ave_left[0] += data_ctrl[2] / pick
            self.ave_left[1] += data_ctrl[1] / pick

            data = self.comm_right.read1epoch()
            data_ctrl = data[-4:]
            data_ctrl = list(map(int, data_ctrl))
            self.ave_right[0] += data_ctrl[2] / pick
            self.ave_right[1] += data_ctrl[1] / pick

        print('初始化遥感中点:', self.ave_left, self.ave_right)

        self.root = root
        if self.root is None:
            self.root = Tk()
        self.root.title("MC手柄")

        self.logger = UiLogger(self.root, height=10, width=32)
        self.logger.logger().grid(row=2, column=1, sticky=W+E)
        self.panel = Label(self.root)
        self.panel.grid(row=1, column=1, sticky=W+E)

        self.lock = threading.Lock()

        t = threading.Thread(target=self.read_thread)
        t.setDaemon(True)
        t.start()
        t = threading.Thread(target=self.parse_thread)
        t.setDaemon(True)
        t.start()
Ejemplo n.º 3
0
    def parse_thread(self):
        # 读取神经网络模型
        model = load_model(self.model_file)

        t1, t2 = 0, 0

        click = False
        key1 = ctrl.ACTION_NONE
        key2 = ctrl.ACTION_NONE
        jump = ctrl.ACTION_NONE

        start = time.time()

        while True:
            if t1 == 5:
                im = self.draw()
                imp = ImageTk.PhotoImage(image=im)
                self.panel.configure(image=imp)
                self.panel.image = imp
                t1 = 0
            t1 += 1

            time.sleep(0.01)

            # data_left = self.comm_left.read1epoch()
            # data_right = self.comm_right.read1epoch()
            self.lock.acquire()
            data_left = self.raw[-1][0]
            data_right = self.raw[-1][1]
            self.lock.release()

            # 右手处理
            right_ctrl = data_right[-4:]
            right_ctrl = list(map(int, right_ctrl))
            ctrl.move((right_ctrl[2] - self.ave_left[0]) * self.sensitivity,
                      (right_ctrl[1] - self.ave_left[1]) * self.sensitivity)
            if right_ctrl[0] == 0 and click is False:
                ctrl.left_down()
                click = True
            if right_ctrl[0] == 1 and click is True:
                ctrl.left_up()
                click = False
            # Jump
            if data_left[6] == 0 and jump == ctrl.ACTION_NONE:
                jump = ctrl.ACTION_UP
                ctrl.kbd_down(VirtualKeyCode.SPACEBAR)
                # ctrl.kbd_click(VirtualKeyCode.SPACEBAR)
            if data_left[6] == 1 and jump == ctrl.ACTION_UP:
                jump = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.SPACEBAR)
                # ctrl.kbd_click(VirtualKeyCode.SPACEBAR)

            # 左手处理
            pos = data_left[-4:][1:3]
            # Right
            if pos[1] > 800 and key1 == ctrl.ACTION_NONE:
                key1 = ctrl.ACTION_D
                ctrl.kbd_down(VirtualKeyCode.D_key)
            if pos[1] <= 800 and key1 == ctrl.ACTION_D:
                key1 = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.D_key)
            # Left
            if pos[1] < 200 and key1 == ctrl.ACTION_NONE:
                key1 = ctrl.ACTION_A
                ctrl.kbd_down(VirtualKeyCode.A_key)
            if pos[1] >= 200 and key1 == ctrl.ACTION_A:
                key1 = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.A_key)

            # Backward
            if pos[0] < 200 and key2 == ctrl.ACTION_NONE:
                key2 = ctrl.ACTION_W
                ctrl.kbd_down(VirtualKeyCode.W_key)
            if pos[0] >= 200 and key2 == ctrl.ACTION_W:
                key2 = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.W_key)
            # Forward
            if pos[0] > 800 and key2 == ctrl.ACTION_NONE:
                key2 = ctrl.ACTION_S
                ctrl.kbd_down(VirtualKeyCode.S_key)
            if pos[0] <= 800 and key2 == ctrl.ACTION_S:
                key2 = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.S_key)

            # 处理神经网络判断
            t2 += 1

            # 隔一段时间再判断
            if t2 == 15:
                t2 = 0
                self.lock.acquire()
                x = np.array(self.frames[len(self.frames) - self.select:])
                self.lock.release()
                x = x.reshape((1, x.size))
                # print('X shape:', x.shape)
                # res = model.train_on_batch(x=x, y=y)
                predict = model.predict(x=x)[0]
                predict = predict.tolist()
                res = predict.index(max(predict))
                res = self.ACTIONS[res]
                # print('predict:', res)
                self.logger.push(UiLogger.Item(UiLogger.LEVEL_INFO, 'predict %.2f' % (time.time() - start), '%s' % res))
Ejemplo n.º 4
0
class MCHandle:
    ACTION_NONE = '无动作'
    ACTION_FORWARD = '前进'
    ACTION_JUMP = '起跳'
    ACTION_DOWN = '下降'
    ACTION_HIT = '打击'
    ACTION_PUT = '放置'
    ACTIONS = [ACTION_NONE, ACTION_FORWARD, ACTION_JUMP, ACTION_DOWN, ACTION_HIT, ACTION_PUT]

    def __init__(self, root=None):
        self.init_top = Tk()

        self.port_left = 'COM4'
        self.port_right = 'COM5'

        self.init_bps = StringVar()
        self.init_bps.set('115200')
        self.init_com_left = StringVar()
        self.init_com_left.set(self.port_left)
        self.init_com_right = StringVar()
        self.init_com_right.set(self.port_right)

        self.init_communication()

        self.bps = 115200
        self.comm = None
        self.n = 512
        self.select = 24
        self.frames = [[0 for i in range(12)] for j in range(self.n)]
        self.raw = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] for j in range(self.n)]

        # 建立网络
        self.model_file = 'mc_actions.h5'

        # 建立网络的过程放在线程2

        # 初始化手柄连接
        self.comm_left = BaseComm(self.init_com_left.get(), self.bps)
        self.comm_right = BaseComm(self.init_com_right.get(), self.bps)

        # 灵敏度
        self.sensitivity = 0.3

        # 先读取几次,读取Init标志
        print('预读取:')
        for i in range(10):
            r1 = self.comm_left.read1epoch()
            r2 = self.comm_right.read1epoch()
            print(r1, r2)

        # 初始化某些变化的值
        # 加上Uilogger
        print('请保持手柄水平放置不动')
        # 初始化遥感的中点值(取平均)       X  Y    X  Y
        self.ave_left, self.ave_right = [0, 0], [0, 0]
        pick = 10
        for i in range(pick):
            data = self.comm_left.read1epoch()
            data_ctrl = data[-4:]
            data_ctrl = list(map(int, data_ctrl))
            self.ave_left[0] += data_ctrl[2] / pick
            self.ave_left[1] += data_ctrl[1] / pick

            data = self.comm_right.read1epoch()
            data_ctrl = data[-4:]
            data_ctrl = list(map(int, data_ctrl))
            self.ave_right[0] += data_ctrl[2] / pick
            self.ave_right[1] += data_ctrl[1] / pick

        print('初始化遥感中点:', self.ave_left, self.ave_right)

        self.root = root
        if self.root is None:
            self.root = Tk()
        self.root.title("MC手柄")

        self.logger = UiLogger(self.root, height=10, width=32)
        self.logger.logger().grid(row=2, column=1, sticky=W+E)
        self.panel = Label(self.root)
        self.panel.grid(row=1, column=1, sticky=W+E)

        self.lock = threading.Lock()

        t = threading.Thread(target=self.read_thread)
        t.setDaemon(True)
        t.start()
        t = threading.Thread(target=self.parse_thread)
        t.setDaemon(True)
        t.start()

    def init_communication(self):
        top = self.init_top
        frame = LabelFrame(top, text="连接设置")
        Label(frame, text="左手柄").grid(row=1, column=1)
        Entry(frame, textvariable=self.init_com_left).grid(row=1, column=2)
        Label(frame, text="右手柄").grid(row=2, column=1)
        Entry(frame, textvariable=self.init_com_right).grid(row=2, column=2)
        Label(frame, text="波特率").grid(row=3, column=1)
        Entry(frame, textvariable=self.init_bps).grid(row=3, column=2)
        frame.grid(row=1, columnspan=3, column=1)

        Button(top, text="测试", command=self.init_communication_test).grid(row=2, column=1, sticky=W + E)
        Button(top, text="刷新", command=self.init_communication_refresh).grid(row=2, column=2, sticky=W + E)
        Button(top, text="确定", command=self.init_communication_ok).grid(row=2, column=3, sticky=W + E)
        top.mainloop()

    def init_communication_ok(self):
        try:
            bps = int(self.init_bps.get())
        except ValueError:
            messagebox.showerror("错误", '数值错误!')
            return
        self.bps = bps
        self.port_left = self.init_com_left.get()
        self.port_right = self.init_com_right.get()
        if self.init_communication_test(show=False) is False:
            messagebox.showerror("错误", '手柄测试不通过!')
            return
        self.init_top.destroy()

    def mainloop(self):
        self.root.mainloop()

    def init_communication_test(self, show=True):
        try:
            bps = int(self.init_bps.get())
        except ValueError:
            messagebox.showerror("错误", '数值错误!')
            return
        res = True
        print('测试左手柄')
        comm = BaseComm(self.init_com_left.get(), bps)
        if not comm.test():
            if show is True:
                messagebox.showerror("错误", '测试左手柄失败')
            res = False
        comm.close()
        print('测试右手柄')
        comm = BaseComm(self.init_com_right.get(), bps)
        if not comm.test():
            if show is True:
                messagebox.showerror("错误", '测试右手柄失败')
            res = False
        comm.close()
        return res

    def init_communication_refresh(self):
        pass

    def draw(self):
        width = 1
        height = 32
        colors = [
            'red', 'orange', 'yellow', 'green', 'cyan', 'blue', 'purple',
            'red', 'orange', 'yellow', 'green', 'cyan', 'blue', 'purple',
        ]

        size = (width * self.n, height * 6)
        im = Image.new("RGB", size, color='white')
        draw = ImageDraw.Draw(im)
        for i in range(self.n - 2):
            for j in range(12):
                draw.line((width * i, self.frames[i][j] + size[1] / 2,
                           width * (i + 1), self.frames[i + 1][j] + size[1] / 2), fill=colors[j])
        sx = size[0] - width * self.select
        draw.line((sx, 0, sx, size[1]), fill='red')
        return im

    # 第二个线程,负责读取
    def read_thread(self):
        while True:
            time.sleep(0.01)
            data_left = self.comm_left.read1epoch()
            data_right = self.comm_right.read1epoch()
            self.lock.acquire()
            self.raw.append([data_left, data_right])
            if len(self.raw) > self.n:
                self.raw = self.raw[1:-1]
            self.lock.release()
            # frames添加数据
            ann = data_left[0:6]
            ann.extend(data_right[0:6])
            self.lock.acquire()
            self.frames.append(ann)
            if len(self.frames) > self.n:
                self.frames = self.frames[1:-1]
            self.lock.release()
            # print('ANN DATA:', ann)

    # 第三个线程,负责解析数据
    def parse_thread(self):
        # 读取神经网络模型
        model = load_model(self.model_file)

        t1, t2 = 0, 0

        click = False
        key1 = ctrl.ACTION_NONE
        key2 = ctrl.ACTION_NONE
        jump = ctrl.ACTION_NONE

        start = time.time()

        while True:
            if t1 == 5:
                im = self.draw()
                imp = ImageTk.PhotoImage(image=im)
                self.panel.configure(image=imp)
                self.panel.image = imp
                t1 = 0
            t1 += 1

            time.sleep(0.01)

            # data_left = self.comm_left.read1epoch()
            # data_right = self.comm_right.read1epoch()
            self.lock.acquire()
            data_left = self.raw[-1][0]
            data_right = self.raw[-1][1]
            self.lock.release()

            # 右手处理
            right_ctrl = data_right[-4:]
            right_ctrl = list(map(int, right_ctrl))
            ctrl.move((right_ctrl[2] - self.ave_left[0]) * self.sensitivity,
                      (right_ctrl[1] - self.ave_left[1]) * self.sensitivity)
            if right_ctrl[0] == 0 and click is False:
                ctrl.left_down()
                click = True
            if right_ctrl[0] == 1 and click is True:
                ctrl.left_up()
                click = False
            # Jump
            if data_left[6] == 0 and jump == ctrl.ACTION_NONE:
                jump = ctrl.ACTION_UP
                ctrl.kbd_down(VirtualKeyCode.SPACEBAR)
                # ctrl.kbd_click(VirtualKeyCode.SPACEBAR)
            if data_left[6] == 1 and jump == ctrl.ACTION_UP:
                jump = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.SPACEBAR)
                # ctrl.kbd_click(VirtualKeyCode.SPACEBAR)

            # 左手处理
            pos = data_left[-4:][1:3]
            # Right
            if pos[1] > 800 and key1 == ctrl.ACTION_NONE:
                key1 = ctrl.ACTION_D
                ctrl.kbd_down(VirtualKeyCode.D_key)
            if pos[1] <= 800 and key1 == ctrl.ACTION_D:
                key1 = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.D_key)
            # Left
            if pos[1] < 200 and key1 == ctrl.ACTION_NONE:
                key1 = ctrl.ACTION_A
                ctrl.kbd_down(VirtualKeyCode.A_key)
            if pos[1] >= 200 and key1 == ctrl.ACTION_A:
                key1 = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.A_key)

            # Backward
            if pos[0] < 200 and key2 == ctrl.ACTION_NONE:
                key2 = ctrl.ACTION_W
                ctrl.kbd_down(VirtualKeyCode.W_key)
            if pos[0] >= 200 and key2 == ctrl.ACTION_W:
                key2 = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.W_key)
            # Forward
            if pos[0] > 800 and key2 == ctrl.ACTION_NONE:
                key2 = ctrl.ACTION_S
                ctrl.kbd_down(VirtualKeyCode.S_key)
            if pos[0] <= 800 and key2 == ctrl.ACTION_S:
                key2 = ctrl.ACTION_NONE
                ctrl.kbd_up(VirtualKeyCode.S_key)

            # 处理神经网络判断
            t2 += 1

            # 隔一段时间再判断
            if t2 == 15:
                t2 = 0
                self.lock.acquire()
                x = np.array(self.frames[len(self.frames) - self.select:])
                self.lock.release()
                x = x.reshape((1, x.size))
                # print('X shape:', x.shape)
                # res = model.train_on_batch(x=x, y=y)
                predict = model.predict(x=x)[0]
                predict = predict.tolist()
                res = predict.index(max(predict))
                res = self.ACTIONS[res]
                # print('predict:', res)
                self.logger.push(UiLogger.Item(UiLogger.LEVEL_INFO, 'predict %.2f' % (time.time() - start), '%s' % res))
Ejemplo n.º 5
0
    def __init__(self, root=None):
        self.init_top = Tk()

        self.port_left = 'COM4'
        self.port_right = 'COM5'

        self.init_bps = StringVar()
        self.init_bps.set('115200')
        self.init_com_left = StringVar()
        self.init_com_left.set(self.port_left)
        self.init_com_right = StringVar()
        self.init_com_right.set(self.port_right)

        self.init_communication()

        self.bps = 115200
        self.comm = None
        self.n = 512
        self.select = 24
        self.frames = [[0 for i in range(12)] for j in range(self.n)]
        self.raw = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] for j in range(self.n)]

        # 建立网络
        self.model_file = 'mc_actions.h5'

        # 建立网络的过程放在线程2

        # self.model = model
        # print(self.model.get_config())

        self.comm_left = BaseComm(self.init_com_left.get(), self.bps)
        self.comm_right = BaseComm(self.init_com_right.get(), self.bps)

        self.root = root
        if self.root is None:
            self.root = Tk()
        self.root.title("MC手柄训练器")

        self.panel = Label(self.root)
        self.panel.pack(side=TOP, expand=1, fill=X)

        frame = Frame(self.root)
        Button(frame, text='切换模式', command=self.predict_mode).grid(row=1, column=1, sticky=W + E)
        Button(frame, text='前进', command=self.action_forward).grid(row=1, column=2, sticky=W + E)
        Button(frame, text='上跳', command=self.action_jump).grid(row=1, column=3, sticky=W + E)
        Button(frame, text='下降', command=self.action_down).grid(row=1, column=4, sticky=W + E)
        Button(frame, text='打击', command=self.action_hit).grid(row=1, column=5, sticky=W + E)
        Button(frame, text='放置', command=self.action_put).grid(row=1, column=6, sticky=W + E)
        Button(frame, text='无动作', command=self.action_none).grid(row=1, column=7, sticky=W + E)
        Button(frame, text='保存模型', command=self.save_model).grid(row=1, column=8, sticky=W + E)
        Label(frame, text='正在训练:').grid(row=1, column=9, sticky=W + E)

        self.var_training = StringVar()
        self.var_training.set('...')
        Label(frame, textvariable=self.var_training).grid(row=1, column=10, sticky=W + E)

        frame.pack(side=BOTTOM, expand=1, fill=X)

        self.logger_test = UiLogger(self.root, title='程序日志', simplify=False, height=10)
        self.logger_test.logger().pack(side=BOTTOM, expand=1, fill=X)

        self.lock = threading.Lock()

        self.training = self.ACTION_NONE
        self.will_save_model = False
        self.train_mode = True

        self.t1 = 0
        self.t2 = 0

        t = threading.Thread(target=self.read_thread)
        t.setDaemon(True)
        t.start()
        t = threading.Thread(target=self.parse_thread)
        t.setDaemon(True)
        t.start()
Ejemplo n.º 6
0
    def parse_thread(self):
        # 建模
        try:
            model = load_model(self.model_file)
        except OSError:
            print("Can't find", self.model_file)
            model = Sequential()
            model.add(Dense(self.select * 12, activation='tanh', input_dim=self.select * 12))
            model.add(Dense(self.select * 24, activation='tanh'))
            model.add(Dense(self.select * 32, activation='tanh'))
            model.add(Dense(self.select * 48, activation='tanh'))
            model.add(Dense(self.select * 32, activation='tanh'))
            model.add(Dense(self.select * 24, activation='tanh'))
            model.add(Dense(self.select * 12, activation='tanh'))
            model.add(Dense(self.select, activation='tanh'))
            model.add(Dense(6, activation='softmax'))

            model.compile(loss='binary_crossentropy', optimizer='adam')

        start = time.time()

        while True:
            self.var_training.set(self.training)

            # 只需要MPU数据
            self.lock.acquire()
            data_left = self.raw[-1][0][:6]
            data_right = self.raw[-1][1][:6]
            self.lock.release()
            # data_left = self.comm_left.read1epoch()[:6]
            # data_right = self.comm_right.read1epoch()[:6]
            # print(data)
            data = data_left
            data.extend(data_right)
            # print(data)
            # self.lock.acquire()
            # self.frames.append(data)
            # if len(self.frames) > self.n:
            #     self.frames = self.frames[1:-1]
            # self.lock.release()
            if self.t1 == 5:
                im = self.draw()
                imp = ImageTk.PhotoImage(image=im)
                self.panel.configure(image=imp)
                self.panel.image = imp
                self.t1 = 0
            self.t1 += 1

            # 开始训练
            if self.t2 == 5 and self.train_mode is True:
                self.lock.acquire()
                x = np.array(self.frames[len(self.frames) - self.select:])
                self.lock.release()
                x = x.reshape((1, x.size))
                # print('X shape:', x.shape)
                one = [0 for i in range(6)]
                one[self.ACTIONS.index(self.training)] = 1
                y = np.array(one)
                y = y.reshape((1, 6))
                # print('Y shape:', y.shape)
                self.t2 = 0
                res = model.train_on_batch(x=x, y=y)
                # print('train:', res)
                self.logger_test.push(UiLogger.Item(UiLogger.LEVEL_INFO, 'training', '%s' % res))

            self.t2 += 1

            if self.will_save_model is True:
                print('保存模型...')
                self.lock.acquire()
                model.save(self.model_file)
                self.will_save_model = False
                self.lock.release()

            # 预测模式
            if self.t2 == 5 and self.train_mode is False:
                self.t2 = 0
                self.lock.acquire()
                x = np.array(self.frames[len(self.frames) - self.select:])
                self.lock.release()
                x = x.reshape((1, x.size))
                # print('X shape:', x.shape)
                # res = model.train_on_batch(x=x, y=y)
                predict = model.predict(x=x)[0]
                predict = predict.tolist()
                res = predict.index(max(predict))
                res = self.ACTIONS[res]
                # print('predict:', res)
                self.logger_test.push(UiLogger.Item(UiLogger.LEVEL_INFO, 'predict %.2f' % (time.time() - start), '%s' % res))
Ejemplo n.º 7
0
class MCHandleTrainer:
    ACTION_NONE = '无动作'
    ACTION_FORWARD = '前进'
    ACTION_JUMP = '起跳'
    ACTION_DOWN = '下降'
    ACTION_HIT = '打击'
    ACTION_PUT = '放置'
    ACTIONS = [ACTION_NONE, ACTION_FORWARD, ACTION_JUMP, ACTION_DOWN, ACTION_HIT, ACTION_PUT]

    def __init__(self, root=None):
        self.init_top = Tk()

        self.port_left = 'COM4'
        self.port_right = 'COM5'

        self.init_bps = StringVar()
        self.init_bps.set('115200')
        self.init_com_left = StringVar()
        self.init_com_left.set(self.port_left)
        self.init_com_right = StringVar()
        self.init_com_right.set(self.port_right)

        self.init_communication()

        self.bps = 115200
        self.comm = None
        self.n = 512
        self.select = 24
        self.frames = [[0 for i in range(12)] for j in range(self.n)]
        self.raw = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] for j in range(self.n)]

        # 建立网络
        self.model_file = 'mc_actions.h5'

        # 建立网络的过程放在线程2

        # self.model = model
        # print(self.model.get_config())

        self.comm_left = BaseComm(self.init_com_left.get(), self.bps)
        self.comm_right = BaseComm(self.init_com_right.get(), self.bps)

        self.root = root
        if self.root is None:
            self.root = Tk()
        self.root.title("MC手柄训练器")

        self.panel = Label(self.root)
        self.panel.pack(side=TOP, expand=1, fill=X)

        frame = Frame(self.root)
        Button(frame, text='切换模式', command=self.predict_mode).grid(row=1, column=1, sticky=W + E)
        Button(frame, text='前进', command=self.action_forward).grid(row=1, column=2, sticky=W + E)
        Button(frame, text='上跳', command=self.action_jump).grid(row=1, column=3, sticky=W + E)
        Button(frame, text='下降', command=self.action_down).grid(row=1, column=4, sticky=W + E)
        Button(frame, text='打击', command=self.action_hit).grid(row=1, column=5, sticky=W + E)
        Button(frame, text='放置', command=self.action_put).grid(row=1, column=6, sticky=W + E)
        Button(frame, text='无动作', command=self.action_none).grid(row=1, column=7, sticky=W + E)
        Button(frame, text='保存模型', command=self.save_model).grid(row=1, column=8, sticky=W + E)
        Label(frame, text='正在训练:').grid(row=1, column=9, sticky=W + E)

        self.var_training = StringVar()
        self.var_training.set('...')
        Label(frame, textvariable=self.var_training).grid(row=1, column=10, sticky=W + E)

        frame.pack(side=BOTTOM, expand=1, fill=X)

        self.logger_test = UiLogger(self.root, title='程序日志', simplify=False, height=10)
        self.logger_test.logger().pack(side=BOTTOM, expand=1, fill=X)

        self.lock = threading.Lock()

        self.training = self.ACTION_NONE
        self.will_save_model = False
        self.train_mode = True

        self.t1 = 0
        self.t2 = 0

        t = threading.Thread(target=self.read_thread)
        t.setDaemon(True)
        t.start()
        t = threading.Thread(target=self.parse_thread)
        t.setDaemon(True)
        t.start()

    def predict_mode(self):
        if self.train_mode is True:
            self.train_mode = False
            self.logger_test.push(UiLogger.Item(UiLogger.LEVEL_WARNING, 'switch', '切换到预测模式'))
            self.t2 = 0
        else:
            self.train_mode = True
            self.logger_test.push(UiLogger.Item(UiLogger.LEVEL_WARNING, 'switch', '切换到训练模式'))
            self.t2 = 0

    def action_forward(self):
        if self.training == self.ACTION_FORWARD:
            self.training = self.ACTION_NONE
        else:
            self.training = self.ACTION_FORWARD

    def action_jump(self):
        if self.training == self.ACTION_JUMP:
            self.training = self.ACTION_NONE
        else:
            self.training = self.ACTION_JUMP

    def action_down(self):
        if self.training == self.ACTION_DOWN:
            self.training = self.ACTION_NONE
        else:
            self.training = self.ACTION_DOWN

    def action_hit(self):
        if self.training == self.ACTION_HIT:
            self.training = self.ACTION_NONE
        else:
            self.training = self.ACTION_HIT

    def action_put(self):
        if self.training == self.ACTION_PUT:
            self.training = self.ACTION_NONE
        else:
            self.training = self.ACTION_PUT

    def action_none(self):
        self.training = self.ACTION_NONE

    def save_model(self):
        self.will_save_model = True

    def init_communication(self):
        top = self.init_top
        frame = LabelFrame(top, text="连接设置")
        Label(frame, text="左手柄").grid(row=1, column=1)
        Entry(frame, textvariable=self.init_com_left).grid(row=1, column=2)
        Label(frame, text="右手柄").grid(row=2, column=1)
        Entry(frame, textvariable=self.init_com_right).grid(row=2, column=2)
        Label(frame, text="波特率").grid(row=3, column=1)
        Entry(frame, textvariable=self.init_bps).grid(row=3, column=2)
        frame.grid(row=1, columnspan=3, column=1)

        Button(top, text="测试", command=self.init_communication_test).grid(row=2, column=1, sticky=W+E)
        Button(top, text="刷新", command=self.init_communication_refresh).grid(row=2, column=2, sticky=W+E)
        Button(top, text="确定", command=self.init_communication_ok).grid(row=2, column=3, sticky=W+E)
        top.mainloop()

    def init_communication_ok(self):
        try:
            bps = int(self.init_bps.get())
        except ValueError:
            messagebox.showerror("错误", '数值错误!')
            return
        self.bps = bps
        self.port_left = self.init_com_left.get()
        self.port_right = self.init_com_right.get()
        if self.init_communication_test(show=False) is False:
            messagebox.showerror("错误", '手柄测试不通过!')
            return
        self.init_top.destroy()

    def mainloop(self):
        self.root.mainloop()

    def init_communication_test(self, show=True):
        try:
            bps = int(self.init_bps.get())
        except ValueError:
            messagebox.showerror("错误", '数值错误!')
            return
        res = True
        print('测试左手柄')
        comm = BaseComm(self.init_com_left.get(), bps)
        if not comm.test():
            if show is True:
                messagebox.showerror("错误", '测试左手柄失败')
            res = False
        comm.close()
        print('测试右手柄')
        comm = BaseComm(self.init_com_right.get(), bps)
        if not comm.test():
            if show is True:
                messagebox.showerror("错误", '测试右手柄失败')
            res = False
        comm.close()
        return res

    def init_communication_refresh(self):
        pass

    # 单个手柄数据读取
    def read_data(self, comm: BaseComm, q: queue.Queue):
        q.put(comm.read1epoch())

    # 第二个线程,负责读取
    def read_thread(self):
        while True:
            time.sleep(0.01)
            q_left = queue.Queue()
            q_right = queue.Queue()
            # data_left = self.comm_left.read1epoch()
            # data_right = self.comm_right.read1epoch()
            thread_left = threading.Thread(target=self.read_data, args=(self.comm_left, q_left))
            thread_right = threading.Thread(target=self.read_data, args=(self.comm_right, q_right))
            thread_left.setDaemon(True)
            thread_right.setDaemon(True)
            thread_left.start()
            thread_right.start()
            thread_left.join(5)
            thread_right.join(5)
            if q_left.empty() or q_right.empty():
                print('WARING: 数据读取失败!')
                continue
            data_left = q_left.get()
            data_right = q_right.get()
            self.lock.acquire()
            self.raw.append([data_left, data_right])
            if len(self.raw) > self.n:
                self.raw = self.raw[1:-1]
            self.lock.release()
            # frames添加数据
            ann = data_left[0:6]
            ann.extend(data_right[0:6])
            self.lock.acquire()
            self.frames.append(ann)
            if len(self.frames) > self.n:
                self.frames = self.frames[1:-1]
            self.lock.release()
            # print('ANN DATA:', ann)

    def parse_thread(self):
        # 建模
        try:
            model = load_model(self.model_file)
        except OSError:
            print("Can't find", self.model_file)
            model = Sequential()
            model.add(Dense(self.select * 12, activation='tanh', input_dim=self.select * 12))
            model.add(Dense(self.select * 24, activation='tanh'))
            model.add(Dense(self.select * 32, activation='tanh'))
            model.add(Dense(self.select * 48, activation='tanh'))
            model.add(Dense(self.select * 32, activation='tanh'))
            model.add(Dense(self.select * 24, activation='tanh'))
            model.add(Dense(self.select * 12, activation='tanh'))
            model.add(Dense(self.select, activation='tanh'))
            model.add(Dense(6, activation='softmax'))

            model.compile(loss='binary_crossentropy', optimizer='adam')

        start = time.time()

        while True:
            self.var_training.set(self.training)

            # 只需要MPU数据
            self.lock.acquire()
            data_left = self.raw[-1][0][:6]
            data_right = self.raw[-1][1][:6]
            self.lock.release()
            # data_left = self.comm_left.read1epoch()[:6]
            # data_right = self.comm_right.read1epoch()[:6]
            # print(data)
            data = data_left
            data.extend(data_right)
            # print(data)
            # self.lock.acquire()
            # self.frames.append(data)
            # if len(self.frames) > self.n:
            #     self.frames = self.frames[1:-1]
            # self.lock.release()
            if self.t1 == 5:
                im = self.draw()
                imp = ImageTk.PhotoImage(image=im)
                self.panel.configure(image=imp)
                self.panel.image = imp
                self.t1 = 0
            self.t1 += 1

            # 开始训练
            if self.t2 == 5 and self.train_mode is True:
                self.lock.acquire()
                x = np.array(self.frames[len(self.frames) - self.select:])
                self.lock.release()
                x = x.reshape((1, x.size))
                # print('X shape:', x.shape)
                one = [0 for i in range(6)]
                one[self.ACTIONS.index(self.training)] = 1
                y = np.array(one)
                y = y.reshape((1, 6))
                # print('Y shape:', y.shape)
                self.t2 = 0
                res = model.train_on_batch(x=x, y=y)
                # print('train:', res)
                self.logger_test.push(UiLogger.Item(UiLogger.LEVEL_INFO, 'training', '%s' % res))

            self.t2 += 1

            if self.will_save_model is True:
                print('保存模型...')
                self.lock.acquire()
                model.save(self.model_file)
                self.will_save_model = False
                self.lock.release()

            # 预测模式
            if self.t2 == 5 and self.train_mode is False:
                self.t2 = 0
                self.lock.acquire()
                x = np.array(self.frames[len(self.frames) - self.select:])
                self.lock.release()
                x = x.reshape((1, x.size))
                # print('X shape:', x.shape)
                # res = model.train_on_batch(x=x, y=y)
                predict = model.predict(x=x)[0]
                predict = predict.tolist()
                res = predict.index(max(predict))
                res = self.ACTIONS[res]
                # print('predict:', res)
                self.logger_test.push(UiLogger.Item(UiLogger.LEVEL_INFO, 'predict %.2f' % (time.time() - start), '%s' % res))

    def draw(self):
        width = 1
        height = 32
        colors = [
            'red', 'orange', 'yellow', 'green', 'cyan', 'blue', 'purple',
            'red', 'orange', 'yellow', 'green', 'cyan', 'blue', 'purple',
        ]

        size = (width * self.n, height * 6)
        im = Image.new("RGB", size, color='white')
        draw = ImageDraw.Draw(im)
        for i in range(self.n - 2):
            for j in range(12):
                draw.line((width * i, self.frames[i][j] + size[1] / 2,
                           width * (i + 1), self.frames[i + 1][j] + size[1] / 2), fill=colors[j])
        sx = size[0] - width * self.select
        draw.line((sx, 0, sx, size[1]), fill='red')
        return im