示例#1
0
def get_matches(feat_source, feat_target, sym=False):

    matches = knn(feat_target, feat_source, k=1).T
    if sym:
        match_inv = knn(feat_source, feat_target, k=1).T
        mask = match_inv[matches[:, 1], 1] == torch.arange(matches.shape[0])
        return matches[mask]
    else:
        return matches
示例#2
0
    def embedding_correspondences(self, shape_x: Shape, shape_y: Shape, emb_x,
                                  emb_y):
        emb_x = emb_x.to(device_cpu)
        emb_y = emb_y.to(device_cpu)

        ass_x = knn(emb_y, emb_x, k=1).to(device)
        ass_y = knn(emb_x, emb_y, k=1).to(device)

        ass_y = torch.index_select(ass_y, 0, my_long_tensor([1, 0]))

        ass_x = ass_x.transpose(0, 1)
        ass_y = ass_y.transpose(0, 1)

        self.ass_to_samples(ass_x, ass_y, shape_x, shape_y)
示例#3
0
def empirical_estimate(points, num_neighbors):
    ps, batch = points.permute(1, 0).cat_tensors()

    N = ps.shape[0]
    val = knn(ps.contiguous(),
              ps.contiguous(),
              batch_x=batch,
              batch_y=batch,
              k=num_neighbors)
    A = ps[val[1, :]].reshape(N, num_neighbors, 3)
    A = A - A.mean(dim=1, keepdim=True)
    Asqr = A.permute(0, 2, 1).bmm(A)
    sigma, _ = Asqr.cpu().symeig()
    w = (sigma[:, 2] * sigma[:, 1]).sqrt()
    val = val[1, :].reshape(N, num_neighbors)
    w, _ = torch.median(w[val].to(ps.device), dim=1, keepdim=True)

    weights = TensorListList()
    bi = 0
    for point_list in points:
        ww = TensorList()
        for p in point_list:
            ww.append(w[batch == bi])
            bi = bi + 1

        weights.append(ww)

    return weights
示例#4
0
    def forward(self, data, idx):
        r, pos, batch = data.r, data.pos, data.batch
        #x, edge_index, batch, edge_attr = data.x.float(), data.edge_index, data.batch, data.edge_attr.float()
        #r_limit = r[0]*0.5
        #row, col = radius(pos, pos[idx], r_limit, batch, batch[idx], max_num_neighbors=64)
        row, col = knn(pos,pos, 32,batch, batch)
        #row, col = radius(pos, pos, r_limit, batch, batch, max_num_neighbors=32)
        edge_index = torch.stack([col, row], dim=0).to(device)# (col, row), or (col row)
        #edge_attr = torch.ones((edge_index.shape[1],1)).to(device)
        x1 = F.relu(self.conv1(pos, edge_index))
        x2 = F.relu(self.conv2(x1, edge_index))
        x = self.lin1(x2)
        
        
        
        # x = x.view(-1,512,128)
        # x = torch.transpose(x,1,2)
        # #x = x[:,:,:64]
        # x = x.reshape(-1,512)
        # x = self.condense(x)
        # x = x.view(-1,128,512)
        # x = torch.transpose(x,1,2)
        # x = x.reshape(-1,128)
        
        x = gmp(x, batch)
        x = self.lin2(x)       
        
        x = self.lin3(x)

        x = self.output(x)        
        return F.log_softmax(x, dim=-1)
