예제 #1
0
파일: tpn.py 프로젝트: ruyijidan/TPN
    def __init__(
        self,
        in_channels=[1024, 1024],
        mid_channels=[1024, 1024],
        out_channels=2048,
        ds_scales=[(1, 1, 1), (1, 1, 1)],
    ):
        super(LevelFusion, self).__init__()

        ops = []
        num_ins = len(in_channels)
        for i in range(num_ins):
            op = Downampling(in_channels[i],
                             mid_channels[i],
                             kernel_size=(1, 1, 1),
                             stride=(1, 1, 1),
                             padding=(0, 0, 0),
                             bias=False,
                             groups=32,
                             norm=True,
                             activation=True,
                             downsample_position='before',
                             downsample_scale=ds_scales[i])
            ops.append(op)
            self.ops = Sequential(*ops)

        in_dims = np.sum(mid_channels)
        self.fusion_conv = Sequential(
            nn.Conv3D(in_dims, out_channels, 1, 1, 0, bias_attr=False),
            nn.BatchNorm(out_channels), Relu())
예제 #2
0
파일: tpn.py 프로젝트: ruyijidan/TPN
    def __init__(
            self,
            inplanes,
            planes,
            kernel_size=(3, 1, 1),
            stride=(1, 1, 1),
            padding=(1, 0, 0),
            bias=False,
            groups=1,
            norm=False,
            activation=False,
            downsample_position='after',
            downsample_scale=(1, 2, 2),
    ):
        super(Downampling, self).__init__()

        self.conv = nn.Conv3D(inplanes,
                              planes,
                              kernel_size,
                              stride,
                              padding,
                              bias_attr=bias,
                              groups=groups)
        self.norm = nn.BatchNorm(planes) if norm else None
        self.relu = Relu() if activation else None
        assert (downsample_position in ['before', 'after'])
        self.downsample_position = downsample_position
        self.pool = MaxPool3D(downsample_scale,
                              downsample_scale, (0, 0, 0),
                              ceil_mode=True)
예제 #3
0
def make_res_layer(block,
                   inplanes,
                   planes,
                   blocks,
                   spatial_stride=1,
                   temporal_stride=1,
                   dilation=1,
                   style='pytorch',
                   inflate_freq=1,
                   inflate_style='3x1x1',
                   nonlocal_freq=1,
                   nonlocal_cfg=None,
                   with_cp=False):
    inflate_freq = inflate_freq if not isinstance(inflate_freq, int) else (inflate_freq,) * blocks
    nonlocal_freq = nonlocal_freq if not isinstance(nonlocal_freq, int) else (nonlocal_freq,) * blocks
    assert len(inflate_freq) == blocks
    assert len(nonlocal_freq) == blocks
    downsample = None
    if spatial_stride != 1 or inplanes != planes * block.expansion:
        downsample = Sequential(
            nn.Conv3D(
                inplanes,
                planes * block.expansion,
                filter_size=1,
                stride=(temporal_stride, spatial_stride, spatial_stride),
                bias_attr=False),
            nn.BatchNorm(planes * block.expansion),
        )

    layers = []
    layers.append(
        block(
            inplanes,
            planes,
            spatial_stride,
            temporal_stride,
            dilation,
            downsample,
            style=style,
            if_inflate=(inflate_freq[0] == 1),
            inflate_style=inflate_style,
            if_nonlocal=(nonlocal_freq[0] == 1),
            nonlocal_cfg=nonlocal_cfg,
            with_cp=with_cp))
    inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(
            block(inplanes,
                  planes,
                  1, 1,
                  dilation,
                  style=style,
                  if_inflate=(inflate_freq[i] == 1),
                  inflate_style=inflate_style,
                  if_nonlocal=(nonlocal_freq[i] == 1),
                  nonlocal_cfg=nonlocal_cfg,
                  with_cp=with_cp))

    return Sequential(*layers)
예제 #4
0
def conv1x3x3(in_planes, out_planes, spatial_stride=1, temporal_stride=1, dilation=1):
    "1x3x3 convolution with padding"
    return nn.Conv3D(
        in_planes,
        out_planes,
        filter_size=(1, 3, 3),
        stride=(temporal_stride, spatial_stride, spatial_stride),
        padding=(0, dilation, dilation),
        dilation=dilation,
        bias_attr=False)
