def forward(self, data): pos, batch = data.pos, data.batch idx = fps(pos, batch, ratio=0.5) # 512 points row, col = radius(pos, pos[idx], 0.2, batch, batch[idx], max_num_neighbors=64) edge_index = torch.stack([col, row], dim=0) # Transpose. x = F.relu(self.local_sa1(None, (pos, pos[idx]), edge_index)) pos, batch = pos[idx], batch[idx] idx = fps(pos, batch, ratio=0.25) # 128 points row, col = radius(pos, pos[idx], 0.4, batch, batch[idx], max_num_neighbors=64) edge_index = torch.stack([col, row], dim=0) # Transpose. x = F.relu(self.local_sa2(x, (pos, pos[idx]), edge_index)) pos, batch = pos[idx], batch[idx] x = self.global_sa(torch.cat([x, pos], dim=1)) x = x.view(-1, 128, self.lin1.in_features).max(dim=1)[0] x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, data): pos, batch = data.pos, data.batch idx = fps(pos, batch, ratio=0.5) # 512 points row, col = radius(pos[idx], pos, 0.1, batch[idx], batch, max_num_neighbors=64) edge_index = torch.stack([row, idx[col]], dim=0) x = F.relu(self.local_sa1(None, pos, edge_index)) x, pos, batch = x[idx], pos[idx], batch[idx] idx = fps(pos, batch, ratio=0.25) # 128 points row, col = radius(pos[idx], pos, 0.2, batch[idx], batch, max_num_neighbors=64) edge_index = torch.stack([row, idx[col]], dim=0) x = F.relu(self.local_sa2(x, pos, edge_index)) x, pos, batch = x[idx], pos[idx], batch[idx] x = self.global_sa(torch.cat([x, pos], dim=1)) x = x.view(-1, 128, 1024).max(dim=1)[0] x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, data): # input = torch.cat([data.norm, data.pos], dim=1) # i = torch.cat([data.norm, data.pos, data.x], dim=1) input = torch.cat([data.norm, data.pos], dim=1) x, batch = input, data.batch edge_index, edge_weight = data.edge_index, data.edge_attr edge_weight = torch.ones((edge_index.size(1),), dtype=x.dtype, device=edge_index.device) # first conv with full points x = F.dropout(x, training=self.training, p=0.2) x = F.relu(self.conv1(x, edge_index, edge_weight)) # Second conv with full points x = torch.cat([x, input], dim=1) x = F.dropout(x, training=self.training, p=0.2) x = F.relu(self.conv2(x, edge_index, edge_weight)) # first down sampling index generation idx = fps(data.pos, batch, ratio=0.5) row, col = radius(data.pos, data.pos[idx], 0.4, batch, batch[idx], max_num_neighbors=64) edge_index_int = torch.stack([col, row], dim=0) x = self.con_int(x, (data.pos, data.pos[idx]), edge_index_int) batch = batch[idx] edge_index, edge_weight = self.filter_adj(edge_index, edge_weight, idx, data.pos.size(0)) x = torch.cat([x, input[idx]], dim=1) x = F.relu(self.conv3(x, edge_index, edge_weight)) out, critical_points = global_max_pool(x, batch) out = self.lin1(out) out = F.log_softmax(out, dim=1) return out, critical_points
def forward(self, x, pos, batch, norm=None): # pool points based on FPS algorithm, returning Npt*ratio centroids idx = fps(pos, batch, ratio=self.ratio) # finds points within radius `self.r` of the centroids, up to `self.K` pts per centroid row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=self.K) # edges joining centroids to their neighbors within ball of radius `self.r` edge_index = torch.stack([col, row], dim=0) # perform convolution if self.conv_name == 'PointConv': x = self.conv(x, (pos, pos[idx]), edge_index) elif self.conv_name == 'GraphConv': x = self.conv(x, edge_index)[idx] elif self.conv_name == 'PPFConv': x = self.conv(x, pos, norm, edge_index)[idx] pos, batch = pos[idx], batch[idx] return (x, pos, batch), idx
def forward(self, x, pos, batch): idx = fps(pos, batch, ratio=self.ratio) row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64) edge_index = torch.stack([col, row], dim=0) x = self.conv(x, (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] return x, pos, batch
def forward(self, points, features, batch): ratio = 1 / self.nb_neighbors fps_indices = gnn.fps(x=points, batch=batch, ratio=ratio) fps_points = points[fps_indices] fps_batch = batch[fps_indices] radius_cluster, radius_indices = gnn.radius(x=points, y=fps_points, batch_x=batch, batch_y=fps_batch, r=self.radius) anchor_points = fps_points[radius_cluster] radius_points = points[radius_indices] radius_features = features[radius_indices] relative_points = (radius_points - anchor_points) / self.radius rel_encoded = self.neighborhood_enc(relative_points, radius_cluster) rel_enc_mapped = rel_encoded[radius_cluster] fc_input = torch.cat( [relative_points, rel_enc_mapped, radius_features], dim=1) fc1_features = F.relu(self.fc1(fc_input)) max_features = gnn.global_max_pool(x=fc1_features, batch=radius_cluster) fc1_global_features = F.relu(self.fc1_global(max_features)) output_features = torch.cat([rel_encoded, fc1_global_features], dim=1) return fps_points, output_features, fps_batch
def forward(self, points, batch): ratio = 1/self.nb_neighbors fps_indices = gnn.fps( x=points, batch=batch, ratio=ratio ) fps_points = points[fps_indices] fps_batch = batch[fps_indices] radius_cluster, radius_indices = gnn.radius( x=points, y=fps_points, batch_x=batch, batch_y=fps_batch, r=self.radius ) anchor_points = fps_points[radius_cluster] radius_points = points[radius_indices] relative_points = (radius_points - anchor_points) / self.radius features = self.neighborhood_encoder(relative_points, radius_cluster) return fps_points, features, fps_batch
def forward(self, points, batch): ratio = 1 / self.nb_neighbors fps_indices = gnn.fps(x=points, batch=batch, ratio=ratio) fps_points = points[fps_indices] fps_batch = batch[fps_indices] radius_cluster, radius_indices = gnn.radius(x=points, y=fps_points, batch_x=batch, batch_y=fps_batch, r=self.radius) anchor_points = fps_points[radius_cluster] radius_points = points[radius_indices] relative_points = (radius_points - anchor_points) / self.radius fc1_features = F.relu(self.fc1(relative_points)) fc2_features = F.relu(self.fc2(fc1_features)) fc3_features = F.relu(self.fc3(fc2_features)) max_features = gnn.global_max_pool(x=fc3_features, batch=radius_cluster) fc1_global_features = F.relu(self.fc1_global(max_features)) fc2_global_features = F.relu(self.fc2_global(fc1_global_features)) fc3_global_features = F.relu(self.fc3_global(fc2_global_features)) return fps_points, fc3_global_features, fps_batch
def forward(self, x, pos, batch): edge_index = radius(pos, pos, self.r, batch, batch, max_num_neighbors=64) msg1 = self.cdf1(x, pos, edge_index) msg2 = self.cdf2(x, pos, edge_index) msg = torch.cat([msg1, msg2], dim=-1) conf = self.lin1(msg) return conf, msg, edge_index
def forward(self, data): x = data.x # 원자 인덱싱 edge_index = data.edge_index pos = data.pos batch = data.batch # Initialize node embeddings h = torch.index_select(self.embeddings, 0, x.long()) # input, dim, index (int or long) # Get the edges and pairwise distances in the local layer edge_index_l, _ = remove_self_loops(edge_index) j_l, i_l = edge_index_l dist_l = (pos[i_l] - pos[j_l]).pow(2).sum(dim=-1).sqrt() # Get the edges pairwise distances in the global layer row, col = radius(pos, pos, self.cutoff, batch, batch, max_num_neighbors=500) # radius edge_index_g = torch.stack([row, col], dim=0) edge_index_g, _ = remove_self_loops(edge_index_g) j_g, i_g = edge_index_g dist_g = (pos[i_g] - pos[j_g]).pow(2).sum(dim=-1).sqrt() # Compute the node indices for defining the angles idx_i_1, idx_j, idx_k, idx_kj, idx_ji, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 = self.indices(edge_index_l, num_nodes=h.size(0)) # Compute the two-hop angles pos_ji_1, pos_kj = pos[idx_j] - pos[idx_i_1], pos[idx_k] - pos[idx_j] # ij, jk (kj, ji) a = (pos_ji_1 * pos_kj).sum(dim=-1) # 내적 u v cos b = torch.cross(pos_ji_1, pos_kj).norm(dim=-1) # 외적 u v sin angle_1 = torch.atan2(b, a) # 각도 = arctan( sin / cos ) # Compute the one-hop angles pos_ji_2, pos_jj = pos[idx_j1] - pos[idx_i_2], pos[idx_j2] - pos[idx_j1] # ij, jj (jj, ji) a = (pos_ji_2 * pos_jj).sum(dim=-1) b = torch.cross(pos_ji_2, pos_jj).norm(dim=-1) angle_2 = torch.atan2(b, a) # Get the RBF and SBF embeddings rbf_g = self.rbf_g(dist_g) rbf_l = self.rbf_l(dist_l) sbf_1 = self.sbf(dist_l, angle_1, idx_kj) # 2-hop sbf_2 = self.sbf(dist_l, angle_2, idx_jj) # 1-hop rbf_g = self.rbf_g_mlp(rbf_g) rbf_l = self.rbf_l_mlp(rbf_l) sbf_1 = self.sbf_1_mlp(sbf_1) sbf_2 = self.sbf_2_mlp(sbf_2) # Perform the message passing schemes node_sum = 0 for layer in range(self.n_layer): h = self.global_layers[layer](h, rbf_g, edge_index_g) h, t = self.local_layers[layer](h, rbf_l, sbf_1, sbf_2, idx_kj, idx_ji, idx_jj, idx_ji_2, edge_index_l) node_sum += t # Readout output = global_add_pool(node_sum, batch) # return output.view(-1)
def forward(self, x, pos, batch): idx = fps(pos, batch, ratio=self.ratio) row, col = radius( pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64 ) # TODO: FIGURE OUT THIS WITH RESPECT TO NUMBER OF POINTS edge_index = torch.stack([col, row], dim=0) x = self.conv(x, (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] return x, pos, batch
def find_neighbours(self, x, y, batch_x=None, batch_y=None, scale_idx=0): if scale_idx >= self.num_scales: raise ValueError("Scale %i is out of bounds %i" % (scale_idx, self.num_scales)) return radius(x, y, self._radius[scale_idx], batch_x, batch_y, max_num_neighbors=self._max_num_neighbors[scale_idx])
def forward(self, data): x, pos, batch = data idx = fps(pos, batch, ratio=self.ratio) row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors) edge_index = torch.stack([col, row], dim=0) x = self.conv(x, (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] data = (x, pos, batch) return data
def forward(self, x, pos, batch): """Compute Point-wise Rigid Transformation. """ pos6d = torch.cat([pos, x[:, :3]], dim=-1) #idx = fps(pos6d, batch, ratio=0.2) #row6d, col6d = radius(pos6d, pos6d[idx], 0.05, batch, batch[idx], # max_num_neighbors=64) #edge_index6d = torch.stack([col6d, row6d], dim=0) row, col = radius(pos, pos, 0.1, batch, batch, max_num_neighbors=64) edge_index = torch.stack([col, row], dim=0) new_pred = self.conv(x, pos, edge_index) return new_pred
def forward(self, x, pos, batch): row, col = radius(pos, pos, self.r, batch, batch, max_num_neighbors=self.sample_size) edge_index = torch.stack([col, row], dim=0) x = self.conv(x, (pos, pos), edge_index) pos, batch = pos, batch return x, pos, batch
def radius_graph_3d(self, pos, batch, r, direction='source2target'): target_idx, source_idx = radius(pos, pos, r, batch, batch, max_num_neighbors=32) if direction == 'source2target': edge_index = torch.stack([source_idx, target_idx], dim=0) else: edge_index = torch.stack([target_idx, source_idx], dim=0) return edge_index
def forward(self, x, pos, norm, batch): idx = fps(pos, batch, ratio=self.ratio) #可以用radius或nerest构建半径内或者最近邻图 row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=32) edge_index = torch.stack([col, row], dim=0) x = self.conv(x, (pos, pos[idx]), (norm, norm[idx]), edge_index) pos, norm, batch = pos[idx], norm[idx], batch[idx] return x, pos, norm, batch
def forward(self, x, pos, batch): pos_ext = torch.cat([pos, x], dim=1) msg = [] x_centers = [] poss = [] idxs = [] msg_list = [] for r3d, ratio, r6d in zip(self.r3d, self.ratios, self.r6d): row6d, col6d = radius(pos_ext, pos_ext, r6d, batch, batch, max_num_neighbors=64) #idx = fps(pos_ext, batch, ratio=ratio) #idxs.append(idx) #row, col = radius(pos_ext, pos_ext[idx], r6d, batch, batch[idx], # max_num_neighbors=64) x_centers = scatter_mean(x.index_select(index=col6d, dim=0), row6d.unsqueeze(1), dim=0) edge_index6d = torch.stack([col6d, row6d], dim=0) #edge_indices.append(edge_index) row3d, col3d = radius(pos, pos, r3d, batch, batch, max_num_neighbors=64) edge_index3d = torch.stack([col3d, row3d], dim=0) msg = self.dst(x_centers, edge_index3d) msg_list.append(msg.unsqueeze(-1)) msg = torch.cat(msg_list, dim=-1) #conf = F.relu(self.lin1(msg)) conf = F.sigmoid(self.lin1(msg)) return conf[:, 0]
def forward(self, data): pos, batch = data.pos, data.batch idx = fps(pos, batch, ratio=0.5) # 512 points edge_index = radius(pos[idx], pos, 0.1, batch[idx], batch, 48) x = F.relu(self.local_sa1(None, pos, edge_index)) pos, batch = pos[idx], batch[idx] idx = fps(pos, batch, ratio=0.25) # 128 points edge_index = radius(pos[idx], pos, 0.2, batch[idx], batch, 48) x = F.relu(self.local_sa2(x, pos, edge_index)) pos, batch = pos[idx], batch[idx] x = self.global_sa(torch.cat([x, pos], dim=1)) x = x.view(-1, 128, 1024).max(dim=1)[0] x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin3(x) return F.log_softmax(x, dim=-1)
def forward(self, data): x, pos, batch = data # Sample idx = fps(pos, batch, ratio=self.sample_ratio) # Group(Build graph) row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors) edge_index = torch.stack([col, row], dim=0) # Apply pointnet x1 = self.point_conv(x, (pos, pos[idx]), edge_index) pos1, batch1 = pos[idx], batch[idx] return x1, pos1, batch1
def forward(self, x, pos, batch): sample_indices = gnn.fps(pos, batch, self.ratio) sparse_indices, dense_indices = gnn.radius(pos, pos[sample_indices], self.radius, batch, batch[sample_indices], max_num_neighbors=64) edge_index = torch.stack( (dense_indices, sparse_indices), dim=0 ) #TODO/CARE: Indices are propagated internally? Care edge direction: a->b <=> a is in N(b) x = self.point_conv(x, (pos, pos[sample_indices]), edge_index) return x, pos[sample_indices], batch[sample_indices]
def forward(self, x, pos, batch): idx = fps(pos, batch, self.sample_points / len(pos)) x_list = [] for i in range(len(self.r_list)): row, col = radius(pos, pos[idx], self.r_list[i], batch, batch[idx], max_num_neighbors=self.group_sample_size[i]) edge_index = torch.stack([col, row], dim=0) group_x = self.conv_list[i](x, (pos, pos[idx]), edge_index) x_list.append(group_x) new_x = torch.cat(x_list, 1) return new_x, pos[idx], batch[idx]
def forward(self, x, pos, batch): # Sampling Layer idx = fps(pos, batch, ratio=self.ratio) # Grouping Layer row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64) edge_index = torch.stack([col, row], dim=0) # PointNet Layer x = self.conv(x, (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] return x, pos, batch
def forward(self, pos, batch): # 1. Sample farthest points. mask = fps(pos, batch, ratio=0.25) # 2. Dynamically generate message passing connections. row, col = radius(pos, pos[mask], 0.3, batch, batch[mask]) assign_index = torch.stack([col, row], dim=0) # Transpose. # 3. Start bipartite message passing. x = self.conv(pos, pos[mask], assign_index) # 4. Global Pooling. x = global_max_pool(x, batch[mask]) # 5. Classifier. return self.classifier(x)
def forward(self, x, pos, batch): pos6d = torch.cat([x, pos], dim=-1) idx = fps(pos6d, batch, ratio=self.ratio) row6d, col6d = radius(pos6d, pos6d[idx], self.r, batch, batch[idx], max_num_neighbors=64) edge_index6d = torch.stack([col6d, row6d], dim=0) x_centers = scatter_mean(x.index_select(index=col6d, dim=0), row6d.unsqueeze(1), dim=0) #x = self.conv(x, (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] return x_centers, pos, batch
def forward(self, x, pos, batch): idx = fps(pos, batch, ratio=self.ratio) print("idx.shape: ", idx.shape) row, col = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64) print("row.shape: ", row.shape) edge_index = torch.stack([col, row], dim=0) print("edge_index.shape: ", edge_index.shape) x = self.conv(x, (pos, pos[idx]), edge_index) print("x.shape: ", x.shape) pos, batch = pos[idx], batch[idx] return x, pos, batch
def find_neighbours(self, x, y, batch_x=None, batch_y=None): if self._conv_type == ConvolutionFormat.MESSAGE_PASSING.value: return radius(x, y, self._radius, batch_x, batch_y, max_num_neighbors=self._max_num_neighbors) elif self._conv_type == ConvolutionFormat.DENSE.value or ConvolutionFormat.PARTIAL_DENSE.value: return tp.ball_query(self._radius, self._max_num_neighbors, x, y, mode=self._conv_type, batch_x=batch_x, batch_y=batch_y) else: raise NotImplementedError
def forward(self, x, pos, batch): if self.training: ratio = self.ratio_train else: ratio = self.ratio_test idx = fps(pos, batch, ratio=ratio) # ball query searches neighbors y_idx, x_idx = radius(pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=128) edge_index = torch.stack([x_idx, y_idx], dim=0) x = self.conv(x, (pos, pos[idx]), edge_index) pos, batch = pos[idx], batch[idx] return x, pos, batch, x_idx, y_idx
def forward(self, x, pos, batch): pos6d = torch.cat([pos, x], dim=-1) idx = fps(pos6d, batch, ratio=self.ratio, random_start=False) #batch0_mask = (batch[idx] == 0) #print('idx^2={}'.format((idx[batch0_mask]**2).sum())) row, col = radius(pos6d, pos6d[idx], self.r6d, batch, batch[idx], max_num_neighbors=128) edge_index = torch.stack([col, row], dim=0) #knn_index = knn(pos6d[idx], pos6d, 1, batch[idx], batch) x = scatter_mean( x.index_select(index=edge_index[0], dim=0), edge_index[1].unsqueeze(1), dim=0) pos = scatter_mean( pos.index_select(index=edge_index[0], dim=0), edge_index[1].unsqueeze(1), dim=0) return (x, pos, batch[idx]), edge_index
def build_bipartite_graph(self, pos, batch, ratio, method='radius', r=0.1, k=32, dilation=1): ''' Build a bipartite graph given a pos vector :param pos: (torch.Tensor) The position of the points :param batch: (torch.Tensor) The batch index for each point :param ratio: (float) The sample ratio :param method: (str) Specify which method to use, ['radius', 'knn'] :param r: (float) If 'radius' is adopted, the radius of the support domain :param k: (int) IF 'knn' is adopted, the number of neighbors :param dilation: (int) If 'knn' is adopted, the dilation rate :return: (torch.Tensor) The edge index, position and batch of the sampled points ''' assert method in ['radius', 'knn'] idx = fps(pos, batch, ratio=ratio) if method == 'radius': row, col = radius(pos, pos[idx], r, batch, batch[idx], max_num_neighbors=k) if method == 'knn': row, col = knn(pos, pos[idx], k * dilation, batch, batch[idx]) if dilation > 1: n = idx.shape[0] index = torch.randint(k * dilation, (n, k), dtype=torch.long, device=row.device) arange = torch.arange(n, dtype=torch.long, device=row.device) arange = arange * (k * dilation) index = (index + arange.view(-1, 1)).view(-1) row, col = row[index], col[index] edge_index = torch.stack([col, row], dim=0) return edge_index, pos[idx], batch[idx]