示例#5
0
def find_k_neighbor(rep_pts, pts, K, D):
    """
    Find K nearest neighbor points
    :param pts: represent points(B, P, C)
    :param rep_pts: original points (B, N, C)
    :param K:
    :param D: Dilation rate
    :return group_pts: K neighbor points(B, P, K, C)
    """
    device = pts.device
    B, N, C = pts.shape
    _, N_rep, _ = rep_pts.shape
    batch_pts = (torch.arange(0, B * N) // N).to(pts.device).contiguous()
    batch_rep_pts = (torch.arange(0, B * N_rep) // N_rep).to(
        pts.device).contiguous()
    pts = pts.view(-1, C).contiguous()
    rep_pts = rep_pts.view(-1, C).contiguous()
    knn_indices = knn(pts, rep_pts, K * D, batch_pts, batch_rep_pts)
    knn_indices = knn_indices[
        1]  ## grab the indices at 0th index it is [0,0,0,0,0,1,1,1,1,1,2,2,2,2,...]
    group_pts = pts[knn_indices].view(B, N_rep, K * D, C).contiguous()
    rand_col = torch.randint(K * D, (K, ))
    group_pts = group_pts[:, :, rand_col, :]

    return group_pts, knn_indices, rand_col
示例#6
0
def fps_pooling(pos, x, edge_attr, batch=None, k=16, r=0.5, reduce='sum'):
    assert reduce in ['max', 'mean', 'add', 'sum']
    idx = fps(pos, batch, ratio=r)
    i, j = knn(pos, pos[idx], k, batch, batch[idx])
    x = scatter(x[j], i, dim=0, reduce=reduce)
    pos, edge_attr, batch = pos[idx], edge_attr[idx], batch[idx]
    return x, pos, edge_attr, batch
def knn_interpolate(x: torch.Tensor,
                    pos_x: torch.Tensor,
                    pos_y: torch.Tensor,
                    batch_x: OptTensor = None,
                    batch_y: OptTensor = None,
                    k: int = 3,
                    num_workers: int = 1):
    r"""The k-NN interpolation from the `"PointNet++: Deep Hierarchical
    Feature Learning on Point Sets in a Metric Space"
    <https://arxiv.org/abs/1706.02413>`_ paper.
    For each point :math:`y` with position :math:`\mathbf{p}(y)`, its
    interpolated features :math:`\mathbf{f}(y)` are given by

    .. math::
        \mathbf{f}(y) = \frac{\sum_{i=1}^k w(x_i) \mathbf{f}(x_i)}{\sum_{i=1}^k
        w(x_i)} \textrm{, where } w(x_i) = \frac{1}{d(\mathbf{p}(y),
        \mathbf{p}(x_i))^2}

    and :math:`\{ x_1, \ldots, x_k \}` denoting the :math:`k` nearest points
    to :math:`y`.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        pos_x (Tensor): Node position matrix
            :math:`\in \mathbb{R}^{N \times d}`.
        pos_y (Tensor): Upsampled node position matrix
            :math:`\in \mathbb{R}^{M \times d}`.
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
            each node from :math:`\mathbf{X}` to a specific example.
            (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
            each node from :math:`\mathbf{Y}` to a specific example.
            (default: :obj:`None`)
        k (int, optional): Number of neighbors. (default: :obj:`3`)
        num_workers (int): Number of workers to use for computation. Has no
            effect in case :obj:`batch_x` or :obj:`batch_y` is not
            :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
    """

    with torch.no_grad():
        assign_index = knn(pos_x,
                           pos_y,
                           k,
                           batch_x=batch_x,
                           batch_y=batch_y,
                           num_workers=num_workers)
        y_idx, x_idx = assign_index[0], assign_index[1]
        diff = pos_x[x_idx] - pos_y[y_idx]
        squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
        weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

    y = scatter_add(x[x_idx] * weights, y_idx, dim=0, dim_size=pos_y.size(0))
    y = y / scatter_add(weights, y_idx, dim=0, dim_size=pos_y.size(0))

    return y
示例#8
0
def edge_loss(p):
    """
    p: BxNx3
    """
    batch = torch.cat([torch.ones(p.size(1)) * bn for bn in range(p.size(0))])
    nearest_idxs = gnn.knn(p.view(-1, 3), p.view(-1, 3), 2, batch, batch)[1,
                                                                          1::2]
    edge_len = torch.norm(p.view(-1, 3) - p.view(-1, 3)[nearest_idxs],
                          dim=-1).view(p.size()[:2])
    edge_loss = edge_len.mean(dim=-1).mean()
    return edge_loss
示例#9
0
    def forward(self, in_layer_data, skip_layer_data):
        in_x, in_pos, in_batch = in_layer_data
        skip_x, skip_pos, skip_batch = skip_layer_data

        row, col = knn(in_pos, skip_pos, self.knn_num, in_batch, skip_batch)
        edge_index = torch.stack([col, row], dim=0)

        x1 = self.point_conv((in_x, skip_x), (in_pos, skip_pos), edge_index)
        pos1, batch1 = skip_pos, skip_batch

        return x1, pos1, batch1
示例#10
0
文件: gconv2d.py 项目: nn4pde/SPINN
 def forward(self, x, y):
     target = torch.stack((x, y), dim=1)
     nodes = torch.vstack((self.points, self.f_points))
     a_index = knn(nodes, target, self.max_nbrs)
     h = self.widths()
     if self.use_pu:
         dnr = self.sph.forward(x, y, nodes, h, 1.0, a_index)
     else:
         dnr = 1.0
     nr = self.sph.forward(x, y, nodes, h, self.u, a_index)
     return nr / dnr
示例#11
0
 def knn_graph_3d(self, pos, x, batch, k, direction='source2target'):
     """Compute k-nearest neighbors in 3D Euclidean space.
 Targets indicate points and sources indicates their neighbors.
 Edges are directed either source-to-target or target-to-source.
 
 Returns:
   edge_index: [2, M] (i, j) = edge_index[:, e] indicated j is one of i's nearest neighbor.
 """
     target_idx, source_idx = knn(pos, pos, k, batch, batch)
     if direction == 'source2target':
         edge_index = torch.stack([source_idx, target_idx], dim=0)
     else:
         edge_index = torch.stack([target_idx, source_idx], dim=0)
     return edge_index
示例#12
0
    def forward(self, x, pos, batch):
        pos6d = torch.cat([pos, x], dim=-1)
        ratio = (self.num_clusters + 2) / x.shape[0]
        idx = fps(pos6d, batch, ratio=ratio, random_start=False)
        idx = idx[:self.num_clusters]
        edge_index = knn(pos6d[idx], pos6d, 1, batch[idx], batch)
        x = scatter_mean(x.index_select(index=edge_index[0], dim=0),
                         edge_index[1].unsqueeze(1),
                         dim=0)
        pos = scatter_mean(pos.index_select(index=edge_index[0], dim=0),
                           edge_index[1].unsqueeze(1),
                           dim=0)

        return x, pos, batch[idx], edge_index
示例#13
0
    def forward(self, support_xyz, batch_index, filtered_index, features):
        # knn
        query_xyz = torch.index_select(support_xyz, 0, filtered_index)
        query_batch_index = torch.index_select(batch_index, 0, filtered_index)
        query_features = torch.index_select(features, 0, filtered_index)

        row, col = knn(support_xyz, query_xyz, self.k, batch_index,
                       query_batch_index)
        edge_index = torch.stack([col, row], dim=0)

        # x has shape [N, F_in]
        # edge_index has shape [2, E]
        return self.propagate(edge_index,
                              x=(features, query_features),
                              pos=(support_xyz, query_xyz))  # shape [N, F_out]
示例#14
0
    def __call__(self, sample):
        keys = sorted([x for x in dir(sample) if 'edge_index' in x])

        num_vertices = [sample.num_nodes] + sample.num_vertices

        # sample.edge_index = knn_graph(sample.pos, self._k[0])

        # knn_edges = knn_graph(sample.pos, self._k[0] * self._d)
        # dilated_idx = [index for index in range(knn_edges.shape[1])[0::self._d]]

        # sample.edge_index = knn_edges[:, dilated_idx]

        for level, key in enumerate(keys):
            if level == len(keys) - 1:
                break

            pos_key = key.replace('edge_index',
                                  'pos').replace('hierarchy_', '')

            subset_points_idx = fps(sample[pos_key],
                                    ratio=sample.num_vertices[level] /
                                    num_vertices[level])

            # if level == 0:
            #     sample.y = sample.y[subset_points_idx]

            num_vertices[level + 1] = subset_points_idx.shape[0]
            sample.num_vertices[level] = num_vertices[level + 1]
            sample['pos_' +
                   str(level + 1)] = sample[pos_key][subset_points_idx]
            sample[f"hierarchy_trace_index_{level+1}"] = knn(
                sample['pos_' + str(level + 1)], sample[pos_key], 1)[1, :]
            # sample[f"hierarchy_edge_index_{level+1}"] = knn_graph(sample['pos_' + str(level+1)], self._k[level+1])

            # knn_edges = knn_graph(sample['pos_' + str(level+1)], self._k[level+1] * self._d)
            # dilated_idx = [index for index in range(knn_edges.shape[1])[0::self._d]]

            # sample[f"hierarchy_edge_index_{level+1}"] = knn_edges[:, dilated_idx]

        keys = sorted([x for x in dir(sample) if x.startswith('x_')])
        for key in keys:
            delattr(sample, key)

        # keys = sorted([x for x in dir(sample) if 'pos_' in x])
        # for key in keys:
        #     delattr(sample, key)

        return sample
示例#15
0
def knn_interpolate(x, pos_x, pos_y, batch_x=None, batch_y=None, k=3):
    r"""The k-NN interpolation from the `"PointNet++: Deep Hierarchical
    Feature Learning on Point Sets in a Metric Space"
    <https://arxiv.org/abs/1706.02413>`_ paper.
    For each point :obj:`y`, its interpolated features are given by

    .. math::
        f(\mathbf{x}) = \frac{\sum_{i=1}^k w_i(y) f_i}{\sum_{i=1}^k w_i(y)}
        \textrm{, where } w_i(y) = \frac{1}{d(y, x_i)^2}.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        pos_x (Tensor): Node position matrix
            :math:`\in \mathbb{R}^{N \times d}`.
        pos_y (Tensor): Upsampled node position matrix
            :math:`\in \mathbb{R}^{M \times d}`.
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
            each node from :math:`\mathbf{X}` to a specific example.
            (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
            each node from :math:`\mathbf{Y}` to a specific example.
            (default: :obj:`None`)
        k (int, optional): Number of neighbors. (default: :obj:`3`)

    :rtype: :class:`Tensor`
    """

    with torch.no_grad():
        y_idx, x_idx = knn(pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y)
        diff = pos_x[x_idx] - pos_y[y_idx]
        squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
        print(squared_distance)
        weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

    # print(weights)
    print(x)
    y = scatter_add(x[x_idx] * weights, y_idx, dim=0, dim_size=pos_y.size(0))
    y = y / scatter_add(weights, y_idx, dim=0, dim_size=pos_y.size(0))

    return y
示例#16
0
文件: utils.py 项目: felja633/RLLReg
def extract_correspondences_gpu(coords, Rs, ts, device=None):
    coords = Rs @ coords + ts

    M = len(coords)
    ind_nns = []
    dists = []
    if not device is None:
        coords = coords.to(device)
    for ind1 in range(M - 1):
        for ind2 in range(ind1 + 1, M):
            point1 = coords[ind1].permute(1, 0)
            point2 = coords[ind2].permute(1, 0)
            inds_nn = knn(point2, point1, 1)
            d = point1[inds_nn[0, :]] - point2[inds_nn[1, :]]
            d = (d * d).sum(dim=1).sqrt()
            dists.append(d)
            ind_nns.append(inds_nn)

    return dict(indices=TensorList(ind_nns), distances=TensorList(dists))
示例#17
0
    def build_bipartite_graph(self,
                              pos,
                              batch,
                              ratio,
                              method='radius',
                              r=0.1,
                              k=32,
                              dilation=1):
        '''
        Build a bipartite graph given a pos vector
        :param pos: (torch.Tensor) The position of the points
        :param batch: (torch.Tensor) The batch index for each point
        :param ratio: (float) The sample ratio
        :param method: (str) Specify which method to use, ['radius', 'knn']
        :param r: (float) If 'radius' is adopted, the radius of the support domain
        :param k: (int) IF 'knn' is adopted, the number of neighbors
        :param dilation: (int) If 'knn' is adopted, the dilation rate
        :return: (torch.Tensor) The edge index, position and batch of the sampled points
        '''
        assert method in ['radius', 'knn']
        idx = fps(pos, batch, ratio=ratio)
        if method == 'radius':
            row, col = radius(pos,
                              pos[idx],
                              r,
                              batch,
                              batch[idx],
                              max_num_neighbors=k)
        if method == 'knn':
            row, col = knn(pos, pos[idx], k * dilation, batch, batch[idx])
            if dilation > 1:
                n = idx.shape[0]
                index = torch.randint(k * dilation, (n, k),
                                      dtype=torch.long,
                                      device=row.device)
                arange = torch.arange(n, dtype=torch.long, device=row.device)
                arange = arange * (k * dilation)
                index = (index + arange.view(-1, 1)).view(-1)
                row, col = row[index], col[index]

        edge_index = torch.stack([col, row], dim=0)
        return edge_index, pos[idx], batch[idx]
    def init_upscale(self, num_knn=3):
        self.ass_array = []
        self.ass_vecs = []
        self.ass_weights = []
        for idx in range(self.scale_idx_len - 1):
            vert_i = self.vert_array[idx].to(device_cpu)
            vert_ip1 = self.vert_array[idx + 1].to(device_cpu)

            ass_curr = knn(vert_i, vert_ip1, num_knn)
            ass_curr = ass_curr[1, :].view(-1, num_knn)
            self.ass_array.append(ass_curr.to(device))  #[n_vert_tp1, num_knn]

            vec_curr = vert_ip1.unsqueeze(1) - vert_i[ass_curr, :]
            self.ass_vecs.append(
                vec_curr.to(device))  #[n_vert_tp1, num_knn, 3]

            weights_curr = 1 / (torch.norm(vec_curr, dim=2, keepdim=True) +
                                1e-5)
            weights_curr = weights_curr / torch.sum(
                weights_curr, dim=1, keepdim=True)
            self.ass_weights.append(
                weights_curr.to(device))  #[n_vert_tp1, num_knn, 1]
示例#19
0
def process_ply(
    ply_path: str,
    n_patch: int = 100,
    k: int = 2048,
    down_sample: Union[None, int, float] = None,
    cuda: bool = False,
):
    r"""
    Processes PLY file from path, convert to PC
    :return patch-pos + patch-color + patch-kernel
    """
    mesh_data = read_mesh(ply_path)
    print(colorama.Fore.GREEN + "Loaded PLY file %s" % ply_path)
    pos, color = mesh_data.pos, mesh_data.color
    if cuda:
        pos = pos.to(torch.device("cuda:7"))
        color = color.to(pos)
    # down sampling using uniform method
    n_pts = pos.shape[0]
    if down_sample is not None:
        if isinstance(down_sample, int):
            idx = torch.randperm(n_pts)[:down_sample]
            pos, color = pos[idx], color[idx]
        elif isinstance(down_sample, float):
            idx = torch.randperm(n_pts)[:np.floor(down_sample * n_pts)]
            pos, color = pos[idx], color[idx]
    # select N kernels by FPS sample: [N, ]
    patch_kernel = fps(pos, ratio=n_patch / n_pts, random_start=True)
    # sprint(patch_kernel.shape)
    # select N patches: [N, 2048]
    patches = knn(pos, pos[patch_kernel], k=k, num_workers=32)
    patches = patches[1].reshape(-1, k)
    # sprint(patches.shape)
    data_list = [
        # ([N, 2048, 3], [N, 2048, 3], [N, 3])
        Data(color=color[patch], pos=pos[patch], patch_index=patch)
        for patch, pk in zip(patches, patch_kernel)
    ]
    return data_list
示例#20
0
    def precompute(self, query, support):
        """ Precomputes a data structure that can be used in the transform itself to speed things up
        """
        pos_x, pos_y = query.pos, support.pos
        if hasattr(support, "batch"):
            batch_y = support.batch
        else:
            batch_y = torch.zeros((support.num_nodes,), dtype=torch.long)
        if hasattr(query, "batch"):
            batch_x = query.batch
        else:
            batch_x = torch.zeros((query.num_nodes,), dtype=torch.long)

        with torch.no_grad():
            assign_index = knn(pos_x, pos_y, self.k, batch_x=batch_x, batch_y=batch_y)
            y_idx, x_idx = assign_index
            diff = pos_x[x_idx] - pos_y[y_idx]
            squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
            weights = 1.0 / torch.clamp(squared_distance, min=1e-16)
        normalisation = scatter_add(weights, y_idx, dim=0, dim_size=pos_y.size(0))

        return Data(num_nodes=support.num_nodes, x_idx=x_idx, y_idx=y_idx, weights=weights, normalisation=normalisation)
示例#21
0
 def find_neighbours(self, x, y, batch_x, batch_y):
     return knn(x, y, self.k, batch_x, batch_y)
示例#22
0
 def find_neighbours(self, x, y, batch_x, batch_y):
     tmp = knn(x, y, self.k, batch_x, batch_y)
     return tmp
示例#23
0
    def forward(self, data, final, start):
        x, edge_index, pos, batch = data.x, data.edge_index, data.pos, data.batch

        x_start = x

        # Only street based pooling
        if self.clustering == 'Street':
            batchClusters1 = self.clusters1
            batchCat = self.categories
            batchClusters2 = self.clusters2

            batch_size = torch.max(batch) + 1

            # Divide clusters and categories from different batches
            for i in range(1, batch_size):
                batchClusters1 = torch.cat(
                    (batchClusters1, self.clusters1 + i * self.maxCluster1))
                batchCat = torch.cat((batchCat, self.categories + i * 5))
                batchClusters2 = torch.cat(
                    (batchClusters2, self.clusters2 + i * self.maxCluster2))

            batchCat = batchCat.long()
            data.batch = batchCat

            data2 = data

            # Both pooled branches, max pooling
            data = max_pool(batchClusters1, data)
            x_t, edge_index_t, pos_t, batchCat_t = data.x, data.edge_index, data.pos, data.batch

            data2 = max_pool(batchClusters2, data2)
            x_t2, edge_index_t2, pos_t2, batchCat_t2 = data2.x, data2.edge_index, data2.pos, data2.batch

            edge_index_t, temp = add_self_loops(edge_index_t)
            edge_index_t2, temp = add_self_loops(edge_index_t2)

            # Add coordinates and categories to input
            if self.coords:
                cats = (batchCat % 5).float()
                catsT = (batchCat_t % 5).float()
                catsT2 = (batchCat_t2 % 5).float()

                normPos = pos / torch.max(pos)
                normPos_t = pos_t / torch.max(pos_t)
                normPos_t2 = pos_t2 / torch.max(pos_t2)
                normCat = (cats / 4).view(batchCat.size(0), 1)
                normCat_t = (catsT / 4).view(batchCat_t.size(0), 1)
                normCat_t2 = (catsT2 / 4).view(batchCat_t2.size(0), 1)

                x = torch.cat((x, normPos, normCat), 1)
                x_t = torch.cat((x_t, normPos_t, normCat_t), 1)
                x_t2 = torch.cat((x_t2, normPos_t2, normCat_t2), 1)

            # Perform convolution blocks in all 3 branches
            for i in range(self.layers):
                x_temp = x
                x = self.moduleList1[i](x, edge_index)
                if self.midSkip:
                    x = torch.cat((x, x_temp), 1)
                    if i == 0:
                        bn = self.bn
                    elif i == 1:
                        bn = self.bn2
                    else:
                        bn = self.bn3

                    x = F.relu(bn(self.skipList1[i](x, edge_index)))

            for i in range(self.layers):
                x_ttemp = x_t
                x_t = self.moduleList2[i](x_t, edge_index_t)
                if self.midSkip:
                    x_t = torch.cat((x_t, x_ttemp), 1)
                    if i == 0:
                        bn = self.bn
                    elif i == 1:
                        bn = self.bn2
                    else:
                        bn = self.bn3
                    x_t = F.relu(bn(self.skipList2[i](x_t, edge_index_t)))

            for i in range(self.layers):
                x_ttemp2 = x_t2
                x_t2 = self.moduleList3[i](x_t2, edge_index_t2)
                if self.midSkip:
                    x_t2 = torch.cat((x_t2, x_ttemp2), 1)
                    if i == 0:
                        bn = self.bn
                    elif i == 1:
                        bn = self.bn2
                    else:
                        bn = self.bn3
                    x_t2 = F.relu(bn(self.skipList3[i](x_t2, edge_index_t2)))

            # Calculate knn weights of both pooled branches for first batch (and last, since the size might be different)
            if start:
                sorter = torch.argsort(batchCat)
                backsorter = torch.argsort(sorter)

                pos = pos[sorter]
                batchCat = batchCat[sorter]

                pairs = knn(pos_t,
                            pos,
                            self.knn,
                            batch_x=batchCat_t,
                            batch_y=batchCat)
                yIdx, xIdx = pairs
                diff = pos_t[xIdx] - pos[yIdx]
                squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
                weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

                pairs2 = knn(pos_t2,
                             pos,
                             self.knn,
                             batch_x=batchCat_t2,
                             batch_y=batchCat)
                yIdx2, xIdx2 = pairs2
                diff2 = pos_t2[xIdx2] - pos[yIdx2]
                squared_distance2 = (diff2 * diff2).sum(dim=-1, keepdim=True)
                weights2 = 1.0 / torch.clamp(squared_distance2, min=1e-16)

                self.weights = weights
                self.xIdx = xIdx
                self.yIdx = yIdx

                self.weights2 = weights2
                self.xIdx2 = xIdx2
                self.yIdx2 = yIdx2
                self.backSorter = backsorter

            if final:
                sorter = torch.argsort(batchCat)
                backsorter = torch.argsort(sorter)

                pos = pos[sorter]
                batchCat = batchCat[sorter]

                pairs = knn(pos_t,
                            pos,
                            self.knn,
                            batch_x=batchCat_t,
                            batch_y=batchCat)
                yIdx, xIdx = pairs
                diff = pos_t[xIdx] - pos[yIdx]
                squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
                weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

                pairs2 = knn(pos_t2,
                             pos,
                             self.knn,
                             batch_x=batchCat_t2,
                             batch_y=batchCat)
                yIdx2, xIdx2 = pairs2
                diff2 = pos_t2[xIdx2] - pos[yIdx2]
                squared_distance2 = (diff2 * diff2).sum(dim=-1, keepdim=True)
                weights2 = 1.0 / torch.clamp(squared_distance2, min=1e-16)

                self.weights = weights
                self.xIdx = xIdx
                self.yIdx = yIdx

                self.weights2 = weights2
                self.xIdx2 = xIdx2
                self.yIdx2 = yIdx2
                self.backSorter = backsorter

            # Unpool pooled branches
            x_t = scatter_add(x_t[self.xIdx] * self.weights,
                              self.yIdx,
                              dim=0,
                              dim_size=pos.size(0))
            x_t = x_t / scatter_add(
                self.weights, self.yIdx, dim=0, dim_size=pos.size(0))

            x_t = x_t[self.backSorter]

            x_t2 = scatter_add(x_t2[self.xIdx2] * self.weights2,
                               self.yIdx2,
                               dim=0,
                               dim_size=pos.size(0))
            x_t2 = x_t2 / scatter_add(
                self.weights2, self.yIdx2, dim=0, dim_size=pos.size(0))

            x_t2 = x_t2[self.backSorter]

            # Input size of final convolution
            if self.skipconv:
                y = torch.cat((x, x_t, x_t2, x_start), 1)
            else:
                y = torch.cat((x, x_t, x_t2), 1)

            # Do final convolution
            y = self.conv_mix(y, edge_index)

            # Add dropout layer
            if self.p != 1:
                y = F.dropout(y, training=self.training, p=self.p)

        return y
示例#24
0
def fps_max_pooling(pos, x, batch=None, k=16, r=0.5):
    idx = fps(pos, batch, ratio=r)
    i, j = knn(pos, pos[idx], k, batch, batch[idx])
    x = scatter(x[j], i, dim=0, reduce='max')
    pos, batch = pos[idx], batch[idx]
    return x, pos, batch
示例#25
0
def losses(pred,
           target,
           target_norms,
           device,
           chamfer=True,
           edge=False,
           norm=False,
           t=2 / 32,
           metrics=True):
    # form vectors of batch indicators
    pred_batch = torch.ones(pred.size()[:2], dtype=torch.int64, device=device)
    pred_batch *= torch.arange(start=0,
                               end=pred.size(0),
                               dtype=torch.int64,
                               device=device).view(-1, 1)
    pred_batch = pred_batch.flatten()

    target_batch = torch.ones(target.size()[:2],
                              dtype=torch.int64,
                              device=device)
    target_batch *= torch.arange(start=0,
                                 end=target.size(0),
                                 dtype=torch.int64,
                                 device=device).view(-1, 1)
    target_batch = target_batch.flatten()

    chamfer_loss = edge_loss = norm_loss = 0.
    precision = recall = fscore = 0.

    # Chamfer Loss
    if chamfer or norm:
        target_pred_nearest = gnn.knn(x=pred.view(-1, 3),
                                      y=target.view(-1, 3),
                                      k=1,
                                      batch_x=pred_batch,
                                      batch_y=target_batch)
        target_pred_nearest = target_pred_nearest.view(2, target.size(0),
                                                       target.size(1), -1)

        pred_target_nearest = gnn.knn(x=target.view(-1, 3),
                                      y=pred.view(-1, 3),
                                      k=1,
                                      batch_x=target_batch,
                                      batch_y=pred_batch)
        pred_target_nearest = pred_target_nearest.view(2, pred.size(0),
                                                       pred.size(1), -1)

        if chamfer:
            pred_target_dist = torch.norm(
                pred.view(-1, 3)[pred_target_nearest[0].view(-1)] -
                target.view(-1, 3)[pred_target_nearest[1].view(-1)],
                dim=-1)
            pred_target_dist = pred_target_dist.view(pred.size(0),
                                                     pred.size(1))

            target_pred_dist = torch.norm(
                target.view(-1, 3)[target_pred_nearest[0].view(-1)] -
                pred.view(-1, 3)[target_pred_nearest[1].view(-1)],
                dim=-1)
            target_pred_dist = target_pred_dist.view(target.size(0),
                                                     target.size(1))

            target_pred_dist_mean = target_pred_dist.mean(dim=-1).mean()
            pred_target_dist_mean = pred_target_dist.mean(dim=-1).mean()

            chamfer_loss = pred_target_dist_mean + target_pred_dist_mean

    if edge or norm:
        pred_k_3 = gnn.knn_graph(x=pred.view(-1, 3), k=3, batch=pred_batch)
        # Edge Loss
        if edge:
            pred_edge_dist = torch.norm(pred.view(-1, 3)[pred_k_3[0, 0::3]] -
                                        pred.view(-1, 3)[pred_k_3[1, 0::3]],
                                        dim=-1)
            pred_edge_dist = pred_edge_dist.view(pred.size(0), pred.size(1))
            edge_loss = pred_edge_dist.mean(dim=-1).mean()

        # Norm Loss
        if norm:
            pred_vec1 = pred.view(-1, 3)[pred_k_3[1, 1::3]] - pred.view(
                -1, 3)[pred_k_3[1, 0::3]]
            pred_vec2 = pred.view(-1, 3)[pred_k_3[1, 2::3]] - pred.view(
                -1, 3)[pred_k_3[1, 0::3]]

            # compute normal vector
            pred_approx_norm = torch.cross(pred_vec1, pred_vec2)
            pred_approx_norm = pred_approx_norm / torch.norm(
                pred_approx_norm, dim=-1).view(-1, 1)

            # cosine dissimilarity of norms
            pred_target_norm_dissim = 1 - torch.mul(
                pred_approx_norm[pred_target_nearest[0].view(-1)],
                target_norms.view(
                    -1, 3)[pred_target_nearest[1].view(-1)]).sum(dim=-1).abs()
            pred_target_norm_dissim = pred_target_norm_dissim.view(
                pred.size(0), pred.size(1))

            target_pred_norm_dissim = 1 - torch.mul(
                target_norms.view(-1, 3)[target_pred_nearest[0].view(-1)],
                pred_approx_norm[target_pred_nearest[1].view(-1)]).sum(
                    dim=-1).abs()
            target_pred_norm_dissim = target_pred_norm_dissim.view(
                target.size(0), target.size(1))

            target_pred_norm_dissim_mean = target_pred_norm_dissim.mean(
                dim=-1).mean()
            pred_target_norm_dissim_mean = pred_target_norm_dissim.mean(
                dim=-1).mean()

            norm_loss = pred_target_norm_dissim_mean + target_pred_norm_dissim_mean

    if metrics:
        # compute metrics
        precision = (target_pred_dist <= t).type(
            torch.float).mean(dim=-1).mean()
        recall = (pred_target_dist <= t).type(torch.float).mean(dim=-1).mean()
        fscore = ((2 * precision * recall) / (precision + recall)).mean()

        # fix nan. values
        precision[precision != precision] = 0.
        recall[recall != recall] = 0.
        fscore[fscore != fscore] = 0.

    losses_all = {'cd': chamfer_loss, 'el': edge_loss, 'nl': norm_loss}
    metrics_all = {'precision': precision, 'recall': recall, 'fscore': fscore}

    return losses_all, metrics_all