示例#1
0
    def forward(self, mol_graph):
        fatoms, fbonds, agraph, bgraph, scope = mol_graph
        fatoms = create_var(fatoms)
        fbonds = create_var(fbonds)
        agraph = create_var(agraph)
        bgraph = create_var(bgraph)

        binput = self.W_i(fbonds)
        message = nn.ReLU()(binput)

        for _ in xrange(self.depth - 1):
            nei_message = index_select_ND(message, 0, bgraph)
            nei_message = nei_message.sum(dim=1)
            nei_message = self.W_h(nei_message)
            message = nn.ReLU()(binput + nei_message)

        nei_message = index_select_ND(message, 0, agraph)
        nei_message = nei_message.sum(dim=1)
        ainput = torch.cat([fatoms, nei_message], dim=1)
        atom_hiddens = nn.ReLU()(self.W_o(ainput))

        mol_vecs = []
        for st, le in scope:
            mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
            mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)
        return mol_vecs
示例#2
0
    def forward(self, fnode, fmess, node_graph, mess_graph, scope):
        fnode = create_var(fnode)
        fmess = create_var(fmess)
        node_graph = create_var(node_graph)
        mess_graph = create_var(mess_graph)
        messages = create_var(torch.zeros(mess_graph.size(0),
                                          self.hidden_size))

        fnode = self.embedding(fnode)
        fmess = index_select_ND(fnode, 0, fmess)
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        node_vecs = self.outputNN(node_vecs)

        max_len = max([x for _, x in scope])
        batch_vecs = []
        for st, le in scope:
            cur_vecs = node_vecs[st:st + le]
            cur_vecs = F.pad(cur_vecs, (0, 0, 0, max_len - le))
            batch_vecs.append(cur_vecs)

        tree_vecs = torch.stack(batch_vecs, dim=0)
        return tree_vecs, messages
示例#3
0
    def forward(self, fnode, fmess, node_graph, mess_graph, scope):
        fnode = create_var(fnode)
        fmess = create_var(fmess)
        node_graph = create_var(node_graph)
        mess_graph = create_var(mess_graph)
        messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))
        ##################
        # try:
        fnode = self.embedding(fnode)
            #print(fnode.size())
        # except:
        #     fnode = torch.randn((fnode.size(),hidden_size)).cuda()
        # ####################

        fmess = index_select_ND(fnode, 0, fmess)
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        node_vecs = self.outputNN(node_vecs)

        max_len = max([x for _,x in scope])
        batch_vecs = []
        for st,le in scope:
            cur_vecs = node_vecs[st] #Root is the first node
            batch_vecs.append( cur_vecs )

        tree_vecs = torch.stack(batch_vecs, dim=0)
        return tree_vecs, messages
示例#4
0
    def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0)
        fatoms = create_var(fatoms)
        fbonds = create_var(fbonds)
        agraph = create_var(agraph)
        bgraph = create_var(bgraph)

        binput = self.W_i(fbonds)
        graph_message = F.relu(binput)

        for i in xrange(self.depth - 1):
            message = torch.cat([tree_message,graph_message], dim=0) 
            nei_message = index_select_ND(message, 0, bgraph)
            nei_message = nei_message.sum(dim=1) #assuming tree_message[0] == vec(0)
            nei_message = self.W_h(nei_message)
            graph_message = F.relu(binput + nei_message)

        message = torch.cat([tree_message,graph_message], dim=0)
        nei_message = index_select_ND(message, 0, agraph)
        nei_message = nei_message.sum(dim=1)
        ainput = torch.cat([fatoms, nei_message], dim=1)
        atom_hiddens = F.relu(self.W_o(ainput))
        
        mol_vecs = []
        for st,le in scope:
            mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
            mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)
        return mol_vecs
