def __init__(self, width_multiplier): super(Network, self).__init__() width_config = { 0.25: (24, 48, 96, 512), 0.33: (32, 64, 128, 512), 0.5: (48, 96, 192, 1024), 1.0: (116, 232, 464, 1024), 1.5: (176, 352, 704, 1024), 2.0: (244, 488, 976, 2048), } width_config = width_config[width_multiplier] in_channels = 24 # outputs, stride, dilation, blocks, type self.network_config = [ g_name('data/bn', nn.BatchNorm2d(3)), slim.conv_bn_relu('stage1/conv', 3, in_channels, 3, 2, 1), g_name('stage1/pool', nn.MaxPool2d(3, 2, 0, ceil_mode=True)), (width_config[0], 2, 1, 4, 'b'), (width_config[1], 1, 1, 8, 'b'), # x16 (width_config[2], 1, 1, 4, 'b'), # x32 slim.conv_bn_relu('conv5', width_config[2], width_config[3], 1) ] self.paf = nn.Conv2d(width_config[3], 38, 1) self.heatmap = nn.Conv2d(width_config[3], 19, 1) self.network = [] for i, config in enumerate(self.network_config): if isinstance(config, nn.Module): self.network.append(config) continue out_channels, stride, dilation, num_blocks, stage_type = config if stride == 2: downsample = True stage_prefix = 'stage_{}'.format(i - 1) blocks = [ BasicBlock(stage_prefix + '_1', in_channels, out_channels, stride, downsample, dilation) ] for i in range(1, num_blocks): blocks.append( BasicBlock(stage_prefix + '_{}'.format(i + 1), out_channels, out_channels, 1, False, dilation)) self.network += [nn.Sequential(*blocks)] in_channels = out_channels self.network = nn.Sequential(*self.network) for name, m in self.named_modules(): if any( map(lambda x: isinstance(m, x), [nn.Linear, nn.Conv1d, nn.Conv2d])): nn.init.kaiming_uniform_(m.weight, mode='fan_in') if m.bias is not None: nn.init.constant_(m.bias, 0)
def __init__(self, name, in_channels, out_channels, stride, downsample, dilation): super(BasicBlock, self).__init__() self.g_name = name self.in_channels = in_channels self.stride = stride self.downsample = downsample channels = out_channels // 2 if not self.downsample and self.stride == 1: assert in_channels == out_channels self.conv = nn.Sequential( slim.conv_bn_relu(name + '/conv1', channels, channels, 1), slim.conv_bn(name + '/conv2', channels, channels, 3, stride=stride, dilation=dilation, padding=dilation, groups=channels), slim.conv_bn_relu(name + '/conv3', channels, channels, 1), ) else: self.conv = nn.Sequential( slim.conv_bn_relu(name + '/conv1', in_channels, channels, 1), slim.conv_bn(name + '/conv2', channels, channels, 3, stride=stride, dilation=dilation, padding=dilation, groups=channels), slim.conv_bn_relu(name + '/conv3', channels, channels, 1), ) self.conv0 = nn.Sequential( slim.conv_bn(name + '/conv4', in_channels, in_channels, 3, stride=stride, dilation=dilation, padding=dilation, groups=in_channels), slim.conv_bn_relu(name + '/conv5', in_channels, channels, 1), ) self.shuffle = slim.channel_shuffle(name + '/shuffle', 2)