def data_loader(database, latent_graph_dict, latent_graph_list, batch_size=4, shuffle=True): # -- This loader creates Databatches from a Database, run on either the training or validation set! # Establish empty lists graph_list_p1 = [] graph_list_p2 = [] graph_list_pm = [] target_list = [] Temperature_list = [] # Iterate through database and construct latent graphs of the molecules, as well as the target tensors for itm in range(len(database)): graph_p1 = latent_graph_list[latent_graph_dict[database["Molecule1"].loc[itm]]] graph_p2 = latent_graph_list[latent_graph_dict[database["Molecule2"].loc[itm]]] graph_pm = latent_graph_list[latent_graph_dict[database["Membrane"].loc[itm]]] graph_p1.y = torch.tensor([float(database["molp1"].loc[itm])]) graph_p2.y = torch.tensor([float(database["molp2"].loc[itm])]) graph_pm.y = torch.tensor([float(database["molpm"].loc[itm])]) graph_list_p1.append(graph_p1) graph_list_p2.append(graph_p2) graph_list_pm.append(graph_pm) target = torch.tensor([database[str(i)].loc[itm]*10 for i in range(2, 17)], dtype=torch.float) target_list.append(target) Temperature_list.append(database["Temperature"].loc[itm]) # Generate a shuffled integer list that then produces distinct batches idxs = torch.randperm(len(database)).tolist() if shuffle==False: idxs = sorted(idxs) # Generate the batches batch_p1 = [] batch_p2 = [] batch_pm = [] batch_target = [] batch_Temperature = [] for b in range(len(database))[::batch_size]: # Creating empty Batch objects B_p1 = Batch() B_p2 = Batch() B_pm = Batch() # Creating empty lists that will be concatenated later (advanced minibatching) stack_x_p1, stack_x_p2, stack_x_pm = [], [], [] stack_ei_p1, stack_ei_p2, stack_ei_pm = [], [], [] stack_ea_p1, stack_ea_p2, stack_ea_pm = [], [], [] stack_y_p1, stack_y_p2, stack_y_pm = [], [], [] stack_btc_p1, stack_btc_p2, stack_btc_pm = [], [], [] stack_target = [] stack_Temperature = [] # Creating the batch shifts, that shift the edge indices b_shift_p1 = 0 b_shift_p2 = 0 b_shift_pm = 0 for i in range(batch_size): if (b+i) < len(database): stack_x_p1.append(graph_list_p1[idxs[b + i]].x) stack_x_p2.append(graph_list_p2[idxs[b + i]].x) stack_x_pm.append(graph_list_pm[idxs[b + i]].x) stack_ei_p1.append(graph_list_p1[idxs[b + i]].edge_index+b_shift_p1) stack_ei_p2.append(graph_list_p2[idxs[b + i]].edge_index+b_shift_p2) stack_ei_pm.append(graph_list_pm[idxs[b + i]].edge_index+b_shift_pm) # Updating the shifts b_shift_p1 += graph_list_p1[idxs[b + i]].x.size()[0] b_shift_p2 += graph_list_p2[idxs[b + i]].x.size()[0] b_shift_pm += graph_list_pm[idxs[b + i]].x.size()[0] stack_ea_p1.append(graph_list_p1[idxs[b + i]].edge_attr) stack_ea_p2.append(graph_list_p2[idxs[b + i]].edge_attr) stack_ea_pm.append(graph_list_pm[idxs[b + i]].edge_attr) stack_y_p1.append(graph_list_p1[idxs[b + i]].y) stack_y_p2.append(graph_list_p2[idxs[b + i]].y) stack_y_pm.append(graph_list_pm[idxs[b + i]].y) stack_btc_p1.append( torch.full([graph_list_p1[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long)) # FILL VALUE IS PROBABLY JUST i! (OR NOT) stack_btc_p2.append( torch.full([graph_list_p2[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long)) stack_btc_pm.append( torch.full([graph_list_pm[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long)) stack_target.append(target_list[idxs[b + i]]) stack_Temperature.append(Temperature_list[int(b+i)]) # print(stack_y_p1) # print(stack_x_p1) B_p1.edge_index=torch.cat(stack_ei_p1,dim=1) B_p1.x=torch.cat(stack_x_p1,dim=0) B_p1.edge_attr=torch.cat(stack_ea_p1,dim=0) B_p1.y=torch.cat(stack_y_p1,dim=0) B_p1.batch=torch.cat(stack_btc_p1,dim=0) B_p2.edge_index = torch.cat(stack_ei_p2, dim=1) B_p2.x = torch.cat(stack_x_p2, dim=0) B_p2.edge_attr = torch.cat(stack_ea_p2, dim=0) B_p2.y = torch.cat(stack_y_p2, dim=0) B_p2.batch = torch.cat(stack_btc_p2, dim=0) B_pm.edge_index = torch.cat(stack_ei_pm, dim=1) B_pm.x = torch.cat(stack_x_pm, dim=0) B_pm.edge_attr = torch.cat(stack_ea_pm, dim=0) B_pm.y = torch.cat(stack_y_pm, dim=0) B_pm.batch = torch.cat(stack_btc_pm, dim=0) B_target = torch.stack(stack_target, dim=0) B_Temperature = torch.Tensor(stack_Temperature) # Appending batches batch_p1.append(B_p1) batch_p2.append(B_p2) batch_pm.append(B_pm) batch_target.append(B_target) batch_Temperature.append(B_Temperature) # Return return [batch_p1, batch_p2, batch_pm, batch_target, batch_Temperature]
def forward(self, plist): # Here we instantiate the three molecular graphs for the first level GNN p1, p2, pm, Temperature = plist x_p1, ei_p1, ea_p1, u_p1, btc_p1 = p1.x, p1.edge_index, p1.edge_attr, p1.y, p1.batch x_p2, ei_p2, ea_p2, u_p2, btc_p2 = p2.x, p2.edge_index, p2.edge_attr, p2.y, p2.batch x_pm, ei_pm, ea_pm, u_pm, btc_pm = pm.x, pm.edge_index, pm.edge_attr, pm.y, pm.batch # Embed the node and edge features enc_x_p1 = self.encoding_node_1(x_p1) enc_x_p2 = self.encoding_node_1(x_p2) enc_x_pm = self.encoding_node_1(x_pm) enc_ea_p1 = self.encoding_edge_1(ea_p1) enc_ea_p2 = self.encoding_edge_1(ea_p2) enc_ea_pm = self.encoding_edge_1(ea_pm) #Create the empty molecular graphs for feature extraction, graph level one u1 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE), fill_value=0.1, dtype=torch.float).to(device) u2 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE), fill_value=0.1, dtype=torch.float).to(device) um = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE), fill_value=0.1, dtype=torch.float).to(device) # Now the first level rounds of message passing are performed, molecular features are constructed for _ in range(self.no_mp_one): enc_x_p1, enc_ea_p1, u1 = self.meta1(x=enc_x_p1, edge_index=ei_p1, edge_attr=enc_ea_p1, u=u1, batch=btc_p1) enc_x_p2, enc_ea_p2, u2 = self.meta1(x=enc_x_p2, edge_index=ei_p2, edge_attr=enc_ea_p2, u=u2, batch=btc_p2) enc_x_pm, enc_ea_pm, um = self.meta1(x=enc_x_pm, edge_index=ei_pm, edge_attr=enc_ea_pm, u=um, batch=btc_pm) # --- GRAPH GENERATION SECOND LEVEL #Encode the nodes second level u1 = self.encoding_node_2(u1) u2 = self.encoding_node_2(u2) um = self.encoding_node_2(um) # Instantiate new Batch object for second level graph nu_Batch = Batch() # Create edge indices for the second level (no connection between p1 and p2 nu_ei = [] temp_ei = torch.tensor([[0, 2, 1, 2], [2, 0, 2, 1]], dtype=torch.long) for b in range(self.batch_size): nu_ei.append( temp_ei + b * 3 ) # +3 because this is the number of nodes in the second stage graph nu_Batch.edge_index = torch.cat(nu_ei, dim=1) # Create the edge features for the second level graphs from the first level "u"s p1_div_pm = u_p1 / u_pm p2_div_pm = u_p2 / u_pm #p1_div_p2 = u_p1 / (u_p2-1) # Concatenate the temperature and molecular percentages concat_T_p1pm = torch.transpose(torch.stack([Temperature, p1_div_pm]), 0, 1) concat_T_p2pm = torch.transpose(torch.stack([Temperature, p2_div_pm]), 0, 1) # Encode the edge features concat_T_p1pm = self.encoding_edge_2(concat_T_p1pm) concat_T_p2pm = self.encoding_edge_2(concat_T_p2pm) nu_ea = [] for b in range(self.batch_size): temp_ea = [] # Appending twice because of bidirectional edges temp_ea.append(concat_T_p1pm[b]) temp_ea.append(concat_T_p1pm[b]) temp_ea.append(concat_T_p2pm[b]) temp_ea.append(concat_T_p2pm[b]) nu_ea.append(torch.stack(temp_ea, dim=0)) nu_Batch.edge_attr = torch.cat(nu_ea, dim=0) # Create new nodes in the batch nu_x = [] for b in range(self.batch_size): temp_x = [] temp_x.append(u1[b]) temp_x.append(u2[b]) temp_x.append(um[b]) nu_x.append(torch.stack(temp_x, dim=0)) nu_Batch.x = torch.cat(nu_x, dim=0) # Create new graph level target gamma = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_TWO), fill_value=0.1, dtype=torch.float) nu_Batch.y = gamma # Create new batch nu_btc = [] for b in range(self.batch_size): nu_btc.append( torch.full(size=[3], fill_value=(int(b)), dtype=torch.long)) nu_Batch.batch = torch.cat(nu_btc, dim=0) nu_Batch = nu_Batch.to(device) # --- MESSAGE PASSING LVL 2 for _ in range(self.no_mp_two): nu_Batch.x, nu_Batch.edge_attr, nu_Batch.y = self.meta3( x=nu_Batch.x, edge_index=nu_Batch.edge_index, edge_attr=nu_Batch.edge_attr, u=nu_Batch.y, batch=nu_Batch.batch) c = self.mlp_last(nu_Batch.y) return c