def _test_fuse_conv(self, conv_args, bn_args, relu_args, bn_class, input_size, device="cpu"): cbr = bb.ConvBNRelu( 3, 6, kernel_size=3, padding=1, conv_args=conv_args, bn_args=bn_args, relu_args=relu_args, ).to(device) for _ in range(3): inputs = torch.rand(input_size).to(device) cbr(inputs) cbr.eval() fused = fuse_utils.fuse_convbnrelu(cbr, inplace=False) self.assertTrue(_find_modules(cbr, bn_class)) self.assertFalse(_find_modules(fused, bn_class)) run_and_compare(cbr, fused, input_size, device)
def test_conv_bn_relu_upsample_empty_input(self): op = bb.ConvBNRelu(4, 4, stride=-2, kernel_size=3, padding=1) input_size = [0, 4, 4, 4] inputs = torch.rand(input_size) output = op(inputs) self.assertEqual(output.shape, torch.Size([0, 4, 8, 8]))
def __init__(self, in_channels, out_channels, *, check_out_channels, **kwargs): super().__init__() assert kwargs["width_divisor"] == WIDTH_DIVISOR self.check_out_channels = check_out_channels self.conv = basic_blocks.ConvBNRelu(in_channels, out_channels, **kwargs)
def test_conv_bn_relu_upsample(self): # currently empty batch for dw conv is not supported op = bb.ConvBNRelu(4, 4, stride=-2, kernel_size=3, padding=1) input_size = [1, 4, 4, 4] inputs = torch.rand(input_size) output = op(inputs) self.assertEqual(output.shape, torch.Size([1, 4, 8, 8]))
def test_conv_bn_relu_empty_input(self): # currently empty batch for dw conv is not supported op = bb.ConvBNRelu(4, 4, stride=2, kernel_size=3, padding=1, groups=4) input_size = [0, 4, 4, 4] inputs = torch.rand(input_size) output = op(inputs) self.assertEqual(output.shape, torch.Size([0, 4, 2, 2])) input_size = [2, 4, 4, 4] inputs = torch.rand(input_size) output = op(inputs) self.assertEqual(output.shape, torch.Size([2, 4, 2, 2]))
def test_fuse_convbnrelu(self): cbr = bb.ConvBNRelu(3, 6, kernel_size=3, padding=1, bn_args="bn", relu_args="relu").eval() fused = fuse_utils.fuse_convbnrelu(cbr, inplace=False) self.assertTrue(_find_modules(cbr, torch.nn.BatchNorm2d)) self.assertFalse(_find_modules(fused, torch.nn.BatchNorm2d)) input_size = [2, 3, 7, 7] run_and_compare(cbr, fused, input_size)
def __init__(self): super().__init__() cbr = bb.ConvBNRelu(3, 6, kernel_size=3, bn_args="bn") self.cbr = cbr self.cbr_list = [cbr]