示例#1
0
 def replace(self, model, match):
     op = match.node.op
     attrs = {name: op.get_config(name) for name in op.configs}
     rt_cfg = match.node.spec.config
     gelu = nnq.GELU(**attrs,
                     approx_mode=rt_cfg.approx_mode,
                     approx_degree=rt_cfg.approx_degree)
     mod_util.replace_modules(model, match.node.name, gelu)
示例#2
0
 def replace(self, model, match):
     # Fuse (Conv2d, BatchNorm2d) and (Conv3d, BatchNorm3d)
     conv_match, bn_match = match.inputs[0].node, match.node
     conv_name, bn_name = conv_match.name, bn_match.name
     conv, bn = conv_match.module, bn_match.module
     transposed = True if isinstance(conv, _ConvTransposeNd) else False
     conv_bn = fuse_conv_bn(conv, bn, conv_match.spec)
     mod_util.replace_modules(model, [conv_name, bn_name], conv_bn)
     return {bn_name: conv_name + '.bn'}
示例#3
0
    def replace(self, model, match):
        op = match.node.op
        attrs = {name: op.get_config(name) for name in op.configs}

        replacement_map = {
            nn.AvgPool2d: nnq.DPUAvgPool2d,
            nn.AdaptiveAvgPool2d: nnq.DPUAdaptiveAvgPool2d,
        }
        pool2d = replacement_map[type(match.node.module)](**attrs)
        mod_util.replace_modules(model, match.node.name, pool2d)
示例#4
0
 def replace(self, model, match):
     float_to_qat = {
         nn.Conv2d: nnq.QuantizedConv2d,
         nn.Conv3d: nnq.QuantizedConv3d,
         nn.ConvTranspose2d: nnq.QuantizedConvTranspose2d,
         nn.ConvTranspose3d: nnq.QuantizedConvTranspose3d,
     }
     matched_module = match.node.module
     qat_cls = float_to_qat[type(matched_module)]
     mod_util.replace_modules(
         model, match.node.name,
         qat_cls.from_float(matched_module, match.node.spec))
示例#5
0
 def replace(self, model, match):
     float_to_qat = {
         nn.Conv2d: nnqat.QuantizedConv2d,
         nn.Conv3d: nnqat.QuantizedConv3d,
         nn.ConvTranspose2d: nnqat.QuantizedConvTranspose2d,
         nn.ConvTranspose3d: nnqat.QuantizedConvTranspose3d,
     }
     matched_module = match.node.module
     #print('replace:', match.node.graph_node.name, match.node.qconfig)
     qat_cls = float_to_qat[type(matched_module)]
     mod_util.replace_modules(
         model, match.node.name,
         qat_cls.from_float(matched_module, match.node.qconfig))
示例#6
0
 def replace(self, model, match):
     mod_util.replace_modules(
         model, match.node.name,
         nnq.LayerNorm.from_float(match.node.module, match.node.spec))
示例#7
0
 def replace(self, model, match):
     rt_cfg = match.node.spec.config
     softmax = nnq.Sigmoid(rt_cfg.approx_mode, rt_cfg.approx_degree,
                           rt_cfg.exp_table_size)
     mod_util.replace_modules(model, match.node.name, softmax)
示例#8
0
 def replace(self, model, match):
     op = match.node.op
     attrs = {name: op.get_config(name) for name in op.configs}
     relu = nnq.DPULeakyReLU(*attrs)
     mod_util.replace_modules(model, match.node.name, relu)
示例#9
0
 def replace(self, model, match):
     mod_util.replace_modules(
         model, match.node.name,
         nnq.QuantizedLinear.from_float(match.node.module, match.node.spec))