Пример #1
0
    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None,
                 sparse=False):
        super(BasicBlock, self).__init__()
        assert groups == 1
        assert dilation == 1
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.sparse = sparse

        if sparse:
            # in the resnet basic block, the first convolution is already strided, so mask_stride = 1
            self.masker = dynconv.MaskUnit(channels=inplanes,
                                           stride=stride,
                                           dilate_stride=1)

        self.fast = False
Пример #2
0
    def __init__(self, cfg, inp, oup, stride=1, expand_ratio=6, sparse=False):
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.identity = stride == 1 and inp == oup
        self.expand_ratio = expand_ratio
        self.sparse = sparse
        print(
            f'Inverted Residual - sparse: {sparse}: inp {inp}, hidden_dim {hidden_dim}, '
            + f'oup {oup}, stride {stride}, expand_ratio {expand_ratio}')

        layers = []
        if expand_ratio != 1:
            layers.extend([
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim, momentum=BN_MOMENTUM),
                nn.ReLU6(inplace=True)
            ])
        layers.extend([
            # dw
            nn.Conv2d(hidden_dim,
                      hidden_dim,
                      3,
                      stride,
                      1,
                      groups=hidden_dim,
                      bias=False),
            nn.BatchNorm2d(hidden_dim, momentum=BN_MOMENTUM),
            nn.ReLU6(inplace=True),
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup, momentum=BN_MOMENTUM),
        ])
        self.conv = nn.Sequential(*layers)

        if sparse:
            assert self.identity
            assert expand_ratio != 1
            self.masker = dynconv.MaskUnit(inp,
                                           stride=stride,
                                           dilate_stride=stride)
        else:
            self.masker = None
Пример #3
0
    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None,
                 sparse=False):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups

        print(
            f'Bottleneck - sparse: {sparse}: inp {inplanes}, hidden_dim {width}, '
            + f'oup {planes * self.expansion}, stride {stride}')

        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.sparse = sparse
        self.fast = True

        if sparse:
            self.masker = dynconv.MaskUnit(channels=inplanes,
                                           stride=stride,
                                           dilate_stride=stride)