Example #1
0
    def forward(self, gl, gr):
        gl = gl.local_var()
        gr = gr.local_var()
        xl = gl.ndata["feat"]
        xr = gr.ndata["feat"]
        for idx in range(self.depth):
            mu_lr, mu_rl = self.attention(xl, xr)
            gl.ndata["x"] = xl
            gr.ndata["x"] = xr
            gl.update_all(dgl.function.copy_src(src='x', out='m'),
                          dgl.function.sum(msg='m', out='x'))
            gr.update_all(dgl.function.copy_src(src='x', out='m'),
                          dgl.function.sum(msg='m', out='x'))

            xl = gl.ndata["x"]
            xr = gr.ndata["x"]

            xl = torch.cat([xl, mu_rl], dim=-1)
            xr = torch.cat([xr, mu_lr], dim=-1)

            xl = getattr(self, "ff%s" % idx)(xl)
            xr = getattr(self, "ff%s" % idx)(xr)

        gl.ndata["x"] = xl
        gr.ndata["x"] = xr

        xl = dgl.sum_nodes(gl, "x")
        xr = dgl.sum_nodes(gr, "x")
        x = self.ff(torch.cat([xl, xr], dim=-1))

        return x
Example #2
0
    def forward(self, graph, node_feature, qs, ally_node_type_index=NODE_ALLY):
        assert isinstance(graph, dgl.BatchedDGLGraph)

        w_emb = self.w_gn(graph, node_feature)  # [# nodes x # node_dim]
        w = torch.abs(self.w_ff(graph, w_emb))  # [# nodes x # 1]
        ally_node_indices = get_filtered_node_index_by_type(
            graph, ally_node_type_index)

        device = w_emb.device

        _qs = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)
        w = w[ally_node_indices, :]  # [# allies x 1]
        _qs[ally_node_indices, :] = w * qs.view(-1, 1)
        graph.ndata['node_feature'] = _qs
        q_tot = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        v_emb = self.v_gn(graph, node_feature)  # [# nodes x # node_dim]
        v = self.v_ff(graph, v_emb)  # [# nodes x # 1]
        v = v[ally_node_indices, :]  # [# allies x 1]
        _v = torch.zeros(size=(graph.number_of_nodes(), 1), device=device)
        _v[ally_node_indices, :] = v

        graph.ndata['node_feature'] = _v
        v = dgl.sum_nodes(graph, 'node_feature')
        _ = graph.ndata.pop('node_feature')

        q_tot = q_tot + v
        return q_tot.view(-1)
    def forward(self, g, h, e, snorm_n, snorm_e):

        #   modified dtype for new dataset
        h = h.float()

        h = self.embedding_lin(h.cuda())
        h_in = h  # for residual connection

        # list of hidden representation at each layer (including input)
        hidden_rep = [h]

        for i in range(self.n_layers):
            h = self.ginlayers[i](g, h, snorm_n)

            # Residual Connection
            if self.residual:
                if self.residual == "gated":
                    z = torch.sigmoid(self.W_g(torch.cat([h, h_in], dim=1)))
                    h = z * h + (torch.ones_like(z) - z) * h_in
                else:
                    h += h_in

        g.ndata['h'] = self.linear_ro(h)
        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.sum_nodes(g, 'h')  # default readout is summation

        score = self.linear_prediction(hg)

        return score
    def forward(self, g, h, e, snorm_n, snorm_e):

        #   modified dtype for new dataset
        h = h.float()

        h = self.embedding_lin(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            h_in = h
            h = conv(g, h, snorm_n)
            if self.residual:
                if self.residual == "gated":
                    z = torch.sigmoid(self.W_g(torch.cat([h, h_in], dim=1)))
                    h = z * h + (torch.ones_like(z) - z) * h_in
                else:
                    h += h_in

        g.ndata['h'] = self.linear_ro(h)

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.sum_nodes(g, 'h')  # default readout is summation

        return self.linear_predict(hg)
Example #5
0
    def forward(self, g):
        g.nodes['word'].data['feat'] = self.dropout(
            self.word_embedding(g.nodes['word'].data['x']))
        g.nodes['concept'].data['feat'] = self.dropout(
            self.concept_embedding(g.nodes['concept'].data['x']))
        g.edges['A'].data['weight'] = self.dropout(
            self.w_w_embedding(g.edges['A'].data['h']))
        g.edges['B'].data['weight'] = self.dropout(
            self.w_c_embedding(g.edges['B'].data['h']))
        g.edges['C'].data['weight'] = self.dropout(
            self.c_w_embedding(g.edges['C'].data['h']))

        # g.nodes['word'].data['feat'] = self.word_embedding(g.nodes['word'].data['x'])
        # g.nodes['concept'].data['feat'] = self.concept_embedding(g.nodes['concept'].data['x'])
        # g.edges['A'].data['weight'] = self.w_w_embedding(g.edges['A'].data['h'])
        # g.edges['B'].data['weight'] = self.w_c_embedding(g.edges['B'].data['h'])
        # g.edges['C'].data['weight'] = self.c_w_embedding(g.edges['C'].data['h'])

        h = g.ndata['feat']
        h = self.rgcn(g, h)
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            # hg = 0
            hg = torch.cat((dgl.sum_nodes(
                g, 'h', ntype='word'), dgl.sum_nodes(g, 'h', ntype='concept')),
                           -1)
            # for ntype in g.ntypes:
            #     hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            #     r = torch.cat(dgl.sum_nodes(g, 'h', ntype=ntype))
            return self.classify(hg)
Example #6
0
    def get_q(self, graph, node_feature, qs, ws=None):
        device = node_feature.device
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)

        # compute weighted sum of qs
        if ws is None:
            ws = self.get_w(graph, node_feature)  # [#. allies x #. clusters]

        weighted_q = qs.view(-1, 1) * ws  # [#. allies x #. clusters]

        qs = torch.zeros(size=(graph.number_of_nodes(), self.num_clusters), device=device)
        qs[ally_indices, :] = weighted_q

        graph.ndata['q'] = qs
        q_aggregated = dgl.sum_nodes(graph, 'q')  # [#. graph x #. clusters]

        # compute state_dependent_bias
        graph.ndata['node_feature'] = node_feature
        sum_node_feature = dgl.sum_nodes(graph, 'node_feature')  # [#. graph x feature dim]
        q_v = self.q_b_net(sum_node_feature)  # [#. graph x #. clusters]

        _ = graph.ndata.pop('node_feature')
        _ = graph.ndata.pop('q')
        q_aggregated = q_aggregated + q_v

        return q_aggregated  # [#. graph x #. clusters]
