def forward(self, p): batch_size, T, D = p.size() # acquire the index for each point coord = {} index = {} if 'xz' in self.plane_type: coord['xz'] = normalize_coordinate(p.clone(), plane='xz', padding=self.padding) index['xz'] = coordinate2index(coord['xz'], self.reso_plane) if 'xy' in self.plane_type: coord['xy'] = normalize_coordinate(p.clone(), plane='xy', padding=self.padding) index['xy'] = coordinate2index(coord['xy'], self.reso_plane) if 'yz' in self.plane_type: coord['yz'] = normalize_coordinate(p.clone(), plane='yz', padding=self.padding) index['yz'] = coordinate2index(coord['yz'], self.reso_plane) if 'grid' in self.plane_type: coord['grid'] = normalize_3d_coordinate(p.clone(), padding=self.padding) index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d') ################## if self.pos_encoding: pp = self.pe(p) net = self.fc_pos(pp) else: net = self.fc_pos(p) ################## #net = self.fc_pos(p) net = self.blocks[0](net) for block in self.blocks[1:]: pooled = self.pool_local(coord, index, net) net = torch.cat([net, pooled], dim=2) net = block(net) c = self.fc_c(net) fea = {} #if 'grid' in self.plane_type: # fea = {**self.generate_grid_features(p, c)} if 'grid' in self.plane_type: fea['grid'] = self.generate_grid_features(p, c) if 'xz' in self.plane_type: fea['xz'] = self.generate_plane_features(p, c, plane='xz') if 'xy' in self.plane_type: fea['xy'] = self.generate_plane_features(p, c, plane='xy') if 'yz' in self.plane_type: fea['yz'] = self.generate_plane_features(p, c, plane='yz') return fea
def sample_grid_feature(self, p, c): # normalize to the range of (0, 1) p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) p_nor = p_nor[:, :, None, None].float() vgrid = 2.0 * p_nor - 1.0 # normalize to (-1, 1) # acutally trilinear interpolation if mode = 'bilinear' c = F.grid_sample(c, vgrid, padding_mode='border', align_corners=True, mode=self.sample_mode).squeeze(-1).squeeze(-1) return c
def generate_grid_features(self, p, c): p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) index = coordinate2index(p_nor, self.reso_grid, coord_type='3d') # scatter grid features from points fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3) c = c.permute(0, 2, 1) fea_grid = scatter_mean(c, index, out=fea_grid) fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) if self.unet3d is not None: fea_grid = self.unet3d(fea_grid) return fea_grid
def forward(self, p): batch_size, T, D = p.size() p1 = p[::2] p2 = p[1::2] p12 = torch.cat([p1, p2], dim=1) # acquire the index for each point coord = {} index = {} if 'grid' in self.plane_type: # First scan coord['grid_1'] = normalize_3d_coordinate(p1.clone(), padding=self.padding) index['grid_1'] = coordinate2index(coord['grid_1'], self.reso_grid, coord_type='3d') # Second scan coord['grid_2'] = normalize_3d_coordinate(p2.clone(), padding=self.padding) index['grid_2'] = coordinate2index(coord['grid_2'], self.reso_grid, coord_type='3d') # First + Second coord['grid_12'] = normalize_3d_coordinate(p12.clone(), padding=self.padding) index['grid_12'] = coordinate2index(coord['grid_12'], self.reso_grid, coord_type='3d') # Encode first scan net_1 = self.fc_pos(p1) net_1 = self.blocks[0](net_1) for block in self.blocks[1:]: pooled = self.pool_local(index['grid_1'], net_1) net_1 = torch.cat([net_1, pooled], dim=2) net_1 = block(net_1) c1 = self.fc_c(net_1) # Encode second scan net_2 = self.fc_pos(p2) net_2 = self.blocks[0](net_2) for block in self.blocks[1:]: pooled = self.pool_local(index['grid_2'], net_2) net_2 = torch.cat([net_2, pooled], dim=2) net_2 = block(net_2) c2 = self.fc_c(net_2) # Encode first + second scan net_12 = self.fc_pos(p12) net_12 = self.blocks[0](net_12) for block in self.blocks[1:]: pooled = self.pool_local(index['grid_12'], net_12) net_12 = torch.cat([net_12, pooled], dim=2) net_12 = block(net_12) c12 = self.fc_c(net_12) fea = {} fea['unet3d'] = self.unet3d if 'grid' in self.plane_type: fea['latent_1'] = self.generate_grid_features(p1, c1) fea['latent_2'] = self.generate_grid_features(p2, c2) fea['latent_12'] = self.generate_grid_features(p12, c12) return fea