예제 #1
0
def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
    """
    Fuses convolution/BN layers for inference purposes. Will deepcopy your
    model by default, but can modify the model inplace as well.
    """
    patterns = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d),
                (nn.Conv3d, nn.BatchNorm3d)]
    if not inplace:
        model = copy.deepcopy(model)
    fx_model = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    new_graph = copy.deepcopy(fx_model.graph)

    for pattern in patterns:
        for node in new_graph.nodes:
            if matches_module_pattern(pattern, node, modules):
                if len(node.args[0].users
                       ) > 1:  # Output of conv is used by other nodes
                    continue
                conv = modules[node.args[0].target]
                bn = modules[node.target]
                if not bn.track_running_stats:
                    continue
                fused_conv = fuse_conv_bn_eval(conv, bn)
                replace_node_module(node.args[0], modules, fused_conv)
                node.replace_all_uses_with(node.args[0])
                new_graph.erase_node(node)
    return fx.GraphModule(fx_model, new_graph)
예제 #2
0
  def test_register_forward_hook(self):
    """
    https://learnml.today/speeding-up-model-with-fusing-batch-normalization-and-convolution-3
    """
    import torch
    import torchvision
    from torch.nn.utils import fusion

    @torch.no_grad()
    def fuse(conv, bn):

      fused = torch.nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=True
      )

      # setting weights
      w_conv = conv.weight.clone().view(conv.out_channels, -1)
      w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))

      fused.weight.copy_(torch.mm(w_bn, w_conv).view(fused.weight.size()))

      # setting bias
      if conv.bias is not None:
        b_conv = conv.bias
      else:
        b_conv = torch.zeros(conv.weight.size(0))

      b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
        torch.sqrt(bn.running_var + bn.eps))
      fused.bias.copy_(w_bn @ b_conv + b_bn)
      # fused.bias.copy_(b_conv + b_bn)

      return fused

    # Testing
    # we need to turn off gradient calculation because we didn't write it
    torch.set_grad_enabled(False)
    x = torch.randn(16, 3, 256, 256)
    # resnet18 = torchvision.models.resnet18(pretrained=True)
    pretrained_net = torchvision.models.vgg11_bn(pretrained=True)
    # removing all learning variables, etc
    pretrained_net.eval()
    model = torch.nn.Sequential(
      pretrained_net.features[0],
      pretrained_net.features[1]
    )
    f1 = model.forward(x)
    print(model[0].weight.min(), model[0].weight.max(),
          model[0].bias.min(), model[0].bias.max())

    fused = fuse(model[0], model[1])
    print(fused.weight.min(), fused.weight.max(),
          fused.bias.min(), fused.bias.max())

    f2 = fused.forward(x)
    d = (f1 - f2).mean().item()
    print("error:", d)

    fused_2 = fusion.fuse_conv_bn_eval(model[0], model[1])
    f2 = fused_2.forward(x)
    d = (f1 - f2).mean().item()
    print("error:", d)

    pass
예제 #3
0
 def get_importance(self):
     if self.module.bn_module:
         fused_conv = fuse_conv_bn_eval(self.module.org_module, self.module.bn_module)
     else:
         fused_conv = self.module.org_module
     return torch.abs( fused_conv.weight * self.module.mask )
