Esempio n. 1
0
class ModelWindow(QtWidgets.QMainWindow, Ui_ModelWindow):  #模型界面对应的类
    def __init__(self):
        super(ModelWindow, self).__init__()
        self.setupUi(self)
        self.global_id = 0  #每个层对应的唯一ID,用于基于DAG图的校验和代码生成
        self.nodes = dict()  #所有层的信息
        self.id_name = dict()  #ID-名字映射
        self.name_id = dict()  #名字-ID映射
        self.net = nx.DiGraph()  #图,节点为层的ID
        self.library = fclib.LIBRARY.copy()
        self.library.addNodeType(CovNode, [('CovNode', )])
        self.library.addNodeType(PoolNode, [('PoolNode', )])
        self.library.addNodeType(LinearNode, [('LinearNode', )])
        self.library.addNodeType(ConcatNode, [('ConcatNode', )])
        self.library.addNodeType(Concat1dNode, [('Concat1dNode', )])
        self.library.addNodeType(SoftmaxNode, [('SoftmaxNode', )])
        self.library.addNodeType(LogSoftmaxNode, [('LogSoftmaxNode', )])
        self.library.addNodeType(BachNorm1dNode, [('BachNorm1dNode', )])
        self.library.addNodeType(BachNorm2dNode, [('BachNorm2dNode', )])
        self.library.addNodeType(AddNode, [('ResAddNode', )])
        self.library.addNodeType(IdentityNode, [('IdentityNode', )])
        self.type_name = {
            2: 'Cov2d',
            3: 'Pool2d',
            4: 'Linear',
            5: 'Softmax',
            6: 'LogSoftmax',
            7: 'BachNorm1d',
            8: 'BachNorm2d',
            9: 'Res_Add',
            10: 'Concat2d',
            11: 'Concat1d',
            12: 'Identity'
        }
        self.fc = Flowchart()  #模型可视化的流程图,对应FlowChart按钮
        self.fc.setLibrary(self.library)  #引入FCNodes.py中定义的Node
        self.outputs = list()  #模型所有输出
        w = self.fc.widget()
        self.fc_inputs = dict()  #self.fc流程图的输入
        main_widget = QWidget()
        main_layout = QGridLayout()
        main_widget.setLayout(main_layout)
        self.detail = QTreeWidget()
        self.detail.setColumnCount(2)
        self.detail.setHeaderLabels(["属性", "值"])
        self.root = QTreeWidgetItem(self.detail)
        self.root.setText(0, "所有属性")
        main_layout.addWidget(self.fc.widget(), 0, 0, 1, 2)
        main_layout.addWidget(self.detail, 0, 2)
        self.setCentralWidget(main_widget)

    def add_layer(self):  #用于弹出“层-新建”动作对应的层操作界面
        self.addlayer_window = AddLayerWindow()
        self.addlayer_window.datasignal.connect(self.accept_layer)
        self.addlayer_window.show()

    def modifiey_layer(self):  #用于弹出“层-更改”动作对应的层操作界面
        self.addlayer_window = AddLayerWindow()
        self.addlayer_window.datasignal.connect(self.accept_layer)
        self.addlayer_window.layertype.addItem("恒等层(Identity)")
        self.addlayer_window.show()

    def clear(self):  #对应“层-清除”,清除当前模型
        self.fc.clear()
        self.id_name = dict()
        self.name_id = dict()
        self.nodes = dict()
        self.net = nx.DiGraph()
        self.fc.clear()
        self.fc_inputs = dict()
        self.global_id = 0
        self.outputs = list()
        self.detail.clear()
        self.detail.setColumnCount(2)
        self.detail.setHeaderLabels(["属性", "值"])
        self.root = QTreeWidgetItem(self.detail)
        self.root.setText(0, "所有属性")

    def to_train(self):  #对应“功能-训练”动作,跳转到训练界面
        self.train_window = TrainWindow()
        self.train_window.show()

    def to_test(self):  #对应“功能-测试”动作,跳转到测试界面
        self.test_window = TestWindow()
        self.test_window.show()

    def accept_layer(self, data):  #接收层操作界面中信号的槽函数,接收层操作界面传来的数据并显示在该界面上
        reset_flag = False
        if not data['type'] in [1, 12]:  #对输入层和恒等层不检查输入
            inputs = data['input'].split(";")
            data['input'] = list()
            for i in range(len(inputs)):  #判断连接是否合法
                try:
                    name = inputs[i]
                    layer = self.name_id[name]
                    if data['type'] == 3 and (self.nodes[inputs[i]]['type']
                                              not in [2, 8, 9, 10]):
                        QMessageBox.warning(self, "错误",
                                            "输入:{}层类型不合法".format(inputs[i]))
                        return 0
                    elif (data['type'] in [
                            5, 6, 7, 11
                    ]) and (self.nodes[inputs[i]]['type'] not in [4, 7, 11]):
                        QMessageBox.warning(self, "错误",
                                            "输入:{}层类型不合法".format(inputs[i]))
                        return 0
                    elif (data['type']
                          in [2, 8, 10]) and (self.nodes[inputs[i]]['type']
                                              not in [1, 2, 3, 8, 9, 10]):
                        QMessageBox.warning(self, "错误",
                                            "输入:{}层类型不合法".format(inputs[i]))
                        return 0
                    data['input'].append(layer)  #将input中的层名替换为对应的ID
                except:
                    QMessageBox.warning(self, "错误",
                                        "输入来自未生成的层:{}".format(inputs[i]))
                    return 0
        else:
            data['input'] = list()
        if data['type'] == 9:  #残差连接层判断两个输入size是否相同
            layer_id = data['input'][0]
            cur_size = self.nodes[self.id_name[layer_id]]['para']['out_size']
            for layer_id in data['input']:
                layer = self.nodes[self.id_name[layer_id]]
                if layer['para']['out_size'] != cur_size:
                    QMessageBox.warning(self, "错误", "输入尺寸不匹配")
                    return 0
        if data['name'] in self.name_id.keys():  #替换同名层的情形
            id = self.name_id[data['name']]
            for input_id in data['input']:
                if input_id >= id:
                    QMessageBox.warning(self, "错误", "输入来自后继层")
                    return 0
            if data['type'] == 12:
                data['para']['former_type'] = self.nodes[data['name']]['type']
                if len(data['input']) == 0:
                    data['input'].append(self.nodes[data['name']]['input'][0])
            self.net.remove_node(id)
            self.net.add_node(id)
            for i in data['input']:
                self.net.add_edge(i, id)
            self.fc.removeNode(self.fc.nodes()[data['name']])
            data['ID'] = id
            reset_flag = True
        else:  #新建层
            data['ID'] = self.global_id
            self.name_id[data['name']] = self.global_id
            self.id_name[self.global_id] = data['name']
            self.net.add_node(self.global_id)
            for i in data['input']:
                self.net.add_edge(i, self.global_id)
            self.global_id += 1
        self.nodes[data['name']] = data
        if data['type'] == 1:  #输入层对应的流程图操作
            self.fc.addInput(data['name'])
            self.fc_inputs[data['name']] = data['para']['out_size']
            self.fc.setInput(**self.fc_inputs)
        elif data['type'] in [9, 10, 11]:  #concat2D、concat1D和残差连接层的流程图操作
            node = self.fc.createNode(
                self.type_name[data['type']],
                name=data['name'],
                pos=(data['input'][0] * 120,
                     (data['ID'] - data['input'][0]) * 150 - 500))
            node.setPara(data['para'])
            node.setView(self.root)
            for i in data['input']:
                in_name = self.id_name[i]
                in_size = self.nodes[in_name]['para']['out_size']
                node.addInput(in_name)
                if self.nodes[in_name]['type'] == 1:
                    self.fc.connectTerminals(self.fc[in_name], node[in_name])
                else:
                    self.fc.connectTerminals(
                        self.fc.nodes()[in_name]['dataOut'], node[in_name])
            if data['isoutput']:
                self.fc.addOutput(data['name'])
                self.fc.connectTerminals(node['dataOut'],
                                         self.fc[data['name']])
        else:  #其他层的流程图操作
            node = self.fc.createNode(
                self.type_name[data['type']],
                name=data['name'],
                pos=(data['input'][0] * 120,
                     (data['ID'] - data['input'][0]) * 150 - 500))
            node.setPara(data['para'])
            node.setView(self.root)
            if self.nodes[self.id_name[data['input'][0]]]['type'] == 1:
                self.fc.connectTerminals(
                    self.fc[self.id_name[data['input'][0]]], node['dataIn'])
            else:
                self.fc.connectTerminals(
                    self.fc.nodes()[self.id_name[data['input'][0]]]['dataOut'],
                    node['dataIn'])
            if data['isoutput']:
                self.fc.addOutput(data['name'])
                self.fc.connectTerminals(node['dataOut'],
                                         self.fc[data['name']])
        if data['isoutput']:  #添加流程图输出
            self.outputs.append(data['ID'])
        if reset_flag:  #对替换的层恢复后向的连接
            for everynode in self.nodes.values():
                if data['ID'] in everynode['input']:
                    out_terminal = node.outputs()['dataOut']
                    if everynode['type'] in [9, 10, 11]:
                        in_terminal = self.fc.nodes()[
                            everynode['name']].inputs()[data['name']]
                    else:
                        in_terminal = self.fc.nodes()[
                            everynode['name']].inputs()['dataIn']
                    self.fc.connectTerminals(in_terminal, out_terminal)
                    self.net.add_edge(data['ID'], everynode['ID'])

    def export_file(self):  #导出pytorch脚本
        filename, _ = QFileDialog.getSaveFileName(self, '导出模型', 'C:\\',
                                                  'Python Files (*.py)')
        if filename is None or filename == "":
            return 0
        filedir, filename_text = os.path.split(filename)
        filename_text = filename_text.split(".")[0]
        if self.check() == 0:
            QMessageBox.warning(self, "错误", "出现size<=0或模型没有输出")
            return 0
        sort = list(nx.topological_sort(self.net))
        content = list()
        for id in sort:
            content.append(self.nodes[self.id_name[id]])
        with open(filedir + '/' + filename_text + ".json", 'w') as f1:
            json.dump(content, f1)
        with open(filename, 'w') as f:
            space = "    "
            f.write(
                'import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n'
            )
            f.write("class Net(nn.Module):\n")
            f.write(space + "def __init__(self):\n")
            f.write(space * 2 +
                    "super(Net, self).__init__()\n")  #一下为__init__函数
            for id in sort:
                name = self.id_name[id]
                layer = self.nodes[name]
                if layer['type'] == 2:
                    f.write(space * 2 + "self." + layer['name'] +
                            " = nn.Conv2d(")
                    f.write(str(layer['para']['in_size'][0])+","+str(layer['para']['outchannels'])+","\
                            +str(layer['para']['kernel']))
                    for k, v in layer['para'].items():
                        if k in [
                                'stride', 'padding', 'dilation', 'bias',
                                'padding_mode'
                        ] and v is not None:
                            f.write("," + k + "=" + str(v))
                    f.write(")\n")
                elif layer['type'] == 12:
                    f.write(space * 2 +
                            "self.{} = nn.Identity()\n".format(name))
                elif layer['type'] == 3:
                    if layer['para']['type'] == "max":
                        f.write(space * 2 + "self." + layer['name'] +
                                " = nn.MaxPool2d(")
                        f.write(str(layer['para']['kernel']))
                        for k, v in layer['para'].items():
                            if k in ['stride', 'padding'] and v is not None:
                                f.write("," + k + "=" + str(v))
                    elif layer['para']['type'] == "average":
                        f.write(space * 2 + "self." + layer['name'] +
                                " = nn.AvgPool2d(")
                        f.write(str(layer['para']['kernel']))
                        for k, v in layer['para'].items():
                            if k in ['stride', 'padding'] and v is not None:
                                f.write("," + k + "=" + str(v))
                    else:
                        f.write(space * 2 + "self." + layer['name'] +
                                " = nn.LPPool2d(")
                        f.write(str(layer['para']['power']) + ",")
                        f.write(str(layer['para']['kernel']))
                        for k, v in layer['para'].items():
                            if k == 'stride' and v is not None:
                                f.write("," + k + "=" + str(v))
                    f.write(")\n")
                elif layer['type'] == 4:
                    f.write(space * 2 + "self." + layer['name'] +
                            " = nn.Linear(")
                    f.write(str(layer['para']['in_size']) + ",")
                    f.write(str(layer['para']['out_features']) + ",")
                    f.write("bias=" + str(layer['para']['bias']))
                    f.write(")\n")
                elif layer['type'] == 7:
                    f.write(space * 2 + "self." + layer['name'] +
                            " = nn.BatchNorm1d(")
                    f.write(str(layer['para']['in_size']) + ",")
                    f.write("eps=" + str(layer['para']['eps']) + ",")
                    f.write("momentum=" + str(layer['para']['momentum']) + ",")
                    f.write("affine=" + str(layer['para']['affine']) + ",")
                    f.write("track_running_stats=" +
                            str(layer['para']['track_running_stats']) + ")\n")
                elif layer['type'] == 8:
                    f.write(space * 2 + "self." + layer['name'] +
                            " = nn.BatchNorm2d(")
                    f.write(str(layer['para']['in_size'][0]) + ",")
                    f.write("eps=" + str(layer['para']['eps']) + ",")
                    f.write("momentum=" + str(layer['para']['momentum']) + ",")
                    f.write("affine=" + str(layer['para']['affine']) + ",")
                    f.write("track_running_stats=" +
                            str(layer['para']['track_running_stats']) + ")\n")
            #以下为forward函数
            f.write(space + "def forward(self")
            for input in self.fc_inputs.keys():
                f.write("," + input)
            f.write("):\n")
            return_layers = list()
            for layer_id in sort:
                layer = self.nodes[self.id_name[layer_id]]
                if layer['type'] == 2 or layer['type'] == 4:
                    f.write(space * 2 + layer['name'] + " = ")
                    if layer['type'] == 4:
                        input_name = self.id_name[layer['input'][0]]
                        inner_str = "self.{0}({1}.view({1}.size()[0], -1))".format(
                            layer['name'], input_name)
                    else:
                        inner_str = "self." + layer[
                            'name'] + "(" + self.id_name[layer['input']
                                                         [0]] + ")"
                    if layer['para']['activate'] != "None":
                        if layer['para']['activate'] in ['tanh', 'sigmoid']:
                            tmp = "torch." + layer['para'][
                                'activate'] + "({})".format(inner_str)
                        else:
                            tmp = "F." + layer['para'][
                                'activate'] + "({})".format(inner_str)
                        inner_str = tmp
                    if layer['para']['dropout']:
                        tmp = "F.dropout({}, p=".format(inner_str) + str(
                            layer['para']['dropout_radio']) + ")"
                        inner_str = tmp
                    f.write(inner_str)
                    f.write("\n")
                elif layer['type'] in [3, 7, 8]:
                    f.write(
                        space * 2 + layer['name'] + " = self.{}({})\n".format(
                            layer['name'], self.id_name[layer['input'][0]]))
                elif layer['type'] == 5:
                    f.write(space * 2 + layer['name'] +
                            " = F.softmax({}, dim=1)\n".format(self.id_name[
                                layer['input'][0]]))
                elif layer['type'] == 6:
                    f.write(space * 2 + layer['name'] +
                            " = F.log_softmax({}, dim=1)\n".format(
                                self.id_name[layer['input'][0]]))
                elif layer['type'] == 9:
                    f.write(space * 2 + layer['name'] + " = ")
                    for i in range(len(layer['input']) - 1):
                        f.write(self.id_name[layer['input'][i]] + "+")
                    f.write(self.id_name[layer['input'][len(layer['input']) -
                                                        1]] + "\n")
                elif layer['type'] == 10:
                    layers_to_cat = list()
                    for input_id in layer['input']:
                        input_name = self.id_name[input_id]
                        input_layer = self.nodes[input_name]
                        layers_to_cat.append(input_name)
                        pad_up, pad_down, pad_left, pad_right = 0, 0, 0, 0
                        if input_layer['para']['out_size'][1] != layer['para'][
                                'out_size'][1]:
                            rest = (layer['para']['out_size'][1] -
                                    input_layer['para']['out_size'][1]) / 2
                            if rest != int(rest):
                                pad_up, pad_down = int(rest), int(rest) + 1
                            else:
                                pad_up, pad_down = int(rest), int(rest)
                        if input_layer['para']['out_size'][2] != layer['para'][
                                'out_size'][2]:
                            rest = (layer['para']['out_size'][2] -
                                    input_layer['para']['out_size'][2]) / 2
                            if rest != int(rest):
                                pad_left, pad_right = int(rest), int(rest) + 1
                            else:
                                pad_left, pad_right = int(rest), int(rest)
                            pad = (pad_left, pad_right, pad_up, pad_down)
                        f.write(space*2+input_name+" = F.pad({}, ({}, {}, {}, {}))\n".format(input_name,\
                                                        pad_left, pad_right, pad_up, pad_down))
                    f.write(space * 2 + layer['name'] +
                            " = torch.cat([{}], 1)\n".format(",".join(
                                layers_to_cat)))
                elif layer['type'] == 11:
                    layers_to_cat = list()
                    for input_id in layer['input']:
                        input_name = self.id_name[input_id]
                        layers_to_cat.append(input_name)
                    f.write(space * 2 + layer['name'] +
                            " = torch.cat([{}], 1)\n".format(",".join(
                                layers_to_cat)))
                elif layer['type'] == 12:
                    f.write(
                        space * 2 + layer['name'] + " = self.{}({})\n".format(
                            layer['name'], self.id_name[layer['input'][0]]))
                if layer['isoutput']:
                    return_layers.append(layer['name'])
            if len(return_layers) > 1:
                f.write(space * 2 +
                        "return {}\n".format(",".join(return_layers)))
            else:
                f.write(space * 2 + "return {}".format(return_layers[0]))

    def check(self):  #基于DAG的校验
        if len(self.outputs) == 0:
            QMessageBox.warning(self, "错误", "模型没有输出")
            return 0
        for item in self.nodes.values():
            if type(item['para']['out_size']) is int:
                if item['para']['out_size'] <= 0:
                    QMessageBox.warning(self, "错误",
                                        "{}层输出尺寸为负数或0".format(item['name']))
                    return 0
            else:
                for size in item['para']['out_size']:
                    if size <= 0:
                        QMessageBox.warning(
                            self, "错误", "{}层输出尺寸为负数或0".format(item['name']))
                        return 0

    def import_file(self):  #导入json文件重建模型
        filename, _ = QFileDialog.getOpenFileName(self, '导入模型', 'C:\\',
                                                  'JSON Files (*.json)')
        if filename is None or filename == "":
            return 0
        d_json = json.load(open(filename, 'r'))
        self.id_name = dict()
        self.name_id = dict()
        self.nodes = dict()
        self.net = nx.DiGraph()
        self.fc.clear()
        self.fc_inputs = dict()
        self.global_id = 0
        self.outputs = list()
        self.detail.clear()
        self.detail.setColumnCount(2)
        self.detail.setHeaderLabels(["属性", "值"])
        self.root = QTreeWidgetItem(self.detail)
        self.root.setText(0, "所有属性")
        for data in d_json:
            self.id_name[data['ID']] = data['name']
            self.name_id[data['name']] = data['ID']
            self.nodes[data['name']] = data
            self.net.add_node(data['ID'])
            for i in data['input']:
                self.net.add_edge(i, data['ID'])
            if data['type'] == 1:
                if not data['name'] in self.fc.inputs().keys():
                    self.fc.addInput(data['name'])
                self.fc_inputs[data['name']] = data['para']['out_size']
                self.fc.setInput(**self.fc_inputs)
            elif data['type'] in [9, 10, 11]:
                node = self.fc.createNode(
                    self.type_name[data['type']],
                    name=data['name'],
                    pos=(data['input'][0] * 120,
                         (data['ID'] - data['input'][0]) * 150 - 500))
                node.setPara(data['para'])
                node.setView(self.root)
                for i in data['input']:
                    in_name = self.id_name[i]
                    in_size = self.nodes[in_name]['para']['out_size']
                    node.addInput(in_name)
                    if self.nodes[in_name]['type'] == 1:
                        self.fc.connectTerminals(self.fc[in_name],
                                                 node[in_name])
                    else:
                        self.fc.connectTerminals(
                            self.fc.nodes()[in_name]['dataOut'], node[in_name])
                if data['isoutput']:
                    if not data['name'] in self.fc.outputs().keys():
                        self.fc.addOutput(data['name'])
                    self.fc.connectTerminals(node['dataOut'],
                                             self.fc[data['name']])
            else:
                node = self.fc.createNode(
                    self.type_name[data['type']],
                    name=data['name'],
                    pos=(data['input'][0] * 120,
                         (data['ID'] - data['input'][0]) * 150 - 500))
                node.setPara(data['para'])
                node.setView(self.root)
                if self.nodes[self.id_name[data['input'][0]]]['type'] == 1:
                    self.fc.connectTerminals(
                        self.fc[self.id_name[data['input'][0]]],
                        node['dataIn'])
                else:
                    self.fc.connectTerminals(
                        self.fc.nodes()[self.id_name[data['input'][0]]]
                        ['dataOut'], node['dataIn'])
                if data['isoutput']:
                    if not data['name'] in self.fc.outputs().keys():
                        self.fc.addOutput(data['name'])
                    self.fc.connectTerminals(node['dataOut'],
                                             self.fc[data['name']])
            if data['isoutput']:
                self.outputs.append(data['ID'])