Exemple #1
0
class ODPSReader(object):
    def __init__(
        self,
        project,
        access_id,
        access_key,
        endpoint,
        table,
        partition=None,
        num_processes=None,
        options=None,
        transform_fn=None,
        columns=None,
    ):
        """
        Constructs a `ODPSReader` instance.

        Args:
            project: Name of the ODPS project.
            access_id: ODPS user access ID.
            access_key: ODPS user access key.
            endpoint: ODPS cluster endpoint.
            table: ODPS table name.
            tunnel_endpoint: ODPS tunnel endpoint.
            partition: ODPS table's partition.
            options: Other options passed to ODPS context.
            num_processes: Number of parallel processes on this worker.
                If `None`, use the number of cores.
            transform_fn: Customized transfrom function
            columns: list of table column names
        """
        super(ODPSReader, self).__init__()

        if table.find(".") > 0:
            project, table = table.split(".")
        if options is None:
            options = {}
        self._project = project
        self._access_id = access_id
        self._access_key = access_key
        self._endpoint = endpoint
        self._table = table
        self._partition = partition
        self._num_processes = num_processes
        _configure_odps_options(self._endpoint, options)
        self._odps_table = ODPS(
            self._access_id,
            self._access_key,
            self._project,
            self._endpoint,
        ).get_table(self._table)

        self._transform_fn = transform_fn
        self._columns = columns

    def reset(self, shards, shard_size):
        """
        The parallel reader launches multiple worker processes to read
        records from an ODPS table and applies `transform_fn` to each record.
        If `transform_fn` is not set, the transform stage will be skipped.

        Worker process:
        1. get a shard from index queue, the shard is a pair (start, count)
            of the ODPS table
        2. reads the records from the ODPS table
        3. apply `transform_fn` to each record
        4. put records to the result queue

        Main process:
        1. call `reset` to create a number of shards given a input shard
        2. put shard to index queue of workers in round-robin way
        3. call `get_records`  to get transformed data from result queue
        4. call `stop` to stop the workers
        """
        self._result_queue = Queue()
        self._index_queues = []
        self._workers = []

        self._shards = []
        self._shard_idx = 0
        self._worker_idx = 0

        for i in range(self._num_processes):
            index_queue = Queue()
            self._index_queues.append(index_queue)

            p = Process(target=self._worker_loop, args=(i, ))
            p.daemon = True
            p.start()
            self._workers.append(p)

        self._create_shards(shards, shard_size)
        for i in range(2 * self._num_processes):
            self._put_index()

    def get_shards_count(self):
        return len(self._shards)

    def get_records(self):
        data = self._result_queue.get()
        self._put_index()
        return data

    def stop(self):
        for q in self._index_queues:
            q.put((None, None))

    def _worker_loop(self, worker_id):
        while True:
            index = self._index_queues[worker_id].get()
            if index[0] is None and index[1] is None:
                break

            records = []
            for record in self.record_generator_with_retry(
                    start=index[0],
                    end=index[0] + index[1],
                    columns=self._columns,
                    transform_fn=self._transform_fn,
            ):
                records.append(record)
            self._result_queue.put(records)

    def _create_shards(self, shards, shard_size):
        start = shards[0]
        count = shards[1]
        m = count // shard_size
        n = count % shard_size

        for i in range(m):
            self._shards.append((start + i * shard_size, shard_size))
        if n != 0:
            self._shards.append((start + m * shard_size, n))

    def _next_worker_id(self):
        cur_id = self._worker_idx
        self._worker_idx += 1
        if self._worker_idx == self._num_processes:
            self._worker_idx = 0
        return cur_id

    def _put_index(self):
        # put index to the index queue of each worker
        # with Round-Robin way
        if self._shard_idx < len(self._shards):
            worker_id = self._next_worker_id()
            shard = self._shards[self._shard_idx]
            self._index_queues[worker_id].put(shard)
            self._shard_idx += 1

    def to_iterator(
        self,
        num_workers,
        worker_index,
        batch_size,
        epochs=1,
        shuffle=False,
        columns=None,
        cache_batch_count=None,
        limit=-1,
    ):
        """
        Load slices of ODPS table (partition of table if `partition`
        was specified) data with Python Generator.
        Args:
            num_workers: Total number of worker in the cluster.
            worker_index: Current index of the worker in the cluster.
            batch_size: Size of a slice.
            epochs: Repeat the data for this many times.
            shuffle: Whether to shuffle the data or rows.
            columns: The list of columns to load. If `None`,
                use all schema names of ODPS table.
            cache_batch_count: The cache batch count.
            limit: The limit for the table size to load.
        """
        if not worker_index < num_workers:
            raise ValueError(
                "index of worker should be less than number of worker")
        if not batch_size > 0:
            raise ValueError("batch_size should be positive")

        table_size = self.get_table_size()
        if 0 < limit < table_size:
            table_size = limit
        if columns is None:
            columns = self._odps_table.schema.names

        if cache_batch_count is None:
            cache_batch_count = self._estimate_cache_batch_count(
                columns=columns, table_size=table_size, batch_size=batch_size)

        large_batch_size = batch_size * cache_batch_count

        overall_items = range(0, table_size, large_batch_size)

        if len(overall_items) < num_workers:
            overall_items = range(0, table_size, int(table_size / num_workers))

        worker_items = list(
            np.array_split(np.asarray(overall_items),
                           num_workers)[worker_index])
        if shuffle:
            random.shuffle(worker_items)
        worker_items_with_epoch = worker_items * epochs

        # `worker_items_with_epoch` is the total number of batches
        # that needs to be read and the worker number should not
        # be larger than `worker_items_with_epoch`
        if self._num_processes is None:
            self._num_processes = min(8, len(worker_items_with_epoch))
        else:
            self._num_processes = min(self._num_processes,
                                      len(worker_items_with_epoch))

        if self._num_processes == 0:
            raise ValueError(
                "Total worker number is 0. Please check if table has data.")

        with Executor(max_workers=self._num_processes) as executor:

            futures = queue.Queue()
            # Initialize concurrently running processes according
            # to `num_processes`
            for i in range(self._num_processes):
                range_start = worker_items_with_epoch[i]
                range_end = min(range_start + large_batch_size, table_size)
                future = executor.submit(self.read_batch, range_start,
                                         range_end, columns)
                futures.put(future)

            worker_items_index = self._num_processes

            while not futures.empty():
                if worker_items_index < len(worker_items_with_epoch):
                    range_start = worker_items_with_epoch[worker_items_index]
                    range_end = min(range_start + large_batch_size, table_size)
                    future = executor.submit(self.read_batch, range_start,
                                             range_end, columns)
                    futures.put(future)
                    worker_items_index = worker_items_index + 1

                head_future = futures.get()
                records = head_future.result()
                for i in range(0, len(records), batch_size):
                    yield records[i:i + batch_size]  # noqa: E203

    def read_batch(self, start, end, columns=None, max_retries=3):
        """
        Read ODPS table in chosen row range [ `start`, `end` ) with the
        specified columns `columns`.
        Args:
            start: The row index to start reading.
            end: The row index to end reading.
            columns: The list of column to read.
            max_retries : The maximum number of retries in case of exceptions.
        Returns:
            Two-dimension python list with shape: (end - start, len(columns))
        """
        retry_count = 0
        if columns is None:
            columns = self._odps_table.schema.names
        while retry_count < max_retries:
            try:
                record_gen = self.record_generator(start, end, columns)
                return [record for record in record_gen]
            except Exception as e:
                if retry_count >= max_retries:
                    raise Exception("Exceeded maximum number of retries")
                logger.warning("ODPS read exception {} for {} in {}."
                               "Retrying time: {}".format(
                                   e, columns, self._table, retry_count))
                time.sleep(5)
                retry_count += 1

    def record_generator_with_retry(self,
                                    start,
                                    end,
                                    columns=None,
                                    max_retries=3,
                                    transform_fn=None):
        """Wrap record_generator with retry to avoid ODPS table read failure
        due to network instability.
        """
        retry_count = 0
        while retry_count < max_retries:
            try:
                for record in self.record_generator(start, end, columns):
                    if transform_fn:
                        record = transform_fn(record)
                    yield record
                break
            except Exception as e:
                if retry_count >= max_retries:
                    raise Exception("Exceeded maximum number of retries")
                logger.warning("ODPS read exception {} for {} in {}."
                               "Retrying time: {}".format(
                                   e, columns, self._table, retry_count))
                time.sleep(5)
                retry_count += 1

    def record_generator(self, start, end, columns=None):
        """Generate records from an ODPS table
        """
        if columns is None:
            columns = self._odps_table.schema.names
        with self._odps_table.open_reader(partition=self._partition,
                                          reopen=False) as reader:
            for record in reader.read(start=start,
                                      count=end - start,
                                      columns=columns):
                yield [str(record[column]) for column in columns]

    def get_table_size(self, max_retries=3):
        retry_count = 0
        while retry_count < max_retries:
            try:
                with self._odps_table.open_reader(
                        partition=self._partition) as reader:
                    return reader.count
            except Exception as e:
                if retry_count >= max_retries:
                    raise Exception("Exceeded maximum number of retries")
                logger.warning("ODPS read exception {} to get table size."
                               "Retrying time: {}".format(e, retry_count))
                time.sleep(5)
                retry_count += 1

    def _estimate_cache_batch_count(self, columns, table_size, batch_size):
        """
        This function calculates the appropriate cache batch size
        when we download from ODPS, if batch size is small, we will
        repeatedly create http connection and download small chunk of
        data. To read more efficiently, we will read
        `batch_size * cache_batch_count` lines of data.
        However, determining a proper `cache_batch_count` is non-trivial.
        Our heuristic now is to set a per download upper bound.
        """

        sample_size = 10
        max_cache_batch_count = 50
        upper_bound = 20 * 1000000

        if table_size < sample_size:
            return 1

        batch = self.read_batch(start=0, end=sample_size, columns=columns)

        size_sample = _nested_list_size(batch)
        size_per_batch = size_sample * batch_size / sample_size

        # `size_per_batch * cache_batch_count` will
        # not exceed upper bound but will always greater than 0
        cache_batch_count_estimate = max(int(upper_bound / size_per_batch), 1)

        return min(cache_batch_count_estimate, max_cache_batch_count)
