def test_pointnet_sa_module():
    if not torch.cuda.is_available():
        pytest.skip()
    from mmdet3d.ops import build_sa_module
    sa_cfg = dict(
        type='PointSAModule',
        num_point=16,
        radius=0.2,
        num_sample=8,
        mlp_channels=[12, 32],
        norm_cfg=dict(type='BN2d'),
        use_xyz=True,
        pool_mod='max')
    self = build_sa_module(sa_cfg).cuda()

    assert self.mlps[0].layer0.conv.in_channels == 15
    assert self.mlps[0].layer0.conv.out_channels == 32

    xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)

    # (B, N, 3)
    xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda()
    # (B, C, N)
    features = xyz.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()

    # test forward
    new_xyz, new_features, inds = self(xyz, features)
    assert new_xyz.shape == torch.Size([1, 16, 3])
    assert new_features.shape == torch.Size([1, 32, 16])
    assert inds.shape == torch.Size([1, 16])

    # can't set normalize_xyz when radius is None
    with pytest.raises(AssertionError):
        sa_cfg = dict(
            type='PointSAModule',
            num_point=16,
            radius=None,
            num_sample=8,
            mlp_channels=[12, 32],
            norm_cfg=dict(type='BN2d'),
            use_xyz=True,
            pool_mod='max',
            normalize_xyz=True)
        self = build_sa_module(sa_cfg)

    # test kNN sampling when radius is None
    sa_cfg['normalize_xyz'] = False
    self = build_sa_module(sa_cfg).cuda()

    xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)

    xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda()
    features = xyz.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()
    new_xyz, new_features, inds = self(xyz, features)
    assert new_xyz.shape == torch.Size([1, 16, 3])
    assert new_features.shape == torch.Size([1, 32, 16])
    assert inds.shape == torch.Size([1, 16])
def test_paconv_sa_module():
    if not torch.cuda.is_available():
        pytest.skip()
    from mmdet3d.ops import build_sa_module
    sa_cfg = dict(
        type='PAConvSAModule',
        num_point=16,
        radius=0.2,
        num_sample=8,
        mlp_channels=[12, 32],
        paconv_num_kernels=[8],
        norm_cfg=dict(type='BN2d'),
        use_xyz=True,
        pool_mod='max',
        paconv_kernel_input='w_neighbor')
    self = build_sa_module(sa_cfg).cuda()

    assert self.mlps[0].layer0.in_channels == 15 * 2
    assert self.mlps[0].layer0.out_channels == 32
    assert self.mlps[0].layer0.num_kernels == 8

    xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)

    # (B, N, 3)
    xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda()
    # (B, C, N)
    features = xyz.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()

    # test forward
    new_xyz, new_features, inds = self(xyz, features)
    assert new_xyz.shape == torch.Size([1, 16, 3])
    assert new_features.shape == torch.Size([1, 32, 16])
    assert inds.shape == torch.Size([1, 16])

    # test kNN sampling when radius is None
    sa_cfg = dict(
        type='PAConvSAModule',
        num_point=16,
        radius=None,
        num_sample=8,
        mlp_channels=[12, 32],
        paconv_num_kernels=[8],
        norm_cfg=dict(type='BN2d'),
        use_xyz=True,
        pool_mod='max',
        paconv_kernel_input='identity')
    self = build_sa_module(sa_cfg).cuda()
    assert self.mlps[0].layer0.in_channels == 15 * 1

    xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)

    xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda()
    features = xyz.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()
    new_xyz, new_features, inds = self(xyz, features)
    assert new_xyz.shape == torch.Size([1, 16, 3])
    assert new_features.shape == torch.Size([1, 32, 16])
    assert inds.shape == torch.Size([1, 16])
    def __init__(self,
                 rep_type=None,
                 density=None,
                 seed_feat_dim=None,
                 sa_cfg=dict(type='PointSAModule',
                             pool_mod='max',
                             use_xyz=True,
                             normalize_xyz=True),
                 sa_radius=0.2,
                 sa_num_sample=16,
                 num_seed_points=1024):
        super(RepPointRoIExtractor, self).__init__()
        self.rep_type = rep_type
        self.density = density
        self.num_rep_points = density * 6 if rep_type == 'ray' else density**3

        self.seed_aggregation = build_sa_module(
            cfg=sa_cfg,
            num_point=num_seed_points,
            radius=sa_radius,
            num_sample=sa_num_sample,
            mlp_channels=[seed_feat_dim, 128, 64, 32],
            norm_cfg=dict(type='BN2d'))

        self.reduce_dim = torch.nn.Conv1d(self.num_rep_points * 32, 128, 1)
