def test_conv(Base, nruns=100, device='cuda'):
    # will do basic sanity check, we want to get same spatial
    # dimension with custom convolutions as in standard module
    # to swap convolution types conveniently
    chi, cho, k, s = 8, 32, 3, 1

    x = torch.randn(16, chi, 512, 512)
    conv = Base(chi, cho, k, s, autopad(k))
    conv_ = nn.Conv2d(chi, cho, k, s, autopad(k))

    if 'cuda' in device:
        assert torch.cuda.is_available()
        conv.cuda().train()
        conv_.cuda().train()
        x = x.cuda()

        if torch.backends.cudnn.benchmark:
            # have to do warm up iterations for fair comparison
            print('benchmark warm up...')
            for _ in range(50):
                _ = conv(x)
    else:
        conv.cpu().train()
        conv_.cpu().train()
        nruns = 1

    p = count_params(conv)
    p_ = count_params(conv_)
    # relative number of parameter change in brackets w.r.t. nn.conv2d
    print(f'Number of parameters: {p} ({p / p_ * 100:.2f}%)')

    # ensure same behaviour as standard module
    out = conv(x)
    out_ = conv_(x)
    assert out.shape == out_.shape, f'Shape missmatch, should be {out_.shape} but is {out.shape}'

    # g0 = torch.randn_like(out)
    # performance test without feature/target loading
    # because that would require a significant amount of overhead
    start = time_synchronized()
    for _ in range(nruns):
        out = conv(x)
        for param in conv.parameters():
            param.grad = None
        out.mean().backward()  # out.backward(g0)
    end = time_synchronized()

    print(f'Forward + Backward time: {(end - start) * 1000 / nruns:.3f}ms')
 def __init__(self,
              chi,
              cho,
              k,
              s=1,
              p=None,
              dilation=1,
              groups=1,
              bias=True):
     super().__init__()
     # decreases complexity, but kernel space limited
     k = k if isinstance(k, tuple) else (k, k)
     p = p if isinstance(p, tuple) else (p, p)
     s = s if isinstance(s, tuple) else (s, s)
     p = autopad(k, p)
     self.conv1 = nn.Conv2d(chi,
                            chi, (k[0], 1), (s[0], 1), (p[0], 0),
                            dilation=dilation,
                            groups=groups,
                            bias=True)
     self.conv2 = nn.Conv2d(chi,
                            cho, (1, k[1]), (1, s[1]), (0, p[1]),
                            dilation=dilation,
                            groups=groups,
                            bias=bias)
 def __init__(self,
              chi,
              cho,
              k,
              s=1,
              p=None,
              dilation=1,
              groups=8,
              bias=True):
     super().__init__()  # typically groups are 2, 4, 8, 16
     self.cho = cho
     self.groups = groups
     p = autopad(k, p)
     # decreases complexity, the idea of grouped convolutions is that
     # the correlation between feature channels is sparse anyway,
     # and here we will be even more sparse since we only allow
     # intra channel group correlation,
     # use grouped convolution also for 1x1 convolutions (see ShuffleNet)
     # which are then called pointwise grouped convolutions
     self.conv = nn.Conv2d(chi,
                           cho,
                           k,
                           s,
                           p,
                           dilation=dilation,
                           groups=groups,
                           bias=bias)
 def __init__(self,
              chi,
              cho,
              k,
              s=1,
              p=None,
              dilation=1,
              groups=1,
              bias=True):
     super().__init__()
     k = k if isinstance(k, tuple) else (k, k)
     p = autopad(k, p)
     self.conv = ops.DeformConv2d(chi,
                                  cho,
                                  k,
                                  s,
                                  p,
                                  dilation=dilation,
                                  groups=groups,
                                  bias=bias)
     # for each group we need output channels of 2 to get for each kernel weight
     # offset position in x, y (2 channels) and we have to know that for every
     # pixel in the convolution output, thus we use same kernel size and padding!!
     self.offset = nn.Conv2d(chi,
                             groups * 2 * k[0] * k[1],
                             k,
                             s,
                             p,
                             dilation=dilation,
                             groups=groups,
                             bias=True)
     self._init_offset()
 def __init__(self,
              chi,
              cho,
              k,
              s=1,
              p=None,
              dilation=1,
              groups=1,
              bias=True):
     super().__init__()
     # paper claims importance of bias term!
     k = k if isinstance(k, tuple) else (k, k)
     p = p if isinstance(p, tuple) else (p, p)
     s = s if isinstance(s, tuple) else (s, s)
     p = autopad(k, p)
     # lateral, kernel: C x 1 x 1
     self.conv1 = nn.Conv2d(chi, cho, 1, 1, 0, groups=1, bias=True)
     # vertical, kernel: 1 x Y x 1
     self.conv2 = nn.Conv2d(cho,
                            cho, (k[0], 1), (s[0], 1), (p[0], 0),
                            dilation=dilation,
                            groups=cho,
                            bias=True)
     # horizontal, kernel: 1 x 1 x X,
     # last term can omit bias e.g. if batchnorm is done anyway afterwards
     self.conv3 = nn.Conv2d(cho,
                            cho, (1, k[1]), (1, s[1]), (0, p[1]),
                            dilation=dilation,
                            groups=cho,
                            bias=bias)
 def __init__(self,
              chi,
              cho,
              k,
              s=1,
              p=None,
              dilation=1,
              groups=1,
              bias=True):
     super().__init__()
     p = autopad(k, p)
     # decreases complexity, smaller networks can be wider,
     # each filter soley has access to a single input channel
     # and we keep the number of input channels at first
     self.conv1 = nn.Conv2d(chi,
                            chi,
                            k,
                            s,
                            p,
                            dilation=dilation,
                            groups=chi,
                            bias=True)
     # learn channelwise(inter group) correlation with 1x1 convolutions
     self.conv2 = nn.Conv2d(chi,
                            cho,
                            1,
                            1,
                            0,
                            dilation=dilation,
                            groups=1,
                            bias=bias)
 def __init__(self,
              c1,
              c2,
              k=1,
              s=1,
              p=None,
              g=1,
              act=True):  # ch_in, ch_out, kernel, stride, padding, groups
     super(Conv, self).__init__()
     self.conv = nn.Conv2d(c1,
                           c2,
                           k,
                           s,
                           autopad(k, p),
                           groups=g,
                           bias=False)
     self.bn = nn.BatchNorm2d(c2)
     self.act = nn.ReLU() if act is True else (
         act if isinstance(act, nn.Module) else nn.Identity())
 def __init__(self,
              chi,
              cho,
              k=1,
              s=1,
              p=None,
              d=1,
              g=GROUPS,
              act=True,
              affine=True):
     super().__init__()
     if chi % g != 0 or cho % g != 0:
         g = 1
         print(
             f'Channel {chi} or {cho} not divisible by groups: {g}; using groups=1'
         )
     p = autopad(k, p)
     self.conv = BASE(chi, cho, k, s, p, dilation=d, groups=g, bias=False)
     self.bn = nn.BatchNorm2d(cho, affine=affine)
     self.act = nn.ReLU() if act is True else (
         act if isinstance(act, nn.Module) else nn.Identity())