Exemplo n.º 1
0
def fuse_br1_br2c_uff(registry, network):
    pattern = [{"name":"input",  "type":trt.ITensor,               "children":["c_br1", "c_br2a"], "channels":64},
               {"name":"c_br1",  "type":trt.LayerType.CONVOLUTION, "children":"s_br1"},
               {"name":"s_br1",  "type":trt.LayerType.SCALE,       "children":"add"},
               {"name":"c_br2a", "type":trt.LayerType.CONVOLUTION, "children":"s_br2a"},
               {"name":"s_br2a", "type":trt.LayerType.SCALE,       "children":"r_br2a"},
               {"name":"r_br2a", "type":trt.LayerType.ACTIVATION,  "children":"c_br2b", "subtype":trt.ActivationType.RELU},
               {"name":"c_br2b", "type":trt.LayerType.CONVOLUTION, "children":"s_br2b"},
               {"name":"s_br2b", "type":trt.LayerType.SCALE,       "children":"r_br2b"},
               {"name":"r_br2b", "type":trt.LayerType.ACTIVATION,  "children":"c_br2c", "subtype":trt.ActivationType.RELU},
               {"name":"c_br2c", "type":trt.LayerType.CONVOLUTION, "children":"s_br2c"},
               {"name":"s_br2c", "type":trt.LayerType.SCALE,       "children":"add"},
               {"name":"add",    "type":trt.LayerType.ELEMENTWISE, "children":"relu",   "op":trt.ElementWiseOperation.SUM},
               {"name":"relu",   "type":trt.LayerType.ACTIVATION,  "children":"output", "subtype":trt.ActivationType.RELU},
               {"name":"output", "type":trt.ITensor,               "channels":256}]

    matches = ns.search(network, pattern)
    matchNumber = 0
    for match in matches:
        matchNumber = matchNumber + 1
        pluginName = "RES2_BR1_BR2C_" + str(matchNumber)

        # build an array with the dynamic ranges computed during calibration
        dynamic_ranges=np.array([match["input"].get_dynamic_range(),
                                 match["s_br1"].get_output(0).get_dynamic_range(),
                                 match["c_br2c"].get_input(0).get_dynamic_range(),
                                 match["s_br2c"].get_output(0).get_dynamic_range(),
                                 match["output"].get_dynamic_range()], dtype=np.float32)
                                 
        # build plugin fields, with weight/scale/bias/dynamic_range data
        fields = trt.PluginFieldCollection()
        fields.append(trt.PluginField("c_br1_w", match["c_br1"].kernel.data, trt.PluginFieldType.FLOAT32))
        fields.append(trt.PluginField("s_br1_s", match["s_br1"].scale.data, trt.PluginFieldType.FLOAT32))
        fields.append(trt.PluginField("s_br1_b", match["s_br1"].shift.data, trt.PluginFieldType.FLOAT32))
        fields.append(trt.PluginField("c_br2c_w", match["c_br2c"].kernel.data, trt.PluginFieldType.FLOAT32))
        fields.append(trt.PluginField("s_br2c_s", match["s_br2c"].scale.data, trt.PluginFieldType.FLOAT32))
        fields.append(trt.PluginField("s_br2c_b", match["s_br2c"].shift.data, trt.PluginFieldType.FLOAT32))
        fields.append(trt.PluginField("dynamic_ranges", memoryview(dynamic_ranges), trt.PluginFieldType.FLOAT32))
            
        creator=registry.get_plugin_creator('RnRes2Br1Br2c_TRT', '2', '');
        if creator is None:
            raise Exception("Creator for 'RnRes2Br1Br2c_TRT' not found")
        plugin=creator.create_plugin(pluginName, fields)
        if plugin is None:
            raise Exception("Plugin creation failed")
          
        logging.info("Plugin creation successful")
        inputs = [match["input"], match["r_br2b"].get_output(0)]
        layer = network.add_plugin_v2(inputs, plugin)
        layer.name = pluginName
        
        unfusedOutput = match["output"]
        fusedOutput = layer.get_output(0)
        fusedOutput.set_dynamic_range(-unfusedOutput.get_dynamic_range(), unfusedOutput.get_dynamic_range())
            
        for i in range(network.num_layers):
            layer = network.get_layer(i)
            if layer.name==pluginName:
                continue
            for j in range(layer.num_inputs):
               if layer.get_input(j) == unfusedOutput:
                   layer.set_input(j, fusedOutput)