Exemple #4
0
    def __init__(self,
                 num_classes,
                 bbox_coder,
                 train_cfg=None,
                 test_cfg=None,
                 vote_module_cfg=None,
                 vote_aggregation_cfg=None,
                 pred_layer_cfg=None,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 center_loss_mse=None,
                 dir_class_loss=None,
                 dir_res_loss=None,
                 size_class_loss=None,
                 size_res_loss=None,
                 semantic_loss=None,
                 iou_loss=None):
        super(VoteHead, self).__init__()
        self.num_classes = num_classes
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_module_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']

        self.objectness_loss = build_loss(objectness_loss)
        self.dir_res_loss = build_loss(dir_res_loss)
        self.dir_class_loss = build_loss(dir_class_loss)
        self.size_res_loss = build_loss(size_res_loss)
        if size_class_loss is not None:
            self.size_class_loss = build_loss(size_class_loss)
        if semantic_loss is not None:
            self.semantic_loss = build_loss(semantic_loss)
        if iou_loss is not None:
            self.iou_loss = build_loss(iou_loss)
        else:
            self.iou_loss = None
        if center_loss is not None:
            self.center_loss = build_loss(center_loss)
        if center_loss_mse is not None:
            self.center_loss_mse = build_loss(center_loss_mse)

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.num_sizes = self.bbox_coder.num_sizes
        self.num_dir_bins = self.bbox_coder.num_dir_bins

        self.vote_module = VoteModule(**vote_module_cfg)
        self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
        self.fp16_enabled = False

        # Bbox classification and regression
        self.conv_pred = BaseConvBboxHead(
            **pred_layer_cfg,
            num_cls_out_channels=self._get_cls_out_channels(),
            num_reg_out_channels=self._get_reg_out_channels())
    def __init__(self,
                 in_channels,
                 num_points=(2048, 1024, 512, 256),
                 radius=(0.2, 0.4, 0.8, 1.2),
                 num_samples=(64, 32, 16, 16),
                 sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
                              (128, 128, 256)),
                 fp_channels=((256, 256), (256, 256)),
                 norm_cfg=dict(type='BN2d'),
                 sa_cfg=dict(
                     type='PointSAModule',
                     pool_mod='max',
                     use_xyz=True,
                     normalize_xyz=True)):
        super().__init__()
        self.num_sa = len(sa_channels)
        self.num_fp = len(fp_channels)

        assert len(num_points) == len(radius) == len(num_samples) == len(
            sa_channels)
        assert len(sa_channels) >= len(fp_channels)

        self.SA_modules = nn.ModuleList()
        sa_in_channel = in_channels - 3  # number of channels without xyz
        skip_channel_list = [sa_in_channel]

        for sa_index in range(self.num_sa):
            cur_sa_mlps = list(sa_channels[sa_index])
            cur_sa_mlps = [sa_in_channel] + cur_sa_mlps
            sa_out_channel = cur_sa_mlps[-1]

            self.SA_modules.append(
                build_sa_module(
                    num_point=num_points[sa_index],
                    radius=radius[sa_index],
                    num_sample=num_samples[sa_index],
                    mlp_channels=cur_sa_mlps,
                    norm_cfg=norm_cfg,
                    cfg=sa_cfg))
            skip_channel_list.append(sa_out_channel)
            sa_in_channel = sa_out_channel

        self.FP_modules = nn.ModuleList()

        fp_source_channel = skip_channel_list.pop()
        fp_target_channel = skip_channel_list.pop()
        for fp_index in range(len(fp_channels)):
            cur_fp_mlps = list(fp_channels[fp_index])
            cur_fp_mlps = [fp_source_channel + fp_target_channel] + cur_fp_mlps
            self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))
            if fp_index != len(fp_channels) - 1:
                fp_source_channel = cur_fp_mlps[-1]
                fp_target_channel = skip_channel_list.pop()
