def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = x.relu() cluster1 = graclus(edge_index, num_nodes=x.shape[0]) pooled_1 = data pooled_1.x = x pooled_1 = max_pool(cluster1, pooled_1) edge_index_2 = pooled_1.edge_index x2 = pooled_1.x x2 = self.conv2(x2, edge_index_2) x2 = x2.relu() cluster2 = graclus(edge_index_2, num_nodes=x2.shape[0]) pooled_2 = pooled_1 pooled_2.x = x2 pooled_2 = max_pool(cluster2, pooled_2) edge_index_3 = pooled_2.edge_index x3 = pooled_2.x x3 = self.conv3(x3, edge_index_3) x3 = x3.relu() x3 = self.conv4(x3, edge_index_3) x3 = x3.relu() x3 = knn_interpolate(x3, pooled_2.pos, pooled_1.pos) x3 = torch.cat((x2, x3), dim=1) x3 = self.conv5(x3, edge_index_2) x3 = x3.relu() x3 = knn_interpolate(x3, pooled_1.pos, data.pos) x = torch.cat((x, x3), dim=1) x = self.lin1(x) x = x.relu() x = self.lin2(x) x = x.relu() x = torch.sigmoid(self.out(x)) return x
def forward(self, data): x, edge_index = data.x, data.edge_index x = self.pre_lin(x) if self.masif_descr else x x = self.conv1(x, edge_index) x = self.s1(x) x = self.conv2(x, edge_index) x = self.s2(x) x = self.conv3(x, edge_index) x = self.s3(x) cluster = graclus(edge_index, num_nodes=x.shape[0]) inter = data inter.x = x inter = max_pool(cluster, inter) interx = self.inters1(self.interconv1(inter.x, inter.edge_index)) inter = knn_interpolate(interx, inter.pos, data.pos) x1 = self.affine1(x) x1 += inter x = self.conv4(x, edge_index) x = self.s4(x) x = self.conv5(x, edge_index) x = self.s5(x) x = self.conv6(x, edge_index) x = self.s6(x) inter = data inter.x = x inter = max_pool(cluster, inter) interx = self.inters2(self.interconv1(inter.x, inter.edge_index)) inter = knn_interpolate(interx, inter.pos, data.pos) x2 = self.affine1(x) x2 += inter x = self.conv7(x, edge_index) x = self.s7(x) x = self.conv8(x, edge_index) x = self.s8(x) x = self.conv9(x, edge_index) x = self.s9(x) x = x + x1 + x2 x = self.conv10(x, edge_index) x = self.s10(x) x = self.lin1(x) x = self.s11(x) x = self.lin2(x) x = self.s12(x) x = self.out(x) x = torch.sigmoid(x) return x
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip): x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k) if x_skip is not None: x = torch.cat([x, x_skip], dim=1) x = self.nn(x) return x, pos_skip, batch_skip
def predict_original_samples(self, batch, conv_type, output): """ Takes the output generated by the NN and upsamples it to the original data Arguments: batch -- processed batch conv_type -- Type of convolutio (DENSE, PARTIAL_DENSE, etc...) output -- output predicted by the model """ full_res_results = {} num_sample = BaseDataset.get_num_samples(batch, conv_type) if conv_type == "DENSE": output = output.reshape(num_sample, -1, output.shape[-1]) # [B,N,L] setattr(batch, "_pred", output) for b in range(num_sample): sampleid = batch.sampleid[b] sample_raw_pos = self.test_dataset[0].get_raw(sampleid).pos.to( output.device) predicted = BaseDataset.get_sample(batch, "_pred", b, conv_type) origindid = BaseDataset.get_sample(batch, SaveOriginalPosId.KEY, b, conv_type) full_prediction = knn_interpolate(predicted, sample_raw_pos[origindid], sample_raw_pos, k=3) labels = full_prediction.max(1)[1].unsqueeze(-1) full_res_results[self.test_dataset[0].get_filename( sampleid)] = np.hstack(( sample_raw_pos.cpu().numpy(), labels.cpu().numpy(), )) return full_res_results
def test_knn_interpolate(): x = torch.Tensor([[1], [10], [100], [-1], [-10], [-100]]) pos_x = torch.Tensor([[-1, 0], [0, 0], [1, 0], [-2, 0], [0, 0], [2, 0]]) pos_y = torch.Tensor([[-1, -1], [1, 1], [-2, -2], [2, 2]]) batch_x = torch.tensor([0, 0, 0, 1, 1, 1]) batch_y = torch.tensor([0, 0, 1, 1]) y = knn_interpolate(x, pos_x, pos_y, batch_x, batch_y, k=2) assert y.tolist() == [[4], [70], [-4], [-70]]
def forward(self, data): #print() #print([x.shape if x is not None else None for x in data]) x, pos, batch, x_skip, pos_skip, batch_skip = data x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k) if x_skip is not None: x = torch.cat([x, x_skip], dim=1) x = self.nn(x) data = (x, pos_skip, batch_skip) return data
def forward(self, data): x, edge_index = data.x, data.edge_index x = self.pre_lin(x) if self.masif_descr else x x = self.conv1(x, edge_index) x = self.s1(x) x = self.conv2(x, edge_index) x = self.s2(x) x = self.conv3(x, edge_index) x = self.s3(x) cluster1 = graclus(edge_index, num_nodes=x.shape[0]) inter1 = data inter1.x = x inter1 = max_pool(cluster1, inter1) x = self.s4(self.conv4(inter1.x, inter1.edge_index)) edge_index = inter1.edge_index x = self.conv5(x, edge_index) x = self.s5(x) x = self.conv6(x, edge_index) x = self.s6(x) cluster2 = graclus(edge_index, num_nodes=x.shape[0]) inter2 = inter1 inter2.x = x inter2 = max_pool(cluster2, inter2) x = self.s7(self.conv7(inter2.x, inter2.edge_index)) x = knn_interpolate(x, inter2.pos, inter1.pos) x = self.conv8(x, edge_index) x = self.s8(x) x = knn_interpolate(x, inter1.pos, data.pos) edge_index = data.edge_index x = self.conv9(x, edge_index) x = self.s9(x) x = self.conv10(x, edge_index) x = self.s10(x) x = self.lin1(x) x = self.s11(x) x = self.lin2(x) x = self.s12(x) x = self.out(x) x = torch.sigmoid(x) return x
def forward(self, data): x, edge_index_1 = data.x, data.edge_index # define downscaled samples. cluster1 = graclus(edge_index_1, num_nodes=x.shape[0]) downsample_1 = avg_pool(cluster1, data) edge_index_2 = downsample_1.edge_index cluster2 = graclus(edge_index_2, num_nodes=downsample_1.x.shape[0]) downsample_2 = avg_pool(cluster2, downsample_1) edge_index_3 = downsample_2.edge_index x = self.conv1(x, edge_index_1) x = self.s1(x) inter1 = data inter1.x = x inter1 = max_pool(cluster1, inter1) x2 = inter1.x x2 = torch.cat((self.affine1(downsample_1.x), x2), dim=1) x2 = self.conv2(x2, edge_index_2) x2 = self.s2(x2) inter2 = inter1 inter2.x = x2 inter2 = max_pool(cluster2, inter2) x3 = inter2.x x3 = torch.cat((self.affine2(downsample_2.x), x3), dim=1) x3 = self.conv3(x3, edge_index_3) x3 = self.s3(x3) x3 = knn_interpolate(x3, downsample_2.pos, downsample_1.pos) x2 = torch.cat((x2, x3), dim=1) x2 = knn_interpolate(x2, downsample_1.pos, data.pos) x = torch.cat((x, x2), dim=1) x = self.conv4(x, edge_index_1) x = self.s4(x) x = self.conv5(x, edge_index_1) x = self.s5(x) x = self.s6(self.lin1(x)) x = self.s7(self.lin2(x)) return torch.sigmoid(self.out(x))
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip): x_upsampled = gnn.knn_interpolate( x, pos, pos_skip, batch_x=batch, batch_y=batch_skip, k=self.k) #Interpolate coarse features to dense positions if x_skip is not None: x_upsampled = torch.cat( (x_upsampled, x_skip), dim=1) #Concatenate new and old dense features # print('dims:', x.shape, x_skip.shape, x_upsampled.shape) x_upsampled = self.mlp(x_upsampled) #Run MLP return x_upsampled, pos_skip, batch_skip
def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None): # transform low-res features and reduce the number of features x_sub = self.mlp_sub(x_sub) # interpolate low-res feats to high-res points x_interpolated = knn_interpolate(x_sub, pos_sub, pos, k=3, batch_x=batch_sub, batch_y=batch) x = self.mlp(x) + x_interpolated return x
def forward(self, data, backsampling): if self.conv_layer: data.x = F.elu(self.conva(data.x, data.edge_index, data.edge_attr)) # data.x = F.elu(self.convb(data.x, data.edge_index, data.edge_attr)) data.x = knn_interpolate(data.x, data.pos, backsampling.pos, data.batch, backsampling.batch, k=self.k) data.pos = backsampling.pos data.edge_index = backsampling.edge_index data.edge_attr = backsampling.edge_attr data.batch = backsampling.batch return data
def forward(self, x_hr, pos_hr, batch_hr, x_lr, pos_lr, batch_lr): out_x = x_hr out_pos = pos_hr out_batch = batch_hr if (out_x is not None): out_x = torch.cat([out_x, out_pos], dim=1) else: out_x = out_pos lr_x = x_lr lr_pos = pos_lr lr_batch = batch_lr lr_x = torch.cat([lr_x, lr_pos], dim=1) #return a tensor where the feature of each point of lr_x are appended in a new array in # the point of the upsampled version is in his knn lr_x = knn_interpolate(lr_x, lr_pos, out_pos, lr_batch, out_batch, k=1) out_x = torch.cat([out_x, lr_x], dim=1) return out_x, out_pos.new_zeros((out_x.size(0), 3)), out_batch
def __call__(self, query, support, precomputed: Data = None): """ Computes a new set of features going from the query resolution position to the support resolution position Args: - query: data structure that holds the low res data (position + features) - support: data structure that holds the position to which we will interpolate Returns: - torch.tensor: interpolated features """ if precomputed: num_points = support.pos.size(0) if num_points != precomputed.num_nodes: raise ValueError("Precomputed indices do not match with the data given to the transform") x = query.x x_idx, y_idx, weights, normalisation = ( precomputed.x_idx, precomputed.y_idx, precomputed.weights, precomputed.normalisation, ) y = scatter_add(x[x_idx] * weights, y_idx, dim=0, dim_size=num_points) y = y / normalisation return y x, pos = query.x, query.pos pos_support = support.pos if hasattr(support, "batch"): batch_support = support.batch else: batch_support = torch.zeros((support.num_nodes,), dtype=torch.long) if hasattr(query, "batch"): batch = query.batch else: batch = torch.zeros((query.num_nodes,), dtype=torch.long) return knn_interpolate(x, pos, pos_support, batch, batch_support, k=self.k)
def conv(self, x, pos, pos_skip, batch, batch_skip, *args): return knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
def forward(self, x): batch_size = x.size(0) num_points = x.size(2) x0 = get_graph_feature( x, x, k=self.k, dilation=False, dilation_type=self.dilation ) # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k) t = self.transform_net(x0) # (batch_size, 3, 3) x = x.transpose( 2, 1) # (batch_size, 3, num_points) -> (batch_size, num_points, 3) x = torch.bmm( x, t ) # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3) x = x.transpose( 2, 1) # (batch_size, num_points, 3) -> (batch_size, 3, num_points) pos = x pos_1 = pos x1 = self.edge1(x, pos_1) x1_upsample = x1 x1_cpy = x1 bs = torch.arange(x1.size(0)).to(self.device) bs = bs.repeat_interleave(x1.size(2)) pos_upsample = pos.reshape(3, -1) pos = pos.permute(2, 1, 0).contiguous() pos = pos.view(pos.size(0) * pos.size(2), -1) #(n_pts*bs),channel x1 = x1.permute(2, 1, 0).contiguous() #n_pts,channel,bs x1 = x1.view(x1.size(0) * x1.size(2), -1) #(n_pts*bs),channel idx = fps(pos, batch=bs, ratio=0.375) #2048->768 pos2 = pos[idx] pos2 = pos2.view(-1, pos2.size(1), batch_size) pos2 = pos2.permute(2, 1, 0) #batchsize,channel,pts pos2_1 = pos2 x2 = x1[idx] #pts*bs,channel x2 = x2.view(-1, x1.size(1), batch_size) #pts,channel,bs x2 = x2.permute(2, 1, 0) #.view(x2.size(2),x2.size(1),x2.size(0)) ######################### x2 = self.edge2(x2, pos2_1) x2_upsample = x2 bs2 = torch.arange(x2.size(0)).to(self.device) bs2 = bs2.repeat_interleave(x2.size(2)) pos2_upsample = pos2.reshape(3, -1) pos2 = pos2.permute(2, 1, 0).contiguous() pos2 = pos2.view(pos2.size(0) * pos2.size(2), -1) #(n_pts*bs),channel x2 = x2.permute( 2, 1, 0).contiguous() #view(x1.size(2),x1.size(1),x1.size(0)) x2 = x2.view(x2.size(0) * x2.size(2), -1) #(n_pts*bs),channel idx = fps(pos2, batch=bs2, ratio=1 / 3) #768 -> 384 pos3 = pos2[idx] pos3 = pos3.view(-1, pos3.size(1), batch_size) pos3 = pos3.permute(2, 1, 0) #batchsize,channel,pts pos3_1 = pos3 pos3_upsample = pos3.reshape(3, -1) x3 = x2[idx] x3 = x3.view(-1, x3.size(1), batch_size) #pts,channel,bs x3 = x3.permute(2, 1, 0) #.view(x3.size(2),x3.size(1),x3.size(0)) ########################## #pos3 = pos2[:,idx] x3 = self.edge3(x3, pos3_1) #bottleneck x3_upsample = x3 #bs,channel,n_pts bs3 = torch.arange(x3.size(0)).to(self.device) bs3 = bs3.repeat_interleave(x3.size(2)) pos3 = pos3_upsample pos2 = pos2_upsample pos = pos_upsample x3_2 = x3.reshape(x3.size(0) * x3.size(2), -1) x4 = knn_interpolate(x3_2, pos3.reshape(pos3.size(1), pos3.size(0)), pos2.reshape(pos2.size(1), pos2.size(0)), batch_x=bs3, batch_y=bs2) x4 = x4.reshape(batch_size, x3_2.size(1), -1) x4 = torch.cat((x4, x2_upsample), 1) x4 = self.edge4(x4, pos2_1) x4_2 = x4.reshape(x4.size(0) * x4.size(2), -1) x5 = knn_interpolate(x4_2, pos2.reshape(pos2.size(1), pos2.size(0)), pos.reshape(pos.size(1), pos.size(0)), batch_x=bs2, batch_y=bs) x5 = x5.reshape(batch_size, x4_2.size(1), -1) x5 = torch.cat((x5, x1_upsample), 1) x5 = self.edge5(x5, pos_1) x5 = self.edge6(x5, pos_1) residual = x5 x = self.conv4( x5 ) # (batch_size, 256, num_points) -> (batch_size, 256, num_points) x = self.conv5( x ) # (batch_size, 256, num_points) -> (batch_size, 128, num_points) x = self.dp3(x) x = self.conv6( x ) # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points) return x
def knn_interpolate(self, pos, k): x = knn_interpolate(self.x, self.pos, pos, k=k) return Prediction(pos, x)