Ejemplo n.º 1
0
    def forward(self,
                x: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        if self.normalize:
            if isinstance(edge_index, Tensor):
                edge_index, edge_weight = gcn_norm(  # yapf: disable
                    edge_index,
                    edge_weight,
                    x.size(self.node_dim),
                    improved=False,
                    add_self_loops=False,
                    dtype=x.dtype)

            elif isinstance(edge_index, SparseTensor):
                edge_index = gcn_norm(  # yapf: disable
                    edge_index,
                    edge_weight,
                    x.size(self.node_dim),
                    add_self_loops=False,
                    dtype=x.dtype)

        xs = [x]
        for k in range(self.K):
            # propagate_type: (x: Tensor, edge_weight: OptTensor)
            out = self.propagate(edge_index,
                                 x=xs[-1],
                                 edge_weight=edge_weight,
                                 size=None)
            xs.append(out)
        return self.lin(torch.cat(xs, dim=-1))
Ejemplo n.º 2
0
    def forward(self,
                x: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        if self.normalize:
            if isinstance(edge_index, Tensor):
                edge_index, edge_weight = gcn_norm(  # yapf: disable
                    edge_index,
                    edge_weight,
                    x.size(self.node_dim),
                    improved=False,
                    add_self_loops=False,
                    dtype=x.dtype)

            elif isinstance(edge_index, SparseTensor):
                edge_index = gcn_norm(  # yapf: disable
                    edge_index,
                    edge_weight,
                    x.size(self.node_dim),
                    add_self_loops=False,
                    dtype=x.dtype)

        out = self.lins[0](x)
        for lin in self.lins[1:]:
            # propagate_type: (x: Tensor, edge_weight: OptTensor)
            x = self.propagate(edge_index,
                               x=x,
                               edge_weight=edge_weight,
                               size=None)
            out += lin.forward(x)
        return out
Ejemplo n.º 3
0
    def forward(self,
                x: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        False,
                        self.add_self_loops,
                        self.flow,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        False,
                        self.add_self_loops,
                        self.flow,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        h = x
        for k in range(self.K):
            if self.dropout > 0 and self.training:
                if isinstance(edge_index, Tensor):
                    assert edge_weight is not None
                    edge_weight = F.dropout(edge_weight, p=self.dropout)
                else:
                    value = edge_index.storage.value()
                    assert value is not None
                    value = F.dropout(value, p=self.dropout)
                    edge_index = edge_index.set_value(value, layout='coo')

            # propagate_type: (x: Tensor, edge_weight: OptTensor)
            x = self.propagate(edge_index,
                               x=x,
                               edge_weight=edge_weight,
                               size=None)
            x = x * (1 - self.alpha)
            x += self.alpha * h

        return x
Ejemplo n.º 4
0
    def forward(self,
                x: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        r"""
        Args:
            x: The input node features of shape :obj:`[num_nodes, num_layers,
                channels]`.
        """

        if x.dim() != 3:
            raise ValueError('Feature shape must be [num_nodes, num_layers, '
                             'channels].')

        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        False,
                        self.add_self_loops,
                        self.flow,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        False,
                        self.add_self_loops,
                        self.flow,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        return self.propagate(edge_index,
                              x=x,
                              edge_weight=edge_weight,
                              size=None)
Ejemplo n.º 5
0
    def forward(self,
                x: Tensor,
                x_0: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""

        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        False,
                        self.add_self_loops,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        False,
                        self.add_self_loops,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)

        if self.weight2 is None:
            out = (1 - self.alpha) * x + self.alpha * x_0
            out = (1 - self.beta) * out + self.beta * (out @ self.weight1)
        else:
            out1 = (1 - self.alpha) * x
            out1 = (1 - self.beta) * out1 + self.beta * (out1 @ self.weight1)
            out2 = self.alpha * x_0
            out2 = (1 - self.beta) * out2 + self.beta * (out2 @ self.weight2)
            out = out1 + out2

        return out
Ejemplo n.º 6
0
    def __call__(self, data):
        assert data.edge_index is not None or data.adj_t is not None

        if data.edge_index is not None:
            edge_weight = data.edge_attr
            if 'edge_weight' in data:
                edge_weight = data.edge_weight
            data.edge_index, data.edge_weight = gcn_norm(
                data.edge_index, edge_weight, data.num_nodes)
        else:
            data.adj_t = gcn_norm(data.adj_t)

        return data
Ejemplo n.º 7
0
    def forward(self, x, edge_index, edge_weight=None):
        if isinstance(edge_index, torch.Tensor):
            edge_index, norm = gcn_norm(
                edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(
                edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
            norm = None

        hidden = x*(self.temp[0])
        for k in range(self.K):
            x = self.propagate(edge_index, x=x, norm=norm)
            gamma = self.temp[k+1]
            hidden = hidden + gamma*x
        return hidden
Ejemplo n.º 8
0
    def __call__(self, data):
        assert 'edge_index' in data or 'adj_t' in data

        if 'edge_index' in data:
            edge_weight = data.edge_attr
            if 'edge_weight' in data:
                edge_weight = data.edge_weight
            data.edge_index, data.edge_weight = gcn_norm(
                data.edge_index, edge_weight, data.num_nodes,
                add_self_loops=self.add_self_loops)
        else:
            data.adj_t = gcn_norm(data.adj_t,
                                  add_self_loops=self.add_self_loops)

        return data
Ejemplo n.º 9
0
    def compute_energy(self, x, edge_index, device):

        energy_list = []
        edge_weight = None
        edge_index, edge_weight = gcn_norm(edge_index,
                                           edge_weight,
                                           x.size(0),
                                           False,
                                           dtype=x.dtype)
        adj_weight = to_dense_adj(edge_index, edge_attr=edge_weight)
        num_nodes = x.size(0)
        adj_weight = torch.squeeze(adj_weight, dim=0)
        laplacian_weight = torch.eye(
            num_nodes, dtype=torch.float, device=device) - adj_weight

        # compute energy in the first layer
        energy = self.Dirichlet_energy(x, laplacian_weight)
        energy_list.append(energy)

        if self.lin_first:
            x = self.SGC.lin(x)

        for k in range(self.num_layers):
            # propagate_type: (x: Tensor, edge_weight: OptTensor)
            x = self.SGC.propagate(edge_index,
                                   x=x,
                                   edge_weight=edge_weight,
                                   size=None)

            # compute energy in the middle layer
            energy = self.Dirichlet_energy(x, laplacian_weight)
            energy_list.append(energy)

        return energy_list
Ejemplo n.º 10
0
    def forward(self,
                x,
                edge_index,
                edge_weight: Optional[torch.Tensor] = None):
        """"""
        cache = self._cache
        if cache is not None:
            if edge_index.size(1) != cache[0]:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}. Please disable '
                    'the caching behavior of this layer by removing the '
                    '`cached=True` argument in its constructor.'.format(
                        cache[0], edge_index.size(1)))
            x = cache[1]

        else:
            num_edges = edge_index.size(1)

            edge_index, norm = gcn_norm(edge_index,
                                        x.size(self.node_dim),
                                        edge_weight,
                                        dtype=x.dtype)

            for k in range(self.K):
                x = self.propagate(edge_index, x=x, norm=norm)

            if self.cached:
                self._cache = (num_edges, x)

        return self.lin(x)
Ejemplo n.º 11
0
    def forward(self, x, edge_index, edge_weight=None, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        bias = params.get("bias", None)

        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        self.improved,
                        self.add_self_loops,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        self.improved,
                        self.add_self_loops,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        x = torch.matmul(x, params["weight"])

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index,
                             x=x,
                             edge_weight=edge_weight,
                             size=None)

        if self.bias is not None:
            out += bias
Ejemplo n.º 12
0
    def forward(self,
                x: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""

        if isinstance(edge_index, Tensor):
            edge_index, edge_weight = gcn_norm(  # yapf: disable
                edge_index,
                edge_weight,
                x.size(self.node_dim),
                add_self_loops=False,
                dtype=x.dtype)

        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(  # yapf: disable
                edge_index,
                edge_weight,
                x.size(self.node_dim),
                add_self_loops=False,
                dtype=x.dtype)

        x = x.unsqueeze(-3)
        out = x
        for t in range(self.num_layers):
            if t == 0:
                out = out @ self.init_weight
            else:
                out = out @ self.weight[0 if self.shared_weights else t - 1]

            # propagate_type: (x: Tensor, edge_weight: OptTensor)
            out = self.propagate(edge_index,
                                 x=out,
                                 edge_weight=edge_weight,
                                 size=None)

            root = F.dropout(x, p=self.dropout, training=self.training)
            out += root @ self.root_weight[0 if self.shared_weights else t]

            if self.bias is not None:
                out += self.bias[0 if self.shared_weights else t]

            if t < self.num_layers - 1:
                out = self.act(out)

        return out.mean(dim=-3)
Ejemplo n.º 13
0
    def forward(self, x, edge_index, edge_weight=None):
        edge_index, norm = gcn_norm(
            edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)

        hidden = x*(self.temp[0])
        for k in range(self.K):
            x = self.propagate(edge_index, x=x, norm=norm)
            gamma = self.temp[k+1]
            hidden = hidden + gamma*x
        return hidden
Ejemplo n.º 14
0
    def forward(self, data, train_idx):
        n = data.graph['num_nodes']
        edge_index = data.graph['edge_index']
        edge_weight=None

        if isinstance(edge_index, torch.Tensor):
            edge_index, edge_weight = gcn_norm( 
                edge_index, edge_weight, n, False)
            row, col = edge_index
            # transposed if directed
            adj_t = SparseTensor(row=col, col=row, value=edge_weight, sparse_sizes=(n, n))
        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(
                edge_index, edge_weight, n, False)
            edge_weight=None
            adj_t = edge_index

        y = torch.zeros((n, self.out_channels)).to(adj_t.device())
        if data.label.shape[1] == 1:
            # make one hot
            y[train_idx] = F.one_hot(data.label[train_idx], self.out_channels).squeeze(1).to(y)
        elif self.mult_bin:
            y = torch.zeros((n, 2*self.out_channels)).to(adj_t.device())
            for task in range(data.label.shape[1]):
                y[train_idx, 2*task:2*task+2] = F.one_hot(data.label[train_idx, task], 2).to(y)
        else:
            y[train_idx] = data.label[train_idx].to(y.dtype)
        result = y.clone()
        for _ in range(self.num_iters):
            for _ in range(self.hops):
                result = matmul(adj_t, result)
            result *= self.alpha
            result += (1-self.alpha)*y

        if self.mult_bin:
            output = torch.zeros((n, self.out_channels)).to(result.device)
            for task in range(data.label.shape[1]):
                output[:, task] = result[:, 2*task+1]
            result = output

        return result
Ejemplo n.º 15
0
    def neighborhood_aggregation(self, x, adj_t):
        if self.aggregator == 'gcn':
            adj_t = gcn_norm(adj_t,
                             num_nodes=x.size(self.node_dim),
                             add_self_loops=self.add_self_loops,
                             dtype=x.dtype)
        elif self.add_self_loops:
            adj_t = adj_t.set_diag()

        for k in range(self.K):
            x = self.propagate(adj_t, x=x)

        return x
Ejemplo n.º 16
0
    def forward(self,
                x: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        cache = self._cached_x
        if cache is None:
            if isinstance(edge_index, Tensor):
                edge_index, edge_weight = gcn_norm(  # yapf: disable
                    edge_index,
                    edge_weight,
                    x.size(self.node_dim),
                    False,
                    self.add_self_loops,
                    self.flow,
                    dtype=x.dtype)
            elif isinstance(edge_index, SparseTensor):
                edge_index = gcn_norm(  # yapf: disable
                    edge_index,
                    edge_weight,
                    x.size(self.node_dim),
                    False,
                    self.add_self_loops,
                    self.flow,
                    dtype=x.dtype)

            for k in range(self.K):
                # propagate_type: (x: Tensor, edge_weight: OptTensor)
                x = self.propagate(edge_index,
                                   x=x,
                                   edge_weight=edge_weight,
                                   size=None)
                if self.cached:
                    self._cached_x = x
        else:
            x = cache.detach()

        return self.lin(x)
Ejemplo n.º 17
0
    def forward(
        self,
        y: Tensor,
        edge_index: Adj,
        mask: Optional[Tensor] = None,
        edge_weight: OptTensor = None,
        post_step: Callable = lambda y: y.clamp_(0., 1.)
    ) -> Tensor:
        """"""

        if y.dtype == torch.long and y.size(0) == y.numel():
            y = F.one_hot(y.view(-1)).to(torch.float)

        out = y
        if mask is not None:
            out = torch.zeros_like(y)
            out[mask] = y[mask]

        if isinstance(edge_index, SparseTensor) and not edge_index.has_value():
            edge_index = gcn_norm(edge_index, add_self_loops=False)
        elif isinstance(edge_index, Tensor) and edge_weight is None:
            edge_index, edge_weight = gcn_norm(edge_index,
                                               num_nodes=y.size(0),
                                               add_self_loops=False)

        res = (1 - self.alpha) * out
        for _ in range(self.num_layers):
            # propagate_type: (y: Tensor, edge_weight: OptTensor)
            out = self.propagate(edge_index,
                                 x=out,
                                 edge_weight=edge_weight,
                                 size=None)
            out.mul_(self.alpha).add_(res)
            out = post_step(out)

        return out
Ejemplo n.º 18
0
    def forward(self,
                x: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        if self.normalize and isinstance(edge_index, Tensor):
            out = gcn_norm(edge_index,
                           edge_weight,
                           x.size(self.node_dim),
                           add_self_loops=False,
                           dtype=x.dtype)
            edge_index, edge_weight = out
        elif self.normalize and isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(edge_index,
                                  None,
                                  x.size(self.node_dim),
                                  add_self_loops=False,
                                  dtype=x.dtype)

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        return self.propagate(edge_index,
                              x=x,
                              edge_weight=edge_weight,
                              size=None)
Ejemplo n.º 19
0
    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:

        if self.lin_first:
            x = self.lin(x)

        """"""
        cache = self._cached_x
        if cache is None:
            if isinstance(edge_index, Tensor):
                edge_index, edge_weight = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim), False,
                    self.add_self_loops, dtype=x.dtype)
            elif isinstance(edge_index, SparseTensor):
                edge_index = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim), False,
                    self.add_self_loops, dtype=x.dtype)

            for k in range(self.K):
                # propagate_type: (x: Tensor, edge_weight: OptTensor)
                x = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                                   size=None)
            if self.cached:
                self._cached_x = x
        else:
            x = cache

        if self.bn:
            x = self.bn(x)
        if self.dropout > 0.:
            x = F.dropout(x, p=self.dropout, training=self.training)

        if not self.lin_first:
            x = self.lin(x)

        return x
Ejemplo n.º 20
0
    def neighborhood_aggregation(self, x, adj_t):
        if self.K <= 0:
            return x

        if self.normalize:
            adj_t = gcn_norm(adj_t, add_self_loops=False)

        if self.add_self_loops:
            adj_t = adj_t.set_diag()

        for k in range(self.K):
            x = self.propagate(adj_t, x=x)

        x = self.transform(x)
        return x
Ejemplo n.º 21
0
    def forward(self, data):
        edge_index = data.graph['edge_index']
        x = data.graph['node_feat']
        x = self.lin(x)
        n = data.graph['num_nodes']
        edge_weight=None

        if isinstance(edge_index, torch.Tensor):
            edge_index, edge_weight = gcn_norm( 
                edge_index, edge_weight, n, False,
                 dtype=x.dtype)
            row, col = edge_index
            adj_t = SparseTensor(row=col, col=row, value=edge_weight, sparse_sizes=(n, n))
        elif isinstance(edge_index, SparseTensor):
            edge_index = gcn_norm(
                edge_index, edge_weight, n, False,
                dtype=x.dtype)
            edge_weight=None
            adj_t = edge_index

        for _ in range(self.hops):
            x = matmul(adj_t, x)
        
        return x
Ejemplo n.º 22
0
    def forward(self, data):
        data = T.ToSparseTensor()(data)
        x, adj_t = data.x, data.adj_t
        adj_t = gcn_norm(adj_t)
        x = F.dropout(x, self.dropout, training=self.training)
        x = x_0 = self.lins[0](x).relu()

        for conv in self.convs:
            x = F.dropout(x, self.dropout, training=self.training)
            x = conv(x=x, x_0=x_0, edge_index=adj_t[0])
            x = x.relu()
        z = x
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.lins[1](x)

        return z, x.log_softmax(dim=-1)
Ejemplo n.º 23
0
    def forward(self,
                x,
                edge_index,
                edge_weight: Optional[torch.Tensor] = None):
        """"""
        edge_index, norm = gcn_norm(edge_index,
                                    x.size(self.node_dim),
                                    edge_weight,
                                    dtype=x.dtype)

        hidden = x
        for k in range(self.K):
            x = self.propagate(edge_index, x=x, norm=norm)
            x = x * (1 - self.alpha)
            x = x + self.alpha * hidden

        return x
Ejemplo n.º 24
0
    def forward(self, W: torch.FloatTensor, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""

        if self.normalize:
            cache = self._cached_edge_index
            if cache is None:
                edge_index, edge_weight = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops)

        x = torch.matmul(x, W)

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=None)

        return out
Ejemplo n.º 25
0
 def forward(self, x, adj_t, edge_types):
     """Calculates embeddings"""
     num_genes = (x.shape[-2] - self.num_dis_nodes - self.num_comp_nodes -
                  self.num_pathways)
     adj_t = gcn_norm(adj_t, num_nodes=x.size(-2), add_self_loops=False)
     x1 = F.relu(self.conv1(self.lin1(x), adj_t, edge_types))
     x1 = torch.cat(
         (
             self.normg1(x1[:num_genes]),
             self.normd1(x1[num_genes:num_genes + self.num_dis_nodes]),
             self.normc1(x1[num_genes + self.num_dis_nodes:num_genes +
                            self.num_dis_nodes + self.num_comp_nodes]),
             self.normp1(
                 x1[num_genes + self.num_dis_nodes + self.num_comp_nodes:]),
         ),
         0,
     )
     x2 = F.relu(self.conv2(x1, adj_t, edge_types))
     x2 = torch.cat(
         (
             self.normg2(x2[:num_genes]),
             self.normd2(x2[num_genes:num_genes + self.num_dis_nodes]),
             self.normc2(x2[num_genes + self.num_dis_nodes:num_genes +
                            self.num_dis_nodes + self.num_comp_nodes]),
             self.normp2(
                 x2[num_genes + self.num_dis_nodes + self.num_comp_nodes:]),
         ),
         0,
     )
     x3 = self.conv3(x2, adj_t, edge_types)
     x3 = torch.cat(
         (
             self.normg3(x3[:num_genes]),
             self.normd3(x3[num_genes:num_genes + self.num_dis_nodes]),
             self.normc3(x3[num_genes + self.num_dis_nodes:num_genes +
                            self.num_dis_nodes + self.num_comp_nodes]),
             self.normp3(
                 x3[num_genes + self.num_dis_nodes + self.num_comp_nodes:]),
         ),
         0,
     )
     x3 = self.drop(x3)
     return x3
Ejemplo n.º 26
0
    def __norm__(self, x, edge_index,
                 edge_weight: Optional[torch.Tensor] = None):

        cache = self._cache
        if cache is not None:
            if edge_index.size(1) != cache[0]:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}. Please disable '
                    'the caching behavior of this layer by removing the '
                    '`cached=True` argument in its constructor.'.format(
                        cache[0], edge_index.size(1)))
            return cache[1:]

        num_edges = edge_index.size(1)

        edge_index, edge_weight = gcn_norm(edge_index, x.size(self.node_dim),
                                           edge_weight, dtype=x.dtype)

        if self.cached:
            self._cache = (num_edges, edge_index, edge_weight)

        return edge_index, edge_weight