예제 #4
0
def parameters_extractor(model, ext_config, result_path="", fuse=False):
    """
    Extracts layers properties, weight & bias and writes the result to .h file  with the model name
    under the same path, 

    Keyword arguments:
    model -- The model object
    config -- config dictionary, contaning special parameters 

    """
    res_path = None
    global conv2d_counter
    global maxpool2d_counter
    global quantrelu_counter
    global fullyconn_counter
    global pre_layer
    with open(Path(result_path) / "ann_config.hpp", 'w') as file_object:
        file_object.write("#ifndef ANN_CONFIG_H_\n#define ANN_CONFIG_H_\n\n")
        file_object.write("{:<48}{}\n".format("#define DATAWIDTH",
                                              ext_config['DATAWIDTH']))
        file_object.write("{:<48}{}\n".format("#define CLASS_LABEL_BITS",
                                              ext_config['CLASS_LABEL_BITS']))
        file_object.write("{:<48}{}\n\n\n".format(
            "#define SEQUENCE_LENGTH", ext_config['SEQUENCE_LENGTH']))

        # Extract Features layers Data
        features_iter = model.features.children()
        conv_layer_list = []
        with tqdm(total=len(model.features),
                  desc='Extracting features parameters') as pbar:
            i = next(features_iter, None)
            while i != None:
                if isinstance(i, brevitas.nn.quant_conv.QuantConv2d):
                    if fuse:
                        bn = next(features_iter)
                        i = fuse_conv_bn_eval(i, bn)
                        # merge_bn(i,bn)
                        pbar.update()
                        print("Fusing BatchNorm2d with Conv2d layer:{}".format(
                            conv2d_counter))
                    quant_conv2d_parser(i, file_object, ext_config)
                    conv_layer_list.append(i)
                    pre_layer = "CONV2D_{}_".format(conv2d_counter)
                    conv2d_counter += 1
                elif isinstance(i, torch.nn.modules.pooling.MaxPool2d):
                    maxpool2d_parser(i, file_object, ext_config)
                    pre_layer = "MAXPOOL2D_{}_".format(maxpool2d_counter)
                    maxpool2d_counter += 1
                elif isinstance(i, brevitas.nn.quant_activation.QuantReLU):
                    quantReLU_parser(i, file_object, ext_config)
                    pre_layer = "RELU_{}_".format(quantrelu_counter)
                    quantrelu_counter += 1
                else:
                    print("Faced an Unknown layer:\n", type(i))
                i = next(features_iter, None)
                pbar.update()

        # Extract classifier layers Data
        for i in tqdm(model.classifier,
                      desc='Extracting classifer parameters'):
            if isinstance(i, brevitas.nn.QuantLinear):
                fullyconn_parser(i, file_object, ext_config)
                pre_layer = "FC_{}_".format(fullyconn_counter)
                fullyconn_counter += 1
            elif isinstance(i, brevitas.nn.quant_activation.QuantReLU):
                quantReLU_parser(i, file_object, ext_config)
                pre_layer = "RELU_{}_".format(quantrelu_counter)
                quantrelu_counter += 1
        file_object.write("#endif")
        res_path = file_object.name

    with open(Path(result_path) / "ann_weight_bias_config.hpp",
              'w') as file_object:
        file_object.write(
            "#ifndef ANN_WEIGHT_BIAS_CONFIG_H_\n#define ANN_WEIGHT_BIAS_CONFIG_H_\n"
        )
        file_object.write('#include "ann_config.hpp"\n\n')
        # Extract Conv layers Weight & Bias
        conv2d_counter = 0
        for i in tqdm(conv_layer_list,
                      desc='Extracting conv layers weight & bias'):
            if FINN_STRUCTURES:
                conv_weight_bias_finn(i, file_object,
                                      int(i.quant_weight_bit_width()),
                                      int(i.quant_bias_bit_width()))
            else:
                conv_weight_bias_array(i, file_object,
                                       int(i.quant_weight_bit_width()),
                                       int(i.quant_bias_bit_width()))
            conv2d_counter += 1

        # Extract linear layers Weight & Bias
        fullyconn_counter = 0
        for i in tqdm(model.classifier,
                      desc='Extracting linear layers weight & bias'):
            if isinstance(i, brevitas.nn.quant_linear.QuantLinear):
                if FINN_STRUCTURES:
                    linear_weight_bias_finn(i, file_object,
                                            int(i.quant_weight_bit_width()),
                                            int(i.quant_bias_bit_width()))
                else:
                    linear_weight_bias_array(i, file_object,
                                             int(i.quant_weight_bit_width()),
                                             int(i.quant_bias_bit_width()))
                fullyconn_counter += 1
        file_object.write("#endif")
        res_path = res_path + "\n" + file_object.name

    return res_path