示例#1
0
 def __init__(self, block, num_segments, non_local_cfg=dict()):
     super(NL3DWrapper, self).__init__()
     self.block = block
     self.non_local_cfg = non_local_cfg
     self.non_local_block = NonLocal3d(self.block.conv3.norm.num_features,
                                       **self.non_local_cfg)
     self.num_segments = num_segments
示例#2
0
    def __init__(self,
                 inplanes,
                 planes,
                 spatial_stride=1,
                 temporal_stride=1,
                 dilation=1,
                 downsample=None,
                 style='pytorch',
                 inflate=True,
                 inflate_style='3x1x1',
                 non_local=False,
                 non_local_cfg=dict(),
                 conv_cfg=dict(type='Conv3d'),
                 norm_cfg=dict(type='BN3d'),
                 act_cfg=dict(type='ReLU'),
                 with_cp=False):
        super().__init__()
        assert style in ['pytorch', 'caffe']
        assert inflate_style in ['3x1x1', '3x3x3']

        self.inplanes = inplanes
        self.planes = planes
        self.spatial_stride = spatial_stride
        self.temporal_stride = temporal_stride
        self.dilation = dilation
        self.style = style
        self.inflate = inflate
        self.inflate_style = inflate_style
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        self.with_cp = with_cp
        self.non_local = non_local
        self.non_local_cfg = non_local_cfg

        self.conv1_stride_s = spatial_stride
        self.conv2_stride_s = 1
        self.conv1_stride_t = temporal_stride
        self.conv2_stride_t = 1

        if self.inflate:
            conv1_kernel_size = (3, 3, 3)
            conv1_padding = (1, dilation, dilation)
            conv2_kernel_size = (3, 3, 3)
            conv2_padding = (1, 1, 1)
        else:
            conv1_kernel_size = (1, 3, 3)
            conv1_padding = (0, dilation, dilation)
            conv2_kernel_size = (1, 3, 3)
            conv2_padding = (0, 1, 1)

        self.conv1 = ConvModule(inplanes,
                                planes,
                                conv1_kernel_size,
                                stride=(self.conv1_stride_t,
                                        self.conv1_stride_s,
                                        self.conv1_stride_s),
                                padding=conv1_padding,
                                dilation=(1, dilation, dilation),
                                bias=False,
                                conv_cfg=self.conv_cfg,
                                norm_cfg=self.norm_cfg,
                                act_cfg=self.act_cfg)

        self.conv2 = ConvModule(planes,
                                planes * self.expansion,
                                conv2_kernel_size,
                                stride=(self.conv2_stride_t,
                                        self.conv2_stride_s,
                                        self.conv2_stride_s),
                                padding=conv2_padding,
                                bias=False,
                                conv_cfg=self.conv_cfg,
                                norm_cfg=self.norm_cfg,
                                act_cfg=None)

        self.downsample = downsample
        self.relu = build_activation_layer(self.act_cfg)

        if self.non_local:
            self.non_local_block = NonLocal3d(self.conv2.norm.num_features,
                                              **self.non_local_cfg)
示例#3
0
def test_nonlocal():
    with pytest.raises(ValueError):
        # mode should be in ['embedded_gaussian', 'dot_product']
        _NonLocalNd(3, mode='unsupport_mode')

    # _NonLocalNd
    _NonLocalNd(3, norm_cfg=dict(type='BN'))
    # Not Zero initialization
    _NonLocalNd(3, norm_cfg=dict(type='BN'), zeros_init=True)

    # NonLocal3d
    imgs = torch.randn(2, 3, 10, 20, 20)
    nonlocal_3d = NonLocal3d(3)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            # NonLocal is only implemented on gpu in parrots
            imgs = imgs.cuda()
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    nonlocal_3d = NonLocal3d(3, mode='dot_product')
    assert nonlocal_3d.mode == 'dot_product'
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
    for m in [nonlocal_3d.g, nonlocal_3d.phi]:
        assert isinstance(m, nn.Sequential) and len(m) == 2
        assert isinstance(m[1], nn.MaxPool3d)
        assert m[1].kernel_size == (1, 2, 2)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    # NonLocal2d
    imgs = torch.randn(2, 3, 20, 20)
    nonlocal_2d = NonLocal2d(3)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            nonlocal_2d.cuda()
    out = nonlocal_2d(imgs)
    assert out.shape == imgs.shape

    nonlocal_2d = NonLocal2d(3, mode='dot_product', sub_sample=True)
    for m in [nonlocal_2d.g, nonlocal_2d.phi]:
        assert isinstance(m, nn.Sequential) and len(m) == 2
        assert isinstance(m[1], nn.MaxPool2d)
        assert m[1].kernel_size == (2, 2)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_2d.cuda()
    out = nonlocal_2d(imgs)
    assert out.shape == imgs.shape

    # NonLocal1d
    imgs = torch.randn(2, 3, 20)
    nonlocal_1d = NonLocal1d(3)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            nonlocal_1d.cuda()
    out = nonlocal_1d(imgs)
    assert out.shape == imgs.shape

    nonlocal_1d = NonLocal1d(3, mode='dot_product', sub_sample=True)
    for m in [nonlocal_1d.g, nonlocal_1d.phi]:
        assert isinstance(m, nn.Sequential) and len(m) == 2
        assert isinstance(m[1], nn.MaxPool1d)
        assert m[1].kernel_size == 2
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_1d.cuda()
    out = nonlocal_1d(imgs)
    assert out.shape == imgs.shape
示例#4
0
def test_nonlocal3d():
    # NonLocal3d with 'embedded_gaussian' mode
    imgs = torch.randn(2, 3, 10, 20, 20)
    nonlocal_3d = NonLocal3d(3)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            # NonLocal is only implemented on gpu in parrots
            imgs = imgs.cuda()
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    # NonLocal3d with 'dot_product' mode
    nonlocal_3d = NonLocal3d(3, mode='dot_product')
    assert nonlocal_3d.mode == 'dot_product'
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    # NonLocal3d with 'concatenation' mode
    nonlocal_3d = NonLocal3d(3, mode='concatenation')
    assert nonlocal_3d.mode == 'concatenation'
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    # NonLocal3d with 'gaussian' mode
    nonlocal_3d = NonLocal3d(3, mode='gaussian')
    assert not hasattr(nonlocal_3d, 'phi')
    assert nonlocal_3d.mode == 'gaussian'
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    # NonLocal3d with 'gaussian' mode and sub_sample
    nonlocal_3d = NonLocal3d(3, mode='gaussian', sub_sample=True)
    assert isinstance(nonlocal_3d.g, nn.Sequential) and len(nonlocal_3d.g) == 2
    assert isinstance(nonlocal_3d.g[1], nn.MaxPool3d)
    assert nonlocal_3d.g[1].kernel_size == (1, 2, 2)
    assert isinstance(nonlocal_3d.phi, nn.MaxPool3d)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape

    # NonLocal3d with 'dot_product' mode and sub_sample
    nonlocal_3d = NonLocal3d(3, mode='dot_product', sub_sample=True)
    for m in [nonlocal_3d.g, nonlocal_3d.phi]:
        assert isinstance(m, nn.Sequential) and len(m) == 2
        assert isinstance(m[1], nn.MaxPool3d)
        assert m[1].kernel_size == (1, 2, 2)
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            nonlocal_3d.cuda()
    out = nonlocal_3d(imgs)
    assert out.shape == imgs.shape