def forward(self, p): batch_size, T, D = p.size() # acquire the index for each point coord = {} index = {} if 'xz' in self.plane_type: coord['xz'] = normalize_coordinate(p.clone(), plane='xz', padding=self.padding) index['xz'] = coordinate2index(coord['xz'], self.reso_plane) if 'xy' in self.plane_type: coord['xy'] = normalize_coordinate(p.clone(), plane='xy', padding=self.padding) index['xy'] = coordinate2index(coord['xy'], self.reso_plane) if 'yz' in self.plane_type: coord['yz'] = normalize_coordinate(p.clone(), plane='yz', padding=self.padding) index['yz'] = coordinate2index(coord['yz'], self.reso_plane) if 'grid' in self.plane_type: coord['grid'] = normalize_3d_coordinate(p.clone(), padding=self.padding) index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d') ################## if self.pos_encoding: pp = self.pe(p) net = self.fc_pos(pp) else: net = self.fc_pos(p) ################## #net = self.fc_pos(p) net = self.blocks[0](net) for block in self.blocks[1:]: pooled = self.pool_local(coord, index, net) net = torch.cat([net, pooled], dim=2) net = block(net) c = self.fc_c(net) fea = {} #if 'grid' in self.plane_type: # fea = {**self.generate_grid_features(p, c)} if 'grid' in self.plane_type: fea['grid'] = self.generate_grid_features(p, c) if 'xz' in self.plane_type: fea['xz'] = self.generate_plane_features(p, c, plane='xz') if 'xy' in self.plane_type: fea['xy'] = self.generate_plane_features(p, c, plane='xy') if 'yz' in self.plane_type: fea['yz'] = self.generate_plane_features(p, c, plane='yz') return fea
def generate_grid_features(self, p, c): p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) index = coordinate2index(p_nor, self.reso_grid, coord_type='3d') # scatter grid features from points fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3) c = c.permute(0, 2, 1) fea_grid = scatter_mean(c, index, out=fea_grid) fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) if self.unet3d is not None: fea_grid = self.unet3d(fea_grid) return fea_grid
def generate_plane_features(self, p, c, plane='xz'): # acquire indices of features in plane xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) index = coordinate2index(xy, self.reso_plane) # scatter plane features from points fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2) c = c.permute(0, 2, 1) fea_plane = scatter_mean(c, index, out=fea_plane) fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # process the plane features with UNet if self.unet is not None: fea_plane = self.unet(fea_plane) return fea_plane
def generate_dynamic_plane_features(self, p, c, normal_feature, basis_normalizer_matrix): # acquire indices of features in plane xy = normalize_dynamic_plane_coordinate( p.clone(), basis_normalizer_matrix, padding=self.padding) # normalize to the range of (0, 1) index = coordinate2index(xy, self.reso_plane) # scatter plane features from points fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2) c = c.permute(0, 2, 1) # B x 512 x T c = c + normal_feature.unsqueeze(2) fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 fea_plane = fea_plane.reshape( p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso) # process the plane features with UNet if self.unet is not None: fea_plane = self.unet(fea_plane) return fea_plane
def forward(self, p): batch_size, T, D = p.size() self.device = 'cuda' ################## if self.pos_encoding: pp = self.pe(p) net = self.fc_pos(pp) net_pl = self.fc_plane_net(pp) else: net = self.fc_pos(p) net_pl = self.fc_plane_net(p) ################## normal_fea = [] normal_fea_hdim = {} num_planes = self.num_channels num_dynamic_planes = self.num_channels - 3 for l in range(num_dynamic_planes): normal_fea.append(self.plane_params[l](self.actvn(net_pl))) normal_fea_hdim['dynamic_plane{}'.format( l)] = self.plane_params_hdim[l](normal_fea[l]) if (self.num_channels == 3): raise Exception( f"Number of channels is {self.num_channels}, no point in using Hybrid approach" ) self.plane_parameters = torch.stack( normal_fea, dim=1) # plane parameter (batch_size x num_dynamic_planes x 3) C_mat = ChangeBasis( self.plane_parameters, device=self.device ) # change of basis and normalizer matrix (concatenated) # acquire the index for each point coord = {} index = {} for l in range(num_planes): if l == 0: coord['plane{}'.format(l)] = normalize_coordinate( p.clone(), plane='xz', padding=self.padding) index['plane{}'.format(l)] = coordinate2index( coord['plane{}'.format(l)], self.reso_plane) elif l == 1: coord['plane{}'.format(l)] = normalize_coordinate( p.clone(), plane='xy', padding=self.padding) index['plane{}'.format(l)] = coordinate2index( coord['plane{}'.format(l)], self.reso_plane) elif l == 2: coord['plane{}'.format(l)] = normalize_coordinate( p.clone(), plane='yz', padding=self.padding) index['plane{}'.format(l)] = coordinate2index( coord['plane{}'.format(l)], self.reso_plane) else: dynamic_plane_id = l - 3 coord['plane{}'.format( l)] = normalize_dynamic_plane_coordinate( p.clone(), C_mat[:, dynamic_plane_id], padding=self.padding) index['plane{}'.format(l)] = coordinate2index( coord['plane{}'.format(l)], self.reso_plane) net = self.blocks[0](net) for block in self.blocks[1:]: pooled = self.pool_local(coord, index, net) net = torch.cat([net, pooled], dim=2) net = block(net) c = self.fc_c(net) fea = {} for l in range(num_planes): if l == 0: fea['plane{}'.format(l)] = self.generate_plane_features( p, c, plane='xz') elif l == 1: fea['plane{}'.format(l)] = self.generate_plane_features( p, c, plane='xy') elif l == 2: fea['plane{}'.format(l)] = self.generate_plane_features( p, c, plane='yz') else: dynamic_plane_id = l - 3 fea['plane{}'.format( l)] = self.generate_dynamic_plane_features( p, c, normal_fea_hdim['dynamic_plane{}'.format( dynamic_plane_id)], C_mat[:, dynamic_plane_id]) fea['c_mat'] = C_mat # Normalize plane params for similarity loss calculation eye_basis = torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]).to(self.device) canonical_plane_parameters = torch.cat(batch_size * [eye_basis]).view( batch_size, 3, 3) self.plane_parameters = self.plane_parameters.reshape( [batch_size * num_dynamic_planes, 3]) self.plane_parameters = self.plane_parameters / torch.norm( self.plane_parameters, p=2, dim=1).view( batch_size * num_dynamic_planes, 1) # normalize self.plane_parameters = self.plane_parameters.view(batch_size, -1) self.plane_parameters = self.plane_parameters.view(batch_size, -1, 3) # Concatenate canonical plane normals and dynamic plane parameters self.plane_parameters = torch.cat( [canonical_plane_parameters, self.plane_parameters], dim=1) return fea
def forward(self, p): batch_size, T, D = p.size() self.device = 'cpu' ################## if self.pos_encoding: pp = self.pe(p) net = self.fc_pos(pp) net_pl = self.fc_plane_net(pp) else: net = self.fc_pos(p) net_pl = self.fc_plane_net(p) ################## normal_fea = [] normal_fea_hdim = {} for l in range(self.num_channels): normal_fea.append(self.plane_params[l](self.actvn(net_pl))) normal_fea_hdim['plane{}'.format(l)] = self.plane_params_hdim[l]( normal_fea[l]) self.plane_parameters = torch.stack( normal_fea, dim=1) # plane parameter (batch_size x L x 3) C_mat = ChangeBasis( self.plane_parameters, device=self.device ) # change of basis and normalizer matrix (concatenated) num_planes = C_mat.size()[1] # acquire the index for each point coord = {} index = {} for l in range(num_planes): coord['plane{}'.format(l)] = normalize_dynamic_plane_coordinate( p.clone(), C_mat[:, l], padding=self.padding) index['plane{}'.format(l)] = coordinate2index( coord['plane{}'.format(l)], self.reso_plane) net = self.blocks[0](net) for block in self.blocks[1:]: pooled = self.pool_local(coord, index, net) net = torch.cat([net, pooled], dim=2) net = block(net) c = self.fc_c(net) fea = {} for l in range(C_mat.size()[1]): fea['plane{}'.format(l)] = self.generate_dynamic_plane_features( p, c, normal_fea_hdim['plane{}'.format(l)], C_mat[:, l]) fea['c_mat'] = C_mat # Normalize plane params for similarity loss calculation self.plane_parameters = self.plane_parameters.reshape( [batch_size * num_planes, 3]) self.plane_parameters = self.plane_parameters / torch.norm( self.plane_parameters, p=2, dim=1).view(batch_size * num_planes, 1) # normalize self.plane_parameters = self.plane_parameters.view(batch_size, -1) self.plane_parameters = self.plane_parameters.view(batch_size, -1, 3) # print("just fea", type(fea)) return fea
def forward(self, p): batch_size, T, D = p.size() p1 = p[::2] p2 = p[1::2] p12 = torch.cat([p1, p2], dim=1) # acquire the index for each point coord = {} index = {} if 'grid' in self.plane_type: # First scan coord['grid_1'] = normalize_3d_coordinate(p1.clone(), padding=self.padding) index['grid_1'] = coordinate2index(coord['grid_1'], self.reso_grid, coord_type='3d') # Second scan coord['grid_2'] = normalize_3d_coordinate(p2.clone(), padding=self.padding) index['grid_2'] = coordinate2index(coord['grid_2'], self.reso_grid, coord_type='3d') # First + Second coord['grid_12'] = normalize_3d_coordinate(p12.clone(), padding=self.padding) index['grid_12'] = coordinate2index(coord['grid_12'], self.reso_grid, coord_type='3d') # Encode first scan net_1 = self.fc_pos(p1) net_1 = self.blocks[0](net_1) for block in self.blocks[1:]: pooled = self.pool_local(index['grid_1'], net_1) net_1 = torch.cat([net_1, pooled], dim=2) net_1 = block(net_1) c1 = self.fc_c(net_1) # Encode second scan net_2 = self.fc_pos(p2) net_2 = self.blocks[0](net_2) for block in self.blocks[1:]: pooled = self.pool_local(index['grid_2'], net_2) net_2 = torch.cat([net_2, pooled], dim=2) net_2 = block(net_2) c2 = self.fc_c(net_2) # Encode first + second scan net_12 = self.fc_pos(p12) net_12 = self.blocks[0](net_12) for block in self.blocks[1:]: pooled = self.pool_local(index['grid_12'], net_12) net_12 = torch.cat([net_12, pooled], dim=2) net_12 = block(net_12) c12 = self.fc_c(net_12) fea = {} fea['unet3d'] = self.unet3d if 'grid' in self.plane_type: fea['latent_1'] = self.generate_grid_features(p1, c1) fea['latent_2'] = self.generate_grid_features(p2, c2) fea['latent_12'] = self.generate_grid_features(p12, c12) return fea