Example #7
0
def test_simple_readout():
    g1 = dgl.DGLGraph()
    g1.add_nodes(3)
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)  # no edges
    g1.add_edges([0, 1, 2], [2, 0, 1])

    n1 = F.randn((3, 5))
    n2 = F.randn((4, 5))
    e1 = F.randn((3, 5))
    s1 = F.sum(n1, 0)  # node sums
    s2 = F.sum(n2, 0)
    se1 = F.sum(e1, 0)  # edge sums
    m1 = F.mean(n1, 0)  # node means
    m2 = F.mean(n2, 0)
    me1 = F.mean(e1, 0)  # edge means
    w1 = F.randn((3, ))
    w2 = F.randn((4, ))
    max1 = F.max(n1, 0)
    max2 = F.max(n2, 0)
    maxe1 = F.max(e1, 0)
    ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0)
    ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0)
    wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0)
    wm2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) / F.sum(F.unsqueeze(w2, 1), 0)
    g1.ndata['x'] = n1
    g2.ndata['x'] = n2
    g1.ndata['w'] = w1
    g2.ndata['w'] = w2
    g1.edata['x'] = e1

    assert F.allclose(dgl.sum_nodes(g1, 'x'), s1)
    assert F.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
    assert F.allclose(dgl.sum_edges(g1, 'x'), se1)
    assert F.allclose(dgl.mean_nodes(g1, 'x'), m1)
    assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
    assert F.allclose(dgl.mean_edges(g1, 'x'), me1)
    assert F.allclose(dgl.max_nodes(g1, 'x'), max1)
    assert F.allclose(dgl.max_edges(g1, 'x'), maxe1)

    g = dgl.batch([g1, g2])
    s = dgl.sum_nodes(g, 'x')
    m = dgl.mean_nodes(g, 'x')
    max_bg = dgl.max_nodes(g, 'x')
    assert F.allclose(s, F.stack([s1, s2], 0))
    assert F.allclose(m, F.stack([m1, m2], 0))
    assert F.allclose(max_bg, F.stack([max1, max2], 0))
    ws = dgl.sum_nodes(g, 'x', 'w')
    wm = dgl.mean_nodes(g, 'x', 'w')
    assert F.allclose(ws, F.stack([ws1, ws2], 0))
    assert F.allclose(wm, F.stack([wm1, wm2], 0))
    s = dgl.sum_edges(g, 'x')
    m = dgl.mean_edges(g, 'x')
    max_bg_e = dgl.max_edges(g, 'x')
    assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
    assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
    assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
