Exemple #1
0
def test_simple_readout():
    g1 = dgl.DGLGraph()
    g1.add_nodes(3)
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)  # no edges
    g1.add_edges([0, 1, 2], [2, 0, 1])

    n1 = F.randn((3, 5))
    n2 = F.randn((4, 5))
    e1 = F.randn((3, 5))
    s1 = F.sum(n1, 0)  # node sums
    s2 = F.sum(n2, 0)
    se1 = F.sum(e1, 0)  # edge sums
    m1 = F.mean(n1, 0)  # node means
    m2 = F.mean(n2, 0)
    me1 = F.mean(e1, 0)  # edge means
    w1 = F.randn((3, ))
    w2 = F.randn((4, ))
    max1 = F.max(n1, 0)
    max2 = F.max(n2, 0)
    maxe1 = F.max(e1, 0)
    ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0)
    ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0)
    wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0)
    wm2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) / F.sum(F.unsqueeze(w2, 1), 0)
    g1.ndata['x'] = n1
    g2.ndata['x'] = n2
    g1.ndata['w'] = w1
    g2.ndata['w'] = w2
    g1.edata['x'] = e1

    assert F.allclose(dgl.sum_nodes(g1, 'x'), s1)
    assert F.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
    assert F.allclose(dgl.sum_edges(g1, 'x'), se1)
    assert F.allclose(dgl.mean_nodes(g1, 'x'), m1)
    assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
    assert F.allclose(dgl.mean_edges(g1, 'x'), me1)
    assert F.allclose(dgl.max_nodes(g1, 'x'), max1)
    assert F.allclose(dgl.max_edges(g1, 'x'), maxe1)

    g = dgl.batch([g1, g2])
    s = dgl.sum_nodes(g, 'x')
    m = dgl.mean_nodes(g, 'x')
    max_bg = dgl.max_nodes(g, 'x')
    assert F.allclose(s, F.stack([s1, s2], 0))
    assert F.allclose(m, F.stack([m1, m2], 0))
    assert F.allclose(max_bg, F.stack([max1, max2], 0))
    ws = dgl.sum_nodes(g, 'x', 'w')
    wm = dgl.mean_nodes(g, 'x', 'w')
    assert F.allclose(ws, F.stack([ws1, ws2], 0))
    assert F.allclose(wm, F.stack([wm1, wm2], 0))
    s = dgl.sum_edges(g, 'x')
    m = dgl.mean_edges(g, 'x')
    max_bg_e = dgl.max_edges(g, 'x')
    assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
    assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
    assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
Exemple #2
0
    def forward(self, graph):
        # graph: a dgl graph
        # use node degree as the initial node feature
        h = graph.in_degrees()
        h1 = h.view(-1, 1).float()
        h2 = (h - 3) > 0
        h2 = h2.view(-1, 1).float()
        h3 = 3 / h1
        h4 = (h - 4) > 0
        h4 = h4.view(-1, 1).float()

        h_ = th.cat((h1, h2, h3, h4), 1)

        h_at = self.gat1(h_, graph)
        h_at = self.gat2(h_at, graph)
        h_at = self.gat3(h_at, graph)

        graph.ndata['h'] = h_at

        # calculate graph representation
        graph_emb = dgl.max_nodes(graph, 'h')
        h_at2 = self.drop_layer1(graph_emb)
        pred = F.sigmoid(self.classify(h_at2))

        return pred, graph_emb, graph.ndata['h']
Exemple #3
0
    def forward(self, graph):
        # graph: a dgl graph
        # use node degree as the initial node feature
        # and binary variable if node has fractional value
        h = graph.in_degrees()
        h1 = h.view(-1, 1).float()
        h2 = (h - 3) > 0
        h2 = h2.view(-1, 1).float()
        h3 = 3 / h1
        h4 = (h - 4) > 0
        h4 = h4.view(-1, 1).float()

        h_ = th.cat((h1, h2, h3, h4), 1)

        # perform graph convolution and activation function (relu)
        h_co = F.relu(self.conv1(graph, h_))
        h_co = F.relu(self.conv2(graph, h_co))
        graph.ndata['h'] = h_co
        # calculate graph representation
        graph_emb = dgl.max_nodes(graph, 'h')
        h_emb = F.relu(self.layer1(graph_emb))
        h_emb = self.drop_layer1(h_emb)
        pred = F.sigmoid(self.layer2(h_emb))

        return pred
