Exemplo n.º 1
0
def data_loader(database, latent_graph_dict, latent_graph_list, batch_size=4, shuffle=True):
    # -- This loader creates Databatches from a Database, run on either the training or validation set!

    # Establish empty lists
    graph_list_p1 = []
    graph_list_p2 = []
    graph_list_pm = []
    target_list = []
    Temperature_list = []

    # Iterate through database and construct latent graphs of the molecules, as well as the target tensors
    for itm in range(len(database)):
        graph_p1 = latent_graph_list[latent_graph_dict[database["Molecule1"].loc[itm]]]
        graph_p2 = latent_graph_list[latent_graph_dict[database["Molecule2"].loc[itm]]]
        graph_pm = latent_graph_list[latent_graph_dict[database["Membrane"].loc[itm]]]

        graph_p1.y = torch.tensor([float(database["molp1"].loc[itm])])
        graph_p2.y = torch.tensor([float(database["molp2"].loc[itm])])
        graph_pm.y = torch.tensor([float(database["molpm"].loc[itm])])

        graph_list_p1.append(graph_p1)
        graph_list_p2.append(graph_p2)
        graph_list_pm.append(graph_pm)

        target = torch.tensor([database[str(i)].loc[itm]*10 for i in range(2, 17)], dtype=torch.float)
        target_list.append(target)

        Temperature_list.append(database["Temperature"].loc[itm])

    # Generate a shuffled integer list that then produces distinct batches
    idxs = torch.randperm(len(database)).tolist()
    if shuffle==False:
        idxs = sorted(idxs)

    # Generate the batches
    batch_p1 = []
    batch_p2 = []
    batch_pm = []
    batch_target = []
    batch_Temperature = []

    for b in range(len(database))[::batch_size]:
        # Creating empty Batch objects
        B_p1 = Batch()
        B_p2 = Batch()
        B_pm = Batch()

        # Creating empty lists that will be concatenated later (advanced minibatching)
        stack_x_p1, stack_x_p2, stack_x_pm = [], [], []
        stack_ei_p1, stack_ei_p2, stack_ei_pm = [], [], []
        stack_ea_p1, stack_ea_p2, stack_ea_pm = [], [], []
        stack_y_p1, stack_y_p2, stack_y_pm = [], [], []
        stack_btc_p1, stack_btc_p2, stack_btc_pm = [], [], []

        stack_target = []
        stack_Temperature = []

        # Creating the batch shifts, that shift the edge indices
        b_shift_p1 = 0
        b_shift_p2 = 0
        b_shift_pm = 0

        for i in range(batch_size):
            if (b+i) < len(database):
                stack_x_p1.append(graph_list_p1[idxs[b + i]].x)
                stack_x_p2.append(graph_list_p2[idxs[b + i]].x)
                stack_x_pm.append(graph_list_pm[idxs[b + i]].x)

                stack_ei_p1.append(graph_list_p1[idxs[b + i]].edge_index+b_shift_p1)
                stack_ei_p2.append(graph_list_p2[idxs[b + i]].edge_index+b_shift_p2)
                stack_ei_pm.append(graph_list_pm[idxs[b + i]].edge_index+b_shift_pm)

                # Updating the shifts
                b_shift_p1 += graph_list_p1[idxs[b + i]].x.size()[0]
                b_shift_p2 += graph_list_p2[idxs[b + i]].x.size()[0]
                b_shift_pm += graph_list_pm[idxs[b + i]].x.size()[0]

                stack_ea_p1.append(graph_list_p1[idxs[b + i]].edge_attr)
                stack_ea_p2.append(graph_list_p2[idxs[b + i]].edge_attr)
                stack_ea_pm.append(graph_list_pm[idxs[b + i]].edge_attr)

                stack_y_p1.append(graph_list_p1[idxs[b + i]].y)
                stack_y_p2.append(graph_list_p2[idxs[b + i]].y)
                stack_y_pm.append(graph_list_pm[idxs[b + i]].y)

                stack_btc_p1.append(
                    torch.full([graph_list_p1[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long))
                # FILL VALUE IS PROBABLY JUST i! (OR NOT)
                stack_btc_p2.append(
                    torch.full([graph_list_p2[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long))
                stack_btc_pm.append(
                    torch.full([graph_list_pm[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long))

                stack_target.append(target_list[idxs[b + i]])
                stack_Temperature.append(Temperature_list[int(b+i)])

        # print(stack_y_p1)
        # print(stack_x_p1)

        B_p1.edge_index=torch.cat(stack_ei_p1,dim=1)
        B_p1.x=torch.cat(stack_x_p1,dim=0)
        B_p1.edge_attr=torch.cat(stack_ea_p1,dim=0)
        B_p1.y=torch.cat(stack_y_p1,dim=0)
        B_p1.batch=torch.cat(stack_btc_p1,dim=0)

        B_p2.edge_index = torch.cat(stack_ei_p2, dim=1)
        B_p2.x = torch.cat(stack_x_p2, dim=0)
        B_p2.edge_attr = torch.cat(stack_ea_p2, dim=0)
        B_p2.y = torch.cat(stack_y_p2, dim=0)
        B_p2.batch = torch.cat(stack_btc_p2, dim=0)

        B_pm.edge_index = torch.cat(stack_ei_pm, dim=1)
        B_pm.x = torch.cat(stack_x_pm, dim=0)
        B_pm.edge_attr = torch.cat(stack_ea_pm, dim=0)
        B_pm.y = torch.cat(stack_y_pm, dim=0)
        B_pm.batch = torch.cat(stack_btc_pm, dim=0)

        B_target = torch.stack(stack_target, dim=0)

        B_Temperature = torch.Tensor(stack_Temperature)

        # Appending batches
        batch_p1.append(B_p1)
        batch_p2.append(B_p2)
        batch_pm.append(B_pm)
        batch_target.append(B_target)
        batch_Temperature.append(B_Temperature)

    # Return
    return [batch_p1, batch_p2, batch_pm, batch_target, batch_Temperature]
Exemplo n.º 2
0
    def forward(self, plist):
        # Here we instantiate the three molecular graphs for the first level GNN
        p1, p2, pm, Temperature = plist
        x_p1, ei_p1, ea_p1, u_p1, btc_p1 = p1.x, p1.edge_index, p1.edge_attr, p1.y, p1.batch
        x_p2, ei_p2, ea_p2, u_p2, btc_p2 = p2.x, p2.edge_index, p2.edge_attr, p2.y, p2.batch
        x_pm, ei_pm, ea_pm, u_pm, btc_pm = pm.x, pm.edge_index, pm.edge_attr, pm.y, pm.batch

        # Embed the node and edge features
        enc_x_p1 = self.encoding_node_1(x_p1)
        enc_x_p2 = self.encoding_node_1(x_p2)
        enc_x_pm = self.encoding_node_1(x_pm)

        enc_ea_p1 = self.encoding_edge_1(ea_p1)
        enc_ea_p2 = self.encoding_edge_1(ea_p2)
        enc_ea_pm = self.encoding_edge_1(ea_pm)

        #Create the empty molecular graphs for feature extraction, graph level one
        u1 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE),
                        fill_value=0.1,
                        dtype=torch.float).to(device)
        u2 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE),
                        fill_value=0.1,
                        dtype=torch.float).to(device)
        um = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE),
                        fill_value=0.1,
                        dtype=torch.float).to(device)

        # Now the first level rounds of message passing are performed, molecular features are constructed
        for _ in range(self.no_mp_one):
            enc_x_p1, enc_ea_p1, u1 = self.meta1(x=enc_x_p1,
                                                 edge_index=ei_p1,
                                                 edge_attr=enc_ea_p1,
                                                 u=u1,
                                                 batch=btc_p1)
            enc_x_p2, enc_ea_p2, u2 = self.meta1(x=enc_x_p2,
                                                 edge_index=ei_p2,
                                                 edge_attr=enc_ea_p2,
                                                 u=u2,
                                                 batch=btc_p2)
            enc_x_pm, enc_ea_pm, um = self.meta1(x=enc_x_pm,
                                                 edge_index=ei_pm,
                                                 edge_attr=enc_ea_pm,
                                                 u=um,
                                                 batch=btc_pm)

        # --- GRAPH GENERATION SECOND LEVEL

        #Encode the nodes second level
        u1 = self.encoding_node_2(u1)
        u2 = self.encoding_node_2(u2)
        um = self.encoding_node_2(um)

        # Instantiate new Batch object for second level graph
        nu_Batch = Batch()

        # Create edge indices for the second level (no connection between p1 and p2
        nu_ei = []
        temp_ei = torch.tensor([[0, 2, 1, 2], [2, 0, 2, 1]], dtype=torch.long)
        for b in range(self.batch_size):
            nu_ei.append(
                temp_ei + b * 3
            )  # +3 because this is the number of nodes in the second stage graph
        nu_Batch.edge_index = torch.cat(nu_ei, dim=1)

        # Create the edge features for the second level graphs from the first level "u"s
        p1_div_pm = u_p1 / u_pm
        p2_div_pm = u_p2 / u_pm
        #p1_div_p2 = u_p1 / (u_p2-1)

        # Concatenate the temperature and molecular percentages
        concat_T_p1pm = torch.transpose(torch.stack([Temperature, p1_div_pm]),
                                        0, 1)
        concat_T_p2pm = torch.transpose(torch.stack([Temperature, p2_div_pm]),
                                        0, 1)

        # Encode the edge features
        concat_T_p1pm = self.encoding_edge_2(concat_T_p1pm)
        concat_T_p2pm = self.encoding_edge_2(concat_T_p2pm)

        nu_ea = []
        for b in range(self.batch_size):
            temp_ea = []
            # Appending twice because of bidirectional edges
            temp_ea.append(concat_T_p1pm[b])
            temp_ea.append(concat_T_p1pm[b])
            temp_ea.append(concat_T_p2pm[b])
            temp_ea.append(concat_T_p2pm[b])
            nu_ea.append(torch.stack(temp_ea, dim=0))
        nu_Batch.edge_attr = torch.cat(nu_ea, dim=0)

        # Create new nodes in the batch
        nu_x = []
        for b in range(self.batch_size):
            temp_x = []
            temp_x.append(u1[b])
            temp_x.append(u2[b])
            temp_x.append(um[b])
            nu_x.append(torch.stack(temp_x, dim=0))
        nu_Batch.x = torch.cat(nu_x, dim=0)

        # Create new graph level target
        gamma = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_TWO),
                           fill_value=0.1,
                           dtype=torch.float)
        nu_Batch.y = gamma
        # Create new batch
        nu_btc = []
        for b in range(self.batch_size):
            nu_btc.append(
                torch.full(size=[3], fill_value=(int(b)), dtype=torch.long))
        nu_Batch.batch = torch.cat(nu_btc, dim=0)

        nu_Batch = nu_Batch.to(device)

        # --- MESSAGE PASSING LVL 2
        for _ in range(self.no_mp_two):
            nu_Batch.x, nu_Batch.edge_attr, nu_Batch.y = self.meta3(
                x=nu_Batch.x,
                edge_index=nu_Batch.edge_index,
                edge_attr=nu_Batch.edge_attr,
                u=nu_Batch.y,
                batch=nu_Batch.batch)

        c = self.mlp_last(nu_Batch.y)

        return c