Esempio n. 1
    def forward(self, gl, gr):
        gl = gl.local_var()
        gr = gr.local_var()
        xl = gl.ndata["feat"]
        xr = gr.ndata["feat"]
        for idx in range(self.depth):
            mu_lr, mu_rl = self.attention(xl, xr)
            gl.ndata["x"] = xl
            gr.ndata["x"] = xr
            gl.update_all(dgl.function.copy_src(src='x', out='m'),
                          dgl.function.sum(msg='m', out='x'))
            gr.update_all(dgl.function.copy_src(src='x', out='m'),
                          dgl.function.sum(msg='m', out='x'))

            xl = gl.ndata["x"]
            xr = gr.ndata["x"]

            xl =[xl, mu_rl], dim=-1)
            xr =[xr, mu_lr], dim=-1)

            xl = getattr(self, "ff%s" % idx)(xl)
            xr = getattr(self, "ff%s" % idx)(xr)

        gl.ndata["x"] = xl
        gr.ndata["x"] = xr

        xl = dgl.sum_nodes(gl, "x")
        xr = dgl.sum_nodes(gr, "x")
        x = self.ff([xl, xr], dim=-1))

        return x
Esempio n. 2
    def forward(self, graph, node_feature, qs, ally_node_type_index=NODE_ALLY):
        assert isinstance(graph, dgl.BatchedDGLGraph)

        w_emb = self.w_gn(graph, node_feature)  # [# nodes x # node_dim]
        w = torch.abs(self.w_ff(graph, w_emb))  # [# nodes x # 1]
        ally_node_indices = get_filtered_node_index_by_type(
            graph, ally_node_type_index)

        device = w_emb.device

        _qs = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)
        w = w[ally_node_indices, :]  # [# allies x 1]
        _qs[ally_node_indices, :] = w * qs.view(-1, 1)
        graph.ndata['node_feature'] = _qs
        q_tot = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        v_emb = self.v_gn(graph, node_feature)  # [# nodes x # node_dim]
        v = self.v_ff(graph, v_emb)  # [# nodes x # 1]
        v = v[ally_node_indices, :]  # [# allies x 1]
        _v = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)
        _v[ally_node_indices, :] = v

        graph.ndata['node_feature'] = _v
        v = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        q_tot = q_tot + v
        return q_tot.view(-1)
    def forward(self, g, h, e, snorm_n, snorm_e):

        #   modified dtype for new dataset
        h = h.float()

        h = self.embedding_lin(h.cuda())
        h_in = h  # for residual connection

        # list of hidden representation at each layer (including input)
        hidden_rep = [h]

        for i in range(self.n_layers):
            h = self.ginlayers[i](g, h, snorm_n)

            # Residual Connection
            if self.residual:
                if self.residual == "gated":
                    z = torch.sigmoid(self.W_g([h, h_in], dim=1)))
                    h = z * h + (torch.ones_like(z) - z) * h_in
                    h += h_in

        g.ndata['h'] = self.linear_ro(h)
        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
            hg = dgl.sum_nodes(g, 'h')  # default readout is summation

        score = self.linear_prediction(hg)

        return score
    def forward(self, g, h, e, snorm_n, snorm_e):

        #   modified dtype for new dataset
        h = h.float()

        h = self.embedding_lin(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            h_in = h
            h = conv(g, h, snorm_n)
            if self.residual:
                if self.residual == "gated":
                    z = torch.sigmoid(self.W_g([h, h_in], dim=1)))
                    h = z * h + (torch.ones_like(z) - z) * h_in
                    h += h_in

        g.ndata['h'] = self.linear_ro(h)

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
            hg = dgl.sum_nodes(g, 'h')  # default readout is summation

        return self.linear_predict(hg)
