def from_data_list_token(data_list, follow_batch=[]): """ This is pretty a copy paste of the from data list of pytorch geometric batch object with the difference that indexes that are negative are not incremented """ keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert "batch" not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch["{}_batch".format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: mask = item >= 0 item[mask] = item[mask] + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] += data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size,), i, dtype=torch.long) batch["{}_batch".format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes,), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError("Unsupported attribute type {} : {}".format(type(item), item)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def collate_fn_withpad(data_list): ''' Modified based on PyTorch-Geometric's implementation :param data_list: :return: ''' keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() for key in keys: batch[key] = [] batch.batch = [] cumsum = 0 for i, data in enumerate(data_list): num_nodes = data.num_nodes batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long)) for key in data.keys: item = data[key] item = item + cumsum if data.__cumsum__(key, item) else item batch[key].append(item) cumsum += num_nodes for key in keys: item = batch[key][0] if torch.is_tensor(item): if (len(item.shape) == 3): tlens = [x.shape[1] for x in batch[key]] maxtlens = np.max(tlens) to_cat = [] for x in batch[key]: to_cat.append( torch.cat([ x, x.new_zeros(x.shape[0], maxtlens - x.shape[1], x.shape[2]) ], dim=1)) batch[key] = torch.cat(to_cat, dim=0) if 'tlens' not in batch.keys: batch['tlens'] = item.new_tensor(tlens, dtype=torch.long) else: batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError('Unsupported attribute type.') batch.batch = torch.cat(batch.batch, dim=-1) return batch.contiguous()
def collate(data_list): batch = Batch() batch.batch = [] for key in data_list[0].keys: batch[key] = default_collate([d[key] for d in data_list]) for i, data in enumerate(data_list): num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) batch.batch = torch.cat(batch.batch, dim=0) return batch
def forward(self, data): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch if self._precompute_multi_scale: idx = getattr(data, "idx_{}".format(self._index), None) else: idx = self.sampler(pos, batch) batch_obj.idx = idx ms_x = [] for scale_idx in range(self.neighbour_finder.num_scales): if self._precompute_multi_scale: edge_index = getattr( data, "edge_index_{}_{}".format(self._index, scale_idx), None) else: row, col = self.neighbour_finder( pos, pos[idx], batch_x=batch, batch_y=batch[idx], scale_idx=scale_idx, ) edge_index = torch.stack([col, row], dim=0) ms_x.append(self.conv(x, (pos, pos[idx]), edge_index, batch)) batch_obj.x = torch.cat(ms_x, -1) batch_obj.pos = pos[idx] batch_obj.batch = batch[idx] copy_from_to(data, batch_obj) return batch_obj
def forward(self, data): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch idx_sampler = self.sampler(pos=pos, x=x, batch=batch) idx_neighbour, _ = self.neighbour_finder(pos, pos, batch_x=batch, batch_y=batch) shadow_x = torch.full((1, ) + x.shape[1:], self.shadow_features_fill).to(x.device) shadow_points = torch.full((1, ) + pos.shape[1:], self.shadow_points_fill_).to(x.device) x = torch.cat([x, shadow_x], dim=0) pos = torch.cat([pos, shadow_points], dim=0) x_neighbour = x[idx_neighbour] pos_centered_neighbour = pos[idx_neighbour] - pos[:-1].unsqueeze( 1) # Centered the points batch_obj.x = self.conv(x, pos, x_neighbour, pos_centered_neighbour, idx_neighbour, idx_sampler) batch_obj.pos = pos[idx_sampler] batch_obj.batch = batch[idx_sampler] copy_from_to(data, batch_obj) return batch_obj
def forward(self, data, **kwargs): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch x = self.nn(torch.cat([x, pos], dim=1)) x = self.pool(x, batch) batch_obj.x = x batch_obj.pos = pos.new_zeros((x.size(0), 3)) batch_obj.batch = torch.arange(x.size(0), device=batch.device) copy_from_to(data, batch_obj) return batch_obj
def forward(self, data, **kwargs): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch idx = self.sampler(pos, batch) row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx]) edge_index = torch.stack([col, row], dim=0) batch_obj.idx = idx batch_obj.edge_index = edge_index batch_obj.x = self.conv(x, (pos[idx], pos), edge_index, batch) batch_obj.pos = pos[idx] batch_obj.batch = batch[idx] copy_from_to(data, batch_obj) return batch_obj
def forward(self, data, **kwargs): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch idx = self.sampler(pos, batch) batch_obj.idx = idx ms_x = [] for scale_idx in range(self.neighbour_finder.num_scales): row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx], scale_idx=scale_idx,) edge_index = torch.stack([col, row], dim=0) ms_x.append(self.conv(x, (pos, pos[idx]), edge_index, batch)) batch_obj.x = torch.cat(ms_x, -1) batch_obj.pos = pos[idx] batch_obj.batch = batch[idx] copy_from_to(data, batch_obj) return batch_obj
def forward(self, data, **kwargs): batch_obj = Batch() x = data.x # (N, indim) shortcut = x # (N, indim) x = self.features_downsample_nn(x) # (N, outdim//4) # if this is an identity resnet block, idx will be None data = self.convs(data) # (N', convdim) x = data.x idx = data.idx x = self.features_upsample_nn(x) # (N', outdim) if idx is not None: shortcut = shortcut[idx] # (N', indim) shortcut = self.shortcut_feature_resize_nn(shortcut) # (N', outdim) x = shortcut + x batch_obj.x = x batch_obj.pos = data.pos batch_obj.batch = data.batch copy_from_to(data, batch_obj) return batch_obj
def forward(self, data): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch if self._precompute_multi_scale: idx = getattr(data, "index_{}".format(self._index), None) edge_index = getattr(data, "edge_index_{}".format(self._index), None) else: idx = self.sampler(pos, batch) row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx]) edge_index = torch.stack([col, row], dim=0) batch_obj.idx = idx batch_obj.edge_index = edge_index batch_obj.x = self.conv(x, (pos, pos[idx]), edge_index, batch) batch_obj.pos = pos[idx] batch_obj.batch = batch[idx] copy_from_to(data, batch_obj) return batch_obj
def data_loader(database, latent_graph_dict, latent_graph_list, batch_size=4, shuffle=True): # -- This loader creates Databatches from a Database, run on either the training or validation set! # Establish empty lists graph_list_p1 = [] graph_list_p2 = [] graph_list_pm = [] target_list = [] Temperature_list = [] # Iterate through database and construct latent graphs of the molecules, as well as the target tensors for itm in range(len(database)): graph_p1 = latent_graph_list[latent_graph_dict[database["Molecule1"].loc[itm]]] graph_p2 = latent_graph_list[latent_graph_dict[database["Molecule2"].loc[itm]]] graph_pm = latent_graph_list[latent_graph_dict[database["Membrane"].loc[itm]]] graph_p1.y = torch.tensor([float(database["molp1"].loc[itm])]) graph_p2.y = torch.tensor([float(database["molp2"].loc[itm])]) graph_pm.y = torch.tensor([float(database["molpm"].loc[itm])]) graph_list_p1.append(graph_p1) graph_list_p2.append(graph_p2) graph_list_pm.append(graph_pm) target = torch.tensor([database[str(i)].loc[itm]*10 for i in range(2, 17)], dtype=torch.float) target_list.append(target) Temperature_list.append(database["Temperature"].loc[itm]) # Generate a shuffled integer list that then produces distinct batches idxs = torch.randperm(len(database)).tolist() if shuffle==False: idxs = sorted(idxs) # Generate the batches batch_p1 = [] batch_p2 = [] batch_pm = [] batch_target = [] batch_Temperature = [] for b in range(len(database))[::batch_size]: # Creating empty Batch objects B_p1 = Batch() B_p2 = Batch() B_pm = Batch() # Creating empty lists that will be concatenated later (advanced minibatching) stack_x_p1, stack_x_p2, stack_x_pm = [], [], [] stack_ei_p1, stack_ei_p2, stack_ei_pm = [], [], [] stack_ea_p1, stack_ea_p2, stack_ea_pm = [], [], [] stack_y_p1, stack_y_p2, stack_y_pm = [], [], [] stack_btc_p1, stack_btc_p2, stack_btc_pm = [], [], [] stack_target = [] stack_Temperature = [] # Creating the batch shifts, that shift the edge indices b_shift_p1 = 0 b_shift_p2 = 0 b_shift_pm = 0 for i in range(batch_size): if (b+i) < len(database): stack_x_p1.append(graph_list_p1[idxs[b + i]].x) stack_x_p2.append(graph_list_p2[idxs[b + i]].x) stack_x_pm.append(graph_list_pm[idxs[b + i]].x) stack_ei_p1.append(graph_list_p1[idxs[b + i]].edge_index+b_shift_p1) stack_ei_p2.append(graph_list_p2[idxs[b + i]].edge_index+b_shift_p2) stack_ei_pm.append(graph_list_pm[idxs[b + i]].edge_index+b_shift_pm) # Updating the shifts b_shift_p1 += graph_list_p1[idxs[b + i]].x.size()[0] b_shift_p2 += graph_list_p2[idxs[b + i]].x.size()[0] b_shift_pm += graph_list_pm[idxs[b + i]].x.size()[0] stack_ea_p1.append(graph_list_p1[idxs[b + i]].edge_attr) stack_ea_p2.append(graph_list_p2[idxs[b + i]].edge_attr) stack_ea_pm.append(graph_list_pm[idxs[b + i]].edge_attr) stack_y_p1.append(graph_list_p1[idxs[b + i]].y) stack_y_p2.append(graph_list_p2[idxs[b + i]].y) stack_y_pm.append(graph_list_pm[idxs[b + i]].y) stack_btc_p1.append( torch.full([graph_list_p1[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long)) # FILL VALUE IS PROBABLY JUST i! (OR NOT) stack_btc_p2.append( torch.full([graph_list_p2[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long)) stack_btc_pm.append( torch.full([graph_list_pm[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long)) stack_target.append(target_list[idxs[b + i]]) stack_Temperature.append(Temperature_list[int(b+i)]) # print(stack_y_p1) # print(stack_x_p1) B_p1.edge_index=torch.cat(stack_ei_p1,dim=1) B_p1.x=torch.cat(stack_x_p1,dim=0) B_p1.edge_attr=torch.cat(stack_ea_p1,dim=0) B_p1.y=torch.cat(stack_y_p1,dim=0) B_p1.batch=torch.cat(stack_btc_p1,dim=0) B_p2.edge_index = torch.cat(stack_ei_p2, dim=1) B_p2.x = torch.cat(stack_x_p2, dim=0) B_p2.edge_attr = torch.cat(stack_ea_p2, dim=0) B_p2.y = torch.cat(stack_y_p2, dim=0) B_p2.batch = torch.cat(stack_btc_p2, dim=0) B_pm.edge_index = torch.cat(stack_ei_pm, dim=1) B_pm.x = torch.cat(stack_x_pm, dim=0) B_pm.edge_attr = torch.cat(stack_ea_pm, dim=0) B_pm.y = torch.cat(stack_y_pm, dim=0) B_pm.batch = torch.cat(stack_btc_pm, dim=0) B_target = torch.stack(stack_target, dim=0) B_Temperature = torch.Tensor(stack_Temperature) # Appending batches batch_p1.append(B_p1) batch_p2.append(B_p2) batch_pm.append(B_pm) batch_target.append(B_target) batch_Temperature.append(B_Temperature) # Return return [batch_p1, batch_p2, batch_pm, batch_target, batch_Temperature]
def collate(data_list): keys = data_list[0].keys assert 'batch' not in keys batch = Batch() for key in keys: batch[key] = [] batch.batch = [] if 'edge_index_2' in keys: batch.batch_2 = [] if 'edge_index_3' in keys: batch.batch_3 = [] keys.remove('edge_index') props = [ 'edge_index_2', 'assignment_2', 'edge_index_3', 'assignment_3', 'assignment_2to3' ] keys = [x for x in keys if x not in props] cumsum_1 = N_1 = cumsum_2 = N_2 = cumsum_3 = N_3 = 0 for i, data in enumerate(data_list): for key in keys: batch[key].append(data[key]) N_1 = data.num_nodes batch.edge_index.append(data.edge_index + cumsum_1) batch.batch.append(torch.full((N_1, ), i, dtype=torch.long)) if 'edge_index_2' in data: N_2 = data.assignment_2[1].max().item() + 1 batch.edge_index_2.append(data.edge_index_2 + cumsum_2) batch.assignment_2.append(data.assignment_2 + torch.tensor([[cumsum_1], [cumsum_2]])) batch.batch_2.append(torch.full((N_2, ), i, dtype=torch.long)) if 'edge_index_3' in data: N_3 = data.assignment_3[1].max().item() + 1 batch.edge_index_3.append(data.edge_index_3 + cumsum_3) batch.assignment_3.append(data.assignment_3 + torch.tensor([[cumsum_1], [cumsum_3]])) batch.batch_3.append(torch.full((N_3, ), i, dtype=torch.long)) if 'assignment_2to3' in data: assert 'edge_index_2' in data and 'edge_index_3' in data batch.assignment_2to3.append( data.assignment_2to3 + torch.tensor([[cumsum_2], [cumsum_3]])) cumsum_1 += N_1 cumsum_2 += N_2 cumsum_3 += N_3 keys = [x for x in batch.keys if x not in ['batch', 'batch_2', 'batch_3']] for key in keys: batch[key] = torch.cat(batch[key], dim=data_list[0].cat_dim(key)) batch.batch = torch.cat(batch.batch, dim=-1) if 'batch_2' in batch: batch.batch_2 = torch.cat(batch.batch_2, dim=-1) if 'batch_3' in batch: batch.batch_3 = torch.cat(batch.batch_3, dim=-1) return batch.contiguous()
def forward(self, plist): # Here we instantiate the three molecular graphs for the first level GNN p1, p2, pm, Temperature = plist x_p1, ei_p1, ea_p1, u_p1, btc_p1 = p1.x, p1.edge_index, p1.edge_attr, p1.y, p1.batch x_p2, ei_p2, ea_p2, u_p2, btc_p2 = p2.x, p2.edge_index, p2.edge_attr, p2.y, p2.batch x_pm, ei_pm, ea_pm, u_pm, btc_pm = pm.x, pm.edge_index, pm.edge_attr, pm.y, pm.batch # Embed the node and edge features enc_x_p1 = self.encoding_node_1(x_p1) enc_x_p2 = self.encoding_node_1(x_p2) enc_x_pm = self.encoding_node_1(x_pm) enc_ea_p1 = self.encoding_edge_1(ea_p1) enc_ea_p2 = self.encoding_edge_1(ea_p2) enc_ea_pm = self.encoding_edge_1(ea_pm) #Create the empty molecular graphs for feature extraction, graph level one u1 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE), fill_value=0.1, dtype=torch.float).to(device) u2 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE), fill_value=0.1, dtype=torch.float).to(device) um = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE), fill_value=0.1, dtype=torch.float).to(device) # Now the first level rounds of message passing are performed, molecular features are constructed for _ in range(self.no_mp_one): enc_x_p1, enc_ea_p1, u1 = self.meta1(x=enc_x_p1, edge_index=ei_p1, edge_attr=enc_ea_p1, u=u1, batch=btc_p1) enc_x_p2, enc_ea_p2, u2 = self.meta1(x=enc_x_p2, edge_index=ei_p2, edge_attr=enc_ea_p2, u=u2, batch=btc_p2) enc_x_pm, enc_ea_pm, um = self.meta1(x=enc_x_pm, edge_index=ei_pm, edge_attr=enc_ea_pm, u=um, batch=btc_pm) # --- GRAPH GENERATION SECOND LEVEL #Encode the nodes second level u1 = self.encoding_node_2(u1) u2 = self.encoding_node_2(u2) um = self.encoding_node_2(um) # Instantiate new Batch object for second level graph nu_Batch = Batch() # Create edge indices for the second level (no connection between p1 and p2 nu_ei = [] temp_ei = torch.tensor([[0, 2, 1, 2], [2, 0, 2, 1]], dtype=torch.long) for b in range(self.batch_size): nu_ei.append( temp_ei + b * 3 ) # +3 because this is the number of nodes in the second stage graph nu_Batch.edge_index = torch.cat(nu_ei, dim=1) # Create the edge features for the second level graphs from the first level "u"s p1_div_pm = u_p1 / u_pm p2_div_pm = u_p2 / u_pm #p1_div_p2 = u_p1 / (u_p2-1) # Concatenate the temperature and molecular percentages concat_T_p1pm = torch.transpose(torch.stack([Temperature, p1_div_pm]), 0, 1) concat_T_p2pm = torch.transpose(torch.stack([Temperature, p2_div_pm]), 0, 1) # Encode the edge features concat_T_p1pm = self.encoding_edge_2(concat_T_p1pm) concat_T_p2pm = self.encoding_edge_2(concat_T_p2pm) nu_ea = [] for b in range(self.batch_size): temp_ea = [] # Appending twice because of bidirectional edges temp_ea.append(concat_T_p1pm[b]) temp_ea.append(concat_T_p1pm[b]) temp_ea.append(concat_T_p2pm[b]) temp_ea.append(concat_T_p2pm[b]) nu_ea.append(torch.stack(temp_ea, dim=0)) nu_Batch.edge_attr = torch.cat(nu_ea, dim=0) # Create new nodes in the batch nu_x = [] for b in range(self.batch_size): temp_x = [] temp_x.append(u1[b]) temp_x.append(u2[b]) temp_x.append(um[b]) nu_x.append(torch.stack(temp_x, dim=0)) nu_Batch.x = torch.cat(nu_x, dim=0) # Create new graph level target gamma = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_TWO), fill_value=0.1, dtype=torch.float) nu_Batch.y = gamma # Create new batch nu_btc = [] for b in range(self.batch_size): nu_btc.append( torch.full(size=[3], fill_value=(int(b)), dtype=torch.long)) nu_Batch.batch = torch.cat(nu_btc, dim=0) nu_Batch = nu_Batch.to(device) # --- MESSAGE PASSING LVL 2 for _ in range(self.no_mp_two): nu_Batch.x, nu_Batch.edge_attr, nu_Batch.y = self.meta3( x=nu_Batch.x, edge_index=nu_Batch.edge_index, edge_attr=nu_Batch.edge_attr, u=nu_Batch.y, batch=nu_Batch.batch) c = self.mlp_last(nu_Batch.y) return c
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = set.union(*keys) keys.remove('depth_count') keys = list(keys) depth = max(data.depth_count.shape[0] for data in data_list) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {i: 0 for i in range(depth)} depth_count = th.zeros((depth, ), dtype=th.long) batch.batch = [] for i, data in enumerate(data_list): edges = data['edge_index'] for d in range(1, depth): mask = data.depth_mask == d edges[mask] += cumsum[d - 1] cumsum[d - 1] += data.depth_count[d - 1].item() batch['edge_index'].append(edges) depth_count += data['depth_count'] for key in data.keys: if key == 'edge_index' or key == 'depth_count': continue item = data[key] batch[key].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) batch.depth_count = depth_count if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: item = data[key] + cumsum.get(key, 0) if key in cumsum: if key == 'edge_index': cumsum[key] += data['x_numeric'].shape[0] if key == 'subgraphs_nodes': cumsum[key] += data['x_numeric'].shape[0] if key == 'subgraphs_edges': cumsum[key] += data['edge_attr_numeric'].shape[0] if key == 'subgraphs_connectivity': cumsum[key] += data['subgraphs_nodes'].shape[0] if key == 'subgraphs_nodes_path_batch': cumsum[key] += data['subgraphs_nodes_path_batch'].max() + 1 if key == 'subgraphs_edges_path_batch': cumsum[key] += data['subgraphs_edges_path_batch'].max() + 1 if key == 'subgraphs_global_path_batch': cumsum[key] += 1 if key == 'cycles_id': if data['cycles_id'].shape[0] > 0: cumsum[key] += data['cycles_id'].max() + 1 if key == 'cycles_edge_index': cumsum[key] += data['edge_attr_numeric'].shape[0] if key == 'edges_connectivity_ids': cumsum[key] += data['edge_attr_numeric'].shape[0] else: if key == 'edge_index': cumsum[key] = data['x_numeric'].shape[0] if key == 'subgraphs_nodes': cumsum[key] = data['x_numeric'].shape[0] if key == 'subgraphs_edges': cumsum[key] = data['edge_attr_numeric'].shape[0] if key == 'subgraphs_connectivity': cumsum[key] = data['subgraphs_nodes'].shape[0] if key == 'subgraphs_nodes_path_batch': cumsum[key] = data['subgraphs_nodes_path_batch'].max() + 1 if key == 'subgraphs_edges_path_batch': cumsum[key] = data['subgraphs_edges_path_batch'].max() + 1 if key == 'subgraphs_global_path_batch': cumsum[key] = 1 if key == 'cycles_id': if data['cycles_id'].shape[0] > 0: cumsum[key] = data['cycles_id'].max() + 1 if key == 'cycles_edge_index': cumsum[key] = data['edge_attr_numeric'].shape[0] if key == 'edges_connectivity_ids': cumsum[key] = data['edge_attr_numeric'].shape[0] batch[key].append(item) for key in follow_batch: size = data[key].size(data.__cat_dim__(key, data[key])) item = torch.full((size, ), i, dtype=torch.long) batch['{}_batch'.format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): if key not in ['subgraphs_connectivity', 'edges_connectivity_ids']: batch[key] = torch.cat( batch[key], dim=data_list[0].__cat_dim__(key, item)) else: batch[key] = torch.cat( batch[key], dim=-1) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError('Unsupported attribute type.') # Copy custom data functions to batch (does not work yet): # if data_list.__class__ != Data: # org_funcs = set(Data.__dict__.keys()) # funcs = set(data_list[0].__class__.__dict__.keys()) # batch.__custom_funcs__ = funcs.difference(org_funcs) # for func in funcs.difference(org_funcs): # setattr(batch, func, getattr(data_list[0], func)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: # logger.info(f"key={key}") item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: item = item + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] = cumsum[key] + data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size,), i, dtype=torch.long) batch['{}_batch'.format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes,), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] logger.debug(f"key = {key}") if torch.is_tensor(item): logger.debug(f"batch[{key}]") logger.debug(f"item.shape = {item.shape}") elem = data_list[0] # type(elem) = Data or ClevrData dim_ = elem.__cat_dim__(key, item) # basically, which dim we want to concat batch[key] = torch.cat(batch[key], dim=dim_) # batch[key] = torch.cat(batch[key], # dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()