class PortalRawDataNotifier(object): class NotifyCtx(object): def __init__(self, master_addr): self._master_addr = master_addr channel = make_insecure_channel(master_addr, ChannelType.INTERNAL) self._master_cli = dj_grpc.DataJoinMasterServiceStub(channel) self._data_source = None self._raw_date_ctl = None self._raw_data_updated_datetime = {} @property def data_source(self): if self._data_source is None: self._data_source = \ self._master_cli.GetDataSource(empty_pb2.Empty()) return self._data_source def get_raw_data_updated_datetime(self, partition_id): if partition_id not in self._raw_data_updated_datetime: ts = self._raw_date_controller.get_raw_data_latest_timestamp( partition_id) if ts.seconds > 3600: ts.seconds -= 3600 else: ts.seconds = 0 self._raw_data_updated_datetime[partition_id] = \ common.convert_timestamp_to_datetime( common.trim_timestamp_by_hourly(ts) ) return self._raw_data_updated_datetime[partition_id] def add_raw_data(self, partition_id, fpaths, timestamps, end_ts): assert len(fpaths) == len(timestamps), \ "the number of raw data path and timestamp should same" if len(fpaths) > 0: self._raw_date_controller.add_raw_data(partition_id, fpaths, True, timestamps) self._raw_data_updated_datetime[partition_id] = end_ts @property def data_source_master_addr(self): return self._master_addr @property def _raw_date_controller(self): if self._raw_date_ctl is None: self._raw_date_ctl = RawDataController(self.data_source, self._master_cli) return self._raw_date_ctl def __init__(self, etcd, portal_name, downstream_data_source_masters): self._lock = threading.Lock() self._etcd = etcd self._portal_name = portal_name assert len(downstream_data_source_masters) > 0, \ "PortalRawDataNotifier launched when has master to notify" self._master_notify_ctx = {} for addr in downstream_data_source_masters: self._master_notify_ctx[addr] = \ PortalRawDataNotifier.NotifyCtx(addr) self._notify_worker = None self._started = False def start_notify_worker(self): with self._lock: if not self._started: assert self._notify_worker is None, \ "notify worker should be None if not started" self._notify_worker = RoutineWorker('potral-raw_data-notifier', self._raw_data_notify_fn, self._raw_data_notify_cond, 5) self._notify_worker.start_routine() self._started = True self._notify_worker.wakeup() def stop_notify_worker(self): notify_worker = None with self._lock: notify_worker = self._notify_worker self._notify_worker = None if notify_worker is not None: notify_worker.stop_routine() def _check_partition_num(self, notify_ctx, portal_manifest): assert isinstance(notify_ctx, PortalRawDataNotifier.NotifyCtx) data_source = notify_ctx.data_source ds_partition_num = data_source.data_source_meta.partition_num if portal_manifest.output_partition_num % ds_partition_num != 0: raise ValueError( "the partition number({}) of down stream data source "\ "{} should be divised by output partition of "\ "portatl({})".format(ds_partition_num, data_source.data_source_meta.name, portal_manifest.output_partition_num) ) def _add_raw_data_impl(self, notify_ctx, portal_manifest, ds_pid): dt = notify_ctx.get_raw_data_updated_datetime(ds_pid) + \ timedelta(hours=1) begin_dt = common.convert_timestamp_to_datetime( common.trim_timestamp_by_hourly(portal_manifest.begin_timestamp)) if dt < begin_dt: dt = begin_dt committed_dt = common.convert_timestamp_to_datetime( portal_manifest.committed_timestamp) fpaths = [] timestamps = [] ds_ptnum = notify_ctx.data_source.data_source_meta.partition_num while dt <= committed_dt: for pt_pid in range(ds_pid, portal_manifest.output_partition_num, ds_ptnum): fpath = common.encode_portal_hourly_fpath( portal_manifest.output_data_base_dir, dt, pt_pid) if gfile.Exists(fpath): fpaths.append(fpath) timestamps.append(common.convert_datetime_to_timestamp(dt)) if len(fpaths) > 32 or dt == committed_dt: break dt += timedelta(hours=1) notify_ctx.add_raw_data(ds_pid, fpaths, timestamps, dt) logging.info("add %d raw data file for partition %d of data "\ "source %s. latest updated datetime %s", len(fpaths), ds_pid, notify_ctx.data_source.data_source_meta.name, dt) return dt >= committed_dt def _notify_one_data_source(self, notify_ctx, portal_manifest): assert isinstance(notify_ctx, PortalRawDataNotifier.NotifyCtx) try: self._check_partition_num(notify_ctx, portal_manifest) ds_ptnum = notify_ctx.data_source.data_source_meta.partition_num pt_ptnum = portal_manifest.output_partition_num add_finished = False while not add_finished: add_finished = True for ds_pid in range(ds_ptnum): if not self._add_raw_data_impl(notify_ctx, portal_manifest, ds_pid): add_finished = False except Exception as e: # pylint: disable=broad-except logging.error("Failed to notify data source[master-addr: %s] "\ "new raw data added, reason %s", notify_ctx.data_source_master_addr, e) def _raw_data_notify_fn(self): portal_manifest = common.retrieve_portal_manifest( self._etcd, self._portal_name) for _, notify_ctx in self._master_notify_ctx.items(): self._notify_one_data_source(notify_ctx, portal_manifest) def _raw_data_notify_cond(self): return True
class RawDataPartitioner(object): class OutputFileWriter(object): def __init__(self, options, partition_id): self._options = options self._partition_id = partition_id self._process_index = 0 self._writer = None self._dumped_item = 0 self._output_fpaths = [] self._output_dir = os.path.join( self._options.output_dir, common.partition_repr(self._partition_id) ) if not gfile.Exists(self._output_dir): gfile.MakeDirs(self._output_dir) assert gfile.IsDirectory(self._output_dir) def append_item(self, index, item): writer = self._get_output_writer() if self._options.output_builder == 'TF_RECORD': writer.write(item.tf_record) else: assert self._options.output_builder == 'CSV_DICT' writer.write(item.csv_record) self._dumped_item += 1 if self._dumped_item >= self._options.output_item_threshold: self._finish_writer() if self._process_index % 16 == 0: logging.info("Output partition %d dump %d files, "\ "last index %d", self._partition_id, self._process_index, index) def finish(self): self._finish_writer() def get_output_files(self): return self._output_fpaths def _get_output_writer(self): if self._writer is None: self._new_writer() return self._writer def _new_writer(self): assert self._writer is None fname = "{:04}-{:08}.rd".format( self._options.partitioner_rank_id, self._process_index ) fpath = os.path.join(self._output_dir, fname) self._output_fpaths.append(fpath) if self._options.output_builder == 'TF_RECORD': self._writer = tf.io.TFRecordWriter(fpath) else: assert self._options.output_builder == 'CSV_DICT' self._writer = CsvDictWriter(fpath) self._dumped_item = 0 def _finish_writer(self): if self._writer is not None: self._writer.close() self._writer = None self._dumped_item = 0 self._process_index += 1 def __init__(self, options): self._options = options self._raw_data_batch_fetcher = RawDataBatchFetcher(options) self._fetch_worker = RoutineWorker('raw_data_batch_fetcher', self._raw_data_batch_fetch_fn, self._raw_data_batch_fetch_cond, 5) self._next_part_index = 0 self._cond = threading.Condition() self._fetch_worker.start_routine() def partition(self): if self._check_finished_tag(): logging.warning("partition has finished for rank id of parti"\ "tioner %d", self._options.partitioner_rank_id) return next_index = 0 hint_index = 0 fetch_finished = False fetcher = self._raw_data_batch_fetcher writers = [RawDataPartitioner.OutputFileWriter(self._options, pid) for pid in range(self._options.output_partition_num)] iter_round = 0 bp_options = self._options.batch_processor_options signal_round_threhold = bp_options.max_flying_item / \ bp_options.batch_size // 8 while not fetch_finished: fetch_finished, batch, hint_index = \ fetcher.fetch_item_batch_by_index(next_index, hint_index) iter_round += 1 if batch is not None: for index, item in enumerate(batch): raw_id = item.raw_id partition_id = CityHash32(raw_id) % \ self._options.output_partition_num writer = writers[partition_id] writer.append_item(batch.begin_index+index, item) next_index = batch.begin_index + len(batch) if iter_round % signal_round_threhold == 0: hint_index = self._evict_staless_batch(hint_index, next_index-1) logging.info("consumed %d items", next_index-1) self._set_next_part_index(next_index) self._wakeup_raw_data_fetcher() elif not fetch_finished: hint_index = self._evict_staless_batch(hint_index, next_index-1) with self._cond: self._cond.wait(1) for partition_id, writer in enumerate(writers): writer.finish() fpaths = writer.get_output_files() logging.info("part %d output %d files by partitioner", partition_id, len(fpaths)) for fpath in fpaths: logging.info("%s", fpath) logging.info("-----------------------------------") self._dump_finished_tag() self._fetch_worker.stop_routine() def _evict_staless_batch(self, hint_index, staless_index): evict_cnt = self._raw_data_batch_fetcher.evict_staless_item_batch( staless_index ) if hint_index <= evict_cnt: return 0 return hint_index-evict_cnt def _set_next_part_index(self, next_part_index): with self._cond: self._next_part_index = next_part_index def _get_next_part_index(self): with self._cond: return self._next_part_index def _raw_data_batch_fetch_fn(self): next_part_index = self._get_next_part_index() fetcher = self._raw_data_batch_fetcher for batch in fetcher.make_processor(next_part_index): logging.debug("fetch batch begin at %d, len %d. wakeup "\ "partitioner", batch.begin_index, len(batch)) self._wakeup_partitioner() def _raw_data_batch_fetch_cond(self): next_part_index = self._get_next_part_index() return self._raw_data_batch_fetcher.need_process(next_part_index) def _wakeup_partitioner(self): with self._cond: self._cond.notify_all() def _wakeup_raw_data_fetcher(self): self._fetch_worker.wakeup() def _dump_finished_tag(self): finished_tag_fpath = self._get_finished_tag_fpath() with gfile.GFile(finished_tag_fpath, 'w') as fh: fh.write('') def _check_finished_tag(self): return gfile.Exists(self._get_finished_tag_fpath()) def _get_finished_tag_fpath(self): return os.path.join( self._options.output_dir, '_SUCCESS.{:08}'.format(self._options.partitioner_rank_id) )