Example #1
0
    def forward(self, graph, x):
        edge_index = graph.edge_index
        dim = x.shape[1]
        edge_msg = x[edge_index[
            1]]  # if edge_attr is None else x[edge_index[1]] + edge_attr
        edge_msg = self.act(edge_msg) + self.eps

        if self.aggr == "softmax_sg":
            h = mul_edge_softmax(graph, self.beta * edge_msg).T
            h = edge_msg * h
        elif self.aggr == "softmax":
            h = mul_edge_softmax(graph, edge_msg).T
            h = edge_msg * h
        elif self.aggr == "powermean":
            deg = graph.degrees()
            h = edge_msg.pow(self.t) / deg[edge_index[0]].unsqueeze(-1)
        else:
            raise NotImplementedError

        h = torch.zeros_like(x).scatter_add_(
            dim=0, index=edge_index[0].unsqueeze(-1).repeat(1, dim), src=h)
        if self.aggr == "powermean":
            h = h.pow(1.0 / self.p)
        if self.use_msg_norm:
            h = self.message_norm(x, h)
        h = self.mlp(h)
        return h
Example #2
0
    def forward(self, x, edge_index, edge_attr=None):
        device = x.device
        dim = x.shape[1]
        num_nodes = x.shape[0]
        edge_msg = x[edge_index[1]]  # if edge_attr is None else x[edge_index[1]] + edge_attr
        edge_msg = self.act(edge_msg) + self.eps

        if self.aggr == "softmax_sg":
            h = mul_edge_softmax(edge_index, self.beta * edge_msg, shape=(num_nodes, num_nodes))
            h = edge_msg * h
        elif self.aggr == "softmax":
            h = mul_edge_softmax(edge_index, edge_msg, shape=(num_nodes, num_nodes))
            h = edge_msg * h
        elif self.aggr == "powermean":
            deg = spmm(
                indices=edge_index,
                values=torch.ones(edge_index.shape[1]),
                b=torch.ones(num_nodes).unsqueeze(-1).to(device),
            ).view(-1)
            h = edge_msg.pow(self.t) / deg[edge_index[0]].unsqueeze(-1)
        else:
            raise NotImplementedError

        h = torch.zeros_like(x).scatter_add_(dim=0, index=edge_index[0].unsqueeze(-1).repeat(1, dim), src=h)
        if self.aggr == "powermean":
            h = h.pow(1.0 / self.p)
        if self.use_msg_norm:
            h = self.message_norm(x, h)
        h = self.mlp(h)
        return h
Example #3
0
    def forward(self, x, edge_index):
        num_nodes = x.shape[0]
        device = x.device

        h = self.activation(torch.matmul(x, self.weight) + self.bias)

        h = h.split(self.factor_dim, dim=-1)
        h = torch.cat([dt.unsqueeze(0) for dt in h], dim=0)
        norm = h.pow(2).sum(dim=-1).sqrt().unsqueeze(-1)

        # multi-channel softmax: faster
        h_normed = h / norm  # (K, N, d)
        h_src = h_dst = h_normed.permute(1, 0, 2)  # (N, K, d)
        add_shape = h.shape  # (K, N, d)

        for _ in range(self.iterations):
            src_edge_attr = h_dst[edge_index[0]] * h_src[edge_index[1]]
            src_edge_attr = src_edge_attr.sum(
                dim=-1) / self.tau  # shape: (N, K)
            edge_attr_softmax = mul_edge_softmax(
                edge_index, src_edge_attr,
                shape=(num_nodes, num_nodes))  # shape: (E, K)
            edge_attr_softmax = edge_attr_softmax.t().unsqueeze(
                -1)  # shape: (K, E, 1)

            dst_edge_attr = h_src.index_select(0, edge_index[1]).permute(
                1, 0, 2)  # shape: (E, K, d) -> (K, E, d)
            dst_edge_attr = dst_edge_attr * edge_attr_softmax
            edge_index_ = edge_index[0].unsqueeze(-1).unsqueeze(0).repeat(
                self.K, 1, h.shape[-1])
            node_attr = torch.zeros(add_shape).to(device).scatter_add_(
                1, edge_index_, dst_edge_attr)  # (K, N, d)
            node_attr = node_attr + h_normed
            node_attr_norm = node_attr.pow(2).sum(-1).sqrt().unsqueeze(
                -1)  # shape: (K, N, 1)
            node_attr = (node_attr / node_attr_norm).permute(
                1, 0, 2)  # shape: (N, K, d)
            h_dst = node_attr

        h_dst = h_dst.reshape(num_nodes, -1)

        # Calculate the softmax of each channel separately
        # h_src = h_dst = h / norm  # (K, N, d)
        #
        # for _ in range(self.iterations):
        #     for i in range(self.K):
        #         h_attr = h_dst[i]
        #         edge_attr = h_attr[edge_index[0]] * h_src[i][edge_index[1]]
        #
        #         edge_attr = edge_attr.sum(-1)/self.tau
        #         edge_attr = edge_softmax(edge_index, edge_attr, shape=(num_nodes, num_nodes))
        #
        #         node_attr = spmm(edge_index, edge_attr, h_src[i])
        #
        #         node_attr = node_attr + h_src[i]
        #         h_src[i] = node_attr / node_attr.pow(2).sum(-1).sqrt().unsqueeze(-1)
        #
        # h_dst = h_dst.permute(1, 0, 2).reshape(num_nodes, -1)

        return h_dst