Ejemplo n.º 27
0
    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
        """"""
        symnorm_weight: OptTensor = None
        if "symnorm" in self.aggregators:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, symnorm_weight = gcn_norm(  # yapf: disable
                        edge_index,
                        None,
                        num_nodes=x.size(self.node_dim),
                        improved=False,
                        add_self_loops=self.add_self_loops)
                    if self.cached:
                        self._cached_edge_index = (edge_index, symnorm_weight)
                else:
                    edge_index, symnorm_weight = cache

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index,
                        None,
                        num_nodes=x.size(self.node_dim),
                        improved=False,
                        add_self_loops=self.add_self_loops)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        elif self.add_self_loops:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if self.cached and cache is not None:
                    edge_index = cache[0]
                else:
                    edge_index, _ = add_remaining_self_loops(edge_index)
                    if self.cached:
                        self._cached_edge_index = (edge_index, None)

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if self.cached and cache is not None:
                    edge_index = cache
                else:
                    edge_index = fill_diag(edge_index, 1.0)
                    if self.cached:
                        self._cached_adj_t = edge_index

        # [num_nodes, (out_channels // num_heads) * num_bases]
        bases = self.bases_lin(x)
        # [num_nodes, num_heads * num_bases * num_aggrs]
        weightings = self.comb_lin(x)

        # [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases]
        # propagate_type: (x: Tensor, symnorm_weight: OptTensor)
        aggregated = self.propagate(edge_index,
                                    x=bases,
                                    symnorm_weight=symnorm_weight,
                                    size=None)

        weightings = weightings.view(-1, self.num_heads,
                                     self.num_bases * len(self.aggregators))
        aggregated = aggregated.view(
            -1,
            len(self.aggregators) * self.num_bases,
            self.out_channels // self.num_heads,
        )

        # [num_nodes, num_heads, out_channels // num_heads]
        out = torch.matmul(weightings, aggregated)
        out = out.view(-1, self.out_channels)

        if self.bias is not None:
            out += self.bias

        return out