Exemple #2
0
def _read_odps_one_shot(
    project,
    access_id,
    access_key,
    endpoint,
    table,
    partition,
    start,
    end,
    columns,
    max_retries=3,
):
    """
    Read ODPS table in chosen row range [ `start`, `end` ) with the specified
    columns `columns`.

    Args:
        project: The ODPS project.
        access_id: The ODPS user access ID.
        access_key: The ODPS user access key.
        endpoint: The ODPS cluster endpoint.
        table: The ODPS table name.
        partition: The ODPS table's partition. Default is `None` if the
            table is not partitioned.
        start: The row index to start reading.
        end: The row index to end reading.
        columns: The list of column to read.
        max_retries : The maximum number of retries in case of exceptions.

    Returns: Two-dimension python list with shape: (end - start, len(columns))
    """
    odps_table = ODPS(access_id, access_key, project, endpoint).get_table(
        table
    )

    retry_count = 0

    while retry_count < max_retries:
        try:
            batch_record = []
            with odps_table.open_reader(
                partition=partition, reopen=True
            ) as reader:
                for record in reader.read(
                    start=start, count=end - start, columns=columns
                ):
                    batch_record.append([record[column] for column in columns])

            return batch_record

        except Exception as e:
            import time

            if retry_count >= max_retries:
                raise
            logger.warning(
                "ODPS read exception {} for {} in {}. retrying {} time".format(
                    e, columns, table, retry_count
                )
            )
            time.sleep(5)
            retry_count += 1
