Ejemplo n.º 1
0
    def __init__(self, in_channels, out_channels, dim, kernel_size,
                 is_open_spline=True, degree=1, aggr='mean', root_weight=True,
                 bias=True, **kwargs):
        super(SplineConv, self).__init__(aggr=aggr, **kwargs)

        if spline_basis is None:
            raise ImportError('`SplineConv` requires `torch-spline-conv`.')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dim = dim
        self.degree = degree

        kernel_size = torch.tensor(repeat(kernel_size, dim), dtype=torch.long)
        self.register_buffer('kernel_size', kernel_size)

        is_open_spline = repeat(is_open_spline, dim)
        is_open_spline = torch.tensor(is_open_spline, dtype=torch.uint8)
        self.register_buffer('is_open_spline', is_open_spline)

        K = kernel_size.prod().item()
        self.weight = Parameter(torch.Tensor(K, in_channels, out_channels))

        if root_weight:
            self.root = Parameter(torch.Tensor(in_channels, out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()
Ejemplo n.º 2
0
def voxel_grid(
    pos: Tensor,
    size: Union[float, List[float], Tensor],
    batch: Optional[Tensor] = None,
    start: Optional[Union[float, List[float], Tensor]] = None,
    end: Optional[Union[float, List[float], Tensor]] = None,
) -> Tensor:
    r"""Voxel grid pooling from the, *e.g.*, `Dynamic Edge-Conditioned Filters
    in Convolutional Networks on Graphs <https://arxiv.org/abs/1704.02901>`_
    paper, which overlays a regular grid of user-defined size over a point
    cloud and clusters all points within the same voxel.

    Args:
        pos (Tensor): Node position matrix
            :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times D}`.
        size (float or [float] or Tensor): Size of a voxel (in each dimension).
        batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0,
            \ldots,B-1\}}^N`, which assigns each node to a specific example.
            (default: :obj:`None`)
        start (float or [float] or Tensor, optional): Start coordinates of the
            grid (in each dimension). If set to :obj:`None`, will be set to the
            minimum coordinates found in :attr:`pos`. (default: :obj:`None`)
        end (float or [float] or Tensor, optional): End coordinates of the grid
            (in each dimension). If set to :obj:`None`, will be set to the
            maximum coordinates found in :attr:`pos`. (default: :obj:`None`)

    :rtype: :class:`LongTensor`
    """

    if grid_cluster is None:
        raise ImportError('`voxel_grid` requires `torch-cluster`.')

    pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
    num_nodes, dim = pos.size()

    size = size.tolist() if torch.is_tensor(size) else size
    start = start.tolist() if torch.is_tensor(start) else start
    end = end.tolist() if torch.is_tensor(end) else end

    size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim)

    if batch is None:
        batch = torch.zeros(pos.shape[0], dtype=torch.long)

    pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1)
    size = size + [1]
    start = None if start is None else start + [0]
    end = None if end is None else end + [batch.max().item()]

    size = torch.tensor(size, dtype=pos.dtype, device=pos.device)
    if start is not None:
        start = torch.tensor(start, dtype=pos.dtype, device=pos.device)
    if end is not None:
        end = torch.tensor(end, dtype=pos.dtype, device=pos.device)

    return grid_cluster(pos, size, start, end)
