def forward(self, mol_graph: BatchMolGraph): f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components( ) if self.atom_messages: a2a = mol_graph.get_a2a() input = self.W_i(f_atoms) else: a2a = None input = self.W_i(f_bonds) message = self.act_func(input) return (input, message, f_atoms, f_bonds, a2a, a2b, b2a, b2revb, a_scope)
def forward(self, mol_graph: BatchMolGraph, features_batch: List[np.ndarray] = None) -> torch.FloatTensor: """ Encodes a batch of molecular graphs. :param mol_graph: A BatchMolGraph representing a batch of molecular graphs. :param features_batch: A list of ndarrays containing additional features. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if self.use_input_features: features_batch = torch.from_numpy( np.stack(features_batch)).float().to(self.device) if self.features_only: return features_batch f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components( atom_messages=self.atom_messages) f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.to( self.device), f_bonds.to(self.device), a2b.to(self.device), b2a.to( self.device), b2revb.to(self.device) if self.atom_messages: a2a = mol_graph.get_a2a().to(self.device) # Input if self.atom_messages: input = self.W_i(f_atoms) # num_atoms x hidden_size else: input = self.W_i(f_bonds) # num_bonds x hidden_size message = self.act_func(input) # num_bonds x hidden_size # Message passing for depth in range(self.depth - 1): if self.undirected: message = (message + message[b2revb]) / 2 if self.atom_messages: nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum( dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden message = self.W_h(message) message = self.act_func(input + message) # num_bonds x hidden_size message = self.dropout_layer(message) # num_bonds x hidden a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden # Readout mol_vecs = [] for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) mol_vec = cur_hiddens # (num_atoms, hidden_size) mol_vec = mol_vec.sum(dim=0) / a_size mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) if self.use_input_features: features_batch = features_batch.to(mol_vecs) if len(features_batch.shape) == 1: features_batch = features_batch.view( [1, features_batch.shape[0]]) mol_vecs = torch.cat([mol_vecs, features_batch], dim=1) # (num_molecules, hidden_size) return mol_vecs # num_molecules x hidden
def forward( self, mol_graph: BatchMolGraph, atom_descriptors_batch: List[np.ndarray] = None ) -> torch.FloatTensor: """ Encodes a batch of molecular graphs. :param mol_graph: A :class:`~chemprop.features.featurization.BatchMolGraph` representing a batch of molecular graphs. :param atom_descriptors_batch: A list of numpy arrays containing additional atomic descriptors :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule. """ if atom_descriptors_batch is not None: atom_descriptors_batch = [ np.zeros([1, atom_descriptors_batch[0].shape[1]]) ] + atom_descriptors_batch # padding the first with 0 to match the atom_hiddens atom_descriptors_batch = (torch.from_numpy( np.concatenate(atom_descriptors_batch, axis=0)).float().to(self.device)) f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components( atom_messages=self.atom_messages) f_atoms, f_bonds, a2b, b2a, b2revb = ( f_atoms.to(self.device), f_bonds.to(self.device), a2b.to(self.device), b2a.to(self.device), b2revb.to(self.device), ) if self.atom_messages: a2a = mol_graph.get_a2a().to(self.device) # Input if self.atom_messages: input = self.W_i(f_atoms) # num_atoms x hidden_size else: input = self.W_i(f_bonds) # num_bonds x hidden_size message = self.act_func(input) # num_bonds x hidden_size # Message passing for depth in range(self.depth - 1): if self.undirected: message = (message + message[b2revb]) / 2 if self.atom_messages: nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum( dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden message = self.W_h(message) message = self.act_func(input + message) # num_bonds x hidden_size message = self.dropout_layer(message) # num_bonds x hidden a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden # concatenate the atom descriptors if atom_descriptors_batch is not None: if len(atom_hiddens) != len(atom_descriptors_batch): raise ValueError( f"The number of atoms is different from the length of the extra atom features" ) atom_hiddens = torch.cat( [atom_hiddens, atom_descriptors_batch], dim=1) # num_atoms x (hidden + descriptor size) atom_hiddens = self.atom_descriptors_layer( atom_hiddens) # num_atoms x (hidden + descriptor size) atom_hiddens = self.dropout_layer( atom_hiddens) # num_atoms x (hidden + descriptor size) # Readout mol_vecs = [] for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) mol_vec = cur_hiddens # (num_atoms, hidden_size) if self.aggregation == "mean": mol_vec = mol_vec.sum(dim=0) / a_size elif self.aggregation == "sum": mol_vec = mol_vec.sum(dim=0) elif self.aggregation == "norm": mol_vec = mol_vec.sum(dim=0) / self.aggregation_norm mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) return mol_vecs # num_molecules x hidden
def forward(self, mol_graph: BatchMolGraph, features_batch: List[np.ndarray] = None) -> torch.FloatTensor: """ Encodes a batch of molecular graphs. :param mol_graph: A BatchMolGraph representing a batch of molecular graphs. :param features_batch: A list of ndarrays containing additional features. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if self.use_input_features: features_batch = torch.from_numpy(np.stack(features_batch)).float() if self.args.cuda: features_batch = features_batch.cuda() if self.features_only: return features_batch f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, smiles_batch = mol_graph.get_components( ) print('f_atoms', f_atoms.shape) print(f_atoms[0, :]) print(f_atoms[1, :]) print('f_bonds', f_bonds.shape) print(f_bonds[0, :]) print('a2b', a2b) print('b2a', b2a) print('b2revb', b2revb) print('a_scope', a_scope) print('b_scope', b_scope) if self.atom_messages: a2a = mol_graph.get_a2a() if self.args.cuda or next(self.parameters()).is_cuda: f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda( ), a2b.cuda(), b2a.cuda(), b2revb.cuda() if self.atom_messages: a2a = a2a.cuda() # Input if self.atom_messages: input = self.W_i(f_atoms) # num_atoms x hidden_size else: input = self.W_i(f_bonds) # num_bonds x hidden_size print('smiles_batch', smiles_batch) print('f_bonds', f_bonds.shape) print('\n\n\n\n') # print('input', input.shape) message = self.act_func(input) # num_bonds x hidden_size # print('before', message[0, :2]) # message[0, 0] = message[0, 0] + 0.001 # print('after', message[0, :2]) # print('message', message, message.shape) # print('\n\n') # Message passing for depth in range(self.depth - 1): # GCNN # print('depth', depth) if self.undirected: message = (message + message[b2revb]) / 2 if self.atom_messages: nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum( dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden message = self.W_h(message) message = self.act_func(input + message) # num_bonds x hidden_size message = self.dropout_layer(message) # num_bonds x hidden a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat( [f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) 133 + 1000 atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden # Readout mol_vecs = [] for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) # MPNEncoder.output_x.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) mol_vec = cur_hiddens # (num_atoms, hidden_size) mol_vec = mol_vec.sum(dim=0) / a_size # sum 一個分子中所有原子向量 / 原子數量 mol_vecs.append(mol_vec) # MPNEncoder.output_x.append(mol_vec) ##################################################### #if len(mol_vecs) < 50: # nnnnvecs = torch.stack(MPNEncoder.output_x, dim=0) # with open('hidden_vector.pkl', 'wb') as f: # pickle.dump(nnnnvecs, f) # print('mol_vecs: ', nnnnvecs) # print('mol_vecs_size: ', nnnnvecs.shape) ##################################################### mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) if self.use_input_features: features_batch = features_batch.to(mol_vecs) if len(features_batch.shape) == 1: features_batch = features_batch.view( [1, features_batch.shape[0]]) mol_vecs = torch.cat([mol_vecs, features_batch], dim=1) # (num_molecules, hidden_size) return mol_vecs # num_molecules x hidden
def forward(self, mol_graph: BatchMolGraph, features_batch: List[np.ndarray] = None) -> torch.FloatTensor: """ Encodes a batch of molecular graphs. :param mol_graph: A BatchMolGraph representing a batch of molecular graphs. :param features_batch: A list of ndarrays containing additional features. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if self.use_input_features: features_batch = torch.from_numpy(np.stack(features_batch)).float() if self.args.cuda: features_batch = features_batch.cuda() if self.features_only: return features_batch f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components( ) if self.atom_messages: a2a = mol_graph.get_a2a() if self.args.cuda or next(self.parameters()).is_cuda: f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda( ), a2b.cuda(), b2a.cuda(), b2revb.cuda() if self.atom_messages: a2a = a2a.cuda() # Input if self.atom_messages: #print('atom_messages') input = self.W_i(f_atoms) # num_atoms x hidden_size else: input = self.W_i(f_bonds) # num_bonds x hidden_size message = self.act_func(input) # num_bonds x hidden_size save_depth = [] # wei, update for each step # wei, save the information of depth=0 (atomic features) padding = torch.zeros( (f_atoms.size()[0], f_atoms.size()[1] + self.hidden_size)) if self.args.cuda: padding = padding.cuda() padding[:, :f_atoms.size()[1]] = f_atoms a_input = padding atom_hiddens = self.act_func(self.W_o(a_input)) atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden save_depth.append(atom_hiddens) #print('mol_graph:', mol_graph.smiles_batch) # Message passing for depth in range(self.depth - 1): ################ save information of one bond distance from the central atom ################ if depth == 0: a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func( self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer( atom_hiddens) # num_atoms x hidden save_depth.append(atom_hiddens) ################ save information of one bond distance from the central atom ################ if self.undirected: message = (message + message[b2revb]) / 2 if self.atom_messages: nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum( dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden message = self.W_h(message) message = self.act_func(input + message) # num_bonds x hidden_size message = self.dropout_layer(message) # num_bonds x hidden # wei, save depth a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func( self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer( atom_hiddens) # num_atoms x hidden save_depth.append(atom_hiddens) # save information of each depth ############ origin ############ #print('message_after_m_passing:\n', message) #a2x = a2a if self.atom_messages else a2b #nei_a_message = index_select_ND(message, a2x) # num_atoms x max_num_bonds x hidden #a_message = nei_a_message.sum(dim=1) # num_atoms x hidden #a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) #atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden #atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden ############ origin ############ # Readout atomic_vecs_d0 = [] atomic_vecs_d1 = [] atomic_vecs_d2 = [] atomic_vecs_final = [] mol_vecs = [] for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens_d0 = save_depth[0].narrow(0, a_start, a_size) cur_hiddens_d1 = save_depth[1].narrow(0, a_start, a_size) cur_hiddens_d2 = save_depth[2].narrow(0, a_start, a_size) cur_hiddens_final = save_depth[-1].narrow(0, a_start, a_size) cur_hiddens_mol = save_depth[-1].narrow(0, a_start, a_size) atomic_vec_d0 = cur_hiddens_d0 atomic_vec_d0 = self.padding(atomic_vec_d0) # padding atomic_vec_d1 = cur_hiddens_d1 atomic_vec_d1 = self.padding(atomic_vec_d1) # padding atomic_vec_d2 = cur_hiddens_d2 atomic_vec_d2 = self.padding(atomic_vec_d2) # padding atomic_vec_final = cur_hiddens_final atomic_vec_final = self.padding(atomic_vec_final) # padding mol_vec = cur_hiddens_mol # (num_atoms, hidden_size) mol_vec = mol_vec.sum(dim=0) / a_size atomic_vecs_d0.append(atomic_vec_d0) atomic_vecs_d1.append(atomic_vec_d1) atomic_vecs_d2.append(atomic_vec_d2) atomic_vecs_final.append(atomic_vec_final) mol_vecs.append(mol_vec) atomic_vecs_d0 = torch.stack(atomic_vecs_d0, dim=0) atomic_vecs_d1 = torch.stack(atomic_vecs_d1, dim=0) atomic_vecs_d2 = torch.stack(atomic_vecs_d2, dim=0) atomic_vecs_final = torch.stack(atomic_vecs_final, dim=0) mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) if self.args.cuda: atomic_vecs_d0, atomic_vecs_d1, atomic_vecs_d2, atomic_vecs_final, mol_vecs = atomic_vecs_d0.cuda( ), atomic_vecs_d1.cuda(), atomic_vecs_d2.cuda( ), atomic_vecs_final.cuda(), mol_vecs.cuda() #overall_vecs=torch.cat((mol_vecs,mol_vec_molar),dim=0) if self.use_input_features: features_batch = features_batch.to(mol_vecs) if len(features_batch.shape) == 1: features_batch = features_batch.view( [1, features_batch.shape[0]]) mol_vecs = torch.cat( [mol_vecs, features_batch], dim=1) # (num_molecules, num_atoms, hidden_size) return atomic_vecs_d0, atomic_vecs_d1, atomic_vecs_d2, atomic_vecs_final, mol_vecs
def forward(self, mol_graph: BatchMolGraph, features_batch: List[np.ndarray] = None) -> torch.FloatTensor: """ Encodes a batch of molecular graphs. :param mol_graph: A BatchMolGraph representing a batch of molecular graphs. :param features_batch: A list of ndarrays containing additional features. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if self.use_input_features: features_batch = torch.from_numpy(np.stack(features_batch)).float() if self.args.cuda: features_batch = features_batch.cuda() if self.features_only: return features_batch f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components( ) a2a = mol_graph.get_a2a() if self.args.cuda or next(self.parameters()).is_cuda: f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda( ), a2b.cuda(), b2a.cuda(), b2revb.cuda() a2a = a2a.cuda() # Input input = self.W_i(f_atoms) # num_atoms x hidden_size message = self.act_func(input) # num_bonds x hidden_size # Message passing for depth in range(self.depth): nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim nei_message = self.W_h( nei_message) # num_atoms x max_num_bonds x hidden nei_message = self.act_func(nei_message) nei_message = self.dropout_layer(nei_message) message = nei_message.sum(dim=1) # num_atoms x hidden # Output step: a_input = torch.cat([f_atoms, message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden # Readout mol_vecs = [] for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) mol_vec = cur_hiddens # (num_atoms, hidden_size) mol_vec = mol_vec.sum(dim=0) / a_size mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) if self.use_input_features: features_batch = features_batch.to(mol_vecs) if len(features_batch.shape) == 1: features_batch = features_batch.view( [1, features_batch.shape[0]]) mol_vecs = torch.cat([mol_vecs, features_batch], dim=1) # (num_molecules, hidden_size) return mol_vecs # num_molecules x hidden
def forward( self, mol_graph: BatchMolGraph, features_batch: List[np.ndarray] = None, viz_dir: str = None ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]: """ Encodes a batch of molecular graphs. :param mol_graph: A BatchMolGraph representing a batch of molecular graphs. :param features_batch: A list of ndarrays containing additional features. :param viz_dir: Directory in which to save visualized attention weights. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if self.use_input_features: features_batch = torch.from_numpy(np.stack(features_batch)).float() if self.args.cuda: features_batch = features_batch.cuda() if self.features_only: return features_batch f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components( ) if self.atom_messages: a2a = mol_graph.get_a2a() if self.args.cuda or next(self.parameters()).is_cuda: f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda( ), a2b.cuda(), b2a.cuda(), b2revb.cuda() if self.atom_messages: a2a = a2a.cuda() if self.learn_virtual_edges: atom1_features, atom2_features = f_atoms[b2a], f_atoms[ b2a[b2revb]] # each num_bonds x atom_fdim ve_score = torch.sum( self.lve(atom1_features) * atom2_features, dim=1) + torch.sum( self.lve(atom2_features) * atom1_features, dim=1) is_ve_indicator_index = self.atom_fdim # in current featurization, the first bond feature is 1 or 0 for virtual or not virtual num_virtual = f_bonds[:, is_ve_indicator_index].sum() straight_through_mask = torch.ones( f_bonds.size(0)).cuda() + f_bonds[:, is_ve_indicator_index] * ( ve_score - ve_score.detach()) / num_virtual # normalize for grad norm straight_through_mask = straight_through_mask.unsqueeze(1).repeat( (1, self.hidden_size)) # num_bonds x hidden_size # Input if self.atom_messages: input = self.W_i(f_atoms) # num_atoms x hidden_size else: input = self.W_i(f_bonds) # num_bonds x hidden_size message = self.act_func(input) # num_bonds x hidden_size if self.message_attention: b2b = mol_graph.get_b2b( ) # Warning: this is O(n_atoms^3) when using virtual edges if next(self.parameters()).is_cuda: b2b = b2b.cuda() message_attention_mask = (b2b != 0).float() # num_bonds x max_num_bonds if self.global_attention: global_attention_mask = torch.zeros( mol_graph.n_bonds, mol_graph.n_bonds) # num_bonds x num_bonds for start, length in b_scope: for i in range(start, start + length): global_attention_mask[i, start:start + length] = 1 if next(self.parameters()).is_cuda: global_attention_mask = global_attention_mask.cuda() # Message passing for depth in range(self.depth - 1): if self.undirected: message = (message + message[b2revb]) / 2 if self.learn_virtual_edges: message = message * straight_through_mask if self.message_attention: # TODO: Parallelize attention heads nei_message = index_select_ND(message, b2b) message = message.unsqueeze(1).repeat( (1, nei_message.size(1), 1)) # num_bonds x maxnb x hidden attention_scores = [ (self.W_ma[i](nei_message) * message).sum(dim=2) for i in range(self.num_heads) ] # num_bonds x maxnb attention_scores = [ attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20) for i in range(self.num_heads) ] # num_bonds x maxnb attention_weights = [ F.softmax(attention_scores[i], dim=1) for i in range(self.num_heads) ] # num_bonds x maxnb message_components = [ nei_message * attention_weights[i].unsqueeze(2).repeat( (1, 1, self.hidden_size)) for i in range(self.num_heads) ] # num_bonds x maxnb x hidden message_components = [ component.sum(dim=1) for component in message_components ] # num_bonds x hidden message = torch.cat(message_components, dim=1) # num_bonds x num_heads * hidden elif self.atom_messages: nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum( dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden for lpm in range(self.layers_per_message - 1): message = self.W_h[lpm][depth](message) # num_bonds x hidden message = self.act_func(message) message = self.W_h[self.layers_per_message - 1][depth](message) if self.normalize_messages: message = message / message.norm(dim=1, keepdim=True) if self.master_node: # master_state = self.W_master_in(self.act_func(nei_message.sum(dim=0))) #try something like this to preserve invariance for master node # master_state = self.GRU_master(nei_message.unsqueeze(1)) # master_state = master_state[-1].squeeze(0) #this actually doesn't preserve order invariance anymore mol_vecs = [self.cached_zero_vector] for start, size in b_scope: if size == 0: continue mol_vec = message.narrow(0, start, size) mol_vec = mol_vec.sum(dim=0) / size mol_vecs += [mol_vec for _ in range(size)] master_state = self.act_func( self.W_master_in(torch.stack( mol_vecs, dim=0))) # num_bonds x hidden_size message = self.act_func( input + message + self.W_master_out(master_state)) # num_bonds x hidden_size else: message = self.act_func(input + message) # num_bonds x hidden_size if self.global_attention: attention_scores = torch.matmul( self.W_ga1(message), message.t()) # num_bonds x num_bonds attention_scores = attention_scores * global_attention_mask + ( 1 - global_attention_mask) * (-1e+20 ) # num_bonds x num_bonds attention_weights = F.softmax(attention_scores, dim=1) # num_bonds x num_bonds attention_hiddens = torch.matmul( attention_weights, message) # num_bonds x hidden_size attention_hiddens = self.act_func( self.W_ga2(attention_hiddens)) # num_bonds x hidden_size attention_hiddens = self.dropout_layer( attention_hiddens) # num_bonds x hidden_size message = message + attention_hiddens # num_bonds x hidden_size if viz_dir is not None: visualize_bond_attention(viz_dir, mol_graph, attention_weights, depth) if self.use_layer_norm: message = self.layer_norm(message) message = self.dropout_layer(message) # num_bonds x hidden if self.master_node and self.use_master_as_output: assert self.hidden_size == self.master_dim mol_vecs = [] for start, size in b_scope: if size == 0: mol_vecs.append(self.cached_zero_vector) else: mol_vecs.append(master_state[start]) return torch.stack(mol_vecs, dim=0) # Get atom hidden states from message hidden states if self.learn_virtual_edges: message = message * straight_through_mask a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden if self.deepset: atom_hiddens = self.W_s2s_a(atom_hiddens) atom_hiddens = self.act_func(atom_hiddens) atom_hiddens = self.W_s2s_b(atom_hiddens) if self.bert_pretraining: atom_preds = self.W_v(atom_hiddens)[ 1:] # num_atoms x vocab/output size (leave out atom padding) # Readout if self.set2set: # Set up sizes batch_size = len(a_scope) lengths = [length for _, length in a_scope] max_num_atoms = max(lengths) # Set up memory from atom features memory = torch.zeros( batch_size, max_num_atoms, self.hidden_size) # (batch_size, max_num_atoms, hidden_size) for i, (start, size) in enumerate(a_scope): memory[i, :size] = atom_hiddens.narrow(0, start, size) memory_transposed = memory.transpose( 2, 1) # (batch_size, hidden_size, max_num_atoms) # Create mask (1s for atoms, 0s for not atoms) mask = create_mask(lengths, cuda=next( self.parameters()).is_cuda) # (max_num_atoms, batch_size) mask = mask.t().unsqueeze(2) # (batch_size, max_num_atoms, 1) # Set up query query = torch.ones( 1, batch_size, self.hidden_size) # (1, batch_size, hidden_size) # Move to cuda if next(self.parameters()).is_cuda: memory, memory_transposed, query = memory.cuda( ), memory_transposed.cuda(), query.cuda() # Run RNN for _ in range(self.set2set_iters): # Compute attention weights over atoms in each molecule query = query.squeeze(0).unsqueeze( 2) # (batch_size, hidden_size, 1) dot = torch.bmm(memory, query) # (batch_size, max_num_atoms, 1) dot = dot * mask + (1 - mask) * ( -1e+20) # (batch_size, max_num_atoms, 1) attention = F.softmax(dot, dim=1) # (batch_size, max_num_atoms, 1) # Construct next input as attention over memory attended = torch.bmm(memory_transposed, attention) # (batch_size, hidden_size, 1) attended = attended.view( 1, batch_size, self.hidden_size) # (1, batch_size, hidden_size) # Run RNN for one step query, _ = self.set2set_rnn( attended) # (1, batch_size, hidden_size) # Final RNN output is the molecule encodings mol_vecs = query.squeeze(0) # (batch_size, hidden_size) else: mol_vecs = [] # TODO: Maybe do this in parallel with masking rather than looping for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) if self.attention: att_w = torch.matmul(self.W_a(cur_hiddens), cur_hiddens.t()) att_w = F.softmax(att_w, dim=1) att_hiddens = torch.matmul(att_w, cur_hiddens) att_hiddens = self.act_func(self.W_b(att_hiddens)) att_hiddens = self.dropout_layer(att_hiddens) mol_vec = (cur_hiddens + att_hiddens) if viz_dir is not None: visualize_atom_attention(viz_dir, mol_graph.smiles_batch[i], a_size, att_w) else: mol_vec = cur_hiddens # (num_atoms, hidden_size) mol_vec = mol_vec.sum(dim=0) / a_size mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) if self.use_input_features: features_batch = features_batch.to(mol_vecs) if len(features_batch.shape) == 1: features_batch = features_batch.view( [1, features_batch.shape[0]]) mol_vecs = torch.cat([mol_vecs, features_batch], dim=1) # (num_molecules, hidden_size) if self.bert_pretraining: features_preds = self.W_f(mol_vecs) if hasattr(self, 'W_f') else None return {'features': features_preds, 'vocab': atom_preds} return mol_vecs # num_molecules x hidden
def forward( self, mol_graph: BatchMolGraph, features_batch: List[np.ndarray] = None, viz_dir: str = None ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]: """ Encodes a batch of molecular graphs. :param mol_graph: A BatchMolGraph representing a batch of molecular graphs. :param features_batch: A list of ndarrays containing additional features. :param viz_dir: Directory in which to save visualized attention weights. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if self.use_input_features: features_batch = torch.from_numpy(np.stack(features_batch)).float() if self.args.cuda: features_batch = features_batch.cuda() if self.features_only: return features_batch f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, uid, targets = mol_graph.get_components( ) if self.atom_messages: a2a = mol_graph.get_a2a() if self.args.cuda or next(self.parameters()).is_cuda: f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda( ), a2b.cuda(), b2a.cuda(), b2revb.cuda() if self.atom_messages: a2a = a2a.cuda() if self.learn_virtual_edges: atom1_features, atom2_features = f_atoms[b2a], f_atoms[ b2a[b2revb]] # each num_bonds x atom_fdim ve_score = torch.sum( self.lve(atom1_features) * atom2_features, dim=1) + torch.sum( self.lve(atom2_features) * atom1_features, dim=1) is_ve_indicator_index = self.atom_fdim # in current featurization, the first bond feature is 1 or 0 for virtual or not virtual num_virtual = f_bonds[:, is_ve_indicator_index].sum() straight_through_mask = torch.ones( f_bonds.size(0)).cuda() + f_bonds[:, is_ve_indicator_index] * ( ve_score - ve_score.detach()) / num_virtual # normalize for grad norm straight_through_mask = straight_through_mask.unsqueeze(1).repeat( (1, self.hidden_size)) # num_bonds x hidden_size # Input if self.atom_messages: input = self.W_i(f_atoms) # num_atoms x hidden_size else: input = self.W_i(f_bonds) # num_bonds x hidden_size message = self.act_func(input) # num_bonds x hidden_size if self.message_attention: b2b = mol_graph.get_b2b( ) # Warning: this is O(n_atoms^3) when using virtual edges if next(self.parameters()).is_cuda: b2b = b2b.cuda() message_attention_mask = (b2b != 0).float() # num_bonds x max_num_bonds if self.global_attention: global_attention_mask = torch.zeros( mol_graph.n_bonds, mol_graph.n_bonds) # num_bonds x num_bonds for start, length in b_scope: for i in range(start, start + length): global_attention_mask[i, start:start + length] = 1 if next(self.parameters()).is_cuda: global_attention_mask = global_attention_mask.cuda() # Message passing for depth in range(self.depth - 1): if self.undirected: message = (message + message[b2revb]) / 2 if self.learn_virtual_edges: message = message * straight_through_mask if self.message_attention: # TODO: Parallelize attention heads nei_message = index_select_ND(message, b2b) message = message.unsqueeze(1).repeat( (1, nei_message.size(1), 1)) # num_bonds x maxnb x hidden attention_scores = [ (self.W_ma[i](nei_message) * message).sum(dim=2) for i in range(self.num_heads) ] # num_bonds x maxnb attention_scores = [ attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20) for i in range(self.num_heads) ] # num_bonds x maxnb attention_weights = [ F.softmax(attention_scores[i], dim=1) for i in range(self.num_heads) ] # num_bonds x maxnb message_components = [ nei_message * attention_weights[i].unsqueeze(2).repeat( (1, 1, self.hidden_size)) for i in range(self.num_heads) ] # num_bonds x maxnb x hidden message_components = [ component.sum(dim=1) for component in message_components ] # num_bonds x hidden message = torch.cat(message_components, dim=1) # num_bonds x num_heads * hidden elif self.atom_messages: nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum( dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden for lpm in range(self.layers_per_message - 1): message = self.W_h[lpm][depth](message) # num_bonds x hidden message = self.act_func(message) message = self.W_h[self.layers_per_message - 1][depth](message) if self.normalize_messages: message = message / message.norm(dim=1, keepdim=True) if self.master_node: # master_state = self.W_master_in(self.act_func(nei_message.sum(dim=0))) #try something like this to preserve invariance for master node # master_state = self.GRU_master(nei_message.unsqueeze(1)) # master_state = master_state[-1].squeeze(0) #this actually doesn't preserve order invariance anymore mol_vecs = [self.cached_zero_vector] for start, size in b_scope: if size == 0: continue mol_vec = message.narrow(0, start, size) mol_vec = mol_vec.sum(dim=0) / size mol_vecs += [mol_vec for _ in range(size)] master_state = self.act_func( self.W_master_in(torch.stack( mol_vecs, dim=0))) # num_bonds x hidden_size message = self.act_func( input + message + self.W_master_out(master_state)) # num_bonds x hidden_size else: message = self.act_func(input + message) # num_bonds x hidden_size if self.global_attention: attention_scores = torch.matmul( self.W_ga1(message), message.t()) # num_bonds x num_bonds attention_scores = attention_scores * global_attention_mask + ( 1 - global_attention_mask) * (-1e+20 ) # num_bonds x num_bonds attention_weights = F.softmax(attention_scores, dim=1) # num_bonds x num_bonds attention_hiddens = torch.matmul( attention_weights, message) # num_bonds x hidden_size attention_hiddens = self.act_func( self.W_ga2(attention_hiddens)) # num_bonds x hidden_size attention_hiddens = self.dropout_layer( attention_hiddens) # num_bonds x hidden_size message = message + attention_hiddens # num_bonds x hidden_size if viz_dir is not None: visualize_bond_attention(viz_dir, mol_graph, attention_weights, depth) # Bond attention during message passing if self.bond_attention: alphas = [ self.act_func(self.W_ap[j](message)) for j in range(self.attention_pooling_heads) ] # num_bonds x hidden_size alphas = [ self.V_ap[j](alphas[j]) for j in range(self.attention_pooling_heads) ] # num_bonds x 1 alphas = [ F.softmax(alphas[j], dim=0) for j in range(self.attention_pooling_heads) ] # num_bonds x 1 alphas = [ alphas[j].squeeze(1) for j in range(self.attention_pooling_heads) ] # num_bonds att_hiddens = [ torch.mul(alphas[j], message.t()).t() for j in range(self.attention_pooling_heads) ] # num_bonds x hidden_size att_hiddens = sum(att_hiddens) / float( self.attention_pooling_heads) # num_bonds x hidden_size message = att_hiddens # att_hiddens is the new message # Create visualizations (time-consuming, best only done on test set) if self.attention_viz and depth == self.depth - 2: # Visualize at end of primary message passing phase bond_analysis = dict() for dict_key in range(8): bond_analysis[dict_key] = [] # Loop through the individual graphs in the batch for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: # Skip over empty graphs continue else: for j in range(self.attention_pooling_heads): atoms = f_atoms[ a_start:a_start + a_size, :].cpu().numpy() # Atom features bonds1 = a2b[a_start:a_start + a_size, :] # a2b bonds2 = b2a[b_scope[i][0]:b_scope[i][0] + b_scope[i][1]] # b2a bonds_dict = { } # Dictionary to keep track of atoms that bonds are connecting for k in range( len(bonds2)): # Collect info from b2a bonds_dict[ k + 1] = bonds2[k].item() - a_start + 1 for k in range(bonds1.shape[0] ): # Collect info from a2b for m in range(bonds1.shape[1]): bond_num = bonds1[ k, m].item() - b_scope[i][0] + 1 if bond_num > 0: bonds_dict[bond_num] = ( bonds_dict[bond_num], k + 1) # Save weights for this graph weights = alphas[j].cpu().data.numpy( )[b_scope[i][0]:b_scope[i][0] + b_scope[i][1]] id_number = uid[i].item() # Graph uid label = targets[i].item() # Graph label viz_dir = self.args.save_dir + '/' + 'bond_attention_visualizations' # Folder os.makedirs( viz_dir, exist_ok=True ) # Only create folder if not already exist label, num_subgraphs, num_type_2_connections, num_type_1_isolated, num_type_2_isolated = visualize_bond_attention_pooling( atoms, bonds_dict, weights, id_number, label, viz_dir) # Write analysis results if not os.path.exists( self.args.save_dir + '/' + 'attention_analysis.txt'): f = open( self.args.save_dir + '/' + 'attention_analysis.txt', 'w') f.write( '# Category Subgraphs Type 2 Bonds Type 1 Isolated ' 'Type 2 Isolated' + '\n') else: f = open( self.args.save_dir + '/' + 'attention_analysis.txt', 'a') f.write( str(label) + ' ' + str(num_subgraphs) + ' ' + str(num_type_2_connections) + ' ' + str(num_type_1_isolated) + ' ' + str(num_type_2_isolated) + '\n') bond_analysis[label].append(num_subgraphs) bond_analysis[label + 2].append(num_type_2_connections) bond_analysis[label + 4].append(num_type_1_isolated) bond_analysis[label + 6].append(num_type_2_isolated) # # Write analysis results # if not os.path.exists(self.args.save_dir + '/' + 'attention_analysis.txt'): # f = open(self.args.save_dir + '/' + 'attention_analysis.txt', 'w') # f.write('# Category 0 Subgraphs Category 1 Subgraphs ' # 'Category 0 Type 2 Bonds Category 1 Type 2 Bonds ' # 'Category 0 Type 1 Isolated Category 1 Type 1 Isolated ' # 'Category 0 Type 2 Isolated Category 1 Type 2 Isolated' '\n') # else: # f = open(self.args.save_dir + '/' + 'attention_analysis.txt', 'a') # f.write(str(np.mean(np.array(bond_analysis[0]))) + ' ' + # str(np.mean(np.array(bond_analysis[1]))) + ' ' + # str(np.mean(np.array(bond_analysis[2]))) + ' ' + # str(np.mean(np.array(bond_analysis[3]))) + ' ' + # str(np.mean(np.array(bond_analysis[4]))) + ' ' + # str(np.mean(np.array(bond_analysis[5]))) + ' ' + # str(np.mean(np.array(bond_analysis[6]))) + ' ' + # str(np.mean(np.array(bond_analysis[7]))) + ' ' + '\n') # f.close() if self.use_layer_norm: message = self.layer_norm(message) message = self.dropout_layer(message) # num_bonds x hidden if self.master_node and self.use_master_as_output: assert self.hidden_size == self.master_dim mol_vecs = [] for start, size in b_scope: if size == 0: mol_vecs.append(self.cached_zero_vector) else: mol_vecs.append(master_state[start]) return torch.stack(mol_vecs, dim=0) # Get atom hidden states from message hidden states if self.learn_virtual_edges: message = message * straight_through_mask if self.bond_attention_pooling: nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden message = self.act_func(self.W_o(message)) # num_bonds x hidden bond_hiddens = self.dropout_layer(message) # num_bonds x hidden else: a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func( self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer( atom_hiddens) # num_atoms x hidden if self.deepset: atom_hiddens = self.W_s2s_a(atom_hiddens) atom_hiddens = self.act_func(atom_hiddens) atom_hiddens = self.W_s2s_b(atom_hiddens) if self.bert_pretraining: atom_preds = self.W_v(atom_hiddens)[ 1:] # num_atoms x vocab/output size (leave out atom padding) # Readout if self.set2set: # Set up sizes batch_size = len(a_scope) lengths = [length for _, length in a_scope] max_num_atoms = max(lengths) # Set up memory from atom features memory = torch.zeros( batch_size, max_num_atoms, self.hidden_size) # (batch_size, max_num_atoms, hidden_size) for i, (start, size) in enumerate(a_scope): memory[i, :size] = atom_hiddens.narrow(0, start, size) memory_transposed = memory.transpose( 2, 1) # (batch_size, hidden_size, max_num_atoms) # Create mask (1s for atoms, 0s for not atoms) mask = create_mask(lengths, cuda=next( self.parameters()).is_cuda) # (max_num_atoms, batch_size) mask = mask.t().unsqueeze(2) # (batch_size, max_num_atoms, 1) # Set up query query = torch.ones( 1, batch_size, self.hidden_size) # (1, batch_size, hidden_size) # Move to cuda if next(self.parameters()).is_cuda: memory, memory_transposed, query = memory.cuda( ), memory_transposed.cuda(), query.cuda() # Run RNN for _ in range(self.set2set_iters): # Compute attention weights over atoms in each molecule query = query.squeeze(0).unsqueeze( 2) # (batch_size, hidden_size, 1) dot = torch.bmm(memory, query) # (batch_size, max_num_atoms, 1) dot = dot * mask + (1 - mask) * ( -1e+20) # (batch_size, max_num_atoms, 1) attention = F.softmax(dot, dim=1) # (batch_size, max_num_atoms, 1) # Construct next input as attention over memory attended = torch.bmm(memory_transposed, attention) # (batch_size, hidden_size, 1) attended = attended.view( 1, batch_size, self.hidden_size) # (1, batch_size, hidden_size) # Run RNN for one step query, _ = self.set2set_rnn( attended) # (1, batch_size, hidden_size) # Final RNN output is the molecule encodings mol_vecs = query.squeeze(0) # (batch_size, hidden_size) else: mol_vecs = [] if self.bond_attention_pooling: for i, (b_start, b_size) in enumerate(b_scope): if b_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = bond_hiddens.narrow(0, b_start, b_size) alphas = [ self.act_func(self.W_ap[j](cur_hiddens)) for j in range(self.attention_pooling_heads) ] alphas = [ self.V_ap[j](alphas[j]) for j in range(self.attention_pooling_heads) ] alphas = [ F.softmax(alphas[j], dim=0) for j in range(self.attention_pooling_heads) ] alphas = [ alphas[j].squeeze(1) for j in range(self.attention_pooling_heads) ] att_hiddens = [ torch.mul(alphas[j], cur_hiddens.t()).t() for j in range(self.attention_pooling_heads) ] att_hiddens = sum(att_hiddens) / float( self.attention_pooling_heads) mol_vec = att_hiddens mol_vec = mol_vec.sum(dim=0) mol_vecs.append(mol_vec) if self.attention_viz: for j in range(self.attention_pooling_heads): atoms = f_atoms[ a_scope[i][0]:a_scope[i][0] + a_scope[i][1], :].cpu().numpy() bonds1 = a2b[a_scope[i][0]:a_scope[i][0] + a_scope[i][1], :] bonds2 = b2a[b_scope[i][0]:b_scope[i][0] + b_scope[i][1]] bonds_dict = {} for k in range(len(bonds2)): bonds_dict[k + 1] = bonds2[k].item( ) - a_scope[i][0] + 1 for k in range(bonds1.shape[0]): for m in range(bonds1.shape[1]): bond_num = bonds1[ k, m].item() - b_scope[i][0] + 1 if bond_num > 0: bonds_dict[bond_num] = ( bonds_dict[bond_num], k + 1) weights = alphas[j].cpu().data.numpy() id_number = uid[i].item() label = targets[i].item() viz_dir = self.args.save_dir + '/' + 'bond_attention_pooling_visualizations' os.makedirs(viz_dir, exist_ok=True) visualize_bond_attention_pooling( atoms, bonds_dict, weights, id_number, label, viz_dir) else: # TODO: Maybe do this in parallel with masking rather than looping for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) if self.attention: att_w = torch.matmul(self.W_a(cur_hiddens), cur_hiddens.t()) att_w = F.softmax(att_w, dim=1) att_hiddens = torch.matmul(att_w, cur_hiddens) att_hiddens = self.act_func(self.W_b(att_hiddens)) att_hiddens = self.dropout_layer(att_hiddens) mol_vec = (cur_hiddens + att_hiddens) if viz_dir is not None: visualize_atom_attention( viz_dir, mol_graph.smiles_batch[i], a_size, att_w) elif self.attention_pooling: alphas = [ self.act_func(self.W_ap[j](cur_hiddens)) for j in range(self.attention_pooling_heads) ] alphas = [ self.V_ap[j](alphas[j]) for j in range(self.attention_pooling_heads) ] alphas = [ F.softmax(alphas[j], dim=0) for j in range(self.attention_pooling_heads) ] alphas = [ alphas[j].squeeze(1) for j in range(self.attention_pooling_heads) ] att_hiddens = [ torch.mul(alphas[j], cur_hiddens.t()).t() for j in range(self.attention_pooling_heads) ] att_hiddens = sum(att_hiddens) / float( self.attention_pooling_heads) mol_vec = att_hiddens if self.attention_viz: for j in range(self.attention_pooling_heads): atoms = f_atoms[a_start:a_start + a_size, :].cpu().numpy() weights = alphas[j].cpu().data.numpy() id_number = uid[i].item() label = targets[i].item() viz_dir = self.args.save_dir + '/' + 'attention_pooling_visualizations' os.makedirs(viz_dir, exist_ok=True) visualize_attention_pooling( atoms, weights, id_number, label, viz_dir) analyze_attention_pooling( atoms, weights, id_number, label, viz_dir) else: mol_vec = cur_hiddens # (num_atoms, hidden_size) mol_vec = mol_vec.sum( dim=0 ) / a_size #TODO: remove a_size division in case of attention_pooling mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) if self.use_input_features: features_batch = features_batch.to(mol_vecs) if len(features_batch.shape) == 1: features_batch = features_batch.view( [1, features_batch.shape[0]]) mol_vecs = torch.cat([mol_vecs, features_batch], dim=1) # (num_molecules, hidden_size) if self.bert_pretraining: features_preds = self.W_f(mol_vecs) if hasattr(self, 'W_f') else None return {'features': features_preds, 'vocab': atom_preds} return mol_vecs # num_molecules x hidden
def forward(self, mol_graph: BatchMolGraph, features_batch: T.Tensor): f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components( ) if self.atom_messages: a2a = mol_graph.get_a2a() input = self.W_i(f_atoms) else: input = self.W_i(f_bonds) message = self.act_func(input) # Message passing for depth in range(self.depth - 1): if self.undirected: message = (message + message[b2revb]) / 2 if self.atom_messages: nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = T.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum( dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden message = self.W_h(message) message = self.act_func(input + message) # num_bonds x hidden_size message = self.dropout_layer(message) # num_bonds x hidden a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = T.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden # Readout mol_vecs = [] for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) mol_vec = cur_hiddens # (num_atoms, hidden_size) mol_vec = mol_vec.sum(dim=0) / a_size #mol_vec = mol_vec.mean(0).values mol_vecs.append(mol_vec) mol_vecs = T.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) return mol_vecs