def collate(data): data_list, parts = [d[0] for d in data], [d[1] for d in data] partptr = cluster_dataset.__partptr__ adj = cat([data.adj for data in data_list], dim=0) adj = adj.t() adjs = [] for part in parts: start = partptr[part] length = partptr[part + 1] - start adjs.append(adj.narrow(0, start, length)) adj = cat(adjs, dim=0).t() row, col, value = adj.coo() data = cluster_dataset.__data__.__class__() data.num_nodes = adj.size(0) data.edge_index = torch.stack([row, col], dim=0) data.edge_attr = value ref = data_list[0] keys = ref.keys keys.remove('adj') for key in keys: if ref[key].size(0) != ref.adj.size(0): data[key] = ref[key] else: data[key] = torch.cat([d[key] for d in data_list], dim=ref.__cat_dim__(key, ref[key])) if cluster_dataset.dataset.transform is not None: data = cluster_dataset.dataset.transform(data) return data
def collate(batch): data_list = [data[0] for data in batch] parts: List[int] = [data[1] for data in batch] partptr = cluster_data.partptr adj = cat([data.adj for data in data_list], dim=0) adj = adj.t() adjs = [] for part in parts: start = partptr[part] length = partptr[part + 1] - start adjs.append(adj.narrow(0, start, length)) adj = cat(adjs, dim=0).t() row, col, value = adj.coo() data = cluster_data.data.__class__() data.num_nodes = adj.size(0) data.edge_index = torch.stack([row, col], dim=0) data.edge_attr = value ref = data_list[0] keys = ref.keys keys.remove('adj') for key in keys: if ref[key].size(0) != ref.adj.size(0): data[key] = ref[key] else: data[key] = torch.cat([d[key] for d in data_list], dim=ref.__cat_dim__(key, ref[key])) return data
def __collate__(self, batch): if not isinstance(batch, torch.Tensor): batch = torch.tensor(batch) N = self.cluster_data.data.num_nodes E = self.cluster_data.data.num_edges start = self.cluster_data.partptr[batch].tolist() end = self.cluster_data.partptr[batch + 1].tolist() node_idx = torch.cat([torch.arange(s, e) for s, e in zip(start, end)]) data = copy.copy(self.cluster_data.data) adj, data.adj = self.cluster_data.data.adj, None adj = cat([adj.narrow(0, s, e - s) for s, e in zip(start, end)], dim=0) adj = adj.index_select(1, node_idx) row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in data: if item.size(0) == N: data[key] = item[node_idx] if item.size(0) == E: data[key] = item[edge_idx] return data
def _collate( key: str, values: List[Any], data_list: List[BaseData], stores: List[BaseStorage], increment: bool, ) -> Tuple[Any, Any, Any]: elem = values[0] if isinstance(elem, Tensor): # Concatenate a list of `torch.Tensor` along the `cat_dim`. # NOTE: We need to take care of incrementing elements appropriately. cat_dim = data_list[0].__cat_dim__(key, elem, stores[0]) if cat_dim is None or elem.dim() == 0: values = [value.unsqueeze(0) for value in values] slices = cumsum([value.size(cat_dim or 0) for value in values]) if increment: incs = get_incs(key, values, data_list, stores) if incs.dim() > 1 or int(incs[-1]) != 0: values = [ value + inc.to(value.device) for value, inc in zip(values, incs) ] else: incs = None if torch.utils.data.get_worker_info() is not None: # Write directly into shared memory to avoid an extra copy: numel = sum(value.numel() for value in values) storage = elem.storage()._new_shared(numel) out = elem.new(storage) else: out = None value = torch.cat(values, dim=cat_dim or 0, out=out) return value, slices, incs elif isinstance(elem, SparseTensor) and increment: # Concatenate a list of `SparseTensor` along the `cat_dim`. # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. cat_dim = data_list[0].__cat_dim__(key, elem, stores[0]) cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim repeats = [[value.size(dim) for dim in cat_dims] for value in values] slices = cumsum(repeats) value = cat(values, dim=cat_dim) return value, slices, None elif isinstance(elem, (int, float)): # Convert a list of numerical values to a `torch.Tensor`. value = torch.tensor(values) if increment: incs = get_incs(key, values, data_list, stores) if int(incs[-1]) != 0: value.add_(incs) else: incs = None slices = torch.arange(len(values) + 1) return value, slices, incs elif isinstance(elem, Mapping): # Recursively collate elements of dictionaries. value_dict, slice_dict, inc_dict = {}, {}, {} for key in elem.keys(): value_dict[key], slice_dict[key], inc_dict[key] = _collate( key, [v[key] for v in values], data_list, stores, increment) return value_dict, slice_dict, inc_dict elif (isinstance(elem, Sequence) and not isinstance(elem, str) and len(elem) > 0 and isinstance(elem[0], (Tensor, SparseTensor))): # Recursively collate elements of lists. value_list, slice_list, inc_list = [], [], [] for i in range(len(elem)): value, slices, incs = _collate(key, [v[i] for v in values], data_list, stores, increment) value_list.append(value) slice_list.append(slices) inc_list.append(incs) return value_list, slice_list, inc_list else: # Other-wise, just return the list of values as it is. slices = torch.arange(len(values) + 1) return values, slices, None
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() for key in data_list[0].__dict__.keys(): if key[:2] != '__' and key[-2:] != '__': batch[key] = None batch.__data_class__ = data_list[0].__class__ for key in keys + ['batch']: batch[key] = [] device = None slices = {key: [0] for key in keys} cumsum = {key: [0] for key in keys} cat_dims = {} num_nodes_list = [] for i, data in enumerate(data_list): for key in keys: item = data[key] # Increase values by `cumsum` value. cum = cumsum[key][-1] if isinstance(item, Tensor) and item.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: item = item + cum elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: value = value + cum item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item + cum # Treat 0-dimensional tensors as 1-dimensional. if isinstance(item, Tensor) and item.dim() == 0: item = item.unsqueeze(0) batch[key].append(item) # Gather the size of the `cat` dimension. size = 1 cat_dim = data.__cat_dim__(key, data[key]) cat_dims[key] = cat_dim if isinstance(item, Tensor): size = item.size(cat_dim) device = item.device elif isinstance(item, SparseTensor): size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] device = item.device() slices[key].append(size + slices[key][-1]) inc = data.__inc__(key, item) if isinstance(inc, (tuple, list)): inc = torch.tensor(inc) cumsum[key].append(inc + cumsum[key][-1]) if key in follow_batch: if isinstance(size, Tensor): for j, size in enumerate(size.tolist()): tmp = f'{key}_{j}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) else: tmp = f'{key}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) if hasattr(data, '__num_nodes__'): num_nodes_list.append(data.__num_nodes__) else: num_nodes_list.append(None) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long, device=device) batch.batch.append(item) # Fix initial slice values: for key in keys: slices[key][0] = slices[key][1] - slices[key][1] batch.batch = None if len(batch.batch) == 0 else batch.batch batch.__slices__ = slices batch.__cumsum__ = cumsum batch.__cat_dims__ = cat_dims batch.__num_nodes_list__ = num_nodes_list ref_data = data_list[0] for key in batch.keys: items = batch[key] item = items[0] if isinstance(item, Tensor): batch[key] = torch.cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, SparseTensor): batch[key] = cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, (int, float)): batch[key] = torch.tensor(items) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`.""" keys = set(data_list[0].keys) keys -= {'num_nodes', 'num_edges'} keys -= set(exclude_keys) keys = sorted(list(keys)) assert 'batch' not in keys and 'ptr' not in keys batch = cls() batch.__num_graphs__ = len(data_list) batch.__data_class__ = data_list[0].__class__ for key in keys + ['batch']: batch[key] = [] batch.ptr = [0] device = None slices = {key: [0] for key in keys} cumsum = {key: [0] for key in keys} cat_dims = {} num_nodes_list = [] for i, data in enumerate(data_list): for key in keys: item = data[key] # Increase values by `cumsum` value. cum = cumsum[key][-1] if isinstance(item, Tensor) and item.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: item = item + cum elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: value = value + cum item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item + cum # Gather the size of the `cat` dimension. size = 1 cat_dim = data.__cat_dim__(key, data[key]) # 0-dimensional tensors have no dimension along which to # concatenate, so we set `cat_dim` to `None`. if isinstance(item, Tensor) and item.dim() == 0: cat_dim = None cat_dims[key] = cat_dim # Add a batch dimension to items whose `cat_dim` is `None`: if isinstance(item, Tensor) and cat_dim is None: cat_dim = 0 # Concatenate along this new batch dimension. item = item.unsqueeze(0) device = item.device elif isinstance(item, Tensor): size = item.size(cat_dim) device = item.device elif isinstance(item, SparseTensor): size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] device = item.device() batch[key].append(item) # Append item to the attribute list. slices[key].append(size + slices[key][-1]) inc = data.__inc__(key, item) if isinstance(inc, (tuple, list)): inc = torch.tensor(inc) cumsum[key].append(inc + cumsum[key][-1]) if key in follow_batch: if isinstance(size, Tensor): for j, size in enumerate(size.tolist()): tmp = f'{key}_{j}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) else: tmp = f'{key}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) if 'num_nodes' in data: num_nodes_list.append(data.num_nodes) else: num_nodes_list.append(None) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long, device=device) batch.batch.append(item) batch.ptr.append(batch.ptr[-1] + num_nodes) batch.batch = None if len(batch.batch) == 0 else batch.batch batch.ptr = None if len(batch.ptr) == 1 else batch.ptr batch.__slices__ = slices batch.__cumsum__ = cumsum batch.__cat_dims__ = cat_dims batch.__num_nodes_list__ = num_nodes_list ref_data = data_list[0] for key in batch.keys: items = batch[key] item = items[0] cat_dim = ref_data.__cat_dim__(key, item) cat_dim = 0 if cat_dim is None else cat_dim if isinstance(item, Tensor): batch[key] = torch.cat(items, cat_dim) elif isinstance(item, SparseTensor): batch[key] = cat(items, cat_dim) elif isinstance(item, (int, float)): batch[key] = torch.tensor(items) if torch_geometric.is_debug_enabled(): batch.debug() return batch
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ for key in keys + ['batch']: batch[key] = [] slices = {key: [0] for key in keys} cumsum = {key: [0] for key in keys} cat_dims = {} num_nodes_list = [] mm = 0 for i, data in enumerate(data_list): for key in keys: item = data[key] # Increase values by `cumsum` value. cum = cumsum[key][-1] # DAGNN if key in ["bi_layer_index", "bi_layer_parent_index"]: # if key == "bi_layer_index": # # print("now") # mm = torch.max(item).item() # print(mm, data.x.shape[0]) item[:, 1] = item[:, 1] + cum # if key == "bi_layer_index": # # print("now") # mm = torch.max(item).item() # print(mm) #, data.x.shape[0]) elif key in ["layer_index", "layer_parent_index"]: # keep layer dimension 0 and only increase ids in dimension 1 (we merge layers of all data objects) item[1] = item[1] + cum elif isinstance(item, Tensor) and item.dtype != torch.bool: item = item + cum if cum != 0 else item elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: value = value + cum if cum != 0 else value item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item + cum # Treat 0-dimensional tensors as 1-dimensional. if isinstance(item, Tensor) and item.dim() == 0: item = item.unsqueeze(0) batch[key].append(item) # Gather the size of the `cat` dimension. size = 1 cat_dim = data.__cat_dim__(key, data[key]) cat_dims[key] = cat_dim if isinstance(item, Tensor): size = item.size(cat_dim) elif isinstance(item, SparseTensor): size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] slices[key].append(size + slices[key][-1]) inc = data.__inc__(key, item) if isinstance(inc, (tuple, list)): inc = torch.tensor(inc) cumsum[key].append(inc + cumsum[key][-1]) if key in follow_batch: if isinstance(size, Tensor): for j, size in enumerate(size.tolist()): tmp = f'{key}_{j}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long)) else: tmp = f'{key}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long)) if hasattr(data, '__num_nodes__'): num_nodes_list.append(data.__num_nodes__) else: num_nodes_list.append(None) num_nodes = data.num_nodes # print(mm, num_nodes, sum(num_nodes_list)) if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) # Fix initial slice values: for key in keys: slices[key][0] = slices[key][1] - slices[key][1] batch.batch = None if len(batch.batch) == 0 else batch.batch batch.__slices__ = slices batch.__cumsum__ = cumsum batch.__cat_dims__ = cat_dims batch.__num_nodes_list__ = num_nodes_list ref_data = data_list[0] for key in batch.keys: items = batch[key] item = items[0] if isinstance(item, Tensor): batch[key] = torch.cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, SparseTensor): batch[key] = cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, (int, float)): batch[key] = torch.tensor(items) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()