コード例 #1
0
def test_knn(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)
    y = tensor([
        [1, 0],
        [-1, 0],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 1], torch.long, device)

    edge_index = knn(x, y, 2)
    assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])

    edge_index = knn(x, y, 2, batch_x, batch_y)
    assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])

    if x.is_cuda:
        edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True)
        assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])

    # Skipping a batch
    batch_x = tensor([0, 0, 0, 0, 2, 2, 2, 2], torch.long, device)
    batch_y = tensor([0, 2], torch.long, device)
    edge_index = knn(x, y, 2, batch_x, batch_y)
    assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])
コード例 #2
0
def test_knn(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)
    y = tensor([
        [1, 0],
        [-1, 0],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 1], torch.long, device)

    row, col = knn(x, y, 2, batch_x, batch_y)
    col = col.view(-1, 2).sort(dim=-1)[0].view(-1)

    assert row.tolist() == [0, 0, 1, 1]
    assert col.tolist() == [2, 3, 4, 5]

    if x.is_cuda:
        row, col = knn(x, y, 2, batch_x, batch_y, cosine=True)
        assert row.tolist() == [0, 0, 1, 1]
        assert col.tolist() == [0, 1, 4, 5]
コード例 #3
0
def point_to_plane(pc1: torch.Tensor, pc2: torch.Tensor,
                   pc1_normals) -> torch.Tensor:
    """ Calculates the point-to-plane distortion from pc1 -> pc2

    See `"Geometric Distortion Metrics for Point Cloud Compression" <https://www.merl.com/publications/docs/TR2017-113.pdf>` paper.

    :param pc1: Tensor [bn x d] representing Point cloud 1
    :param pc2: Tensor [bn x d] representing Point cloud 2
    :param pc1_normals: Tensor [bn x 3] representing the normalized surface normals for pc1 in x/y/z.
    :return: Tensor shape [] containing the point-to-plane distortion
    """

    # for knn search we only consider the spatial coordinates
    spatial_f1 = pc1[..., :3] if pc1.shape[-1] > 3 else pc1
    spatial_f2 = pc2[..., :3] if pc2.shape[-1] > 3 else pc2

    # Find the nearest neighbour
    nn_f1, nn_f2 = knn(spatial_f2.contiguous(), spatial_f1.contiguous(), 1)

    # Calculate the sum of squared errors
    nn_error = pc2[nn_f2, :3] - pc1[nn_f1, :3]

    batch_errors = nn_error.view(-1, 1, 3)
    batch_normals = pc1_normals.view(-1, 3, 1)

    projected_error = torch.bmm(batch_errors, batch_normals)
    projected_error = projected_error**2

    return projected_error.mean()