Ejemplo n.º 3
0
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        dim: int,
        kernel_size: Union[int, List[int]],
        is_open_spline: bool = True,
        degree: int = 1,
        aggr: str = 'mean',
        root_weight: bool = True,
        bias: bool = True,
        **kwargs,
    ):
        super().__init__(aggr=aggr, **kwargs)

        if spline_basis is None:
            raise ImportError("'SplineConv' requires 'torch-spline-conv'")

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dim = dim
        self.degree = degree
        self.root_weight = root_weight

        kernel_size = torch.tensor(repeat(kernel_size, dim), dtype=torch.long)
        self.register_buffer('kernel_size', kernel_size)

        is_open_spline = repeat(is_open_spline, dim)
        is_open_spline = torch.tensor(is_open_spline, dtype=torch.uint8)
        self.register_buffer('is_open_spline', is_open_spline)

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.K = kernel_size.prod().item()

        if in_channels[0] > 0:
            self.weight = Parameter(
                torch.Tensor(self.K, in_channels[0], out_channels))
        else:
            self.weight = torch.nn.parameter.UninitializedParameter()
            self._hook = self.register_forward_pre_hook(
                self.initialize_parameters)

        if root_weight:
            self.lin = Linear(in_channels[1],
                              out_channels,
                              bias=False,
                              weight_initializer='uniform')

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()
Ejemplo n.º 4
0
    def __init__(self, in_channels, hidden_channels, out_channels, depth,
                 pool_ratios=0.5, sum_res=True, act=F.relu):
        super(GraphUNet, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        self.sum_res = sum_res

        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(conv(in_channels, channels, improved=True))
        for i in range(depth):
            self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
            self.down_convs.append(conv(channels, channels, improved=True))

        in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        for i in range(depth - 1):
            self.up_convs.append(conv(in_channels, channels, improved=True))
        self.up_convs.append(conv(in_channels, out_channels, improved=True))

        self.reset_parameters()
Ejemplo n.º 5
0
    def __init__(self,
                 data,
                 size,
                 num_hops,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 bipartite=True,
                 add_self_loops=False,
                 flow='source_to_target'):

        self.data = data
        self.size = repeat(size, num_hops)
        self.num_hops = num_hops
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.bipartite = bipartite
        self.add_self_loops = add_self_loops
        self.flow = flow

        assert flow in ['source_to_target', 'target_to_source']
        self.i, self.j = (0, 1) if flow == 'target_to_source' else (1, 0)

        self.edge_index_i, self.e_assoc = data.edge_index[self.i].sort()
        self.edge_index_j = data.edge_index[self.j, self.e_assoc]
        deg = degree(self.edge_index_i, data.num_nodes, dtype=torch.long)
        self.cumdeg = torch.cat([deg.new_zeros(1), deg.cumsum(0)])

        self.tmp = torch.empty(data.num_nodes, dtype=torch.long)
Ejemplo n.º 6
0
    def __init__(self, data, size, num_hops, batch_size=1, shuffle=False,
                 drop_last=False, bipartite=True, add_self_loops=False,
                 flow='source_to_target'):

        self.data = data
        self.size = repeat(size, num_hops)
        self.num_hops = num_hops
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.bipartite = bipartite
        self.add_self_loops = add_self_loops
        self.flow = flow

        self.edge_index = data.edge_index
        self.e_id = torch.arange(self.edge_index.size(1))
        if bipartite and add_self_loops:
            tmp = segregate_self_loops(self.edge_index, self.e_id)
            self.edge_index, self.e_id, self.edge_index_loop = tmp[:3]
            self.e_id_loop = self.e_id.new_full((data.num_nodes, ), -1)
            self.e_id_loop[tmp[2][0]] = tmp[3]

        assert flow in ['source_to_target', 'target_to_source']
        self.i, self.j = (0, 1) if flow == 'target_to_source' else (1, 0)

        edge_index_i, self.e_assoc = self.edge_index[self.i].sort()
        self.edge_index_j = self.edge_index[self.j, self.e_assoc]
        deg = degree(edge_index_i, data.num_nodes, dtype=torch.long)
        self.cumdeg = torch.cat([deg.new_zeros(1), deg.cumsum(0)])

        self.tmp = torch.empty(data.num_nodes, dtype=torch.long)
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        depth: int = 1,
        pool_ratios: float = 0.5,
        act=F.relu,
        variational: bool = True,
    ):
        super(VariationalGraphEncoder, self).__init__()
        assert depth >= 1
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        self.variational = variational

        self.down_convs = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.down_convs.append(GCNConv(in_channels, hidden_channels, improved=True))
        for i in range(depth):
            self.pools.append(TopKPooling(hidden_channels, self.pool_ratios[i]))
            self.down_convs.append(
                GCNConv(hidden_channels, hidden_channels, improved=True)
            )

        self.conv_mu = GCNConv(hidden_channels, out_channels)
        if self.variational:
            self.conv_logstd = GCNConv(hidden_channels, out_channels)

        self.reset_parameters()
