class DataPortalMaster(dp_grpc.DataPortalMasterServiceServicer): def __init__(self, portal_name, kvstore, portal_options): super(DataPortalMaster, self).__init__() self._portal_name = portal_name self._kvstore = kvstore self._portal_options = portal_options self._data_portal_job_manager = DataPortalJobManager( self._kvstore, self._portal_name, self._portal_options.long_running, self._portal_options.check_success_tag, self._portal_options.single_subfolder, self._portal_options.files_per_job_limit, start_date=self._portal_options.start_date, end_date=self._portal_options.end_date) self._bg_worker = None def GetDataPortalManifest(self, request, context): return self._data_portal_job_manager.get_portal_manifest() def RequestNewTask(self, request, context): response = dp_pb.NewTaskResponse() finished, task = \ self._data_portal_job_manager.alloc_task(request.rank_id) if task is not None: if isinstance(task, dp_pb.MapTask): response.map_task.MergeFrom(task) else: assert isinstance(task, dp_pb.ReduceTask) response.reduce_task.MergeFrom(task) elif not finished: response.pending.MergeFrom(empty_pb2.Empty()) else: response.finished.MergeFrom(empty_pb2.Empty()) return response def FinishTask(self, request, context): self._data_portal_job_manager.finish_task(request.rank_id, request.partition_id, request.part_state) return common_pb.Status() def start(self): self._bg_worker = RoutineWorker( 'portal_master_bg_worker', self._data_portal_job_manager.backgroup_task, lambda: True, 30) self._bg_worker.start_routine() def stop(self): if self._bg_worker is not None: self._bg_worker.stop_routine() self._bg_worker = None
class RawDataPartitioner(object): class OutputFileWriter(object): def __init__(self, options, partition_id): self._options = options self._partition_id = partition_id self._process_index = 0 self._writer = None self._dumped_item = 0 self._output_fpaths = [] self._output_dir = os.path.join( self._options.output_dir, common.partition_repr(self._partition_id) ) if not gfile.Exists(self._output_dir): gfile.MakeDirs(self._output_dir) assert gfile.IsDirectory(self._output_dir) def append_item(self, index, item): writer = self._get_output_writer() if self._options.output_builder == 'TF_RECORD': writer.write(item.tf_record) else: assert self._options.output_builder == 'CSV_DICT' writer.write(item.csv_record) self._dumped_item += 1 if self._dumped_item >= self._options.output_item_threshold: self._finish_writer() if self._process_index % 16 == 0: logging.info("Output partition %d dump %d files, "\ "last index %d", self._partition_id, self._process_index, index) def finish(self): self._finish_writer() def get_output_files(self): return self._output_fpaths def _get_output_writer(self): if self._writer is None: self._new_writer() return self._writer def _new_writer(self): assert self._writer is None fname = "{:04}-{:08}.rd".format( self._options.partitioner_rank_id, self._process_index ) fpath = os.path.join(self._output_dir, fname) self._output_fpaths.append(fpath) if self._options.output_builder == 'TF_RECORD': self._writer = tf.io.TFRecordWriter(fpath) else: assert self._options.output_builder == 'CSV_DICT' self._writer = CsvDictWriter(fpath) self._dumped_item = 0 def _finish_writer(self): if self._writer is not None: self._writer.close() self._writer = None self._dumped_item = 0 self._process_index += 1 def __init__(self, options): self._options = options self._raw_data_batch_fetcher = RawDataBatchFetcher(options) self._fetch_worker = RoutineWorker('raw_data_batch_fetcher', self._raw_data_batch_fetch_fn, self._raw_data_batch_fetch_cond, 5) self._next_part_index = 0 self._cond = threading.Condition() self._fetch_worker.start_routine() def partition(self): if self._check_finished_tag(): logging.warning("partition has finished for rank id of parti"\ "tioner %d", self._options.partitioner_rank_id) return next_index = 0 hint_index = 0 fetch_finished = False fetcher = self._raw_data_batch_fetcher writers = [RawDataPartitioner.OutputFileWriter(self._options, pid) for pid in range(self._options.output_partition_num)] iter_round = 0 bp_options = self._options.batch_processor_options signal_round_threhold = bp_options.max_flying_item / \ bp_options.batch_size // 8 while not fetch_finished: fetch_finished, batch, hint_index = \ fetcher.fetch_item_batch_by_index(next_index, hint_index) iter_round += 1 if batch is not None: for index, item in enumerate(batch): raw_id = item.raw_id partition_id = CityHash32(raw_id) % \ self._options.output_partition_num writer = writers[partition_id] writer.append_item(batch.begin_index+index, item) next_index = batch.begin_index + len(batch) if iter_round % signal_round_threhold == 0: hint_index = self._evict_staless_batch(hint_index, next_index-1) logging.info("consumed %d items", next_index-1) self._set_next_part_index(next_index) self._wakeup_raw_data_fetcher() elif not fetch_finished: hint_index = self._evict_staless_batch(hint_index, next_index-1) with self._cond: self._cond.wait(1) for partition_id, writer in enumerate(writers): writer.finish() fpaths = writer.get_output_files() logging.info("part %d output %d files by partitioner", partition_id, len(fpaths)) for fpath in fpaths: logging.info("%s", fpath) logging.info("-----------------------------------") self._dump_finished_tag() self._fetch_worker.stop_routine() def _evict_staless_batch(self, hint_index, staless_index): evict_cnt = self._raw_data_batch_fetcher.evict_staless_item_batch( staless_index ) if hint_index <= evict_cnt: return 0 return hint_index-evict_cnt def _set_next_part_index(self, next_part_index): with self._cond: self._next_part_index = next_part_index def _get_next_part_index(self): with self._cond: return self._next_part_index def _raw_data_batch_fetch_fn(self): next_part_index = self._get_next_part_index() fetcher = self._raw_data_batch_fetcher for batch in fetcher.make_processor(next_part_index): logging.debug("fetch batch begin at %d, len %d. wakeup "\ "partitioner", batch.begin_index, len(batch)) self._wakeup_partitioner() def _raw_data_batch_fetch_cond(self): next_part_index = self._get_next_part_index() return self._raw_data_batch_fetcher.need_process(next_part_index) def _wakeup_partitioner(self): with self._cond: self._cond.notify_all() def _wakeup_raw_data_fetcher(self): self._fetch_worker.wakeup() def _dump_finished_tag(self): finished_tag_fpath = self._get_finished_tag_fpath() with gfile.GFile(finished_tag_fpath, 'w') as fh: fh.write('') def _check_finished_tag(self): return gfile.Exists(self._get_finished_tag_fpath()) def _get_finished_tag_fpath(self): return os.path.join( self._options.output_dir, '_SUCCESS.{:08}'.format(self._options.partitioner_rank_id) )