示例#1
0
def list_to_batch(data):
    r"""Constructs a batch object from a python list holding
    :class:`torch_geometric.data.Data` objects. 
        """

    keys = ['x','y','pos','grid_size']

    batch = SimpleBatch()
    batch.__data_class__ = data.__class__
    

    for key in data[0].keys:
        if key in ['y','grid_size']:
            item=torch.cat([data[i][key] for i in range(len(data))])
            batch[key]=item
        else:
            item=torch.cat([data[i][key] for i in range(len(data))])
            batch[key]=item
    return batch.contiguous()
示例#2
0
def batch_to_batch3(data,l,j):
    r"""Constructs a batch object from a python list holding
    :class:`torch_geometric.data.Data` objects. 
        """

    keys = ['x','y','pos','grid_size']

    batch = SimpleBatch()
    batch.__data_class__ = data.__class__

    for key in data.keys:
        if key in ['y','grid_size']:
            item = data[key]
            batch[key]=item[[0,j]]
        else:
            item = data[key]
            batch[key]=torch.cat((torch.unsqueeze(item[0,l,:],0),torch.unsqueeze(item[j,l,:],0)),axis=0)
            #batch[key]=item[:,:128,:]
    return batch.contiguous()
示例#3
0
def batch_to_batch2(data,random):
    r"""Constructs a batch object from a python list holding
    :class:`torch_geometric.data.Data` objects. 
        """

    keys = ['x','y','pos','grid_size']

    batch = SimpleBatch()
    batch.__data_class__ = data.__class__
    
    l1,l2,l3=get_list_random(random,len(data['x'][0]))

    for key in data.keys:
        if key in ['y','grid_size']:
            item = data[key]
            batch[key]=item[[0,1]]
        else:
            item = data[key]
            batch[key]=torch.cat((torch.unsqueeze(item[0,l1,:],0),torch.unsqueeze(item[1,l1,:],0)),axis=0)
            #batch[key]=item[:,:128,:]
    return batch.contiguous(),l1