def from_data_list_token(data_list, follow_batch=[]): """ This is pretty a copy paste of the from data list of pytorch geometric batch object with the difference that indexes that are negative are not incremented """ keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert "batch" not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch["{}_batch".format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: mask = item >= 0 item[mask] = item[mask] + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] += data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size,), i, dtype=torch.long) batch["{}_batch".format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes,), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError("Unsupported attribute type {} : {}".format(type(item), item)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def collate_fn_withpad(data_list): ''' Modified based on PyTorch-Geometric's implementation :param data_list: :return: ''' keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() for key in keys: batch[key] = [] batch.batch = [] cumsum = 0 for i, data in enumerate(data_list): num_nodes = data.num_nodes batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long)) for key in data.keys: item = data[key] item = item + cumsum if data.__cumsum__(key, item) else item batch[key].append(item) cumsum += num_nodes for key in keys: item = batch[key][0] if torch.is_tensor(item): if (len(item.shape) == 3): tlens = [x.shape[1] for x in batch[key]] maxtlens = np.max(tlens) to_cat = [] for x in batch[key]: to_cat.append( torch.cat([ x, x.new_zeros(x.shape[0], maxtlens - x.shape[1], x.shape[2]) ], dim=1)) batch[key] = torch.cat(to_cat, dim=0) if 'tlens' not in batch.keys: batch['tlens'] = item.new_tensor(tlens, dtype=torch.long) else: batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError('Unsupported attribute type.') batch.batch = torch.cat(batch.batch, dim=-1) return batch.contiguous()
def collate(data_list): keys = data_list[0].keys assert 'batch' not in keys batch = Batch() for key in keys: batch[key] = [] batch.batch = [] if 'edge_index_2' in keys: batch.batch_2 = [] if 'edge_index_3' in keys: batch.batch_3 = [] keys.remove('edge_index') props = [ 'edge_index_2', 'assignment_2', 'edge_index_3', 'assignment_3', 'assignment_2to3' ] keys = [x for x in keys if x not in props] cumsum_1 = N_1 = cumsum_2 = N_2 = cumsum_3 = N_3 = 0 for i, data in enumerate(data_list): for key in keys: batch[key].append(data[key]) N_1 = data.num_nodes batch.edge_index.append(data.edge_index + cumsum_1) batch.batch.append(torch.full((N_1, ), i, dtype=torch.long)) if 'edge_index_2' in data: N_2 = data.assignment_2[1].max().item() + 1 batch.edge_index_2.append(data.edge_index_2 + cumsum_2) batch.assignment_2.append(data.assignment_2 + torch.tensor([[cumsum_1], [cumsum_2]])) batch.batch_2.append(torch.full((N_2, ), i, dtype=torch.long)) if 'edge_index_3' in data: N_3 = data.assignment_3[1].max().item() + 1 batch.edge_index_3.append(data.edge_index_3 + cumsum_3) batch.assignment_3.append(data.assignment_3 + torch.tensor([[cumsum_1], [cumsum_3]])) batch.batch_3.append(torch.full((N_3, ), i, dtype=torch.long)) if 'assignment_2to3' in data: assert 'edge_index_2' in data and 'edge_index_3' in data batch.assignment_2to3.append( data.assignment_2to3 + torch.tensor([[cumsum_2], [cumsum_3]])) cumsum_1 += N_1 cumsum_2 += N_2 cumsum_3 += N_3 keys = [x for x in batch.keys if x not in ['batch', 'batch_2', 'batch_3']] for key in keys: batch[key] = torch.cat(batch[key], dim=data_list[0].cat_dim(key)) batch.batch = torch.cat(batch.batch, dim=-1) if 'batch_2' in batch: batch.batch_2 = torch.cat(batch.batch_2, dim=-1) if 'batch_3' in batch: batch.batch_3 = torch.cat(batch.batch_3, dim=-1) return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = set.union(*keys) keys.remove('depth_count') keys = list(keys) depth = max(data.depth_count.shape[0] for data in data_list) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {i: 0 for i in range(depth)} depth_count = th.zeros((depth, ), dtype=th.long) batch.batch = [] for i, data in enumerate(data_list): edges = data['edge_index'] for d in range(1, depth): mask = data.depth_mask == d edges[mask] += cumsum[d - 1] cumsum[d - 1] += data.depth_count[d - 1].item() batch['edge_index'].append(edges) depth_count += data['depth_count'] for key in data.keys: if key == 'edge_index' or key == 'depth_count': continue item = data[key] batch[key].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) batch.depth_count = depth_count if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: item = data[key] + cumsum.get(key, 0) if key in cumsum: if key == 'edge_index': cumsum[key] += data['x_numeric'].shape[0] if key == 'subgraphs_nodes': cumsum[key] += data['x_numeric'].shape[0] if key == 'subgraphs_edges': cumsum[key] += data['edge_attr_numeric'].shape[0] if key == 'subgraphs_connectivity': cumsum[key] += data['subgraphs_nodes'].shape[0] if key == 'subgraphs_nodes_path_batch': cumsum[key] += data['subgraphs_nodes_path_batch'].max() + 1 if key == 'subgraphs_edges_path_batch': cumsum[key] += data['subgraphs_edges_path_batch'].max() + 1 if key == 'subgraphs_global_path_batch': cumsum[key] += 1 if key == 'cycles_id': if data['cycles_id'].shape[0] > 0: cumsum[key] += data['cycles_id'].max() + 1 if key == 'cycles_edge_index': cumsum[key] += data['edge_attr_numeric'].shape[0] if key == 'edges_connectivity_ids': cumsum[key] += data['edge_attr_numeric'].shape[0] else: if key == 'edge_index': cumsum[key] = data['x_numeric'].shape[0] if key == 'subgraphs_nodes': cumsum[key] = data['x_numeric'].shape[0] if key == 'subgraphs_edges': cumsum[key] = data['edge_attr_numeric'].shape[0] if key == 'subgraphs_connectivity': cumsum[key] = data['subgraphs_nodes'].shape[0] if key == 'subgraphs_nodes_path_batch': cumsum[key] = data['subgraphs_nodes_path_batch'].max() + 1 if key == 'subgraphs_edges_path_batch': cumsum[key] = data['subgraphs_edges_path_batch'].max() + 1 if key == 'subgraphs_global_path_batch': cumsum[key] = 1 if key == 'cycles_id': if data['cycles_id'].shape[0] > 0: cumsum[key] = data['cycles_id'].max() + 1 if key == 'cycles_edge_index': cumsum[key] = data['edge_attr_numeric'].shape[0] if key == 'edges_connectivity_ids': cumsum[key] = data['edge_attr_numeric'].shape[0] batch[key].append(item) for key in follow_batch: size = data[key].size(data.__cat_dim__(key, data[key])) item = torch.full((size, ), i, dtype=torch.long) batch['{}_batch'.format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): if key not in ['subgraphs_connectivity', 'edges_connectivity_ids']: batch[key] = torch.cat( batch[key], dim=data_list[0].__cat_dim__(key, item)) else: batch[key] = torch.cat( batch[key], dim=-1) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError('Unsupported attribute type.') # Copy custom data functions to batch (does not work yet): # if data_list.__class__ != Data: # org_funcs = set(Data.__dict__.keys()) # funcs = set(data_list[0].__class__.__dict__.keys()) # batch.__custom_funcs__ = funcs.difference(org_funcs) # for func in funcs.difference(org_funcs): # setattr(batch, func, getattr(data_list[0], func)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: # logger.info(f"key={key}") item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: item = item + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] = cumsum[key] + data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size,), i, dtype=torch.long) batch['{}_batch'.format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes,), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] logger.debug(f"key = {key}") if torch.is_tensor(item): logger.debug(f"batch[{key}]") logger.debug(f"item.shape = {item.shape}") elem = data_list[0] # type(elem) = Data or ClevrData dim_ = elem.__cat_dim__(key, item) # basically, which dim we want to concat batch[key] = torch.cat(batch[key], dim=dim_) # batch[key] = torch.cat(batch[key], # dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()