Exemple #3
0
class ODPSReader(object):
    def __init__(
        self,
        project,
        access_id,
        access_key,
        endpoint,
        table,
        partition=None,
        num_processes=None,
        options=None,
    ):
        """
        Constructs a `ODPSReader` instance.

        Args:
            project: Name of the ODPS project.
            access_id: ODPS user access ID.
            access_key: ODPS user access key.
            endpoint: ODPS cluster endpoint.
            table: ODPS table name.
            partition: ODPS table's partition.
            options: Other options passed to ODPS context.
            num_processes: Number of parallel processes on this worker.
                If `None`, use the number of cores.
        """
        super(ODPSReader, self).__init__()

        if table.find(".") > 0:
            project, table = table.split(".")
        if options is None:
            options = {}
        self._project = project
        self._access_id = access_id
        self._access_key = access_key
        self._endpoint = endpoint
        self._table = table
        self._partition = partition
        self._num_processes = num_processes
        _configure_odps_options(self._endpoint, options)
        self._odps_table = ODPS(self._access_id, self._access_key,
                                self._project,
                                self._endpoint).get_table(self._table)

    def to_iterator(
        self,
        num_workers,
        worker_index,
        batch_size,
        epochs=1,
        shuffle=False,
        columns=None,
        cache_batch_count=None,
        limit=-1,
    ):
        """
        Load slices of ODPS table (partition of table if `partition`
        was specified) data with Python Generator.

        Args:
            num_workers: Total number of worker in the cluster.
            worker_index: Current index of the worker in the cluster.
            batch_size: Size of a slice.
            epochs: Repeat the data for this many times.
            shuffle: Whether to shuffle the data or rows.
            columns: The list of columns to load. If `None`,
                use all schema names of ODPS table.
            cache_batch_count: The cache batch count.
            limit: The limit for the table size to load.
        """
        if not worker_index < num_workers:
            raise ValueError(
                "index of worker should be less than number of worker")
        if not batch_size > 0:
            raise ValueError("batch_size should be positive")

        table_size = self.get_table_size()
        if 0 < limit < table_size:
            table_size = limit
        if columns is None:
            columns = self._odps_table.schema.names

        if cache_batch_count is None:
            cache_batch_count = self._estimate_cache_batch_count(
                columns=columns, table_size=table_size, batch_size=batch_size)

        large_batch_size = batch_size * cache_batch_count

        overall_items = range(0, table_size, large_batch_size)

        if len(overall_items) < num_workers:
            overall_items = range(0, table_size, int(table_size / num_workers))

        worker_items = list(
            np.array_split(np.asarray(overall_items),
                           num_workers)[worker_index])
        if shuffle:
            random.shuffle(worker_items)
        worker_items_with_epoch = worker_items * epochs

        # `worker_items_with_epoch` is the total number of batches
        # that needs to be read and the worker number should not
        # be larger than `worker_items_with_epoch`
        if self._num_processes is None:
            self._num_processes = min(8, len(worker_items_with_epoch))
        else:
            self._num_processes = min(self._num_processes,
                                      len(worker_items_with_epoch))

        if self._num_processes == 0:
            raise ValueError(
                "Total worker number is 0. Please check if table has data.")

        with Executor(max_workers=self._num_processes) as executor:

            futures = Queue()
            # Initialize concurrently running processes according
            # to `num_processes`
            for i in range(self._num_processes):
                range_start = worker_items_with_epoch[i]
                range_end = min(range_start + large_batch_size, table_size)
                future = executor.submit(self.read_batch, range_start,
                                         range_end, columns)
                futures.put(future)

            worker_items_index = self._num_processes

            while not futures.empty():
                if worker_items_index < len(worker_items_with_epoch):
                    range_start = worker_items_with_epoch[worker_items_index]
                    range_end = min(range_start + large_batch_size, table_size)
                    future = executor.submit(self.read_batch, range_start,
                                             range_end, columns)
                    futures.put(future)
                    worker_items_index = worker_items_index + 1

                head_future = futures.get()
                records = head_future.result()
                for i in range(0, len(records), batch_size):
                    yield records[i:i + batch_size]  # noqa: E203

    def read_batch(self, start, end, columns=None, max_retries=3):
        """
        Read ODPS table in chosen row range [ `start`, `end` ) with the
        specified columns `columns`.

        Args:
            start: The row index to start reading.
            end: The row index to end reading.
            columns: The list of column to read.
            max_retries : The maximum number of retries in case of exceptions.

        Returns:
            Two-dimension python list with shape: (end - start, len(columns))
        """
        retry_count = 0
        if columns is None:
            columns = self._odps_table.schema.names
        while retry_count < max_retries:
            try:
                batch_record = []
                with self._odps_table.open_reader(partition=self._partition,
                                                  reopen=True) as reader:
                    for record in reader.read(start=start,
                                              count=end - start,
                                              columns=columns):
                        batch_record.append(
                            [record[column] for column in columns])
                return batch_record
            except Exception as e:
                if retry_count >= max_retries:
                    raise Exception("Exceeded maximum number of retries")
                logger.warning("ODPS read exception {} for {} in {}."
                               "Retrying time: {}".format(
                                   e, columns, self._table, retry_count))
                time.sleep(5)
                retry_count += 1

    def get_table_size(self):
        with self._odps_table.open_reader(partition=self._partition) as reader:
            return reader.count

    def _estimate_cache_batch_count(self, columns, table_size, batch_size):
        """
        This function calculates the appropriate cache batch size
        when we download from ODPS, if batch size is small, we will
        repeatedly create http connection and download small chunk of
        data. To read more efficiently, we will read
        `batch_size * cache_batch_count` lines of data.
        However, determining a proper `cache_batch_count` is non-trivial.
        Our heuristic now is to set a per download upper bound.
        """

        sample_size = 10
        max_cache_batch_count = 50
        upper_bound = 20 * 1000000

        if table_size < sample_size:
            return 1

        batch = self.read_batch(start=0, end=sample_size, columns=columns)

        size_sample = _nested_list_size(batch)
        size_per_batch = size_sample * batch_size / sample_size

        # `size_per_batch * cache_batch_count` will
        # not exceed upper bound but will always greater than 0
        cache_batch_count_estimate = max(int(upper_bound / size_per_batch), 1)

        return min(cache_batch_count_estimate, max_cache_batch_count)