Exemple #6
0
    def __init__(self,
                 num_dims,
                 num_classes,
                 primitive_mode,
                 train_cfg=None,
                 test_cfg=None,
                 vote_module_cfg=None,
                 vote_aggregation_cfg=None,
                 feat_channels=(128, 128),
                 upper_thresh=100.0,
                 surface_thresh=0.5,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 semantic_reg_loss=None,
                 semantic_cls_loss=None,
                 init_cfg=None):
        super(PrimitiveHead, self).__init__(init_cfg=init_cfg)
        assert primitive_mode in ['z', 'xy', 'line']
        # The dimension of primitive semantic information.
        self.num_dims = num_dims
        self.num_classes = num_classes
        self.primitive_mode = primitive_mode
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_module_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']
        self.upper_thresh = upper_thresh
        self.surface_thresh = surface_thresh

        self.objectness_loss = build_loss(objectness_loss)
        self.center_loss = build_loss(center_loss)
        self.semantic_reg_loss = build_loss(semantic_reg_loss)
        self.semantic_cls_loss = build_loss(semantic_cls_loss)

        assert vote_aggregation_cfg['mlp_channels'][0] == vote_module_cfg[
            'in_channels']

        # Primitive existence flag prediction
        self.flag_conv = ConvModule(
            vote_module_cfg['conv_channels'][-1],
            vote_module_cfg['conv_channels'][-1] // 2,
            1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            bias=True,
            inplace=True)
        self.flag_pred = torch.nn.Conv1d(
            vote_module_cfg['conv_channels'][-1] // 2, 2, 1)

        self.vote_module = VoteModule(**vote_module_cfg)
        self.vote_aggregation = build_sa_module(vote_aggregation_cfg)

        prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
        conv_pred_list = list()
        for k in range(len(feat_channels)):
            conv_pred_list.append(
                ConvModule(
                    prev_channel,
                    feat_channels[k],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    bias=True,
                    inplace=True))
            prev_channel = feat_channels[k]
        self.conv_pred = nn.Sequential(*conv_pred_list)

        conv_out_channel = 3 + num_dims + num_classes
        self.conv_pred.add_module('conv_out',
                                  nn.Conv1d(prev_channel, conv_out_channel, 1))
Exemple #7
0
    def __init__(self,
                 num_classes,
                 suface_matching_cfg,
                 line_matching_cfg,
                 bbox_coder,
                 train_cfg=None,
                 test_cfg=None,
                 gt_per_seed=1,
                 num_proposal=256,
                 feat_channels=(128, 128),
                 primitive_feat_refine_streams=2,
                 primitive_refine_channels=[128, 128, 128],
                 upper_thresh=100.0,
                 surface_thresh=0.5,
                 line_thresh=0.5,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 dir_class_loss=None,
                 dir_res_loss=None,
                 size_class_loss=None,
                 size_res_loss=None,
                 semantic_loss=None,
                 cues_objectness_loss=None,
                 cues_semantic_loss=None,
                 proposal_objectness_loss=None,
                 primitive_center_loss=None):
        super(H3DBboxHead, self).__init__()
        self.num_classes = num_classes
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = gt_per_seed
        self.num_proposal = num_proposal
        self.with_angle = bbox_coder['with_rot']
        self.upper_thresh = upper_thresh
        self.surface_thresh = surface_thresh
        self.line_thresh = line_thresh

        self.objectness_loss = build_loss(objectness_loss)
        self.center_loss = build_loss(center_loss)
        self.dir_class_loss = build_loss(dir_class_loss)
        self.dir_res_loss = build_loss(dir_res_loss)
        self.size_class_loss = build_loss(size_class_loss)
        self.size_res_loss = build_loss(size_res_loss)
        self.semantic_loss = build_loss(semantic_loss)

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.num_sizes = self.bbox_coder.num_sizes
        self.num_dir_bins = self.bbox_coder.num_dir_bins

        self.cues_objectness_loss = build_loss(cues_objectness_loss)
        self.cues_semantic_loss = build_loss(cues_semantic_loss)
        self.proposal_objectness_loss = build_loss(proposal_objectness_loss)
        self.primitive_center_loss = build_loss(primitive_center_loss)

        assert suface_matching_cfg['mlp_channels'][-1] == \
            line_matching_cfg['mlp_channels'][-1]

        # surface center matching
        self.surface_center_matcher = build_sa_module(suface_matching_cfg)
        # line center matching
        self.line_center_matcher = build_sa_module(line_matching_cfg)

        # Compute the matching scores
        matching_feat_dims = suface_matching_cfg['mlp_channels'][-1]
        self.matching_conv = ConvModule(matching_feat_dims,
                                        matching_feat_dims,
                                        1,
                                        padding=0,
                                        conv_cfg=conv_cfg,
                                        norm_cfg=norm_cfg,
                                        bias=True,
                                        inplace=True)
        self.matching_pred = nn.Conv1d(matching_feat_dims, 2, 1)

        # Compute the semantic matching scores
        self.semantic_matching_conv = ConvModule(matching_feat_dims,
                                                 matching_feat_dims,
                                                 1,
                                                 padding=0,
                                                 conv_cfg=conv_cfg,
                                                 norm_cfg=norm_cfg,
                                                 bias=True,
                                                 inplace=True)
        self.semantic_matching_pred = nn.Conv1d(matching_feat_dims, 2, 1)

        # Surface feature aggregation
        self.surface_feats_aggregation = list()
        for k in range(primitive_feat_refine_streams):
            self.surface_feats_aggregation.append(
                ConvModule(matching_feat_dims,
                           matching_feat_dims,
                           1,
                           padding=0,
                           conv_cfg=conv_cfg,
                           norm_cfg=norm_cfg,
                           bias=True,
                           inplace=True))
        self.surface_feats_aggregation = nn.Sequential(
            *self.surface_feats_aggregation)

        # Line feature aggregation
        self.line_feats_aggregation = list()
        for k in range(primitive_feat_refine_streams):
            self.line_feats_aggregation.append(
                ConvModule(matching_feat_dims,
                           matching_feat_dims,
                           1,
                           padding=0,
                           conv_cfg=conv_cfg,
                           norm_cfg=norm_cfg,
                           bias=True,
                           inplace=True))
        self.line_feats_aggregation = nn.Sequential(
            *self.line_feats_aggregation)

        # surface center(6) + line center(12)
        prev_channel = 18 * matching_feat_dims
        self.bbox_pred = nn.ModuleList()
        for k in range(len(primitive_refine_channels)):
            self.bbox_pred.append(
                ConvModule(prev_channel,
                           primitive_refine_channels[k],
                           1,
                           padding=0,
                           conv_cfg=conv_cfg,
                           norm_cfg=norm_cfg,
                           bias=True,
                           inplace=False))
            prev_channel = primitive_refine_channels[k]

        # Final object detection
        # Objectness scores (2), center residual (3),
        # heading class+residual (num_heading_bin*2), size class +
        # residual(num_size_cluster*4)
        conv_out_channel = (2 + 3 + bbox_coder['num_dir_bins'] * 2 +
                            bbox_coder['num_sizes'] * 4 + self.num_classes)
        self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1))
    def __init__(self,
                 in_channels,
                 num_points=(2048, 1024, 512, 256),
                 radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
                 num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
                 sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
                              ((64, 64, 128), (64, 64, 128), (64, 96, 128)),
                              ((128, 128, 256), (128, 192, 256), (128, 256,
                                                                  256))),
                 aggregation_channels=(64, 128, 256),
                 fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
                 fps_sample_range_lists=((-1), (-1), (512, -1)),
                 dilated_group=(True, True, True),
                 out_indices=(2, ),
                 norm_cfg=dict(type='BN2d'),
                 sa_cfg=dict(type='PointSAModuleMSG',
                             pool_mod='max',
                             use_xyz=True,
                             normalize_xyz=False),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.num_sa = len(sa_channels)
        self.out_indices = out_indices
        assert max(out_indices) < self.num_sa
        assert len(num_points) == len(radii) == len(num_samples) == len(
            sa_channels) == len(aggregation_channels)

        self.SA_modules = nn.ModuleList()
        self.aggregation_mlps = nn.ModuleList()
        sa_in_channel = in_channels - 3  # number of channels without xyz
        skip_channel_list = [sa_in_channel]

        for sa_index in range(self.num_sa):
            cur_sa_mlps = list(sa_channels[sa_index])
            sa_out_channel = 0
            for radius_index in range(len(radii[sa_index])):
                cur_sa_mlps[radius_index] = [sa_in_channel] + list(
                    cur_sa_mlps[radius_index])
                sa_out_channel += cur_sa_mlps[radius_index][-1]

            if isinstance(fps_mods[sa_index], tuple):
                cur_fps_mod = list(fps_mods[sa_index])
            else:
                cur_fps_mod = list([fps_mods[sa_index]])

            if isinstance(fps_sample_range_lists[sa_index], tuple):
                cur_fps_sample_range_list = list(
                    fps_sample_range_lists[sa_index])
            else:
                cur_fps_sample_range_list = list(
                    [fps_sample_range_lists[sa_index]])

            self.SA_modules.append(
                build_sa_module(
                    num_point=num_points[sa_index],
                    radii=radii[sa_index],
                    sample_nums=num_samples[sa_index],
                    mlp_channels=cur_sa_mlps,
                    fps_mod=cur_fps_mod,
                    fps_sample_range_list=cur_fps_sample_range_list,
                    dilated_group=dilated_group[sa_index],
                    norm_cfg=norm_cfg,
                    cfg=sa_cfg,
                    bias=True))
            skip_channel_list.append(sa_out_channel)

            cur_aggregation_channel = aggregation_channels[sa_index]
            if cur_aggregation_channel is None:
                self.aggregation_mlps.append(None)
                sa_in_channel = sa_out_channel
            else:
                self.aggregation_mlps.append(
                    ConvModule(sa_out_channel,
                               cur_aggregation_channel,
                               conv_cfg=dict(type='Conv1d'),
                               norm_cfg=dict(type='BN1d'),
                               kernel_size=1,
                               bias=True))
                sa_in_channel = cur_aggregation_channel
    def __init__(self,
                 in_channels,
                 num_points=(2048, 1024, 512, 256),
                 radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
                 num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
                 sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
                              ((64, 64, 128), (64, 64, 128), (64, 96, 128)),
                              ((128, 128, 256), (128, 192, 256), (128, 256,
                                                                  256))),
                 aggregation_channels=(64, 128, 256),
                 fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
                 fps_sample_range_lists=((-1), (-1), (512, -1)),
                 dilated_group=(True, True, True),
                 fp_channels=((256, 256), (256, 256)),
                 norm_cfg=dict(type='BN2d'),
                 sa_cfg=dict(type='PointSAModuleMSG',
                             pool_mod='max',
                             use_xyz=True,
                             normalize_xyz=False)):
        super().__init__()
        self.num_sa = len(sa_channels)
        self.num_fp = len(fp_channels)

        assert len(num_points) == len(radii) == len(num_samples) == len(
            sa_channels) == len(aggregation_channels)

        self.SA_modules = nn.ModuleList()
        self.aggregation_mlps = nn.ModuleList()
        sa_in_channel = in_channels - 3  # number of channels without xyz
        skip_channel_list = [sa_in_channel]

        for sa_index in range(self.num_sa):
            cur_sa_mlps = list(sa_channels[sa_index])
            sa_out_channel = 0
            for radius_index in range(len(radii[sa_index])):
                cur_sa_mlps[radius_index] = [sa_in_channel] + list(
                    cur_sa_mlps[radius_index])
                sa_out_channel += cur_sa_mlps[radius_index][-1]

            if isinstance(fps_mods[sa_index], tuple):
                cur_fps_mod = list(fps_mods[sa_index])
            else:
                cur_fps_mod = list([fps_mods[sa_index]])

            if isinstance(fps_sample_range_lists[sa_index], tuple):
                cur_fps_sample_range_list = list(
                    fps_sample_range_lists[sa_index])
            else:
                cur_fps_sample_range_list = list(
                    [fps_sample_range_lists[sa_index]])
            if num_points[sa_index] != None:
                self.SA_modules.append(
                    build_sa_module(
                        num_point=num_points[sa_index],
                        radii=radii[sa_index],
                        sample_nums=num_samples[sa_index],
                        mlp_channels=cur_sa_mlps,
                        fps_mod=cur_fps_mod,
                        fps_sample_range_list=cur_fps_sample_range_list,
                        dilated_group=dilated_group[sa_index],
                        norm_cfg=norm_cfg,
                        cfg=sa_cfg,
                        bias=True))
            else:
                self.SA_modules.append(
                    build_sa_module(
                        num_point=None,
                        radius=None,
                        num_sample=None,
                        mlp_channels=[sa_in_channel] + cur_sa_mlps,
                        norm_cfg=norm_cfg,
                        cfg=dict(type='PointSAModule',
                                 pool_mod=sa_cfg.get('pool_mod'),
                                 use_xyz=sa_cfg.get('use_xyz'),
                                 normalize_xyz=sa_cfg.get('normalize_xyz'))))
            if aggregation_channels[sa_index] is not None:
                self.aggregation_mlps.append(
                    ConvModule(sa_out_channel,
                               aggregation_channels[sa_index],
                               conv_cfg=dict(type='Conv1d'),
                               norm_cfg=dict(type='BN1d'),
                               kernel_size=1,
                               bias=True))
                sa_in_channel = aggregation_channels[sa_index]
            else:
                self.aggregation_mlps.append(None)
                sa_in_channel = sa_out_channel
            skip_channel_list.append(sa_in_channel)

        self.FP_modules = nn.ModuleList()
        fp_source_channel = skip_channel_list.pop()
        fp_target_channel = skip_channel_list.pop()
        for fp_index in range(len(fp_channels)):
            cur_fp_mlps = list(fp_channels[fp_index])
            cur_fp_mlps = [fp_source_channel + fp_target_channel] + cur_fp_mlps
            self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))
            if fp_index != len(fp_channels) - 1:
                fp_source_channel = cur_fp_mlps[-1]
                fp_target_channel = skip_channel_list.pop()