Example #8
0
    def forward(self, graph, node_feature, sub_q_tots):
        graph.ndata['node_feature'] = node_feature
        device = node_feature.device

        len_groups = []
        q_tot = torch.zeros(graph.batch_size, device=device)
        for i, sub_q_tot in enumerate(sub_q_tots):
            node_indices = get_filtered_node_index_by_assignment(graph, i)

            len_groups.append(len(node_indices))

            mask = torch.zeros(size=(node_feature.shape[0], 1), device=device)
            mask[node_indices, :] = 1

            graph.ndata[
                'masked_node_feature'] = graph.ndata['node_feature'] * mask
            w_input = dgl.sum_nodes(graph, 'masked_node_feature')
            if self.rectifier == 'abs':
                q_tot = q_tot + torch.abs(self.w(w_input)).view(-1) * sub_q_tot
            elif self.rectifier == 'softplus':
                q_tot = q_tot + F.softplus(
                    self.w(w_input)).view(-1) * sub_q_tot
            else:
                raise RuntimeError("Not implemented rectifier")
            _ = graph.ndata.pop('masked_node_feature')

        # testing
        _ = graph.ndata.pop('node_feature')
        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        _v = torch.zeros(size=(graph.number_of_nodes(), node_feature.shape[1]),
                         device=device)
        _v[ally_indices, :] = node_feature[ally_indices, :]
        graph.ndata['node_feature'] = _v
        # testing

        v = self.v(dgl.sum_nodes(graph, 'node_feature')).view(-1)
        q_tot = q_tot + v

        _ = graph.ndata.pop('node_feature')

        len_groups = np.array(len_groups)
        ratio_groups = len_groups / np.sum(len_groups)

        print("Num elements in groups {}".format(len_groups))
        print("Num elements ratio {}".format(ratio_groups))

        ally_indices = get_filtered_node_index_by_type(graph, NODE_ALLY)
        target_assignment_weight = graph.ndata['normalized_score'][
            ally_indices]

        print("Average normalized scores {}".format(
            target_assignment_weight.mean(0)))

        return q_tot
Example #9
0
    def _take_action(self, action):
        undecided = self.x == 2
        self.x[undecided] = action[undecided]
        self.t += 1

        x1 = (self.x == 1)
        self.g = self.g.to(self.device)
        self.g.ndata['h'] = x1.float()
        self.g.update_all(fn.copy_src(src='h', out='m'),
                          fn.sum(msg='m', out='h'))
        x1_deg = self.g.ndata.pop('h')

        ## forgive clashing
        clashed = x1 & (x1_deg > 0)
        self.x[clashed] = 2
        x1_deg[clashed] = 0

        # graph clean up
        still_undecided = (self.x == 2)
        self.x[still_undecided & (x1_deg > 0)] = 0

        # fill timeout with zeros
        still_undecided = (self.x == 2)
        timeout = (self.t == self.max_epi_t)
        self.x[still_undecided & timeout] = 0

        done = self._check_done()
        self.epi_t[~done] += 1

        # compute reward and solution
        x1 = (self.x == 1).float()
        node_sol = x1

        h = node_sol
        self.g.ndata['h'] = h
        next_sol = dgl.sum_nodes(self.g, 'h')
        self.g.ndata.pop('h')

        reward = (next_sol - self.sol)

        if self.hamming_reward_coef > 0.0 and self.num_samples == 2:
            xl, xr = self.x.split(1, dim=1)
            undecidedl, undecidedr = undecided.split(1, dim=1)
            hamming_d = torch.abs(xl.float() - xr.float())
            hamming_d[(xl == 2) | (xr == 2)] = 0.0
            hamming_d[~undecidedl & ~undecidedr] = 0.0
            self.g.ndata['h'] = hamming_d
            hamming_reward = dgl.sum_nodes(self.g, 'h').expand_as(reward)
            self.g.ndata.pop('h')
            reward += self.hamming_reward_coef * hamming_reward

        reward /= self.max_num_nodes

        return reward, next_sol, done
