示例#1
0
class MasterFSM(object):
    INVALID_PEER_FSM_STATE = {}
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Init] = set(
            [common_pb.DataSourceState.Failed,
             common_pb.DataSourceState.Ready,
             common_pb.DataSourceState.Finished]
        )
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Processing] = set(
            [common_pb.DataSourceState.Failed,
             common_pb.DataSourceState.Finished]
        )
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Ready] = set(
            [common_pb.DataSourceState.Failed,
             common_pb.DataSourceState.Init]
        )
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Finished] = set(
            [common_pb.DataSourceState.Failed,
             common_pb.DataSourceState.Init,
             common_pb.DataSourceState.Processing]
        )

    def __init__(self, peer_client, data_source_name, etcd):
        self._lock = threading.Lock()
        self._peer_client = peer_client
        self._data_source_name = data_source_name
        self._etcd = etcd
        self._init_fsm_action()
        self._data_source = None
        self._sync_data_source()
        assert self._data_source is not None, \
            "data source must not None if sync data source success"
        self._raw_data_manifest_manager = RawDataManifestManager(
                etcd, self._data_source
            )
        self._data_source_meta = self._data_source.data_source_meta
        if self._data_source.role == common_pb.FLRole.Leader:
            self._role_repr = "leader"
        else:
            self._role_repr = "follower"
        self._fsm_worker = None
        self._started = False

    def get_mainifest_manager(self):
        return self._raw_data_manifest_manager

    def get_data_source(self):
        with self._lock:
            self._sync_data_source()
            return self._data_source

    def set_failed(self):
        return self.set_state(common_pb.DataSourceState.Failed, None)

    def set_state(self, new_state, origin_state=None):
        with self._lock:
            try:
                self._sync_data_source()
                if self._data_source.state == new_state:
                    return True
                if (origin_state is None or
                        self._data_source.state == origin_state):
                    self._data_source.state = new_state
                    self._update_data_source(self._data_source)
                    return True
                logging.warning("DataSource: %s failed to set to state: "
                                "%d, origin state mismatch(%d != %d)",
                                self._data_source_name, new_state,
                                origin_state, self._data_source.state)
                return False
            except Exception as e: # pylint: disable=broad-except
                logging.warning("Faile to set state to %d with exception %s",
                                new_state, e)
                return False
            return True

    def start_fsm_worker(self):
        with self._lock:
            if not self._started:
                assert self._fsm_worker is None, \
                    "fsm_woker must be None if FSM is not started"
                self._started = True
                self._fsm_worker = RoutineWorker(
                        '{}_fsm_worker'.format(self._data_source_name),
                        self._fsm_routine_fn,
                        self._fsm_routine_cond, 5
                    )
                self._fsm_worker.start_routine()

    def stop_fsm_worker(self):
        tmp_worker = None
        with self._lock:
            if self._fsm_worker is not None:
                tmp_worker = self._fsm_worker
                self._fsm_worker = None
        if tmp_worker is not None:
            tmp_worker.stop_routine()

    def _fsm_routine_fn(self):
        peer_info = self._get_peer_data_source_status()
        with self._lock:
            self._sync_data_source()
            state = self._data_source.state
            if self._fallback_failed_state(peer_info):
                logging.warning("%s at state %d, Peer at state %d "\
                                "state invalid! abort data source %s",
                                self._role_repr, state,
                                peer_info.state, self._data_source_name)
            elif state not in self._fsm_driven_handle:
                logging.error("%s at error state %d for data_source %s",
                               self._role_repr, state, self._data_source_name)
            else:
                state_changed = self._fsm_driven_handle[state](peer_info)
                if state_changed:
                    self._sync_data_source()
                    new_state = self._data_source.state
                    logging.warning("%s state changed from %d to %d",
                                    self._role_repr, state, new_state)
                    state = new_state
            if state in (common_pb.DataSourceState.Init,
                            common_pb.DataSourceState.Processing):
                self._raw_data_manifest_manager.sub_new_raw_data()

    def _fsm_routine_cond(self):
        return True

    def _sync_data_source(self):
        if self._data_source is None:
            self._data_source = \
                retrieve_data_source(self._etcd, self._data_source_name)

    def _init_fsm_action(self):
        self._fsm_driven_handle = {
             common_pb.DataSourceState.UnKnown:
                 self._get_fsm_action('unknown'),
             common_pb.DataSourceState.Init:
                 self._get_fsm_action('init'),
             common_pb.DataSourceState.Processing:
                 self._get_fsm_action('processing'),
             common_pb.DataSourceState.Ready:
                 self._get_fsm_action('ready'),
             common_pb.DataSourceState.Finished:
                 self._get_fsm_action('finished'),
             common_pb.DataSourceState.Failed:
                 self._get_fsm_action('failed')
        }

    def _get_fsm_action(self, action):
        def _not_implement(useless):
            raise NotImplementedError('state is not NotImplemented')
        name = '_fsm_{}_action'.format(action)
        return getattr(self, name, _not_implement)

    def _fsm_init_action(self, peer_info):
        state_changed = False
        if self._data_source.role == common_pb.FLRole.Leader:
            if peer_info.state == common_pb.DataSourceState.Init:
                state_changed = True
        elif peer_info.state == common_pb.DataSourceState.Processing:
            state_changed = True
        if state_changed:
            self._data_source.state = common_pb.DataSourceState.Processing
            self._update_data_source(self._data_source)
            return True
        return False

    def _fsm_processing_action(self, peer_info):
        if self._all_partition_finished():
            state_changed = False
            if self._data_source.role == common_pb.FLRole.Leader:
                if peer_info.state == common_pb.DataSourceState.Processing:
                    state_changed = True
            elif peer_info.state == common_pb.DataSourceState.Ready:
                state_changed = True
            if state_changed:
                self._data_source.state = common_pb.DataSourceState.Ready
                self._update_data_source(self._data_source)
                return True
        return False

    def _fsm_ready_action(self, peer_info):
        state_changed = False
        if self._data_source.role == common_pb.FLRole.Leader:
            if peer_info.state == common_pb.DataSourceState.Ready:
                state_changed = True
        elif peer_info.state == common_pb.DataSourceState.Finished:
            state_changed = True
        if state_changed:
            self._data_source.state = common_pb.DataSourceState.Finished
            self._update_data_source(self._data_source)
            return True
        return False

    def _fsm_finished_action(self, peer_info):
        return False

    def _fsm_failed_action(self, peer_info):
        if peer_info.state != common_pb.DataSourceState.Failed:
            request = dj_pb.DataSourceRequest(
                    data_source_meta=self._data_source_meta
                )
            self._peer_client.AbortDataSource(request)
        return False

    def _fallback_failed_state(self, peer_info):
        state = self._data_source.state
        if (state in self.INVALID_PEER_FSM_STATE and
                peer_info.state in self.INVALID_PEER_FSM_STATE[state]):
            self._data_source.state = common_pb.DataSourceState.Failed
            self._update_data_source(self._data_source)
            return True
        return False

    def _update_data_source(self, data_source):
        self._data_source = None
        try:
            commit_data_source(self._etcd, data_source)
        except Exception as e:
            logging.error("Failed to update data source: %s since "\
                          "exception: %s", self._data_source_name, e)
            raise
        self._data_source = data_source
        logging.debug("Success update to update data source: %s.",
                       self._data_source_name)

    def _get_peer_data_source_status(self):
        request = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_meta
            )
        return self._peer_client.GetDataSourceStatus(request)

    def _all_partition_finished(self):
        all_manifest = self._raw_data_manifest_manager.list_all_manifest()
        assert len(all_manifest) == \
                self._data_source.data_source_meta.partition_num, \
            "manifest number should same with partition number"
        for manifest in all_manifest.values():
            if manifest.sync_example_id_rep.state != \
                    dj_pb.SyncExampleIdState.Synced or \
                    manifest.join_example_rep.state != \
                    dj_pb.JoinExampleState.Joined:
                return False
        return True
