def start_process(self): with self._lock: if not self._started: self._worker_map = { self._id_batch_fetcher_name(): RoutineWorker( self._id_batch_fetcher_name(), self._id_batch_fetch_fn, self._id_batch_fetch_cond, 5), self._psi_rsa_signer_name(): RoutineWorker( self._psi_rsa_signer_name(), self._psi_rsa_sign_fn, self._psi_rsa_sign_cond, 5), self._sort_run_dumper_name(): RoutineWorker( self._sort_run_dumper_name(), self._sort_run_dump_fn, self._sort_run_dump_cond, 5), self._sort_run_merger_name(): RoutineWorker( self._sort_run_merger_name(), self._sort_run_merge_fn, self._sort_run_merge_cond, 5) } for _, w in self._worker_map.items(): w.start_routine() self._started = True
def start(self): self._bg_worker = RoutineWorker( 'portal_master_bg_worker', self._data_portal_job_manager.backgroup_task, lambda: True, 30 ) self._bg_worker.start_routine()
def start_dump_worker(self): with self._lock: if not self._started: assert self._example_id_dump_worker is None self._example_id_dump_worker = RoutineWorker( 'example_id_dumper', self._dump_example_ids_fn, self._dump_example_ids_cond, 1) self._example_id_dump_worker.start_routine() self._started = True
def start_dump_worker(self): with self._lock: if not self._started: assert self._data_block_dump_worker is None self._data_block_dump_worker = RoutineWorker( 'data_block_dumper', self._dump_data_block_fn, self._dump_data_block_cond, 1) self._data_block_dump_worker.start_routine() self._started = True
def start_fsm_worker(self): with self._lock: if not self._started: assert self._fsm_worker is None self._started = True self._fsm_worker = RoutineWorker( '{}_fsm_worker'.format(self._data_source_name), self._fsm_routine_fn, self._fsm_routine_cond, 5) self._fsm_worker.start_routine()
def __init__(self, options): self._options = options self._raw_data_batch_fetcher = RawDataBatchFetcher(options) self._fetch_worker = RoutineWorker('raw_data_batch_fetcher', self._raw_data_batch_fetch_fn, self._raw_data_batch_fetch_cond, 5) self._next_part_index = 0 self._cond = threading.Condition() self._fetch_worker.start_routine()
def start_dump_worker(self): with self._lock: if not self._started: assert self._dump_worker is None, \ "dumper woker for {} should be None if "\ "not started".format(self._repr_str) self._dump_worker = RoutineWorker( self._repr_str + '-dump_worker', self._dump_fn, self._dump_cond, 1) self._dump_worker.start_routine() self._started = True
def start_notify_worker(self): with self._lock: if not self._started: assert self._notify_worker is None, \ "notify worker should be None if not started" self._notify_worker = RoutineWorker('potral-raw_data-notifier', self._raw_data_notify_fn, self._raw_data_notify_cond, 5) self._notify_worker.start_routine() self._started = True self._notify_worker.wakeup()
class DataPortalMaster(dp_grpc.DataPortalMasterServiceServicer): def __init__(self, portal_name, kvstore, portal_options): super(DataPortalMaster, self).__init__() self._portal_name = portal_name self._kvstore = kvstore self._portal_options = portal_options self._data_portal_job_manager = DataPortalJobManager( self._kvstore, self._portal_name, self._portal_options.long_running, self._portal_options.check_success_tag, self._portal_options.single_subfolder, self._portal_options.files_per_job_limit, start_date=self._portal_options.start_date, end_date=self._portal_options.end_date) self._bg_worker = None def GetDataPortalManifest(self, request, context): return self._data_portal_job_manager.get_portal_manifest() def RequestNewTask(self, request, context): response = dp_pb.NewTaskResponse() finished, task = \ self._data_portal_job_manager.alloc_task(request.rank_id) if task is not None: if isinstance(task, dp_pb.MapTask): response.map_task.MergeFrom(task) else: assert isinstance(task, dp_pb.ReduceTask) response.reduce_task.MergeFrom(task) elif not finished: response.pending.MergeFrom(empty_pb2.Empty()) else: response.finished.MergeFrom(empty_pb2.Empty()) return response def FinishTask(self, request, context): self._data_portal_job_manager.finish_task(request.rank_id, request.partition_id, request.part_state) return common_pb.Status() def start(self): self._bg_worker = RoutineWorker( 'portal_master_bg_worker', self._data_portal_job_manager.backgroup_task, lambda: True, 30) self._bg_worker.start_routine() def stop(self): if self._bg_worker is not None: self._bg_worker.stop_routine() self._bg_worker = None
def start_process(self): with self._cond: if not self._started: self._worker_map = { 'raw_data_batch_fetcher': RoutineWorker('raw_data_batch_fetcher', self._raw_data_batch_fetch_fn, self._raw_data_batch_fetch_cond, 5), 'raw_data_partitioner': RoutineWorker('raw_data_partitioner', self._raw_data_part_fn, self._raw_data_part_cond, 5) } for _, w in self._worker_map.items(): w.start_routine() self._started = True
def start_routine_workers(self): with self._lock: if not self._started: self._worker_map = { 'sync_partition_allocator': RoutineWorker('sync_partition_allocator', self._allocate_sync_partition_fn, self._allocate_sync_partition_cond, 5), 'follower_example_id_syncer': RoutineWorker('follower_example_id_syncer', self._sync_follower_example_id_fn, self._sync_follower_example_id_cond, 5), } for _, w in self._worker_map.items(): w.start_routine() self._state = _SyncState.ALLOC_SYNC_PARTITION self._started = True
def start_routine_workers(self): with self._lock: if not self._started: self._worker_map = { 'join_partition_allocator': RoutineWorker('join_partition_allocator', self._allocate_join_partition_fn, self._allocate_join_partition_cond, 5), 'leader_example_joiner': RoutineWorker('leader_example_joiner', self._join_leader_example_fn, self._join_leader_example_cond, 5), 'data_block_meta_syncer': RoutineWorker('data_block_meta_syncer', self._sync_data_block_meta_fn, self._sync_data_block_meta_cond, 5), } for _, w in self._worker_map.items(): w.start_routine() self._state = _JoinState.ALLOC_JOIN_PARTITION self._started = True
def start_routine_workers(self): with self._lock: if not self._started: self._worker_map = { self._input_data_ready_sniffer_name(): RoutineWorker(self._input_data_ready_sniffer_name(), self._input_data_ready_sniff_fn, self._input_data_ready_sniff_cond, 5), self._repart_executor_name(): RoutineWorker(self._repart_executor_name(), self._repart_execute_fn, self._repart_execute_cond, 5), self._committed_datetime_forwarder_name(): RoutineWorker(self._committed_datetime_forwarder_name(), self._committed_datetime_forward_fn, self._committed_datetime_forward_cond, 5) } self._started = True for _, w in self._worker_map.items(): w.start_routine() self._started = True
def start_routine_workers(self): with self._lock: if not self._started: self._worker_map = { self._partition_allocator_name(): RoutineWorker(self._partition_allocator_name(), self._allocate_new_partition_fn, self._allocate_new_partition_cond, 5), self._producer_name(): RoutineWorker(self._producer_name(), self._data_producer_fn, self._data_producer_cond, 5), self._consumer_name(): RoutineWorker(self._consumer_name(), self._data_consumer_fn, self._data_consumer_cond, 5), } for _, w in self._worker_map.items(): w.start_routine() self._started = True self._wakeup_new_partition_allocator()
class MasterFSM(object): INVALID_PEER_FSM_STATE = {} INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Init] = set([ common_pb.DataSourceState.Failed, common_pb.DataSourceState.Ready, common_pb.DataSourceState.Finished ]) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Processing] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Finished]) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Ready] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init]) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Finished] = set([ common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init, common_pb.DataSourceState.Processing ]) def __init__(self, peer_client, data_source_name, etcd): self._lock = threading.Lock() self._peer_client = peer_client self._data_source_name = data_source_name self._master_etcd_key = os.path.join(data_source_name, 'master') self._etcd = etcd self._init_fsm_action() self._data_source = None self._sync_data_source() assert self._data_source is not None self._raw_data_manifest_manager = RawDataManifestManager( etcd, self._data_source) assert self._data_source is not None self._data_source_meta = self._data_source.data_source_meta if self._data_source.role == common_pb.FLRole.Leader: self._role_repr = "leader" else: self._role_repr = "follower" self._fsm_worker = None self._started = False def get_mainifest_manager(self): return self._raw_data_manifest_manager def get_data_source(self): with self._lock: self._sync_data_source() assert self._data_source is not None return self._data_source def set_failed(self): return self.set_state(common_pb.DataSourceState.Failed, None) def set_state(self, new_state, origin_state=None): with self._lock: try: self._sync_data_source() assert self._data_source is not None if self._data_source.state == new_state: return True if (origin_state is None or self._data_source.state == origin_state): self._data_source.state = new_state self._update_data_source(self._data_source) return True logging.warning( "DataSource: %s failed to set to state: " "%d, origin state mismatch(%d != %d)", self._data_source_name, new_state, origin_state, self._data_source.state) return False except Exception as e: # pylint: disable=broad-except logging.warning("Faile to set state to %d with exception %s", new_state, e) return False return True def start_fsm_worker(self): with self._lock: if not self._started: assert self._fsm_worker is None self._started = True self._fsm_worker = RoutineWorker( '{}_fsm_worker'.format(self._data_source_name), self._fsm_routine_fn, self._fsm_routine_cond, 5) self._fsm_worker.start_routine() def stop_fsm_worker(self): tmp_worker = None with self._lock: if self._fsm_worker is not None: tmp_worker = self._fsm_worker self._fsm_worker = None if tmp_worker is not None: tmp_worker.stop_routine() def _fsm_routine_fn(self): peer_info = self._get_peer_data_source_state() if peer_info.status.code != 0: logging.error("Failed to get peer state, %s", peer_info.status.error_message) return with self._lock: self._sync_data_source() assert self._data_source is not None state = self._data_source.state if self._fallback_failed_state(peer_info): logging.warning( "Self(%s) at state %d, Peer at state %d "\ "state invalid! abort data source %s", self._role_repr, state, peer_info.state, self._data_source_name ) elif state not in self._fsm_driven_handle: logging.error("Self(%s) at error state %d for data_source %s", self._role_repr, state, self._data_source_name) else: state_changed = self._fsm_driven_handle[state](peer_info) if state_changed: self._sync_data_source() assert self._data_source is not None new_state = self._data_source.state logging.warning("Self(%s) state changed from %d to %d", self._role_repr, state, new_state) def _fsm_routine_cond(self): return True def _sync_data_source(self): if self._data_source is None: raw_data = self._etcd.get_data(self._master_etcd_key) if raw_data is None: raise ValueError("etcd master key is None for {}".format( self._data_source_name)) self._data_source = text_format.Parse(raw_data, common_pb.DataSource()) def _init_fsm_action(self): self._fsm_driven_handle = { common_pb.DataSourceState.UnKnown: self._get_fsm_action('unknown'), common_pb.DataSourceState.Init: self._get_fsm_action('init'), common_pb.DataSourceState.Processing: self._get_fsm_action('processing'), common_pb.DataSourceState.Ready: self._get_fsm_action('ready'), common_pb.DataSourceState.Finished: self._get_fsm_action('finished'), common_pb.DataSourceState.Failed: self._get_fsm_action('failed') } def _get_fsm_action(self, action): def _not_implement(useless): raise NotImplementedError('state is not NotImplemented') name = '_fsm_{}_action'.format(action) return getattr(self, name, _not_implement) def _fsm_init_action(self, peer_info): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Init: state_changed = True elif peer_info.state == common_pb.DataSourceState.Processing: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Processing self._update_data_source(self._data_source) return True return False def _fsm_processing_action(self, peer_info): if self._all_partition_finished(): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Processing: state_changed = True elif peer_info.state == common_pb.DataSourceState.Ready: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Ready self._update_data_source(self._data_source) return True return False def _fsm_ready_action(self, peer_info): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Ready: state_changed = True elif peer_info.state == common_pb.DataSourceState.Finished: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Finished self._update_data_source(self._data_source) return True return False def _fsm_finished_action(self, peer_info): return False def _fsm_failed_action(self, peer_info): if peer_info.state != common_pb.DataSourceState.Failed: self._peer_client.AbortDataSource(self._data_source_meta) return False def _fallback_failed_state(self, peer_info): state = self._data_source.state if (state in self.INVALID_PEER_FSM_STATE and peer_info.state in self.INVALID_PEER_FSM_STATE[state]): self._data_source.state = common_pb.DataSourceState.Failed self._update_data_source(self._data_source) return True return False def _update_data_source(self, data_source): self._data_source = None try: self._etcd.set_data(self._master_etcd_key, text_format.MessageToString(data_source)) except Exception as e: logging.error("Failed to update data source: %s since "\ "exception: %s", self._data_source_name, e) raise self._data_source = data_source logging.debug("Success update to update data source: %s.", self._data_source_name) def _get_peer_data_source_state(self): return self._peer_client.GetDataSourceState(self._data_source_meta) def _all_partition_finished(self): all_manifest = self._raw_data_manifest_manager.list_all_manifest() assert (len(all_manifest) == self._data_source.data_source_meta.partition_num) for manifest in all_manifest.values(): if manifest.state != dj_pb.RawDataState.Done: return False return True
class ExampleJoinFollower(object): def __init__(self, etcd, data_source): self._lock = threading.Lock() self._etcd = etcd self._data_source = data_source self._data_block_dump_manager = None self._data_block_dump_worker = None self._started = False def start_create_data_block(self, partition_id): with self._lock: dump_manager = self._data_block_dump_manager if (dump_manager is not None and dump_manager.get_partition_id() != partition_id): raise RuntimeError("partition {} is not finished".format( dump_manager.get_partition_id())) if dump_manager is None: self._data_block_dump_manager = DataBlockDumperManager( self._etcd, self._data_source, partition_id) dump_manager = self._data_block_dump_manager next_index = dump_manager.get_next_data_block_index() return next_index def add_synced_data_block_meta(self, meta): with self._lock: self._check_status(meta.partition_id) manager = self._data_block_dump_manager return manager.append_synced_data_block_meta(meta) def finish_sync_data_block_meta(self, partition_id): with self._lock: self._check_status(partition_id) self._data_block_dump_manager.finish_sync_data_block_meta() return not self._data_block_dump_manager.need_dump() def get_processing_partition_id(self): with self._lock: if self._data_block_dump_manager is None: return None return self._data_block_dump_manager.get_partition_id() def reset_dump_partition(self): with self._lock: if self._data_block_dump_manager is None: return dump_manager = self._data_block_dump_manager partition_id = dump_manager.get_partition_id() self._check_status(partition_id) if (not dump_manager.data_block_meta_sync_finished() or dump_manager.need_dump()): raise RuntimeError( "partition {} is dumpping".format(partition_id)) self._data_block_dump_manager = None def start_dump_worker(self): with self._lock: if not self._started: assert self._data_block_dump_worker is None self._data_block_dump_worker = RoutineWorker( 'data_block_dumper', self._dump_data_block_fn, self._dump_data_block_cond, 1) self._data_block_dump_worker.start_routine() self._started = True def stop_dump_worker(self): dumper = None with self._lock: if self._data_block_dump_worker is not None: dumper = self._data_block_dump_worker self._data_block_dump_worker = None if dumper is not None: dumper.stop_routine() def _check_status(self, partition_id): if self._data_block_dump_manager is None: raise RuntimeError("no partition is processing") ptn_id = self._data_block_dump_manager.get_partition_id() if partition_id != ptn_id: raise RuntimeError("partition id mismatch {} != {}".format( partition_id, ptn_id)) def _dump_data_block_fn(self): dump_manager = self._data_block_dump_manager assert dump_manager is not None if dump_manager.need_dump(): dump_manager.dump_data_blocks() def _dump_data_block_cond(self): with self._lock: return (self._data_block_dump_manager is not None and self._data_block_dump_manager.need_dump())
class TransmitFollower(object): class ImplContext(object): def __init__(self, partition_id): self.partition_id = partition_id def get_next_index(self): raise NotImplementedError("get_next_index is not Implemented "\ "in base ImplContext") def get_dumped_index(self): raise NotImplementedError("get_dumped_index is not Implemented "\ "in base ImplContext") def add_synced_content(self, sync_ctnt): raise NotImplementedError("add_synced_content is not Implemented "\ "in base ImplContext") def finish_sync_content(self): raise NotImplementedError("finish_sync_content is not Implemented "\ "in base ImplContext") def need_dump(self): raise NotImplementedError("need_dump is not Implemented "\ "in base ImplContext") def make_dumper(self): raise NotImplementedError("make_dumper is not Implemented "\ "in base ImplContext") def is_sync_content_finished(self): raise NotImplementedError("is_sync_content_finished is not "\ "Implemented in base ImplContext") def __init__(self, etcd, data_source, repr_str): self._lock = threading.Lock() self._etcd = etcd self._data_source = data_source self._repr_str = repr_str self._dump_worker = None self._impl_ctx = None self._started = False @metrics.timer(func_name='start_sync_partition', tags={'role': 'transmit_follower'}) def start_sync_partition(self, partition_id): with self._lock: if self._impl_ctx is not None and \ self._impl_ctx.partition_id != partition_id: raise RuntimeError("{} is processing partition {}".format( self._repr_str, self._impl_ctx.partition_id)) if self._impl_ctx is None: self._impl_ctx = self._make_new_impl_ctx(partition_id) return self._impl_ctx.get_next_index(), \ self._impl_ctx.get_dumped_index() @metrics.timer(func_name='add_synced_item', tags={'role': 'transmit_follower'}) def add_synced_item(self, sync_ctnt): with self._lock: partition_id = \ self._extract_partition_id_from_sync_content(sync_ctnt) self._check_status(partition_id) filled, next_index = self._impl_ctx.add_synced_content(sync_ctnt) if filled: self._dump_worker.wakeup() return filled, next_index, self._impl_ctx.get_dumped_index() @metrics.timer(func_name='finish_sync_partition', tags={'role': 'transmit_follower'}) def finish_sync_partition(self, partition_id): with self._lock: self._check_status(partition_id) self._impl_ctx.finish_sync_content() return not self._impl_ctx.need_dump(), \ self._impl_ctx.get_dumped_index() @metrics.timer(func_name='reset_partition', tags={'role': 'transmit_follower'}) def reset_partition(self, partition_id): with self._lock: if not self._check_status(partition_id, False): return if not self._impl_ctx.is_sync_content_finished() or \ self._impl_ctx.need_dump(): raise RuntimeError("{} is still dumping for partition {}"\ .format(self._repr_str, partition_id)) self._impl_ctx = None def get_processing_partition_id(self): with self._lock: if self._impl_ctx is None: return None return self._impl_ctx.partition_id def start_dump_worker(self): with self._lock: if not self._started: assert self._dump_worker is None, \ "dumper woker for {} should be None if "\ "not started".format(self._repr_str) self._dump_worker = RoutineWorker( self._repr_str + '-dump_worker', self._dump_fn, self._dump_cond, 1) self._dump_worker.start_routine() self._started = True def stop_dump_worker(self): dumper_worker = None with self._lock: dumper_worker = self._dump_worker self._dump_worker = None if dumper_worker is not None: dumper_worker.stop_routine() def _check_status(self, partition_id, raise_exception=True): if self._impl_ctx is None: if not raise_exception: return False raise RuntimeError("no partition is processing") if self._impl_ctx.partition_id != partition_id: if not raise_exception: return False raise RuntimeError("partition id mismatch {} != {} for {}".format( self._impl_ctx.partition_id, partition_id, self._repr_str)) return True def _dump_fn(self, impl_ctx): with impl_ctx.make_dumper() as dumper: dumper() def _dump_cond(self): with self._lock: if self._impl_ctx is not None and self._impl_ctx.need_dump(): self._dump_worker.setup_args(self._impl_ctx) return True return False def _make_new_impl_ctx(self, partition_id): raise NotImplementedError("_make_new_impl_ctx is not Implemented "\ "in base TransmitFollower") def _extract_partition_id_from_sync_content(self, sync_content): raise NotImplementedError("_extract_partition_id_from_sync_content "\ "is not Implemented in base TransmitFollower")
class RawDataPartitioner(object): class OutputFileWriter(object): def __init__(self, options, partition_id): self._options = options self._partition_id = partition_id self._process_index = 0 self._writer = None self._dumped_item = 0 self._output_fpaths = [] self._output_dir = os.path.join( self._options.output_dir, common.partition_repr(self._partition_id) ) if not gfile.Exists(self._output_dir): gfile.MakeDirs(self._output_dir) assert gfile.IsDirectory(self._output_dir) def append_item(self, index, item): writer = self._get_output_writer() if self._options.output_builder == 'TF_RECORD': writer.write(item.tf_record) else: assert self._options.output_builder == 'CSV_DICT' writer.write(item.csv_record) self._dumped_item += 1 if self._dumped_item >= self._options.output_item_threshold: self._finish_writer() if self._process_index % 16 == 0: logging.info("Output partition %d dump %d files, "\ "last index %d", self._partition_id, self._process_index, index) def finish(self): self._finish_writer() def get_output_files(self): return self._output_fpaths def _get_output_writer(self): if self._writer is None: self._new_writer() return self._writer def _new_writer(self): assert self._writer is None fname = "{:04}-{:08}.rd".format( self._options.partitioner_rank_id, self._process_index ) fpath = os.path.join(self._output_dir, fname) self._output_fpaths.append(fpath) if self._options.output_builder == 'TF_RECORD': self._writer = tf.io.TFRecordWriter(fpath) else: assert self._options.output_builder == 'CSV_DICT' self._writer = CsvDictWriter(fpath) self._dumped_item = 0 def _finish_writer(self): if self._writer is not None: self._writer.close() self._writer = None self._dumped_item = 0 self._process_index += 1 def __init__(self, options): self._options = options self._raw_data_batch_fetcher = RawDataBatchFetcher(options) self._fetch_worker = RoutineWorker('raw_data_batch_fetcher', self._raw_data_batch_fetch_fn, self._raw_data_batch_fetch_cond, 5) self._next_part_index = 0 self._cond = threading.Condition() self._fetch_worker.start_routine() def partition(self): if self._check_finished_tag(): logging.warning("partition has finished for rank id of parti"\ "tioner %d", self._options.partitioner_rank_id) return next_index = 0 hint_index = 0 fetch_finished = False fetcher = self._raw_data_batch_fetcher writers = [RawDataPartitioner.OutputFileWriter(self._options, pid) for pid in range(self._options.output_partition_num)] iter_round = 0 bp_options = self._options.batch_processor_options signal_round_threhold = bp_options.max_flying_item / \ bp_options.batch_size // 8 while not fetch_finished: fetch_finished, batch, hint_index = \ fetcher.fetch_item_batch_by_index(next_index, hint_index) iter_round += 1 if batch is not None: for index, item in enumerate(batch): raw_id = item.raw_id partition_id = CityHash32(raw_id) % \ self._options.output_partition_num writer = writers[partition_id] writer.append_item(batch.begin_index+index, item) next_index = batch.begin_index + len(batch) if iter_round % signal_round_threhold == 0: hint_index = self._evict_staless_batch(hint_index, next_index-1) logging.info("consumed %d items", next_index-1) self._set_next_part_index(next_index) self._wakeup_raw_data_fetcher() elif not fetch_finished: hint_index = self._evict_staless_batch(hint_index, next_index-1) with self._cond: self._cond.wait(1) for partition_id, writer in enumerate(writers): writer.finish() fpaths = writer.get_output_files() logging.info("part %d output %d files by partitioner", partition_id, len(fpaths)) for fpath in fpaths: logging.info("%s", fpath) logging.info("-----------------------------------") self._dump_finished_tag() self._fetch_worker.stop_routine() def _evict_staless_batch(self, hint_index, staless_index): evict_cnt = self._raw_data_batch_fetcher.evict_staless_item_batch( staless_index ) if hint_index <= evict_cnt: return 0 return hint_index-evict_cnt def _set_next_part_index(self, next_part_index): with self._cond: self._next_part_index = next_part_index def _get_next_part_index(self): with self._cond: return self._next_part_index def _raw_data_batch_fetch_fn(self): next_part_index = self._get_next_part_index() fetcher = self._raw_data_batch_fetcher for batch in fetcher.make_processor(next_part_index): logging.debug("fetch batch begin at %d, len %d. wakeup "\ "partitioner", batch.begin_index, len(batch)) self._wakeup_partitioner() def _raw_data_batch_fetch_cond(self): next_part_index = self._get_next_part_index() return self._raw_data_batch_fetcher.need_process(next_part_index) def _wakeup_partitioner(self): with self._cond: self._cond.notify_all() def _wakeup_raw_data_fetcher(self): self._fetch_worker.wakeup() def _dump_finished_tag(self): finished_tag_fpath = self._get_finished_tag_fpath() with gfile.GFile(finished_tag_fpath, 'w') as fh: fh.write('') def _check_finished_tag(self): return gfile.Exists(self._get_finished_tag_fpath()) def _get_finished_tag_fpath(self): return os.path.join( self._options.output_dir, '_SUCCESS.{:08}'.format(self._options.partitioner_rank_id) )
class MasterFSM(object): INVALID_PEER_FSM_STATE = {} INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Init] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Ready, common_pb.DataSourceState.Finished] ) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Processing] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Finished] ) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Ready] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init] ) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Finished] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init, common_pb.DataSourceState.Processing] ) def __init__(self, peer_client, data_source_name, etcd): self._lock = threading.Lock() self._peer_client = peer_client self._data_source_name = data_source_name self._etcd = etcd self._init_fsm_action() self._data_source = None self._sync_data_source() assert self._data_source is not None, \ "data source must not None if sync data source success" self._raw_data_manifest_manager = RawDataManifestManager( etcd, self._data_source ) self._data_source_meta = self._data_source.data_source_meta if self._data_source.role == common_pb.FLRole.Leader: self._role_repr = "leader" else: self._role_repr = "follower" self._fsm_worker = None self._started = False def get_mainifest_manager(self): return self._raw_data_manifest_manager def get_data_source(self): with self._lock: self._sync_data_source() return self._data_source def set_failed(self): return self.set_state(common_pb.DataSourceState.Failed, None) def set_state(self, new_state, origin_state=None): with self._lock: try: self._sync_data_source() if self._data_source.state == new_state: return True if (origin_state is None or self._data_source.state == origin_state): self._data_source.state = new_state self._update_data_source(self._data_source) return True logging.warning("DataSource: %s failed to set to state: " "%d, origin state mismatch(%d != %d)", self._data_source_name, new_state, origin_state, self._data_source.state) return False except Exception as e: # pylint: disable=broad-except logging.warning("Faile to set state to %d with exception %s", new_state, e) return False return True def start_fsm_worker(self): with self._lock: if not self._started: assert self._fsm_worker is None, \ "fsm_woker must be None if FSM is not started" self._started = True self._fsm_worker = RoutineWorker( '{}_fsm_worker'.format(self._data_source_name), self._fsm_routine_fn, self._fsm_routine_cond, 5 ) self._fsm_worker.start_routine() def stop_fsm_worker(self): tmp_worker = None with self._lock: if self._fsm_worker is not None: tmp_worker = self._fsm_worker self._fsm_worker = None if tmp_worker is not None: tmp_worker.stop_routine() def _fsm_routine_fn(self): peer_info = self._get_peer_data_source_status() with self._lock: self._sync_data_source() state = self._data_source.state if self._fallback_failed_state(peer_info): logging.warning("%s at state %d, Peer at state %d "\ "state invalid! abort data source %s", self._role_repr, state, peer_info.state, self._data_source_name) elif state not in self._fsm_driven_handle: logging.error("%s at error state %d for data_source %s", self._role_repr, state, self._data_source_name) else: state_changed = self._fsm_driven_handle[state](peer_info) if state_changed: self._sync_data_source() new_state = self._data_source.state logging.warning("%s state changed from %d to %d", self._role_repr, state, new_state) state = new_state if state in (common_pb.DataSourceState.Init, common_pb.DataSourceState.Processing): self._raw_data_manifest_manager.sub_new_raw_data() def _fsm_routine_cond(self): return True def _sync_data_source(self): if self._data_source is None: self._data_source = \ retrieve_data_source(self._etcd, self._data_source_name) def _init_fsm_action(self): self._fsm_driven_handle = { common_pb.DataSourceState.UnKnown: self._get_fsm_action('unknown'), common_pb.DataSourceState.Init: self._get_fsm_action('init'), common_pb.DataSourceState.Processing: self._get_fsm_action('processing'), common_pb.DataSourceState.Ready: self._get_fsm_action('ready'), common_pb.DataSourceState.Finished: self._get_fsm_action('finished'), common_pb.DataSourceState.Failed: self._get_fsm_action('failed') } def _get_fsm_action(self, action): def _not_implement(useless): raise NotImplementedError('state is not NotImplemented') name = '_fsm_{}_action'.format(action) return getattr(self, name, _not_implement) def _fsm_init_action(self, peer_info): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Init: state_changed = True elif peer_info.state == common_pb.DataSourceState.Processing: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Processing self._update_data_source(self._data_source) return True return False def _fsm_processing_action(self, peer_info): if self._all_partition_finished(): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Processing: state_changed = True elif peer_info.state == common_pb.DataSourceState.Ready: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Ready self._update_data_source(self._data_source) return True return False def _fsm_ready_action(self, peer_info): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Ready: state_changed = True elif peer_info.state == common_pb.DataSourceState.Finished: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Finished self._update_data_source(self._data_source) return True return False def _fsm_finished_action(self, peer_info): return False def _fsm_failed_action(self, peer_info): if peer_info.state != common_pb.DataSourceState.Failed: request = dj_pb.DataSourceRequest( data_source_meta=self._data_source_meta ) self._peer_client.AbortDataSource(request) return False def _fallback_failed_state(self, peer_info): state = self._data_source.state if (state in self.INVALID_PEER_FSM_STATE and peer_info.state in self.INVALID_PEER_FSM_STATE[state]): self._data_source.state = common_pb.DataSourceState.Failed self._update_data_source(self._data_source) return True return False def _update_data_source(self, data_source): self._data_source = None try: commit_data_source(self._etcd, data_source) except Exception as e: logging.error("Failed to update data source: %s since "\ "exception: %s", self._data_source_name, e) raise self._data_source = data_source logging.debug("Success update to update data source: %s.", self._data_source_name) def _get_peer_data_source_status(self): request = dj_pb.DataSourceRequest( data_source_meta=self._data_source_meta ) return self._peer_client.GetDataSourceStatus(request) def _all_partition_finished(self): all_manifest = self._raw_data_manifest_manager.list_all_manifest() assert len(all_manifest) == \ self._data_source.data_source_meta.partition_num, \ "manifest number should same with partition number" for manifest in all_manifest.values(): if manifest.sync_example_id_rep.state != \ dj_pb.SyncExampleIdState.Synced or \ manifest.join_example_rep.state != \ dj_pb.JoinExampleState.Joined: return False return True
class ExampleIdSyncFollower(object): def __init__(self, data_source): self._lock = threading.Lock() self._data_source = data_source self._example_id_dump_manager = None self._example_id_dump_worker = None self._started = False def start_dump_partition(self, partition_id): with self._lock: if (self._example_id_dump_manager is not None and self._example_id_dump_manager.get_partition_id() != partition_id): ptn_id = self._example_id_dump_manager.get_partition_id() raise RuntimeError( "partition {} is not finished".format(ptn_id)) if self._example_id_dump_manager is None: self._example_id_dump_manager = ExampleIdDumperManager( self._data_source, partition_id) dump_manager = self._example_id_dump_manager next_index = dump_manager.get_next_index() return next_index def add_synced_example_req(self, synced_example_req): with self._lock: self._check_status(synced_example_req.partition_id) return self._example_id_dump_manager.append_synced_example_req( synced_example_req) def finish_sync_partition_example(self, partition_id): with self._lock: self._check_status(partition_id) self._example_id_dump_manager.finish_sync_example() return not self._example_id_dump_manager.need_dump() def reset_dump_partition(self): with self._lock: if self._example_id_dump_manager is None: return partition_id = self._example_id_dump_manager.get_partition_id() self._check_status(partition_id) if (not self._example_id_dump_manager.example_sync_finished() or self._example_id_dump_manager.need_dump()): raise RuntimeError( "partition {} is dumpping".format(partition_id)) self._example_id_dump_manager = None def get_processing_partition_id(self): with self._lock: if self._example_id_dump_manager is None: return None return self._example_id_dump_manager.get_partition_id() def start_dump_worker(self): with self._lock: if not self._started: assert self._example_id_dump_worker is None self._example_id_dump_worker = RoutineWorker( 'example_id_dumper', self._dump_example_ids_fn, self._dump_example_ids_cond, 1) self._example_id_dump_worker.start_routine() self._started = True def stop_dump_worker(self): dumper = None with self._lock: if self._example_id_dump_worker is not None: dumper = self._example_id_dump_worker self._dump_worker = None if dumper is not None: dumper.stop_routine() def _check_status(self, partition_id): if self._example_id_dump_manager is None: raise RuntimeError("no partition is processing") ptn_id = self._example_id_dump_manager.get_partition_id() if partition_id != ptn_id: raise RuntimeError("partition id mismatch {} != {}".format( partition_id, ptn_id)) def _dump_example_ids_fn(self): dump_manager = None with self._lock: dump_manager = self._example_id_dump_manager assert dump_manager is not None if dump_manager.need_dump(): dump_manager.dump_example_ids() def _dump_example_ids_cond(self): with self._lock: return (self._example_id_dump_manager is not None and self._example_id_dump_manager.need_dump())
class PortalRawDataNotifier(object): class NotifyCtx(object): def __init__(self, master_addr): self._master_addr = master_addr channel = make_insecure_channel(master_addr, ChannelType.INTERNAL) self._master_cli = dj_grpc.DataJoinMasterServiceStub(channel) self._data_source = None self._raw_date_ctl = None self._raw_data_updated_datetime = {} @property def data_source(self): if self._data_source is None: self._data_source = \ self._master_cli.GetDataSource(empty_pb2.Empty()) return self._data_source def get_raw_data_updated_datetime(self, partition_id): if partition_id not in self._raw_data_updated_datetime: ts = self._raw_date_controller.get_raw_data_latest_timestamp( partition_id) if ts.seconds > 3600: ts.seconds -= 3600 else: ts.seconds = 0 self._raw_data_updated_datetime[partition_id] = \ common.convert_timestamp_to_datetime( common.trim_timestamp_by_hourly(ts) ) return self._raw_data_updated_datetime[partition_id] def add_raw_data(self, partition_id, fpaths, timestamps, end_ts): assert len(fpaths) == len(timestamps), \ "the number of raw data path and timestamp should same" if len(fpaths) > 0: self._raw_date_controller.add_raw_data(partition_id, fpaths, True, timestamps) self._raw_data_updated_datetime[partition_id] = end_ts @property def data_source_master_addr(self): return self._master_addr @property def _raw_date_controller(self): if self._raw_date_ctl is None: self._raw_date_ctl = RawDataController(self.data_source, self._master_cli) return self._raw_date_ctl def __init__(self, etcd, portal_name, downstream_data_source_masters): self._lock = threading.Lock() self._etcd = etcd self._portal_name = portal_name assert len(downstream_data_source_masters) > 0, \ "PortalRawDataNotifier launched when has master to notify" self._master_notify_ctx = {} for addr in downstream_data_source_masters: self._master_notify_ctx[addr] = \ PortalRawDataNotifier.NotifyCtx(addr) self._notify_worker = None self._started = False def start_notify_worker(self): with self._lock: if not self._started: assert self._notify_worker is None, \ "notify worker should be None if not started" self._notify_worker = RoutineWorker('potral-raw_data-notifier', self._raw_data_notify_fn, self._raw_data_notify_cond, 5) self._notify_worker.start_routine() self._started = True self._notify_worker.wakeup() def stop_notify_worker(self): notify_worker = None with self._lock: notify_worker = self._notify_worker self._notify_worker = None if notify_worker is not None: notify_worker.stop_routine() def _check_partition_num(self, notify_ctx, portal_manifest): assert isinstance(notify_ctx, PortalRawDataNotifier.NotifyCtx) data_source = notify_ctx.data_source ds_partition_num = data_source.data_source_meta.partition_num if portal_manifest.output_partition_num % ds_partition_num != 0: raise ValueError( "the partition number({}) of down stream data source "\ "{} should be divised by output partition of "\ "portatl({})".format(ds_partition_num, data_source.data_source_meta.name, portal_manifest.output_partition_num) ) def _add_raw_data_impl(self, notify_ctx, portal_manifest, ds_pid): dt = notify_ctx.get_raw_data_updated_datetime(ds_pid) + \ timedelta(hours=1) begin_dt = common.convert_timestamp_to_datetime( common.trim_timestamp_by_hourly(portal_manifest.begin_timestamp)) if dt < begin_dt: dt = begin_dt committed_dt = common.convert_timestamp_to_datetime( portal_manifest.committed_timestamp) fpaths = [] timestamps = [] ds_ptnum = notify_ctx.data_source.data_source_meta.partition_num while dt <= committed_dt: for pt_pid in range(ds_pid, portal_manifest.output_partition_num, ds_ptnum): fpath = common.encode_portal_hourly_fpath( portal_manifest.output_data_base_dir, dt, pt_pid) if gfile.Exists(fpath): fpaths.append(fpath) timestamps.append(common.convert_datetime_to_timestamp(dt)) if len(fpaths) > 32 or dt == committed_dt: break dt += timedelta(hours=1) notify_ctx.add_raw_data(ds_pid, fpaths, timestamps, dt) logging.info("add %d raw data file for partition %d of data "\ "source %s. latest updated datetime %s", len(fpaths), ds_pid, notify_ctx.data_source.data_source_meta.name, dt) return dt >= committed_dt def _notify_one_data_source(self, notify_ctx, portal_manifest): assert isinstance(notify_ctx, PortalRawDataNotifier.NotifyCtx) try: self._check_partition_num(notify_ctx, portal_manifest) ds_ptnum = notify_ctx.data_source.data_source_meta.partition_num pt_ptnum = portal_manifest.output_partition_num add_finished = False while not add_finished: add_finished = True for ds_pid in range(ds_ptnum): if not self._add_raw_data_impl(notify_ctx, portal_manifest, ds_pid): add_finished = False except Exception as e: # pylint: disable=broad-except logging.error("Failed to notify data source[master-addr: %s] "\ "new raw data added, reason %s", notify_ctx.data_source_master_addr, e) def _raw_data_notify_fn(self): portal_manifest = common.retrieve_portal_manifest( self._etcd, self._portal_name) for _, notify_ctx in self._master_notify_ctx.items(): self._notify_one_data_source(notify_ctx, portal_manifest) def _raw_data_notify_cond(self): return True