def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module: """ Fuses convolution/BN layers for inference purposes. Will deepcopy your model by default, but can modify the model inplace as well. """ patterns = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d)] if not inplace: model = copy.deepcopy(model) fx_model = fx.symbolic_trace(model) modules = dict(fx_model.named_modules()) new_graph = copy.deepcopy(fx_model.graph) for pattern in patterns: for node in new_graph.nodes: if matches_module_pattern(pattern, node, modules): if len(node.args[0].users ) > 1: # Output of conv is used by other nodes continue conv = modules[node.args[0].target] bn = modules[node.target] if not bn.track_running_stats: continue fused_conv = fuse_conv_bn_eval(conv, bn) replace_node_module(node.args[0], modules, fused_conv) node.replace_all_uses_with(node.args[0]) new_graph.erase_node(node) return fx.GraphModule(fx_model, new_graph)
def test_register_forward_hook(self): """ https://learnml.today/speeding-up-model-with-fusing-batch-normalization-and-convolution-3 """ import torch import torchvision from torch.nn.utils import fusion @torch.no_grad() def fuse(conv, bn): fused = torch.nn.Conv2d( conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, bias=True ) # setting weights w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) fused.weight.copy_(torch.mm(w_bn, w_conv).view(fused.weight.size())) # setting bias if conv.bias is not None: b_conv = conv.bias else: b_conv = torch.zeros(conv.weight.size(0)) b_bn = bn.bias - bn.weight.mul(bn.running_mean).div( torch.sqrt(bn.running_var + bn.eps)) fused.bias.copy_(w_bn @ b_conv + b_bn) # fused.bias.copy_(b_conv + b_bn) return fused # Testing # we need to turn off gradient calculation because we didn't write it torch.set_grad_enabled(False) x = torch.randn(16, 3, 256, 256) # resnet18 = torchvision.models.resnet18(pretrained=True) pretrained_net = torchvision.models.vgg11_bn(pretrained=True) # removing all learning variables, etc pretrained_net.eval() model = torch.nn.Sequential( pretrained_net.features[0], pretrained_net.features[1] ) f1 = model.forward(x) print(model[0].weight.min(), model[0].weight.max(), model[0].bias.min(), model[0].bias.max()) fused = fuse(model[0], model[1]) print(fused.weight.min(), fused.weight.max(), fused.bias.min(), fused.bias.max()) f2 = fused.forward(x) d = (f1 - f2).mean().item() print("error:", d) fused_2 = fusion.fuse_conv_bn_eval(model[0], model[1]) f2 = fused_2.forward(x) d = (f1 - f2).mean().item() print("error:", d) pass
def get_importance(self): if self.module.bn_module: fused_conv = fuse_conv_bn_eval(self.module.org_module, self.module.bn_module) else: fused_conv = self.module.org_module return torch.abs( fused_conv.weight * self.module.mask )
def parameters_extractor(model, ext_config, result_path="", fuse=False): """ Extracts layers properties, weight & bias and writes the result to .h file with the model name under the same path, Keyword arguments: model -- The model object config -- config dictionary, contaning special parameters """ res_path = None global conv2d_counter global maxpool2d_counter global quantrelu_counter global fullyconn_counter global pre_layer with open(Path(result_path) / "ann_config.hpp", 'w') as file_object: file_object.write("#ifndef ANN_CONFIG_H_\n#define ANN_CONFIG_H_\n\n") file_object.write("{:<48}{}\n".format("#define DATAWIDTH", ext_config['DATAWIDTH'])) file_object.write("{:<48}{}\n".format("#define CLASS_LABEL_BITS", ext_config['CLASS_LABEL_BITS'])) file_object.write("{:<48}{}\n\n\n".format( "#define SEQUENCE_LENGTH", ext_config['SEQUENCE_LENGTH'])) # Extract Features layers Data features_iter = model.features.children() conv_layer_list = [] with tqdm(total=len(model.features), desc='Extracting features parameters') as pbar: i = next(features_iter, None) while i != None: if isinstance(i, brevitas.nn.quant_conv.QuantConv2d): if fuse: bn = next(features_iter) i = fuse_conv_bn_eval(i, bn) # merge_bn(i,bn) pbar.update() print("Fusing BatchNorm2d with Conv2d layer:{}".format( conv2d_counter)) quant_conv2d_parser(i, file_object, ext_config) conv_layer_list.append(i) pre_layer = "CONV2D_{}_".format(conv2d_counter) conv2d_counter += 1 elif isinstance(i, torch.nn.modules.pooling.MaxPool2d): maxpool2d_parser(i, file_object, ext_config) pre_layer = "MAXPOOL2D_{}_".format(maxpool2d_counter) maxpool2d_counter += 1 elif isinstance(i, brevitas.nn.quant_activation.QuantReLU): quantReLU_parser(i, file_object, ext_config) pre_layer = "RELU_{}_".format(quantrelu_counter) quantrelu_counter += 1 else: print("Faced an Unknown layer:\n", type(i)) i = next(features_iter, None) pbar.update() # Extract classifier layers Data for i in tqdm(model.classifier, desc='Extracting classifer parameters'): if isinstance(i, brevitas.nn.QuantLinear): fullyconn_parser(i, file_object, ext_config) pre_layer = "FC_{}_".format(fullyconn_counter) fullyconn_counter += 1 elif isinstance(i, brevitas.nn.quant_activation.QuantReLU): quantReLU_parser(i, file_object, ext_config) pre_layer = "RELU_{}_".format(quantrelu_counter) quantrelu_counter += 1 file_object.write("#endif") res_path = file_object.name with open(Path(result_path) / "ann_weight_bias_config.hpp", 'w') as file_object: file_object.write( "#ifndef ANN_WEIGHT_BIAS_CONFIG_H_\n#define ANN_WEIGHT_BIAS_CONFIG_H_\n" ) file_object.write('#include "ann_config.hpp"\n\n') # Extract Conv layers Weight & Bias conv2d_counter = 0 for i in tqdm(conv_layer_list, desc='Extracting conv layers weight & bias'): if FINN_STRUCTURES: conv_weight_bias_finn(i, file_object, int(i.quant_weight_bit_width()), int(i.quant_bias_bit_width())) else: conv_weight_bias_array(i, file_object, int(i.quant_weight_bit_width()), int(i.quant_bias_bit_width())) conv2d_counter += 1 # Extract linear layers Weight & Bias fullyconn_counter = 0 for i in tqdm(model.classifier, desc='Extracting linear layers weight & bias'): if isinstance(i, brevitas.nn.quant_linear.QuantLinear): if FINN_STRUCTURES: linear_weight_bias_finn(i, file_object, int(i.quant_weight_bit_width()), int(i.quant_bias_bit_width())) else: linear_weight_bias_array(i, file_object, int(i.quant_weight_bit_width()), int(i.quant_bias_bit_width())) fullyconn_counter += 1 file_object.write("#endif") res_path = res_path + "\n" + file_object.name return res_path