예제 #5
0
def conv3x3x3(in_planes,
              out_planes,
              spatial_stride=1,
              temporal_stride=1,
              dilation=1):
    "3x3x3 convolution with padding"
    return nn.Conv3D(
        num_channels=in_planes,
        num_filters=out_planes,
        filter_size=3,
        stride=(temporal_stride, spatial_stride, spatial_stride),
        padding=dilation,
        dilation=dilation,
        bias_attr=False,
    )
예제 #6
0
파일: tpn.py 프로젝트: ruyijidan/TPN
    def __init__(
        self,
        inplanes,
        planes,
        downsample_scale=8,
    ):
        super(TemporalModulation, self).__init__()

        self.conv = nn.Conv3D(inplanes,
                              planes, (3, 1, 1), (1, 1, 1), (1, 0, 0),
                              bias_attr=False,
                              groups=32)
        self.pool = MaxPool3D((downsample_scale, 1, 1),
                              (downsample_scale, 1, 1), (0, 0, 0),
                              ceil_mode=True)
예제 #7
0
파일: tpn.py 프로젝트: ruyijidan/TPN
 def __init__(
     self,
     inplanes,
     planes,
     kernel_size,
     stride,
     padding,
     bias=False,
     groups=1,
 ):
     super(ConvModule, self).__init__()
     self.conv = nn.Conv3D(inplanes,
                           planes,
                           kernel_size,
                           stride,
                           padding,
                           bias_attr=bias,
                           groups=groups)
     self.bn = nn.BatchNorm(planes)
     self.relu = Relu()
예제 #8
0
파일: tpn.py 프로젝트: ruyijidan/TPN
    def __init__(self,
                 in_channels=[256, 512, 1024, 2048],
                 out_channels=256,
                 spatial_modulation_config=None,
                 temporal_modulation_config=None,
                 upsampling_config=None,
                 downsampling_config=None,
                 level_fusion_config=None,
                 aux_head_config=None,
                 mode=None):
        super(TPN, self).__init__()
        assert isinstance(in_channels, list)
        assert isinstance(out_channels, int)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_ins = len(in_channels)
        self.mode = mode
        # spatial_modulation_config = Config(spatial_modulation_config) if isinstance(spatial_modulation_config,
        #                                                                             dict) else spatial_modulation_config
        # temporal_modulation_config = Config(temporal_modulation_config) if isinstance(temporal_modulation_config,
        #                                                                               dict) else temporal_modulation_config
        # upsampling_config = Config(upsampling_config) if isinstance(upsampling_config, dict) else upsampling_config
        # downsampling_config = Config(downsampling_config) if isinstance(downsampling_config,
        #                                                                 dict) else downsampling_config
        # aux_head_config = Config(aux_head_config) if isinstance(aux_head_config, dict) else aux_head_config
        # level_fusion_config = Config(level_fusion_config) if isinstance(level_fusion_config,
        #                                                                 dict) else level_fusion_config

        # self.temporal_modulation_ops = nn.ModuleList()
        # self.upsampling_ops = nn.ModuleList()
        # self.downsampling_ops = nn.ModuleList()

        temp_modulation_ops = []
        temp_upsampling_ops = []
        temp_downsampling_ops = []
        for i in range(0, self.num_ins, 1):
            inplanes = in_channels[-1]
            planes = out_channels

            if temporal_modulation_config is not None:
                # overwrite the temporal_modulation_config
                # print(temporal_modulation_config)

                temporal_modulation_config['param'][
                    'downsample_scale'] = temporal_modulation_config['scales'][
                        i]
                temporal_modulation_config['param']['inplanes'] = inplanes
                temporal_modulation_config['param']['planes'] = planes
                temporal_modulation = TemporalModulation(
                    **temporal_modulation_config['param'])
                temp_modulation_ops.append(temporal_modulation)
            self.temporal_modulation_ops = Sequential(*temp_modulation_ops)

            if i < self.num_ins - 1:
                if upsampling_config is not None:
                    # overwrite the upsampling_config
                    upsampling = Upsampling(**upsampling_config)
                    temp_upsampling_ops.append(upsampling)
                self.upsampling_ops = Sequential(*temp_upsampling_ops)
                if downsampling_config is not None:
                    # overwrite the downsampling_config
                    downsampling_config['param']['inplanes'] = planes
                    downsampling_config['param']['planes'] = planes
                    downsampling_config['param'][
                        'downsample_scale'] = downsampling_config['scales']
                    downsampling = Downampling(**downsampling_config['param'])
                    temp_downsampling_ops.append(downsampling)
                self.downsampling_ops = Sequential(*temp_downsampling_ops)

        self.level_fusion_op = LevelFusion()  # **level_fusion_config
        self.spatial_modulation = SpatialModulation(
        )  # **spatial_modulation_config
        out_dims = level_fusion_config['out_channels']

        # Two pyramids
        self.level_fusion_op2 = LevelFusion(**level_fusion_config)

        self.pyramid_fusion_op = Sequential(
            nn.Conv3D(out_dims * 2, 2048, 1, 1, 0, bias_attr=False),
            nn.BatchNorm(2048), Relu())

        # overwrite aux_head_config
        if aux_head_config is not None:
            aux_head_config['inplanes'] = self.in_channels[-2]
            self.aux_head = AuxHead(**aux_head_config)
        else:
            self.aux_head = None
