def parallel_record_records(self, task, num_processes, shard_size, transform_fn): check_required_kwargs(["project", "access_id", "access_key"], self._kwargs) start = task.start end = task.end table = self._get_odps_table_name(task.shard_name) table = table.split(".")[1] project = self._kwargs["project"] access_id = self._kwargs["access_id"] access_key = self._kwargs["access_key"] endpoint = self._kwargs.get("endpoint") partition = self._kwargs.get("partition", None) columns = self._kwargs.get("columns", None) pd = ODPSReader( access_id=access_id, access_key=access_key, project=project, endpoint=endpoint, table=table, partition=partition, num_processes=num_processes, transform_fn=transform_fn, columns=columns, ) pd.reset((start, end - start), shard_size) shard_count = pd.get_shards_count() for i in range(shard_count): records = pd.get_records() for record in records: yield record pd.stop()
def test_parallel_read(self): def transform(record): return float(record[0]) + 1 start = 0 end = 100 shard_size = (end - start) // 4 pd = ODPSReader( access_id=self._access_id, access_key=self._access_key, project=self._project, endpoint=self._endpoint, table=self._test_read_table, num_processes=2, transform_fn=transform, ) results = [] pd.reset((start, end - start), shard_size) shard_count = pd.get_shards_count() for i in range(shard_count): records = pd.get_records() for record in records: results.append(record) pd.stop() self.assertEqual(len(results), 100)