def start_process(self):
        with self._lock:
            if not self._started:
                self._worker_map = {
                    self._id_batch_fetcher_name(): RoutineWorker(
                    self._id_batch_fetcher_name(),
                    self._id_batch_fetch_fn,
                    self._id_batch_fetch_cond, 5),

                    self._psi_rsa_signer_name(): RoutineWorker(
                    self._psi_rsa_signer_name(),
                    self._psi_rsa_sign_fn,
                    self._psi_rsa_sign_cond, 5),

                    self._sort_run_dumper_name(): RoutineWorker(
                    self._sort_run_dumper_name(),
                    self._sort_run_dump_fn,
                    self._sort_run_dump_cond, 5),

                    self._sort_run_merger_name(): RoutineWorker(
                    self._sort_run_merger_name(),
                    self._sort_run_merge_fn,
                    self._sort_run_merge_cond, 5)
                }
                for _, w in self._worker_map.items():
                    w.start_routine()
                self._started = True
示例#2
0
 def start(self):
     self._bg_worker = RoutineWorker(
             'portal_master_bg_worker',
             self._data_portal_job_manager.backgroup_task,
             lambda: True, 30
         )
     self._bg_worker.start_routine()
 def start_dump_worker(self):
     with self._lock:
         if not self._started:
             assert self._example_id_dump_worker is None
             self._example_id_dump_worker = RoutineWorker(
                 'example_id_dumper', self._dump_example_ids_fn,
                 self._dump_example_ids_cond, 1)
             self._example_id_dump_worker.start_routine()
             self._started = True
示例#4
0
 def start_dump_worker(self):
     with self._lock:
         if not self._started:
             assert self._data_block_dump_worker is None
             self._data_block_dump_worker = RoutineWorker(
                 'data_block_dumper', self._dump_data_block_fn,
                 self._dump_data_block_cond, 1)
             self._data_block_dump_worker.start_routine()
             self._started = True
 def start_fsm_worker(self):
     with self._lock:
         if not self._started:
             assert self._fsm_worker is None
             self._started = True
             self._fsm_worker = RoutineWorker(
                 '{}_fsm_worker'.format(self._data_source_name),
                 self._fsm_routine_fn, self._fsm_routine_cond, 5)
             self._fsm_worker.start_routine()
 def __init__(self, options):
     self._options = options
     self._raw_data_batch_fetcher = RawDataBatchFetcher(options)
     self._fetch_worker = RoutineWorker('raw_data_batch_fetcher',
                                        self._raw_data_batch_fetch_fn,
                                        self._raw_data_batch_fetch_cond, 5)
     self._next_part_index = 0
     self._cond = threading.Condition()
     self._fetch_worker.start_routine()
示例#7
0
 def start_dump_worker(self):
     with self._lock:
         if not self._started:
             assert self._dump_worker is None, \
                 "dumper woker for {} should be None if "\
                 "not started".format(self._repr_str)
             self._dump_worker = RoutineWorker(
                 self._repr_str + '-dump_worker', self._dump_fn,
                 self._dump_cond, 1)
             self._dump_worker.start_routine()
             self._started = True
 def start_notify_worker(self):
     with self._lock:
         if not self._started:
             assert self._notify_worker is None, \
                 "notify worker should be None if not started"
             self._notify_worker = RoutineWorker('potral-raw_data-notifier',
                                                 self._raw_data_notify_fn,
                                                 self._raw_data_notify_cond,
                                                 5)
             self._notify_worker.start_routine()
             self._started = True
         self._notify_worker.wakeup()
示例#9
0
class DataPortalMaster(dp_grpc.DataPortalMasterServiceServicer):
    def __init__(self, portal_name, kvstore, portal_options):
        super(DataPortalMaster, self).__init__()
        self._portal_name = portal_name
        self._kvstore = kvstore
        self._portal_options = portal_options
        self._data_portal_job_manager = DataPortalJobManager(
            self._kvstore,
            self._portal_name,
            self._portal_options.long_running,
            self._portal_options.check_success_tag,
            self._portal_options.single_subfolder,
            self._portal_options.files_per_job_limit,
            start_date=self._portal_options.start_date,
            end_date=self._portal_options.end_date)
        self._bg_worker = None

    def GetDataPortalManifest(self, request, context):
        return self._data_portal_job_manager.get_portal_manifest()

    def RequestNewTask(self, request, context):
        response = dp_pb.NewTaskResponse()
        finished, task = \
            self._data_portal_job_manager.alloc_task(request.rank_id)
        if task is not None:
            if isinstance(task, dp_pb.MapTask):
                response.map_task.MergeFrom(task)
            else:
                assert isinstance(task, dp_pb.ReduceTask)
                response.reduce_task.MergeFrom(task)
        elif not finished:
            response.pending.MergeFrom(empty_pb2.Empty())
        else:
            response.finished.MergeFrom(empty_pb2.Empty())
        return response

    def FinishTask(self, request, context):
        self._data_portal_job_manager.finish_task(request.rank_id,
                                                  request.partition_id,
                                                  request.part_state)
        return common_pb.Status()

    def start(self):
        self._bg_worker = RoutineWorker(
            'portal_master_bg_worker',
            self._data_portal_job_manager.backgroup_task, lambda: True, 30)
        self._bg_worker.start_routine()

    def stop(self):
        if self._bg_worker is not None:
            self._bg_worker.stop_routine()
        self._bg_worker = None
 def start_process(self):
     with self._cond:
         if not self._started:
             self._worker_map = {
                 'raw_data_batch_fetcher':
                 RoutineWorker('raw_data_batch_fetcher',
                               self._raw_data_batch_fetch_fn,
                               self._raw_data_batch_fetch_cond, 5),
                 'raw_data_partitioner':
                 RoutineWorker('raw_data_partitioner',
                               self._raw_data_part_fn,
                               self._raw_data_part_cond, 5)
             }
             for _, w in self._worker_map.items():
                 w.start_routine()
             self._started = True