示例#5
0
    def forward(self, fnode, fmess, node_graph, mess_graph, scope):
        fnode = create_var(fnode)
        fmess = create_var(fmess)
        node_graph = create_var(node_graph)
        mess_graph = create_var(mess_graph)
        messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))

        fnode = self.embedding(fnode)
        fmess1 = index_select_ND(fnode, 0, fmess[:, 0])
        fmess2 = self.E_pos(fmess[:, 1])
        fmess = self.inputNN( torch.cat([fmess1,fmess2], dim=-1) )
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        node_vecs = self.outputNN(node_vecs)

        max_len = max([x for _,x in scope])
        batch_vecs = []
        for st,le in scope:
            cur_vecs = node_vecs[st] #Root is the first node
            batch_vecs.append( cur_vecs )

        tree_vecs = torch.stack(batch_vecs, dim=0)
        return tree_vecs, messages
示例#6
0
 def encode_node(self, tree_tensors, hatom, node_idx):
     """ return the node embedding learned from MPN given tree tensors, learned atom embeddings 
     and the index of node to be learned.
     """
     hnode, hmess, agraph, bgraph = self.embed_tree(tree_tensors, hatom)
     hnode = index_select_ND(hnode, 0, node_idx)
     agraph = index_select_ND(agraph, 0, node_idx)
     hnode, _ = self.mpn(hnode, hmess, agraph, bgraph, self.depthT,
                         self.W_t, self.outputNode)
     return hnode
示例#7
0
 def embed_tree(self, tree_tensors, hatom):
     fnode, fmess, agraph, bgraph, cgraph, dgraph, _ = tree_tensors
     finput = self.embedding(fnode)
         
     hnode = index_select_ND(hatom, 0, dgraph).sum(dim=1)
     hnode = self.W_i( torch.cat([finput, hnode], dim=-1) )
     
     hmess1 = hnode.index_select(index=fmess[:,0], dim=0)
     hmess2 = index_select_ND(hatom, 0, cgraph).sum(dim=1)
     hmess = self.W_g( torch.cat([hmess1, hmess2], dim=-1) )
     
     return hnode, hmess, agraph, bgraph
示例#8
0
    def forward(self, tree_tensors, graph_tensors, orders):
        tensors = self.embed_graph(graph_tensors)
        hatom, _ = self.mpn(*tensors, self.depthG, self.W_a, self.outputAtom)
        hatom[0, :] = hatom[0, :] * 0

        tensors = self.embed_tree(tree_tensors, hatom)
        hnode, _ = self.mpn(*tensors, self.depthT, self.W_t, self.outputNode)
        hnode[0, :] = hnode[0, :] * 0

        revise_nodes = [[edge[1] for edge in order] for order in orders]
        revise_nodes = create_pad_tensor(revise_nodes).to(device).long()
        embedding = index_select_ND(hnode, 0, revise_nodes).sum(dim=1)

        return embedding, hnode, hatom
示例#9
0
    def embed_tree(self, tree_tensors, hatom):
        """ Prepare the embeddings for tree message passing.
        Incoprate the learned embeddings for atoms into the tree node embeddings
        
        Args:
            tree_tensors: The data of junction tree
            hatom: The learned atom embeddings through graph message passing 
        
        """

        fnode, fmess, agraph, bgraph, cgraph, dgraph, _ = tree_tensors
        finput = self.embedding(fnode)

        # combine atom embeddings with node embeddings
        hnode = index_select_ND(hatom, 0, dgraph).sum(dim=1)
        hnode = self.W_i(torch.cat([finput, hnode], dim=-1))

        # combine atom embeddings with edge embeddings
        hmess1 = hnode.index_select(index=fmess[:, 0], dim=0)
        hmess2 = index_select_ND(hatom, 0, cgraph).sum(dim=1)
        hmess = self.W_g(torch.cat([hmess1, hmess2], dim=-1))

        return hnode, hmess, agraph, bgraph
