Ejemplo n.º 1
0
    def _download(self):
        from torch_geometric.data.dataset import files_exist, makedirs
        if files_exist(self.raw_paths) and files_exist(
                self.raw_split_file_paths):  # pragma: no cover
            return

        makedirs(self.raw_dir)
        self.download()
Ejemplo n.º 2
0
 def _process(self):
     if self.reprocess:
         self.process()
     if files_exist(self.processed_paths):  # pragma: no cover
         return
     print('Processing...')
     makedirs(self.processed_dir)
     self.process()
     print('Done!')
Ejemplo n.º 3
0
 def raw_file_names(self):
     if files_exist(self.processed_paths):  # pragma: no cover
         return
     onlyfiles = [
         f for f in listdir(self.raw_graphs_path)
         if isfile(join(self.raw_graphs_path, f))
     ]
     self.raw_files = onlyfiles
     return onlyfiles
Ejemplo n.º 4
0
    def process(self):
        if files_exist(self.raw_paths):
            shutil.copyfile(self.raw_paths[0], self.processed_paths[0])
            return

        data_list = []
        for i in range(self.num_per_class):
            data_list.append(self.gen_class1())
            data_list.append(self.gen_class2())

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
Ejemplo n.º 5
0
 def __init__(self,
              root,
              categories,
              single_dual_nodes,
              undirected_dual_edges,
              primal_features_from_dual_features=False,
              train=True,
              prevent_nonmanifold_edges=True,
              num_augmentations=1,
              vertices_scale_mean=None,
              vertices_scale_var=None,
              edges_flip_fraction=None,
              slide_vertices_fraction=None,
              return_sample_indices=False,
              transform=None,
              pre_transform=None,
              pre_filter=None):
     assert (isinstance(categories, list))
     if (len(categories) == 0):
         # Extract data from all the valid categories.
         self.__categories = self.valid_categories
     else:
         for _category in categories:
             assert (_category in self.valid_categories)
         self.__categories = sorted(categories)
     self.__num_augmentations = num_augmentations
     self.__vertices_scale_mean = vertices_scale_mean
     self.__vertices_scale_var = vertices_scale_var
     self.__edges_flip_fraction = edges_flip_fraction
     self.__slide_vertices_fraction = slide_vertices_fraction
     self.__single_dual_nodes = single_dual_nodes
     self.__undirected_dual_edges = undirected_dual_edges
     (self.__primal_features_from_dual_features
      ) = primal_features_from_dual_features
     self.__prevent_nonmanifold_edges = prevent_nonmanifold_edges
     self.__split = 'train' if train else 'test'
     self.__return_sample_indices = return_sample_indices
     self.processed_file_names_train = []
     self.processed_file_names_test = []
     # Store input parameters.
     self.__input_parameters = {
         k: v
         for k, v in locals().items() if (k[0] != '_' and k != 'self')
     }
     # Do not insert the parameter 'return_sample_indices' in the input
     # parameters, as this is only used for data access and does not vary the
     # features of the dataset.
     self.__input_parameters.pop('return_sample_indices')
     self.__input_parameters['categories'] = self.__categories
     self.__input_parameters['root'] = osp.abspath(root)
     super(CosegDualPrimal, self).__init__(root, transform, pre_transform,
                                           pre_filter)
     # Check that if the processed data will not be recomputed (but loaded
     # from disk), the parameters of the processed data stored on disk match
     # the input parameters of the current dataset.
     if (files_exist(self.processed_paths)):
         # Load parameter file of the previously-saved preprocessed data.
         dataset_parameters_filename = osp.join(
             self.processed_dir,
             f'processed_data_params_{self.__split}.pkl')
         try:
             with open(dataset_parameters_filename, 'rb') as f:
                 previous_dataset_params = pkl.load(f)
         except IOError:
             raise IOError(
                 "Unable to open preprocessed-data parameter file "
                 f"'{dataset_parameters_filename}'. Exiting.")
         assert (previous_dataset_params.keys(
         ) == self.__input_parameters.keys()), (
             "The current dataset and the processed one at "
             f"'{self.processed_dir} should have the same list of possible "
             "input parameters, but they do not.")
         if (previous_dataset_params != self.__input_parameters):
             # The two datasets are still compatible if the only difference
             # is in the categories, and those of the current dataset are a
             # subset of those of the previous dataset. Same applies for the
             # number of augmentation, if the augmentation parameters match:
             # in this case, as long as the current dataset has a number of
             # augmentations at most equal to that of the previous dataset,
             # it is possible to keep using the previous one, taking only as
             # many augmented versions as specified in the current dataset.
             different_params = set(
                 k for k in previous_dataset_params.keys() if
                 previous_dataset_params[k] != self.__input_parameters[k])
             are_parameters_compatible = False
             if (len(different_params) == 1):
                 if ('categories' in different_params):
                     are_parameters_compatible = set(
                         self.__input_parameters['categories']).issubset(
                             previous_dataset_params['categories'])
                 elif ('num_augmentations' in different_params):
                     are_parameters_compatible = (
                         self.__input_parameters['num_augmentations'] <=
                         previous_dataset_params['num_augmentations'])
             if (not are_parameters_compatible):
                 raise KeyError(
                     "Trying to use preprocessed data at "
                     f"'{self.processed_dir}', but the parameters with "
                     "which these data were generated do not match the "
                     "input parameters of the current dataset. The "
                     f"parameters that differ are {different_params}. "
                     "Either delete the preprocessed data, specify a "
                     "different root folder, or change the input parameters "
                     "of the current dataset.")
Ejemplo n.º 6
0
 def _process(self):
     if files_exist(self.processed_paths):
         for _path in self.processed_paths:
             os.remove(_path)
     super(_BaseDataset, self)._process()