def list_to_batch(data): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. """ keys = ['x','y','pos','grid_size'] batch = SimpleBatch() batch.__data_class__ = data.__class__ for key in data[0].keys: if key in ['y','grid_size']: item=torch.cat([data[i][key] for i in range(len(data))]) batch[key]=item else: item=torch.cat([data[i][key] for i in range(len(data))]) batch[key]=item return batch.contiguous()
def batch_to_batch3(data,l,j): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. """ keys = ['x','y','pos','grid_size'] batch = SimpleBatch() batch.__data_class__ = data.__class__ for key in data.keys: if key in ['y','grid_size']: item = data[key] batch[key]=item[[0,j]] else: item = data[key] batch[key]=torch.cat((torch.unsqueeze(item[0,l,:],0),torch.unsqueeze(item[j,l,:],0)),axis=0) #batch[key]=item[:,:128,:] return batch.contiguous()
def batch_to_batch2(data,random): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. """ keys = ['x','y','pos','grid_size'] batch = SimpleBatch() batch.__data_class__ = data.__class__ l1,l2,l3=get_list_random(random,len(data['x'][0])) for key in data.keys: if key in ['y','grid_size']: item = data[key] batch[key]=item[[0,1]] else: item = data[key] batch[key]=torch.cat((torch.unsqueeze(item[0,l1,:],0),torch.unsqueeze(item[1,l1,:],0)),axis=0) #batch[key]=item[:,:128,:] return batch.contiguous(),l1