예제 #9
0
    def __init__(self,
                 inplanes,
                 planes,
                 spatial_stride=1,
                 temporal_stride=1,
                 dilation=1,
                 downsample=None,
                 style='pytorch',
                 if_inflate=True,
                 inflate_style='3x1x1',
                 if_nonlocal=True,
                 nonlocal_cfg=None,
                 with_cp=False):
        """Bottleneck block for ResNet.
        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
        """
        super(Bottleneck, self).__init__()
        assert style in ['pytorch', 'caffe']
        assert inflate_style in ['3x1x1', '3x3x3']
        self.inplanes = inplanes
        self.planes = planes

        if style == 'pytorch':
            self.conv1_stride = 1
            self.conv2_stride = spatial_stride
            self.conv1_stride_t = 1
            self.conv2_stride_t = temporal_stride
        else:
            self.conv1_stride = spatial_stride
            self.conv2_stride = 1
            self.conv1_stride_t = temporal_stride
            self.conv2_stride_t = 1
        if if_inflate:
            if inflate_style == '3x1x1':
                self.conv1 = nn.Conv3D(
                    inplanes,
                    planes,
                    filter_size=(3, 1, 1),
                    stride=(self.conv1_stride_t, self.conv1_stride, self.conv1_stride),
                    padding=(1, 0, 0),
                    bias_attr=False)
                self.conv2 = nn.Conv3D(
                    planes,
                    planes,
                    filter_size=(1, 3, 3),
                    stride=(self.conv2_stride_t, self.conv2_stride, self.conv2_stride),
                    padding=(0, dilation, dilation),
                    dilation=(1, dilation, dilation),
                    bias_attr=False)
            else:
                self.conv1 = nn.Conv3D(
                    inplanes,
                    planes,
                    filter_size=1,
                    stride=(self.conv1_stride_t, self.conv1_stride, self.conv1_stride),
                    bias_attr=False)
                self.conv2 = nn.Conv3D(
                    planes,
                    planes,
                    filter_size=3,
                    stride=(self.conv2_stride_t, self.conv2_stride, self.conv2_stride),
                    padding=(1, dilation, dilation),
                    dilation=(1, dilation, dilation),
                    bias_attr=False)
        else:
            self.conv1 = nn.Conv3D(
                inplanes,
                planes,
                filter_size=1,
                stride=(1, self.conv1_stride, self.conv1_stride),
                bias_attr=False)
            self.conv2 = nn.Conv3D(
                planes,
                planes,
                filter_size=(1, 3, 3),
                stride=(1, self.conv2_stride, self.conv2_stride),
                padding=(0, dilation, dilation),
                dilation=(1, dilation, dilation),
                bias_attr=False)

        self.bn1 = nn.BatchNorm(planes)
        self.bn2 = nn.BatchNorm(planes)
        self.conv3 = nn.Conv3D(
            planes, planes * self.expansion, filter_size=1, bias_attr=False)
        self.bn3 = nn.BatchNorm(planes * self.expansion)
        self.relu = Relu()
        self.downsample = downsample
        self.spatial_tride = spatial_stride
        self.temporal_tride = temporal_stride
        self.dilation = dilation
        self.with_cp = with_cp