Exemple #10
0
    def __init__(self,
                 num_classes,
                 bbox_coder,
                 train_cfg=None,
                 test_cfg=None,
                 vote_moudule_cfg=None,
                 vote_aggregation_cfg=None,
                 feat_channels=(128, 128),
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 objectness_loss=None,
                 center_loss=None,
                 dir_class_loss=None,
                 dir_res_loss=None,
                 size_class_loss=None,
                 size_res_loss=None,
                 semantic_loss=None):
        super(VoteHead, self).__init__()
        self.num_classes = num_classes
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.gt_per_seed = vote_moudule_cfg['gt_per_seed']
        self.num_proposal = vote_aggregation_cfg['num_point']

        self.objectness_loss = build_loss(objectness_loss)
        self.center_loss = build_loss(center_loss)
        self.dir_class_loss = build_loss(dir_class_loss)
        self.dir_res_loss = build_loss(dir_res_loss)
        self.size_class_loss = build_loss(size_class_loss)
        self.size_res_loss = build_loss(size_res_loss)
        self.semantic_loss = build_loss(semantic_loss)

        assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[
            'in_channels']

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.num_sizes = self.bbox_coder.num_sizes
        self.num_dir_bins = self.bbox_coder.num_dir_bins

        self.vote_module = VoteModule(**vote_moudule_cfg)
        self.vote_aggregation = build_sa_module(vote_aggregation_cfg)

        prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
        conv_pred_list = list()
        for k in range(len(feat_channels)):
            conv_pred_list.append(
                ConvModule(
                    prev_channel,
                    feat_channels[k],
                    1,
                    padding=0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    bias=True,
                    inplace=True))
            prev_channel = feat_channels[k]
        self.conv_pred = nn.Sequential(*conv_pred_list)

        # Objectness scores (2), center residual (3),
        # heading class+residual (num_dir_bins*2),
        # size class+residual(num_sizes*4)
        conv_out_channel = (2 + 3 + self.num_dir_bins * 2 +
                            self.num_sizes * 4 + num_classes)
        self.conv_pred.add_module('conv_out',
                                  nn.Conv1d(prev_channel, conv_out_channel, 1))