Ejemplo n.º 1
0
def make_layers(cfg, network_width_multiplier, batch_norm=False, groups=1):
    layers = []
    in_channels = 3

    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            if in_channels == 3:
                conv2d = nl.SharableConv2d(in_channels,
                                           int(v * network_width_multiplier),
                                           kernel_size=3,
                                           padding=1,
                                           bias=False)
            else:
                conv2d = nl.SharableConv2d(in_channels,
                                           int(v * network_width_multiplier),
                                           kernel_size=3,
                                           padding=1,
                                           bias=False,
                                           groups=groups)

            if batch_norm:
                layers += [
                    conv2d,
                    nn.BatchNorm2d(int(v * network_width_multiplier)),
                    nn.ReLU(inplace=True)
                ]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = int(v * network_width_multiplier)

    layers += [
        View(-1,
             int(512 * network_width_multiplier) * 7 * 7),
        nl.SharableLinear(
            int(512 * network_width_multiplier) * 7 * 7,
            int(4096 * network_width_multiplier)),
        nn.ReLU(True),
        # We need Dropout() for 224x224
        nn.Dropout(),
        nl.SharableLinear(int(4096 * network_width_multiplier),
                          int(4096 * network_width_multiplier)),
        nn.ReLU(True),
        nn.Dropout()
    ]

    return nn.Sequential(*layers)
Ejemplo n.º 2
0
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nl.SharableConv2d(in_planes,
                             out_planes,
                             kernel_size=1,
                             stride=stride,
                             bias=False)
Ejemplo n.º 3
0
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nl.SharableConv2d(in_planes,
                             out_planes,
                             kernel_size=3,
                             stride=stride,
                             padding=dilation,
                             groups=groups,
                             bias=False,
                             dilation=dilation)
Ejemplo n.º 4
0
    def __init__(self,
                 block,
                 layers,
                 dataset_history,
                 dataset2num_classes,
                 network_width_multiplier,
                 shared_layer_info,
                 num_classes=1000,
                 zero_init_residual=False,
                 groups=1,
                 width_per_group=64,
                 replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.network_width_multiplier = network_width_multiplier
        self.shared_layer_info = shared_layer_info
        self.inplanes = int(64 * network_width_multiplier)
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(
                                 replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nl.SharableConv2d(3,
                                       self.inplanes,
                                       kernel_size=7,
                                       stride=2,
                                       padding=3,
                                       bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, network_width_multiplier * 64,
                                       layers[0])
        self.layer2 = self._make_layer(block,
                                       network_width_multiplier * 128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       network_width_multiplier * 256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       network_width_multiplier * 512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.datasets, self.classifiers = dataset_history, nn.ModuleList()
        self.dataset2num_classes = dataset2num_classes
        # we delete default self.classifier for imagenet, because we manually add it in packnet_imagenet_main.py

        if self.datasets:
            self._reconstruct_classifiers()

        for m in self.modules():
            if isinstance(m, nl.SharableConv2d):
                #nn.init.constant_(m.weight, 0)
                nn.init.normal_(m.weight, 0, 0.001)
                #nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)