Example #10
0
    def forward(self, g, pos):
        normalizer = torch.tensor(g.batch_num_nodes).unsqueeze_(1).float().to(
            pos.device)

        g.ndata['a_gp'] = (pos == 0).float()
        gp_embed = dgl.sum_nodes(g, 'h', 'a_gp') / normalizer
        g.ndata['a_p'] = (pos == 1).float()
        p_embed = dgl.mean_nodes(g, 'h', 'a_p')
        g.ndata['a_sib'] = (pos == 2).float()
        sib_embed = dgl.sum_nodes(g, 'h', 'a_sib') / normalizer

        return torch.cat((gp_embed, p_embed, sib_embed), 1)
Example #11
0
def test_sum_case1(idtype):
    # NOTE: If you want to update this test case, remember to update the docstring
    #  example too!!!
    g1 = dgl.graph(([0, 1], [1, 0]), idtype=idtype, device=F.ctx())
    g1.ndata['h'] = F.tensor([1., 2.])
    g2 = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())
    g2.ndata['h'] = F.tensor([1., 2., 3.])
    bg = dgl.batch([g1, g2])
    bg.ndata['w'] = F.tensor([.1, .2, .1, .5, .2])
    assert F.allclose(F.tensor([3.]), dgl.sum_nodes(g1, 'h'))
    assert F.allclose(F.tensor([3., 6.]), dgl.sum_nodes(bg, 'h'))
    assert F.allclose(F.tensor([.5, 1.7]), dgl.sum_nodes(bg, 'h', 'w'))
Example #12
0
def test_simple_readout():
    g1 = dgl.DGLGraph()
    g1.add_nodes(3)
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)  # no edges
    g1.add_edges([0, 1, 2], [2, 0, 1])

    n1 = th.randn(3, 5)
    n2 = th.randn(4, 5)
    e1 = th.randn(3, 5)
    s1 = n1.sum(0)  # node sums
    s2 = n2.sum(0)
    se1 = e1.sum(0)  # edge sums
    m1 = n1.mean(0)  # node means
    m2 = n2.mean(0)
    me1 = e1.mean(0)  # edge means
    w1 = th.randn(3)
    w2 = th.randn(4)
    ws1 = (n1 * w1[:, None]).sum(0)  # weighted node sums
    ws2 = (n2 * w2[:, None]).sum(0)
    wm1 = (n1 * w1[:, None]).sum(0) / w1[:, None].sum(0)  # weighted node means
    wm2 = (n2 * w2[:, None]).sum(0) / w2[:, None].sum(0)
    g1.ndata['x'] = n1
    g2.ndata['x'] = n2
    g1.ndata['w'] = w1
    g2.ndata['w'] = w2
    g1.edata['x'] = e1

    assert U.allclose(dgl.sum_nodes(g1, 'x'), s1)
    assert U.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
    assert U.allclose(dgl.sum_edges(g1, 'x'), se1)
    assert U.allclose(dgl.mean_nodes(g1, 'x'), m1)
    assert U.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
    assert U.allclose(dgl.mean_edges(g1, 'x'), me1)

    g = dgl.batch([g1, g2])
    s = dgl.sum_nodes(g, 'x')
    m = dgl.mean_nodes(g, 'x')
    assert U.allclose(s, th.stack([s1, s2], 0))
    assert U.allclose(m, th.stack([m1, m2], 0))
    ws = dgl.sum_nodes(g, 'x', 'w')
    wm = dgl.mean_nodes(g, 'x', 'w')
    assert U.allclose(ws, th.stack([ws1, ws2], 0))
    assert U.allclose(wm, th.stack([wm1, wm2], 0))
    s = dgl.sum_edges(g, 'x')
    m = dgl.mean_edges(g, 'x')
    assert U.allclose(s, th.stack([se1, th.zeros(5)], 0))
    assert U.allclose(m, th.stack([me1, th.zeros(5)], 0))