コード例 #4
0
    def forward(
            self, x: Union[Tensor, PairTensor],
            batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor:
        """"""

        is_bipartite: bool = True
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
            is_bipartite = False
        assert x[0].dim() == 2, 'Static graphs not supported in `GravNetConv`.'

        b: PairOptTensor = (None, None)
        if isinstance(batch, Tensor):
            b = (batch, batch)
        elif isinstance(batch, tuple):
            assert batch is not None
            b = (batch[0], batch[1])

        h_l: Tensor = self.lin_h(x[0])

        s_l: Tensor = self.lin_s(x[0])
        s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l

        edge_index = knn(s_l, s_r, self.k, b[0], b[1],
                         num_workers=self.num_workers)

        edge_weight = (s_l[edge_index[1]] - s_r[edge_index[0]]).pow(2).sum(-1)
        edge_weight = torch.exp(-10. * edge_weight)  # 10 gives a better spread

        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=(h_l, None),
                             edge_weight=edge_weight,
                             size=(s_l.size(0), s_r.size(0)))

        return self.lin_p(out), edge_index, edge_weight
    def forward(
            self,
            x: Union[Tensor, PairTensor],
            batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
        assert x[0].dim() == 2, \
            'Static graphs not supported in `DynamicEdgeConv`.'

        b: PairOptTensor = (None, None)
        if isinstance(batch, Tensor):
            b = (batch, batch)
        elif isinstance(batch, tuple):
            assert batch is not None
            b = (batch[0], batch[1])

        edge_index = knn(x[0],
                         x[1],
                         self.k,
                         b[0],
                         b[1],
                         num_workers=self.num_workers)

        # propagate_type: (x: PairTensor)
        return self.propagate(edge_index, x=x, size=None)
コード例 #6
0
    def forward(
            self,
            x: Union[Tensor, PairTensor],
            batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor:

        is_bipartite: bool = True
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
            is_bipartite = False

        if x[0].dim() != 2:
            raise ValueError("Static graphs not supported in 'GravNetConv'")

        b: PairOptTensor = (None, None)
        if isinstance(batch, Tensor):
            b = (batch, batch)
        elif isinstance(batch, tuple):
            assert batch is not None
            b = (batch[0], batch[1])

        # embed the inputs before message passing
        msg_activations = self.lin_p(x[0])

        # transform to the space dimension to build the graph
        s_l: Tensor = self.lin_s(x[0])
        s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l

        # add error message when trying to preform knn without enough neighbors in the region
        if (torch.unique(b[0], return_counts=True)[1] < self.k).sum() != 0:
            raise RuntimeError(
                f"Not enough elements in a region to perform the k-nearest neighbors. Current k-value={self.k}"
            )

        edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0])
        # edge_index = knn_graph(s_l, self.k, b[0])     # cmspepr

        edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1)
        edge_weight = torch.exp(-10.0 *
                                edge_weight)  # 10 gives a better spread

        # message passing
        out = self.propagate(edge_index,
                             x=(msg_activations, None),
                             edge_weight=edge_weight,
                             size=(s_l.size(0), s_r.size(0)))

        return self.lin_out(out)
コード例 #7
0
    def forward(
            self,
            x: Union[Tensor, PairTensor],
            batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor:
        """"""

        is_bipartite: bool = True
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
            is_bipartite = False

        if x[0].dim() != 2:
            raise ValueError("Static graphs not supported in 'GravNetConv'")

        b: PairOptTensor = (None, None)
        if isinstance(batch, Tensor):
            b = (batch, batch)
        elif isinstance(batch, tuple):
            assert batch is not None
            b = (batch[0], batch[1])

        # embed the inputs before message passing
        msg_activations = self.lin_p(x[0])

        # transform to the space dimension to build the graph
        s_l: Tensor = self.lin_s(x[0])
        s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l

        edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0])
        # edge_index = knn_graph(s_l, self.k, b[0], b[1]).flip([0])

        edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1)
        edge_weight = torch.exp(-10.0 *
                                edge_weight)  # 10 gives a better spread

        # return the adjacency matrix of the graph for lrp purposes
        A = to_dense_adj(
            edge_index.to("cpu"),
            edge_attr=edge_weight.to("cpu"))[0]  # adjacency matrix

        # message passing
        out = self.propagate(edge_index,
                             x=(msg_activations, None),
                             edge_weight=edge_weight,
                             size=(s_l.size(0), s_r.size(0)))

        return self.lin_out(out), A, msg_activations
