Beispiel #1
0
    def forward(self, cand_batch, mol_tree_batch):
        cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = \
                mol2dgl(cand_batch, mol_tree_batch)

        n_samples = len(cand_graphs)

        cand_graphs = batch(cand_graphs)
        cand_line_graph = line_graph(cand_graphs, no_backtracking=True)

        n_nodes = len(cand_graphs.nodes)
        n_edges = len(cand_graphs.edges)

        cand_graphs = self.run(cand_graphs, cand_line_graph,
                               tree_mess_src_edges, tree_mess_tgt_edges,
                               tree_mess_tgt_nodes, mol_tree_batch)

        cand_graphs = unbatch(cand_graphs)
        g_repr = torch.stack(
            [g.get_n_repr()['h'].mean(0) for g in cand_graphs], 0)

        self.n_samples_total += n_samples
        self.n_nodes_total += n_nodes
        self.n_edges_total += n_edges
        self.n_passes += 1

        return g_repr
Beispiel #2
0
    def _load_graph(self):
        num_graphs = self.label.shape[0]
        graphs = []
        line_graphs = []

        for idx in trange(num_graphs):
            n_atoms = self.N[idx]
            # get all the atomic coordinates of the idx-th molecular graph
            R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]]
            # calculate the distance between all atoms
            dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
            # keep all edges that don't exceed the cutoff and delete self-loops
            adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms,
                                                              dtype=np.bool)
            adj = adj.tocoo()
            u, v = torch.tensor(adj.row), torch.tensor(adj.col)
            g = dgl_graph((u, v))
            g.ndata['R'] = torch.tensor(R, dtype=torch.float32)
            g.ndata['Z'] = torch.tensor(
                self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
                dtype=torch.long)

            # add user-defined features
            if self.edge_funcs is not None:
                for func in self.edge_funcs:
                    g.apply_edges(func)

            graphs.append(g)
            l_g = dgl.line_graph(g, backtracking=False)
            line_graphs.append(l_g)

        return graphs, line_graphs
Beispiel #3
0
    def forward(self, mol_trees):
        mol_tree_batch = batch(mol_trees)

        # Build line graph to prepare for belief propagation
        mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True)

        return self.run(mol_tree_batch, mol_tree_batch_lg)
Beispiel #4
0
    def forward(self, mol_graph):
        mol_graph = mol_graph.local_var()
        line_mol_graph = dgl.line_graph(mol_graph, backtracking=False)

        line_input = self.W_i(mol_graph.edata['x'])
        line_mol_graph.ndata['msg_input'] = line_input
        line_mol_graph.ndata['msg'] = F.relu(line_input)

        # Message passing over the line graph
        for _ in range(self.depth - 1):
            line_mol_graph.update_all(message_func=fn.copy_u('msg', 'msg'),
                                      reduce_func=fn.sum('msg', 'nei_msg'))
            nei_msg = self.W_h(line_mol_graph.ndata['nei_msg'])
            line_mol_graph.ndata['msg'] = F.relu(line_input + nei_msg)

        # Message passing over the raw graph
        mol_graph.edata['msg'] = line_mol_graph.ndata['msg']
        mol_graph.update_all(message_func=fn.copy_e('msg', 'msg'),
                             reduce_func=fn.sum('msg', 'nei_msg'))

        raw_input = torch.cat([mol_graph.ndata['x'], mol_graph.ndata['nei_msg']], dim=1)
        mol_graph.ndata['atom_hiddens'] = self.W_o(raw_input)

        # Readout
        mol_vecs = dgl.mean_nodes(mol_graph, 'atom_hiddens')

        return mol_vecs
Beispiel #5
0
    def forward(self, mol_tree_batch):
        # Build line graph to prepare for belief propagation
        mol_tree_batch_lg = dgl.line_graph(mol_tree_batch,
                                           backtracking=False,
                                           shared=True)
        mol_tree_batch_lg._node_frames = mol_tree_batch._edge_frames

        return self.run(mol_tree_batch, mol_tree_batch_lg)