Example #13
0
def euclidean_matrix(graphs, dims, readout='sum'):
    '''Returns the pairwise euclidean distance between readout feature from all graphs.
    graphs : list of dgl graphs
    dims : graph features are concatenation of features obtained from all iterations, and this variable has
        the individual feature dimensions for the iterations.
    '''
    graphs = dgl.batch(graphs)
    if readout == 'sum':
        graph_reprs = dgl.sum_nodes(graphs, 'h')
    elif readout == 'mean':
        graph_reprs = dgl.mean_nodes(graphs, 'h')
    else:
        raise ValueError('Readout for gram_matrix shall be either "mean" or "sum"')
    
    distances = []

    dims = np.cumsum([0] + dims)

    with torch.no_grad():
        for dim_start, dim_end in zip(dims, dims[1:]):
            features = graph_reprs[:, dim_start:dim_end]

            matrix = dist_matrix(features, features)
            distances.append(matrix.cpu().numpy())

    return distances
Example #14
0
def gram_matrix(graphs, dims, dist_fn, readout='sum'):
    '''Wrapper function to compute the gram matrix on graphs in batch, returns list of gram matrices as numpy.array
    graphs : list of dgl graphs
    dims : graph features are concatenation of features obtained from all iterations, and this variable has
        the individual feature dimensions for the iterations.
    '''
    graphs = dgl.batch(graphs)
    if readout == 'sum':
        graph_reprs = dgl.sum_nodes(graphs, 'h')
    elif readout == 'mean':
        graph_reprs = dgl.mean_nodes(graphs, 'h')
    else:
        raise ValueError('Readout for gram_matrix shall be either "mean" or "sum"')
    
    distances = []

    dims = np.cumsum([0] + dims)

    with torch.no_grad():
        for dim_start, dim_end in zip(dims, dims[1:]):
            features = graph_reprs[:, dim_start:dim_end]

            gram_matrix = dist_fn(features, features)
            distances.append(gram_matrix.cpu().numpy())

    return distances
Example #15
0
    def forward(self, graph, edge_feat, node_feat, g_repr):
        node_trf_func = lambda x: self.compute_node_repr(
            nodes=x, graph=graph, g_repr=g_repr)

        graph.edata['edge_feat'] = edge_feat
        graph.ndata['node_feat'] = node_feat
        edge_trf_func = lambda x: self.compute_edge_repr(
            edges=x, graph=graph, g_repr=g_repr)

        graph.apply_edges(edge_trf_func)
        graph.update_all(self.graph_message_func, self.graph_reduce_func,
                         node_trf_func)

        e_comb = dgl.sum_edges(graph, 'edge_feat')
        n_comb = dgl.sum_nodes(graph, 'node_feat')

        e_out = graph.edata['edge_feat']
        n_out = graph.ndata['node_feat']

        e_keys = list(graph.edata.keys())
        n_keys = list(graph.ndata.keys())
        for key in e_keys:
            graph.edata.pop(key)
        for key in n_keys:
            graph.ndata.pop(key)

        return e_out, n_out, self.compute_u_repr(n_comb, e_comb, g_repr)
