def test_knn(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) y = tensor([ [1, 0], [-1, 0], ], dtype, device) batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch_y = tensor([0, 1], torch.long, device) edge_index = knn(x, y, 2) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) edge_index = knn(x, y, 2, batch_x, batch_y) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) if x.is_cuda: edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) # Skipping a batch batch_x = tensor([0, 0, 0, 0, 2, 2, 2, 2], torch.long, device) batch_y = tensor([0, 2], torch.long, device) edge_index = knn(x, y, 2, batch_x, batch_y) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
def test_knn(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) y = tensor([ [1, 0], [-1, 0], ], dtype, device) batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch_y = tensor([0, 1], torch.long, device) row, col = knn(x, y, 2, batch_x, batch_y) col = col.view(-1, 2).sort(dim=-1)[0].view(-1) assert row.tolist() == [0, 0, 1, 1] assert col.tolist() == [2, 3, 4, 5] if x.is_cuda: row, col = knn(x, y, 2, batch_x, batch_y, cosine=True) assert row.tolist() == [0, 0, 1, 1] assert col.tolist() == [0, 1, 4, 5]
def point_to_plane(pc1: torch.Tensor, pc2: torch.Tensor, pc1_normals) -> torch.Tensor: """ Calculates the point-to-plane distortion from pc1 -> pc2 See `"Geometric Distortion Metrics for Point Cloud Compression" <https://www.merl.com/publications/docs/TR2017-113.pdf>` paper. :param pc1: Tensor [bn x d] representing Point cloud 1 :param pc2: Tensor [bn x d] representing Point cloud 2 :param pc1_normals: Tensor [bn x 3] representing the normalized surface normals for pc1 in x/y/z. :return: Tensor shape [] containing the point-to-plane distortion """ # for knn search we only consider the spatial coordinates spatial_f1 = pc1[..., :3] if pc1.shape[-1] > 3 else pc1 spatial_f2 = pc2[..., :3] if pc2.shape[-1] > 3 else pc2 # Find the nearest neighbour nn_f1, nn_f2 = knn(spatial_f2.contiguous(), spatial_f1.contiguous(), 1) # Calculate the sum of squared errors nn_error = pc2[nn_f2, :3] - pc1[nn_f1, :3] batch_errors = nn_error.view(-1, 1, 3) batch_normals = pc1_normals.view(-1, 3, 1) projected_error = torch.bmm(batch_errors, batch_normals) projected_error = projected_error**2 return projected_error.mean()
def forward( self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: """""" is_bipartite: bool = True if isinstance(x, Tensor): x: PairTensor = (x, x) is_bipartite = False assert x[0].dim() == 2, 'Static graphs not supported in `GravNetConv`.' b: PairOptTensor = (None, None) if isinstance(batch, Tensor): b = (batch, batch) elif isinstance(batch, tuple): assert batch is not None b = (batch[0], batch[1]) h_l: Tensor = self.lin_h(x[0]) s_l: Tensor = self.lin_s(x[0]) s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l edge_index = knn(s_l, s_r, self.k, b[0], b[1], num_workers=self.num_workers) edge_weight = (s_l[edge_index[1]] - s_r[edge_index[0]]).pow(2).sum(-1) edge_weight = torch.exp(-10. * edge_weight) # 10 gives a better spread # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=(h_l, None), edge_weight=edge_weight, size=(s_l.size(0), s_r.size(0))) return self.lin_p(out), edge_index, edge_weight
def forward( self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: """""" if isinstance(x, Tensor): x: PairTensor = (x, x) assert x[0].dim() == 2, \ 'Static graphs not supported in `DynamicEdgeConv`.' b: PairOptTensor = (None, None) if isinstance(batch, Tensor): b = (batch, batch) elif isinstance(batch, tuple): assert batch is not None b = (batch[0], batch[1]) edge_index = knn(x[0], x[1], self.k, b[0], b[1], num_workers=self.num_workers) # propagate_type: (x: PairTensor) return self.propagate(edge_index, x=x, size=None)
def forward( self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: is_bipartite: bool = True if isinstance(x, Tensor): x: PairTensor = (x, x) is_bipartite = False if x[0].dim() != 2: raise ValueError("Static graphs not supported in 'GravNetConv'") b: PairOptTensor = (None, None) if isinstance(batch, Tensor): b = (batch, batch) elif isinstance(batch, tuple): assert batch is not None b = (batch[0], batch[1]) # embed the inputs before message passing msg_activations = self.lin_p(x[0]) # transform to the space dimension to build the graph s_l: Tensor = self.lin_s(x[0]) s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l # add error message when trying to preform knn without enough neighbors in the region if (torch.unique(b[0], return_counts=True)[1] < self.k).sum() != 0: raise RuntimeError( f"Not enough elements in a region to perform the k-nearest neighbors. Current k-value={self.k}" ) edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0]) # edge_index = knn_graph(s_l, self.k, b[0]) # cmspepr edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1) edge_weight = torch.exp(-10.0 * edge_weight) # 10 gives a better spread # message passing out = self.propagate(edge_index, x=(msg_activations, None), edge_weight=edge_weight, size=(s_l.size(0), s_r.size(0))) return self.lin_out(out)
def forward( self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: """""" is_bipartite: bool = True if isinstance(x, Tensor): x: PairTensor = (x, x) is_bipartite = False if x[0].dim() != 2: raise ValueError("Static graphs not supported in 'GravNetConv'") b: PairOptTensor = (None, None) if isinstance(batch, Tensor): b = (batch, batch) elif isinstance(batch, tuple): assert batch is not None b = (batch[0], batch[1]) # embed the inputs before message passing msg_activations = self.lin_p(x[0]) # transform to the space dimension to build the graph s_l: Tensor = self.lin_s(x[0]) s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0]) # edge_index = knn_graph(s_l, self.k, b[0], b[1]).flip([0]) edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1) edge_weight = torch.exp(-10.0 * edge_weight) # 10 gives a better spread # return the adjacency matrix of the graph for lrp purposes A = to_dense_adj( edge_index.to("cpu"), edge_attr=edge_weight.to("cpu"))[0] # adjacency matrix # message passing out = self.propagate(edge_index, x=(msg_activations, None), edge_weight=edge_weight, size=(s_l.size(0), s_r.size(0))) return self.lin_out(out), A, msg_activations
def knn(x: Tensor, y: Tensor, k: int, batch_x: OptTensor = None, batch_y: OptTensor = None, cosine: bool = False, num_workers: int = 1) -> Tensor: r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in :obj:`x`. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. y (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{M \times F}`. k (int): The number of neighbors. batch_x (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) batch_y (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each node to a specific example. (default: :obj:`None`) cosine (boolean, optional): If :obj:`True`, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default: :obj:`False`) num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) :rtype: :class:`LongTensor` .. code-block:: python import torch from torch_geometric.nn import knn x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch_x = torch.tensor([0, 0, 0, 0]) y = torch.Tensor([[-1, 0], [1, 0]]) batch_y = torch.tensor([0, 0]) assign_index = knn(x, y, 2, batch_x, batch_y) """ return torch_cluster.knn(x, y, k, batch_x, batch_y, cosine, num_workers)
def test_radius(dtype, device): x = tensor([ [-1, -1], [-1, +1], [+1, +1], [+1, -1], [-1, -1], [-1, +1], [+1, +1], [+1, -1], ], dtype, device) y = tensor([ [1, 0], [-1, 0], ], dtype, device) batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch_y = tensor([0, 1], torch.long, device) out = knn(x, y, 2, batch_x, batch_y) assert out.tolist() == [[0, 0, 1, 1], [2, 3, 4, 5]]
def forward( self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: # type: (Tensor, OptTensor) -> Tensor # noqa # type: (PairTensor, Optional[PairTensor]) -> Tensor # noqa """""" is_bipartite: bool = True if isinstance(x, Tensor): x: PairTensor = (x, x) is_bipartite = False if x[0].dim() != 2: raise ValueError("Static graphs not supported in 'GravNetConv'") b: PairOptTensor = (None, None) if isinstance(batch, Tensor): b = (batch, batch) elif isinstance(batch, tuple): assert batch is not None b = (batch[0], batch[1]) h_l: Tensor = self.lin_h(x[0]) s_l: Tensor = self.lin_s(x[0]) s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0]) edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1) edge_weight = torch.exp(-10. * edge_weight) # 10 gives a better spread # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=(h_l, None), edge_weight=edge_weight, size=(s_l.size(0), s_r.size(0))) return self.lin_out1(x[1]) + self.lin_out2(out)
def knn(x, y, k, batch_x=None, batch_y=None, cosine=False): r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in :obj:`x`. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. y (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{M \times F}`. k (int): The number of neighbors. batch_x (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) batch_y (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each node to a specific example. (default: :obj:`None`) cosine (boolean, optional): If :obj:`True`, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default: :obj:`False`) :rtype: :class:`LongTensor` .. code-block:: python import torch from torch_geometric.nn import knn x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch_x = torch.tensor([0, 0, 0, 0]) y = torch.Tensor([[-1, 0], [1, 0]]) batch_x = torch.tensor([0, 0]) assign_index = knn(x, y, 2, batch_x, batch_y) """ if torch_cluster is None: raise ImportError('`knn` requires `torch-cluster`.') return torch_cluster.knn(x, y, k, batch_x, batch_y, cosine)
def one_way_matching_distortion(pc1: torch.Tensor, pc2: torch.Tensor) -> torch.Tensor: """ Calculates the one way matching distortion from pc1 -> pc2 See `"Dynamic Polygon Clouds: Representation and Compression for VR/AR" <https://arxiv.org/abs/1610.00402>` paper. :param pc1: Tensor representing Point cloud 1 :param pc2: Tensor representing Point cloud 2 :return: Tensor shape [] containing the matching distortion value. """ # for knn search we only consider the spatial coordinates spatial_f1 = pc1[..., :3] if pc1.shape[-1] > 3 else pc1 spatial_f2 = pc2[..., :3] if pc2.shape[-1] > 3 else pc2 # Find the nearest neighbour nn_f1, nn_f2 = knn(spatial_f2.contiguous(), spatial_f1.contiguous(), 1) # Calculate the sum of squared errors nn_error = pc2[nn_f2, ...] - pc1[nn_f1, ...] nn_squared_error = nn_error**2 sum_squared_error = nn_squared_error.sum(dim=-1) return sum_squared_error.mean()
def forward( self, x: Union[Tensor, PairTensor], batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor: # type: (Tensor, OptTensor) -> Tensor # noqa # type: (PairTensor, Optional[PairTensor]) -> Tensor # noqa """""" if isinstance(x, Tensor): x: PairTensor = (x, x) if x[0].dim() != 2: raise ValueError("Static graphs not supported in DynamicEdgeConv") b: PairOptTensor = (None, None) if isinstance(batch, Tensor): b = (batch, batch) elif isinstance(batch, tuple): assert batch is not None b = (batch[0], batch[1]) edge_index = knn(x[0], x[1], self.k, b[0], b[1]).flip([0]) # propagate_type: (x: PairTensor) return self.propagate(edge_index, x=x, size=None)
def stroke_renderer(curve_points: T.Tensor, locations: T.Tensor, colors: T.Tensor, widths: T.Tensor, H: int, W: int, K: int, canvas_color: float): """ Renders the given brushstroke parameters onto a canvas. See Alg. 1 in https://arxiv.org/pdf/2103.17185.pdf. Args: curve_points (tensor): Points specifying the curves that will be rendered on the canvas, shape [N, S, 2]. locations (tensor): Location of each curve, shape [N, 2]. colors (tensor): Color of each curve, shape [N, 3]. widths (tensor): Width of each curve, shape [N, 1]. H (int): Height of the canvas. W (int): Width of the canvas. K (int): Number of brushstrokes to consider for each pixel, see Sec. C.2 of the paper (Arxiv version). canvas_color (str): Background color of the canvas. Options: 'gray', 'white', 'black', 'noise'. Returns: (tensor): The rendered canvas, shape [H, W, 3]. """ colors = T.clamp(colors, 0., 1.) coord_x, coord_y = T.split(locations, [1, 1], dim=-1) coord_x = T.clamp(coord_x, 0, W) coord_y = T.clamp(coord_y, 0, H) locations = T.cat((coord_x, coord_y), dim=1) widths = T.exp(widths) device = curve_points.device N, S, _ = curve_points.shape # define coarse grid cell t_H = T.linspace(0., float(H), int(H // 5)).to(device) t_W = T.linspace(0., float(W), int(W // 5)).to(device) P_y, P_x = T.meshgrid(t_H, t_W) P = T.stack([P_x, P_y], dim=-1) # [32, 32, 2] # Find nearest brushstrokes' indices for every coarse grid cell indices = knn(locations, P.view(-1, 2), k=K)[1] # Resize the KNN index tensor to full resolution indices = indices.view(len(t_H), len(t_W), -1) indices = indices.permute(2, 0, 1) indices = TF.resize(indices, size=(H, W), interpolation=TF.InterpolationMode.NEAREST) indices = indices.permute(1, 2, 0) # locations of points sampled from curves canvas_with_nearest_Bs = curve_points[indices.flatten()].view(H, W, K, S, 2) # colors of curves canvas_with_nearest_Bs_colors = colors[indices.flatten()].view(H, W, K, 3) # brush size canvas_with_nearest_Bs_bs = widths[indices.flatten()].view(H, W, K, 1) # Now create full-size canvas t_H = T.linspace(0., float(H), H).to(device) t_W = T.linspace(0., float(W), W).to(device) P_y, P_x = T.meshgrid(t_H, t_W) P_full = T.stack([P_x, P_y], dim=-1) # [H, W, 2] # Compute distance from every pixel on canvas to each (among nearest ones) line segment between points from curves indices_a = T.tensor([i for i in range(S - 1)], dtype=T.long).to(device) canvas_with_nearest_Bs_a = canvas_with_nearest_Bs[:, :, :, indices_a, :] # start points of each line segment indices_b = T.tensor([i for i in range(1, S)], dtype=T.long).to(device) canvas_with_nearest_Bs_b = canvas_with_nearest_Bs[:, :, :, indices_b, :] # end points of each line segments canvas_with_nearest_Bs_b_a = canvas_with_nearest_Bs_b - canvas_with_nearest_Bs_a # [H, W, N, S - 1, 2] P_full_canvas_with_nearest_Bs_a = P_full[:, :, None, None, :] - canvas_with_nearest_Bs_a # [H, W, K, S - 1, 2] # find the projection of grid points on curves # first find the projections of a grid point on each line segment of a curve # numerator is the dot product between two vectors # the first vector is the line segments. the second vector is the sample points -> grid t = T.sum(canvas_with_nearest_Bs_b_a * P_full_canvas_with_nearest_Bs_a, dim=-1) / ( T.sum(canvas_with_nearest_Bs_b_a ** 2, dim=-1) + 1e-8) # if t value is outside [0, 1], then the nearest point on the line does not lie on the segment, so clip values of t t = T.clamp(t, 0., 1.) # compute closest points on each line segment, which are the projections on each segment - [H, W, K, S - 1, 2] closest_points_on_each_line_segment = canvas_with_nearest_Bs_a + t[..., None] * canvas_with_nearest_Bs_b_a # compute the distance from every pixel to the closest point on each line segment - [H, W, K, S - 1] dist_to_closest_point_on_line_segment = T.sum( (P_full[..., None, None, :] - closest_points_on_each_line_segment) ** 2, dim=-1) # and distance to the nearest bezier curve. D_per_strokes = T.amin(dist_to_closest_point_on_line_segment, dim=-1) # [H, W, K] D = T.amin(D_per_strokes, dim=-1) # [H, W] # Finally render curves on a canvas to obtain image. I_NNs_B_ranking = F.softmax(100000. * (1.0 / (1e-8 + D_per_strokes)), dim=-1) # [H, W, N] I_colors = T.einsum('hwnf,hwn->hwf', canvas_with_nearest_Bs_colors, I_NNs_B_ranking) # [H, W, 3] bs = T.einsum('hwnf,hwn->hwf', canvas_with_nearest_Bs_bs, I_NNs_B_ranking) # [H, W, 1] bs_mask = T.sigmoid(bs - D[..., None]) # AOE of each brush stroke canvas = T.ones_like(I_colors) * canvas_color I = I_colors * bs_mask + (1 - bs_mask) * canvas return I # HxWx3
def forward(self, pc1, pc2, pc1_batch, pc2_batch): num_points = pc1.shape[1] # Center both point clouds around origin pcs = torch.cat([pc1, pc2])[..., :3] centroids = torch.mean(pcs, dim=0) pc1[..., :3] -= centroids pc2[..., :3] -= centroids pc1_centroids = torch.mean(pc1, dim=0) pc2_centroids = torch.mean(pc1, dim=0) pc1 -= pc1_centroids pc2 -= pc2_centroids edge_index1 = knn_graph(pc1, k=self.k, batch=pc1_batch) edge_index2 = knn(pc2, pc1, self.k, batch_x=pc2_batch, batch_y=pc1_batch) # TODO: Sort out batches # print("net1") pc1 += pc1_centroids pc2 += pc2_centroids net1 = self.conv1(pc1, edge_index1, pc2, edge_index2) # print("net2") net2 = self.conv1(pc1, edge_index1, pc2, edge_index2) # print("knn") edge_index = knn_graph(net1, k=self.k, batch=pc1_batch) # print("net3") # print("edges: ", edge_index.shape ) # print("net1: ", net1.shape ) net3 = self.conv2(net1, edge_index) # print("nets") # print("net1 ",net1.shape) # print("net2 ",net2.shape) # print("net3 ",net3.shape) nets = torch.cat([net1, net2, net3], dim=-1) # print("catted", nets.shape) # print("out7") out7 = self.conv3(nets) # print("ou7: ", out7.shape) # print("max pool") global_features = global_max_pool(out7, pc1_batch) # print("global feature ", global_features.shape) expand_x = global_features[pc1_batch] #expand_x = x.repeat(1,num_points, 1, 1) # print("expand ", expand_x.shape) # print("net1 ",net1.shape) # print("net2 ",net2.shape) # print("net3 ",net3.shape) concat = torch.cat([expand_x, net1, net2, net3], dim=-1) # dim = 1024 + 64 + 64 + 256 = 1408 # print("conv4") x = self.conv4(concat) # print("x: ", x.shape) # print("conv5") x = self.conv5(x) # print("x: ", x.shape) # print("conv6") x = self.conv6(x) # print("x: ", x.shape) # print("conv7") x = self.conv7(x) # print("x: ", x.shape) # print("donezo") # print("pc1: ", pc1.shape) # print("x: ", x.shape) out = pc1 + x out[..., :3] += centroids return out