Esempio n. 5
    def forward(self, g):
        g.nodes['word'].data['feat'] = self.dropout(
        g.nodes['concept'].data['feat'] = self.dropout(
        g.edges['A'].data['weight'] = self.dropout(
        g.edges['B'].data['weight'] = self.dropout(
        g.edges['C'].data['weight'] = self.dropout(

        # g.nodes['word'].data['feat'] = self.word_embedding(g.nodes['word'].data['x'])
        # g.nodes['concept'].data['feat'] = self.concept_embedding(g.nodes['concept'].data['x'])
        # g.edges['A'].data['weight'] = self.w_w_embedding(g.edges['A'].data['h'])
        # g.edges['B'].data['weight'] = self.w_c_embedding(g.edges['B'].data['h'])
        # g.edges['C'].data['weight'] = self.c_w_embedding(g.edges['C'].data['h'])

        h = g.ndata['feat']
        h = self.rgcn(g, h)
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            # hg = 0
            hg =
                g, 'h', ntype='word'), dgl.sum_nodes(g, 'h', ntype='concept')),
            # for ntype in g.ntypes:
            #     hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            #     r =, 'h', ntype=ntype))
            return self.classify(hg)
Esempio n. 6
    def get_q(self, graph, node_feature, qs, ws=None):
        device = node_feature.device
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)

        # compute weighted sum of qs
        if ws is None:
            ws = self.get_w(graph, node_feature)  # [#. allies x #. clusters]

        weighted_q = qs.view(-1, 1) * ws  # [#. allies x #. clusters]

        qs = torch.zeros(size=(graph.number_of_nodes(), self.num_clusters), device=device)
        qs[ally_indices, :] = weighted_q

        graph.ndata['q'] = qs
        q_aggregated = dgl.sum_nodes(graph, 'q')  # [#. graph x #. clusters]

        # compute state_dependent_bias
        graph.ndata['node_feature'] = node_feature
        sum_node_feature = dgl.sum_nodes(graph, 'node_feature')  # [#. graph x feature dim]
        q_v = self.q_b_net(sum_node_feature)  # [#. graph x #. clusters]

        _ = graph.ndata.pop('node_feature')
        _ = graph.ndata.pop('q')
        q_aggregated = q_aggregated + q_v

        return q_aggregated  # [#. graph x #. clusters]
Esempio n. 7
def test_simple_readout():
    g1 = dgl.DGLGraph()
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)  # no edges
    g1.add_edges([0, 1, 2], [2, 0, 1])

    n1 = F.randn((3, 5))
    n2 = F.randn((4, 5))
    e1 = F.randn((3, 5))
    s1 = F.sum(n1, 0)  # node sums
    s2 = F.sum(n2, 0)
    se1 = F.sum(e1, 0)  # edge sums
    m1 = F.mean(n1, 0)  # node means
    m2 = F.mean(n2, 0)
    me1 = F.mean(e1, 0)  # edge means
    w1 = F.randn((3, ))
    w2 = F.randn((4, ))
    max1 = F.max(n1, 0)
    max2 = F.max(n2, 0)
    maxe1 = F.max(e1, 0)
    ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0)
    ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0)
    wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0)
    wm2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) / F.sum(F.unsqueeze(w2, 1), 0)
    g1.ndata['x'] = n1
    g2.ndata['x'] = n2
    g1.ndata['w'] = w1
    g2.ndata['w'] = w2
    g1.edata['x'] = e1

    assert F.allclose(dgl.sum_nodes(g1, 'x'), s1)
    assert F.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
    assert F.allclose(dgl.sum_edges(g1, 'x'), se1)
    assert F.allclose(dgl.mean_nodes(g1, 'x'), m1)
    assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
    assert F.allclose(dgl.mean_edges(g1, 'x'), me1)
    assert F.allclose(dgl.max_nodes(g1, 'x'), max1)
    assert F.allclose(dgl.max_edges(g1, 'x'), maxe1)

    g = dgl.batch([g1, g2])
    s = dgl.sum_nodes(g, 'x')
    m = dgl.mean_nodes(g, 'x')
    max_bg = dgl.max_nodes(g, 'x')
    assert F.allclose(s, F.stack([s1, s2], 0))
    assert F.allclose(m, F.stack([m1, m2], 0))
    assert F.allclose(max_bg, F.stack([max1, max2], 0))
    ws = dgl.sum_nodes(g, 'x', 'w')
    wm = dgl.mean_nodes(g, 'x', 'w')
    assert F.allclose(ws, F.stack([ws1, ws2], 0))
    assert F.allclose(wm, F.stack([wm1, wm2], 0))
    s = dgl.sum_edges(g, 'x')
    m = dgl.mean_edges(g, 'x')
    max_bg_e = dgl.max_edges(g, 'x')
    assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
    assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
    assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