Example #16
0
    def forward(self, g, node_feats):
        r"""Computes graph representations out of node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes in the batch of graphs.

        Returns
        -------
        g_feats : float32 tensor of shape (G, node_in_feats)
            Output graph representations. G for the number of graphs in the batch.
        """
        if self.gaussian_expand:
            node_feats = self.gaussian_histogram(node_feats)

        with g.local_scope():
            g.ndata['h'] = node_feats
            g_feats = dgl.sum_nodes(g, 'h')

        if self.gaussian_expand:
            g_feats = self.to_out(g_feats)
            if self.activation is not None:
                g_feats = self.activation(g_feats)

        return g_feats
    def forward(self, bg, feat):
        batch_size = bg.batch_size
        x = bg.ndata[feat]
        batch = tensor([], dtype=torch.int64)
        batch_num_nodes = bg.batch_num_nodes
        for index, num in enumerate(batch_num_nodes):
            batch = torch.cat((batch, tensor(index).expand(num)))

        h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)),
             x.new_zeros((self.num_layers, batch_size, self.in_channels)))
        q_star = x.new_zeros(batch_size, self.out_channels)

        for i in range(self.processing_steps):
            q, h = self.lstm(q_star.unsqueeze(0), h)
            q = q.view(batch_size, self.in_channels)
            e = (x * q[batch]).sum(dim=-1, keepdim=True)
            a = torch.cat(list(
                map(lambda x: softmax(x, dim=0),
                    list(torch.split(e, batch_num_nodes)))),
                          dim=0)
            bg.ndata['w'] = a
            r = dgl.sum_nodes(bg, feat, 'w')
            q_star = torch.cat([q, r], dim=-1)

        return q_star
Example #18
0
    def forward(self,
                g,
                h,
                e,
                snorm_n,
                snorm_e,
                mlp=True,
                head=False,
                return_graph=False):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            h = conv(g, h, snorm_n)
        g.ndata['h'] = h

        if return_graph:
            return g

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        if mlp:
            return self.MLP_layer(hg)
        else:
            if head:
                return self.projection_head(hg)
            else:
                return hg
Example #19
0
    def forward(self, g, node_feats):
        """Computes graph representations out of node features.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_feats)
            Input node features, V for the number of nodes.

        Returns
        -------
        graph_feats : float32 tensor of shape (G, graph_feats)
            Graph representations computed. G for the number of graphs.
        """
        node_feats = self.in_project(node_feats)
        if self.activation is not None:
            node_feats = self.activation(node_feats)
        node_feats = self.out_project(node_feats)

        with g.local_scope():
            g.ndata['h'] = node_feats
            if self.mode == 'max':
                graph_feats = dgl.max_nodes(g, 'h')
            elif self.mode == 'mean':
                graph_feats = dgl.mean_nodes(g, 'h')
            elif self.mode == 'sum':
                graph_feats = dgl.sum_nodes(g, 'h')

        return graph_feats
    def forward(self, g, h, e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            # For reduced graphs
            h = conv(g, h, e)
            # For original graphs
            # h = conv(g, h)
        g.ndata['h'] = h
        if self.readout == "sum":
            # For reduced graphs
            hg = dgl.sum_nodes(g, feat='h', weight='weight')

            # For original graphs
            # hg = dgl.sum_nodes(g, feat= 'h')
        elif self.readout == "max":
            # For reduced graphs
            hg = dgl.max_nodes(g, feat='h', weight='weight')

            # For original graphs
            # hg = dgl.max_nodes(g, feat= 'h')
        elif self.readout == "mean":
            # For reduced graphs
            hg = dgl.mean_nodes(g, feat='h', weight='weight')

            # For original graphs
            # hg = dgl.mean_nodes(g, feat= 'h')
        else:
            # For reduced graphs
            hg = dgl.mean_nodes(
                g, feat='h', weight='weight')  # default readout is mean nodes

            # For original graphs
            # hg = dgl.mean_nodes(g, feat= 'h')
        return self.MLP_layer(hg)
Example #21
0
    def forward(self, g, h, e, pos_enc=None):

        # input embedding
        if self.pos_enc:
            h = self.embedding_pos_enc(pos_enc)
        else:
            h = self.embedding_h(h)

        # computing the 'pseudo' named tensor which depends on node degrees
        g.ndata['deg'] = g.in_degrees()
        g.apply_edges(self.compute_pseudo)
        pseudo = g.edata['pseudo'].to(self.device).float()

        for i in range(len(self.layers)):
            h = self.layers[i](g, h, self.pseudo_proj[i](pseudo))
        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Example #22
0
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.edge_feat:
            e = self.embedding_e(e)

        # Loop all layers
        for i, conv in enumerate(self.layers):
            # Graph conv layers
            h_t = conv(g, h, e, snorm_n)
            h = h_t

            # Virtual node layer
            if self.virtual_node_layers is not None:
                if i == 0:
                    vn_h = 0
                if i < len(self.virtual_node_layers):
                    vn_h, h = self.virtual_node_layers[i].forward(g, h, vn_h)

        g.ndata['h'] = h

        # Readout layer
        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Example #23
0
    def forward(self, g):

        self.embedding_layer(g, "node_0")
        if self.atom_ref is not None:
            self.e0(g, "e0")
        self.rbf_layer(g)

        self.edge_embedding_layer(g)

        for idx in range(self.n_conv):
            self.conv_layers[idx](g, idx + 1)

        node_embeddings = tuple(g.ndata["node_%d" % (i)]
                                for i in range(self.n_conv + 1))
        g.ndata["node"] = th.cat(node_embeddings, 1)

        # concat multilevel representations
        node = self.node_dense_layer1(g.ndata["node"])
        node = self.activation(node)
        res = self.node_dense_layer2(node)
        g.ndata["res"] = res

        if self.atom_ref is not None:
            g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]

        if self.norm:
            g.ndata["res"] = g.ndata[
                "res"] * self.std_per_node + self.mean_per_node
        res = dgl.sum_nodes(g, "res")
        return res
