def forward( self, mol_graph: BatchMolGraph, features_batch: List[np.ndarray] = None, viz_dir: str = None ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]: """ Encodes a batch of molecular graphs. :param mol_graph: A BatchMolGraph representing a batch of molecular graphs. :param features_batch: A list of ndarrays containing additional features. :param viz_dir: Directory in which to save visualized attention weights. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if self.use_input_features: features_batch = torch.from_numpy(np.stack(features_batch)).float() if self.args.cuda: features_batch = features_batch.cuda() if self.features_only: return features_batch f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components( ) if self.atom_messages: a2a = mol_graph.get_a2a() if self.args.cuda or next(self.parameters()).is_cuda: f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda( ), a2b.cuda(), b2a.cuda(), b2revb.cuda() if self.atom_messages: a2a = a2a.cuda() if self.learn_virtual_edges: atom1_features, atom2_features = f_atoms[b2a], f_atoms[ b2a[b2revb]] # each num_bonds x atom_fdim ve_score = torch.sum( self.lve(atom1_features) * atom2_features, dim=1) + torch.sum( self.lve(atom2_features) * atom1_features, dim=1) is_ve_indicator_index = self.atom_fdim # in current featurization, the first bond feature is 1 or 0 for virtual or not virtual num_virtual = f_bonds[:, is_ve_indicator_index].sum() straight_through_mask = torch.ones( f_bonds.size(0)).cuda() + f_bonds[:, is_ve_indicator_index] * ( ve_score - ve_score.detach()) / num_virtual # normalize for grad norm straight_through_mask = straight_through_mask.unsqueeze(1).repeat( (1, self.hidden_size)) # num_bonds x hidden_size # Input if self.atom_messages: input = self.W_i(f_atoms) # num_atoms x hidden_size else: input = self.W_i(f_bonds) # num_bonds x hidden_size message = self.act_func(input) # num_bonds x hidden_size if self.message_attention: b2b = mol_graph.get_b2b( ) # Warning: this is O(n_atoms^3) when using virtual edges if next(self.parameters()).is_cuda: b2b = b2b.cuda() message_attention_mask = (b2b != 0).float() # num_bonds x max_num_bonds if self.global_attention: global_attention_mask = torch.zeros( mol_graph.n_bonds, mol_graph.n_bonds) # num_bonds x num_bonds for start, length in b_scope: for i in range(start, start + length): global_attention_mask[i, start:start + length] = 1 if next(self.parameters()).is_cuda: global_attention_mask = global_attention_mask.cuda() # Message passing for depth in range(self.depth - 1): if self.undirected: message = (message + message[b2revb]) / 2 if self.learn_virtual_edges: message = message * straight_through_mask if self.message_attention: # TODO: Parallelize attention heads nei_message = index_select_ND(message, b2b) message = message.unsqueeze(1).repeat( (1, nei_message.size(1), 1)) # num_bonds x maxnb x hidden attention_scores = [ (self.W_ma[i](nei_message) * message).sum(dim=2) for i in range(self.num_heads) ] # num_bonds x maxnb attention_scores = [ attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20) for i in range(self.num_heads) ] # num_bonds x maxnb attention_weights = [ F.softmax(attention_scores[i], dim=1) for i in range(self.num_heads) ] # num_bonds x maxnb message_components = [ nei_message * attention_weights[i].unsqueeze(2).repeat( (1, 1, self.hidden_size)) for i in range(self.num_heads) ] # num_bonds x maxnb x hidden message_components = [ component.sum(dim=1) for component in message_components ] # num_bonds x hidden message = torch.cat(message_components, dim=1) # num_bonds x num_heads * hidden elif self.atom_messages: nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum( dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden for lpm in range(self.layers_per_message - 1): message = self.W_h[lpm][depth](message) # num_bonds x hidden message = self.act_func(message) message = self.W_h[self.layers_per_message - 1][depth](message) if self.normalize_messages: message = message / message.norm(dim=1, keepdim=True) if self.master_node: # master_state = self.W_master_in(self.act_func(nei_message.sum(dim=0))) #try something like this to preserve invariance for master node # master_state = self.GRU_master(nei_message.unsqueeze(1)) # master_state = master_state[-1].squeeze(0) #this actually doesn't preserve order invariance anymore mol_vecs = [self.cached_zero_vector] for start, size in b_scope: if size == 0: continue mol_vec = message.narrow(0, start, size) mol_vec = mol_vec.sum(dim=0) / size mol_vecs += [mol_vec for _ in range(size)] master_state = self.act_func( self.W_master_in(torch.stack( mol_vecs, dim=0))) # num_bonds x hidden_size message = self.act_func( input + message + self.W_master_out(master_state)) # num_bonds x hidden_size else: message = self.act_func(input + message) # num_bonds x hidden_size if self.global_attention: attention_scores = torch.matmul( self.W_ga1(message), message.t()) # num_bonds x num_bonds attention_scores = attention_scores * global_attention_mask + ( 1 - global_attention_mask) * (-1e+20 ) # num_bonds x num_bonds attention_weights = F.softmax(attention_scores, dim=1) # num_bonds x num_bonds attention_hiddens = torch.matmul( attention_weights, message) # num_bonds x hidden_size attention_hiddens = self.act_func( self.W_ga2(attention_hiddens)) # num_bonds x hidden_size attention_hiddens = self.dropout_layer( attention_hiddens) # num_bonds x hidden_size message = message + attention_hiddens # num_bonds x hidden_size if viz_dir is not None: visualize_bond_attention(viz_dir, mol_graph, attention_weights, depth) if self.use_layer_norm: message = self.layer_norm(message) message = self.dropout_layer(message) # num_bonds x hidden if self.master_node and self.use_master_as_output: assert self.hidden_size == self.master_dim mol_vecs = [] for start, size in b_scope: if size == 0: mol_vecs.append(self.cached_zero_vector) else: mol_vecs.append(master_state[start]) return torch.stack(mol_vecs, dim=0) # Get atom hidden states from message hidden states if self.learn_virtual_edges: message = message * straight_through_mask a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func(self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer(atom_hiddens) # num_atoms x hidden if self.deepset: atom_hiddens = self.W_s2s_a(atom_hiddens) atom_hiddens = self.act_func(atom_hiddens) atom_hiddens = self.W_s2s_b(atom_hiddens) if self.bert_pretraining: atom_preds = self.W_v(atom_hiddens)[ 1:] # num_atoms x vocab/output size (leave out atom padding) # Readout if self.set2set: # Set up sizes batch_size = len(a_scope) lengths = [length for _, length in a_scope] max_num_atoms = max(lengths) # Set up memory from atom features memory = torch.zeros( batch_size, max_num_atoms, self.hidden_size) # (batch_size, max_num_atoms, hidden_size) for i, (start, size) in enumerate(a_scope): memory[i, :size] = atom_hiddens.narrow(0, start, size) memory_transposed = memory.transpose( 2, 1) # (batch_size, hidden_size, max_num_atoms) # Create mask (1s for atoms, 0s for not atoms) mask = create_mask(lengths, cuda=next( self.parameters()).is_cuda) # (max_num_atoms, batch_size) mask = mask.t().unsqueeze(2) # (batch_size, max_num_atoms, 1) # Set up query query = torch.ones( 1, batch_size, self.hidden_size) # (1, batch_size, hidden_size) # Move to cuda if next(self.parameters()).is_cuda: memory, memory_transposed, query = memory.cuda( ), memory_transposed.cuda(), query.cuda() # Run RNN for _ in range(self.set2set_iters): # Compute attention weights over atoms in each molecule query = query.squeeze(0).unsqueeze( 2) # (batch_size, hidden_size, 1) dot = torch.bmm(memory, query) # (batch_size, max_num_atoms, 1) dot = dot * mask + (1 - mask) * ( -1e+20) # (batch_size, max_num_atoms, 1) attention = F.softmax(dot, dim=1) # (batch_size, max_num_atoms, 1) # Construct next input as attention over memory attended = torch.bmm(memory_transposed, attention) # (batch_size, hidden_size, 1) attended = attended.view( 1, batch_size, self.hidden_size) # (1, batch_size, hidden_size) # Run RNN for one step query, _ = self.set2set_rnn( attended) # (1, batch_size, hidden_size) # Final RNN output is the molecule encodings mol_vecs = query.squeeze(0) # (batch_size, hidden_size) else: mol_vecs = [] # TODO: Maybe do this in parallel with masking rather than looping for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) if self.attention: att_w = torch.matmul(self.W_a(cur_hiddens), cur_hiddens.t()) att_w = F.softmax(att_w, dim=1) att_hiddens = torch.matmul(att_w, cur_hiddens) att_hiddens = self.act_func(self.W_b(att_hiddens)) att_hiddens = self.dropout_layer(att_hiddens) mol_vec = (cur_hiddens + att_hiddens) if viz_dir is not None: visualize_atom_attention(viz_dir, mol_graph.smiles_batch[i], a_size, att_w) else: mol_vec = cur_hiddens # (num_atoms, hidden_size) mol_vec = mol_vec.sum(dim=0) / a_size mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) if self.use_input_features: features_batch = features_batch.to(mol_vecs) if len(features_batch.shape) == 1: features_batch = features_batch.view( [1, features_batch.shape[0]]) mol_vecs = torch.cat([mol_vecs, features_batch], dim=1) # (num_molecules, hidden_size) if self.bert_pretraining: features_preds = self.W_f(mol_vecs) if hasattr(self, 'W_f') else None return {'features': features_preds, 'vocab': atom_preds} return mol_vecs # num_molecules x hidden
def forward( self, mol_graph: BatchMolGraph, features_batch: List[np.ndarray] = None, viz_dir: str = None ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]: """ Encodes a batch of molecular graphs. :param mol_graph: A BatchMolGraph representing a batch of molecular graphs. :param features_batch: A list of ndarrays containing additional features. :param viz_dir: Directory in which to save visualized attention weights. :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule. """ if self.use_input_features: features_batch = torch.from_numpy(np.stack(features_batch)).float() if self.args.cuda: features_batch = features_batch.cuda() if self.features_only: return features_batch f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, uid, targets = mol_graph.get_components( ) if self.atom_messages: a2a = mol_graph.get_a2a() if self.args.cuda or next(self.parameters()).is_cuda: f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda( ), a2b.cuda(), b2a.cuda(), b2revb.cuda() if self.atom_messages: a2a = a2a.cuda() if self.learn_virtual_edges: atom1_features, atom2_features = f_atoms[b2a], f_atoms[ b2a[b2revb]] # each num_bonds x atom_fdim ve_score = torch.sum( self.lve(atom1_features) * atom2_features, dim=1) + torch.sum( self.lve(atom2_features) * atom1_features, dim=1) is_ve_indicator_index = self.atom_fdim # in current featurization, the first bond feature is 1 or 0 for virtual or not virtual num_virtual = f_bonds[:, is_ve_indicator_index].sum() straight_through_mask = torch.ones( f_bonds.size(0)).cuda() + f_bonds[:, is_ve_indicator_index] * ( ve_score - ve_score.detach()) / num_virtual # normalize for grad norm straight_through_mask = straight_through_mask.unsqueeze(1).repeat( (1, self.hidden_size)) # num_bonds x hidden_size # Input if self.atom_messages: input = self.W_i(f_atoms) # num_atoms x hidden_size else: input = self.W_i(f_bonds) # num_bonds x hidden_size message = self.act_func(input) # num_bonds x hidden_size if self.message_attention: b2b = mol_graph.get_b2b( ) # Warning: this is O(n_atoms^3) when using virtual edges if next(self.parameters()).is_cuda: b2b = b2b.cuda() message_attention_mask = (b2b != 0).float() # num_bonds x max_num_bonds if self.global_attention: global_attention_mask = torch.zeros( mol_graph.n_bonds, mol_graph.n_bonds) # num_bonds x num_bonds for start, length in b_scope: for i in range(start, start + length): global_attention_mask[i, start:start + length] = 1 if next(self.parameters()).is_cuda: global_attention_mask = global_attention_mask.cuda() # Message passing for depth in range(self.depth - 1): if self.undirected: message = (message + message[b2revb]) / 2 if self.learn_virtual_edges: message = message * straight_through_mask if self.message_attention: # TODO: Parallelize attention heads nei_message = index_select_ND(message, b2b) message = message.unsqueeze(1).repeat( (1, nei_message.size(1), 1)) # num_bonds x maxnb x hidden attention_scores = [ (self.W_ma[i](nei_message) * message).sum(dim=2) for i in range(self.num_heads) ] # num_bonds x maxnb attention_scores = [ attention_scores[i] * message_attention_mask + (1 - message_attention_mask) * (-1e+20) for i in range(self.num_heads) ] # num_bonds x maxnb attention_weights = [ F.softmax(attention_scores[i], dim=1) for i in range(self.num_heads) ] # num_bonds x maxnb message_components = [ nei_message * attention_weights[i].unsqueeze(2).repeat( (1, 1, self.hidden_size)) for i in range(self.num_heads) ] # num_bonds x maxnb x hidden message_components = [ component.sum(dim=1) for component in message_components ] # num_bonds x hidden message = torch.cat(message_components, dim=1) # num_bonds x num_heads * hidden elif self.atom_messages: nei_a_message = index_select_ND( message, a2a) # num_atoms x max_num_bonds x hidden nei_f_bonds = index_select_ND( f_bonds, a2b) # num_atoms x max_num_bonds x bond_fdim nei_message = torch.cat( (nei_a_message, nei_f_bonds), dim=2) # num_atoms x max_num_bonds x hidden + bond_fdim message = nei_message.sum( dim=1) # num_atoms x hidden + bond_fdim else: # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1) # message a_message = sum(nei_a_message) rev_message nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden for lpm in range(self.layers_per_message - 1): message = self.W_h[lpm][depth](message) # num_bonds x hidden message = self.act_func(message) message = self.W_h[self.layers_per_message - 1][depth](message) if self.normalize_messages: message = message / message.norm(dim=1, keepdim=True) if self.master_node: # master_state = self.W_master_in(self.act_func(nei_message.sum(dim=0))) #try something like this to preserve invariance for master node # master_state = self.GRU_master(nei_message.unsqueeze(1)) # master_state = master_state[-1].squeeze(0) #this actually doesn't preserve order invariance anymore mol_vecs = [self.cached_zero_vector] for start, size in b_scope: if size == 0: continue mol_vec = message.narrow(0, start, size) mol_vec = mol_vec.sum(dim=0) / size mol_vecs += [mol_vec for _ in range(size)] master_state = self.act_func( self.W_master_in(torch.stack( mol_vecs, dim=0))) # num_bonds x hidden_size message = self.act_func( input + message + self.W_master_out(master_state)) # num_bonds x hidden_size else: message = self.act_func(input + message) # num_bonds x hidden_size if self.global_attention: attention_scores = torch.matmul( self.W_ga1(message), message.t()) # num_bonds x num_bonds attention_scores = attention_scores * global_attention_mask + ( 1 - global_attention_mask) * (-1e+20 ) # num_bonds x num_bonds attention_weights = F.softmax(attention_scores, dim=1) # num_bonds x num_bonds attention_hiddens = torch.matmul( attention_weights, message) # num_bonds x hidden_size attention_hiddens = self.act_func( self.W_ga2(attention_hiddens)) # num_bonds x hidden_size attention_hiddens = self.dropout_layer( attention_hiddens) # num_bonds x hidden_size message = message + attention_hiddens # num_bonds x hidden_size if viz_dir is not None: visualize_bond_attention(viz_dir, mol_graph, attention_weights, depth) # Bond attention during message passing if self.bond_attention: alphas = [ self.act_func(self.W_ap[j](message)) for j in range(self.attention_pooling_heads) ] # num_bonds x hidden_size alphas = [ self.V_ap[j](alphas[j]) for j in range(self.attention_pooling_heads) ] # num_bonds x 1 alphas = [ F.softmax(alphas[j], dim=0) for j in range(self.attention_pooling_heads) ] # num_bonds x 1 alphas = [ alphas[j].squeeze(1) for j in range(self.attention_pooling_heads) ] # num_bonds att_hiddens = [ torch.mul(alphas[j], message.t()).t() for j in range(self.attention_pooling_heads) ] # num_bonds x hidden_size att_hiddens = sum(att_hiddens) / float( self.attention_pooling_heads) # num_bonds x hidden_size message = att_hiddens # att_hiddens is the new message # Create visualizations (time-consuming, best only done on test set) if self.attention_viz and depth == self.depth - 2: # Visualize at end of primary message passing phase bond_analysis = dict() for dict_key in range(8): bond_analysis[dict_key] = [] # Loop through the individual graphs in the batch for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: # Skip over empty graphs continue else: for j in range(self.attention_pooling_heads): atoms = f_atoms[ a_start:a_start + a_size, :].cpu().numpy() # Atom features bonds1 = a2b[a_start:a_start + a_size, :] # a2b bonds2 = b2a[b_scope[i][0]:b_scope[i][0] + b_scope[i][1]] # b2a bonds_dict = { } # Dictionary to keep track of atoms that bonds are connecting for k in range( len(bonds2)): # Collect info from b2a bonds_dict[ k + 1] = bonds2[k].item() - a_start + 1 for k in range(bonds1.shape[0] ): # Collect info from a2b for m in range(bonds1.shape[1]): bond_num = bonds1[ k, m].item() - b_scope[i][0] + 1 if bond_num > 0: bonds_dict[bond_num] = ( bonds_dict[bond_num], k + 1) # Save weights for this graph weights = alphas[j].cpu().data.numpy( )[b_scope[i][0]:b_scope[i][0] + b_scope[i][1]] id_number = uid[i].item() # Graph uid label = targets[i].item() # Graph label viz_dir = self.args.save_dir + '/' + 'bond_attention_visualizations' # Folder os.makedirs( viz_dir, exist_ok=True ) # Only create folder if not already exist label, num_subgraphs, num_type_2_connections, num_type_1_isolated, num_type_2_isolated = visualize_bond_attention_pooling( atoms, bonds_dict, weights, id_number, label, viz_dir) # Write analysis results if not os.path.exists( self.args.save_dir + '/' + 'attention_analysis.txt'): f = open( self.args.save_dir + '/' + 'attention_analysis.txt', 'w') f.write( '# Category Subgraphs Type 2 Bonds Type 1 Isolated ' 'Type 2 Isolated' + '\n') else: f = open( self.args.save_dir + '/' + 'attention_analysis.txt', 'a') f.write( str(label) + ' ' + str(num_subgraphs) + ' ' + str(num_type_2_connections) + ' ' + str(num_type_1_isolated) + ' ' + str(num_type_2_isolated) + '\n') bond_analysis[label].append(num_subgraphs) bond_analysis[label + 2].append(num_type_2_connections) bond_analysis[label + 4].append(num_type_1_isolated) bond_analysis[label + 6].append(num_type_2_isolated) # # Write analysis results # if not os.path.exists(self.args.save_dir + '/' + 'attention_analysis.txt'): # f = open(self.args.save_dir + '/' + 'attention_analysis.txt', 'w') # f.write('# Category 0 Subgraphs Category 1 Subgraphs ' # 'Category 0 Type 2 Bonds Category 1 Type 2 Bonds ' # 'Category 0 Type 1 Isolated Category 1 Type 1 Isolated ' # 'Category 0 Type 2 Isolated Category 1 Type 2 Isolated' '\n') # else: # f = open(self.args.save_dir + '/' + 'attention_analysis.txt', 'a') # f.write(str(np.mean(np.array(bond_analysis[0]))) + ' ' + # str(np.mean(np.array(bond_analysis[1]))) + ' ' + # str(np.mean(np.array(bond_analysis[2]))) + ' ' + # str(np.mean(np.array(bond_analysis[3]))) + ' ' + # str(np.mean(np.array(bond_analysis[4]))) + ' ' + # str(np.mean(np.array(bond_analysis[5]))) + ' ' + # str(np.mean(np.array(bond_analysis[6]))) + ' ' + # str(np.mean(np.array(bond_analysis[7]))) + ' ' + '\n') # f.close() if self.use_layer_norm: message = self.layer_norm(message) message = self.dropout_layer(message) # num_bonds x hidden if self.master_node and self.use_master_as_output: assert self.hidden_size == self.master_dim mol_vecs = [] for start, size in b_scope: if size == 0: mol_vecs.append(self.cached_zero_vector) else: mol_vecs.append(master_state[start]) return torch.stack(mol_vecs, dim=0) # Get atom hidden states from message hidden states if self.learn_virtual_edges: message = message * straight_through_mask if self.bond_attention_pooling: nei_a_message = index_select_ND( message, a2b) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden rev_message = message[b2revb] # num_bonds x hidden message = a_message[b2a] - rev_message # num_bonds x hidden message = self.act_func(self.W_o(message)) # num_bonds x hidden bond_hiddens = self.dropout_layer(message) # num_bonds x hidden else: a2x = a2a if self.atom_messages else a2b nei_a_message = index_select_ND( message, a2x) # num_atoms x max_num_bonds x hidden a_message = nei_a_message.sum(dim=1) # num_atoms x hidden a_input = torch.cat([f_atoms, a_message], dim=1) # num_atoms x (atom_fdim + hidden) atom_hiddens = self.act_func( self.W_o(a_input)) # num_atoms x hidden atom_hiddens = self.dropout_layer( atom_hiddens) # num_atoms x hidden if self.deepset: atom_hiddens = self.W_s2s_a(atom_hiddens) atom_hiddens = self.act_func(atom_hiddens) atom_hiddens = self.W_s2s_b(atom_hiddens) if self.bert_pretraining: atom_preds = self.W_v(atom_hiddens)[ 1:] # num_atoms x vocab/output size (leave out atom padding) # Readout if self.set2set: # Set up sizes batch_size = len(a_scope) lengths = [length for _, length in a_scope] max_num_atoms = max(lengths) # Set up memory from atom features memory = torch.zeros( batch_size, max_num_atoms, self.hidden_size) # (batch_size, max_num_atoms, hidden_size) for i, (start, size) in enumerate(a_scope): memory[i, :size] = atom_hiddens.narrow(0, start, size) memory_transposed = memory.transpose( 2, 1) # (batch_size, hidden_size, max_num_atoms) # Create mask (1s for atoms, 0s for not atoms) mask = create_mask(lengths, cuda=next( self.parameters()).is_cuda) # (max_num_atoms, batch_size) mask = mask.t().unsqueeze(2) # (batch_size, max_num_atoms, 1) # Set up query query = torch.ones( 1, batch_size, self.hidden_size) # (1, batch_size, hidden_size) # Move to cuda if next(self.parameters()).is_cuda: memory, memory_transposed, query = memory.cuda( ), memory_transposed.cuda(), query.cuda() # Run RNN for _ in range(self.set2set_iters): # Compute attention weights over atoms in each molecule query = query.squeeze(0).unsqueeze( 2) # (batch_size, hidden_size, 1) dot = torch.bmm(memory, query) # (batch_size, max_num_atoms, 1) dot = dot * mask + (1 - mask) * ( -1e+20) # (batch_size, max_num_atoms, 1) attention = F.softmax(dot, dim=1) # (batch_size, max_num_atoms, 1) # Construct next input as attention over memory attended = torch.bmm(memory_transposed, attention) # (batch_size, hidden_size, 1) attended = attended.view( 1, batch_size, self.hidden_size) # (1, batch_size, hidden_size) # Run RNN for one step query, _ = self.set2set_rnn( attended) # (1, batch_size, hidden_size) # Final RNN output is the molecule encodings mol_vecs = query.squeeze(0) # (batch_size, hidden_size) else: mol_vecs = [] if self.bond_attention_pooling: for i, (b_start, b_size) in enumerate(b_scope): if b_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = bond_hiddens.narrow(0, b_start, b_size) alphas = [ self.act_func(self.W_ap[j](cur_hiddens)) for j in range(self.attention_pooling_heads) ] alphas = [ self.V_ap[j](alphas[j]) for j in range(self.attention_pooling_heads) ] alphas = [ F.softmax(alphas[j], dim=0) for j in range(self.attention_pooling_heads) ] alphas = [ alphas[j].squeeze(1) for j in range(self.attention_pooling_heads) ] att_hiddens = [ torch.mul(alphas[j], cur_hiddens.t()).t() for j in range(self.attention_pooling_heads) ] att_hiddens = sum(att_hiddens) / float( self.attention_pooling_heads) mol_vec = att_hiddens mol_vec = mol_vec.sum(dim=0) mol_vecs.append(mol_vec) if self.attention_viz: for j in range(self.attention_pooling_heads): atoms = f_atoms[ a_scope[i][0]:a_scope[i][0] + a_scope[i][1], :].cpu().numpy() bonds1 = a2b[a_scope[i][0]:a_scope[i][0] + a_scope[i][1], :] bonds2 = b2a[b_scope[i][0]:b_scope[i][0] + b_scope[i][1]] bonds_dict = {} for k in range(len(bonds2)): bonds_dict[k + 1] = bonds2[k].item( ) - a_scope[i][0] + 1 for k in range(bonds1.shape[0]): for m in range(bonds1.shape[1]): bond_num = bonds1[ k, m].item() - b_scope[i][0] + 1 if bond_num > 0: bonds_dict[bond_num] = ( bonds_dict[bond_num], k + 1) weights = alphas[j].cpu().data.numpy() id_number = uid[i].item() label = targets[i].item() viz_dir = self.args.save_dir + '/' + 'bond_attention_pooling_visualizations' os.makedirs(viz_dir, exist_ok=True) visualize_bond_attention_pooling( atoms, bonds_dict, weights, id_number, label, viz_dir) else: # TODO: Maybe do this in parallel with masking rather than looping for i, (a_start, a_size) in enumerate(a_scope): if a_size == 0: mol_vecs.append(self.cached_zero_vector) else: cur_hiddens = atom_hiddens.narrow(0, a_start, a_size) if self.attention: att_w = torch.matmul(self.W_a(cur_hiddens), cur_hiddens.t()) att_w = F.softmax(att_w, dim=1) att_hiddens = torch.matmul(att_w, cur_hiddens) att_hiddens = self.act_func(self.W_b(att_hiddens)) att_hiddens = self.dropout_layer(att_hiddens) mol_vec = (cur_hiddens + att_hiddens) if viz_dir is not None: visualize_atom_attention( viz_dir, mol_graph.smiles_batch[i], a_size, att_w) elif self.attention_pooling: alphas = [ self.act_func(self.W_ap[j](cur_hiddens)) for j in range(self.attention_pooling_heads) ] alphas = [ self.V_ap[j](alphas[j]) for j in range(self.attention_pooling_heads) ] alphas = [ F.softmax(alphas[j], dim=0) for j in range(self.attention_pooling_heads) ] alphas = [ alphas[j].squeeze(1) for j in range(self.attention_pooling_heads) ] att_hiddens = [ torch.mul(alphas[j], cur_hiddens.t()).t() for j in range(self.attention_pooling_heads) ] att_hiddens = sum(att_hiddens) / float( self.attention_pooling_heads) mol_vec = att_hiddens if self.attention_viz: for j in range(self.attention_pooling_heads): atoms = f_atoms[a_start:a_start + a_size, :].cpu().numpy() weights = alphas[j].cpu().data.numpy() id_number = uid[i].item() label = targets[i].item() viz_dir = self.args.save_dir + '/' + 'attention_pooling_visualizations' os.makedirs(viz_dir, exist_ok=True) visualize_attention_pooling( atoms, weights, id_number, label, viz_dir) analyze_attention_pooling( atoms, weights, id_number, label, viz_dir) else: mol_vec = cur_hiddens # (num_atoms, hidden_size) mol_vec = mol_vec.sum( dim=0 ) / a_size #TODO: remove a_size division in case of attention_pooling mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size) if self.use_input_features: features_batch = features_batch.to(mol_vecs) if len(features_batch.shape) == 1: features_batch = features_batch.view( [1, features_batch.shape[0]]) mol_vecs = torch.cat([mol_vecs, features_batch], dim=1) # (num_molecules, hidden_size) if self.bert_pretraining: features_preds = self.W_f(mol_vecs) if hasattr(self, 'W_f') else None return {'features': features_preds, 'vocab': atom_preds} return mol_vecs # num_molecules x hidden