示例#11
0
 def start_routine_workers(self):
     with self._lock:
         if not self._started:
             self._worker_map = {
                 'sync_partition_allocator':
                 RoutineWorker('sync_partition_allocator',
                               self._allocate_sync_partition_fn,
                               self._allocate_sync_partition_cond, 5),
                 'follower_example_id_syncer':
                 RoutineWorker('follower_example_id_syncer',
                               self._sync_follower_example_id_fn,
                               self._sync_follower_example_id_cond, 5),
             }
             for _, w in self._worker_map.items():
                 w.start_routine()
             self._state = _SyncState.ALLOC_SYNC_PARTITION
             self._started = True
示例#12
0
 def start_routine_workers(self):
     with self._lock:
         if not self._started:
             self._worker_map = {
                 'join_partition_allocator':
                 RoutineWorker('join_partition_allocator',
                               self._allocate_join_partition_fn,
                               self._allocate_join_partition_cond, 5),
                 'leader_example_joiner':
                 RoutineWorker('leader_example_joiner',
                               self._join_leader_example_fn,
                               self._join_leader_example_cond, 5),
                 'data_block_meta_syncer':
                 RoutineWorker('data_block_meta_syncer',
                               self._sync_data_block_meta_fn,
                               self._sync_data_block_meta_cond, 5),
             }
             for _, w in self._worker_map.items():
                 w.start_routine()
             self._state = _JoinState.ALLOC_JOIN_PARTITION
             self._started = True
示例#13
0
 def start_routine_workers(self):
     with self._lock:
         if not self._started:
             self._worker_map = {
                 self._input_data_ready_sniffer_name():
                 RoutineWorker(self._input_data_ready_sniffer_name(),
                               self._input_data_ready_sniff_fn,
                               self._input_data_ready_sniff_cond, 5),
                 self._repart_executor_name():
                 RoutineWorker(self._repart_executor_name(),
                               self._repart_execute_fn,
                               self._repart_execute_cond, 5),
                 self._committed_datetime_forwarder_name():
                 RoutineWorker(self._committed_datetime_forwarder_name(),
                               self._committed_datetime_forward_fn,
                               self._committed_datetime_forward_cond, 5)
             }
             self._started = True
             for _, w in self._worker_map.items():
                 w.start_routine()
             self._started = True
示例#14
0
 def start_routine_workers(self):
     with self._lock:
         if not self._started:
             self._worker_map = {
                 self._partition_allocator_name():
                 RoutineWorker(self._partition_allocator_name(),
                               self._allocate_new_partition_fn,
                               self._allocate_new_partition_cond, 5),
                 self._producer_name():
                 RoutineWorker(self._producer_name(),
                               self._data_producer_fn,
                               self._data_producer_cond, 5),
                 self._consumer_name():
                 RoutineWorker(self._consumer_name(),
                               self._data_consumer_fn,
                               self._data_consumer_cond, 5),
             }
             for _, w in self._worker_map.items():
                 w.start_routine()
             self._started = True
             self._wakeup_new_partition_allocator()
