def batch_to_batch3(data,l,j): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. """ keys = ['x','y','pos','grid_size'] batch = SimpleBatch() batch.__data_class__ = data.__class__ for key in data.keys: if key in ['y','grid_size']: item = data[key] batch[key]=item[[0,j]] else: item = data[key] batch[key]=torch.cat((torch.unsqueeze(item[0,l,:],0),torch.unsqueeze(item[j,l,:],0)),axis=0) #batch[key]=item[:,:128,:] return batch.contiguous()
def list_to_batch(data): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. """ keys = ['x','y','pos','grid_size'] batch = SimpleBatch() batch.__data_class__ = data.__class__ for key in data[0].keys: if key in ['y','grid_size']: item=torch.cat([data[i][key] for i in range(len(data))]) batch[key]=item else: item=torch.cat([data[i][key] for i in range(len(data))]) batch[key]=item return batch.contiguous()
def test_fromlist(self): nb_points = 100 pos = torch.randn((nb_points, 3)) y = torch.tensor([range(10) for i in range(pos.shape[0])], dtype=torch.float) d = Data(pos=pos, y=y) b = SimpleBatch.from_data_list([d, d]) self.assertEqual(b.pos.size(), (2, 100, 3)) self.assertEqual(b.y.size(), (2, 100, 10))
def batch_to_batch2(data,random): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. """ keys = ['x','y','pos','grid_size'] batch = SimpleBatch() batch.__data_class__ = data.__class__ l1,l2,l3=get_list_random(random,len(data['x'][0])) for key in data.keys: if key in ['y','grid_size']: item = data[key] batch[key]=item[[0,1]] else: item = data[key] batch[key]=torch.cat((torch.unsqueeze(item[0,l1,:],0),torch.unsqueeze(item[1,l1,:],0)),axis=0) #batch[key]=item[:,:128,:] return batch.contiguous(),l1
def _get_collate_function(conv_type, is_multiscale): if is_multiscale: if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower(): return lambda datalist: MultiScaleBatch.from_data_list(datalist) else: raise NotImplementedError( "MultiscaleTransform is activated and supported only for partial_dense format" ) is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) if is_dense: return lambda datalist: SimpleBatch.from_data_list(datalist) else: return lambda datalist: torch_geometric.data.batch.Batch.from_data_list(datalist)
def __getitem__(self, index): return SimpleBatch.from_data_list(self.datalist)
print(data) #Batch(batch=[1000], pos=[1000, 3], x=[1000, 3]) pointnet = PointNet(OmegaConf.create({'conv_type': 'PARTIAL_DENSE'})) pointnet.set_input(data, "cpu") data_out = pointnet.forward() print(data_out.shape) # torch.Size([1000, 4]) ##################### DENSE FORMAT ##################### num_points = 500 num_classes = 10 input_nc = 3 pos = torch.randn((num_points, 3)) x = torch.randn((num_points, input_nc)) data = Data(pos=pos, x=x) data = SimpleBatch.from_data_list([data, data]) print(data) #SimpleBatch(pos=[2, 500, 3], x=[2, 500, 3]) pointnet = PointNet(OmegaConf.create({'conv_type': 'DENSE'})) pointnet.set_input(data, "cpu") data_out = pointnet.forward() print(data_out.shape) #torch.Size([2, 500, 4])