示例#10
0
    def forward(self, holder, depth):
        fnode = create_var(holder[0])
        fmess = create_var(holder[1])
        node_graph = create_var(holder[2])
        mess_graph = create_var(holder[3])
        scope = holder[4]

        fnode = self.embedding(fnode)
        x = index_select_ND(fnode, 0, fmess)
        h = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))

        mask = torch.ones(h.size(0), 1)
        mask[0] = 0  #first vector is padding
        mask = create_var(mask)

        for it in xrange(depth):
            h_nei = index_select_ND(h, 0, mess_graph)
            h = GRU(x, h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
            h = h * mask

        mess_nei = index_select_ND(h, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        root_vecs = [node_vecs[st] for st, le in scope]
        return torch.stack(root_vecs, dim=0)
示例#11
0
    def mpn(self, hnode, hmess, agraph, bgraph, depth, W_m, W_n):
        """ Returns the node embeddings and message embeddings learned through message passing networks

        Args:
            hnode: initial node embeddings
            hmess: initial message embeddings
            agraph: message adjacency matrix for nodes. ( `agraph[i, j] = 1` represents that node i is connected with message j.)
            bgraph: message adjacency matrix for messages. ( `bgraph[i, j] = 1` represents that message i is connected with message j.)
            depth: depth of message passing
            W_m, W_n: functions used in message passing
            
        """
        messages = MPNN(hmess, bgraph, W_m, depth, self.hidden_size)
        mess_nei = index_select_ND(messages, 0, agraph)
        node_vecs = torch.cat((hnode, mess_nei.sum(dim=1)), dim=-1)
        node_vecs = W_n(node_vecs)
        return node_vecs, messages
示例#12
0
    def forward(self, h, x, mess_graph):
        mask = torch.ones(h.size(0), 1)
        mask[0] = 0 #first vector is padding
        mask = create_var(mask)
        for it in xrange(self.depth):
            h_nei = index_select_ND(h, 0, mess_graph)
            sum_h = h_nei.sum(dim=1)
            z_input = torch.cat([x, sum_h], dim=1)
            z = F.sigmoid(self.W_z(z_input))

            r_1 = self.W_r(x).view(-1, 1, self.hidden_size)
            r_2 = self.U_r(h_nei)
            r = F.sigmoid(r_1 + r_2)
            
            gated_h = r * h_nei
            sum_gated_h = gated_h.sum(dim=1)
            h_input = torch.cat([x, sum_gated_h], dim=1)
            pre_h = F.tanh(self.W_h(h_input))
            h = (1.0 - z) * sum_h + z * pre_h
            h = h * mask

        return h
    def forward(self, atom_feature_matrix, bond_feature_matrix,
                atom_adjacency_graph, bond_adjacency_graph, scope, tree_mess):
        """
        Description: Implements the forward pass for encoding the candidate molecular subgraphs, for the graph decoding step. (Section 2.5)

        Args:
            atom_feature_matrix: torch.tensor (shape: num_atoms x ATOM_FEATURE_DIM)
                Matrix of atom features for all the atoms, over all the molecules, in the dataset.

            bond_feature_matrix: torch.tensor (shape: num_bonds x ATOM_FEATURE_DIM + BOND_FEATURE_DIM)
                Matrix of bond features for all the bond, over all the molecules, in the dataset.

            atom_adjacency_graph: torch.tensor (shape: num_atoms x MAX_NUM_NEIGHBORS(=6))
                For every atom across the training dataset, this atom_graph gives the bond idxs of all the bonds
                in which it is present. An atom can at most be present in MAX_NUM_NEIGHBORS(= 6) bonds.

            bond_adjacency_graph: torch.tensor (shape: num_bonds x MAX_NUM_NEIGHBORS(=6))
                For every non-ring bond (cluster-node) across the training dataset, this bond_graph gives the bond idx of
                those non-ring bonds (cluster-nodes), to which it is connected in the "cluster-graph".

            scope: List[Tuple(int, int)]
                List of tuples of (total_atoms, num_atoms). Used to extract the atom features for a
                particular molecule in the dataset, from the atom_feature_matrix.

        Returns:
            mol_vecs: torch.tensor (shape: num_candidate_subgraphs x hidden_size)
                The encoding of all the candidate subgraphs for scoring purposes. (Section 2.5)

        """
        # create PyTorch Variables
        atom_feature_matrix = create_var(atom_feature_matrix)
        bond_feature_matrix = create_var(bond_feature_matrix)
        atom_adjacency_graph = create_var(atom_adjacency_graph)
        bond_adjacency_graph = create_var(bond_adjacency_graph)

        static_messages = self.W_i(bond_feature_matrix)

        # apply ReLU activation for timestep, t = 0
        graph_message = nn.ReLU()(static_messages)

        # implement message passing for timesteps, t = 1 to T (depth)
        for timestep in range(self.depth - 1):

            message = torch.cat([tree_mess, graph_message], dim=0)

            # obtain messages from all the "inward edges"
            neighbor_message_vecs = index_select_ND(message, 0,
                                                    bond_adjacency_graph)

            # sum up all the "inward edge" message vectors
            neighbor_message_vecs_sum = neighbor_message_vecs.sum(dim=1)

            # multiply with the weight matrix for the hidden layer
            neighbor_message = self.W_h(neighbor_message_vecs_sum)

            # message at timestep t + 1
            graph_message = nn.ReLU()(static_messages + neighbor_message)

        # neighbor message vectors for each node from the message matrix
        message = torch.cat([tree_mess, graph_message], dim=0)

        # neighbor message for each atom
        neighbor_message_vecs = index_select_ND(message, 0,
                                                atom_adjacency_graph)

        # neighbor message for each atom
        neighbor_message_atom_matrix = neighbor_message_vecs.sum(dim=1)

        # concatenate atom feature vector and neighbor hidden message vector
        atom_input_matrix = torch.cat(
            [atom_feature_matrix, neighbor_message_atom_matrix], dim=1)

        atom_hidden_layer_synaptic_input = nn.ReLU()(
            self.W_o(atom_input_matrix))

        # list to store the corresponding molecule vectors for each molecule
        mol_vecs = []
        for start_idx, len in scope:
            # mol_vec = atom_hidden_layer_synaptic_input.narrow(0, start_idx, len).sum(dim=0) / len
            mol_vec = atom_hidden_layer_synaptic_input[start_idx:start_idx +
                                                       len].mean(dim=0)
            mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)
        return mol_vecs
    def forward(self, node_wid_list, edge_node_idx_list, node_message_graph, mess_adjacency_graph, scope):
        """
        Args:
            node_wid_list: torch.LongTensor() (shape: num_edges)
                The list of wids i.e. idx of the corresponding cluster vocabulary item, for the initial node of each edge.

            edge_node_idx_list: torch.LongTensor() (shape: num_edges)
                The list of idx of the initial node of each edge.

            node_message_graph: torch.LongTensor (shape: num_nodes x MAX_NUM_NEIGHBORS)
                For each node, the list of idxs of all "inward hidden edge message vectors" for purposes of node feature aggregation.

            mess_adjacency_graph: torch.LongTensor (shape: num_edges x MAX_NUM_NEIGHBORS)
                For each edge, the list of idxs of all "inward hidden edge message vectors" for purposes of node feature aggregation.

            scope: List[Tuple(int, int)]
                The list to store tuples of (start_idx, len) to segregate all the node features, for a particular junction-tree.

            mess_dict: Dict{Tuple(int, int): int}
                The dictionary mapping edge in the form (x.idx, y.idx) to idx of message.

        Returns:
            tree_vecs: torch.tensor (shape: batch_size x hidden_size)
                The hidden vectors for the root nodes, of all the junction-trees, across the entire dataset.

        """
        # create PyTorch variables
        node_wid_list = create_var(node_wid_list)
        edge_node_idx_list = create_var(edge_node_idx_list)
        node_message_graph = create_var(node_message_graph)
        mess_adjacency_graph = create_var(mess_adjacency_graph)

        # hidden vectors for all the edges
        messages = create_var(torch.zeros(mess_adjacency_graph.size(0), self.hidden_size))

        # obtain node feature embedding
        node_feature_embeddings = self.embedding(node_wid_list)

        # for each edge obtain the embedding for the initial node
        initial_node_features = index_select_ND(node_feature_embeddings, 0, edge_node_idx_list)

        # obtain the hidden vectors for all the edges using GRU
        messages = self.GRU(messages, initial_node_features, mess_adjacency_graph)

        # for each node, obtain all the neighboring message vectors
        node_neighbor_mess_vecs = index_select_ND(messages, 0, node_message_graph)

        # for each node, sum up all the neighboring message vectors
        node_neighbor_mess_vecs_sum = node_neighbor_mess_vecs.sum(dim=1)

        # for each node, concatenate the node embedding feature and the sum of hidden neighbor message vectors
        node_vecs_synaptic_input = torch.cat([node_feature_embeddings, node_neighbor_mess_vecs_sum], dim=-1)

        # apply the neural network layer
        node_vecs = self.outputNN(node_vecs_synaptic_input)

        # list to store feature vectors of the root node, for all the junction-trees, across the entire dataset
        root_vecs = []

        for start_idx, _ in scope:
            # root node is the first node in the list of nodes of a juncion-tree by design
            root_vec = node_vecs[start_idx]
            root_vecs.append(root_vec)

        # stack the root tensors to form a 2-D tensor
        tree_vecs = torch.stack(root_vecs, dim=0)
        return tree_vecs, messages
示例#15
0
    def forward(self, atom_layer_input, bond_layer_input, atom_adjacency_graph, atom_bond_adjacency_graph, bond_atom_adjacency_graph):
        """
        Args:
            atom_layer_input: torch.tensor (shape: batch_size x atom_feature_dim)
                The matrix containing feature vectors, for all the atoms, across the entire batch.
                * atom_feature_dim = len(ELEM_LIST) + 6 + 5 + 4 + 1

            bond_layer_input: torch.tensor (shape: batch_size x bond_feature_dim)
                The matrix containing feature vectors, for all the bonds, across the entire batch.
                * bond_feature_dim = 5 + 6

            atom_adjacency_graph: torch.tensor (shape: num_atoms x MAX_NUM_NEIGHBORS(=6))
                For each atom, across the entire batch, the idxs of neighboring atoms.

            atom_bond_adjacency_graph: torch.tensor(shape: num_atoms x MAX_NUM_NEIGHBORS(=6))
                For each atom, across the entire batch, the idxs of all the bonds, in which it is the initial atom.

            bond_atom_adjacency_graph: torch.tensor (shape num_bonds x 2)
                For each bond, across the entire batch, the idxs of the 2 atoms, of which the bond is composed of.
        """
        # implement edge gate computation
        edge_gate_x = torch.index_select(input=atom_layer_input, dim=0, index=bond_atom_adjacency_graph[:, 0])
        edge_gate_y = torch.index_select(input=atom_layer_input, dim=0, index=bond_atom_adjacency_graph[:, 1])

        assert(bond_layer_input.shape[0] == edge_gate_x.shape[0])
        assert (bond_layer_input.shape[0] == edge_gate_y.shape[0])

        edge_gate_synaptic_input = self.A(bond_layer_input) + self.B(edge_gate_x) + self.C(edge_gate_y)

        # apply sigmoid activation for computing edge gates
        edge_gates = F.sigmoid(edge_gate_synaptic_input)

        # implement batch normalization for bond/edge features
        # edge_gate_synaptic_input = self.bn_bond_features(edge_gate_synaptic_input)

        # apply ReLU activation for computing new bond features
        # add residual
        # if self.bond_feature_dim == 0:
        #     bond_layer_output = F.relu(edge_gate_synaptic_input) + bond_layer_input
        # else:
        #     bond_layer_output = F.relu(edge_gate_synaptic_input)

        bond_layer_output = F.relu(edge_gate_synaptic_input)

        # implement node features computation

        # for each atom, aggregate the features vectors of neighboring atoms
        atom_neighbor_features_tensor = self.V(index_select_ND(atom_layer_input, 0, atom_adjacency_graph))

        # for each atom, get the edge gates for the corresponding neighbor atom features
        atom_neighbor_edge_gates_tensor = index_select_ND(edge_gates, 0, atom_bond_adjacency_graph)

        assert(atom_neighbor_edge_gates_tensor.shape == atom_neighbor_features_tensor.shape)

        # for each atom, multiply the edge gates with corresponding neighbor atom feature vectors
        atom_neighbor_message_tensor = atom_neighbor_edge_gates_tensor * atom_neighbor_features_tensor

        atom_neighbor_message_sum = atom_neighbor_message_tensor.sum(dim=1)

        assert(atom_neighbor_message_sum.shape[0] == atom_layer_input.shape[0])

        atom_features_synaptic_input = self.U(atom_layer_input) + atom_neighbor_message_sum

        # implement batch normalization
        # atom_features_synaptic_input = self.bn_atom_features(atom_features_synaptic_input)

        # apply ReLU activation for computing new atom features
        # add residual
        # if self.atom_feature_dim == 0:
        #     atom_layer_output = F.relu(atom_features_synaptic_input) + atom_layer_input
        # else:
        #     atom_layer_output = F.relu(atom_features_synaptic_input)

        atom_layer_output = F.relu(atom_features_synaptic_input)

        return atom_layer_output, bond_layer_output


        # for atom_idx in range(total_atoms):
        #     # feature vector for the current atom
        #     atom_feature_vec = atom_feature_matrix[atom_idx]
        #
        #     # idxs of all the neighbor atoms, of current atom
        #     neighbor_atom_idx = atom_adjacency_list[atom_idx]
        #
        #     # feature vectors of all the neighbor atoms.
        #     neighbor_atom_feature_vecs = torch.index_select(input=atom_feature_matrix, dim=0, index=neighbor_atom_idx)
        #
        #     # idxs of all the bonds, in which this atom is that beginning atom
        #     bond_atom_idx = atom_bond_adjacency_list[atom_idx]
        #
        #     # feature vectors of all the bonds, in which this atom is that beginning atom
        #     bond_feature_vecs = torch.index_select(input=bond_feature_matrix, dim=0, index=bond_atom_idx)
        #
        #     # compute new bond features, of all the bond, in which this atom, is the starting atom
        #     bond_features_synaptic_input = self.evaluate_bond_features_synaptic_input(atom_feature_vec, bond_feature_vecs, neighbor_atom_feature_vecs)
        #
        #     # apply the ReLU activation onto the new edge
        #     new_bond_features = nn.ReLU()(bond_features_synaptic_input)
        #
        #     # update the feature vectors of all the bonds, in which this atom is the beginning atom
        #     bond_layer_output[bond_atom_idx] = new_bond_features
        #
        #     # evaluate the edge gates, for all the edges in which this atom is the beginning atom
        #     edge_gates = nn.Sigmoid()(bond_features_synaptic_input)
        #
        #     # implement point-wise multiplication (Hadamard Product)
        #     edge_gate_prod_neighbor_vecs = edge_gates * self.V(neighbor_atom_feature_vecs)
        #
        #     # sum up the hadamard product of the edge gates and the neighbor atom feature vectors
        #     neighbor_vec_edge_gate_sum = torch.sum(edge_gate_prod_neighbor_vecs, dim=0)
        #
        #     # evaluate the new feature vector for the atom
        #     new_atom_vec_synaptic_input = self.U(atom_feature_vec) + neighbor_vec_edge_gate_sum
        #
        #     # apply ReLU activation
        #     new_atom_vec = nn.ReLU()(new_atom_vec_synaptic_input)
        #
        #     # set the atom's feature vector to the new value
        #     atom_layer_output[atom_idx] = new_atom_vec
        return atom_layer_output, bond_layer_output
    def forward(self, cand_batch, tree_mess):
        fatoms, fbonds = [], []
        in_bonds, all_bonds = [], []
        mess_dict, all_mess = {}, [create_var(torch.zeros(self.hidden_size))
                                   ]  #Ensure index 0 is vec(0)
        total_atoms = 0
        scope = []

        for e, vec in tree_mess.iteritems():
            mess_dict[e] = len(all_mess)
            all_mess.append(vec)

        for mol, all_nodes, ctr_node in cand_batch:
            n_atoms = mol.GetNumAtoms()
            ctr_bid = ctr_node.idx

            for atom in mol.GetAtoms():
                fatoms.append(atom_features(atom))
                in_bonds.append([])

            for bond in mol.GetBonds():
                a1 = bond.GetBeginAtom()
                a2 = bond.GetEndAtom()
                x = a1.GetIdx() + total_atoms
                y = a2.GetIdx() + total_atoms
                #Here x_nid,y_nid could be 0
                x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
                x_bid = all_nodes[x_nid - 1].idx if x_nid > 0 else -1
                y_bid = all_nodes[y_nid - 1].idx if y_nid > 0 else -1

                bfeature = bond_features(bond)

                b = len(all_mess) + len(
                    all_bonds)  #bond idx offseted by len(all_mess)
                all_bonds.append((x, y))
                fbonds.append(torch.cat([fatoms[x], bfeature], 0))
                in_bonds[y].append(b)

                b = len(all_mess) + len(all_bonds)
                all_bonds.append((y, x))
                fbonds.append(torch.cat([fatoms[y], bfeature], 0))
                in_bonds[x].append(b)

                if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
                    if (x_bid, y_bid) in mess_dict:
                        mess_idx = mess_dict[(x_bid, y_bid)]
                        in_bonds[y].append(mess_idx)
                    if (y_bid, x_bid) in mess_dict:
                        mess_idx = mess_dict[(y_bid, x_bid)]
                        in_bonds[x].append(mess_idx)

            scope.append((total_atoms, n_atoms))
            total_atoms += n_atoms

        total_bonds = len(all_bonds)
        total_mess = len(all_mess)
        fatoms = torch.stack(fatoms, 0)
        fbonds = torch.stack(fbonds, 0)
        agraph = torch.zeros(total_atoms, MAX_NB).long()
        bgraph = torch.zeros(total_bonds, MAX_NB).long()
        tree_message = torch.stack(all_mess, dim=0)

        for a in xrange(total_atoms):
            for i, b in enumerate(in_bonds[a]):
                agraph[a, i] = b

        for b1 in xrange(total_bonds):
            x, y = all_bonds[b1]
            for i, b2 in enumerate(
                    in_bonds[x]):  #b2 is offseted by len(all_mess)
                if b2 < total_mess or all_bonds[b2 - total_mess][0] != y:
                    bgraph[b1, i] = b2

        fatoms = create_var(fatoms)
        fbonds = create_var(fbonds)
        agraph = create_var(agraph)
        bgraph = create_var(bgraph)

        binput = self.W_i(fbonds)
        graph_message = nn.ReLU()(binput)

        for i in xrange(self.depth - 1):
            message = torch.cat([tree_message, graph_message], dim=0)
            nei_message = index_select_ND(message, 0, bgraph)
            nei_message = nei_message.sum(dim=1)
            nei_message = self.W_h(nei_message)
            graph_message = nn.ReLU()(binput + nei_message)

        message = torch.cat([tree_message, graph_message], dim=0)
        nei_message = index_select_ND(message, 0, agraph)
        nei_message = nei_message.sum(dim=1)
        ainput = torch.cat([fatoms, nei_message], dim=1)
        atom_hiddens = nn.ReLU()(self.W_o(ainput))

        mol_vecs = []
        for st, le in scope:
            mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
            mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)
        return mol_vecs
示例#17
0
 def mpn(self, hnode, hmess, agraph, bgraph, depth, W_m, W_n):
     messages = MPNN(hmess, bgraph, W_m, depth, self.hidden_size)
     mess_nei = index_select_ND(messages, 0, agraph)
     node_vecs = torch.cat((hnode, mess_nei.sum(dim=1)), dim=-1)
     node_vecs = W_n(node_vecs)
     return node_vecs, messages
示例#18
0
 def encode_node(self, tree_tensors, hatom, node_idx):
     hnode, hmess, agraph, bgraph = self.embed_tree(tree_tensors, hatom)
     hnode = index_select_ND(hnode, 0, node_idx)
     agraph = index_select_ND(agraph, 0, node_idx)
     hnode, _ = self.mpn(hnode, hmess, agraph, bgraph, self.depthT, self.W_t, self.outputNode)
     return hnode