def test_channel_shuffle(): x = torch.randn(1, 24, 56, 56) with pytest.raises(AssertionError): # num_channels should be divisible by groups channel_shuffle(x, 7) groups = 3 batch_size, num_channels, height, width = x.size() channels_per_group = num_channels // groups out = channel_shuffle(x, groups) # test the output value when groups = 3 for b in range(batch_size): for c in range(num_channels): c_out = c % channels_per_group * groups + c // channels_per_group for i in range(height): for j in range(width): assert x[b, c, i, j] == out[b, c_out, i, j]
def _inner_forward(x): if self.stride > 1: out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) else: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) out = channel_shuffle(out, 2) return out
def _inner_forward(x): x = self.conv1(x) x1, x2 = x.chunk(2, dim=1) x2 = self.expand_conv(x2) x2 = self.depthwise_conv(x2) x2 = self.linear_conv(x2) out = torch.cat((self.branch1(x1), x2), dim=1) out = channel_shuffle(out, 2) return out
def _inner_forward(x): x = [s.chunk(2, dim=1) for s in x] x1 = [s[0] for s in x] x2 = [s[1] for s in x] x2 = self.cross_resolution_weighting(x2) x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)] out = [channel_shuffle(s, 2) for s in out] return out