示例#15
0
class MasterFSM(object):
    INVALID_PEER_FSM_STATE = {}
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Init] = set([
        common_pb.DataSourceState.Failed, common_pb.DataSourceState.Ready,
        common_pb.DataSourceState.Finished
    ])
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Processing] = set(
        [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Finished])
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Ready] = set(
        [common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init])
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Finished] = set([
        common_pb.DataSourceState.Failed, common_pb.DataSourceState.Init,
        common_pb.DataSourceState.Processing
    ])

    def __init__(self, peer_client, data_source_name, etcd):
        self._lock = threading.Lock()
        self._peer_client = peer_client
        self._data_source_name = data_source_name
        self._master_etcd_key = os.path.join(data_source_name, 'master')
        self._etcd = etcd
        self._init_fsm_action()
        self._data_source = None
        self._sync_data_source()
        assert self._data_source is not None
        self._raw_data_manifest_manager = RawDataManifestManager(
            etcd, self._data_source)
        assert self._data_source is not None
        self._data_source_meta = self._data_source.data_source_meta
        if self._data_source.role == common_pb.FLRole.Leader:
            self._role_repr = "leader"
        else:
            self._role_repr = "follower"
        self._fsm_worker = None
        self._started = False

    def get_mainifest_manager(self):
        return self._raw_data_manifest_manager

    def get_data_source(self):
        with self._lock:
            self._sync_data_source()
            assert self._data_source is not None
            return self._data_source

    def set_failed(self):
        return self.set_state(common_pb.DataSourceState.Failed, None)

    def set_state(self, new_state, origin_state=None):
        with self._lock:
            try:
                self._sync_data_source()
                assert self._data_source is not None
                if self._data_source.state == new_state:
                    return True
                if (origin_state is None
                        or self._data_source.state == origin_state):
                    self._data_source.state = new_state
                    self._update_data_source(self._data_source)
                    return True
                logging.warning(
                    "DataSource: %s failed to set to state: "
                    "%d, origin state mismatch(%d != %d)",
                    self._data_source_name, new_state, origin_state,
                    self._data_source.state)
                return False
            except Exception as e:  # pylint: disable=broad-except
                logging.warning("Faile to set state to %d with exception %s",
                                new_state, e)
                return False
            return True

    def start_fsm_worker(self):
        with self._lock:
            if not self._started:
                assert self._fsm_worker is None
                self._started = True
                self._fsm_worker = RoutineWorker(
                    '{}_fsm_worker'.format(self._data_source_name),
                    self._fsm_routine_fn, self._fsm_routine_cond, 5)
                self._fsm_worker.start_routine()

    def stop_fsm_worker(self):
        tmp_worker = None
        with self._lock:
            if self._fsm_worker is not None:
                tmp_worker = self._fsm_worker
                self._fsm_worker = None
        if tmp_worker is not None:
            tmp_worker.stop_routine()

    def _fsm_routine_fn(self):
        peer_info = self._get_peer_data_source_state()
        if peer_info.status.code != 0:
            logging.error("Failed to get peer state, %s",
                          peer_info.status.error_message)
            return
        with self._lock:
            self._sync_data_source()
            assert self._data_source is not None
            state = self._data_source.state
            if self._fallback_failed_state(peer_info):
                logging.warning(
                        "Self(%s) at state %d, Peer at state %d "\
                        "state invalid! abort data source %s",
                        self._role_repr, state,
                        peer_info.state, self._data_source_name
                    )
            elif state not in self._fsm_driven_handle:
                logging.error("Self(%s) at error state %d for data_source %s",
                              self._role_repr, state, self._data_source_name)
            else:
                state_changed = self._fsm_driven_handle[state](peer_info)
                if state_changed:
                    self._sync_data_source()
                    assert self._data_source is not None
                    new_state = self._data_source.state
                    logging.warning("Self(%s) state changed from %d to %d",
                                    self._role_repr, state, new_state)

    def _fsm_routine_cond(self):
        return True

    def _sync_data_source(self):
        if self._data_source is None:
            raw_data = self._etcd.get_data(self._master_etcd_key)
            if raw_data is None:
                raise ValueError("etcd master key is None for {}".format(
                    self._data_source_name))
            self._data_source = text_format.Parse(raw_data,
                                                  common_pb.DataSource())

    def _init_fsm_action(self):
        self._fsm_driven_handle = {
            common_pb.DataSourceState.UnKnown:
            self._get_fsm_action('unknown'),
            common_pb.DataSourceState.Init:
            self._get_fsm_action('init'),
            common_pb.DataSourceState.Processing:
            self._get_fsm_action('processing'),
            common_pb.DataSourceState.Ready:
            self._get_fsm_action('ready'),
            common_pb.DataSourceState.Finished:
            self._get_fsm_action('finished'),
            common_pb.DataSourceState.Failed:
            self._get_fsm_action('failed')
        }

    def _get_fsm_action(self, action):
        def _not_implement(useless):
            raise NotImplementedError('state is not NotImplemented')

        name = '_fsm_{}_action'.format(action)
        return getattr(self, name, _not_implement)

    def _fsm_init_action(self, peer_info):
        state_changed = False
        if self._data_source.role == common_pb.FLRole.Leader:
            if peer_info.state == common_pb.DataSourceState.Init:
                state_changed = True
        elif peer_info.state == common_pb.DataSourceState.Processing:
            state_changed = True
        if state_changed:
            self._data_source.state = common_pb.DataSourceState.Processing
            self._update_data_source(self._data_source)
            return True
        return False

    def _fsm_processing_action(self, peer_info):
        if self._all_partition_finished():
            state_changed = False
            if self._data_source.role == common_pb.FLRole.Leader:
                if peer_info.state == common_pb.DataSourceState.Processing:
                    state_changed = True
            elif peer_info.state == common_pb.DataSourceState.Ready:
                state_changed = True
            if state_changed:
                self._data_source.state = common_pb.DataSourceState.Ready
                self._update_data_source(self._data_source)
                return True
        return False

    def _fsm_ready_action(self, peer_info):
        state_changed = False
        if self._data_source.role == common_pb.FLRole.Leader:
            if peer_info.state == common_pb.DataSourceState.Ready:
                state_changed = True
        elif peer_info.state == common_pb.DataSourceState.Finished:
            state_changed = True
        if state_changed:
            self._data_source.state = common_pb.DataSourceState.Finished
            self._update_data_source(self._data_source)
            return True
        return False

    def _fsm_finished_action(self, peer_info):
        return False

    def _fsm_failed_action(self, peer_info):
        if peer_info.state != common_pb.DataSourceState.Failed:
            self._peer_client.AbortDataSource(self._data_source_meta)
        return False

    def _fallback_failed_state(self, peer_info):
        state = self._data_source.state
        if (state in self.INVALID_PEER_FSM_STATE
                and peer_info.state in self.INVALID_PEER_FSM_STATE[state]):
            self._data_source.state = common_pb.DataSourceState.Failed
            self._update_data_source(self._data_source)
            return True
        return False

    def _update_data_source(self, data_source):
        self._data_source = None
        try:
            self._etcd.set_data(self._master_etcd_key,
                                text_format.MessageToString(data_source))
        except Exception as e:
            logging.error("Failed to update data source: %s since "\
                          "exception: %s", self._data_source_name, e)
            raise
        self._data_source = data_source
        logging.debug("Success update to update data source: %s.",
                      self._data_source_name)

    def _get_peer_data_source_state(self):
        return self._peer_client.GetDataSourceState(self._data_source_meta)

    def _all_partition_finished(self):
        all_manifest = self._raw_data_manifest_manager.list_all_manifest()
        assert (len(all_manifest) ==
                self._data_source.data_source_meta.partition_num)
        for manifest in all_manifest.values():
            if manifest.state != dj_pb.RawDataState.Done:
                return False
        return True
