コード例 #1
0
    def __init__(self,
                 in_channels,
                 num_points=(2048, 1024, 512, 256),
                 radius=(0.2, 0.4, 0.8, 1.2),
                 num_samples=(64, 32, 16, 16),
                 sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
                              (128, 128, 256)),
                 fp_channels=((256, 256), (256, 256)),
                 norm_cfg=dict(type='BN2d'),
                 pool_mod='max',
                 use_xyz=True,
                 normalize_xyz=True):
        super().__init__()

        self.num_sa = len(sa_channels)
        self.num_fp = len(fp_channels)

        assert len(num_points) == len(radius) == len(num_samples) == len(
            sa_channels)
        assert len(sa_channels) >= len(fp_channels)
        assert pool_mod in ['max', 'avg']

        self.SA_modules = nn.ModuleList()
        sa_in_channel = in_channels - 3  # number of channels without xyz
        skip_channel_list = [sa_in_channel]

        for sa_index in range(self.num_sa):
            cur_sa_mlps = list(sa_channels[sa_index])
            cur_sa_mlps = [sa_in_channel] + cur_sa_mlps
            sa_out_channel = cur_sa_mlps[-1]

            self.SA_modules.append(
                PointSAModule(
                    num_point=num_points[sa_index],
                    radius=radius[sa_index],
                    num_sample=num_samples[sa_index],
                    mlp_channels=cur_sa_mlps,
                    norm_cfg=norm_cfg,
                    use_xyz=use_xyz,
                    pool_mod=pool_mod,
                    normalize_xyz=normalize_xyz))
            skip_channel_list.append(sa_out_channel)
            sa_in_channel = sa_out_channel

        self.FP_modules = nn.ModuleList()

        fp_source_channel = skip_channel_list.pop()
        fp_target_channel = skip_channel_list.pop()
        for fp_index in range(len(fp_channels)):
            cur_fp_mlps = list(fp_channels[fp_index])
            cur_fp_mlps = [fp_source_channel + fp_target_channel] + cur_fp_mlps
            self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))
            if fp_index != len(fp_channels) - 1:
                fp_source_channel = cur_fp_mlps[-1]
                fp_target_channel = skip_channel_list.pop()
コード例 #2
0
    def __init__(self,
                 fp_channels=((768, 256, 256), (384, 256, 256),
                              (320, 256, 128), (128, 128, 128, 128)),
                 **kwargs):
        super(PointNet2Head, self).__init__(**kwargs)

        self.num_fp = len(fp_channels)
        self.FP_modules = nn.ModuleList()
        for cur_fp_mlps in fp_channels:
            self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))

        # https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40
        self.pre_seg_conv = ConvModule(fp_channels[-1][-1],
                                       self.channels,
                                       kernel_size=1,
                                       bias=True,
                                       conv_cfg=self.conv_cfg,
                                       norm_cfg=self.norm_cfg,
                                       act_cfg=self.act_cfg)
コード例 #3
0
ファイル: modules.py プロジェクト: Fawkes7/mmdetection3d
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, with_se=False, normalize=True, eps=0,
                                width_multiplier=1, voxel_resolution_multiplier=1):
    r, vr = width_multiplier, voxel_resolution_multiplier
    fp_layers = []
    for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
        fp_blocks = []
        out_channels = list(int(r * oc) for oc in fp_configs)
        fp_blocks.append(PointFPModule([in_channels + sa_in_channels[-1 - fp_idx]] + out_channels))
        in_channels = out_channels[-1]
        if conv_configs is not None:
            out_channels, num_blocks, voxel_resolution = conv_configs
            out_channels = int(r * out_channels)
            if voxel_resolution is None:
                block = SharedMLP
            else:
                block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution),
                                          with_se=with_se, normalize=normalize, eps=eps)
            for _ in range(num_blocks):
                fp_blocks.append(block(in_channels, out_channels))
                in_channels = out_channels
        fp_layers.append(nn.ModuleList(fp_blocks))

    return fp_layers, in_channels
コード例 #4
0
def test_pointnet_fp_module():
    if not torch.cuda.is_available():
        pytest.skip()
    from mmdet3d.ops import PointFPModule

    self = PointFPModule(mlp_channels=[24, 16]).cuda()
    assert self.mlps.layer0.conv.in_channels == 24
    assert self.mlps.layer0.conv.out_channels == 16

    xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin',
                      np.float32).reshape((-1, 6))

    # (B, N, 3)
    xyz1 = torch.from_numpy(xyz[0::2, :3]).view(1, -1, 3).cuda()
    # (B, C1, N)
    features1 = xyz1.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()

    # (B, M, 3)
    xyz2 = torch.from_numpy(xyz[1::3, :3]).view(1, -1, 3).cuda()
    # (B, C2, N)
    features2 = xyz2.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()

    fp_features = self(xyz1, xyz2, features1, features2)
    assert fp_features.shape == torch.Size([1, 16, 50])