Ejemplo n.º 28
0
        adj_t = torch.load(path)
    else:
        path_sym = dataset.root + '/mag240m/paper_to_paper_symmetric.pt'
        if osp.exists(path_sym):
            adj_t = torch.load(path_sym)
        else:
            edge_index = dataset.edge_index('paper', 'cites', 'paper')
            edge_index = torch.from_numpy(edge_index)
            adj_t = SparseTensor(row=edge_index[0],
                                 col=edge_index[1],
                                 sparse_sizes=(dataset.num_papers,
                                               dataset.num_papers),
                                 is_sorted=True)
            adj_t = adj_t.to_symmetric()
            torch.save(adj_t, path_sym)
        adj_t = gcn_norm(adj_t, add_self_loops=True)
        torch.save(adj_t, path)
    print(f'Done! [{time.perf_counter() - t:.2f}s]')

    train_idx = dataset.get_idx_split('train')
    valid_idx = dataset.get_idx_split('valid')
    test_idx = dataset.get_idx_split('test')
    num_features = dataset.num_paper_features

    pbar = tqdm(total=args.num_layers * (num_features // 128))
    pbar.set_description('Pre-processing node features')

    for j in range(0, num_features, 128):  # Run spmm in chunks...
        x = dataset.paper_feat[:, j:min(j + 128, num_features)]
        x = torch.from_numpy(x.astype(np.float32))
Ejemplo n.º 29
0
    def forward(self,
                x: Tensor,
                x_0: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        """"""

        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        False,
                        self.add_self_loops,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index,
                        edge_weight,
                        x.size(self.node_dim),
                        False,
                        self.add_self_loops,
                        dtype=x.dtype)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)

        x.mul_(1 - self.alpha)
        x_0 = self.alpha * x_0[:x.size(0)]

        if self.weight2 is None:
            out = x.add_(x_0)
            out = torch.addmm(out,
                              out,
                              self.weight1,
                              beta=1. - self.beta,
                              alpha=self.beta)
        else:
            out = torch.addmm(x,
                              x,
                              self.weight1,
                              beta=1. - self.beta,
                              alpha=self.beta)
            out += torch.addmm(x_0,
                               x_0,
                               self.weight2,
                               beta=1. - self.beta,
                               alpha=self.beta)

        return out
Ejemplo n.º 30
0
    t = time.perf_counter()
    print('Reading adjacency matrix...', end=' ', flush=True)
    path = f'{dataset.dir}/paper_to_paper_symmetric.pt'
    if osp.exists(path):
        adj_t = torch.load(path)
    else:
        edge_index = dataset.edge_index('paper', 'cites', 'paper')
        edge_index = torch.from_numpy(edge_index)
        adj_t = SparseTensor(
            row=edge_index[0], col=edge_index[1],
            sparse_sizes=(dataset.num_papers, dataset.num_papers),
            is_sorted=True)
        adj_t = adj_t.to_symmetric()
        torch.save(adj_t, path)
    adj_t = gcn_norm(adj_t, add_self_loops=False)
    if args.low_memory:
        adj_t = adj_t.to(torch.half)
    print(f'Done! [{time.perf_counter() - t:.2f}s]')

    train_idx = dataset.get_idx_split('train')
    valid_idx = dataset.get_idx_split('valid')
    test_idx = dataset.get_idx_split('test')

    y_train = torch.from_numpy(dataset.paper_label[train_idx]).to(torch.long)
    y_valid = torch.from_numpy(dataset.paper_label[valid_idx]).to(torch.long)

    model = LabelPropagation(args.num_layers, args.alpha)

    N, C = dataset.num_papers, dataset.num_classes