示例#16
0
class ExampleJoinFollower(object):
    def __init__(self, etcd, data_source):
        self._lock = threading.Lock()
        self._etcd = etcd
        self._data_source = data_source
        self._data_block_dump_manager = None
        self._data_block_dump_worker = None
        self._started = False

    def start_create_data_block(self, partition_id):
        with self._lock:
            dump_manager = self._data_block_dump_manager
            if (dump_manager is not None
                    and dump_manager.get_partition_id() != partition_id):
                raise RuntimeError("partition {} is not finished".format(
                    dump_manager.get_partition_id()))
            if dump_manager is None:
                self._data_block_dump_manager = DataBlockDumperManager(
                    self._etcd, self._data_source, partition_id)
                dump_manager = self._data_block_dump_manager
            next_index = dump_manager.get_next_data_block_index()
            return next_index

    def add_synced_data_block_meta(self, meta):
        with self._lock:
            self._check_status(meta.partition_id)
            manager = self._data_block_dump_manager
            return manager.append_synced_data_block_meta(meta)

    def finish_sync_data_block_meta(self, partition_id):
        with self._lock:
            self._check_status(partition_id)
            self._data_block_dump_manager.finish_sync_data_block_meta()
            return not self._data_block_dump_manager.need_dump()

    def get_processing_partition_id(self):
        with self._lock:
            if self._data_block_dump_manager is None:
                return None
            return self._data_block_dump_manager.get_partition_id()

    def reset_dump_partition(self):
        with self._lock:
            if self._data_block_dump_manager is None:
                return
            dump_manager = self._data_block_dump_manager
            partition_id = dump_manager.get_partition_id()
            self._check_status(partition_id)
            if (not dump_manager.data_block_meta_sync_finished()
                    or dump_manager.need_dump()):
                raise RuntimeError(
                    "partition {} is dumpping".format(partition_id))
            self._data_block_dump_manager = None

    def start_dump_worker(self):
        with self._lock:
            if not self._started:
                assert self._data_block_dump_worker is None
                self._data_block_dump_worker = RoutineWorker(
                    'data_block_dumper', self._dump_data_block_fn,
                    self._dump_data_block_cond, 1)
                self._data_block_dump_worker.start_routine()
                self._started = True

    def stop_dump_worker(self):
        dumper = None
        with self._lock:
            if self._data_block_dump_worker is not None:
                dumper = self._data_block_dump_worker
                self._data_block_dump_worker = None
        if dumper is not None:
            dumper.stop_routine()

    def _check_status(self, partition_id):
        if self._data_block_dump_manager is None:
            raise RuntimeError("no partition is processing")
        ptn_id = self._data_block_dump_manager.get_partition_id()
        if partition_id != ptn_id:
            raise RuntimeError("partition id mismatch {} != {}".format(
                partition_id, ptn_id))

    def _dump_data_block_fn(self):
        dump_manager = self._data_block_dump_manager
        assert dump_manager is not None
        if dump_manager.need_dump():
            dump_manager.dump_data_blocks()

    def _dump_data_block_cond(self):
        with self._lock:
            return (self._data_block_dump_manager is not None
                    and self._data_block_dump_manager.need_dump())
示例#17
0
class TransmitFollower(object):
    class ImplContext(object):
        def __init__(self, partition_id):
            self.partition_id = partition_id

        def get_next_index(self):
            raise NotImplementedError("get_next_index is not Implemented "\
                                      "in base ImplContext")

        def get_dumped_index(self):
            raise NotImplementedError("get_dumped_index is not Implemented "\
                                      "in base ImplContext")

        def add_synced_content(self, sync_ctnt):
            raise NotImplementedError("add_synced_content is not Implemented "\
                                      "in base ImplContext")

        def finish_sync_content(self):
            raise NotImplementedError("finish_sync_content is not Implemented "\
                                      "in base ImplContext")

        def need_dump(self):
            raise NotImplementedError("need_dump is not Implemented "\
                                      "in base ImplContext")

        def make_dumper(self):
            raise NotImplementedError("make_dumper is not Implemented "\
                                      "in base ImplContext")

        def is_sync_content_finished(self):
            raise NotImplementedError("is_sync_content_finished is not "\
                                      "Implemented in base ImplContext")

    def __init__(self, etcd, data_source, repr_str):
        self._lock = threading.Lock()
        self._etcd = etcd
        self._data_source = data_source
        self._repr_str = repr_str
        self._dump_worker = None
        self._impl_ctx = None
        self._started = False

    @metrics.timer(func_name='start_sync_partition',
                   tags={'role': 'transmit_follower'})
    def start_sync_partition(self, partition_id):
        with self._lock:
            if self._impl_ctx is not None and \
                    self._impl_ctx.partition_id != partition_id:
                raise RuntimeError("{} is processing partition {}".format(
                    self._repr_str, self._impl_ctx.partition_id))
            if self._impl_ctx is None:
                self._impl_ctx = self._make_new_impl_ctx(partition_id)
            return self._impl_ctx.get_next_index(), \
                    self._impl_ctx.get_dumped_index()

    @metrics.timer(func_name='add_synced_item',
                   tags={'role': 'transmit_follower'})
    def add_synced_item(self, sync_ctnt):
        with self._lock:
            partition_id = \
                self._extract_partition_id_from_sync_content(sync_ctnt)
            self._check_status(partition_id)
            filled, next_index = self._impl_ctx.add_synced_content(sync_ctnt)
            if filled:
                self._dump_worker.wakeup()
            return filled, next_index, self._impl_ctx.get_dumped_index()

    @metrics.timer(func_name='finish_sync_partition',
                   tags={'role': 'transmit_follower'})
    def finish_sync_partition(self, partition_id):
        with self._lock:
            self._check_status(partition_id)
            self._impl_ctx.finish_sync_content()
            return not self._impl_ctx.need_dump(), \
                    self._impl_ctx.get_dumped_index()

    @metrics.timer(func_name='reset_partition',
                   tags={'role': 'transmit_follower'})
    def reset_partition(self, partition_id):
        with self._lock:
            if not self._check_status(partition_id, False):
                return
            if not self._impl_ctx.is_sync_content_finished() or \
                    self._impl_ctx.need_dump():
                raise RuntimeError("{} is still dumping for partition {}"\
                                   .format(self._repr_str, partition_id))
            self._impl_ctx = None

    def get_processing_partition_id(self):
        with self._lock:
            if self._impl_ctx is None:
                return None
            return self._impl_ctx.partition_id

    def start_dump_worker(self):
        with self._lock:
            if not self._started:
                assert self._dump_worker is None, \
                    "dumper woker for {} should be None if "\
                    "not started".format(self._repr_str)
                self._dump_worker = RoutineWorker(
                    self._repr_str + '-dump_worker', self._dump_fn,
                    self._dump_cond, 1)
                self._dump_worker.start_routine()
                self._started = True

    def stop_dump_worker(self):
        dumper_worker = None
        with self._lock:
            dumper_worker = self._dump_worker
            self._dump_worker = None
        if dumper_worker is not None:
            dumper_worker.stop_routine()

    def _check_status(self, partition_id, raise_exception=True):
        if self._impl_ctx is None:
            if not raise_exception:
                return False
            raise RuntimeError("no partition is processing")
        if self._impl_ctx.partition_id != partition_id:
            if not raise_exception:
                return False
            raise RuntimeError("partition id mismatch {} != {} for {}".format(
                self._impl_ctx.partition_id, partition_id, self._repr_str))
        return True

    def _dump_fn(self, impl_ctx):
        with impl_ctx.make_dumper() as dumper:
            dumper()

    def _dump_cond(self):
        with self._lock:
            if self._impl_ctx is not None and self._impl_ctx.need_dump():
                self._dump_worker.setup_args(self._impl_ctx)
                return True
            return False

    def _make_new_impl_ctx(self, partition_id):
        raise NotImplementedError("_make_new_impl_ctx is not Implemented "\
                                  "in base TransmitFollower")

    def _extract_partition_id_from_sync_content(self, sync_content):
        raise NotImplementedError("_extract_partition_id_from_sync_content "\
                                  "is not Implemented in base TransmitFollower")