def fuse_serial_3_conv2dc_onnx(registry, network):
    pattern1 = [{
        "name": "input",
        "type": trt.ITensor,
        "children": ["add", "c_br2a"],
        "channels": 512
    }, {
        "name": "c_br2a",
        "type": trt.LayerType.CONVOLUTION,
        "children": "r_br2a"
    }, {
        "name": "r_br2a",
        "type": trt.LayerType.ACTIVATION,
        "children": "c_br2b",
        "subtype": trt.ActivationType.RELU
    }, {
        "name": "c_br2b",
        "type": trt.LayerType.CONVOLUTION,
        "children": "r_br2b"
    }, {
        "name": "r_br2b",
        "type": trt.LayerType.ACTIVATION,
        "children": "c_br2c",
        "subtype": trt.ActivationType.RELU
    }, {
        "name": "c_br2c",
        "type": trt.LayerType.CONVOLUTION,
        "children": "add"
    }, {
        "name": "add",
        "type": trt.LayerType.ELEMENTWISE,
        "children": "relu",
        "op": trt.ElementWiseOperation.SUM
    }, {
        "name": "relu",
        "type": trt.LayerType.ACTIVATION,
        "children": "output",
        "subtype": trt.ActivationType.RELU
    }, {
        "name": "output",
        "type": trt.ITensor,
        "channels": 512
    }]

    scale = trt.Weights(np.ones((256), dtype=np.float32))
    matches = ns.search(network, pattern1)
    for j in range(len(matches) - 1):
        for i in range(len(matches) - 1 - j):
            if matches[i]['input'].name[-4:] < matches[i +
                                                       1]['input'].name[-4:]:
                temp = matches[i]
                matches[i] = matches[i + 1]
                matches[i + 1] = temp

    matchNumber = 0
    print('matches:', len(matches))
    for match in matches:
        # build an array with the dynamic ranges computed during calibration
        dynamic_ranges = np.array(
            [
                match["input"].get_dynamic_range(),
                match["r_br2b"].get_output(0).get_dynamic_range(),
                match["output"].get_dynamic_range()  # short scale
            ],
            dtype=np.float32)
        input_layout = np.array(match["input"].shape, dtype=np.int32)
        out_layout = np.array(match["output"].shape, dtype=np.int32)
        logging.info(dynamic_ranges)

        # build plugin fields, with weight/scale/bias/dynamic_range data
        fields = trt.PluginFieldCollection()

        fields.append(
            trt.PluginField("c_br2c_w", match["c_br2c"].kernel.data,
                            trt.PluginFieldType.FLOAT32))
        fields.append(
            trt.PluginField("s_br2c_b", match["c_br2c"].bias.data,
                            trt.PluginFieldType.FLOAT32))

        fields.append(
            trt.PluginField("dynamic_ranges", memoryview(dynamic_ranges),
                            trt.PluginFieldType.FLOAT32))

        creator = registry.get_plugin_creator('Serial3Conv2d_TRT', '1', '')
        matchNumber = matchNumber + 1
        pluginName = "Serial3Conv2d_" + str(matchNumber)
        if creator is None:
            raise Exception("Creator for 'Serial3Conv2d_TRT' not found")
        logging.info("Plugin create_plugin create_plugin")
        plugin = creator.create_plugin(pluginName, fields)
        logging.info("Plugin create_plugin successful")
        if plugin is None:
            raise Exception("Plugin creation failed")

        logging.info("Plugin creation successful")
        inputs = [match["input"], match["r_br2b"].get_output(0)]
        layer = network.add_plugin_v2(inputs, plugin)
        layer.name = pluginName

        unfusedOutput = match["output"]  # shortcut

        fusedOutput = layer.get_output(0)
        fusedOutput.set_dynamic_range(-unfusedOutput.get_dynamic_range(),
                                      unfusedOutput.get_dynamic_range())

        for i in range(network.num_layers):
            layer = network.get_layer(i)
            if layer.name == pluginName:
                continue
            for j in range(layer.num_inputs):
                if layer.get_input(j) == unfusedOutput:
                    print('replace:', matchNumber)
                    layer.set_input(j, fusedOutput)

        #break
    print('complete all sub')