Exemplo n.º 1
0
def test_radius(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)
    y = tensor([
        [0, 0],
        [0, 1],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 1], torch.long, device)

    edge_index = radius(x, y, 2, max_num_neighbors=4)
    assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
                                      (1, 2), (1, 5), (1, 6)])

    edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
    assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
                                      (1, 6)])
Exemplo n.º 2
0
def test():
    from torch_cluster import radius
    from e3nn.math import soft_one_hot_linspace

    conv = Convolution(
        irreps_node_input='0e + 1e',
        irreps_node_output='0e + 1e',
        irreps_node_attr_input='2x0e',
        irreps_node_attr_output='3x0e',
        irreps_edge_attr='0e + 1e',
        num_edge_scalar_attr=4,
        radial_layers=1,
        radial_neurons=50,
        num_neighbors=3.0,
    )

    pos_in = torch.randn(5, 3)
    pos_out = torch.randn(2, 3)

    node_input = torch.randn(5, 4)
    node_attr_input = torch.randn(5, 2)
    node_attr_output = torch.randn(2, 3)

    edge_src, edge_dst = radius(pos_out, pos_in, r=2.0)
    edge_vec = pos_in[edge_src] - pos_out[edge_dst]
    edge_attr = o3.spherical_harmonics([0, 1], edge_vec, True)
    edge_scalar_attr = soft_one_hot_linspace(x=edge_vec.norm(dim=1),
                                             start=0.0,
                                             end=2.0,
                                             number=4,
                                             basis='smooth_finite',
                                             cutoff=True)

    conv(node_input, node_attr_input, node_attr_output, edge_src, edge_dst,
         edge_attr, edge_scalar_attr)
Exemplo n.º 3
0
def random_crop_3D(P, F, factor):
    npoints = P.shape[0]
    n_points_after_crop = np.round(npoints * factor).astype(np.int)

    points_max = (P.max(axis=0) * 1000).astype(np.int)
    points_min = (P.min(axis=0) * 1000).astype(np.int)

    centroid = np.asarray([
        np.random.randint(low=points_min[0], high=points_max[0], dtype=int),
        np.random.randint(low=points_min[1], high=points_max[1], dtype=int),
        np.random.randint(low=points_min[2], high=points_max[2], dtype=int)
    ])

    centroid = centroid.astype(np.float32) / 1000

    rad = 0.1
    inc = 0.2

    npoints_inside_sphere = 0

    x = torch.from_numpy(P)
    y = torch.from_numpy(centroid).unsqueeze(0)
    while npoints_inside_sphere < n_points_after_crop:
        _, crop = torch_cluster.radius(x,
                                       y,
                                       rad,
                                       max_num_neighbors=n_points_after_crop)

        npoints_inside_sphere = len(crop)

        rad = np.round(rad + inc, 1)

    return P[crop.numpy()], F[crop.numpy()]
Exemplo n.º 4
0
def _variance_estimator_sparse(r, pos, f, batch_idx):
    with torch.no_grad():
        assign_index = radius(pos,
                              pos,
                              r,
                              batch_x=batch_idx,
                              batch_y=batch_idx)
        y_idx, x_idx = assign_index
        # diff = pos[x_idx] - pos[y_idx]
        # squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
        # weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

        grad_f = (f[x_idx] - f[y_idx])**2
    y = scatter_add(grad_f, y_idx, dim=0, dim_size=pos.size(0))
    return y
Exemplo n.º 5
0
def radius(x,
           y,
           r,
           batch_x=None,
           batch_y=None,
           max_num_neighbors=32,
           num_workers=1):
    r"""Finds for each element in :obj:`y` all points in :obj:`x` within
    distance :obj:`r`.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        y (Tensor): Node feature matrix
            :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
        r (float): The radius.
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
            node to a specific example. (default: :obj:`None`)
        max_num_neighbors (int, optional): The maximum number of neighbors to
            return for each element in :obj:`y`. (default: :obj:`32`)
        num_workers (int): Number of workers to use for computation. Has no
            effect in case :obj:`batch_x` or :obj:`batch_y` is not
            :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)

    :rtype: :class:`LongTensor`

    .. code-block:: python

        import torch
        from torch_geometric.nn import radius

        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])
        assign_index = radius(x, y, 1.5, batch_x, batch_y)
    """
    if torch_cluster is None:
        raise ImportError('`radius` requires `torch-cluster`.')

    return torch_cluster.radius(x, y, r, batch_x, batch_y, max_num_neighbors,
                                num_workers)
Exemplo n.º 6
0
def test_radius(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)
    y = tensor([
        [0, 0],
        [0, 1],
    ], dtype, device)

    batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
    batch_y = tensor([0, 1], torch.long, device)

    out = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
    assert coalesce(out).tolist() == [[0, 0, 0, 0, 1, 1], [0, 1, 2, 3, 5, 6]]