Beispiel #6
0
    def forward(self, mol_trees, tree_vec):
        '''
        The training procedure which computes the prediction loss given the
        ground truth tree
        '''
        mol_tree_batch = batch(mol_trees)
        mol_tree_batch_lg = line_graph(mol_tree_batch, no_backtracking=True)
        n_trees = len(mol_trees)

        return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
Beispiel #7
0
    def forward(self, mol_graph):
        mol_line_graph = dgl.line_graph(mol_graph, backtracking=False, shared=True)
        mol_line_graph._node_frames = mol_graph._edge_frames

        mol_graph = self.run(mol_graph, mol_line_graph)

        # TODO: replace with unbatch or readout
        g_repr = mean_nodes(mol_graph, 'h')

        return g_repr
Beispiel #8
0
    def forward(self, mol_trees):
        mol_tree_batch = batch(mol_trees)
        if torch.cuda.is_available() and not os.getenv('NOCUDA', None):
            mol_tree_batch = mol_tree_batch.to('cuda:0')

        # Build line graph to prepare for belief propagation
        mol_tree_batch_lg = dgl.line_graph(mol_tree_batch,
                                           backtracking=False,
                                           shared=True)

        return self.run(mol_tree_batch, mol_tree_batch_lg)
Beispiel #9
0
    def forward(self, mol_trees, tree_vec):
        '''
        The training procedure which computes the prediction loss given the
        ground truth tree
        '''
        mol_tree_batch = batch(mol_trees)
        mol_tree_batch_lg = dgl.line_graph(mol_tree_batch,
                                           backtracking=False,
                                           shared=True)
        mol_tree_batch_lg._node_frames = mol_tree_batch._edge_frames
        n_trees = len(mol_trees)

        return self.run(mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec)
Beispiel #10
0
    def forward(self, cand_batch, mol_tree_batch):
        cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch

        cand_line_graph = dgl.line_graph(cand_graphs, backtracking=False, shared=True)
        cand_line_graph._node_frames = cand_graphs._edge_frames

        cand_graphs = self.run(
            cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges,
            tree_mess_tgt_nodes, mol_tree_batch)

        g_repr = mean_nodes(cand_graphs, 'h')

        return g_repr
Beispiel #11
0
    def forward(self, mol_graph):
        n_samples = mol_graph.batch_size

        mol_line_graph = line_graph(mol_graph, backtracking=False, shared=True)

        n_nodes = mol_graph.number_of_nodes()
        n_edges = mol_graph.number_of_edges()

        mol_graph = self.run(mol_graph, mol_line_graph)

        # TODO: replace with unbatch or readout
        g_repr = mean_nodes(mol_graph, 'h')

        self.n_samples_total += n_samples
        self.n_nodes_total += n_nodes
        self.n_edges_total += n_edges
        self.n_passes += 1

        return g_repr
Beispiel #12
0
    def forward(self, tree_graphs):
        device = tree_graphs.device
        if 'x' not in tree_graphs.ndata:
            tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid'])
        tree_graphs.apply_edges(fn.copy_u('x', 'src_x'))
        tree_graphs = tree_graphs.local_var()

        line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False)
        line_tree_graphs.ndata.update({
            'src_x': tree_graphs.edata['src_x'],
            'src_x_r': self.W_r(tree_graphs.edata['src_x']),
            # Exploit the fact that the reduce function is a sum of incoming messages,
            # and uncomputed messages are zero vectors.
            'h': torch.zeros(line_tree_graphs.num_nodes(), self.hidden_size).to(device),
            'sum_h': torch.zeros(line_tree_graphs.num_nodes(), self.hidden_size).to(device),
            'sum_gated_h': torch.zeros(line_tree_graphs.num_nodes(), self.hidden_size).to(device)
        })

        # Get the ID of the root nodes, the first node of all trees
        root_ids = get_root_ids(tree_graphs)

        for eid in level_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)):
            line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'),
                                  reduce_func=fn.sum('h_nei', 'sum_h'))
            line_tree_graphs.pull(v=eid, message_func=self.gru_message,
                                  reduce_func=fn.sum('m', 'sum_gated_h'))
            line_tree_graphs.apply_nodes(self.gru_update, v=eid)

        # Readout
        tree_graphs.ndata['h'] = torch.zeros(tree_graphs.num_nodes(), self.hidden_size).to(device)
        tree_graphs.edata['h'] = line_tree_graphs.ndata['h']
        root_ids = root_ids.to(device)
        tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype),
                         message_func=fn.copy_e('h', 'm'),
                         reduce_func=fn.sum('m', 'h'))
        root_vec = torch.cat([
            tree_graphs.ndata['x'][root_ids],
            tree_graphs.ndata['h'][root_ids]
        ], dim=1)
        root_vec = self.W(root_vec)

        return tree_graphs.edata['h'], root_vec
