예제 #1
0
    def test_empty(self):
        x = torch.zeros(4, 1, 34, 34)
        to_dense = ME.MinkowskiToDenseTensor(x.shape)

        # Convert to sparse data
        sparse_data = ME.to_sparse(x)
        dense_data = to_dense(sparse_data)

        self.assertEqual(dense_data.shape, x.shape)
예제 #2
0
    def __init__(self, config={}, **kwargs):
        self.config = self.__class__.default_config()
        self.config.update(config)
        self.config.update(kwargs)

        if self.config.data_root is None:
            self.n_images = 0
            if self.config.img_size is not None:
                self.img_size = self.config.img_size
                self.coords = []
                self.feats = []
                self.n_channels = 1
            else:
                raise ValueError(
                    "If data_root not given, the img_size must be specified in the config"
                )
            self.labels = torch.zeros((0, 1), dtype=torch.long)

        else:
            # load HDF5 Mnist3d dataset
            dataset_filepath = os.path.join(self.config.data_root, '3Dmnist',
                                            'full_dataset_vectors.h5')
            with h5py.File(dataset_filepath, 'r') as file:
                if self.config.split == "train":
                    X = file["X_train"][:]
                    Y = file["y_train"][:]
                elif self.config.split in ["valid", "test"]:
                    X = file["X_test"][:]
                    Y = file["y_test"][:]
                self.n_images = int(X.shape[0])
                self.has_labels = True
                self.labels = torch.LongTensor(Y)
                self.img_size = (16, 16, 16)
                self.n_channels = 1
                self.coords = []
                self.feats = []
                images = torch.Tensor(X).float().reshape((
                    -1,
                    self.n_channels,
                ) + self.img_size)
                images_sparse = ME.to_sparse(images)
                for idx in range(len(images)):
                    self.coords.append(images_sparse.coordinates_at(idx))
                    self.feats.append(images_sparse.features_at(idx))

        if self.config.preprocess is not None:
            self.coords, self.feats = self.config.preprocess(
                self.coords, self.feats)

        # data augmentation boolean
        self.data_augmentation = self.config.data_augmentation
        if self.data_augmentation:
            self.augment = []

        # the user can additionally specify a transform in the config
        self.transform = self.config.transform
        self.target_transform = self.config.target_transform