Example #4
0
File: gat.py Project: znsoftm/cogdl
    def forward(self, graph, x):
        h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features)
        # h: N * H * d
        h[torch.isnan(h)] = 0.0

        edge_index = graph.edge_index
        # Self-attention on the nodes - Shared attention mechanism
        h_l = (self.a_l * h).sum(dim=-1)[edge_index[0, :]]
        h_r = (self.a_r * h).sum(dim=-1)[edge_index[1, :]]
        edge_attention = self.leakyrelu(h_l + h_r)
        # edge_e: E * H
        edge_attention = mul_edge_softmax(graph, edge_attention)
        num_edges = graph.num_edges
        num_nodes = graph.num_nodes

        with graph.local_graph():
            if self.fast_mode:
                edge_attention = edge_attention.view(-1)
                edge_attention = self.dropout(edge_attention)

                edge_index = edge_index.view(-1)
                edge_index = edge_index.unsqueeze(0).repeat(self.nhead, 1)
                add_num = torch.arange(0, self.nhead * num_nodes,
                                       num_nodes).view(-1,
                                                       1).to(edge_index.device)
                edge_index = edge_index + add_num
                edge_index = edge_index.split((num_edges, num_edges), dim=1)

                row, col = edge_index
                row = row.reshape(-1)
                col = col.reshape(-1)
                edge_index = torch.stack([row, col])

                graph.edge_index = edge_index
                graph.edge_weight = edge_attention
                h_prime = spmm(
                    graph,
                    h.permute(1, 0, 2).reshape(num_nodes * self.nhead, -1))
                assert not torch.isnan(h_prime).any()
                h_prime = h_prime.split([num_nodes] * self.nhead)
            else:
                edge_attention = self.dropout(edge_attention)
                h_prime = []
                h = h.permute(1, 0, 2).contiguous()
                for i in range(self.nhead):
                    edge_weight = edge_attention[i]
                    graph.edge_weight = edge_weight
                    hidden = h[i]
                    assert not torch.isnan(hidden).any()
                    h_prime.append(spmm(graph, hidden))
        if self.residual:
            res = self.residual(x)
        else:
            res = 0

        if self.concat:
            out = torch.cat(h_prime, dim=1) + res
        else:
            out = sum(h_prime) / self.nhead + res
        return out
Example #5
0
    def forward(self, x, edge):
        N = x.size()[0]
        h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features)
        # h: N * H * d
        if torch.isnan(self.W.data).any():
            # print("NaN in Graph Attention, ", self.nhead)
            h[torch.isnan(h)] = 0

        # Self-attention on the nodes - Shared attention mechanism
        h_l = (self.a_l * h).sum(dim=-1)[edge[0, :]]
        h_r = (self.a_r * h).sum(dim=-1)[edge[1, :]]
        edge_attention = self.leakyrelu(h_l + h_r)
        # edge_e: E * H
        edge_attention = mul_edge_softmax(edge, edge_attention, shape=(N, N))

        num_edges = edge.shape[1]
        num_nodes = x.shape[0]

        if self.fast_mode:
            edge_attention = edge_attention.view(-1)
            edge_attention = self.dropout(edge_attention)

            edge_index = edge.view(-1)
            edge_index = edge_index.unsqueeze(0).repeat(self.nhead, 1)
            add_num = torch.arange(0, self.nhead * num_nodes,
                                   num_nodes).view(-1, 1).to(edge_index.device)
            edge_index = edge_index + add_num
            edge_index = edge_index.split((num_edges, num_edges), dim=1)

            row, col = edge_index
            row = row.reshape(-1)
            col = col.reshape(-1)
            edge_index = torch.stack([row, col])

            h_prime = spmm(
                edge_index, edge_attention,
                h.permute(1, 0, 2).reshape(num_nodes * self.nhead, -1))
            assert not torch.isnan(h_prime).any()
            h_prime = h_prime.split([num_nodes] * self.nhead)
        else:
            h_prime = []
            h = h.permute(1, 0, 2).contiguous()
            for i in range(self.nhead):
                edge_weight = edge_attention[:, i]
                hidden = h[i]
                assert not torch.isnan(hidden).any()
                h_prime.append(spmm(edge, edge_weight, hidden))

        if self.residual:
            res = self.residual(x)
        else:
            res = 0

        if self.concat:
            # if this layer is not last layer,
            out = torch.cat(h_prime, dim=1) + res
        else:
            # if this layer is last layer,
            out = sum(h_prime) / self.nhead + res
        return out
Example #6
0
    def forward(self, graph, x):
        h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features)
        h[torch.isnan(h)] = 0.0

        row, col = graph.edge_index
        # Self-attention on the nodes - Shared attention mechanism
        h_l = (self.a_l * h).sum(dim=-1)[row]
        h_r = (self.a_r * h).sum(dim=-1)[col]
        edge_attention = self.leakyrelu(h_l + h_r)
        # edge_attention: E * H
        edge_attention = mul_edge_softmax(graph, edge_attention)
        edge_attention = self.dropout(edge_attention)

        if check_mh_spmm() and next(self.parameters()).device.type != "cpu":
            if self.nhead > 1:
                h_prime = mh_spmm(graph, edge_attention, h)
                out = h_prime.view(h_prime.shape[0], -1)
            else:
                edge_weight = edge_attention.view(-1)
                with graph.local_graph():
                    graph.edge_weight = edge_weight
                    out = spmm(graph, h.squeeze(1))
        else:
            with graph.local_graph():
                h_prime = []
                h = h.permute(1, 0, 2).contiguous()
                for i in range(self.nhead):
                    edge_weight = edge_attention[:, i]
                    graph.edge_weight = edge_weight
                    hidden = h[i]
                    assert not torch.isnan(hidden).any()
                    h_prime.append(spmm(graph, hidden))
            out = torch.cat(h_prime, dim=1)

        if self.residual:
            res = self.residual(x)
            out += res
        return out