def fuse(self, bn):
        if self.groups != 1:  # here conv weights and bn weights don't match
            s = self.conv.weight.shape  # remember original shape: cho, chi, k, k
            # shuffle conv weight in 'cho' dimension to match 'bn'
            x = self.conv.weight.reshape(self.conv.out_channels,
                                         -1)  # cho, chi*k^2
            x = x.reshape(self.groups, int(self.cho / self.groups),
                          -1)  # g, n, chi*k^2
            x = torch.transpose(x, 1, 0)  # n, g, chi*k^2
            x = x.reshape(self.conv.out_channels,
                          -1)  # cho, chi*k^2 but shuffled
            self.conv.weight = x.reshape(*s)  # reshape copies, re-assign

            self.conv = fuse_conv_and_bn(self.conv, bn)  # now weigths match

            # shuffle conv weight in 'cho' dimension back to initial order
            x = self.conv.weight.reshape(self.conv.out_channels,
                                         -1)  # cho, chi*k^2
            x = x.reshape(int(self.cho / self.groups), self.groups,
                          -1)  # n, g, chi*k^2
            x = torch.transpose(x, 1, 0)  # g, n, chi*k^2
            x = x.reshape(self.conv.out_channels, -1)  # cho, chi*k^2
            self.conv.weight = x.reshape(*s)  # reshape copies, re-assign
        else:  # straight forward case
            self.conv = fuse_conv_and_bn(self.conv, bn)
        return self
 def fuse(self):  # merge batchnorm and convolution for inference speed up
     if not isinstance(self.conv, nn.Conv2d):
         self.conv = self.conv.fuse(
             self.bn)  # each custom BASE has own fuse method
     else:
         self.conv = fuse_conv_and_bn(self.conv, self.bn)
     delattr(self, 'bn')  # remove batchnorm
     self.forward = self.fuseforward  # update forward
 def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layers
     print('Fusing layers... ')
     for m in self.model.modules():
         if type(m) is Conv and hasattr(m, 'bn'):
             m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
             delattr(m, 'bn')  # remove batchnorm
             m.forward = m.fuseforward  # update forward
     self.info()
     return self
def test_fuse_conv_and_bn():
    x = torch.randn(16, 3, 256, 256)
    rn18 = torchvision.models.resnet18(pretrained=True)
    rn18.eval()
    net = torch.nn.Sequential(rn18.conv1, rn18.bn1)
    y1 = net.forward(x)
    fusedconv = fuse_conv_and_bn(net[0], net[1])
    y2 = fusedconv.forward(x)
    d = (y1 - y2).norm().div(y1.norm()).item()
    print('fuse relative error: %.8f' % d)
 def fuse(self, bn):
     self.conv = fuse_conv_and_bn(self.conv, bn)
     return self