コード例 #1
0
 def build_dataset_test(self):
     """
     Create testing dataset
     """
     self.dataset_test = dataset_shapenet.ShapeNetSeg(mode="TEST",
                                                                           normalization=self.opt.normalization,
                                                                           class_choice=self.opt.cat,
                                                                           data_augmentation_Z_rotation=False,
                                                                           data_augmentation_Z_rotation_range=40,
                                                                           npoints=self.opt.number_points,
                                                                           random_translation=False)
     self.dataloader_test = torch.utils.data.DataLoader(self.dataset_test, batch_size=self.opt.batch_size,
                                                        shuffle=False, num_workers=int(self.opt.workers),
                                                        drop_last=True)
     self.len_dataset_test = len(self.dataset_test)
コード例 #2
0
 def build_dataset_test_for_matching(self):
     """
     Create testing dataset for matching used at inference
     """
     self.dataset_test = dataset_shapenet.ShapeNetSeg(mode="TEST", knn=False,
                                                      normalization=self.opt.normalization,
                                                      class_choice=self.opt.cat,
                                                      npoints=self.opt.number_points_eval,
                                                      data_augmentation_Z_rotation=False,
                                                      anisotropic_scaling=False,
                                                      sample=False,
                                                      random_translation=False,
                                                      get_single_shape=True)
     self.dataloader_test = torch.utils.data.DataLoader(self.dataset_test, batch_size=1,
                                                        shuffle=False, num_workers=1, drop_last=False)
     self.len_dataset_test = len(self.dataset_test)
     self.parts = self.dataset_train.part_category[self.opt.cat]
コード例 #3
0
 def build_dataset_train(self):
     """
     Create training dataset
     """
     self.dataset_train = dataset_shapenet.ShapeNetSeg(mode=self.opt.mode,
                                                                            knn=self.opt.knn,
                                                                            num_neighbors=self.opt.num_neighbors,
                                                                            normalization=self.opt.normalization,
                                                                            class_choice=self.opt.cat,
                                                                            data_augmentation_Z_rotation=True,
                                                                            data_augmentation_Z_rotation_range=40,
                                                                            anisotropic_scaling=self.opt.anisotropic_scaling,
                                                                            npoints=self.opt.number_points,
                                                                            random_translation=True)
     self.dataloader_train = torch.utils.data.DataLoader(self.dataset_train, batch_size=self.opt.batch_size,
                                                         shuffle=True, num_workers=int(self.opt.workers),
                                                         drop_last=True)
     self.len_dataset = len(self.dataset_train)
コード例 #4
0
import extension.chamfer_python as chamf_python

distChamfer = chamf_python.distChamfer
# ========================================================== #

# Load the Cycle Consistency Model:
trainer = trainer.Trainer(opt)
trainer.build_dataset_train_for_matching()
trainer.build_dataset_test_for_matching()
trainer.build_network()
trainer.network.eval()

# Load our memorization dataset.
dataset = dataset_shapenet.ShapeNetSeg(mode="MEMORIZE",
                                       class_choice="Chair",
                                       npoints=opt.number_points,
                                       get_single_shape=True)
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
#                                          shuffle=False, num_workers=int(opt.workers),
#                                          drop_last=True)
len_dataset_test = len(dataset)


def get_dataset_item(i):
    if i >= len_dataset_test or i < 0:
        return None

    elem = dataset[i]

    points = elem[0][:, :3]
    normals = elem[0][:, 3:6]