def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, pos=None, norm=None, face=None, **kwargs): self.x = x self.edge_index = edge_index self.edge_attr = edge_attr self.y = y self.pos = pos self.norm = norm self.face = face for key, item in kwargs.items(): if key == 'num_nodes': self.__num_nodes__ = item else: self[key] = item if edge_index is not None and edge_index.dtype != torch.long: raise ValueError( (f'Argument `edge_index` needs to be of type `torch.long` but ' f'found type `{edge_index.dtype}`.')) if face is not None and face.dtype != torch.long: raise ValueError( (f'Argument `face` needs to be of type `torch.long` but found ' f'type `{face.dtype}`.')) if torch_geometric.is_debug_enabled(): self.debug()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" for data in data_list: assert isinstance(data, MultiScaleData) num_scales = data_list[0].num_scales for data_entry in data_list: assert data_entry.num_scales == num_scales, "All data objects should contain the same number of scales" # Build multiscale batches multiscale = [] for scale in range(num_scales): ms_scale = [] for data_entry in data_list: ms_scale.append(data_entry.multiscale[scale]) multiscale.append(from_data_list_token(ms_scale)) # Create batch from non multiscale data for data_entry in data_list: del data_entry.multiscale batch = Batch.from_data_list(data_list) batch = MultiScaleBatch.from_data(batch) batch.multiscale = multiscale if torch_geometric.is_debug_enabled(): batch.debug() return batch
def from_data_list_token(data_list, follow_batch=[]): """ This is pretty a copy paste of the from data list of pytorch geometric batch object with the difference that indexes that are negative are not incremented """ keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert "batch" not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch["{}_batch".format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: mask = item >= 0 item[mask] = item[mask] + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] += data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size,), i, dtype=torch.long) batch["{}_batch".format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes,), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError("Unsupported attribute type {} : {}".format(type(item), item)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def __produce_bipartite_data_flow__(self, n_id): r"""Produces a :obj:`DataFlow` object with a bipartite assignment matrix for a given mini-batch :obj:`n_id`.""" data_flow = DataFlow(n_id, self.flow) for l in range(self.num_hops): e_id = neighbor_sampler(n_id, self.cumdeg, self.size[l]) new_n_id = self.edge_index_j.index_select(0, e_id) e_id = self.e_assoc[e_id] if self.add_self_loops: new_n_id = torch.cat([new_n_id, n_id], dim=0) new_n_id, inv = new_n_id.unique(sorted=False, return_inverse=True) res_n_id = inv[-n_id.size(0):] else: new_n_id = new_n_id.unique(sorted=False) res_n_id = None edges = [None, None] edge_index_i = self.edge_index[self.i, e_id] if self.add_self_loops: edge_index_i = torch.cat([edge_index_i, n_id], dim=0) self.tmp[n_id] = torch.arange(n_id.size(0)) edges[self.i] = self.tmp[edge_index_i] edge_index_j = self.edge_index[self.j, e_id] if self.add_self_loops: edge_index_j = torch.cat([edge_index_j, n_id], dim=0) self.tmp[new_n_id] = torch.arange(new_n_id.size(0)) edges[self.j] = self.tmp[edge_index_j] edge_index = torch.stack(edges, dim=0) e_id = self.e_id[e_id] if self.add_self_loops: if self.edge_index_loop.size(1) == self.data.num_nodes: # Only set `e_id` if all self-loops were initially passed # to the graph. e_id = torch.cat([e_id, self.e_id_loop[n_id]]) else: e_id = None if torch_geometric.is_debug_enabled(): warnings.warn( ('Could not add edge identifiers to the DataFlow' 'object due to missing initial self-loops. ' 'Please make sure that your graph already ' 'contains self-loops in case you want to use ' 'edge-conditioned operators.')) n_id = new_n_id data_flow.append(n_id, res_n_id, e_id, edge_index) return data_flow
def from_dict(cls, dictionary): r"""Creates a data object from a python dictionary.""" data = cls() for key, item in dictionary.items(): data[key] = item if torch_geometric.is_debug_enabled(): data.debug() return data
def forward(self, x, edge_index, pseudo, edge_weight=None): """""" try: import torch_geometric except ImportError: # for debug mode only torch_geometric = None x = x.unsqueeze(-1) if x.dim() == 1 else x pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo if not x.is_cuda: warnings.warn( "We do not recommend using the non-optimized CPU " "version of SplineConv. If possible, please convert " "your data to the GPU." ) if torch_geometric is not None and torch_geometric.is_debug_enabled(): if x.size(1) != self.in_channels: raise RuntimeError( "Expected {} node features, but found {}".format( self.in_channels, x.size(1) ) ) if pseudo.size(1) != self.dim: raise RuntimeError( ( "Expected pseudo-coordinate dimensionality of {}, but " "found {}" ).format(self.dim, pseudo.size(1)) ) min_index, max_index = edge_index.min(), edge_index.max() if min_index < 0 or max_index > x.size(0) - 1: raise RuntimeError( ( "Edge indices must lay in the interval [0, {}]" " but found them in the interval [{}, {}]" ).format(x.size(0) - 1, min_index, max_index) ) min_pseudo, max_pseudo = pseudo.min(), pseudo.max() if min_pseudo < 0 or max_pseudo > 1: raise RuntimeError( ( "Pseudo-coordinates must lay in the fixed interval [0, 1]" " but found them in the interval [{}, {}]" ).format(min_pseudo, max_pseudo) ) x = self.ball_in.logmap0(x) return self.propagate(edge_index, x=x, pseudo=pseudo, edge_weight=edge_weight)
def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, pos=None, norm=None, face=None, **kwargs): self.x = x self.edge_index = edge_index self.edge_attr = edge_attr self.y = y self.pos = pos self.norm = norm self.face = face for key, item in kwargs.items(): if key == 'num_nodes': self.__num_nodes__ = item else: self[key] = item if torch_geometric.is_debug_enabled(): self.debug()
def __init__(self, img=None, y=None, indices=None, mask=None, **kwargs): self.img = img self.y = y self.indices = indices self.mask = mask if indices is not None: self.num_nodes = self.indices.shape[0] else: self.num_nodes = None for key, item in kwargs.items(): if key == 'num_nodes': self.__num_nodes__ = item else: self[key] = item if torch_geometric.is_debug_enabled(): self.debug()
def forward(self, x, edge_index, pseudo): """""" x = x.unsqueeze(-1) if x.dim() == 1 else x pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo if not x.is_cuda: warnings.warn('We do not recommend using the non-optimized CPU ' 'version of SplineConv. If possible, please convert ' 'your data to the GPU.') if torch_geometric.is_debug_enabled(): if x.size(1) != self.in_channels: raise RuntimeError( 'Expected {} node features, but found {}'.format( self.in_channels, x.size(1))) if pseudo.size(1) != self.dim: raise RuntimeError( ('Expected pseudo-coordinate dimensionality of {}, but ' 'found {}').format(self.dim, pseudo.size(1))) min_index, max_index = edge_index.min(), edge_index.max() if min_index < 0 or max_index > x.size(0) - 1: raise RuntimeError( ('Edge indices must lay in the interval [0, {}]' ' but found them in the interval [{}, {}]').format( x.size(0) - 1, min_index, max_index)) min_pseudo, max_pseudo = pseudo.min(), pseudo.max() if min_pseudo < 0 or max_pseudo > 1: raise RuntimeError( ('Pseudo-coordinates must lay in the fixed interval [0, 1]' ' but found them in the interval [{}, {}]').format( min_pseudo, max_pseudo)) return self.propagate(edge_index, x=x, pseudo=pseudo)
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: item = data[key] + cumsum.get(key, 0) if key in cumsum: cumsum[key] += data.__inc__(key, item) else: cumsum[key] = data.__inc__(key, item) batch[key].append(item) for key in follow_batch: size = data[key].size(data.__cat_dim__(key, data[key])) item = torch.full((size, ), i, dtype=torch.long) batch['{}_batch'.format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError('Unsupported attribute type.') # Copy custom data functions to batch (does not work yet): # if data_list.__class__ != Data: # org_funcs = set(Data.__dict__.keys()) # funcs = set(data_list[0].__class__.__dict__.keys()) # batch.__custom_funcs__ = funcs.difference(org_funcs) # for func in funcs.difference(org_funcs): # setattr(batch, func, getattr(data_list[0], func)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = MultiScaleBatch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {key: 0 for key in keys} cumsum4list = {key: [] for key in keys} batch.batch = [] batch.list_batch = [] for i, data in enumerate(data_list): for key in data.keys: item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: item = item + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) # Here particular case for this kind of list of tensors # process the neighbors if bool(re.search('(neigh|pool|upsample)', key)): if isinstance(item, list): if (len(cumsum4list[key]) == 0): # for the first time cumsum4list[key] = torch.zeros(len(item), dtype=torch.long) for j in range(len(item)): if (torch.is_tensor(item[j]) and item[j].dtype != torch.bool): item[j][item[j] > 0] += cumsum4list[key][j] # print(key, data["{}_size".format(key)][j]) cumsum4list[key][j] += data["{}_size".format( key)][j] else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] += data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size, ), i, dtype=torch.long) batch['{}_batch'.format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) # indice of the batch at each scale if (data.points is not None): list_batch = [] for j in range(len(data['points'])): size = len(data.points[j]) item = torch.full((size, ), i, dtype=torch.long) list_batch.append(item) batch.list_batch.append(list_batch) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) elif isinstance(item, list): item = batch[key] res = [] for j in range(len(item[0])): col = [f[j] for f in batch[key]] res.append( torch.cat(col, dim=data_list[0].__cat_dim__(key, col))) batch[key] = res # print('item', item) else: raise ValueError('{} is an Unsupported attribute type'.format( type(item))) # Copy custom data functions to batch (does not work yet): # if data_list.__class__ != Data: # org_funcs = set(Data.__dict__.keys()) # funcs = set(data_list[0].__class__.__dict__.keys()) # batch.__custom_funcs__ = funcs.difference(org_funcs) # for func in funcs.difference(org_funcs): # setattr(batch, func, getattr(data_list[0], func)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() for key in data_list[0].__dict__.keys(): if key[:2] != '__' and key[-2:] != '__': batch[key] = None batch.__data_class__ = data_list[0].__class__ for key in keys + ['batch']: batch[key] = [] device = None slices = {key: [0] for key in keys} cumsum = {key: [0] for key in keys} cat_dims = {} num_nodes_list = [] for i, data in enumerate(data_list): for key in keys: item = data[key] # Increase values by `cumsum` value. cum = cumsum[key][-1] if isinstance(item, Tensor) and item.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: item = item + cum elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: value = value + cum item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item + cum # Treat 0-dimensional tensors as 1-dimensional. if isinstance(item, Tensor) and item.dim() == 0: item = item.unsqueeze(0) batch[key].append(item) # Gather the size of the `cat` dimension. size = 1 cat_dim = data.__cat_dim__(key, data[key]) cat_dims[key] = cat_dim if isinstance(item, Tensor): size = item.size(cat_dim) device = item.device elif isinstance(item, SparseTensor): size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] device = item.device() slices[key].append(size + slices[key][-1]) inc = data.__inc__(key, item) if isinstance(inc, (tuple, list)): inc = torch.tensor(inc) cumsum[key].append(inc + cumsum[key][-1]) if key in follow_batch: if isinstance(size, Tensor): for j, size in enumerate(size.tolist()): tmp = f'{key}_{j}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) else: tmp = f'{key}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) if hasattr(data, '__num_nodes__'): num_nodes_list.append(data.__num_nodes__) else: num_nodes_list.append(None) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long, device=device) batch.batch.append(item) # Fix initial slice values: for key in keys: slices[key][0] = slices[key][1] - slices[key][1] batch.batch = None if len(batch.batch) == 0 else batch.batch batch.__slices__ = slices batch.__cumsum__ = cumsum batch.__cat_dims__ = cat_dims batch.__num_nodes_list__ = num_nodes_list ref_data = data_list[0] for key in batch.keys: items = batch[key] item = items[0] if isinstance(item, Tensor): batch[key] = torch.cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, SparseTensor): batch[key] = cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, (int, float)): batch[key] = torch.tensor(items) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = set.union(*keys) keys.remove('depth_count') keys = list(keys) depth = max(data.depth_count.shape[0] for data in data_list) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {i: 0 for i in range(depth)} depth_count = th.zeros((depth, ), dtype=th.long) batch.batch = [] for i, data in enumerate(data_list): edges = data['edge_index'] for d in range(1, depth): mask = data.depth_mask == d edges[mask] += cumsum[d - 1] cumsum[d - 1] += data.depth_count[d - 1].item() batch['edge_index'].append(edges) depth_count += data['depth_count'] for key in data.keys: if key == 'edge_index' or key == 'depth_count': continue item = data[key] batch[key].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) batch.depth_count = depth_count if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ for key in keys + ['batch']: batch[key] = [] slices = {key: [0] for key in keys} cumsum = {key: [0] for key in keys} cat_dims = {} num_nodes_list = [] mm = 0 for i, data in enumerate(data_list): for key in keys: item = data[key] # Increase values by `cumsum` value. cum = cumsum[key][-1] # DAGNN if key in ["bi_layer_index", "bi_layer_parent_index"]: # if key == "bi_layer_index": # # print("now") # mm = torch.max(item).item() # print(mm, data.x.shape[0]) item[:, 1] = item[:, 1] + cum # if key == "bi_layer_index": # # print("now") # mm = torch.max(item).item() # print(mm) #, data.x.shape[0]) elif key in ["layer_index", "layer_parent_index"]: # keep layer dimension 0 and only increase ids in dimension 1 (we merge layers of all data objects) item[1] = item[1] + cum elif isinstance(item, Tensor) and item.dtype != torch.bool: item = item + cum if cum != 0 else item elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: value = value + cum if cum != 0 else value item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item + cum # Treat 0-dimensional tensors as 1-dimensional. if isinstance(item, Tensor) and item.dim() == 0: item = item.unsqueeze(0) batch[key].append(item) # Gather the size of the `cat` dimension. size = 1 cat_dim = data.__cat_dim__(key, data[key]) cat_dims[key] = cat_dim if isinstance(item, Tensor): size = item.size(cat_dim) elif isinstance(item, SparseTensor): size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] slices[key].append(size + slices[key][-1]) inc = data.__inc__(key, item) if isinstance(inc, (tuple, list)): inc = torch.tensor(inc) cumsum[key].append(inc + cumsum[key][-1]) if key in follow_batch: if isinstance(size, Tensor): for j, size in enumerate(size.tolist()): tmp = f'{key}_{j}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long)) else: tmp = f'{key}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long)) if hasattr(data, '__num_nodes__'): num_nodes_list.append(data.__num_nodes__) else: num_nodes_list.append(None) num_nodes = data.num_nodes # print(mm, num_nodes, sum(num_nodes_list)) if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) # Fix initial slice values: for key in keys: slices[key][0] = slices[key][1] - slices[key][1] batch.batch = None if len(batch.batch) == 0 else batch.batch batch.__slices__ = slices batch.__cumsum__ = cumsum batch.__cat_dims__ = cat_dims batch.__num_nodes_list__ = num_nodes_list ref_data = data_list[0] for key in batch.keys: items = batch[key] item = items[0] if isinstance(item, Tensor): batch[key] = torch.cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, SparseTensor): batch[key] = cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, (int, float)): batch[key] = torch.tensor(items) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: # logger.info(f"key={key}") item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: item = item + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] = cumsum[key] + data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size,), i, dtype=torch.long) batch['{}_batch'.format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes,), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] logger.debug(f"key = {key}") if torch.is_tensor(item): logger.debug(f"batch[{key}]") logger.debug(f"item.shape = {item.shape}") elem = data_list[0] # type(elem) = Data or ClevrData dim_ = elem.__cat_dim__(key, item) # basically, which dim we want to concat batch[key] = torch.cat(batch[key], dim=dim_) # batch[key] = torch.cat(batch[key], # dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`.""" keys = set(data_list[0].keys) keys -= {'num_nodes', 'num_edges'} keys -= set(exclude_keys) keys = sorted(list(keys)) assert 'batch' not in keys and 'ptr' not in keys batch = cls() batch.__num_graphs__ = len(data_list) batch.__data_class__ = data_list[0].__class__ for key in keys + ['batch']: batch[key] = [] batch.ptr = [0] device = None slices = {key: [0] for key in keys} cumsum = {key: [0] for key in keys} cat_dims = {} num_nodes_list = [] for i, data in enumerate(data_list): for key in keys: item = data[key] # Increase values by `cumsum` value. cum = cumsum[key][-1] if isinstance(item, Tensor) and item.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: item = item + cum elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: value = value + cum item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item + cum # Gather the size of the `cat` dimension. size = 1 cat_dim = data.__cat_dim__(key, data[key]) # 0-dimensional tensors have no dimension along which to # concatenate, so we set `cat_dim` to `None`. if isinstance(item, Tensor) and item.dim() == 0: cat_dim = None cat_dims[key] = cat_dim # Add a batch dimension to items whose `cat_dim` is `None`: if isinstance(item, Tensor) and cat_dim is None: cat_dim = 0 # Concatenate along this new batch dimension. item = item.unsqueeze(0) device = item.device elif isinstance(item, Tensor): size = item.size(cat_dim) device = item.device elif isinstance(item, SparseTensor): size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] device = item.device() batch[key].append(item) # Append item to the attribute list. slices[key].append(size + slices[key][-1]) inc = data.__inc__(key, item) if isinstance(inc, (tuple, list)): inc = torch.tensor(inc) cumsum[key].append(inc + cumsum[key][-1]) if key in follow_batch: if isinstance(size, Tensor): for j, size in enumerate(size.tolist()): tmp = f'{key}_{j}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) else: tmp = f'{key}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) if 'num_nodes' in data: num_nodes_list.append(data.num_nodes) else: num_nodes_list.append(None) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long, device=device) batch.batch.append(item) batch.ptr.append(batch.ptr[-1] + num_nodes) batch.batch = None if len(batch.batch) == 0 else batch.batch batch.ptr = None if len(batch.ptr) == 1 else batch.ptr batch.__slices__ = slices batch.__cumsum__ = cumsum batch.__cat_dims__ = cat_dims batch.__num_nodes_list__ = num_nodes_list ref_data = data_list[0] for key in batch.keys: items = batch[key] item = items[0] cat_dim = ref_data.__cat_dim__(key, item) cat_dim = 0 if cat_dim is None else cat_dim if isinstance(item, Tensor): batch[key] = torch.cat(items, cat_dim) elif isinstance(item, SparseTensor): batch[key] = cat(items, cat_dim) elif isinstance(item, (int, float)): batch[key] = torch.tensor(items) if torch_geometric.is_debug_enabled(): batch.debug() return batch
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch_custom() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] # we have a bipartite graph, so keep track of batches for each set of nodes batch.batch_factors = [] #for factor beliefs batch.batch_vars = [] #for variable beliefs junk_bin_val = 0 for i, data in enumerate(data_list): for key in data.keys: item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: # print("key:", key) # print("item:", item) # print("cumsum[key]:", cumsum[key]) # print() if key == "facStates_to_varIdx": # print("a item:", item) # print("a item.shape:", item.shape) # print("cumsum[key]:", cumsum[key]) # print("cumsum[key].shape:", cumsum[key].shape) # item = item + cumsum[key] item = item.clone() #without this we edit the data for the next epoch, causing errors item[torch.where(item != -1)] = item[torch.where(item != -1)] + cumsum[key] # print("b item:", item) # print("b item.shape:", item.shape) else: item = item + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) if key == "facStates_to_varIdx": facStates_to_varIdx_inc = data.__inc__(key, item) junk_bin_val += facStates_to_varIdx_inc cumsum[key] += facStates_to_varIdx_inc else: cumsum[key] += data.__inc__(key, item) batch[key].append(item) # if key == 'edge_index' or key == 'facToVar_edge_idx': # print("key:", key) # print("cumsum[key]:", cumsum[key]) # print() if key in follow_batch: item = torch.full((size, ), i, dtype=torch.long) batch['{}_batch'.format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) print() print("num_nodes item:", item) print() batch.batch_factors.append(torch.full((data.numFactors, ), i, dtype=torch.long)) batch.batch_vars.append(torch.full((data.numVars, ), i, dtype=torch.long)) if num_nodes is None: batch.batch = None # print("1 batch.num_nodes:", batch.num_nodes) # print("000 type(batch.num_nodes):", type(batch.num_nodes)) for key in batch.keys: # print() # print("key:", key) item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) if key == "facStates_to_varIdx": batch[key][torch.where(batch[key] == -1)] = junk_bin_val elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError('Unsupported attribute type') ######jdk debugging # if key == 'num_nodes': # print("key: num_nodes, batch[key]:", batch[key]) # try: # print("key:", key, "batch.num_nodes:", batch.num_nodes) # except: # print("batch.num_nodes doesn't work yet") ######end jdk debugging # Copy custom data functions to batch (does not work yet): # if data_list.__class__ != Data: # org_funcs = set(Data.__dict__.keys()) # funcs = set(data_list[0].__class__.__dict__.keys()) # batch.__custom_funcs__ = funcs.difference(org_funcs) # for func in funcs.difference(org_funcs): # setattr(batch, func, getattr(data_list[0], func)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def test_debug(): assert is_debug_enabled() is False set_debug(True) assert is_debug_enabled() is True set_debug(False) assert is_debug_enabled() is False assert is_debug_enabled() is False with set_debug(True): assert is_debug_enabled() is True assert is_debug_enabled() is False assert is_debug_enabled() is False set_debug(True) assert is_debug_enabled() is True with set_debug(False): assert is_debug_enabled() is False assert is_debug_enabled() is True set_debug(False) assert is_debug_enabled() is False assert is_debug_enabled() is False with debug(): assert is_debug_enabled() is True assert is_debug_enabled() is False