def __init__(self, in_channels, num_points=(2048, 1024, 512, 256), radius=(0.2, 0.4, 0.8, 1.2), num_samples=(64, 32, 16, 16), sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), (128, 128, 256)), fp_channels=((256, 256), (256, 256)), norm_cfg=dict(type='BN2d'), pool_mod='max', use_xyz=True, normalize_xyz=True): super().__init__() self.num_sa = len(sa_channels) self.num_fp = len(fp_channels) assert len(num_points) == len(radius) == len(num_samples) == len( sa_channels) assert len(sa_channels) >= len(fp_channels) assert pool_mod in ['max', 'avg'] self.SA_modules = nn.ModuleList() sa_in_channel = in_channels - 3 # number of channels without xyz skip_channel_list = [sa_in_channel] for sa_index in range(self.num_sa): cur_sa_mlps = list(sa_channels[sa_index]) cur_sa_mlps = [sa_in_channel] + cur_sa_mlps sa_out_channel = cur_sa_mlps[-1] self.SA_modules.append( PointSAModule( num_point=num_points[sa_index], radius=radius[sa_index], num_sample=num_samples[sa_index], mlp_channels=cur_sa_mlps, norm_cfg=norm_cfg, use_xyz=use_xyz, pool_mod=pool_mod, normalize_xyz=normalize_xyz)) skip_channel_list.append(sa_out_channel) sa_in_channel = sa_out_channel self.FP_modules = nn.ModuleList() fp_source_channel = skip_channel_list.pop() fp_target_channel = skip_channel_list.pop() for fp_index in range(len(fp_channels)): cur_fp_mlps = list(fp_channels[fp_index]) cur_fp_mlps = [fp_source_channel + fp_target_channel] + cur_fp_mlps self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps)) if fp_index != len(fp_channels) - 1: fp_source_channel = cur_fp_mlps[-1] fp_target_channel = skip_channel_list.pop()
def __init__(self, fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128), (128, 128, 128, 128)), **kwargs): super(PointNet2Head, self).__init__(**kwargs) self.num_fp = len(fp_channels) self.FP_modules = nn.ModuleList() for cur_fp_mlps in fp_channels: self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps)) # https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40 self.pre_seg_conv = ConvModule(fp_channels[-1][-1], self.channels, kernel_size=1, bias=True, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, with_se=False, normalize=True, eps=0, width_multiplier=1, voxel_resolution_multiplier=1): r, vr = width_multiplier, voxel_resolution_multiplier fp_layers = [] for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks): fp_blocks = [] out_channels = list(int(r * oc) for oc in fp_configs) fp_blocks.append(PointFPModule([in_channels + sa_in_channels[-1 - fp_idx]] + out_channels)) in_channels = out_channels[-1] if conv_configs is not None: out_channels, num_blocks, voxel_resolution = conv_configs out_channels = int(r * out_channels) if voxel_resolution is None: block = SharedMLP else: block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), with_se=with_se, normalize=normalize, eps=eps) for _ in range(num_blocks): fp_blocks.append(block(in_channels, out_channels)) in_channels = out_channels fp_layers.append(nn.ModuleList(fp_blocks)) return fp_layers, in_channels
def test_pointnet_fp_module(): if not torch.cuda.is_available(): pytest.skip() from mmdet3d.ops import PointFPModule self = PointFPModule(mlp_channels=[24, 16]).cuda() assert self.mlps.layer0.conv.in_channels == 24 assert self.mlps.layer0.conv.out_channels == 16 xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32).reshape((-1, 6)) # (B, N, 3) xyz1 = torch.from_numpy(xyz[0::2, :3]).view(1, -1, 3).cuda() # (B, C1, N) features1 = xyz1.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda() # (B, M, 3) xyz2 = torch.from_numpy(xyz[1::3, :3]).view(1, -1, 3).cuda() # (B, C2, N) features2 = xyz2.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda() fp_features = self(xyz1, xyz2, features1, features2) assert fp_features.shape == torch.Size([1, 16, 50])
def __init__(self, in_channels, num_points=(2048, 1024, 512, 256), radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)), num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)), sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)), ((64, 64, 128), (64, 64, 128), (64, 96, 128)), ((128, 128, 256), (128, 192, 256), (128, 256, 256))), aggregation_channels=(64, 128, 256), fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')), fps_sample_range_lists=((-1), (-1), (512, -1)), dilated_group=(True, True, True), fp_channels=((256, 256), (256, 256)), norm_cfg=dict(type='BN2d'), sa_cfg=dict(type='PointSAModuleMSG', pool_mod='max', use_xyz=True, normalize_xyz=False)): super().__init__() self.num_sa = len(sa_channels) self.num_fp = len(fp_channels) assert len(num_points) == len(radii) == len(num_samples) == len( sa_channels) == len(aggregation_channels) self.SA_modules = nn.ModuleList() self.aggregation_mlps = nn.ModuleList() sa_in_channel = in_channels - 3 # number of channels without xyz skip_channel_list = [sa_in_channel] for sa_index in range(self.num_sa): cur_sa_mlps = list(sa_channels[sa_index]) sa_out_channel = 0 for radius_index in range(len(radii[sa_index])): cur_sa_mlps[radius_index] = [sa_in_channel] + list( cur_sa_mlps[radius_index]) sa_out_channel += cur_sa_mlps[radius_index][-1] if isinstance(fps_mods[sa_index], tuple): cur_fps_mod = list(fps_mods[sa_index]) else: cur_fps_mod = list([fps_mods[sa_index]]) if isinstance(fps_sample_range_lists[sa_index], tuple): cur_fps_sample_range_list = list( fps_sample_range_lists[sa_index]) else: cur_fps_sample_range_list = list( [fps_sample_range_lists[sa_index]]) if num_points[sa_index] != None: self.SA_modules.append( build_sa_module( num_point=num_points[sa_index], radii=radii[sa_index], sample_nums=num_samples[sa_index], mlp_channels=cur_sa_mlps, fps_mod=cur_fps_mod, fps_sample_range_list=cur_fps_sample_range_list, dilated_group=dilated_group[sa_index], norm_cfg=norm_cfg, cfg=sa_cfg, bias=True)) else: self.SA_modules.append( build_sa_module( num_point=None, radius=None, num_sample=None, mlp_channels=[sa_in_channel] + cur_sa_mlps, norm_cfg=norm_cfg, cfg=dict(type='PointSAModule', pool_mod=sa_cfg.get('pool_mod'), use_xyz=sa_cfg.get('use_xyz'), normalize_xyz=sa_cfg.get('normalize_xyz')))) if aggregation_channels[sa_index] is not None: self.aggregation_mlps.append( ConvModule(sa_out_channel, aggregation_channels[sa_index], conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), kernel_size=1, bias=True)) sa_in_channel = aggregation_channels[sa_index] else: self.aggregation_mlps.append(None) sa_in_channel = sa_out_channel skip_channel_list.append(sa_in_channel) self.FP_modules = nn.ModuleList() fp_source_channel = skip_channel_list.pop() fp_target_channel = skip_channel_list.pop() for fp_index in range(len(fp_channels)): cur_fp_mlps = list(fp_channels[fp_index]) cur_fp_mlps = [fp_source_channel + fp_target_channel] + cur_fp_mlps self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps)) if fp_index != len(fp_channels) - 1: fp_source_channel = cur_fp_mlps[-1] fp_target_channel = skip_channel_list.pop()