示例#1
0
    def process(self):
        self.data, self.slices = read_tu_data(self.raw_dir, self.name)

        if self.pre_filter is not None:
            data_list = [self.get(idx) for idx in range(len(self))]
            filter_pool = ray.util.ActorPool(
                [self.pre_filter for _ in range(self.pool_size)])
            mask_list = list(
                filter_pool.map(lambda a, v: a.remote(v), data_list))
            data_list = [
                data for i, data in enumerate(data_list) if mask_list[i]
            ]
            self.data, self.slices = self.collate(data_list)

        if self.pre_transform is not None:
            data_list = [self.get(idx) for idx in range(len(self))]
            transform_pool = ray.util.ActorPool(
                [self.pre_transform for _ in range(self.pool_size)])
            transformed_data = []
            for i in range(0, len(data_list), self.pool_size):
                last_idx = min(i + (self.pool_size), len(data_list))
                transformed_data += list(
                    transform_pool.map(lambda a, v: a.remote(v),
                                       data_list[i:last_idx]))
            self.data, self.slices = self.collate(transformed_data)

        torch.save((self.data, self.slices), self.processed_paths[0])
示例#2
0
    def process(self):
        self.data, self.slices = read_tu_data(self.raw_dir, self.name)

        if self.pre_filter is not None:
            data_list = [self.get(idx) for idx in range(len(self))]
            data_list = [data for data in data_list if self.pre_filter(data)]
            self.data, self.slices = self.collate(data_list)

        if self.pre_transform is not None:
            data_list = [self.get(idx) for idx in range(len(self))]
            data_list = [self.pre_transform(data) for data in data_list]
            self.data, self.slices = self.collate(data_list)

        torch.save((self.data, self.slices), self.processed_paths[0])
示例#3
0
    def process(self):
        self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name)

        if self.pre_filter is not None or self.pre_transform is not None:
            data_list = [self.get(idx) for idx in range(len(self))]

            if self.pre_filter is not None:
                data_list = [d for d in data_list if self.pre_filter(d)]

            if self.pre_transform is not None:
                data_list = [self.pre_transform(d) for d in data_list]

            self.data, self.slices = self.collate(data_list)
            self._data_list = None  # Reset cache.

        torch.save((self.data, self.slices, sizes), self.processed_paths[0])