예제 #1
0
    def forward(self, input_hidden, graphs: dgl.DGLGraph, batch_num_nodes=None):
        if batch_num_nodes is None:
            b_num_nodes = graphs.batch_num_nodes
        else:
            b_num_nodes = batch_num_nodes
        h_t = self.input_proj(input_hidden)
        # when there are no edges in the graph, there is nothing to do
        if graphs.number_of_edges() > 0:
            #give all the nodes an edges information about the current querry hidden state
            broadcasted_hn = dgl.broadcast_nodes(graphs, h_t)
            graphs.ndata['h_t'] = broadcasted_hn
            broadcasted_he = dgl.broadcast_edges(graphs, h_t)
            graphs.edata['h_t'] = broadcasted_he
            # create a copy of the node and edge states which will be updated for K iterations
            graphs.ndata['F_n_t'] = graphs.ndata['F_n']
            graphs.edata['F_e_t'] = graphs.edata['F_e']

            for _ in range(self.k_update_steps):
                graphs.ndata['s_n'] = self.object_score(torch.cat([graphs.ndata['h_t'], graphs.ndata['F_n_t']], dim=-1))
                graphs.send(message_func=self.io_attention_send)
                graphs.recv(reduce_func=self.io_attention_reduce)
                graphs.ndata['F_n_t'] = graphs.ndata['F_i_tplus1']
                if self.update_relations:
                    graphs.edata['F_e_t'] = graphs.edata['F_e_tplus1']

            io = torch.split(graphs.ndata['F_n_t'], split_size_or_sections=b_num_nodes)
        else:
            io = torch.split(graphs.ndata['F_n'], split_size_or_sections=b_num_nodes)
        io = pad_sequence(io, batch_first=True)
        io_mask = io.sum(dim=-1) != 0

        return io, io_mask
    def forward(self, graph, global_attr=None, out_edge_key='h_e'):
        def send_func(edges):
            edges_to_collect = []
            num_edges = edges.data[self.edge_key].shape[0]
            if self._use_edges:
                edges_to_collect.append(edges.data[self.edge_key])
            if self._use_sender_nodes:
                edges_to_collect.append(edges.src[self.node_key])
            if self._use_receiver_nodes:
                edges_to_collect.append(edges.dst[self.node_key])
            if self._use_globals and global_attr is not None:
                # self._global_attr = global_attr.unsqueeze(0)    # make global_attr.shape = (1, DIM)
                # expanded_global_attr = self._global_attr.expand(num_edges, self._global_attr.shape[1])
                expanded_global_attr = edges.data['expanded_global_attr']
                edges_to_collect.append(expanded_global_attr)

            collected_edges = torch.cat(edges_to_collect, dim=-1)

            if self.recurrent:
                return {
                    out_edge_key:
                    self.net(collected_edges, edges.data[out_edge_key])
                }
            else:
                return {out_edge_key: self.net(collected_edges)}

        graph.edata['expanded_global_attr'] = dgl.broadcast_edges(
            graph, global_attr)
        graph.apply_edges(send_func)

        return graph
예제 #3
0
def test_broadcast_edges():
    # test#1: basic
    g0 = dgl.DGLGraph(nx.path_graph(10))
    feat0 = F.randn((40, ))
    ground_truth = F.stack([feat0] * g0.number_of_edges(), 0)
    assert F.allclose(dgl.broadcast_edges(g0, feat0), ground_truth)

    # test#2: batched graph
    g1 = dgl.DGLGraph(nx.path_graph(3))
    g2 = dgl.DGLGraph()
    g3 = dgl.DGLGraph(nx.path_graph(12))
    bg = dgl.batch([g0, g1, g2, g3])
    feat1 = F.randn((40, ))
    feat2 = F.randn((40, ))
    feat3 = F.randn((40, ))
    ground_truth = F.stack(
        [feat0] * g0.number_of_edges() +\
        [feat1] * g1.number_of_edges() +\
        [feat2] * g2.number_of_edges() +\
        [feat3] * g3.number_of_edges(), 0
    )
    assert F.allclose(
        dgl.broadcast_edges(bg, F.stack([feat0, feat1, feat2, feat3], 0)),
        ground_truth)
예제 #4
0
    def forward(self, g, h):
        for l in range(self.num_layers - 1):
            h, _ = self.gat[l](g, h, merge='flatten')
            h = F.elu(h)
        h, e = self.gat[-1](g, h, merge='mean')

        # Graph level prediction
        g.ndata['h'] = h
        h_readout = dgl.mean_nodes(g, 'h')
        h_pred = self.linear_h(h_readout)

        # Edge prediction
        eh = dgl.broadcast_edges(g, h_readout)
        e_fused = torch.cat((eh, e), dim=1)
        e_pred = self.linear_e(e_fused)

        return h_pred, e_pred