Beispiel #13
0
    def forward(self, cand_batch, mol_tree_batch):
        cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = cand_batch

        n_samples = len(cand_graphs)

        cand_line_graph = line_graph(cand_graphs,
                                     backtracking=False,
                                     shared=True)

        n_nodes = cand_graphs.number_of_nodes()
        n_edges = cand_graphs.number_of_edges()

        cand_graphs = self.run(cand_graphs, cand_line_graph,
                               tree_mess_src_edges, tree_mess_tgt_edges,
                               tree_mess_tgt_nodes, mol_tree_batch)

        g_repr = mean_nodes(cand_graphs, 'h')

        self.n_samples_total += n_samples
        self.n_nodes_total += n_nodes
        self.n_edges_total += n_edges
        self.n_passes += 1

        return g_repr
Beispiel #14
0
    def decode(self, mol_vec):
        assert mol_vec.shape[0] == 1

        mol_tree = DGLMolTree(None)

        init_hidden = cuda(torch.zeros(1, self.hidden_size))

        root_hidden = torch.cat([init_hidden, mol_vec], 1)
        root_hidden = F.relu(self.W(root_hidden))
        root_score = self.W_o(root_hidden)
        _, root_wid = torch.max(root_score, 1)
        root_wid = root_wid.view(1)

        mol_tree.add_nodes(1)  # root
        mol_tree.nodes[0].data['wid'] = root_wid
        mol_tree.nodes[0].data['x'] = self.embedding(root_wid)
        mol_tree.nodes[0].data['h'] = init_hidden
        mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0]))
        mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
            self.vocab.get_smiles(root_wid))

        stack, trace = [], []
        stack.append((0, self.vocab.get_slots(root_wid)))

        all_nodes = {0: root_node_dict}
        first = True
        new_node_id = 0
        new_edge_id = 0

        for step in range(MAX_DECODE_LEN):
            u, u_slots = stack[-1]
            udata = mol_tree.nodes[u].data
            x = udata['x']
            h = udata['h']

            # Predict stop
            p_input = torch.cat([x, h, mol_vec], 1)
            p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
            backtrack = (p_score.item() < 0.5)

            if not backtrack:
                # Predict next clique.  Note that the prediction may fail due
                # to lack of assemblable components
                mol_tree.add_nodes(1)
                new_node_id += 1
                v = new_node_id
                mol_tree.add_edges(u, v)
                uv = new_edge_id
                new_edge_id += 1

                if first:
                    mol_tree.edata.update({
                        's':
                        cuda(torch.zeros(1, self.hidden_size)),
                        'm':
                        cuda(torch.zeros(1, self.hidden_size)),
                        'r':
                        cuda(torch.zeros(1, self.hidden_size)),
                        'z':
                        cuda(torch.zeros(1, self.hidden_size)),
                        'src_x':
                        cuda(torch.zeros(1, self.hidden_size)),
                        'dst_x':
                        cuda(torch.zeros(1, self.hidden_size)),
                        'rm':
                        cuda(torch.zeros(1, self.hidden_size)),
                        'accum_rm':
                        cuda(torch.zeros(1, self.hidden_size)),
                    })
                    first = False

                mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x']
                # keeping dst_x 0 is fine as h on new edge doesn't depend on that.

                # DGL doesn't dynamically maintain a line graph.
                mol_tree_lg = dgl.line_graph(mol_tree,
                                             backtracking=False,
                                             shared=True)

                mol_tree_lg.pull(
                    uv,
                    dec_tree_edge_msg,
                    dec_tree_edge_reduce,
                    self.dec_tree_edge_update.update_zm,
                )
                mol_tree.pull(
                    v,
                    dec_tree_node_msg,
                    dec_tree_node_reduce,
                )

                vdata = mol_tree.nodes[v].data
                h_v = vdata['h']
                q_input = torch.cat([h_v, mol_vec], 1)
                q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))),
                                        -1)
                _, sort_wid = torch.sort(q_score, 1, descending=True)
                sort_wid = sort_wid.squeeze()

                next_wid = None
                for wid in sort_wid.tolist()[:5]:
                    slots = self.vocab.get_slots(wid)
                    cand_node_dict = create_node_dict(
                        self.vocab.get_smiles(wid))
                    if (have_slots(u_slots, slots)
                            and can_assemble(mol_tree, u, cand_node_dict)):
                        next_wid = wid
                        next_slots = slots
                        next_node_dict = cand_node_dict
                        break

                if next_wid is None:
                    # Failed adding an actual children; v is a spurious node
                    # and we mark it.
                    vdata['fail'] = cuda(torch.tensor([1]))
                    backtrack = True
                else:
                    next_wid = cuda(torch.tensor([next_wid]))
                    vdata['wid'] = next_wid
                    vdata['x'] = self.embedding(next_wid)
                    mol_tree.nodes_dict[v] = next_node_dict
                    all_nodes[v] = next_node_dict
                    stack.append((v, next_slots))
                    mol_tree.add_edge(v, u)
                    vu = new_edge_id
                    new_edge_id += 1
                    mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data[
                        'x']
                    mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data[
                        'x']
                    mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data[
                        'x']

                    # DGL doesn't dynamically maintain a line graph.
                    mol_tree_lg = dgl.line_graph(mol_tree,
                                                 backtracking=False,
                                                 shared=True)
                    mol_tree_lg.apply_nodes(self.dec_tree_edge_update.update_r,
                                            uv)

            if backtrack:
                if len(stack) == 1:
                    break  # At root, terminate

                pu, _ = stack[-2]
                u_pu = mol_tree.edge_id(u, pu)

                mol_tree_lg.pull(
                    u_pu,
                    dec_tree_edge_msg,
                    dec_tree_edge_reduce,
                    self.dec_tree_edge_update,
                )
                mol_tree.pull(
                    pu,
                    dec_tree_node_msg,
                    dec_tree_node_reduce,
                )
                stack.pop()

        effective_nodes = mol_tree.filter_nodes(
            lambda nodes: nodes.data['fail'] != 1)
        effective_nodes, _ = torch.sort(effective_nodes)
        return mol_tree, all_nodes, effective_nodes
