class MasterFSM(object): INVALID_PEER_FSM_STATE = {} INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Init] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Ready, common_pb.DataSourceState.Finished] ) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Processing] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Finished] ) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Ready] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init] ) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Finished] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init, common_pb.DataSourceState.Processing] ) def __init__(self, peer_client, data_source_name, etcd): self._lock = threading.Lock() self._peer_client = peer_client self._data_source_name = data_source_name self._etcd = etcd self._init_fsm_action() self._data_source = None self._sync_data_source() assert self._data_source is not None, \ "data source must not None if sync data source success" self._raw_data_manifest_manager = RawDataManifestManager( etcd, self._data_source ) self._data_source_meta = self._data_source.data_source_meta if self._data_source.role == common_pb.FLRole.Leader: self._role_repr = "leader" else: self._role_repr = "follower" self._fsm_worker = None self._started = False def get_mainifest_manager(self): return self._raw_data_manifest_manager def get_data_source(self): with self._lock: self._sync_data_source() return self._data_source def set_failed(self): return self.set_state(common_pb.DataSourceState.Failed, None) def set_state(self, new_state, origin_state=None): with self._lock: try: self._sync_data_source() if self._data_source.state == new_state: return True if (origin_state is None or self._data_source.state == origin_state): self._data_source.state = new_state self._update_data_source(self._data_source) return True logging.warning("DataSource: %s failed to set to state: " "%d, origin state mismatch(%d != %d)", self._data_source_name, new_state, origin_state, self._data_source.state) return False except Exception as e: # pylint: disable=broad-except logging.warning("Faile to set state to %d with exception %s", new_state, e) return False return True def start_fsm_worker(self): with self._lock: if not self._started: assert self._fsm_worker is None, \ "fsm_woker must be None if FSM is not started" self._started = True self._fsm_worker = RoutineWorker( '{}_fsm_worker'.format(self._data_source_name), self._fsm_routine_fn, self._fsm_routine_cond, 5 ) self._fsm_worker.start_routine() def stop_fsm_worker(self): tmp_worker = None with self._lock: if self._fsm_worker is not None: tmp_worker = self._fsm_worker self._fsm_worker = None if tmp_worker is not None: tmp_worker.stop_routine() def _fsm_routine_fn(self): peer_info = self._get_peer_data_source_status() with self._lock: self._sync_data_source() state = self._data_source.state if self._fallback_failed_state(peer_info): logging.warning("%s at state %d, Peer at state %d "\ "state invalid! abort data source %s", self._role_repr, state, peer_info.state, self._data_source_name) elif state not in self._fsm_driven_handle: logging.error("%s at error state %d for data_source %s", self._role_repr, state, self._data_source_name) else: state_changed = self._fsm_driven_handle[state](peer_info) if state_changed: self._sync_data_source() new_state = self._data_source.state logging.warning("%s state changed from %d to %d", self._role_repr, state, new_state) state = new_state if state in (common_pb.DataSourceState.Init, common_pb.DataSourceState.Processing): self._raw_data_manifest_manager.sub_new_raw_data() def _fsm_routine_cond(self): return True def _sync_data_source(self): if self._data_source is None: self._data_source = \ retrieve_data_source(self._etcd, self._data_source_name) def _init_fsm_action(self): self._fsm_driven_handle = { common_pb.DataSourceState.UnKnown: self._get_fsm_action('unknown'), common_pb.DataSourceState.Init: self._get_fsm_action('init'), common_pb.DataSourceState.Processing: self._get_fsm_action('processing'), common_pb.DataSourceState.Ready: self._get_fsm_action('ready'), common_pb.DataSourceState.Finished: self._get_fsm_action('finished'), common_pb.DataSourceState.Failed: self._get_fsm_action('failed') } def _get_fsm_action(self, action): def _not_implement(useless): raise NotImplementedError('state is not NotImplemented') name = '_fsm_{}_action'.format(action) return getattr(self, name, _not_implement) def _fsm_init_action(self, peer_info): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Init: state_changed = True elif peer_info.state == common_pb.DataSourceState.Processing: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Processing self._update_data_source(self._data_source) return True return False def _fsm_processing_action(self, peer_info): if self._all_partition_finished(): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Processing: state_changed = True elif peer_info.state == common_pb.DataSourceState.Ready: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Ready self._update_data_source(self._data_source) return True return False def _fsm_ready_action(self, peer_info): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Ready: state_changed = True elif peer_info.state == common_pb.DataSourceState.Finished: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Finished self._update_data_source(self._data_source) return True return False def _fsm_finished_action(self, peer_info): return False def _fsm_failed_action(self, peer_info): if peer_info.state != common_pb.DataSourceState.Failed: request = dj_pb.DataSourceRequest( data_source_meta=self._data_source_meta ) self._peer_client.AbortDataSource(request) return False def _fallback_failed_state(self, peer_info): state = self._data_source.state if (state in self.INVALID_PEER_FSM_STATE and peer_info.state in self.INVALID_PEER_FSM_STATE[state]): self._data_source.state = common_pb.DataSourceState.Failed self._update_data_source(self._data_source) return True return False def _update_data_source(self, data_source): self._data_source = None try: commit_data_source(self._etcd, data_source) except Exception as e: logging.error("Failed to update data source: %s since "\ "exception: %s", self._data_source_name, e) raise self._data_source = data_source logging.debug("Success update to update data source: %s.", self._data_source_name) def _get_peer_data_source_status(self): request = dj_pb.DataSourceRequest( data_source_meta=self._data_source_meta ) return self._peer_client.GetDataSourceStatus(request) def _all_partition_finished(self): all_manifest = self._raw_data_manifest_manager.list_all_manifest() assert len(all_manifest) == \ self._data_source.data_source_meta.partition_num, \ "manifest number should same with partition number" for manifest in all_manifest.values(): if manifest.sync_example_id_rep.state != \ dj_pb.SyncExampleIdState.Synced or \ manifest.join_example_rep.state != \ dj_pb.JoinExampleState.Joined: return False return True
class MasterFSM(object): INVALID_PEER_FSM_STATE = {} INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Init] = set([ common_pb.DataSourceState.Failed, common_pb.DataSourceState.Ready, common_pb.DataSourceState.Finished ]) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Processing] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Finished]) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Ready] = set( [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init]) INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Finished] = set([ common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init, common_pb.DataSourceState.Processing ]) def __init__(self, peer_client, data_source_name, etcd): self._lock = threading.Lock() self._peer_client = peer_client self._data_source_name = data_source_name self._master_etcd_key = os.path.join(data_source_name, 'master') self._etcd = etcd self._init_fsm_action() self._data_source = None self._sync_data_source() assert self._data_source is not None self._raw_data_manifest_manager = RawDataManifestManager( etcd, self._data_source) assert self._data_source is not None self._data_source_meta = self._data_source.data_source_meta if self._data_source.role == common_pb.FLRole.Leader: self._role_repr = "leader" else: self._role_repr = "follower" self._fsm_worker = None self._started = False def get_mainifest_manager(self): return self._raw_data_manifest_manager def get_data_source(self): with self._lock: self._sync_data_source() assert self._data_source is not None return self._data_source def set_failed(self): return self.set_state(common_pb.DataSourceState.Failed, None) def set_state(self, new_state, origin_state=None): with self._lock: try: self._sync_data_source() assert self._data_source is not None if self._data_source.state == new_state: return True if (origin_state is None or self._data_source.state == origin_state): self._data_source.state = new_state self._update_data_source(self._data_source) return True logging.warning( "DataSource: %s failed to set to state: " "%d, origin state mismatch(%d != %d)", self._data_source_name, new_state, origin_state, self._data_source.state) return False except Exception as e: # pylint: disable=broad-except logging.warning("Faile to set state to %d with exception %s", new_state, e) return False return True def start_fsm_worker(self): with self._lock: if not self._started: assert self._fsm_worker is None self._started = True self._fsm_worker = RoutineWorker( '{}_fsm_worker'.format(self._data_source_name), self._fsm_routine_fn, self._fsm_routine_cond, 5) self._fsm_worker.start_routine() def stop_fsm_worker(self): tmp_worker = None with self._lock: if self._fsm_worker is not None: tmp_worker = self._fsm_worker self._fsm_worker = None if tmp_worker is not None: tmp_worker.stop_routine() def _fsm_routine_fn(self): peer_info = self._get_peer_data_source_state() if peer_info.status.code != 0: logging.error("Failed to get peer state, %s", peer_info.status.error_message) return with self._lock: self._sync_data_source() assert self._data_source is not None state = self._data_source.state if self._fallback_failed_state(peer_info): logging.warning( "Self(%s) at state %d, Peer at state %d "\ "state invalid! abort data source %s", self._role_repr, state, peer_info.state, self._data_source_name ) elif state not in self._fsm_driven_handle: logging.error("Self(%s) at error state %d for data_source %s", self._role_repr, state, self._data_source_name) else: state_changed = self._fsm_driven_handle[state](peer_info) if state_changed: self._sync_data_source() assert self._data_source is not None new_state = self._data_source.state logging.warning("Self(%s) state changed from %d to %d", self._role_repr, state, new_state) def _fsm_routine_cond(self): return True def _sync_data_source(self): if self._data_source is None: raw_data = self._etcd.get_data(self._master_etcd_key) if raw_data is None: raise ValueError("etcd master key is None for {}".format( self._data_source_name)) self._data_source = text_format.Parse(raw_data, common_pb.DataSource()) def _init_fsm_action(self): self._fsm_driven_handle = { common_pb.DataSourceState.UnKnown: self._get_fsm_action('unknown'), common_pb.DataSourceState.Init: self._get_fsm_action('init'), common_pb.DataSourceState.Processing: self._get_fsm_action('processing'), common_pb.DataSourceState.Ready: self._get_fsm_action('ready'), common_pb.DataSourceState.Finished: self._get_fsm_action('finished'), common_pb.DataSourceState.Failed: self._get_fsm_action('failed') } def _get_fsm_action(self, action): def _not_implement(useless): raise NotImplementedError('state is not NotImplemented') name = '_fsm_{}_action'.format(action) return getattr(self, name, _not_implement) def _fsm_init_action(self, peer_info): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Init: state_changed = True elif peer_info.state == common_pb.DataSourceState.Processing: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Processing self._update_data_source(self._data_source) return True return False def _fsm_processing_action(self, peer_info): if self._all_partition_finished(): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Processing: state_changed = True elif peer_info.state == common_pb.DataSourceState.Ready: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Ready self._update_data_source(self._data_source) return True return False def _fsm_ready_action(self, peer_info): state_changed = False if self._data_source.role == common_pb.FLRole.Leader: if peer_info.state == common_pb.DataSourceState.Ready: state_changed = True elif peer_info.state == common_pb.DataSourceState.Finished: state_changed = True if state_changed: self._data_source.state = common_pb.DataSourceState.Finished self._update_data_source(self._data_source) return True return False def _fsm_finished_action(self, peer_info): return False def _fsm_failed_action(self, peer_info): if peer_info.state != common_pb.DataSourceState.Failed: self._peer_client.AbortDataSource(self._data_source_meta) return False def _fallback_failed_state(self, peer_info): state = self._data_source.state if (state in self.INVALID_PEER_FSM_STATE and peer_info.state in self.INVALID_PEER_FSM_STATE[state]): self._data_source.state = common_pb.DataSourceState.Failed self._update_data_source(self._data_source) return True return False def _update_data_source(self, data_source): self._data_source = None try: self._etcd.set_data(self._master_etcd_key, text_format.MessageToString(data_source)) except Exception as e: logging.error("Failed to update data source: %s since "\ "exception: %s", self._data_source_name, e) raise self._data_source = data_source logging.debug("Success update to update data source: %s.", self._data_source_name) def _get_peer_data_source_state(self): return self._peer_client.GetDataSourceState(self._data_source_meta) def _all_partition_finished(self): all_manifest = self._raw_data_manifest_manager.list_all_manifest() assert (len(all_manifest) == self._data_source.data_source_meta.partition_num) for manifest in all_manifest.values(): if manifest.state != dj_pb.RawDataState.Done: return False return True