class RawDataPartitioner(object):
    class OutputFileWriter(object):
        def __init__(self, options, partition_id):
            self._options = options
            self._partition_id = partition_id
            self._process_index = 0
            self._writer = None
            self._dumped_item = 0
            self._output_fpaths = []
            self._output_dir = os.path.join(
                    self._options.output_dir,
                    common.partition_repr(self._partition_id)
                )
            if not gfile.Exists(self._output_dir):
                gfile.MakeDirs(self._output_dir)
            assert gfile.IsDirectory(self._output_dir)

        def append_item(self, index, item):
            writer = self._get_output_writer()
            if self._options.output_builder == 'TF_RECORD':
                writer.write(item.tf_record)
            else:
                assert self._options.output_builder == 'CSV_DICT'
                writer.write(item.csv_record)
            self._dumped_item += 1
            if self._dumped_item >= self._options.output_item_threshold:
                self._finish_writer()
                if self._process_index % 16 == 0:
                    logging.info("Output partition %d dump %d files, "\
                                 "last index %d", self._partition_id,
                                 self._process_index, index)

        def finish(self):
            self._finish_writer()

        def get_output_files(self):
            return self._output_fpaths

        def _get_output_writer(self):
            if self._writer is None:
                self._new_writer()
            return self._writer

        def _new_writer(self):
            assert self._writer is None
            fname = "{:04}-{:08}.rd".format(
                    self._options.partitioner_rank_id,
                    self._process_index
                )
            fpath = os.path.join(self._output_dir, fname)
            self._output_fpaths.append(fpath)
            if self._options.output_builder == 'TF_RECORD':
                self._writer = tf.io.TFRecordWriter(fpath)
            else:
                assert self._options.output_builder == 'CSV_DICT'
                self._writer = CsvDictWriter(fpath)
            self._dumped_item = 0

        def _finish_writer(self):
            if self._writer is not None:
                self._writer.close()
                self._writer = None
            self._dumped_item = 0
            self._process_index += 1

    def __init__(self, options):
        self._options = options
        self._raw_data_batch_fetcher = RawDataBatchFetcher(options)
        self._fetch_worker = RoutineWorker('raw_data_batch_fetcher',
                                           self._raw_data_batch_fetch_fn,
                                           self._raw_data_batch_fetch_cond, 5)
        self._next_part_index = 0
        self._cond = threading.Condition()
        self._fetch_worker.start_routine()

    def partition(self):
        if self._check_finished_tag():
            logging.warning("partition has finished for rank id of parti"\
                            "tioner %d", self._options.partitioner_rank_id)
            return
        next_index = 0
        hint_index = 0
        fetch_finished = False
        fetcher = self._raw_data_batch_fetcher
        writers = [RawDataPartitioner.OutputFileWriter(self._options, pid)
                   for pid in range(self._options.output_partition_num)]
        iter_round = 0
        bp_options = self._options.batch_processor_options
        signal_round_threhold = bp_options.max_flying_item / \
                bp_options.batch_size // 8
        while not fetch_finished:
            fetch_finished, batch, hint_index = \
                    fetcher.fetch_item_batch_by_index(next_index, hint_index)
            iter_round += 1
            if batch is not None:
                for index, item in enumerate(batch):
                    raw_id = item.raw_id
                    partition_id = CityHash32(raw_id) % \
                            self._options.output_partition_num
                    writer = writers[partition_id]
                    writer.append_item(batch.begin_index+index, item)
                next_index = batch.begin_index + len(batch)
                if iter_round % signal_round_threhold == 0:
                    hint_index = self._evict_staless_batch(hint_index,
                                                           next_index-1)
                    logging.info("consumed %d items", next_index-1)
                self._set_next_part_index(next_index)
                self._wakeup_raw_data_fetcher()
            elif not fetch_finished:
                hint_index = self._evict_staless_batch(hint_index,
                                                       next_index-1)
                with self._cond:
                    self._cond.wait(1)
        for partition_id, writer in enumerate(writers):
            writer.finish()
            fpaths = writer.get_output_files()
            logging.info("part %d output %d files by partitioner",
                          partition_id, len(fpaths))
            for fpath in fpaths:
                logging.info("%s", fpath)
            logging.info("-----------------------------------")
        self._dump_finished_tag()
        self._fetch_worker.stop_routine()

    def _evict_staless_batch(self, hint_index, staless_index):
        evict_cnt = self._raw_data_batch_fetcher.evict_staless_item_batch(
                staless_index
            )
        if hint_index <= evict_cnt:
            return 0
        return hint_index-evict_cnt

    def _set_next_part_index(self, next_part_index):
        with self._cond:
            self._next_part_index = next_part_index

    def _get_next_part_index(self):
        with self._cond:
            return self._next_part_index

    def _raw_data_batch_fetch_fn(self):
        next_part_index = self._get_next_part_index()
        fetcher = self._raw_data_batch_fetcher
        for batch in fetcher.make_processor(next_part_index):
            logging.debug("fetch batch begin at %d, len %d. wakeup "\
                          "partitioner", batch.begin_index, len(batch))
            self._wakeup_partitioner()

    def _raw_data_batch_fetch_cond(self):
        next_part_index = self._get_next_part_index()
        return self._raw_data_batch_fetcher.need_process(next_part_index)

    def _wakeup_partitioner(self):
        with self._cond:
            self._cond.notify_all()

    def _wakeup_raw_data_fetcher(self):
        self._fetch_worker.wakeup()

    def _dump_finished_tag(self):
        finished_tag_fpath = self._get_finished_tag_fpath()
        with gfile.GFile(finished_tag_fpath, 'w') as fh:
            fh.write('')

    def _check_finished_tag(self):
        return gfile.Exists(self._get_finished_tag_fpath())

    def _get_finished_tag_fpath(self):
        return os.path.join(
                self._options.output_dir,
                '_SUCCESS.{:08}'.format(self._options.partitioner_rank_id)
            )