Example #24
0
    def forward(self, g, node_feats, g_feats, get_node_weight=False):
        """
        Parameters
        ----------
        g : DGLGraph or BatchedDGLGraph
            Constructed DGLGraphs.
        node_feats : float32 tensor of shape (V, N1)
            Input node features. V for the number of nodes and N1 for the feature size.
        g_feats : float32 tensor of shape (G, N2)
            Input graph features. G for the number of graphs and N2 for the feature size.
        get_node_weight : bool
            Whether to get the weights of atoms during readout.

        Returns
        -------
        float32 tensor of shape (G, N2)
            Updated graph features.
        float32 tensor of shape (V, 1)
            The weights of nodes in readout.
        """
        with g.local_scope():
            g.ndata['z'] = self.compute_logits(
                torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
            g.ndata['a'] = dgl.softmax_nodes(g, 'z')
            g.ndata['hv'] = self.project_nodes(node_feats)
            context = F.elu(dgl.sum_nodes(g, 'hv', 'a'))

            if get_node_weight:
                return self.gru(context, g_feats), g.ndata['a']
            else:
                return self.gru(context, g_feats)
Example #25
0
    def forward(self, graph, edge_feat, node_feat, g_repr, edge_hidden, node_hidden, graph_hidden):

        graph.edata['edge_feat'] = edge_feat
        graph.ndata['node_feat'] = node_feat
        graph.edata['hidden1'] = edge_hidden[0][0]
        graph.ndata['hidden1'] = node_hidden[0][0]
        graph.edata['hidden2'] = edge_hidden[1][0]
        graph.ndata['hidden2'] = node_hidden[1][0]

        node_trf_func = lambda x : self.compute_node_repr(nodes=x, graph=graph, g_repr=g_repr)
        edge_trf_func = lambda x: self.compute_edge_repr(edges=x, graph=graph, g_repr=g_repr)
        graph.apply_edges(edge_trf_func)
        graph.update_all(self.graph_message_func, self.graph_reduce_func, node_trf_func)

        e_comb = dgl.sum_edges(graph, 'edge_feat')
        n_comb = dgl.sum_nodes(graph, 'node_feat')

        u_out, u_hidden = self.compute_u_repr(n_comb, e_comb, g_repr, graph_hidden)

        e_feat = graph.edata['edge_feat']
        n_feat = graph.ndata['node_feat']

        h_e = (torch.unsqueeze(graph.edata['hidden1'],0),torch.unsqueeze(graph.edata['hidden2'],0))
        h_n =  (torch.unsqueeze(graph.ndata['hidden1'],0),torch.unsqueeze(graph.ndata['hidden2'],0))

        e_keys = list(graph.edata.keys())
        n_keys = list(graph.ndata.keys())
        for key in e_keys:
            graph.edata.pop(key)
        for key in n_keys:
            graph.ndata.pop(key)

        return e_feat, h_e, n_feat, h_n, u_out, u_hidden
Example #26
0
    def forward(self, g):
        h = g.ndata['attr']
        h = h.to(self.device)

        # list of hidden representation at each layer (including input)
        hidden_rep = [h]

        for layer in range(self.num_layers - 1):
            h = self.ginlayers[layer](g, h)
            hidden_rep.append(h)

        score_over_layer = 0

        # perform pooling over all nodes in each graph in every layer
        for layer, h in enumerate(hidden_rep):
            g.ndata['h'] = h
            if self.graph_pooling_type == 'sum':
                pooled_h = dgl.sum_nodes(g, 'h')
            elif self.graph_pooling_type == 'mean':
                pooled_h = dgl.mean_nodes(g, 'h')
            elif self.graph_pooling_type == 'max':
                pooled_h = dgl.max_nodes(g, 'h')
            else:
                raise NotImplementedError()

            score_over_layer += F.dropout(
                self.linears_prediction[layer](pooled_h),
                self.final_dropout,
                training=self.training)

        return score_over_layer
Example #27
0
    def forward(self, g, h, e, h_lap_pos_enc=None, h_wl_pos_enc=None):

        # input embedding
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.lap_pos_enc:
            h_lap_pos_enc = self.embedding_lap_pos_enc(h_lap_pos_enc.float())
            h = h + h_lap_pos_enc
        if self.wl_pos_enc:
            h_wl_pos_enc = self.embedding_wl_pos_enc(h_wl_pos_enc)
            h = h + h_wl_pos_enc
        if not self.edge_feat:  # edge feature set to 1
            e = torch.ones(e.size(0), 1).to(self.device)
        e = self.embedding_e(e)

        # convnets
        for conv in self.layers:
            h, e = conv(g, h, e)
        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Example #28
0
    def forward(self, g):
        # g_list list of molecules

        g.edata['distance'] = g.edata['distance'].reshape(-1, 1)

        self.embedding_layer(g)
        if self.atom_ref is not None:
            self.e0(g, "e0")
        self.rbf_layer(g)
        for idx in range(self.n_conv):
            self.conv_layers[idx](g)

        atom = self.atom_dense_layer1(g.ndata["node"])
        atom = self.activation(atom)
        atom = self.activation(self.atom_dense_layer2(atom))
        res = self.regressor(atom)

        g.ndata["res"] = res

        if self.atom_ref is not None:
            g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]

        if self.norm:
            g.ndata["res"] = g.ndata[
                "res"] * self.std_per_atom + self.mean_per_atom
        res = dgl.sum_nodes(g, "res")
        return res
Example #29
0
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        h_init = h
        '''for conv in self.layers:
            h = conv(g, h, snorm_n)
            h = self.joining_layer(h_init + h)'''

        for i in range(self.layer_count):
            conv = self.layers[i]
            joint = self.joining_layers[i]
            h = conv(g, h, snorm_n)
            h = joint(h_init + h)

        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Example #30
0
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)

        # computing the 'pseudo' named tensor which depends on node degrees
        us, vs = g.edges()
        # to avoid zero division in case in_degree is 0, we add constant '1' in all node degrees denoting self-loop
        pseudo = [[
            1 / np.sqrt(g.in_degree(us[i]) + 1),
            1 / np.sqrt(g.in_degree(vs[i]) + 1)
        ] for i in range(g.number_of_edges())]
        pseudo = torch.Tensor(pseudo).to(self.device)

        for i in range(len(self.layers)):
            h = self.layers[i](g, h, self.pseudo_proj[i](pseudo), snorm_n)
        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)