コード例 #5
0
    def __init__(self,
                 in_channels,
                 num_points=(2048, 1024, 512, 256),
                 radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
                 num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
                 sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
                              ((64, 64, 128), (64, 64, 128), (64, 96, 128)),
                              ((128, 128, 256), (128, 192, 256), (128, 256,
                                                                  256))),
                 aggregation_channels=(64, 128, 256),
                 fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
                 fps_sample_range_lists=((-1), (-1), (512, -1)),
                 dilated_group=(True, True, True),
                 fp_channels=((256, 256), (256, 256)),
                 norm_cfg=dict(type='BN2d'),
                 sa_cfg=dict(type='PointSAModuleMSG',
                             pool_mod='max',
                             use_xyz=True,
                             normalize_xyz=False)):
        super().__init__()
        self.num_sa = len(sa_channels)
        self.num_fp = len(fp_channels)

        assert len(num_points) == len(radii) == len(num_samples) == len(
            sa_channels) == len(aggregation_channels)

        self.SA_modules = nn.ModuleList()
        self.aggregation_mlps = nn.ModuleList()
        sa_in_channel = in_channels - 3  # number of channels without xyz
        skip_channel_list = [sa_in_channel]

        for sa_index in range(self.num_sa):
            cur_sa_mlps = list(sa_channels[sa_index])
            sa_out_channel = 0
            for radius_index in range(len(radii[sa_index])):
                cur_sa_mlps[radius_index] = [sa_in_channel] + list(
                    cur_sa_mlps[radius_index])
                sa_out_channel += cur_sa_mlps[radius_index][-1]

            if isinstance(fps_mods[sa_index], tuple):
                cur_fps_mod = list(fps_mods[sa_index])
            else:
                cur_fps_mod = list([fps_mods[sa_index]])

            if isinstance(fps_sample_range_lists[sa_index], tuple):
                cur_fps_sample_range_list = list(
                    fps_sample_range_lists[sa_index])
            else:
                cur_fps_sample_range_list = list(
                    [fps_sample_range_lists[sa_index]])
            if num_points[sa_index] != None:
                self.SA_modules.append(
                    build_sa_module(
                        num_point=num_points[sa_index],
                        radii=radii[sa_index],
                        sample_nums=num_samples[sa_index],
                        mlp_channels=cur_sa_mlps,
                        fps_mod=cur_fps_mod,
                        fps_sample_range_list=cur_fps_sample_range_list,
                        dilated_group=dilated_group[sa_index],
                        norm_cfg=norm_cfg,
                        cfg=sa_cfg,
                        bias=True))
            else:
                self.SA_modules.append(
                    build_sa_module(
                        num_point=None,
                        radius=None,
                        num_sample=None,
                        mlp_channels=[sa_in_channel] + cur_sa_mlps,
                        norm_cfg=norm_cfg,
                        cfg=dict(type='PointSAModule',
                                 pool_mod=sa_cfg.get('pool_mod'),
                                 use_xyz=sa_cfg.get('use_xyz'),
                                 normalize_xyz=sa_cfg.get('normalize_xyz'))))
            if aggregation_channels[sa_index] is not None:
                self.aggregation_mlps.append(
                    ConvModule(sa_out_channel,
                               aggregation_channels[sa_index],
                               conv_cfg=dict(type='Conv1d'),
                               norm_cfg=dict(type='BN1d'),
                               kernel_size=1,
                               bias=True))
                sa_in_channel = aggregation_channels[sa_index]
            else:
                self.aggregation_mlps.append(None)
                sa_in_channel = sa_out_channel
            skip_channel_list.append(sa_in_channel)

        self.FP_modules = nn.ModuleList()
        fp_source_channel = skip_channel_list.pop()
        fp_target_channel = skip_channel_list.pop()
        for fp_index in range(len(fp_channels)):
            cur_fp_mlps = list(fp_channels[fp_index])
            cur_fp_mlps = [fp_source_channel + fp_target_channel] + cur_fp_mlps
            self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))
            if fp_index != len(fp_channels) - 1:
                fp_source_channel = cur_fp_mlps[-1]
                fp_target_channel = skip_channel_list.pop()