예제 #1
0
class DataPortalMaster(dp_grpc.DataPortalMasterServiceServicer):
    def __init__(self, portal_name, kvstore, portal_options):
        super(DataPortalMaster, self).__init__()
        self._portal_name = portal_name
        self._kvstore = kvstore
        self._portal_options = portal_options
        self._data_portal_job_manager = DataPortalJobManager(
            self._kvstore,
            self._portal_name,
            self._portal_options.long_running,
            self._portal_options.check_success_tag,
            self._portal_options.single_subfolder,
            self._portal_options.files_per_job_limit,
            start_date=self._portal_options.start_date,
            end_date=self._portal_options.end_date)
        self._bg_worker = None

    def GetDataPortalManifest(self, request, context):
        return self._data_portal_job_manager.get_portal_manifest()

    def RequestNewTask(self, request, context):
        response = dp_pb.NewTaskResponse()
        finished, task = \
            self._data_portal_job_manager.alloc_task(request.rank_id)
        if task is not None:
            if isinstance(task, dp_pb.MapTask):
                response.map_task.MergeFrom(task)
            else:
                assert isinstance(task, dp_pb.ReduceTask)
                response.reduce_task.MergeFrom(task)
        elif not finished:
            response.pending.MergeFrom(empty_pb2.Empty())
        else:
            response.finished.MergeFrom(empty_pb2.Empty())
        return response

    def FinishTask(self, request, context):
        self._data_portal_job_manager.finish_task(request.rank_id,
                                                  request.partition_id,
                                                  request.part_state)
        return common_pb.Status()

    def start(self):
        self._bg_worker = RoutineWorker(
            'portal_master_bg_worker',
            self._data_portal_job_manager.backgroup_task, lambda: True, 30)
        self._bg_worker.start_routine()

    def stop(self):
        if self._bg_worker is not None:
            self._bg_worker.stop_routine()
        self._bg_worker = None