示例#19
0
class MasterFSM(object):
    INVALID_PEER_FSM_STATE = {}
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Init] = set(
            [common_pb.DataSourceState.Failed,
             common_pb.DataSourceState.Ready,
             common_pb.DataSourceState.Finished]
        )
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Processing] = set(
            [common_pb.DataSourceState.Failed,
             common_pb.DataSourceState.Finished]
        )
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Ready] = set(
            [common_pb.DataSourceState.Failed,
             common_pb.DataSourceState.Init]
        )
    INVALID_PEER_FSM_STATE[common_pb.DataSourceState.Finished] = set(
            [common_pb.DataSourceState.Failed,
             common_pb.DataSourceState.Init,
             common_pb.DataSourceState.Processing]
        )

    def __init__(self, peer_client, data_source_name, etcd):
        self._lock = threading.Lock()
        self._peer_client = peer_client
        self._data_source_name = data_source_name
        self._etcd = etcd
        self._init_fsm_action()
        self._data_source = None
        self._sync_data_source()
        assert self._data_source is not None, \
            "data source must not None if sync data source success"
        self._raw_data_manifest_manager = RawDataManifestManager(
                etcd, self._data_source
            )
        self._data_source_meta = self._data_source.data_source_meta
        if self._data_source.role == common_pb.FLRole.Leader:
            self._role_repr = "leader"
        else:
            self._role_repr = "follower"
        self._fsm_worker = None
        self._started = False

    def get_mainifest_manager(self):
        return self._raw_data_manifest_manager

    def get_data_source(self):
        with self._lock:
            self._sync_data_source()
            return self._data_source

    def set_failed(self):
        return self.set_state(common_pb.DataSourceState.Failed, None)

    def set_state(self, new_state, origin_state=None):
        with self._lock:
            try:
                self._sync_data_source()
                if self._data_source.state == new_state:
                    return True
                if (origin_state is None or
                        self._data_source.state == origin_state):
                    self._data_source.state = new_state
                    self._update_data_source(self._data_source)
                    return True
                logging.warning("DataSource: %s failed to set to state: "
                                "%d, origin state mismatch(%d != %d)",
                                self._data_source_name, new_state,
                                origin_state, self._data_source.state)
                return False
            except Exception as e: # pylint: disable=broad-except
                logging.warning("Faile to set state to %d with exception %s",
                                new_state, e)
                return False
            return True

    def start_fsm_worker(self):
        with self._lock:
            if not self._started:
                assert self._fsm_worker is None, \
                    "fsm_woker must be None if FSM is not started"
                self._started = True
                self._fsm_worker = RoutineWorker(
                        '{}_fsm_worker'.format(self._data_source_name),
                        self._fsm_routine_fn,
                        self._fsm_routine_cond, 5
                    )
                self._fsm_worker.start_routine()

    def stop_fsm_worker(self):
        tmp_worker = None
        with self._lock:
            if self._fsm_worker is not None:
                tmp_worker = self._fsm_worker
                self._fsm_worker = None
        if tmp_worker is not None:
            tmp_worker.stop_routine()

    def _fsm_routine_fn(self):
        peer_info = self._get_peer_data_source_status()
        with self._lock:
            self._sync_data_source()
            state = self._data_source.state
            if self._fallback_failed_state(peer_info):
                logging.warning("%s at state %d, Peer at state %d "\
                                "state invalid! abort data source %s",
                                self._role_repr, state,
                                peer_info.state, self._data_source_name)
            elif state not in self._fsm_driven_handle:
                logging.error("%s at error state %d for data_source %s",
                               self._role_repr, state, self._data_source_name)
            else:
                state_changed = self._fsm_driven_handle[state](peer_info)
                if state_changed:
                    self._sync_data_source()
                    new_state = self._data_source.state
                    logging.warning("%s state changed from %d to %d",
                                    self._role_repr, state, new_state)
                    state = new_state
            if state in (common_pb.DataSourceState.Init,
                            common_pb.DataSourceState.Processing):
                self._raw_data_manifest_manager.sub_new_raw_data()

    def _fsm_routine_cond(self):
        return True

    def _sync_data_source(self):
        if self._data_source is None:
            self._data_source = \
                retrieve_data_source(self._etcd, self._data_source_name)

    def _init_fsm_action(self):
        self._fsm_driven_handle = {
             common_pb.DataSourceState.UnKnown:
                 self._get_fsm_action('unknown'),
             common_pb.DataSourceState.Init:
                 self._get_fsm_action('init'),
             common_pb.DataSourceState.Processing:
                 self._get_fsm_action('processing'),
             common_pb.DataSourceState.Ready:
                 self._get_fsm_action('ready'),
             common_pb.DataSourceState.Finished:
                 self._get_fsm_action('finished'),
             common_pb.DataSourceState.Failed:
                 self._get_fsm_action('failed')
        }

    def _get_fsm_action(self, action):
        def _not_implement(useless):
            raise NotImplementedError('state is not NotImplemented')
        name = '_fsm_{}_action'.format(action)
        return getattr(self, name, _not_implement)

    def _fsm_init_action(self, peer_info):
        state_changed = False
        if self._data_source.role == common_pb.FLRole.Leader:
            if peer_info.state == common_pb.DataSourceState.Init:
                state_changed = True
        elif peer_info.state == common_pb.DataSourceState.Processing:
            state_changed = True
        if state_changed:
            self._data_source.state = common_pb.DataSourceState.Processing
            self._update_data_source(self._data_source)
            return True
        return False

    def _fsm_processing_action(self, peer_info):
        if self._all_partition_finished():
            state_changed = False
            if self._data_source.role == common_pb.FLRole.Leader:
                if peer_info.state == common_pb.DataSourceState.Processing:
                    state_changed = True
            elif peer_info.state == common_pb.DataSourceState.Ready:
                state_changed = True
            if state_changed:
                self._data_source.state = common_pb.DataSourceState.Ready
                self._update_data_source(self._data_source)
                return True
        return False

    def _fsm_ready_action(self, peer_info):
        state_changed = False
        if self._data_source.role == common_pb.FLRole.Leader:
            if peer_info.state == common_pb.DataSourceState.Ready:
                state_changed = True
        elif peer_info.state == common_pb.DataSourceState.Finished:
            state_changed = True
        if state_changed:
            self._data_source.state = common_pb.DataSourceState.Finished
            self._update_data_source(self._data_source)
            return True
        return False

    def _fsm_finished_action(self, peer_info):
        return False

    def _fsm_failed_action(self, peer_info):
        if peer_info.state != common_pb.DataSourceState.Failed:
            request = dj_pb.DataSourceRequest(
                    data_source_meta=self._data_source_meta
                )
            self._peer_client.AbortDataSource(request)
        return False

    def _fallback_failed_state(self, peer_info):
        state = self._data_source.state
        if (state in self.INVALID_PEER_FSM_STATE and
                peer_info.state in self.INVALID_PEER_FSM_STATE[state]):
            self._data_source.state = common_pb.DataSourceState.Failed
            self._update_data_source(self._data_source)
            return True
        return False

    def _update_data_source(self, data_source):
        self._data_source = None
        try:
            commit_data_source(self._etcd, data_source)
        except Exception as e:
            logging.error("Failed to update data source: %s since "\
                          "exception: %s", self._data_source_name, e)
            raise
        self._data_source = data_source
        logging.debug("Success update to update data source: %s.",
                       self._data_source_name)

    def _get_peer_data_source_status(self):
        request = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_meta
            )
        return self._peer_client.GetDataSourceStatus(request)

    def _all_partition_finished(self):
        all_manifest = self._raw_data_manifest_manager.list_all_manifest()
        assert len(all_manifest) == \
                self._data_source.data_source_meta.partition_num, \
            "manifest number should same with partition number"
        for manifest in all_manifest.values():
            if manifest.sync_example_id_rep.state != \
                    dj_pb.SyncExampleIdState.Synced or \
                    manifest.join_example_rep.state != \
                    dj_pb.JoinExampleState.Joined:
                return False
        return True