Ejemplo n.º 8
0
    def __init__(self, in_channels, hidden_channels, out_channels, depth,
                 params={},
                 pool_ratios=0.5, sum_res=True, act=F.relu):
        super(HillGraphUNet, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.undirected_graphs = params['undirected_graphs']
        self.n_attention_heads = params['n_attention_heads']
        #self.attention_concat = params['attention_concat']
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        self.sum_res = sum_res
        self.use_batchnorm = params['use_batchnorm']
        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_bn = torch.nn.ModuleList() # batch norm after each GCN block

        # select between GCNConv and GATConv (graph attention) and BidirectionalGraphConv
        if params['bidirectional_graph_conv']:
            graph_conv = lambda in_c, out_c: BidirectionalGraphConv(in_c, out_c, n_attention_heads=self.n_attention_heads, concat=True)
        elif self.n_attention_heads == 0:
            graph_conv = lambda in_c, out_c: GCNBlock(in_c, channels, out_c)
        else:
            # we divide out_c by n_attention_heads because GATConv output is
            # by default the concatenation of all heads.
            # to get out_c channels at total,
            # we specify the attention heads with (out_c / n_attention_heads) output channels each
            # the user must verify that n_attention_heads divides out_c. n_heads Outputs of size outc//n_heads concatenated
            # one to each other >>> output of size out_c
            graph_conv = lambda in_c, out_c: GATConv(in_c, out_c // self.n_attention_heads, heads=self.n_attention_heads)


        self.down_convs.append(graph_conv(in_channels, channels))
        self.down_bn.append(torch.nn.BatchNorm1d(channels))

        for i in range(depth):
            self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
            self.down_convs.append(graph_conv(channels, channels))
            self.down_bn.append(torch.nn.BatchNorm1d(channels))

        in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        self.up_bn = torch.nn.ModuleList()
        for i in range(depth - 1):
            self.up_convs.append(graph_conv(in_channels, channels))
            self.up_bn.append(torch.nn.BatchNorm1d(channels))

        self.up_convs.append(graph_conv(in_channels, out_channels))
        self.up_bn.append(torch.nn.BatchNorm1d(out_channels))

        self.reset_parameters()
Ejemplo n.º 9
0
def voxel_grid(pos, batch, size, start=None, end=None):
    pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
    num_nodes, dim = pos.size()

    size, start, end = repeat(size, dim), repeat(start, dim), repeat(end, dim)

    pos = torch.cat([pos, batch.unsqueeze(-1).type_as(pos)], dim=-1)
    size = size + [1]
    start = None if start is None else start + [0]
    end = None if end is None else end + [batch.max().item()]

    size = torch.tensor(size, dtype=pos.dtype, device=pos.device)
    if start is not None:
        start = torch.tensor(start, dtype=pos.dtype, device=pos.device)
    if end is not None:
        end = torch.tensor(end, dtype=pos.dtype, device=pos.device)

    return grid_cluster(pos, size, start, end)
Ejemplo n.º 10
0
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int,
                 dim: int,
                 kernel_size: Union[int, List[int]],
                 is_open_spline: bool = True,
                 degree: int = 1,
                 aggr: str = 'mean',
                 root_weight: bool = True,
                 bias: bool = True,
                 **kwargs):  # yapf: disable
        super(SplineConv, self).__init__(aggr=aggr, **kwargs)

        if spline_basis is None:
            raise ImportError('`SplineConv` requires `torch-spline-conv`.')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dim = dim
        self.degree = degree

        kernel_size = torch.tensor(repeat(kernel_size, dim), dtype=torch.long)
        self.register_buffer('kernel_size', kernel_size)

        is_open_spline = repeat(is_open_spline, dim)
        is_open_spline = torch.tensor(is_open_spline, dtype=torch.uint8)
        self.register_buffer('is_open_spline', is_open_spline)

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        K = kernel_size.prod().item()
        self.weight = Parameter(torch.Tensor(K, in_channels[0], out_channels))

        if root_weight:
            self.root = Parameter(torch.Tensor(in_channels[1], out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()
Ejemplo n.º 11
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 dim,
                 kernel_size,
                 is_open_spline=True,
                 degree=1,
                 norm=True,
                 root_weight=True,
                 bias=True):
        super(SplineConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.degree = degree
        self.norm = norm

        kernel_size = torch.tensor(repeat(kernel_size, dim), dtype=torch.long)
        self.register_buffer('kernel_size', kernel_size)

        is_open_spline = repeat(is_open_spline, dim)
        is_open_spline = torch.tensor(is_open_spline, dtype=torch.uint8)
        self.register_buffer('is_open_spline', is_open_spline)

        K = kernel_size.prod().item()
        self.weight = Parameter(torch.Tensor(K, in_channels, out_channels))

        if root_weight:
            self.root = Parameter(torch.Tensor(in_channels, out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()
Ejemplo n.º 12
0
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 depth,
                 pool_ratios=0.5,
                 sum_res=True,
                 act=F.relu,
                 num_relations=4):
        super(GraphRUNet, self).__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        self.sum_res = sum_res

        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        #self.down_convs.append(FastRGCNConv(in_channels, channels, num_relations=num_relations))

        self.down_convs.append(GCNConv(in_channels, channels))
        print("IN_CHANNELS", in_channels)
        #, num_relations=num_relations))
        for i in range(depth):
            self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels))

        in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        for i in range(depth - 1):
            self.up_convs.append(GCNConv(channels, channels))
        self.up_convs.append(GCNConv(channels, channels))

        self.reset_parameters()
Ejemplo n.º 13
0
def test_repeat():
    assert repeat(None, length=4) is None
    assert repeat(4, length=4) == [4, 4, 4, 4]
    assert repeat([2, 3, 4], length=4) == [2, 3, 4, 4]
    assert repeat([1, 2, 3, 4], length=4) == [1, 2, 3, 4]
    assert repeat([1, 2, 3, 4, 5], length=4) == [1, 2, 3, 4]