Beispiel #15
0
def line_graph(g, backtracking=True, shared=False):
    #g2 = tocpu(g)
    g2 = dgl.line_graph(g, backtracking, shared)
    #g2 = g2.to(g.device)
    g2.ndata.update(g.edata)
    return g2
Beispiel #16
0
    def forward(self, tree_graphs, tree_vec):
        device = tree_vec.device
        batch_size = tree_graphs.batch_size

        root_ids = get_root_ids(tree_graphs)

        if 'x' not in tree_graphs.ndata:
            tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid'])
        if 'src_x' not in tree_graphs.edata:
            tree_graphs.apply_edges(fn.copy_u('x', 'src_x'))
        tree_graphs = tree_graphs.local_var()
        tree_graphs.apply_edges(func=lambda edges: {'dst_wid': edges.dst['wid']})

        line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False, shared=True)
        line_num_nodes = line_tree_graphs.num_nodes()
        line_tree_graphs.ndata.update({
            'src_x_r': self.W_r(line_tree_graphs.ndata['src_x']),
            # Exploit the fact that the reduce function is a sum of incoming messages,
            # and uncomputed messages are zero vectors.
            'h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
            'vec': dgl.broadcast_edges(tree_graphs, tree_vec),
            'sum_h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
            'sum_gated_h': torch.zeros(line_num_nodes, self.hidden_size).to(device)
        })

        # input tensors for stop prediction (p) and label prediction (q)
        pred_hiddens, pred_mol_vecs, pred_targets = [], [], []
        stop_hiddens, stop_targets = [], []

        # Predict root
        pred_hiddens.append(torch.zeros(batch_size, self.hidden_size).to(device))
        pred_targets.append(tree_graphs.ndata['wid'][root_ids.to(device)])
        pred_mol_vecs.append(tree_vec)

        # Traverse the tree and predict on children
        for eid, p in dfs_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)):
            eid = eid.to(device)
            p = p.to(device=device, dtype=tree_graphs.idtype)

            # Message passing excluding the target
            line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'),
                                  reduce_func=fn.sum('h_nei', 'sum_h'))
            line_tree_graphs.pull(v=eid, message_func=self.gru_message,
                                  reduce_func=fn.sum('m', 'sum_gated_h'))
            line_tree_graphs.apply_nodes(self.gru_update, v=eid)

            # Node aggregation including the target
            # By construction, the edges of the raw graph follow the order of
            # (i1, j1), (j1, i1), (i2, j2), (j2, i2), ... The order of the nodes
            # in the line graph corresponds to the order of the edges in the raw graph.
            eid = eid.long()
            reverse_eid = torch.bitwise_xor(eid, torch.tensor(1).to(device))
            cur_o = line_tree_graphs.ndata['sum_h'][eid] + \
                    line_tree_graphs.ndata['h'][reverse_eid]

            # Gather targets
            mask = (p == torch.tensor(0).to(device))
            pred_list = eid[mask]
            stop_target = torch.tensor(1).to(device) - p

            # Hidden states for stop prediction
            stop_hidden = torch.cat([line_tree_graphs.ndata['src_x'][eid],
                                     cur_o, line_tree_graphs.ndata['vec'][eid]], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_targets.extend(stop_target)

            #Hidden states for clique prediction
            if len(pred_list) > 0:
                pred_mol_vecs.append(line_tree_graphs.ndata['vec'][pred_list])
                pred_hiddens.append(line_tree_graphs.ndata['h'][pred_list])
                pred_targets.append(line_tree_graphs.ndata['dst_wid'][pred_list])

        #Last stop at root
        root_ids = root_ids.to(device)
        cur_x = tree_graphs.ndata['x'][root_ids]
        tree_graphs.edata['h'] = line_tree_graphs.ndata['h']
        tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype),
                         message_func=fn.copy_e('h', 'm'), reduce_func=fn.sum('m', 'cur_o'))
        stop_hidden = torch.cat([cur_x, tree_graphs.ndata['cur_o'][root_ids], tree_vec], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_targets.extend(torch.zeros(batch_size).to(device))

        # Predict next clique
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
        pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
        pred_vecs = F.relu(self.W(pred_vecs))
        pred_scores = self.W_o(pred_vecs)
        pred_targets = torch.cat(pred_targets, dim=0)

        pred_loss = self.pred_loss(pred_scores, pred_targets) / batch_size
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        # Predict stop
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_vecs = F.relu(self.U(stop_hiddens))
        stop_scores = self.U_s(stop_vecs).squeeze()
        stop_targets = torch.Tensor(stop_targets).to(device)

        stop_loss = self.stop_loss(stop_scores, stop_targets) / batch_size
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()