def parallel_record_records(self, task, num_processes, shard_size, transform_fn): check_required_kwargs(["project", "access_id", "access_key"], self._kwargs) start = task.start end = task.end table = self._get_odps_table_name(task.shard_name) table = table.split(".")[1] project = self._kwargs["project"] access_id = self._kwargs["access_id"] access_key = self._kwargs["access_key"] endpoint = self._kwargs.get("endpoint") partition = self._kwargs.get("partition", None) columns = self._kwargs.get("columns", None) pd = ODPSReader( access_id=access_id, access_key=access_key, project=project, endpoint=endpoint, table=table, partition=partition, num_processes=num_processes, transform_fn=transform_fn, columns=columns, ) pd.reset((start, end - start), shard_size) shard_count = pd.get_shards_count() for i in range(shard_count): records = pd.get_records() for record in records: yield record pd.stop()
def test_parallel_read(self): def transform(record): return float(record[0]) + 1 start = 0 end = 100 shard_size = (end - start) // 4 pd = ODPSReader( access_id=self._access_id, access_key=self._access_key, project=self._project, endpoint=self._endpoint, table=self._test_read_table, num_processes=2, transform_fn=transform, ) results = [] pd.reset((start, end - start), shard_size) shard_count = pd.get_shards_count() for i in range(shard_count): records = pd.get_records() for record in records: results.append(record) pd.stop() self.assertEqual(len(results), 100)
def _get_reader(self, table_name): _check_required_kwargs(["project", "access_id", "access_key"], self._kwargs) return ODPSReader( project=self._kwargs["project"], access_id=self._kwargs["access_id"], access_key=self._kwargs["access_key"], table=table_name, endpoint=self._kwargs.get("endpoint"), num_processes=self._kwargs.get("num_processes", 1), )
def test_write_odps_to_recordio_shards_from_iterator(self): reader = ODPSReader( self._project, self._access_id, self._access_key, self._endpoint, self._test_read_table, None, 4, None, ) records_iter = reader.to_iterator(1, 0, 50, 2, False, None) with tempfile.TemporaryDirectory() as output_dir: write_recordio_shards_from_iterator( records_iter, ["f" + str(i) for i in range(5)], output_dir, records_per_shard=50, ) self.assertEqual(len(os.listdir(output_dir)), 5)
def test_read_to_iterator(self): reader = ODPSReader( self._project, self._access_id, self._access_key, self._endpoint, self._test_read_table, None, 4, None, ) records_iter = reader.to_iterator(1, 0, 50, 2, False, None) records = list(records_iter) self.assertEqual(len(records), 6, "Unexpected number of batches: %d" % len(records)) flattened_records = [record for batch in records for record in batch] self.assertEqual( len(flattened_records), 220, "Unexpected number of total records: %d" % len(flattened_records), )
def get_odps_reader(self, table_name): return ODPSReader( project=self._kwargs["project"], access_id=self._kwargs["access_id"], access_key=self._kwargs["access_key"], table=table_name, endpoint=self._kwargs.get("endpoint"), partition=self._kwargs.get("partition", None), num_processes=self._kwargs.get("num_processes", 1), options={ "odps.options.tunnel.endpoint": self._kwargs.get("tunnel_endpoint", None) }, )