def compute_grad_pen(self, expert_state, expert_action, policy_state, policy_action, lambda_=10): # merge graphs, apply alpha to vote shares mixup_state = Batch() for key, value in expert_state: assert isinstance(key, str), str(key) if key in ("edge_index", "edge_attr"): continue mixup_state[key] = torch.cat( [expert_state[key], policy_state[key]]) mixup_state.edge_index = torch.cat( [ expert_state.edge_index, policy_state.edge_index + expert_state.batch.shape[0], ], dim=1, ) alpha = torch.rand(expert_action.size(0)) batch_size = expert_state.batch.max().item() + 1 mixup_votes = [] for i in range(batch_size): _em = expert_state.batch == i _pm = policy_state.batch == i votes = torch.zeros((_em.sum() + _pm.sum()).item()) assert votes.shape[0] votes[expert_action[i]] = alpha[i] votes[policy_action[i] + _em.sum().item()] = 1 - alpha[i] mixup_votes.append(votes) mixup_action = torch.cat(mixup_votes).view(-1, 1) mixup_action.requires_grad = True disc = self.forward(mixup_state, mixup_action) ones = torch.ones(disc.size()).to(disc.device) inputs = [mixup_action] for key, value in mixup_state: if value.dtype == torch.float: value.requires_grad = True inputs.append(value) grad = autograd.grad( outputs=disc, inputs=inputs, grad_outputs=ones, create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True, )[0] grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() return grad_pen
def forward(self, data, **kwargs): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch idx = self.sampler(pos, batch) row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx]) edge_index = torch.stack([col, row], dim=0) batch_obj.idx = idx batch_obj.edge_index = edge_index batch_obj.x = self.conv(x, (pos[idx], pos), edge_index, batch) batch_obj.pos = pos[idx] batch_obj.batch = batch[idx] copy_from_to(data, batch_obj) return batch_obj
def __collate__(self, n_id): n_id = torch.tensor(n_id) rowptr, col, value = self.adj_t.csr() out = torch.ops.torch_sparse.ego_k_hop_sample_adj( rowptr, col, n_id, self.depth, self.num_neighbors, self.replace) rowptr, col, n_id, e_id, ptr, root_n_id = out adj_t = SparseTensor(rowptr=rowptr, col=col, value=value[e_id] if value is not None else None, sparse_sizes=(n_id.numel(), n_id.numel()), is_sorted=True, trust_data=True) batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()), ptr=ptr) batch.root_n_id = root_n_id if self.is_sparse_tensor: batch.adj_t = adj_t else: row, col, e_id = adj_t.t().coo() batch.edge_index = torch.stack([row, col], dim=0) for k, v in self.data: if k in ['edge_index', 'adj_t', 'num_nodes']: continue if k == 'y' and v.size(0) == self.data.num_nodes: batch[k] = v[n_id][root_n_id] elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes: batch[k] = v[n_id] elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges: batch[k] = v[e_id] else: batch[k] = v return batch
def forward(self, data): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch if self._precompute_multi_scale: idx = getattr(data, "index_{}".format(self._index), None) edge_index = getattr(data, "edge_index_{}".format(self._index), None) else: idx = self.sampler(pos, batch) row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx]) edge_index = torch.stack([col, row], dim=0) batch_obj.idx = idx batch_obj.edge_index = edge_index batch_obj.x = self.conv(x, (pos, pos[idx]), edge_index, batch) batch_obj.pos = pos[idx] batch_obj.batch = batch[idx] copy_from_to(data, batch_obj) return batch_obj
def data_loader(database, latent_graph_dict, latent_graph_list, batch_size=4, shuffle=True): # -- This loader creates Databatches from a Database, run on either the training or validation set! # Establish empty lists graph_list_p1 = [] graph_list_p2 = [] graph_list_pm = [] target_list = [] Temperature_list = [] # Iterate through database and construct latent graphs of the molecules, as well as the target tensors for itm in range(len(database)): graph_p1 = latent_graph_list[latent_graph_dict[database["Molecule1"].loc[itm]]] graph_p2 = latent_graph_list[latent_graph_dict[database["Molecule2"].loc[itm]]] graph_pm = latent_graph_list[latent_graph_dict[database["Membrane"].loc[itm]]] graph_p1.y = torch.tensor([float(database["molp1"].loc[itm])]) graph_p2.y = torch.tensor([float(database["molp2"].loc[itm])]) graph_pm.y = torch.tensor([float(database["molpm"].loc[itm])]) graph_list_p1.append(graph_p1) graph_list_p2.append(graph_p2) graph_list_pm.append(graph_pm) target = torch.tensor([database[str(i)].loc[itm]*10 for i in range(2, 17)], dtype=torch.float) target_list.append(target) Temperature_list.append(database["Temperature"].loc[itm]) # Generate a shuffled integer list that then produces distinct batches idxs = torch.randperm(len(database)).tolist() if shuffle==False: idxs = sorted(idxs) # Generate the batches batch_p1 = [] batch_p2 = [] batch_pm = [] batch_target = [] batch_Temperature = [] for b in range(len(database))[::batch_size]: # Creating empty Batch objects B_p1 = Batch() B_p2 = Batch() B_pm = Batch() # Creating empty lists that will be concatenated later (advanced minibatching) stack_x_p1, stack_x_p2, stack_x_pm = [], [], [] stack_ei_p1, stack_ei_p2, stack_ei_pm = [], [], [] stack_ea_p1, stack_ea_p2, stack_ea_pm = [], [], [] stack_y_p1, stack_y_p2, stack_y_pm = [], [], [] stack_btc_p1, stack_btc_p2, stack_btc_pm = [], [], [] stack_target = [] stack_Temperature = [] # Creating the batch shifts, that shift the edge indices b_shift_p1 = 0 b_shift_p2 = 0 b_shift_pm = 0 for i in range(batch_size): if (b+i) < len(database): stack_x_p1.append(graph_list_p1[idxs[b + i]].x) stack_x_p2.append(graph_list_p2[idxs[b + i]].x) stack_x_pm.append(graph_list_pm[idxs[b + i]].x) stack_ei_p1.append(graph_list_p1[idxs[b + i]].edge_index+b_shift_p1) stack_ei_p2.append(graph_list_p2[idxs[b + i]].edge_index+b_shift_p2) stack_ei_pm.append(graph_list_pm[idxs[b + i]].edge_index+b_shift_pm) # Updating the shifts b_shift_p1 += graph_list_p1[idxs[b + i]].x.size()[0] b_shift_p2 += graph_list_p2[idxs[b + i]].x.size()[0] b_shift_pm += graph_list_pm[idxs[b + i]].x.size()[0] stack_ea_p1.append(graph_list_p1[idxs[b + i]].edge_attr) stack_ea_p2.append(graph_list_p2[idxs[b + i]].edge_attr) stack_ea_pm.append(graph_list_pm[idxs[b + i]].edge_attr) stack_y_p1.append(graph_list_p1[idxs[b + i]].y) stack_y_p2.append(graph_list_p2[idxs[b + i]].y) stack_y_pm.append(graph_list_pm[idxs[b + i]].y) stack_btc_p1.append( torch.full([graph_list_p1[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long)) # FILL VALUE IS PROBABLY JUST i! (OR NOT) stack_btc_p2.append( torch.full([graph_list_p2[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long)) stack_btc_pm.append( torch.full([graph_list_pm[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long)) stack_target.append(target_list[idxs[b + i]]) stack_Temperature.append(Temperature_list[int(b+i)]) # print(stack_y_p1) # print(stack_x_p1) B_p1.edge_index=torch.cat(stack_ei_p1,dim=1) B_p1.x=torch.cat(stack_x_p1,dim=0) B_p1.edge_attr=torch.cat(stack_ea_p1,dim=0) B_p1.y=torch.cat(stack_y_p1,dim=0) B_p1.batch=torch.cat(stack_btc_p1,dim=0) B_p2.edge_index = torch.cat(stack_ei_p2, dim=1) B_p2.x = torch.cat(stack_x_p2, dim=0) B_p2.edge_attr = torch.cat(stack_ea_p2, dim=0) B_p2.y = torch.cat(stack_y_p2, dim=0) B_p2.batch = torch.cat(stack_btc_p2, dim=0) B_pm.edge_index = torch.cat(stack_ei_pm, dim=1) B_pm.x = torch.cat(stack_x_pm, dim=0) B_pm.edge_attr = torch.cat(stack_ea_pm, dim=0) B_pm.y = torch.cat(stack_y_pm, dim=0) B_pm.batch = torch.cat(stack_btc_pm, dim=0) B_target = torch.stack(stack_target, dim=0) B_Temperature = torch.Tensor(stack_Temperature) # Appending batches batch_p1.append(B_p1) batch_p2.append(B_p2) batch_pm.append(B_pm) batch_target.append(B_target) batch_Temperature.append(B_Temperature) # Return return [batch_p1, batch_p2, batch_pm, batch_target, batch_Temperature]
def forward(self, plist): # Here we instantiate the three molecular graphs for the first level GNN p1, p2, pm, Temperature = plist x_p1, ei_p1, ea_p1, u_p1, btc_p1 = p1.x, p1.edge_index, p1.edge_attr, p1.y, p1.batch x_p2, ei_p2, ea_p2, u_p2, btc_p2 = p2.x, p2.edge_index, p2.edge_attr, p2.y, p2.batch x_pm, ei_pm, ea_pm, u_pm, btc_pm = pm.x, pm.edge_index, pm.edge_attr, pm.y, pm.batch # Embed the node and edge features enc_x_p1 = self.encoding_node_1(x_p1) enc_x_p2 = self.encoding_node_1(x_p2) enc_x_pm = self.encoding_node_1(x_pm) enc_ea_p1 = self.encoding_edge_1(ea_p1) enc_ea_p2 = self.encoding_edge_1(ea_p2) enc_ea_pm = self.encoding_edge_1(ea_pm) #Create the empty molecular graphs for feature extraction, graph level one u1 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE), fill_value=0.1, dtype=torch.float).to(device) u2 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE), fill_value=0.1, dtype=torch.float).to(device) um = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE), fill_value=0.1, dtype=torch.float).to(device) # Now the first level rounds of message passing are performed, molecular features are constructed for _ in range(self.no_mp_one): enc_x_p1, enc_ea_p1, u1 = self.meta1(x=enc_x_p1, edge_index=ei_p1, edge_attr=enc_ea_p1, u=u1, batch=btc_p1) enc_x_p2, enc_ea_p2, u2 = self.meta1(x=enc_x_p2, edge_index=ei_p2, edge_attr=enc_ea_p2, u=u2, batch=btc_p2) enc_x_pm, enc_ea_pm, um = self.meta1(x=enc_x_pm, edge_index=ei_pm, edge_attr=enc_ea_pm, u=um, batch=btc_pm) # --- GRAPH GENERATION SECOND LEVEL #Encode the nodes second level u1 = self.encoding_node_2(u1) u2 = self.encoding_node_2(u2) um = self.encoding_node_2(um) # Instantiate new Batch object for second level graph nu_Batch = Batch() # Create edge indices for the second level (no connection between p1 and p2 nu_ei = [] temp_ei = torch.tensor([[0, 2, 1, 2], [2, 0, 2, 1]], dtype=torch.long) for b in range(self.batch_size): nu_ei.append( temp_ei + b * 3 ) # +3 because this is the number of nodes in the second stage graph nu_Batch.edge_index = torch.cat(nu_ei, dim=1) # Create the edge features for the second level graphs from the first level "u"s p1_div_pm = u_p1 / u_pm p2_div_pm = u_p2 / u_pm #p1_div_p2 = u_p1 / (u_p2-1) # Concatenate the temperature and molecular percentages concat_T_p1pm = torch.transpose(torch.stack([Temperature, p1_div_pm]), 0, 1) concat_T_p2pm = torch.transpose(torch.stack([Temperature, p2_div_pm]), 0, 1) # Encode the edge features concat_T_p1pm = self.encoding_edge_2(concat_T_p1pm) concat_T_p2pm = self.encoding_edge_2(concat_T_p2pm) nu_ea = [] for b in range(self.batch_size): temp_ea = [] # Appending twice because of bidirectional edges temp_ea.append(concat_T_p1pm[b]) temp_ea.append(concat_T_p1pm[b]) temp_ea.append(concat_T_p2pm[b]) temp_ea.append(concat_T_p2pm[b]) nu_ea.append(torch.stack(temp_ea, dim=0)) nu_Batch.edge_attr = torch.cat(nu_ea, dim=0) # Create new nodes in the batch nu_x = [] for b in range(self.batch_size): temp_x = [] temp_x.append(u1[b]) temp_x.append(u2[b]) temp_x.append(um[b]) nu_x.append(torch.stack(temp_x, dim=0)) nu_Batch.x = torch.cat(nu_x, dim=0) # Create new graph level target gamma = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_TWO), fill_value=0.1, dtype=torch.float) nu_Batch.y = gamma # Create new batch nu_btc = [] for b in range(self.batch_size): nu_btc.append( torch.full(size=[3], fill_value=(int(b)), dtype=torch.long)) nu_Batch.batch = torch.cat(nu_btc, dim=0) nu_Batch = nu_Batch.to(device) # --- MESSAGE PASSING LVL 2 for _ in range(self.no_mp_two): nu_Batch.x, nu_Batch.edge_attr, nu_Batch.y = self.meta3( x=nu_Batch.x, edge_index=nu_Batch.edge_index, edge_attr=nu_Batch.edge_attr, u=nu_Batch.y, batch=nu_Batch.batch) c = self.mlp_last(nu_Batch.y) return c
def forward(self, primal_graph_batch, dual_graph_batch, pooling_log): r"""Performs the unpooling operation. Args: primal_graph_batch (torch_geometric.data.batch.Batch): Batch containing the input primal graphs on which the unpooling operation should be applied. dual_graph_batch (torch_geometric.data.batch.Batch): Batch containing the input dual graphs on which the unpooling operation should be applied. pooling_log (pd_mesh_net.nn.pool.PoolingInfo): Data structure containing the information saved when pooling, and needed to perform the unpooling operation. Returns: new_primal_graph_batch (torch_geometric.data.batch.Batch): Output primal-graph batch after unpooling. new_dual_graph_batch (torch_geometric.data.batch.Batch): Output dual-graph batch after unpooling. new_primal_edge_to_dual_node_idx_batch (dict): Output dictionary representing the associations between primal-graph edges and dual-graph nodes in the batch after unpooling. """ # Reconstruct the primal graph. old_primal_node_to_new_one = pooling_log.old_primal_node_to_new_one # - Reconstruct the connectivity. new_primal_graph_batch = Batch(batch=pooling_log.old_primal_graph_batch) new_primal_graph_batch.edge_index = pooling_log.old_primal_edge_index # - Assign to the new primal nodes the features of the primal nodes in # which these were merged when performing the pooling operation. new_primal_graph_batch.x = primal_graph_batch.x[ old_primal_node_to_new_one] # Reconstruct the dual graph. old_dual_node_to_new_one = pooling_log.old_dual_node_to_new_one assert (old_dual_node_to_new_one is not None), ( "The input pooling log does not contain the mapping from dual " "nodes before pooling to dual nodes after pooling. Please set the " "argument `return_old_dual_node_to_new_dual_node` to True in the " "`DualPrimalEdgePooling` layer.") # - Reconstruct the connectivity. new_dual_graph_batch = Batch(batch=pooling_log.old_dual_graph_batch) new_dual_graph_batch.edge_index = pooling_log.old_dual_edge_index # - Assign the same learnable feature to all the dual nodes that do not # have a corresponding dual node after performing the pooling # operation. num_old_dual_nodes = len(old_dual_node_to_new_one) new_dual_graph_batch_x = self.new_dual_feature.repeat( num_old_dual_nodes, 1) # - Assign to the new dual nodes that have a corresponding dual node # after the pooling operation the features of these corresponding # nodes. new_dual_graph_batch_x[ old_dual_node_to_new_one != -1] = dual_graph_batch.x[ old_dual_node_to_new_one[old_dual_node_to_new_one != -1]] new_dual_graph_batch.x = new_dual_graph_batch_x # Remap the primal-graph edges to the dual-graph nodes. new_primal_edge_to_dual_node_idx_batch = ( pooling_log.old_primal_edge_to_dual_node_index) return (new_primal_graph_batch, new_dual_graph_batch, new_primal_edge_to_dual_node_idx_batch)