Esempio n. 1
0
    def __init__(self, width_multiplier):
        super(Network, self).__init__()
        width_config = {
            0.25: (24, 48, 96, 512),
            0.33: (32, 64, 128, 512),
            0.5: (48, 96, 192, 1024),
            1.0: (116, 232, 464, 1024),
            1.5: (176, 352, 704, 1024),
            2.0: (244, 488, 976, 2048),
        }
        width_config = width_config[width_multiplier]
        in_channels = 24

        # outputs, stride, dilation, blocks, type
        self.network_config = [
            g_name('data/bn', nn.BatchNorm2d(3)),
            slim.conv_bn_relu('stage1/conv', 3, in_channels, 3, 2, 1),
            g_name('stage1/pool', nn.MaxPool2d(3, 2, 0, ceil_mode=True)),
            (width_config[0], 2, 1, 4, 'b'),
            (width_config[1], 1, 1, 8, 'b'),  # x16
            (width_config[2], 1, 1, 4, 'b'),  # x32
            slim.conv_bn_relu('conv5', width_config[2], width_config[3], 1)
        ]
        self.paf = nn.Conv2d(width_config[3], 38, 1)
        self.heatmap = nn.Conv2d(width_config[3], 19, 1)
        self.network = []
        for i, config in enumerate(self.network_config):
            if isinstance(config, nn.Module):
                self.network.append(config)
                continue
            out_channels, stride, dilation, num_blocks, stage_type = config
            if stride == 2:
                downsample = True
            stage_prefix = 'stage_{}'.format(i - 1)
            blocks = [
                BasicBlock(stage_prefix + '_1', in_channels, out_channels,
                           stride, downsample, dilation)
            ]
            for i in range(1, num_blocks):
                blocks.append(
                    BasicBlock(stage_prefix + '_{}'.format(i + 1),
                               out_channels, out_channels, 1, False, dilation))
            self.network += [nn.Sequential(*blocks)]

            in_channels = out_channels
        self.network = nn.Sequential(*self.network)

        for name, m in self.named_modules():
            if any(
                    map(lambda x: isinstance(m, x),
                        [nn.Linear, nn.Conv1d, nn.Conv2d])):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
Esempio n. 2
0
 def __init__(self, name, in_channels, out_channels, stride, downsample,
              dilation):
     super(BasicBlock, self).__init__()
     self.g_name = name
     self.in_channels = in_channels
     self.stride = stride
     self.downsample = downsample
     channels = out_channels // 2
     if not self.downsample and self.stride == 1:
         assert in_channels == out_channels
         self.conv = nn.Sequential(
             slim.conv_bn_relu(name + '/conv1', channels, channels, 1),
             slim.conv_bn(name + '/conv2',
                          channels,
                          channels,
                          3,
                          stride=stride,
                          dilation=dilation,
                          padding=dilation,
                          groups=channels),
             slim.conv_bn_relu(name + '/conv3', channels, channels, 1),
         )
     else:
         self.conv = nn.Sequential(
             slim.conv_bn_relu(name + '/conv1', in_channels, channels, 1),
             slim.conv_bn(name + '/conv2',
                          channels,
                          channels,
                          3,
                          stride=stride,
                          dilation=dilation,
                          padding=dilation,
                          groups=channels),
             slim.conv_bn_relu(name + '/conv3', channels, channels, 1),
         )
         self.conv0 = nn.Sequential(
             slim.conv_bn(name + '/conv4',
                          in_channels,
                          in_channels,
                          3,
                          stride=stride,
                          dilation=dilation,
                          padding=dilation,
                          groups=in_channels),
             slim.conv_bn_relu(name + '/conv5', in_channels, channels, 1),
         )
     self.shuffle = slim.channel_shuffle(name + '/shuffle', 2)