예제 #2
0
class RawDataPartitioner(object):
    class OutputFileWriter(object):
        def __init__(self, options, partition_id):
            self._options = options
            self._partition_id = partition_id
            self._process_index = 0
            self._writer = None
            self._dumped_item = 0
            self._output_fpaths = []
            self._output_dir = os.path.join(
                    self._options.output_dir,
                    common.partition_repr(self._partition_id)
                )
            if not gfile.Exists(self._output_dir):
                gfile.MakeDirs(self._output_dir)
            assert gfile.IsDirectory(self._output_dir)

        def append_item(self, index, item):
            writer = self._get_output_writer()
            if self._options.output_builder == 'TF_RECORD':
                writer.write(item.tf_record)
            else:
                assert self._options.output_builder == 'CSV_DICT'
                writer.write(item.csv_record)
            self._dumped_item += 1
            if self._dumped_item >= self._options.output_item_threshold:
                self._finish_writer()
                if self._process_index % 16 == 0:
                    logging.info("Output partition %d dump %d files, "\
                                 "last index %d", self._partition_id,
                                 self._process_index, index)

        def finish(self):
            self._finish_writer()

        def get_output_files(self):
            return self._output_fpaths

        def _get_output_writer(self):
            if self._writer is None:
                self._new_writer()
            return self._writer

        def _new_writer(self):
            assert self._writer is None
            fname = "{:04}-{:08}.rd".format(
                    self._options.partitioner_rank_id,
                    self._process_index
                )
            fpath = os.path.join(self._output_dir, fname)
            self._output_fpaths.append(fpath)
            if self._options.output_builder == 'TF_RECORD':
                self._writer = tf.io.TFRecordWriter(fpath)
            else:
                assert self._options.output_builder == 'CSV_DICT'
                self._writer = CsvDictWriter(fpath)
            self._dumped_item = 0

        def _finish_writer(self):
            if self._writer is not None:
                self._writer.close()
                self._writer = None
            self._dumped_item = 0
            self._process_index += 1

    def __init__(self, options):
        self._options = options
        self._raw_data_batch_fetcher = RawDataBatchFetcher(options)
        self._fetch_worker = RoutineWorker('raw_data_batch_fetcher',
                                           self._raw_data_batch_fetch_fn,
                                           self._raw_data_batch_fetch_cond, 5)
        self._next_part_index = 0
        self._cond = threading.Condition()
        self._fetch_worker.start_routine()

    def partition(self):
        if self._check_finished_tag():
            logging.warning("partition has finished for rank id of parti"\
                            "tioner %d", self._options.partitioner_rank_id)
            return
        next_index = 0
        hint_index = 0
        fetch_finished = False
        fetcher = self._raw_data_batch_fetcher
        writers = [RawDataPartitioner.OutputFileWriter(self._options, pid)
                   for pid in range(self._options.output_partition_num)]
        iter_round = 0
        bp_options = self._options.batch_processor_options
        signal_round_threhold = bp_options.max_flying_item / \
                bp_options.batch_size // 8
        while not fetch_finished:
            fetch_finished, batch, hint_index = \
                    fetcher.fetch_item_batch_by_index(next_index, hint_index)
            iter_round += 1
            if batch is not None:
                for index, item in enumerate(batch):
                    raw_id = item.raw_id
                    partition_id = CityHash32(raw_id) % \
                            self._options.output_partition_num
                    writer = writers[partition_id]
                    writer.append_item(batch.begin_index+index, item)
                next_index = batch.begin_index + len(batch)
                if iter_round % signal_round_threhold == 0:
                    hint_index = self._evict_staless_batch(hint_index,
                                                           next_index-1)
                    logging.info("consumed %d items", next_index-1)
                self._set_next_part_index(next_index)
                self._wakeup_raw_data_fetcher()
            elif not fetch_finished:
                hint_index = self._evict_staless_batch(hint_index,
                                                       next_index-1)
                with self._cond:
                    self._cond.wait(1)
        for partition_id, writer in enumerate(writers):
            writer.finish()
            fpaths = writer.get_output_files()
            logging.info("part %d output %d files by partitioner",
                          partition_id, len(fpaths))
            for fpath in fpaths:
                logging.info("%s", fpath)
            logging.info("-----------------------------------")
        self._dump_finished_tag()
        self._fetch_worker.stop_routine()

    def _evict_staless_batch(self, hint_index, staless_index):
        evict_cnt = self._raw_data_batch_fetcher.evict_staless_item_batch(
                staless_index
            )
        if hint_index <= evict_cnt:
            return 0
        return hint_index-evict_cnt

    def _set_next_part_index(self, next_part_index):
        with self._cond:
            self._next_part_index = next_part_index

    def _get_next_part_index(self):
        with self._cond:
            return self._next_part_index

    def _raw_data_batch_fetch_fn(self):
        next_part_index = self._get_next_part_index()
        fetcher = self._raw_data_batch_fetcher
        for batch in fetcher.make_processor(next_part_index):
            logging.debug("fetch batch begin at %d, len %d. wakeup "\
                          "partitioner", batch.begin_index, len(batch))
            self._wakeup_partitioner()

    def _raw_data_batch_fetch_cond(self):
        next_part_index = self._get_next_part_index()
        return self._raw_data_batch_fetcher.need_process(next_part_index)

    def _wakeup_partitioner(self):
        with self._cond:
            self._cond.notify_all()

    def _wakeup_raw_data_fetcher(self):
        self._fetch_worker.wakeup()

    def _dump_finished_tag(self):
        finished_tag_fpath = self._get_finished_tag_fpath()
        with gfile.GFile(finished_tag_fpath, 'w') as fh:
            fh.write('')

    def _check_finished_tag(self):
        return gfile.Exists(self._get_finished_tag_fpath())

    def _get_finished_tag_fpath(self):
        return os.path.join(
                self._options.output_dir,
                '_SUCCESS.{:08}'.format(self._options.partitioner_rank_id)
            )