예제 #10
0
    def __init__(self,
                 depth=50,
                 pretrained=None,
                 pretrained2d=True,
                 num_stages=4,
                 spatial_strides=(1, 2, 2, 2),
                 temporal_strides=(1, 1, 1, 1),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
                 conv1_kernel_t=5,
                 conv1_stride_t=2,
                 pool1_kernel_t=1,
                 pool1_stride_t=2,
                 style='pytorch',
                 frozen_stages=-1,
                 inflate_freq=(1, 1, 1, 1),
                 inflate_stride=(1, 1, 1, 1),
                 inflate_style='3x1x1',
                 nonlocal_stages=(-1,),
                 nonlocal_freq=(0, 1, 1, 0),
                 nonlocal_cfg=None,
                 bn_eval=False,
                 bn_frozen=False,
                 partial_bn=False,
                 with_cp=False):
        super(ResNet_SlowFast, self).__init__()
        if depth not in self.arch_settings:
            raise KeyError('invalid depth {} for resnet'.format(depth))
        self.depth = depth
        self.pretrained = pretrained
        self.pretrained2d = pretrained2d
        self.num_stages = num_stages
        assert num_stages >= 1 and num_stages <= 4
        self.spatial_strides = spatial_strides
        self.temporal_strides = temporal_strides
        self.dilations = dilations
        assert len(spatial_strides) == len(temporal_strides) == len(dilations) == num_stages
        self.out_indices = out_indices
        assert max(out_indices) < num_stages
        self.style = style
        self.frozen_stages = frozen_stages
        self.inflate_freqs = inflate_freq if not isinstance(inflate_freq, int) else (inflate_freq,) * num_stages
        self.inflate_style = inflate_style
        self.nonlocal_stages = nonlocal_stages
        self.nonlocal_freqs = nonlocal_freq if not isinstance(nonlocal_freq, int) else (nonlocal_freq,) * num_stages
        self.nonlocal_cfg = nonlocal_cfg
        self.bn_eval = bn_eval
        self.bn_frozen = bn_frozen
        self.partial_bn = partial_bn
        self.with_cp = with_cp

        self.block, stage_blocks = self.arch_settings[depth]
        self.stage_blocks = stage_blocks[:num_stages]
        self.inplanes = 64

        self.conv1 = nn.Conv3D(
            3, 64, filter_size=(conv1_kernel_t, 7, 7), stride=(conv1_stride_t, 2, 2),
            padding=((conv1_kernel_t - 1) // 2, 3, 3), bias_attr=False)
        self.bn1 = BatchNorm(64)#Batch_Norm3D(64)

        self.relu = Relu()
        self.maxpool = MaxPool3D(kernel_size=(pool1_kernel_t, 3, 3), stride=(pool1_stride_t, 2, 2),
                                 padding=(pool1_kernel_t // 2, 1, 1))

        self.res_layers = []
        for i, num_blocks in enumerate(self.stage_blocks):
            spatial_stride = spatial_strides[i]
            temporal_stride = temporal_strides[i]
            dilation = dilations[i]
            planes = 64 * 2 ** i
            res_layer = make_res_layer(
                self.block,
                self.inplanes,
                planes,
                num_blocks,
                spatial_stride=spatial_stride,
                temporal_stride=temporal_stride,
                dilation=dilation,
                style=self.style,
                inflate_freq=self.inflate_freqs[i],
                inflate_style=self.inflate_style,
                nonlocal_freq=self.nonlocal_freqs[i],
                nonlocal_cfg=self.nonlocal_cfg if i in self.nonlocal_stages else None,
                with_cp=with_cp)
            self.inplanes = planes * self.block.expansion
            layer_name = 'layer{}'.format(i + 1)
            self.add_sublayer(layer_name, res_layer)
            self.res_layers.append(layer_name)

        self.feat_dim = self.block.expansion * 64 * 2 ** (
                len(self.stage_blocks) - 1)