Exemple #4
0
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.edge_feat:
            e = self.embedding_e(e)

        # Loop all layers
        for i, conv in enumerate(self.layers):
            # Graph conv layers
            h_t = conv(g, h, e, snorm_n)
            h = h_t

            # Virtual node layer
            if self.virtual_node_layers is not None:
                if i == 0:
                    vn_h = 0
                if i < len(self.virtual_node_layers):
                    vn_h, h = self.virtual_node_layers[i].forward(g, h, vn_h)

        g.ndata['h'] = h

        # Readout layer
        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Exemple #5
0
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        h_init = h
        '''for conv in self.layers:
            h = conv(g, h, snorm_n)
            h = self.joining_layer(h_init + h)'''

        for i in range(self.layer_count):
            conv = self.layers[i]
            joint = self.joining_layers[i]
            h = conv(g, h, snorm_n)
            h = joint(h_init + h)

        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Exemple #6
0
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.pos_enc_dim > 0:
            h_pos_enc = self.embedding_pos_enc(g.ndata['pos_enc'].to(
                self.device))
            h = h + h_pos_enc
        if self.edge_feat:
            e = self.embedding_e(e)

        for i, conv in enumerate(self.layers):
            h_t = conv(g, h, e, snorm_n)
            h = h_t

        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Exemple #7
0
    def forward(self, g, h, e, pos_enc=None):

        # input embedding
        if self.pos_enc:
            h = self.embedding_pos_enc(pos_enc)
        else:
            h = self.embedding_h(h)

        # computing the 'pseudo' named tensor which depends on node degrees
        g.ndata['deg'] = g.in_degrees()
        g.apply_edges(self.compute_pseudo)
        pseudo = g.edata['pseudo'].to(self.device).float()

        for i in range(len(self.layers)):
            h = self.layers[i](g, h, self.pseudo_proj[i](pseudo))
        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Exemple #8
0
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        #h = self.in_feat_dropout(h)
        if self.JK == 'sum':
            h_list = [h]

        for i, conv in enumerate(self.layers):
            h_t = conv(g, h, e, snorm_n)
            if self.gru_enable and i != len(self.layers) - 1:
                h_t = self.gru(h, h_t)
            h = h_t
            if self.JK == 'sum':
                h_list.append(h)

        g.ndata['h'] = h

        if self.JK == 'last':
            g.ndata['h'] = h

        elif self.JK == 'sum':
            h = 0
            for layer in h_list:
                h += layer
            g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = None

        return self.MLP_layer(hg)