Esempio n. 8
    def forward(self, graph, node_feature, sub_q_tots):
        graph.ndata['node_feature'] = node_feature
        device = node_feature.device

        len_groups = []
        q_tot = torch.zeros(graph.batch_size, device=device)
        for i, sub_q_tot in enumerate(sub_q_tots):
            node_indices = get_filtered_node_index_by_assignment(graph, i)


            mask = torch.zeros(size=(node_feature.shape[0], 1), device=device)
            mask[node_indices, :] = 1

                'masked_node_feature'] = graph.ndata['node_feature'] * mask
            w_input = dgl.sum_nodes(graph, 'masked_node_feature')
            if self.rectifier == 'abs':
                q_tot = q_tot + torch.abs(self.w(w_input)).view(-1) * sub_q_tot
            elif self.rectifier == 'softplus':
                q_tot = q_tot + F.softplus(
                    self.w(w_input)).view(-1) * sub_q_tot
                raise RuntimeError("Not implemented rectifier")
            _ = graph.ndata.pop('masked_node_feature')

        # testing
        _ = graph.ndata.pop('node_feature')
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        _v = torch.zeros(size=(graph.number_of_nodes(), node_feature.shape[1]),
        _v[ally_indices, :] = node_feature[ally_indices, :]
        graph.ndata['node_feature'] = _v
        # testing

        v = self.v(dgl.sum_nodes(graph, 'node_feature')).view(-1)
        q_tot = q_tot + v

        _ = graph.ndata.pop('node_feature')

        len_groups = np.array(len_groups)
        ratio_groups = len_groups / np.sum(len_groups)

        print("Num elements in groups {}".format(len_groups))
        print("Num elements ratio {}".format(ratio_groups))

        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        target_assignment_weight = graph.ndata['normalized_score'][

        print("Average normalized scores {}".format(

        return q_tot
Esempio n. 9
    def _take_action(self, action):
        undecided = self.x == 2
        self.x[undecided] = action[undecided]
        self.t += 1

        x1 = (self.x == 1)
        self.g =
        self.g.ndata['h'] = x1.float()
        self.g.update_all(fn.copy_src(src='h', out='m'),
                          fn.sum(msg='m', out='h'))
        x1_deg = self.g.ndata.pop('h')

        ## forgive clashing
        clashed = x1 & (x1_deg > 0)
        self.x[clashed] = 2
        x1_deg[clashed] = 0

        # graph clean up
        still_undecided = (self.x == 2)
        self.x[still_undecided & (x1_deg > 0)] = 0

        # fill timeout with zeros
        still_undecided = (self.x == 2)
        timeout = (self.t == self.max_epi_t)
        self.x[still_undecided & timeout] = 0

        done = self._check_done()
        self.epi_t[~done] += 1

        # compute reward and solution
        x1 = (self.x == 1).float()
        node_sol = x1

        h = node_sol
        self.g.ndata['h'] = h
        next_sol = dgl.sum_nodes(self.g, 'h')

        reward = (next_sol - self.sol)

        if self.hamming_reward_coef > 0.0 and self.num_samples == 2:
            xl, xr = self.x.split(1, dim=1)
            undecidedl, undecidedr = undecided.split(1, dim=1)
            hamming_d = torch.abs(xl.float() - xr.float())
            hamming_d[(xl == 2) | (xr == 2)] = 0.0
            hamming_d[~undecidedl & ~undecidedr] = 0.0
            self.g.ndata['h'] = hamming_d
            hamming_reward = dgl.sum_nodes(self.g, 'h').expand_as(reward)
            reward += self.hamming_reward_coef * hamming_reward

        reward /= self.max_num_nodes

        return reward, next_sol, done
Esempio n. 10
    def forward(self, g, pos):
        normalizer = torch.tensor(g.batch_num_nodes).unsqueeze_(1).float().to(

        g.ndata['a_gp'] = (pos == 0).float()
        gp_embed = dgl.sum_nodes(g, 'h', 'a_gp') / normalizer
        g.ndata['a_p'] = (pos == 1).float()
        p_embed = dgl.mean_nodes(g, 'h', 'a_p')
        g.ndata['a_sib'] = (pos == 2).float()
        sib_embed = dgl.sum_nodes(g, 'h', 'a_sib') / normalizer

        return, p_embed, sib_embed), 1)
Esempio n. 11
def test_sum_case1(idtype):
    # NOTE: If you want to update this test case, remember to update the docstring
    #  example too!!!
    g1 = dgl.graph(([0, 1], [1, 0]), idtype=idtype, device=F.ctx())
    g1.ndata['h'] = F.tensor([1., 2.])
    g2 = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())
    g2.ndata['h'] = F.tensor([1., 2., 3.])
    bg = dgl.batch([g1, g2])
    bg.ndata['w'] = F.tensor([.1, .2, .1, .5, .2])
    assert F.allclose(F.tensor([3.]), dgl.sum_nodes(g1, 'h'))
    assert F.allclose(F.tensor([3., 6.]), dgl.sum_nodes(bg, 'h'))
    assert F.allclose(F.tensor([.5, 1.7]), dgl.sum_nodes(bg, 'h', 'w'))
Esempio n. 12
def test_simple_readout():
    g1 = dgl.DGLGraph()
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)  # no edges
    g1.add_edges([0, 1, 2], [2, 0, 1])

    n1 = th.randn(3, 5)
    n2 = th.randn(4, 5)
    e1 = th.randn(3, 5)
    s1 = n1.sum(0)  # node sums
    s2 = n2.sum(0)
    se1 = e1.sum(0)  # edge sums
    m1 = n1.mean(0)  # node means
    m2 = n2.mean(0)
    me1 = e1.mean(0)  # edge means
    w1 = th.randn(3)
    w2 = th.randn(4)
    ws1 = (n1 * w1[:, None]).sum(0)  # weighted node sums
    ws2 = (n2 * w2[:, None]).sum(0)
    wm1 = (n1 * w1[:, None]).sum(0) / w1[:, None].sum(0)  # weighted node means
    wm2 = (n2 * w2[:, None]).sum(0) / w2[:, None].sum(0)
    g1.ndata['x'] = n1
    g2.ndata['x'] = n2
    g1.ndata['w'] = w1
    g2.ndata['w'] = w2
    g1.edata['x'] = e1

    assert U.allclose(dgl.sum_nodes(g1, 'x'), s1)
    assert U.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
    assert U.allclose(dgl.sum_edges(g1, 'x'), se1)
    assert U.allclose(dgl.mean_nodes(g1, 'x'), m1)
    assert U.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
    assert U.allclose(dgl.mean_edges(g1, 'x'), me1)

    g = dgl.batch([g1, g2])
    s = dgl.sum_nodes(g, 'x')
    m = dgl.mean_nodes(g, 'x')
    assert U.allclose(s, th.stack([s1, s2], 0))
    assert U.allclose(m, th.stack([m1, m2], 0))
    ws = dgl.sum_nodes(g, 'x', 'w')
    wm = dgl.mean_nodes(g, 'x', 'w')
    assert U.allclose(ws, th.stack([ws1, ws2], 0))
    assert U.allclose(wm, th.stack([wm1, wm2], 0))
    s = dgl.sum_edges(g, 'x')
    m = dgl.mean_edges(g, 'x')
    assert U.allclose(s, th.stack([se1, th.zeros(5)], 0))
    assert U.allclose(m, th.stack([me1, th.zeros(5)], 0))
Esempio n. 13
def euclidean_matrix(graphs, dims, readout='sum'):
    '''Returns the pairwise euclidean distance between readout feature from all graphs.
    graphs : list of dgl graphs
    dims : graph features are concatenation of features obtained from all iterations, and this variable has
        the individual feature dimensions for the iterations.
    graphs = dgl.batch(graphs)
    if readout == 'sum':
        graph_reprs = dgl.sum_nodes(graphs, 'h')
    elif readout == 'mean':
        graph_reprs = dgl.mean_nodes(graphs, 'h')
        raise ValueError('Readout for gram_matrix shall be either "mean" or "sum"')
    distances = []

    dims = np.cumsum([0] + dims)

    with torch.no_grad():
        for dim_start, dim_end in zip(dims, dims[1:]):
            features = graph_reprs[:, dim_start:dim_end]

            matrix = dist_matrix(features, features)

    return distances
Esempio n. 14
def gram_matrix(graphs, dims, dist_fn, readout='sum'):
    '''Wrapper function to compute the gram matrix on graphs in batch, returns list of gram matrices as numpy.array
    graphs : list of dgl graphs
    dims : graph features are concatenation of features obtained from all iterations, and this variable has
        the individual feature dimensions for the iterations.
    graphs = dgl.batch(graphs)
    if readout == 'sum':
        graph_reprs = dgl.sum_nodes(graphs, 'h')
    elif readout == 'mean':
        graph_reprs = dgl.mean_nodes(graphs, 'h')
        raise ValueError('Readout for gram_matrix shall be either "mean" or "sum"')
    distances = []

    dims = np.cumsum([0] + dims)

    with torch.no_grad():
        for dim_start, dim_end in zip(dims, dims[1:]):
            features = graph_reprs[:, dim_start:dim_end]

            gram_matrix = dist_fn(features, features)

    return distances
Esempio n. 15
    def forward(self, graph, edge_feat, node_feat, g_repr):
        node_trf_func = lambda x: self.compute_node_repr(
            nodes=x, graph=graph, g_repr=g_repr)

        graph.edata['edge_feat'] = edge_feat
        graph.ndata['node_feat'] = node_feat
        edge_trf_func = lambda x: self.compute_edge_repr(
            edges=x, graph=graph, g_repr=g_repr)

        graph.update_all(self.graph_message_func, self.graph_reduce_func,

        e_comb = dgl.sum_edges(graph, 'edge_feat')
        n_comb = dgl.sum_nodes(graph, 'node_feat')

        e_out = graph.edata['edge_feat']
        n_out = graph.ndata['node_feat']

        e_keys = list(graph.edata.keys())
        n_keys = list(graph.ndata.keys())
        for key in e_keys:
        for key in n_keys:

        return e_out, n_out, self.compute_u_repr(n_comb, e_comb, g_repr)
Esempio n. 16
    def forward(self, g, node_feats):
        r"""Computes graph representations out of node representations.

        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes in the batch of graphs.

        g_feats : float32 tensor of shape (G, node_in_feats)
            Output graph representations. G for the number of graphs in the batch.
        if self.gaussian_expand:
            node_feats = self.gaussian_histogram(node_feats)

        with g.local_scope():
            g.ndata['h'] = node_feats
            g_feats = dgl.sum_nodes(g, 'h')

        if self.gaussian_expand:
            g_feats = self.to_out(g_feats)
            if self.activation is not None:
                g_feats = self.activation(g_feats)

        return g_feats
Esempio n. 17
    def forward(self, bg, feat):
        batch_size = bg.batch_size
        x = bg.ndata[feat]
        batch = tensor([], dtype=torch.int64)
        batch_num_nodes = bg.batch_num_nodes
        for index, num in enumerate(batch_num_nodes):
            batch =, tensor(index).expand(num)))

        h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)),
             x.new_zeros((self.num_layers, batch_size, self.in_channels)))
        q_star = x.new_zeros(batch_size, self.out_channels)

        for i in range(self.processing_steps):
            q, h = self.lstm(q_star.unsqueeze(0), h)
            q = q.view(batch_size, self.in_channels)
            e = (x * q[batch]).sum(dim=-1, keepdim=True)
            a =
                map(lambda x: softmax(x, dim=0),
                    list(torch.split(e, batch_num_nodes)))),
            bg.ndata['w'] = a
            r = dgl.sum_nodes(bg, feat, 'w')
            q_star =[q, r], dim=-1)

        return q_star
Esempio n. 18
    def forward(self,
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            h = conv(g, h, snorm_n)
        g.ndata['h'] = h

        if return_graph:
            return g

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        if mlp:
            return self.MLP_layer(hg)
            if head:
                return self.projection_head(hg)
                return hg
Esempio n. 19
    def forward(self, g, node_feats):
        """Computes graph representations out of node features.

        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_feats)
            Input node features, V for the number of nodes.

        graph_feats : float32 tensor of shape (G, graph_feats)
            Graph representations computed. G for the number of graphs.
        node_feats = self.in_project(node_feats)
        if self.activation is not None:
            node_feats = self.activation(node_feats)
        node_feats = self.out_project(node_feats)

        with g.local_scope():
            g.ndata['h'] = node_feats
            if self.mode == 'max':
                graph_feats = dgl.max_nodes(g, 'h')
            elif self.mode == 'mean':
                graph_feats = dgl.mean_nodes(g, 'h')
            elif self.mode == 'sum':
                graph_feats = dgl.sum_nodes(g, 'h')

        return graph_feats
    def forward(self, g, h, e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            # For reduced graphs
            h = conv(g, h, e)
            # For original graphs
            # h = conv(g, h)
        g.ndata['h'] = h
        if self.readout == "sum":
            # For reduced graphs
            hg = dgl.sum_nodes(g, feat='h', weight='weight')

            # For original graphs
            # hg = dgl.sum_nodes(g, feat= 'h')
        elif self.readout == "max":
            # For reduced graphs
            hg = dgl.max_nodes(g, feat='h', weight='weight')

            # For original graphs
            # hg = dgl.max_nodes(g, feat= 'h')
        elif self.readout == "mean":
            # For reduced graphs
            hg = dgl.mean_nodes(g, feat='h', weight='weight')

            # For original graphs
            # hg = dgl.mean_nodes(g, feat= 'h')
            # For reduced graphs
            hg = dgl.mean_nodes(
                g, feat='h', weight='weight')  # default readout is mean nodes

            # For original graphs
            # hg = dgl.mean_nodes(g, feat= 'h')
        return self.MLP_layer(hg)
Esempio n. 21
    def forward(self, g, h, e, pos_enc=None):

        # input embedding
        if self.pos_enc:
            h = self.embedding_pos_enc(pos_enc)
            h = self.embedding_h(h)

        # computing the 'pseudo' named tensor which depends on node degrees
        g.ndata['deg'] = g.in_degrees()
        pseudo = g.edata['pseudo'].to(self.device).float()

        for i in range(len(self.layers)):
            h = self.layers[i](g, h, self.pseudo_proj[i](pseudo))
        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Esempio n. 22
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.edge_feat:
            e = self.embedding_e(e)

        # Loop all layers
        for i, conv in enumerate(self.layers):
            # Graph conv layers
            h_t = conv(g, h, e, snorm_n)
            h = h_t

            # Virtual node layer
            if self.virtual_node_layers is not None:
                if i == 0:
                    vn_h = 0
                if i < len(self.virtual_node_layers):
                    vn_h, h = self.virtual_node_layers[i].forward(g, h, vn_h)

        g.ndata['h'] = h

        # Readout layer
        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Esempio n. 23
    def forward(self, g):

        self.embedding_layer(g, "node_0")
        if self.atom_ref is not None:
            self.e0(g, "e0")


        for idx in range(self.n_conv):
            self.conv_layers[idx](g, idx + 1)

        node_embeddings = tuple(g.ndata["node_%d" % (i)]
                                for i in range(self.n_conv + 1))
        g.ndata["node"] =, 1)

        # concat multilevel representations
        node = self.node_dense_layer1(g.ndata["node"])
        node = self.activation(node)
        res = self.node_dense_layer2(node)
        g.ndata["res"] = res

        if self.atom_ref is not None:
            g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]

        if self.norm:
            g.ndata["res"] = g.ndata[
                "res"] * self.std_per_node + self.mean_per_node
        res = dgl.sum_nodes(g, "res")
        return res
