示例#1
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}. Please '
                    'disable the caching behavior of this layer by removing '
                    'the `cached=True` argument in its constructor.'.format(
                        self.cached_num_edges, edge_index.size(1)))

        if not self.cached:
            x = self.lin(x)

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = GCNConv.norm(edge_index,
                                            x.size(0),
                                            edge_weight,
                                            dtype=x.dtype)

            for k in range(self.K):
                x = self.propagate(edge_index, x=x, norm=norm)
            self.cached_result = x

        if self.cached:
            x = self.lin(self.cached_result)

        return x
示例#2
0
    def forward(self, x, edge_index, edge_weight=None):
        edge_index, norm = GCNConv.norm(edge_index, x.size(0), edge_weight,
                                        dtype=x.dtype)

        for k in range(self.K):
            x = self.propagate(edge_index, x=x, norm=norm)
        return x
示例#3
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}'.format(
                        self.cached_num_edges, edge_index.size(1)))

        if not self.cached:
            x = self.lin(x)

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = GCNConv.norm(edge_index,
                                            x.size(0),
                                            edge_weight,
                                            dtype=x.dtype)

            for k in range(self.K):
                x = self.propagate(edge_index, x=x, norm=norm)
            self.cached_result = x

        if self.cached:
            x = self.lin(self.cached_result)

        return x
示例#4
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        # x: [num_nodes, num_layers, channels]
        # edge_index: [2, num_edges]
        # edge_weight: [num_edges]

        if x.dim() != 3:
            raise ValueError('Feature shape must be [num_nodes, num_layers, '
                             'channels].')
        num_nodes, num_layers, channels = x.size()

        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}. Please '
                    'disable the caching behavior of this layer by removing '
                    'the `cached=True` argument in its constructor.'.format(
                        self.cached_num_edges, edge_index.size(1)))

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = GCNConv.norm(edge_index,
                                            x.size(self.node_dim),
                                            edge_weight,
                                            dtype=x.dtype)
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)
示例#5
0
    def forward(self, x, edge_index, edge_weight=None):
        edge_index, norm = GCNConv.norm(edge_index, x.size(0), edge_weight,
                                        dtype=x.dtype)

        xs = [x]
        for k in range(self.K):
            xs.append(self.propagate(edge_index, x=xs[-1], norm=norm))
        return torch.cat(xs, dim = 1)
示例#6
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        if not self.cached or self.cached_result is None:
            edge_index, norm = GCNConv.norm(
                edge_index, x.size(0), edge_weight, dtype=x.dtype)

            for k in range(self.K):
                x = self.propagate(edge_index, x=x, norm=norm)
            self.cached_result = x

        return self.lin(self.cached_result)
示例#7
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        edge_index, norm = GCNConv.norm(
            edge_index, x.size(0), edge_weight, dtype=x.dtype)

        hidden = x
        for k in range(self.K):
            x = self.propagate('add', edge_index, x=x, norm=norm)
            x = x * (1 - self.alpha)
            x = x + self.alpha * hidden

        return x
示例#8
0
    def forward(self, x, edge_index, edge_weight=None):
        edge_index, norm = GCNConv.norm(edge_index, x.size(0), edge_weight, dtype=x.dtype)

        preds = []
        preds.append(x)
        for k in range(self.K):
            x = self.propagate(edge_index, x=x, norm=norm)
            preds.append(x)
           
        pps = torch.stack(preds, dim=1)
        retain_score = self.proj(pps)
        retain_score = retain_score.squeeze()
        retain_score = torch.sigmoid(retain_score)
        retain_score = retain_score.unsqueeze(1)
        out = torch.matmul(retain_score, pps).squeeze()
        return out
示例#9
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        # X: [num_nodes, num_layers, channels]
        # Edge Index: [2, num_edges]
        # Edge Weight: [num_edges]

        if x.dim() != 3:
            raise ValueError('Feature shape must be [num_nodes, num_layers, '
                             'channels].')
        num_nodes, num_layers, channels = x.size()

        if not self.cached or self.cached_result is None:
            edge_index, norm = GCNConv.norm(
                edge_index, x.size(0), edge_weight, dtype=x.dtype)
            self.cached_result = edge_index, norm
        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)