class ExampleIdSyncFollower(object):
    def __init__(self, data_source):
        self._lock = threading.Lock()
        self._data_source = data_source
        self._example_id_dump_manager = None
        self._example_id_dump_worker = None
        self._started = False

    def start_dump_partition(self, partition_id):
        with self._lock:
            if (self._example_id_dump_manager is not None
                    and self._example_id_dump_manager.get_partition_id() !=
                    partition_id):
                ptn_id = self._example_id_dump_manager.get_partition_id()
                raise RuntimeError(
                    "partition {} is not finished".format(ptn_id))
            if self._example_id_dump_manager is None:
                self._example_id_dump_manager = ExampleIdDumperManager(
                    self._data_source, partition_id)
            dump_manager = self._example_id_dump_manager
            next_index = dump_manager.get_next_index()
            return next_index

    def add_synced_example_req(self, synced_example_req):
        with self._lock:
            self._check_status(synced_example_req.partition_id)
            return self._example_id_dump_manager.append_synced_example_req(
                synced_example_req)

    def finish_sync_partition_example(self, partition_id):
        with self._lock:
            self._check_status(partition_id)
            self._example_id_dump_manager.finish_sync_example()
            return not self._example_id_dump_manager.need_dump()

    def reset_dump_partition(self):
        with self._lock:
            if self._example_id_dump_manager is None:
                return
            partition_id = self._example_id_dump_manager.get_partition_id()
            self._check_status(partition_id)
            if (not self._example_id_dump_manager.example_sync_finished()
                    or self._example_id_dump_manager.need_dump()):
                raise RuntimeError(
                    "partition {} is dumpping".format(partition_id))
            self._example_id_dump_manager = None

    def get_processing_partition_id(self):
        with self._lock:
            if self._example_id_dump_manager is None:
                return None
            return self._example_id_dump_manager.get_partition_id()

    def start_dump_worker(self):
        with self._lock:
            if not self._started:
                assert self._example_id_dump_worker is None
                self._example_id_dump_worker = RoutineWorker(
                    'example_id_dumper', self._dump_example_ids_fn,
                    self._dump_example_ids_cond, 1)
                self._example_id_dump_worker.start_routine()
                self._started = True

    def stop_dump_worker(self):
        dumper = None
        with self._lock:
            if self._example_id_dump_worker is not None:
                dumper = self._example_id_dump_worker
                self._dump_worker = None
        if dumper is not None:
            dumper.stop_routine()

    def _check_status(self, partition_id):
        if self._example_id_dump_manager is None:
            raise RuntimeError("no partition is processing")
        ptn_id = self._example_id_dump_manager.get_partition_id()
        if partition_id != ptn_id:
            raise RuntimeError("partition id mismatch {} != {}".format(
                partition_id, ptn_id))

    def _dump_example_ids_fn(self):
        dump_manager = None
        with self._lock:
            dump_manager = self._example_id_dump_manager
        assert dump_manager is not None
        if dump_manager.need_dump():
            dump_manager.dump_example_ids()

    def _dump_example_ids_cond(self):
        with self._lock:
            return (self._example_id_dump_manager is not None
                    and self._example_id_dump_manager.need_dump())
