def fuse(self, bn): if self.groups != 1: # here conv weights and bn weights don't match s = self.conv.weight.shape # remember original shape: cho, chi, k, k # shuffle conv weight in 'cho' dimension to match 'bn' x = self.conv.weight.reshape(self.conv.out_channels, -1) # cho, chi*k^2 x = x.reshape(self.groups, int(self.cho / self.groups), -1) # g, n, chi*k^2 x = torch.transpose(x, 1, 0) # n, g, chi*k^2 x = x.reshape(self.conv.out_channels, -1) # cho, chi*k^2 but shuffled self.conv.weight = x.reshape(*s) # reshape copies, re-assign self.conv = fuse_conv_and_bn(self.conv, bn) # now weigths match # shuffle conv weight in 'cho' dimension back to initial order x = self.conv.weight.reshape(self.conv.out_channels, -1) # cho, chi*k^2 x = x.reshape(int(self.cho / self.groups), self.groups, -1) # n, g, chi*k^2 x = torch.transpose(x, 1, 0) # g, n, chi*k^2 x = x.reshape(self.conv.out_channels, -1) # cho, chi*k^2 self.conv.weight = x.reshape(*s) # reshape copies, re-assign else: # straight forward case self.conv = fuse_conv_and_bn(self.conv, bn) return self
def fuse(self): # merge batchnorm and convolution for inference speed up if not isinstance(self.conv, nn.Conv2d): self.conv = self.conv.fuse( self.bn) # each custom BASE has own fuse method else: self.conv = fuse_conv_and_bn(self.conv, self.bn) delattr(self, 'bn') # remove batchnorm self.forward = self.fuseforward # update forward
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers print('Fusing layers... ') for m in self.model.modules(): if type(m) is Conv and hasattr(m, 'bn'): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.fuseforward # update forward self.info() return self
def test_fuse_conv_and_bn(): x = torch.randn(16, 3, 256, 256) rn18 = torchvision.models.resnet18(pretrained=True) rn18.eval() net = torch.nn.Sequential(rn18.conv1, rn18.bn1) y1 = net.forward(x) fusedconv = fuse_conv_and_bn(net[0], net[1]) y2 = fusedconv.forward(x) d = (y1 - y2).norm().div(y1.norm()).item() print('fuse relative error: %.8f' % d)
def fuse(self, bn): self.conv = fuse_conv_and_bn(self.conv, bn) return self