예제 #5
0
def test_broadcast(idtype, g):
    g = g.astype(idtype).to(F.ctx())
    gfeat = F.randn((g.batch_size, 3))

    # Test.0: broadcast_nodes
    g.ndata['h'] = dgl.broadcast_nodes(g, gfeat)
    subg = dgl.unbatch(g)
    for i, sg in enumerate(subg):
        assert F.allclose(
            sg.ndata['h'],
            F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_nodes(), dim=0))

    # Test.1: broadcast_edges
    g.edata['h'] = dgl.broadcast_edges(g, gfeat)
    subg = dgl.unbatch(g)
    for i, sg in enumerate(subg):
        assert F.allclose(
            sg.edata['h'],
            F.repeat(F.reshape(gfeat[i], (1, 3)), sg.number_of_edges(), dim=0))
예제 #6
0
    def forward(self, tree_graphs, tree_vec):
        device = tree_vec.device
        batch_size = tree_graphs.batch_size

        root_ids = get_root_ids(tree_graphs)

        if 'x' not in tree_graphs.ndata:
            tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid'])
        if 'src_x' not in tree_graphs.edata:
            tree_graphs.apply_edges(fn.copy_u('x', 'src_x'))
        tree_graphs = tree_graphs.local_var()
        tree_graphs.apply_edges(func=lambda edges: {'dst_wid': edges.dst['wid']})

        line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False, shared=True)
        line_num_nodes = line_tree_graphs.num_nodes()
        line_tree_graphs.ndata.update({
            'src_x_r': self.W_r(line_tree_graphs.ndata['src_x']),
            # Exploit the fact that the reduce function is a sum of incoming messages,
            # and uncomputed messages are zero vectors.
            'h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
            'vec': dgl.broadcast_edges(tree_graphs, tree_vec),
            'sum_h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
            'sum_gated_h': torch.zeros(line_num_nodes, self.hidden_size).to(device)
        })

        # input tensors for stop prediction (p) and label prediction (q)
        pred_hiddens, pred_mol_vecs, pred_targets = [], [], []
        stop_hiddens, stop_targets = [], []

        # Predict root
        pred_hiddens.append(torch.zeros(batch_size, self.hidden_size).to(device))
        pred_targets.append(tree_graphs.ndata['wid'][root_ids.to(device)])
        pred_mol_vecs.append(tree_vec)

        # Traverse the tree and predict on children
        for eid, p in dfs_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)):
            eid = eid.to(device)
            p = p.to(device=device, dtype=tree_graphs.idtype)

            # Message passing excluding the target
            line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'),
                                  reduce_func=fn.sum('h_nei', 'sum_h'))
            line_tree_graphs.pull(v=eid, message_func=self.gru_message,
                                  reduce_func=fn.sum('m', 'sum_gated_h'))
            line_tree_graphs.apply_nodes(self.gru_update, v=eid)

            # Node aggregation including the target
            # By construction, the edges of the raw graph follow the order of
            # (i1, j1), (j1, i1), (i2, j2), (j2, i2), ... The order of the nodes
            # in the line graph corresponds to the order of the edges in the raw graph.
            eid = eid.long()
            reverse_eid = torch.bitwise_xor(eid, torch.tensor(1).to(device))
            cur_o = line_tree_graphs.ndata['sum_h'][eid] + \
                    line_tree_graphs.ndata['h'][reverse_eid]

            # Gather targets
            mask = (p == torch.tensor(0).to(device))
            pred_list = eid[mask]
            stop_target = torch.tensor(1).to(device) - p

            # Hidden states for stop prediction
            stop_hidden = torch.cat([line_tree_graphs.ndata['src_x'][eid],
                                     cur_o, line_tree_graphs.ndata['vec'][eid]], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_targets.extend(stop_target)

            #Hidden states for clique prediction
            if len(pred_list) > 0:
                pred_mol_vecs.append(line_tree_graphs.ndata['vec'][pred_list])
                pred_hiddens.append(line_tree_graphs.ndata['h'][pred_list])
                pred_targets.append(line_tree_graphs.ndata['dst_wid'][pred_list])

        #Last stop at root
        root_ids = root_ids.to(device)
        cur_x = tree_graphs.ndata['x'][root_ids]
        tree_graphs.edata['h'] = line_tree_graphs.ndata['h']
        tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype),
                         message_func=fn.copy_e('h', 'm'), reduce_func=fn.sum('m', 'cur_o'))
        stop_hidden = torch.cat([cur_x, tree_graphs.ndata['cur_o'][root_ids], tree_vec], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_targets.extend(torch.zeros(batch_size).to(device))

        # Predict next clique
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
        pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
        pred_vecs = F.relu(self.W(pred_vecs))
        pred_scores = self.W_o(pred_vecs)
        pred_targets = torch.cat(pred_targets, dim=0)

        pred_loss = self.pred_loss(pred_scores, pred_targets) / batch_size
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        # Predict stop
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_vecs = F.relu(self.U(stop_hiddens))
        stop_scores = self.U_s(stop_vecs).squeeze()
        stop_targets = torch.Tensor(stop_targets).to(device)

        stop_loss = self.stop_loss(stop_scores, stop_targets) / batch_size
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()