Exemple #9
0
    def forward(self, feats, bg):
        """Multi-task prediction for a batch of molecules
        Parameters
        ----------
        feats : FloatTensor of shape (N, M0)
            Initial features for all atoms in the batch of molecules
        bg : BatchedDGLGraph
            B Batched DGLGraphs for processing multiple molecules in parallel
        Returns
        -------
        FloatTensor of shape (B, n_tasks)
            Soft prediction for all tasks on the batch of molecules
        """
        # Update atom features
        for gcn in self.gcn_layers:
            feats = gcn(feats, bg)

        # Compute molecule features from atom features
        bg.ndata[self.atom_data_field] = feats
        bg.ndata[self.atom_weight_field] = self.atom_weighting(feats)
        h_g_sum = dgl.sum_nodes(bg, self.atom_data_field,
                                self.atom_weight_field)
        h_g_max = dgl.max_nodes(bg, self.atom_data_field)
        h_g = torch.cat([h_g_sum, h_g_max], dim=1)

        # Multi-task prediction
        return self.soft_classifier(h_g)
    def forward(self, g, h, e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            # For reduced graphs
            h = conv(g, h, e)
            # For original graphs
            # h = conv(g, h)
        g.ndata['h'] = h
        if self.readout == "sum":
            # For reduced graphs
            hg = dgl.sum_nodes(g, feat='h', weight='weight')

            # For original graphs
            # hg = dgl.sum_nodes(g, feat= 'h')
        elif self.readout == "max":
            # For reduced graphs
            hg = dgl.max_nodes(g, feat='h', weight='weight')

            # For original graphs
            # hg = dgl.max_nodes(g, feat= 'h')
        elif self.readout == "mean":
            # For reduced graphs
            hg = dgl.mean_nodes(g, feat='h', weight='weight')

            # For original graphs
            # hg = dgl.mean_nodes(g, feat= 'h')
        else:
            # For reduced graphs
            hg = dgl.mean_nodes(
                g, feat='h', weight='weight')  # default readout is mean nodes

            # For original graphs
            # hg = dgl.mean_nodes(g, feat= 'h')
        return self.MLP_layer(hg)
Exemple #11
0
    def forward(self,
                g,
                h,
                e,
                snorm_n,
                snorm_e,
                mlp=True,
                head=False,
                return_graph=False):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            h = conv(g, h, snorm_n)
        g.ndata['h'] = h

        if return_graph:
            return g

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        if mlp:
            return self.MLP_layer(hg)
        else:
            if head:
                return self.projection_head(hg)
            else:
                return hg
Exemple #12
0
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.pos_enc_dim > 0:
            h_pos_enc = self.embedding_pos_enc(g.ndata['pos_enc'].to(
                self.device))
            h = h + h_pos_enc
        if self.JK == 'sum':
            h_list = [h]
        if self.edge_feat:
            e = self.embedding_e(e)

        for i, conv in enumerate(self.layers):
            h_t = conv(g, h, e, snorm_n)
            if self.gru_enable and i != len(self.layers) - 1:
                h_t = self.gru(h, h_t)
            h = h_t
            if self.JK == 'sum':
                h_list.append(h)

        g.ndata['h'] = h

        if self.JK == 'last':
            g.ndata['h'] = h

        elif self.JK == 'sum':
            h = 0
            for layer in h_list:
                h += layer
            g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        elif self.readout == "directional_abs":
            g.ndata['dir'] = h * torch.abs(g.ndata['eig'][:, 1:2].to(
                self.device)) / torch.sum(torch.abs(g.ndata['eig'][:, 1:2].to(
                    self.device)),
                                          dim=1,
                                          keepdim=True)
            hg = torch.cat([dgl.mean_nodes(g, 'dir'),
                            dgl.mean_nodes(g, 'h')],
                           dim=1)
        elif self.readout == "directional":
            g.ndata['dir'] = h * g.ndata['eig'][:, 1:2].to(
                self.device) / torch.sum(torch.abs(g.ndata['eig'][:, 1:2].to(
                    self.device)),
                                         dim=1,
                                         keepdim=True)
            hg = torch.cat(
                [torch.abs(dgl.mean_nodes(g, 'dir')),
                 dgl.mean_nodes(g, 'h')],
                dim=1)
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
    def forward(self, g, h, e, snorm_n, snorm_e):

        #   modified dtype for new dataset
        h = h.float()

        h = self.embedding_lin(h.cuda())
        h_in = h  # for residual connection

        # list of hidden representation at each layer (including input)
        hidden_rep = [h]

        for i in range(self.n_layers):
            h = self.ginlayers[i](g, h, snorm_n)

            # Residual Connection
            if self.residual:
                if self.residual == "gated":
                    z = torch.sigmoid(self.W_g(torch.cat([h, h_in], dim=1)))
                    h = z * h + (torch.ones_like(z) - z) * h_in
                else:
                    h += h_in

        g.ndata['h'] = self.linear_ro(h)
        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.sum_nodes(g, 'h')  # default readout is summation

        score = self.linear_prediction(hg)

        return score
    def forward(self, g, h, e, snorm_n, snorm_e):

        #   modified dtype for new dataset
        h = h.float()

        h = self.embedding_lin(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            h_in = h
            h = conv(g, h, snorm_n)
            if self.residual:
                if self.residual == "gated":
                    z = torch.sigmoid(self.W_g(torch.cat([h, h_in], dim=1)))
                    h = z * h + (torch.ones_like(z) - z) * h_in
                else:
                    h += h_in

        g.ndata['h'] = self.linear_ro(h)

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.sum_nodes(g, 'h')  # default readout is summation

        return self.linear_predict(hg)
Exemple #15
0
    def forward(self, g, h, e, h_lap_pos_enc=None, h_wl_pos_enc=None):

        # input embedding
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.lap_pos_enc:
            h_lap_pos_enc = self.embedding_lap_pos_enc(h_lap_pos_enc.float())
            h = h + h_lap_pos_enc
        if self.wl_pos_enc:
            h_wl_pos_enc = self.embedding_wl_pos_enc(h_wl_pos_enc)
            h = h + h_wl_pos_enc
        if not self.edge_feat:  # edge feature set to 1
            e = torch.ones(e.size(0), 1).to(self.device)
        e = self.embedding_e(e)

        # convnets
        for conv in self.layers:
            h, e = conv(g, h, e)
        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Exemple #16
0
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)

        # computing the 'pseudo' named tensor which depends on node degrees
        us, vs = g.edges()
        # to avoid zero division in case in_degree is 0, we add constant '1' in all node degrees denoting self-loop
        pseudo = [[
            1 / np.sqrt(g.in_degree(us[i]) + 1),
            1 / np.sqrt(g.in_degree(vs[i]) + 1)
        ] for i in range(g.number_of_edges())]
        pseudo = torch.Tensor(pseudo).to(self.device)

        for i in range(len(self.layers)):
            h = self.layers[i](g, h, self.pseudo_proj[i](pseudo), snorm_n)
        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Exemple #17
0
    def forward(self, bg, feats):
        """Readout

        Parameters
        ----------
        bg : DGLGraph
            DGLGraph for a batch of graphs.
        feats : FloatTensor of shape (N, M1)
            * N is the total number of nodes in the batch of graphs
            * M1 is the input node feature size, which must match
              in_feats in initialization

        Returns
        -------
        h_g : FloatTensor of shape (B, 2 * M1)
            * B is the number of graphs in the batch
            * M1 is the input node feature size, which must match
              in_feats in initialization
        """
        h_g_sum = self.weight_and_sum(bg, feats)
        with bg.local_scope():
            bg.ndata['h'] = feats
            h_g_max = dgl.max_nodes(bg, 'h')
        h_g = torch.cat([h_g_sum, h_g_max], dim=1)
        return h_g
Exemple #18
0
    def forward(self, g):
        inputs = g.ndata['feat'].view(-1, 73).float()

        h = self.conv1(g, inputs)
        h = self.relu(h)
        h = self.dropout(h)

        h = self.conv2(g, h)
        h = self.relu(h)
        h = self.dropout(h)

        h = self.conv3(g, h)
        h = self.relu(h)

        g.ndata['h'] = h
        h = dgl.max_nodes(g, 'h')

        # full connect
        h = h.reshape((-1, 73 * 2))
        h = self.fc1(h)
        h = self.sig(h)
        h = self.dropout(h)

        h = self.fc2(h)
        out = self.sig(h)

        return out
    def forward(self, g, graph_pooling):
        """
        Forward pass on the graph.
        :param g: The graph
        :param graph_pooling: Binary value indicating if the GAT embedding must be pooled (with max) in order to have
        an output on the global graph. Otherwise, output are node-dependant
        :return: prediction of the GAT network
        """

        for l, layer in enumerate(self.embedding_layer[:-1]):
            g = layer(g)
            g.ndata["n_feat"] = torch.relu(g.ndata["n_feat"])
            g.edata["e_feat"] = torch.relu(g.edata["e_feat"])

        last_layer = self.embedding_layer[-1]
        g = last_layer(g)
        g.ndata["n_feat"] = torch.relu(g.ndata["n_feat"])
        g.edata["e_feat"] = torch.relu(g.edata["e_feat"])

        if graph_pooling:
            out = dgl.max_nodes(g, "n_feat")
            for l, layer in enumerate(self.fc_layer):
                out = torch.relu(layer(out))
            out = self.fc_out(out)
            return out

        else:
            for l, layer in enumerate(self.fc_layer):
                g.ndata["n_feat"] = torch.relu(layer(g.ndata["n_feat"]))
            g.ndata["n_feat"] = self.fc_out(g.ndata["n_feat"])

            return g
Exemple #20
0
    def forward(self, g):
        h = g.ndata['attr']
        h = h.to(self.device)

        # list of hidden representation at each layer (including input)
        hidden_rep = [h]

        for layer in range(self.num_layers - 1):
            h = self.ginlayers[layer](g, h)
            hidden_rep.append(h)

        score_over_layer = 0

        # perform pooling over all nodes in each graph in every layer
        for layer, h in enumerate(hidden_rep):
            g.ndata['h'] = h
            if self.graph_pooling_type == 'sum':
                pooled_h = dgl.sum_nodes(g, 'h')
            elif self.graph_pooling_type == 'mean':
                pooled_h = dgl.mean_nodes(g, 'h')
            elif self.graph_pooling_type == 'max':
                pooled_h = dgl.max_nodes(g, 'h')
            else:
                raise NotImplementedError()

            score_over_layer += F.dropout(
                self.linears_prediction[layer](pooled_h),
                self.final_dropout,
                training=self.training)

        return score_over_layer
Exemple #21
0
    def forward(self, g, h):
        h = self._forward(g, h)
        if self._data_type == 'nc':
            h = self.classifier(h)
            return h
        elif self._data_type in ['gc', 'rg']:
            g.ndata['h'] = h
            if self._readout == "sum":
                hg = dgl.sum_nodes(g, 'h')
            elif self._readout == "max":
                hg = dgl.max_nodes(g, 'h')
            elif self._readout == "mean":
                hg = dgl.mean_nodes(g, 'h')
            else:
                hg = dgl.mean_nodes(g, 'h')
            hg = self.classifier(hg)
            return hg
        elif self._data_type in ['ec']:

            def _edge_feat(edges):
                e = torch.cat([edges.src['h'], edges.dst['h']], dim=1)
                e = self.classifier(e)
                return {'e': e}

            g.ndata['h'] = h
            g.apply_edges(_edge_feat)
            return g.edata['e']
Exemple #22
0
    def forward(self, g, node_feats):
        """Computes graph representations out of node features.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_feats)
            Input node features, V for the number of nodes.

        Returns
        -------
        graph_feats : float32 tensor of shape (G, graph_feats)
            Graph representations computed. G for the number of graphs.
        """
        node_feats = self.in_project(node_feats)
        if self.activation is not None:
            node_feats = self.activation(node_feats)
        node_feats = self.out_project(node_feats)

        with g.local_scope():
            g.ndata['h'] = node_feats
            if self.mode == 'max':
                graph_feats = dgl.max_nodes(g, 'h')
            elif self.mode == 'mean':
                graph_feats = dgl.mean_nodes(g, 'h')
            elif self.mode == 'sum':
                graph_feats = dgl.sum_nodes(g, 'h')

        return graph_feats
Exemple #23
0
    def forward(self, bg, feats):
        """Multi-task prediction for a batch of molecules

        Parameters
        ----------
        bg : BatchedDGLGraph
            B Batched DGLGraphs for processing multiple molecules in parallel
        feats : FloatTensor of shape (N, M0)
            Initial features for all atoms in the batch of molecules

        Returns
        -------
        FloatTensor of shape (B, n_tasks)
            Soft prediction for all tasks on the batch of molecules
        """
        # Update atom features with GNNs
        for gnn in self.gnn_layers:
            feats = gnn(bg, feats)

        # Compute molecule features from atom features
        h_g_sum = self.weighted_sum_readout(bg, feats)

        with bg.local_scope():
            bg.ndata['h'] = feats
            h_g_max = dgl.max_nodes(bg, 'h')

        if not isinstance(bg, BatchedDGLGraph):
            h_g_sum = h_g_sum.unsqueeze(0)
            h_g_max = h_g_max.unsqueeze(0)
        h_g = torch.cat([h_g_sum, h_g_max], dim=1)

        # Multi-task prediction
        return self.soft_classifier(h_g)
Exemple #24
0
    def forward(self, bg, feats):

        h_g_sum = self.weight_and_sum(bg, feats)
        with bg.local_scope():
            bg.ndata['h'] = feats
            h_g_max = dgl.max_nodes(bg, 'h')
        h_g = torch.cat([h_g_sum, h_g_max], dim=1)
        return h_g
Exemple #25
0
 def forward(self, dgl_data):
     if self.getnode and self.getedge:
         dgl_feat = torch.cat([
             dgl.mean_nodes(dgl_data, 'h'),
             dgl.max_nodes(dgl_data, 'h'),
             dgl.mean_edges(dgl_data, 'h'),
             dgl.max_edges(dgl_data, 'h'),
         ], -1)
     elif self.getnode:
         dgl_feat = torch.cat(
             [dgl.mean_nodes(dgl_data, 'h'),
              dgl.max_nodes(dgl_data, 'h')], -1)
     else:
         dgl_feat = torch.cat(
             [dgl.mean_edges(dgl_data, 'h'),
              dgl.max_edges(dgl_data, 'h')], -1)
     dgl_predict = self.activate(self.weight_node(dgl_feat))
     return dgl_predict
Exemple #26
0
 def forward(self, dgl_data):
     dgl_feat, _ = torch.max(
         torch.stack([
             dgl.mean_nodes(dgl_data, 'h'),
             dgl.max_nodes(dgl_data, 'h'),
             dgl.mean_edges(dgl_data, 'h'),
             dgl.max_edges(dgl_data, 'h'),
         ], 2), -1)
     return dgl_feat
Exemple #27
0
 def graph_pooling(self, g):
     h = 0
     if self.graph_pooling_type == 'max':
         hg = dgl.max_nodes(g, 'h')
     elif self.graph_pooling_type == 'mean':
         hg = dgl.mean_nodes(g, 'h')
     elif self.graph_pooling_type == 'sum':
         hg = dgl.sum_nodes(g, 'h')
     return hg
Exemple #28
0
    def forward(self, g, x, e, snorm_n, snorm_e):
        # snorm_n batch中用到的
        # h = self.embedding_h(h)
        # h = self.in_feat_dropout(h)

        h_node = torch.zeros([g.number_of_nodes(),self.node_in_dim]).float().to(self.device)
        h_edge = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device)
        src, dst = g.all_edges()

        for edge_layer, node_layer in zip(self.edge_layers, self.node_layers):
            if self.edge_f:
                if self.dst_f:
                    h_edge = edge_layer(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h_edge, snorm_e = snorm_e)
                    h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x)
                else:
                    h_edge = edge_layer(g, src_feat=x[src], e_feat=e, h_feat=h_edge, snorm_e=snorm_e)
                    h_node = node_layer(g, src_feat=x[src], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x)

            else:
                if self.dst_f:
                    h_edge = edge_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_edge, snorm_e=snorm_e)
                    h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_node, snorm_e=snorm_e, n_feat = x)
                else:
                    h_edge = edge_layer(g, src_feat=x[src], h_feat=h_edge, snorm_e=snorm_e)
                    h_node = node_layer(g, src_feat=x[src], h_feat=h_node, snorm_e=snorm_e, n_feat = x)


        g.edata['h'] = h_edge
        if self.node_update:
            g.ndata['h'] = h_node

        # print("g.data:", g.ndata['h'][0].shape)

        if self.readout == "sum":
            he = dgl.sum_edges(g, 'h')
            hn = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            he = dgl.max_edges(g, 'h')
            hn = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            he = dgl.mean_edges(g, 'h')
            hn = dgl.mean_nodes(g, 'h')
        else:
            he = dgl.mean_edges(g, 'h')  # default readout is mean nodes
            hn = dgl.mean_nodes(g, 'h')

        # print(torch.cat([he, hn], dim=1).shape)
        # used to global task

        out = self.Global_MLP_layer(torch.cat([he, hn], dim=1))

        # used to transition task
        edge_out = self.edge_MLPReadout(h_edge)

        # return self.MLP_layer(he)
        return out
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        mean_nodes = dgl.mean_nodes(graph, 'feat')
        max_nodes = dgl.max_nodes(graph, 'feat')
        mean_max = torch.cat([mean_nodes, max_nodes], dim=-1)
        return self.output(mean_max)
Exemple #30
0
 def readout_fn(readout, graphs, h):
     if readout == "sum":
         hg = dgl.sum_nodes(graphs, h)
     elif readout == "max":
         hg = dgl.max_nodes(graphs, h)
     elif readout == "mean":
         hg = dgl.mean_nodes(graphs, h)
     else:
         hg = dgl.mean_nodes(graphs, h)  # default readout is mean nodes
     return hg