def __init__(self, name, in_channels, out_channels, stride, dilation): super(BasicBlock, self).__init__() self.g_name = name self.in_channels = in_channels self.stride = stride channels = out_channels//2 if stride == 1: assert in_channels == out_channels self.conv = nn.Sequential( slim.conv_bn_relu(name + '/conv1', channels, channels, 1), slim.conv_bn(name + '/conv2', channels, channels, 3, stride=stride, dilation=dilation, padding=dilation, groups=channels), slim.conv_bn_relu(name + '/conv3', channels, channels, 1), ) else: self.conv = nn.Sequential( slim.conv_bn_relu(name + '/conv1', in_channels, channels, 1), slim.conv_bn(name + '/conv2', channels, channels, 3, stride=stride, dilation=dilation, padding=dilation, groups=channels), slim.conv_bn_relu(name + '/conv3', channels, channels, 1), ) self.conv0 = nn.Sequential( slim.conv_bn(name + '/conv4', in_channels, in_channels, 3, stride=stride, dilation=dilation, padding=dilation, groups=in_channels), slim.conv_bn_relu(name + '/conv5', in_channels, channels, 1), ) self.shuffle = slim.channel_shuffle(name + '/shuffle', 2)