コード例 #8
0
def knn(x: Tensor,
        y: Tensor,
        k: int,
        batch_x: OptTensor = None,
        batch_y: OptTensor = None,
        cosine: bool = False,
        num_workers: int = 1) -> Tensor:
    r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
    :obj:`x`.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        y (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{M \times F}`.
        k (int): The number of neighbors.
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
            node to a specific example. (default: :obj:`None`)
        cosine (boolean, optional): If :obj:`True`, will use the cosine
            distance instead of euclidean distance to find nearest neighbors.
            (default: :obj:`False`)
        num_workers (int): Number of workers to use for computation. Has no
            effect in case :obj:`batch_x` or :obj:`batch_y` is not
            :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)

    :rtype: :class:`LongTensor`

    .. code-block:: python

        import torch
        from torch_geometric.nn import knn

        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])
        assign_index = knn(x, y, 2, batch_x, batch_y)
    """
    return torch_cluster.knn(x, y, k, batch_x, batch_y, cosine, num_workers)
コード例 #9
0
def test_radius(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)
    y = tensor([
        [1, 0],
        [-1, 0],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 1], torch.long, device)

    out = knn(x, y, 2, batch_x, batch_y)
    assert out.tolist() == [[0, 0, 1, 1], [2, 3, 4, 5]]
コード例 #10
0
    def forward(
            self,
            x: Union[Tensor, PairTensor],
            batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor:
        # type: (Tensor, OptTensor) -> Tensor  # noqa
        # type: (PairTensor, Optional[PairTensor]) -> Tensor  # noqa
        """"""

        is_bipartite: bool = True
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
            is_bipartite = False

        if x[0].dim() != 2:
            raise ValueError("Static graphs not supported in 'GravNetConv'")

        b: PairOptTensor = (None, None)
        if isinstance(batch, Tensor):
            b = (batch, batch)
        elif isinstance(batch, tuple):
            assert batch is not None
            b = (batch[0], batch[1])

        h_l: Tensor = self.lin_h(x[0])

        s_l: Tensor = self.lin_s(x[0])
        s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l

        edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0])

        edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1)
        edge_weight = torch.exp(-10. * edge_weight)  # 10 gives a better spread

        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
        out = self.propagate(edge_index,
                             x=(h_l, None),
                             edge_weight=edge_weight,
                             size=(s_l.size(0), s_r.size(0)))

        return self.lin_out1(x[1]) + self.lin_out2(out)
コード例 #11
0
def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
    r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
    :obj:`x`.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        y (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{M \times F}`.
        k (int): The number of neighbors.
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
            node to a specific example. (default: :obj:`None`)
        cosine (boolean, optional): If :obj:`True`, will use the cosine
            distance instead of euclidean distance to find nearest neighbors.
            (default: :obj:`False`)

    :rtype: :class:`LongTensor`

    .. code-block:: python

        import torch
        from torch_geometric.nn import knn

        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_x = torch.tensor([0, 0])
        assign_index = knn(x, y, 2, batch_x, batch_y)
    """
    if torch_cluster is None:
        raise ImportError('`knn` requires `torch-cluster`.')

    return torch_cluster.knn(x, y, k, batch_x, batch_y, cosine)
コード例 #12
0
def one_way_matching_distortion(pc1: torch.Tensor,
                                pc2: torch.Tensor) -> torch.Tensor:
    """ Calculates the one way matching distortion from pc1 -> pc2

    See `"Dynamic Polygon Clouds: Representation and Compression for VR/AR" <https://arxiv.org/abs/1610.00402>` paper.

    :param pc1: Tensor representing Point cloud 1
    :param pc2: Tensor representing Point cloud 2
    :return: Tensor shape [] containing the matching distortion value.
    """

    # for knn search we only consider the spatial coordinates
    spatial_f1 = pc1[..., :3] if pc1.shape[-1] > 3 else pc1
    spatial_f2 = pc2[..., :3] if pc2.shape[-1] > 3 else pc2

    # Find the nearest neighbour
    nn_f1, nn_f2 = knn(spatial_f2.contiguous(), spatial_f1.contiguous(), 1)

    # Calculate the sum of squared errors
    nn_error = pc2[nn_f2, ...] - pc1[nn_f1, ...]
    nn_squared_error = nn_error**2
    sum_squared_error = nn_squared_error.sum(dim=-1)

    return sum_squared_error.mean()
コード例 #13
0
    def forward(
            self,
            x: Union[Tensor, PairTensor],
            batch: Union[OptTensor, Optional[PairTensor]] = None) -> Tensor:
        # type: (Tensor, OptTensor) -> Tensor  # noqa
        # type: (PairTensor, Optional[PairTensor]) -> Tensor  # noqa
        """"""
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)

        if x[0].dim() != 2:
            raise ValueError("Static graphs not supported in DynamicEdgeConv")

        b: PairOptTensor = (None, None)
        if isinstance(batch, Tensor):
            b = (batch, batch)
        elif isinstance(batch, tuple):
            assert batch is not None
            b = (batch[0], batch[1])

        edge_index = knn(x[0], x[1], self.k, b[0], b[1]).flip([0])

        # propagate_type: (x: PairTensor)
        return self.propagate(edge_index, x=x, size=None)
コード例 #14
0
def stroke_renderer(curve_points: T.Tensor, locations: T.Tensor, colors: T.Tensor, widths: T.Tensor,
                    H: int, W: int, K: int, canvas_color: float):
    """
    Renders the given brushstroke parameters onto a canvas.
    See Alg. 1 in https://arxiv.org/pdf/2103.17185.pdf.

    Args:
        curve_points (tensor): Points specifying the curves that will be rendered on the canvas, shape [N, S, 2].
        locations (tensor): Location of each curve, shape [N, 2].
        colors (tensor): Color of each curve, shape [N, 3].
        widths (tensor): Width of each curve, shape [N, 1].
        H (int): Height of the canvas.
        W (int): Width of the canvas.
        K (int): Number of brushstrokes to consider for each pixel, see Sec. C.2 of the paper (Arxiv version).
        canvas_color (str): Background color of the canvas. Options: 'gray', 'white', 'black', 'noise'.
    Returns:
        (tensor): The rendered canvas, shape [H, W, 3].
    """
    colors = T.clamp(colors, 0., 1.)
    coord_x, coord_y = T.split(locations, [1, 1], dim=-1)
    coord_x = T.clamp(coord_x, 0, W)
    coord_y = T.clamp(coord_y, 0, H)
    locations = T.cat((coord_x, coord_y), dim=1)
    widths = T.exp(widths)

    device = curve_points.device
    N, S, _ = curve_points.shape

    # define coarse grid cell
    t_H = T.linspace(0., float(H), int(H // 5)).to(device)
    t_W = T.linspace(0., float(W), int(W // 5)).to(device)
    P_y, P_x = T.meshgrid(t_H, t_W)
    P = T.stack([P_x, P_y], dim=-1)  # [32, 32, 2]

    # Find nearest brushstrokes' indices for every coarse grid cell
    indices = knn(locations, P.view(-1, 2), k=K)[1]

    # Resize the KNN index tensor to full resolution
    indices = indices.view(len(t_H), len(t_W), -1)
    indices = indices.permute(2, 0, 1)
    indices = TF.resize(indices, size=(H, W), interpolation=TF.InterpolationMode.NEAREST)
    indices = indices.permute(1, 2, 0)

    # locations of points sampled from curves
    canvas_with_nearest_Bs = curve_points[indices.flatten()].view(H, W, K, S, 2)

    # colors of curves
    canvas_with_nearest_Bs_colors = colors[indices.flatten()].view(H, W, K, 3)

    # brush size
    canvas_with_nearest_Bs_bs = widths[indices.flatten()].view(H, W, K, 1)

    # Now create full-size canvas
    t_H = T.linspace(0., float(H), H).to(device)
    t_W = T.linspace(0., float(W), W).to(device)
    P_y, P_x = T.meshgrid(t_H, t_W)
    P_full = T.stack([P_x, P_y], dim=-1)  # [H, W, 2]

    # Compute distance from every pixel on canvas to each (among nearest ones) line segment between points from curves
    indices_a = T.tensor([i for i in range(S - 1)], dtype=T.long).to(device)
    canvas_with_nearest_Bs_a = canvas_with_nearest_Bs[:, :, :, indices_a, :]  # start points of each line segment
    indices_b = T.tensor([i for i in range(1, S)], dtype=T.long).to(device)
    canvas_with_nearest_Bs_b = canvas_with_nearest_Bs[:, :, :, indices_b, :]  # end points of each line segments
    canvas_with_nearest_Bs_b_a = canvas_with_nearest_Bs_b - canvas_with_nearest_Bs_a  # [H, W, N, S - 1, 2]
    P_full_canvas_with_nearest_Bs_a = P_full[:, :, None, None, :] - canvas_with_nearest_Bs_a  # [H, W, K, S - 1, 2]

    # find the projection of grid points on curves
    # first find the projections of a grid point on each line segment of a curve
    # numerator is the dot product between two vectors
    # the first vector is the line segments. the second vector is the sample points -> grid
    t = T.sum(canvas_with_nearest_Bs_b_a * P_full_canvas_with_nearest_Bs_a, dim=-1) / (
            T.sum(canvas_with_nearest_Bs_b_a ** 2, dim=-1) + 1e-8)

    # if t value is outside [0, 1], then the nearest point on the line does not lie on the segment, so clip values of t
    t = T.clamp(t, 0., 1.)

    # compute closest points on each line segment, which are the projections on each segment - [H, W, K, S - 1, 2]
    closest_points_on_each_line_segment = canvas_with_nearest_Bs_a + t[..., None] * canvas_with_nearest_Bs_b_a

    # compute the distance from every pixel to the closest point on each line segment - [H, W, K, S - 1]
    dist_to_closest_point_on_line_segment = T.sum(
        (P_full[..., None, None, :] - closest_points_on_each_line_segment) ** 2, dim=-1)

    # and distance to the nearest bezier curve.
    D_per_strokes = T.amin(dist_to_closest_point_on_line_segment, dim=-1)  # [H, W, K]
    D = T.amin(D_per_strokes, dim=-1)  # [H, W]

    # Finally render curves on a canvas to obtain image.
    I_NNs_B_ranking = F.softmax(100000. * (1.0 / (1e-8 + D_per_strokes)), dim=-1)  # [H, W, N]
    I_colors = T.einsum('hwnf,hwn->hwf', canvas_with_nearest_Bs_colors, I_NNs_B_ranking)  # [H, W, 3]
    bs = T.einsum('hwnf,hwn->hwf', canvas_with_nearest_Bs_bs, I_NNs_B_ranking)  # [H, W, 1]
    bs_mask = T.sigmoid(bs - D[..., None])  # AOE of each brush stroke
    canvas = T.ones_like(I_colors) * canvas_color
    I = I_colors * bs_mask + (1 - bs_mask) * canvas
    return I  # HxWx3
コード例 #15
0
    def forward(self, pc1, pc2, pc1_batch, pc2_batch):
        num_points = pc1.shape[1]

        # Center both point clouds around origin
        pcs = torch.cat([pc1, pc2])[..., :3]
        centroids = torch.mean(pcs, dim=0)
        pc1[..., :3] -= centroids
        pc2[..., :3] -= centroids

        pc1_centroids = torch.mean(pc1, dim=0)
        pc2_centroids = torch.mean(pc1, dim=0)
        pc1 -= pc1_centroids
        pc2 -= pc2_centroids
        edge_index1 = knn_graph(pc1, k=self.k, batch=pc1_batch)
        edge_index2 = knn(pc2,
                          pc1,
                          self.k,
                          batch_x=pc2_batch,
                          batch_y=pc1_batch)  # TODO: Sort out batches
        # print("net1")
        pc1 += pc1_centroids
        pc2 += pc2_centroids
        net1 = self.conv1(pc1, edge_index1, pc2, edge_index2)
        # print("net2")
        net2 = self.conv1(pc1, edge_index1, pc2, edge_index2)

        # print("knn")
        edge_index = knn_graph(net1, k=self.k, batch=pc1_batch)
        # print("net3")
        # print("edges: ", edge_index.shape )
        # print("net1: ", net1.shape )
        net3 = self.conv2(net1, edge_index)

        # print("nets")

        # print("net1 ",net1.shape)
        # print("net2 ",net2.shape)
        # print("net3 ",net3.shape)
        nets = torch.cat([net1, net2, net3], dim=-1)
        # print("catted", nets.shape)
        # print("out7")
        out7 = self.conv3(nets)
        # print("ou7: ", out7.shape)

        # print("max pool")
        global_features = global_max_pool(out7, pc1_batch)
        # print("global feature ", global_features.shape)
        expand_x = global_features[pc1_batch]
        #expand_x = x.repeat(1,num_points, 1, 1)

        # print("expand ", expand_x.shape)
        # print("net1 ",net1.shape)
        # print("net2 ",net2.shape)
        # print("net3 ",net3.shape)
        concat = torch.cat([expand_x, net1, net2, net3],
                           dim=-1)  # dim = 1024 + 64 + 64 + 256 = 1408

        # print("conv4")
        x = self.conv4(concat)
        # print("x: ", x.shape)
        # print("conv5")
        x = self.conv5(x)
        # print("x: ", x.shape)
        # print("conv6")
        x = self.conv6(x)
        # print("x: ", x.shape)
        # print("conv7")
        x = self.conv7(x)
        # print("x: ", x.shape)
        # print("donezo")

        # print("pc1: ", pc1.shape)
        # print("x: ", x.shape)
        out = pc1 + x
        out[..., :3] += centroids

        return out