class MasterFSM(object):
    INVALID_PEER_FSM_STATE = {}
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Init] = set([
        common_pb.DataSourceState.Failed, common_pb.DataSourceState.Ready,
        common_pb.DataSourceState.Finished
    ])
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Processing] = set(
        [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Finished])
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Ready] = set(
        [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init])
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Finished] = set([
        common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init,
        common_pb.DataSourceState.Processing
    ])

    def __init__(self, peer_client, data_source_name, etcd):
        self._lock = threading.Lock()
        self._peer_client = peer_client
        self._data_source_name = data_source_name
        self._master_etcd_key = os.path.join(data_source_name, 'master')
        self._etcd = etcd
        self._init_fsm_action()
        self._data_source = None
        self._sync_data_source()
        assert self._data_source is not None
        self._raw_data_manifest_manager = RawDataManifestManager(
            etcd, self._data_source)
        assert self._data_source is not None
        self._data_source_meta = self._data_source.data_source_meta
        if self._data_source.role == common_pb.FLRole.Leader:
            self._role_repr = "leader"
        else:
            self._role_repr = "follower"
        self._fsm_worker = None
        self._started = False

    def get_mainifest_manager(self):
        return self._raw_data_manifest_manager

    def get_data_source(self):
        with self._lock:
            self._sync_data_source()
            assert self._data_source is not None
            return self._data_source

    def set_failed(self):
        return self.set_state(common_pb.DataSourceState.Failed, None)

    def set_state(self, new_state, origin_state=None):
        with self._lock:
            try:
                self._sync_data_source()
                assert self._data_source is not None
                if self._data_source.state == new_state:
                    return True
                if (origin_state is None
                        or self._data_source.state == origin_state):
                    self._data_source.state = new_state
                    self._update_data_source(self._data_source)
                    return True
                logging.warning(
                    "DataSource: %s failed to set to state: "
                    "%d, origin state mismatch(%d != %d)",
                    self._data_source_name, new_state, origin_state,
                    self._data_source.state)
                return False
            except Exception as e:  # pylint: disable=broad-except
                logging.warning("Faile to set state to %d with exception %s",
                                new_state, e)
                return False
            return True

    def start_fsm_worker(self):
        with self._lock:
            if not self._started:
                assert self._fsm_worker is None
                self._started = True
                self._fsm_worker = RoutineWorker(
                    '{}_fsm_worker'.format(self._data_source_name),
                    self._fsm_routine_fn, self._fsm_routine_cond, 5)
                self._fsm_worker.start_routine()

    def stop_fsm_worker(self):
        tmp_worker = None
        with self._lock:
            if self._fsm_worker is not None:
                tmp_worker = self._fsm_worker
                self._fsm_worker = None
        if tmp_worker is not None:
            tmp_worker.stop_routine()

    def _fsm_routine_fn(self):
        peer_info = self._get_peer_data_source_state()
        if peer_info.status.code != 0:
            logging.error("Failed to get peer state, %s",
                          peer_info.status.error_message)
            return
        with self._lock:
            self._sync_data_source()
            assert self._data_source is not None
            state = self._data_source.state
            if self._fallback_failed_state(peer_info):
                logging.warning(
                        "Self(%s) at state %d, Peer at state %d "\
                        "state invalid! abort data source %s",
                        self._role_repr, state,
                        peer_info.state, self._data_source_name
                    )
            elif state not in self._fsm_driven_handle:
                logging.error("Self(%s) at error state %d for data_source %s",
                              self._role_repr, state, self._data_source_name)
            else:
                state_changed = self._fsm_driven_handle[state](peer_info)
                if state_changed:
                    self._sync_data_source()
                    assert self._data_source is not None
                    new_state = self._data_source.state
                    logging.warning("Self(%s) state changed from %d to %d",
                                    self._role_repr, state, new_state)

    def _fsm_routine_cond(self):
        return True

    def _sync_data_source(self):
        if self._data_source is None:
            raw_data = self._etcd.get_data(self._master_etcd_key)
            if raw_data is None:
                raise ValueError("etcd master key is None for {}".format(
                    self._data_source_name))
            self._data_source = text_format.Parse(raw_data,
                                                  common_pb.DataSource())

    def _init_fsm_action(self):
        self._fsm_driven_handle = {
            common_pb.DataSourceState.UnKnown:
            self._get_fsm_action('unknown'),
            common_pb.DataSourceState.Init:
            self._get_fsm_action('init'),
            common_pb.DataSourceState.Processing:
            self._get_fsm_action('processing'),
            common_pb.DataSourceState.Ready:
            self._get_fsm_action('ready'),
            common_pb.DataSourceState.Finished:
            self._get_fsm_action('finished'),
            common_pb.DataSourceState.Failed:
            self._get_fsm_action('failed')
        }

    def _get_fsm_action(self, action):
        def _not_implement(useless):
            raise NotImplementedError('state is not NotImplemented')

        name = '_fsm_{}_action'.format(action)
        return getattr(self, name, _not_implement)

    def _fsm_init_action(self, peer_info):
        state_changed = False
        if self._data_source.role == common_pb.FLRole.Leader:
            if peer_info.state == common_pb.DataSourceState.Init:
                state_changed = True
        elif peer_info.state == common_pb.DataSourceState.Processing:
            state_changed = True
        if state_changed:
            self._data_source.state = common_pb.DataSourceState.Processing
            self._update_data_source(self._data_source)
            return True
        return False

    def _fsm_processing_action(self, peer_info):
        if self._all_partition_finished():
            state_changed = False
            if self._data_source.role == common_pb.FLRole.Leader:
                if peer_info.state == common_pb.DataSourceState.Processing:
                    state_changed = True
            elif peer_info.state == common_pb.DataSourceState.Ready:
                state_changed = True
            if state_changed:
                self._data_source.state = common_pb.DataSourceState.Ready
                self._update_data_source(self._data_source)
                return True
        return False

    def _fsm_ready_action(self, peer_info):
        state_changed = False
        if self._data_source.role == common_pb.FLRole.Leader:
            if peer_info.state == common_pb.DataSourceState.Ready:
                state_changed = True
        elif peer_info.state == common_pb.DataSourceState.Finished:
            state_changed = True
        if state_changed:
            self._data_source.state = common_pb.DataSourceState.Finished
            self._update_data_source(self._data_source)
            return True
        return False

    def _fsm_finished_action(self, peer_info):
        return False

    def _fsm_failed_action(self, peer_info):
        if peer_info.state != common_pb.DataSourceState.Failed:
            self._peer_client.AbortDataSource(self._data_source_meta)
        return False

    def _fallback_failed_state(self, peer_info):
        state = self._data_source.state
        if (state in self.INVALID_PEER_FSM_STATE
                and peer_info.state in self.INVALID_PEER_FSM_STATE[state]):
            self._data_source.state = common_pb.DataSourceState.Failed
            self._update_data_source(self._data_source)
            return True
        return False

    def _update_data_source(self, data_source):
        self._data_source = None
        try:
            self._etcd.set_data(self._master_etcd_key,
                                text_format.MessageToString(data_source))
        except Exception as e:
            logging.error("Failed to update data source: %s since "\
                          "exception: %s", self._data_source_name, e)
            raise
        self._data_source = data_source
        logging.debug("Success update to update data source: %s.",
                      self._data_source_name)

    def _get_peer_data_source_state(self):
        return self._peer_client.GetDataSourceState(self._data_source_meta)

    def _all_partition_finished(self):
        all_manifest = self._raw_data_manifest_manager.list_all_manifest()
        assert (len(all_manifest) ==
                self._data_source.data_source_meta.partition_num)
        for manifest in all_manifest.values():
            if manifest.state != dj_pb.RawDataState.Done:
                return False
        return True