class PortalRawDataNotifier(object):
    class NotifyCtx(object):
        def __init__(self, master_addr):
            self._master_addr = master_addr
            channel = make_insecure_channel(master_addr, ChannelType.INTERNAL)
            self._master_cli = dj_grpc.DataJoinMasterServiceStub(channel)
            self._data_source = None
            self._raw_date_ctl = None
            self._raw_data_updated_datetime = {}

        @property
        def data_source(self):
            if self._data_source is None:
                self._data_source = \
                        self._master_cli.GetDataSource(empty_pb2.Empty())
            return self._data_source

        def get_raw_data_updated_datetime(self, partition_id):
            if partition_id not in self._raw_data_updated_datetime:
                ts = self._raw_date_controller.get_raw_data_latest_timestamp(
                    partition_id)
                if ts.seconds > 3600:
                    ts.seconds -= 3600
                else:
                    ts.seconds = 0
                self._raw_data_updated_datetime[partition_id] = \
                        common.convert_timestamp_to_datetime(
                                common.trim_timestamp_by_hourly(ts)
                            )
            return self._raw_data_updated_datetime[partition_id]

        def add_raw_data(self, partition_id, fpaths, timestamps, end_ts):
            assert len(fpaths) == len(timestamps), \
                "the number of raw data path and timestamp should same"
            if len(fpaths) > 0:
                self._raw_date_controller.add_raw_data(partition_id, fpaths,
                                                       True, timestamps)
                self._raw_data_updated_datetime[partition_id] = end_ts

        @property
        def data_source_master_addr(self):
            return self._master_addr

        @property
        def _raw_date_controller(self):
            if self._raw_date_ctl is None:
                self._raw_date_ctl = RawDataController(self.data_source,
                                                       self._master_cli)
            return self._raw_date_ctl

    def __init__(self, etcd, portal_name, downstream_data_source_masters):
        self._lock = threading.Lock()
        self._etcd = etcd
        self._portal_name = portal_name
        assert len(downstream_data_source_masters) > 0, \
            "PortalRawDataNotifier launched when has master to notify"
        self._master_notify_ctx = {}
        for addr in downstream_data_source_masters:
            self._master_notify_ctx[addr] = \
                    PortalRawDataNotifier.NotifyCtx(addr)
        self._notify_worker = None
        self._started = False

    def start_notify_worker(self):
        with self._lock:
            if not self._started:
                assert self._notify_worker is None, \
                    "notify worker should be None if not started"
                self._notify_worker = RoutineWorker('potral-raw_data-notifier',
                                                    self._raw_data_notify_fn,
                                                    self._raw_data_notify_cond,
                                                    5)
                self._notify_worker.start_routine()
                self._started = True
            self._notify_worker.wakeup()

    def stop_notify_worker(self):
        notify_worker = None
        with self._lock:
            notify_worker = self._notify_worker
            self._notify_worker = None
        if notify_worker is not None:
            notify_worker.stop_routine()

    def _check_partition_num(self, notify_ctx, portal_manifest):
        assert isinstance(notify_ctx, PortalRawDataNotifier.NotifyCtx)
        data_source = notify_ctx.data_source
        ds_partition_num = data_source.data_source_meta.partition_num
        if portal_manifest.output_partition_num % ds_partition_num != 0:
            raise ValueError(
                    "the partition number({}) of down stream data source "\
                    "{} should be divised by output partition of "\
                    "portatl({})".format(ds_partition_num,
                    data_source.data_source_meta.name,
                    portal_manifest.output_partition_num)
                )

    def _add_raw_data_impl(self, notify_ctx, portal_manifest, ds_pid):
        dt = notify_ctx.get_raw_data_updated_datetime(ds_pid) + \
                timedelta(hours=1)
        begin_dt = common.convert_timestamp_to_datetime(
            common.trim_timestamp_by_hourly(portal_manifest.begin_timestamp))
        if dt < begin_dt:
            dt = begin_dt
        committed_dt = common.convert_timestamp_to_datetime(
            portal_manifest.committed_timestamp)
        fpaths = []
        timestamps = []
        ds_ptnum = notify_ctx.data_source.data_source_meta.partition_num
        while dt <= committed_dt:
            for pt_pid in range(ds_pid, portal_manifest.output_partition_num,
                                ds_ptnum):
                fpath = common.encode_portal_hourly_fpath(
                    portal_manifest.output_data_base_dir, dt, pt_pid)
                if gfile.Exists(fpath):
                    fpaths.append(fpath)
                    timestamps.append(common.convert_datetime_to_timestamp(dt))
            if len(fpaths) > 32 or dt == committed_dt:
                break
            dt += timedelta(hours=1)
        notify_ctx.add_raw_data(ds_pid, fpaths, timestamps, dt)
        logging.info("add %d raw data file for partition %d of data "\
                     "source %s. latest updated datetime %s",
                      len(fpaths), ds_pid,
                      notify_ctx.data_source.data_source_meta.name, dt)
        return dt >= committed_dt

    def _notify_one_data_source(self, notify_ctx, portal_manifest):
        assert isinstance(notify_ctx, PortalRawDataNotifier.NotifyCtx)
        try:
            self._check_partition_num(notify_ctx, portal_manifest)
            ds_ptnum = notify_ctx.data_source.data_source_meta.partition_num
            pt_ptnum = portal_manifest.output_partition_num
            add_finished = False
            while not add_finished:
                add_finished = True
                for ds_pid in range(ds_ptnum):
                    if not self._add_raw_data_impl(notify_ctx, portal_manifest,
                                                   ds_pid):
                        add_finished = False
        except Exception as e:  # pylint: disable=broad-except
            logging.error("Failed to notify data source[master-addr: %s] "\
                          "new raw data added, reason %s",
                          notify_ctx.data_source_master_addr, e)

    def _raw_data_notify_fn(self):
        portal_manifest = common.retrieve_portal_manifest(
            self._etcd, self._portal_name)
        for _, notify_ctx in self._master_notify_ctx.items():
            self._notify_one_data_source(notify_ctx, portal_manifest)

    def _raw_data_notify_cond(self):
        return True