Пример #1
0
def batch_to_batch3(data,l,j):
    r"""Constructs a batch object from a python list holding
    :class:`torch_geometric.data.Data` objects. 
        """

    keys = ['x','y','pos','grid_size']

    batch = SimpleBatch()
    batch.__data_class__ = data.__class__

    for key in data.keys:
        if key in ['y','grid_size']:
            item = data[key]
            batch[key]=item[[0,j]]
        else:
            item = data[key]
            batch[key]=torch.cat((torch.unsqueeze(item[0,l,:],0),torch.unsqueeze(item[j,l,:],0)),axis=0)
            #batch[key]=item[:,:128,:]
    return batch.contiguous()
Пример #2
0
def list_to_batch(data):
    r"""Constructs a batch object from a python list holding
    :class:`torch_geometric.data.Data` objects. 
        """

    keys = ['x','y','pos','grid_size']

    batch = SimpleBatch()
    batch.__data_class__ = data.__class__
    

    for key in data[0].keys:
        if key in ['y','grid_size']:
            item=torch.cat([data[i][key] for i in range(len(data))])
            batch[key]=item
        else:
            item=torch.cat([data[i][key] for i in range(len(data))])
            batch[key]=item
    return batch.contiguous()
    def test_fromlist(self):
        nb_points = 100
        pos = torch.randn((nb_points, 3))
        y = torch.tensor([range(10) for i in range(pos.shape[0])],
                         dtype=torch.float)
        d = Data(pos=pos, y=y)

        b = SimpleBatch.from_data_list([d, d])
        self.assertEqual(b.pos.size(), (2, 100, 3))
        self.assertEqual(b.y.size(), (2, 100, 10))
Пример #4
0
def batch_to_batch2(data,random):
    r"""Constructs a batch object from a python list holding
    :class:`torch_geometric.data.Data` objects. 
        """

    keys = ['x','y','pos','grid_size']

    batch = SimpleBatch()
    batch.__data_class__ = data.__class__
    
    l1,l2,l3=get_list_random(random,len(data['x'][0]))

    for key in data.keys:
        if key in ['y','grid_size']:
            item = data[key]
            batch[key]=item[[0,1]]
        else:
            item = data[key]
            batch[key]=torch.cat((torch.unsqueeze(item[0,l1,:],0),torch.unsqueeze(item[1,l1,:],0)),axis=0)
            #batch[key]=item[:,:128,:]
    return batch.contiguous(),l1
Пример #5
0
    def _get_collate_function(conv_type, is_multiscale):
        if is_multiscale:
            if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower():
                return lambda datalist: MultiScaleBatch.from_data_list(datalist)
            else:
                raise NotImplementedError(
                    "MultiscaleTransform is activated and supported only for partial_dense format"
                )

        is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type)
        if is_dense:
            return lambda datalist: SimpleBatch.from_data_list(datalist)
        else:
            return lambda datalist: torch_geometric.data.batch.Batch.from_data_list(datalist)
Пример #6
0
 def __getitem__(self, index):
     return SimpleBatch.from_data_list(self.datalist)
print(data)
#Batch(batch=[1000], pos=[1000, 3], x=[1000, 3])

pointnet = PointNet(OmegaConf.create({'conv_type': 'PARTIAL_DENSE'}))

pointnet.set_input(data, "cpu")
data_out = pointnet.forward()
print(data_out.shape)
# torch.Size([1000, 4])

##################### DENSE FORMAT #####################

num_points = 500
num_classes = 10
input_nc = 3

pos = torch.randn((num_points, 3))
x = torch.randn((num_points, input_nc))

data = Data(pos=pos, x=x)
data = SimpleBatch.from_data_list([data, data])

print(data)
#SimpleBatch(pos=[2, 500, 3], x=[2, 500, 3])

pointnet = PointNet(OmegaConf.create({'conv_type': 'DENSE'}))

pointnet.set_input(data, "cpu")
data_out = pointnet.forward()
print(data_out.shape)
#torch.Size([2, 500, 4])