Esempio n. 24
    def forward(self, g, node_feats, g_feats, get_node_weight=False):
        g : DGLGraph or BatchedDGLGraph
            Constructed DGLGraphs.
        node_feats : float32 tensor of shape (V, N1)
            Input node features. V for the number of nodes and N1 for the feature size.
        g_feats : float32 tensor of shape (G, N2)
            Input graph features. G for the number of graphs and N2 for the feature size.
        get_node_weight : bool
            Whether to get the weights of atoms during readout.

        float32 tensor of shape (G, N2)
            Updated graph features.
        float32 tensor of shape (V, 1)
            The weights of nodes in readout.
        with g.local_scope():
            g.ndata['z'] = self.compute_logits(
      [dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
            g.ndata['a'] = dgl.softmax_nodes(g, 'z')
            g.ndata['hv'] = self.project_nodes(node_feats)
            context = F.elu(dgl.sum_nodes(g, 'hv', 'a'))

            if get_node_weight:
                return self.gru(context, g_feats), g.ndata['a']
                return self.gru(context, g_feats)
Esempio n. 25
    def forward(self, graph, edge_feat, node_feat, g_repr, edge_hidden, node_hidden, graph_hidden):

        graph.edata['edge_feat'] = edge_feat
        graph.ndata['node_feat'] = node_feat
        graph.edata['hidden1'] = edge_hidden[0][0]
        graph.ndata['hidden1'] = node_hidden[0][0]
        graph.edata['hidden2'] = edge_hidden[1][0]
        graph.ndata['hidden2'] = node_hidden[1][0]

        node_trf_func = lambda x : self.compute_node_repr(nodes=x, graph=graph, g_repr=g_repr)
        edge_trf_func = lambda x: self.compute_edge_repr(edges=x, graph=graph, g_repr=g_repr)
        graph.update_all(self.graph_message_func, self.graph_reduce_func, node_trf_func)

        e_comb = dgl.sum_edges(graph, 'edge_feat')
        n_comb = dgl.sum_nodes(graph, 'node_feat')

        u_out, u_hidden = self.compute_u_repr(n_comb, e_comb, g_repr, graph_hidden)

        e_feat = graph.edata['edge_feat']
        n_feat = graph.ndata['node_feat']

        h_e = (torch.unsqueeze(graph.edata['hidden1'],0),torch.unsqueeze(graph.edata['hidden2'],0))
        h_n =  (torch.unsqueeze(graph.ndata['hidden1'],0),torch.unsqueeze(graph.ndata['hidden2'],0))

        e_keys = list(graph.edata.keys())
        n_keys = list(graph.ndata.keys())
        for key in e_keys:
        for key in n_keys:

        return e_feat, h_e, n_feat, h_n, u_out, u_hidden
Esempio n. 26
    def forward(self, g):
        h = g.ndata['attr']
        h =

        # list of hidden representation at each layer (including input)
        hidden_rep = [h]

        for layer in range(self.num_layers - 1):
            h = self.ginlayers[layer](g, h)

        score_over_layer = 0

        # perform pooling over all nodes in each graph in every layer
        for layer, h in enumerate(hidden_rep):
            g.ndata['h'] = h
            if self.graph_pooling_type == 'sum':
                pooled_h = dgl.sum_nodes(g, 'h')
            elif self.graph_pooling_type == 'mean':
                pooled_h = dgl.mean_nodes(g, 'h')
            elif self.graph_pooling_type == 'max':
                pooled_h = dgl.max_nodes(g, 'h')
                raise NotImplementedError()

            score_over_layer += F.dropout(

        return score_over_layer
Esempio n. 27
    def forward(self, g, h, e, h_lap_pos_enc=None, h_wl_pos_enc=None):

        # input embedding
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.lap_pos_enc:
            h_lap_pos_enc = self.embedding_lap_pos_enc(h_lap_pos_enc.float())
            h = h + h_lap_pos_enc
        if self.wl_pos_enc:
            h_wl_pos_enc = self.embedding_wl_pos_enc(h_wl_pos_enc)
            h = h + h_wl_pos_enc
        if not self.edge_feat:  # edge feature set to 1
            e = torch.ones(e.size(0), 1).to(self.device)
        e = self.embedding_e(e)

        # convnets
        for conv in self.layers:
            h, e = conv(g, h, e)
        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Esempio n. 28
    def forward(self, g):
        # g_list list of molecules

        g.edata['distance'] = g.edata['distance'].reshape(-1, 1)

        if self.atom_ref is not None:
            self.e0(g, "e0")
        for idx in range(self.n_conv):

        atom = self.atom_dense_layer1(g.ndata["node"])
        atom = self.activation(atom)
        atom = self.activation(self.atom_dense_layer2(atom))
        res = self.regressor(atom)

        g.ndata["res"] = res

        if self.atom_ref is not None:
            g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]

        if self.norm:
            g.ndata["res"] = g.ndata[
                "res"] * self.std_per_atom + self.mean_per_atom
        res = dgl.sum_nodes(g, "res")
        return res
Esempio n. 29
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        h_init = h
        '''for conv in self.layers:
            h = conv(g, h, snorm_n)
            h = self.joining_layer(h_init + h)'''

        for i in range(self.layer_count):
            conv = self.layers[i]
            joint = self.joining_layers[i]
            h = conv(g, h, snorm_n)
            h = joint(h_init + h)

        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Esempio n. 30
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)

        # computing the 'pseudo' named tensor which depends on node degrees
        us, vs = g.edges()
        # to avoid zero division in case in_degree is 0, we add constant '1' in all node degrees denoting self-loop
        pseudo = [[
            1 / np.sqrt(g.in_degree(us[i]) + 1),
            1 / np.sqrt(g.in_degree(vs[i]) + 1)
        ] for i in range(g.number_of_edges())]
        pseudo = torch.Tensor(pseudo).to(self.device)

        for i in range(len(self.layers)):
            h = self.layers[i](g, h, self.pseudo_proj[i](pseudo), snorm_n)
        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)