示例#1
0
 def parallel_record_records(self, task, num_processes, shard_size,
                             transform_fn):
     check_required_kwargs(["project", "access_id", "access_key"],
                           self._kwargs)
     start = task.start
     end = task.end
     table = self._get_odps_table_name(task.shard_name)
     table = table.split(".")[1]
     project = self._kwargs["project"]
     access_id = self._kwargs["access_id"]
     access_key = self._kwargs["access_key"]
     endpoint = self._kwargs.get("endpoint")
     partition = self._kwargs.get("partition", None)
     columns = self._kwargs.get("columns", None)
     pd = ODPSReader(
         access_id=access_id,
         access_key=access_key,
         project=project,
         endpoint=endpoint,
         table=table,
         partition=partition,
         num_processes=num_processes,
         transform_fn=transform_fn,
         columns=columns,
     )
     pd.reset((start, end - start), shard_size)
     shard_count = pd.get_shards_count()
     for i in range(shard_count):
         records = pd.get_records()
         for record in records:
             yield record
     pd.stop()
示例#2
0
    def test_parallel_read(self):
        def transform(record):
            return float(record[0]) + 1

        start = 0
        end = 100
        shard_size = (end - start) // 4

        pd = ODPSReader(
            access_id=self._access_id,
            access_key=self._access_key,
            project=self._project,
            endpoint=self._endpoint,
            table=self._test_read_table,
            num_processes=2,
            transform_fn=transform,
        )

        results = []
        pd.reset((start, end - start), shard_size)
        shard_count = pd.get_shards_count()
        for i in range(shard_count):
            records = pd.get_records()
            for record in records:
                results.append(record)
        pd.stop()

        self.assertEqual(len(results), 100)
示例#3
0
 def _get_reader(self, table_name):
     _check_required_kwargs(["project", "access_id", "access_key"],
                            self._kwargs)
     return ODPSReader(
         project=self._kwargs["project"],
         access_id=self._kwargs["access_id"],
         access_key=self._kwargs["access_key"],
         table=table_name,
         endpoint=self._kwargs.get("endpoint"),
         num_processes=self._kwargs.get("num_processes", 1),
     )
示例#4
0
 def get_odps_reader(self, table_name):
     return ODPSReader(
         project=self._kwargs["project"],
         access_id=self._kwargs["access_id"],
         access_key=self._kwargs["access_key"],
         table=table_name,
         endpoint=self._kwargs.get("endpoint"),
         partition=self._kwargs.get("partition", None),
         num_processes=self._kwargs.get("num_processes", 1),
         options={
             "odps.options.tunnel.endpoint":
             self._kwargs.get("tunnel_endpoint", None)
         },
     )
示例#5
0
 def test_write_odps_to_recordio_shards_from_iterator(self):
     reader = ODPSReader(
         self._project,
         self._access_id,
         self._access_key,
         self._endpoint,
         self._test_read_table,
         None,
         4,
         None,
     )
     records_iter = reader.to_iterator(1, 0, 50, 2, False, None)
     with tempfile.TemporaryDirectory() as output_dir:
         write_recordio_shards_from_iterator(
             records_iter,
             ["f" + str(i) for i in range(5)],
             output_dir,
             records_per_shard=50,
         )
         self.assertEqual(len(os.listdir(output_dir)), 5)
示例#6
0
 def test_read_to_iterator(self):
     reader = ODPSReader(
         self._project,
         self._access_id,
         self._access_key,
         self._endpoint,
         self._test_read_table,
         None,
         4,
         None,
     )
     records_iter = reader.to_iterator(1, 0, 50, 2, False, None)
     records = list(records_iter)
     self.assertEqual(len(records), 6,
                      "Unexpected number of batches: %d" % len(records))
     flattened_records = [record for batch in records for record in batch]
     self.assertEqual(
         len(flattened_records),
         220,
         "Unexpected number of total records: %d" % len(flattened_records),
     )