Exemple #4
0
class ODPSReader(object):
    def __init__(
        self,
        project,
        access_id,
        access_key,
        endpoint,
        table,
        partition=None,
        num_processes=None,
        options=None,
        transform_fn=None,
        columns=None,
    ):
        """
        Constructs a `ODPSReader` instance.

        Args:
            project: Name of the ODPS project.
            access_id: ODPS user access ID.
            access_key: ODPS user access key.
            endpoint: ODPS cluster endpoint.
            table: ODPS table name.
            tunnel_endpoint: ODPS tunnel endpoint.
            partition: ODPS table's partition.
            options: Other options passed to ODPS context.
            num_processes: Number of parallel processes on this worker.
                If `None`, use the number of cores.
            transform_fn: Customized transfrom function
            columns: list of table column names
        """
        super(ODPSReader, self).__init__()

        if table.find(".") > 0:
            project, table = table.split(".")
        if options is None:
            options = {}
        self._project = project
        self._access_id = access_id
        self._access_key = access_key
        self._endpoint = endpoint
        self._table = table
        self._partition = partition
        self._num_processes = num_processes
        _configure_odps_options(self._endpoint, options)
        self._odps_table = ODPS(
            self._access_id,
            self._access_key,
            self._project,
            self._endpoint,
        ).get_table(self._table)

        self._transform_fn = transform_fn
        self._columns = columns

    def reset(self, shards, shard_size):
        """
        The parallel reader launches multiple worker processes to read
        records from an ODPS table and applies `transform_fn` to each record.
        If `transform_fn` is not set, the transform stage will be skipped.

        Worker process:
        1. get a shard from index queue, the shard is a pair (start, count)
            of the ODPS table
        2. reads the records from the ODPS table
        3. apply `transform_fn` to each record
        4. put records to the result queue

        Main process:
        1. call `reset` to create a number of shards given a input shard
        2. put shard to index queue of workers in round-robin way
        3. call `get_records`  to get transformed data from result queue
        4. call `stop` to stop the workers
        """
        self._result_queue = Queue()
        self._index_queues = []
        self._workers = []

        self._shards = []
        self._shard_idx = 0
        self._worker_idx = 0

        for i in range(self._num_processes):
            index_queue = Queue()
            self._index_queues.append(index_queue)

            p = Process(target=self._worker_loop, args=(i, ))
            p.daemon = True
            p.start()
            self._workers.append(p)

        self._create_shards(shards, shard_size)
        for i in range(2 * self._num_processes):
            self._put_index()

    def get_shards_count(self):
        return len(self._shards)

    def get_records(self):
        data = self._result_queue.get()
        self._put_index()
        return data

    def stop(self):
        for q in self._index_queues:
            q.put((None, None))

    def _worker_loop(self, worker_id):
        while True:
            index = self._index_queues[worker_id].get()
            if index[0] is None and index[1] is None:
                break

            records = []
            for record in self.record_generator_with_retry(
                    start=index[0],
                    end=index[0] + index[1],
                    columns=self._columns,
                    transform_fn=self._transform_fn,
            ):
                records.append(record)
            self._result_queue.put(records)

    def _create_shards(self, shards, shard_size):
        start = shards[0]
        count = shards[1]
        m = count // shard_size
        n = count % shard_size

        for i in range(m):
            self._shards.append((start + i * shard_size, shard_size))
        if n != 0:
            self._shards.append((start + m * shard_size, n))

    def _next_worker_id(self):
        cur_id = self._worker_idx
        self._worker_idx += 1
        if self._worker_idx == self._num_processes:
            self._worker_idx = 0
        return cur_id

    def _put_index(self):
        # put index to the index queue of each worker
        # with Round-Robin way
        if self._shard_idx < len(self._shards):
            worker_id = self._next_worker_id()
            shard = self._shards[self._shard_idx]
            self._index_queues[worker_id].put(shard)
            self._shard_idx += 1

    def read_batch(self, start, end, columns=None, max_retries=3):
        """
        Read ODPS table in chosen row range [ `start`, `end` ) with the
        specified columns `columns`.
        Args:
            start: The row index to start reading.
            end: The row index to end reading.
            columns: The list of column to read.
            max_retries : The maximum number of retries in case of exceptions.
        Returns:
            Two-dimension python list with shape: (end - start, len(columns))
        """
        retry_count = 0
        if columns is None:
            columns = self._odps_table.schema.names
        while retry_count < max_retries:
            try:
                record_gen = self.record_generator(start, end, columns)
                return [record for record in record_gen]
            except Exception as e:
                if retry_count >= max_retries:
                    raise Exception("Exceeded maximum number of retries")
                logger.warning("ODPS read exception {} for {} in {}."
                               "Retrying time: {}".format(
                                   e, columns, self._table, retry_count))
                time.sleep(5)
                retry_count += 1

    def record_generator_with_retry(self,
                                    start,
                                    end,
                                    columns=None,
                                    max_retries=3,
                                    transform_fn=None):
        """Wrap record_generator with retry to avoid ODPS table read failure
        due to network instability.
        """
        retry_count = 0
        while retry_count < max_retries:
            try:
                for record in self.record_generator(start, end, columns):
                    if transform_fn:
                        record = transform_fn(record)
                    yield record
                break
            except Exception as e:
                if retry_count >= max_retries:
                    raise Exception("Exceeded maximum number of retries")
                logger.warning("ODPS read exception {} for {} in {}."
                               "Retrying time: {}".format(
                                   e, columns, self._table, retry_count))
                time.sleep(5)
                retry_count += 1

    def record_generator(self, start, end, columns=None):
        """Generate records from an ODPS table
        """
        if columns is None:
            columns = self._odps_table.schema.names
        with self._odps_table.open_reader(partition=self._partition,
                                          reopen=False) as reader:
            for record in reader.read(start=start,
                                      count=end - start,
                                      columns=columns):
                yield [str(record[column]) for column in columns]

    def get_table_size(self, max_retries=3):
        retry_count = 0
        while retry_count < max_retries:
            try:
                with self._odps_table.open_reader(
                        partition=self._partition) as reader:
                    return reader.count
            except Exception as e:
                if retry_count >= max_retries:
                    raise Exception("Exceeded maximum number of retries")
                logger.warning("ODPS read exception {} to get table size."
                               "Retrying time: {}".format(e, retry_count))
                time.sleep(5)
                retry_count += 1

    def _estimate_cache_batch_count(self, columns, table_size, batch_size):
        """
        This function calculates the appropriate cache batch size
        when we download from ODPS, if batch size is small, we will
        repeatedly create http connection and download small chunk of
        data. To read more efficiently, we will read
        `batch_size * cache_batch_count` lines of data.
        However, determining a proper `cache_batch_count` is non-trivial.
        Our heuristic now is to set a per download upper bound.
        """

        sample_size = 10
        max_cache_batch_count = 50
        upper_bound = 20 * 1000000

        if table_size < sample_size:
            return 1

        batch = self.read_batch(start=0, end=sample_size, columns=columns)

        size_sample = _nested_list_size(batch)
        size_per_batch = size_sample * batch_size / sample_size

        # `size_per_batch * cache_batch_count` will
        # not exceed upper bound but will always greater than 0
        cache_batch_count_estimate = max(int(upper_bound / size_per_batch), 1)

        return min(cache_batch_count_estimate, max_cache_batch_count)