def __init__(self, options, kvstore_type, use_mock_etcd=False): self._lock = threading.Condition() self._options = options kvstore = DBClient(kvstore_type, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(kvstore, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._callback_submitter = None # pre fock sub processor before launch grpc client self._process_pool_executor.submit(min, 1, 2).result() self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options) if self._options.role == common_pb.FLRole.Leader: private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, options.batch_processor_options.max_flying_item, self._options.max_flying_sign_batch, self._options.slow_sign_threshold, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem) self._callback_submitter = concur_futures.ThreadPoolExecutor(1) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, options.batch_processor_options.max_flying_item, self._options.max_flying_sign_batch, self._options.max_flying_sign_rpc, self._options.sign_rpc_timeout_ms, self._options.slow_sign_threshold, self._options.stub_fanout, self._process_pool_executor, self._callback_submitter, public_key, self._options.leader_rsa_psi_signer_addr) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( dj_pb.SortRunMergerOptions( merger_name='sort_run_merger_'+\ partition_repr(options.partition_id), reader_options=dj_pb.RawDataOptions( raw_data_iter=options.writer_options.output_writer, compressed_type=options.writer_options.compressed_type, read_ahead_size=\ options.sort_run_merger_read_ahead_buffer, read_batch_size=\ options.sort_run_merger_read_batch_size ), writer_options=options.writer_options, output_file_dir=options.output_file_dir, partition_id=options.partition_id ), self._merger_comparator ) self._produce_item_cnt = 0 self._comsume_item_cnt = 0 self._started = False
def __init__(self, kvstore, portal_name, long_running, check_success_tag): self._lock = threading.Lock() self._kvstore = kvstore self._portal_name = portal_name self._check_success_tag = check_success_tag self._portal_manifest = None self._processing_job = None self._sync_portal_manifest() self._sync_processing_job() self._publisher = \ RawDataPublisher(kvstore, self._portal_manifest.raw_data_publish_dir) self._long_running = long_running assert self._portal_manifest is not None self._processed_fpath = set() for job_id in range(0, self._portal_manifest.next_job_id): job = self._sync_portal_job(job_id) assert job is not None and job.job_id == job_id for fpath in job.fpaths: self._processed_fpath.add(fpath) self._job_part_map = {} if self._portal_manifest.processing_job_id >= 0: self._check_processing_job_finished() if self._portal_manifest.processing_job_id < 0: self._launch_new_portal_job()
def __init__(self, options, etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd=False): self._lock = threading.Condition() self._options = options etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(etcd, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._id_batch_fetcher = IdBatchFetcher(etcd, self._options) max_flying_item = options.batch_processor_options.max_flying_item if self._options.role == common_pb.FLRole.Leader: private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.slow_sign_threshold, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.max_flying_sign_rpc, self._options.sign_rpc_timeout_ms, self._options.slow_sign_threshold, self._options.stub_fanout, self._process_pool_executor, public_key, self._options.leader_rsa_psi_signer_addr) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( dj_pb.SortRunMergerOptions( merger_name='sort_run_merger_'+\ partition_repr(options.partition_id), reader_options=dj_pb.RawDataOptions( raw_data_iter=options.writer_options.output_writer, compressed_type=options.writer_options.compressed_type, read_ahead_size=\ options.sort_run_merger_read_ahead_buffer ), writer_options=options.writer_options, output_file_dir=options.output_file_dir, partition_id=options.partition_id ), 'example_id' ) self._started = False
def __init__(self, etcd, portal_manifest, portal_options): self._lock = threading.Lock() self._etcd = etcd self._portal_manifest = portal_manifest self._portal_options = portal_options self._publisher = RawDataPublisher( self._etcd, self._portal_manifest.raw_data_publish_dir) self._started = False self._input_ready_datetime = [] self._output_finished_datetime = [] self._worker_map = {}
def __init__(self, options, etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd=False): self._lock = threading.Condition() self._options = options etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(etcd, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._thread_pool_rpc_executor = None if options.rpc_sync_mode: assert options.rpc_thread_pool_size > 0 self._thread_pool_rpc_executor = concur_futures.ThreadPoolExecutor( options.rpc_thread_pool_size) self._id_batch_fetcher = IdBatchFetcher(etcd, self._options) max_flying_item = options.batch_processor_options.max_flying_item if self._options.role == common_pb.FLRole.Leader: private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.slow_sign_threshold, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.max_flying_sign_rpc, self._options.sign_rpc_timeout_ms, self._options.slow_sign_threshold, self._options.stub_fanout, self._process_pool_executor, public_key, self._options.leader_rsa_psi_signer_addr, self._thread_pool_rpc_executor) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( self._sort_run_dumper.sort_run_dump_dir(), self._options) self._started = False
def __init__(self, options, etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd=False): self._lock = threading.Condition() self._options = options etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(etcd, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._id_batch_fetcher = IdBatchFetcher(self._options) max_flying_item = options.batch_processor_options.max_flying_item if self._options.role == common_pb.FLRole.Leader: private_key = None with gfile.GFile(options.rsa_key_file_path, 'rb') as f: file_content = f.read() private_key = rsa.PrivateKey.load_pkcs1(file_content) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = None with gfile.GFile(options.rsa_key_file_path, 'rb') as f: file_content = f.read() public_key = rsa.PublicKey.load_pkcs1(file_content) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._process_pool_executor, public_key, self._options.leader_rsa_psi_signer_addr) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( self._sort_run_dumper.sort_run_dump_dir, self._options) self._worker_map = {} self._started = False
def __init__(self, kvstore, portal_name, long_running, check_success_tag, single_subfolder, files_per_job_limit, max_files_per_job=8000, start_date=None, end_date=None): self._lock = threading.Lock() self._kvstore = kvstore self._portal_name = portal_name self._check_success_tag = check_success_tag self._single_subfolder = single_subfolder self._files_per_job_limit = files_per_job_limit self._max_files_per_job = max_files_per_job self._start_date = convert_to_datetime(start_date) self._end_date = convert_to_datetime(end_date) self._portal_manifest = None self._processing_job = None self._sync_portal_manifest() self._sync_processing_job() self._publisher = \ RawDataPublisher(kvstore, self._portal_manifest.raw_data_publish_dir) self._long_running = long_running self._finished = False assert self._portal_manifest is not None self._processed_fpath = set() for job_id in range(0, self._portal_manifest.next_job_id): job = self._sync_portal_job(job_id) assert job is not None and job.job_id == job_id for fpath in job.fpaths: self._processed_fpath.add(fpath) self._job_part_map = {} if self._portal_manifest.processing_job_id >= 0: self._check_processing_job_finished() if self._portal_manifest.processing_job_id < 0: if not self._launch_new_portal_job() and not self._long_running: self._finished = True
class DataPortalJobManager(object): def __init__(self, etcd, portal_name, long_running): self._lock = threading.Lock() self._etcd = etcd self._portal_name = portal_name self._portal_manifest = None self._processing_job = None self._sync_portal_manifest() self._sync_processing_job() self._publisher = \ RawDataPublisher(etcd, self._portal_manifest.raw_data_publish_dir) self._long_running = long_running assert self._portal_manifest is not None self._processed_fpath = set() for job_id in range(0, self._portal_manifest.next_job_id): job = self._sync_portal_job(job_id) assert job is not None and job.job_id == job_id for fpath in job.fpaths: self._processed_fpath.add(fpath) self._job_part_map = {} if self._portal_manifest.processing_job_id >= 0: self._check_processing_job_finished() if self._portal_manifest.processing_job_id < 0: self._launch_new_portal_job() def get_portal_manifest(self): with self._lock: return self._sync_portal_manifest() def alloc_task(self, rank_id): with self._lock: self._sync_processing_job() if self._processing_job is not None: partition_id = self._try_to_alloc_part(rank_id, dp_pb.PartState.kInit, dp_pb.PartState.kIdMap) if partition_id is not None: return False, self._create_map_task(rank_id, partition_id) if self._all_job_part_mapped() and \ (self._portal_manifest.data_portal_type == dp_pb.DataPortalType.Streaming): partition_id = self._try_to_alloc_part( rank_id, dp_pb.PartState.kIdMapped, dp_pb.PartState.kEventTimeReduce) if partition_id is not None: return False, self._create_reduce_task(partition_id) return (not self._long_running and self._all_job_part_finished()), None return not self._long_running, None def finish_task(self, rank_id, partition_id, part_state): with self._lock: processing_job = self._sync_processing_job() if processing_job is None: return job_id = self._processing_job.job_id job_part = self._sync_job_part(job_id, partition_id) if job_part.rank_id == rank_id and \ job_part.part_state == part_state: if job_part.part_state == dp_pb.PartState.kIdMap: self._finish_job_part(job_id, partition_id, dp_pb.PartState.kIdMap, dp_pb.PartState.kIdMapped) elif job_part.part_state == dp_pb.PartState.kEventTimeReduce: self._finish_job_part(job_id, partition_id, dp_pb.PartState.kEventTimeReduce, dp_pb.PartState.kEventTimeReduced) self._check_processing_job_finished() def backgroup_task(self): with self._lock: if self._sync_processing_job() is not None: self._check_processing_job_finished() if self._sync_processing_job() is None and self._long_running: self._launch_new_portal_job() def _all_job_part_mapped(self): processing_job = self._sync_processing_job() assert processing_job is not None job_id = processing_job.job_id for partition_id in range(self._output_partition_num): job_part = self._sync_job_part(job_id, partition_id) if job_part.part_state <= dp_pb.PartState.kIdMap: return False return True def _all_job_part_finished(self): processing_job = self._sync_processing_job() assert processing_job is not None job_id = self._processing_job.job_id for partition_id in range(self._output_partition_num): job_part = self._sync_job_part(job_id, partition_id) if not self._is_job_part_finished(job_part): return False return True def _finish_job_part(self, job_id, partition_id, src_state, target_state): job_part = self._sync_job_part(job_id, partition_id) assert job_part is not None and job_part.part_state == src_state new_job_part = dp_pb.PortalJobPart() new_job_part.MergeFrom(job_part) new_job_part.part_state = target_state new_job_part.rank_id = -1 self._update_job_part(new_job_part) def _create_map_task(self, rank_id, partition_id): assert self._processing_job is not None job = self._processing_job map_fpaths = [] for fpath in job.fpaths: fname = path.basename(fpath) if hash(fname) % self._output_partition_num == partition_id: map_fpaths.append(fpath) return dp_pb.MapTask(fpaths=map_fpaths, output_base_dir=self._map_output_dir(job.job_id), output_partition_num=self._output_partition_num, partition_id=partition_id, part_field=self._get_part_field()) def _get_part_field(self): portal_mainifest = self._sync_portal_manifest() if portal_mainifest.data_portal_type == dp_pb.DataPortalType.PSI: return 'raw_id' assert portal_mainifest.data_portal_type == \ dp_pb.DataPortalType.Streaming return 'example_id' def _create_reduce_task(self, partition_id): assert self._processing_job is not None job = self._processing_job job_id = job.job_id return dp_pb.ReduceTask( map_base_dir=self._map_output_dir(job_id), reduce_base_dir=self._reduce_output_dir(job_id), partition_id=partition_id) def _try_to_alloc_part(self, rank_id, src_state, target_state): alloc_partition_id = None processing_job = self._sync_processing_job() assert processing_job is not None job_id = self._processing_job.job_id for partition_id in range(self._output_partition_num): part_job = self._sync_job_part(job_id, partition_id) if part_job.part_state == src_state and \ alloc_partition_id is None: alloc_partition_id = partition_id if part_job.part_state == target_state and \ part_job.rank_id == rank_id: alloc_partition_id = partition_id break if alloc_partition_id is None: return None part_job = self._job_part_map[alloc_partition_id] if part_job.part_state == src_state: new_job_part = dp_pb.PortalJobPart(job_id=job_id, rank_id=rank_id, partition_id=alloc_partition_id, part_state=target_state) self._update_job_part(new_job_part) return alloc_partition_id def _sync_portal_job(self, job_id): etcd_key = common.portal_job_etcd_key(self._portal_name, job_id) data = self._etcd.get_data(etcd_key) if data is not None: return text_format.Parse(data, dp_pb.DataPortalJob()) return None def _sync_processing_job(self): assert self._sync_portal_manifest() is not None if self._portal_manifest.processing_job_id < 0: self._processing_job = None elif self._processing_job is None or \ (self._processing_job.job_id != self._portal_manifest.processing_job_id): job_id = self._portal_manifest.processing_job_id self._processing_job = self._sync_portal_job(job_id) assert self._processing_job is not None return self._processing_job def _update_processing_job(self, job): self._processing_job = None etcd_key = common.portal_job_etcd_key(self._portal_name, job.job_id) self._etcd.set_data(etcd_key, text_format.MessageToString(job)) self._processing_job = job def _sync_portal_manifest(self): if self._portal_manifest is None: etcd_key = common.portal_etcd_base_dir(self._portal_name) data = self._etcd.get_data(etcd_key) if data is not None: self._portal_manifest = \ text_format.Parse(data, dp_pb.DataPortalManifest()) return self._portal_manifest def _update_portal_manifest(self, new_portal_manifest): self._portal_manifest = None etcd_key = common.portal_etcd_base_dir(self._portal_name) data = text_format.MessageToString(new_portal_manifest) self._etcd.set_data(etcd_key, data) self._portal_manifest = new_portal_manifest def _launch_new_portal_job(self): assert self._sync_processing_job() is None all_fpaths = self._list_input_dir() rest_fpaths = [] for fpath in all_fpaths: if fpath not in self._processed_fpath: rest_fpaths.append(fpath) if len(rest_fpaths) == 0: logging.info("no file left for portal") return rest_fpaths.sort() portal_mainifest = self._sync_portal_manifest() new_job = dp_pb.DataPortalJob(job_id=portal_mainifest.next_job_id, finished=False, fpaths=rest_fpaths) self._update_processing_job(new_job) new_portal_manifest = dp_pb.DataPortalManifest() new_portal_manifest.MergeFrom(portal_mainifest) new_portal_manifest.next_job_id += 1 new_portal_manifest.processing_job_id = new_job.job_id self._update_portal_manifest(new_portal_manifest) for partition_id in range(self._output_partition_num): self._sync_job_part(new_job.job_id, partition_id) def _list_input_dir(self): input_dir = self._portal_manifest.input_base_dir fnames = gfile.ListDirectory(input_dir) if len(self._portal_manifest.input_file_wildcard) > 0: wildcard = self._portal_manifest.input_file_wildcard fnames = [f for f in fnames if fnmatch(f, wildcard)] return [path.join(input_dir, f) for f in fnames] def _sync_job_part(self, job_id, partition_id): if partition_id not in self._job_part_map or \ self._job_part_map[partition_id] is None or \ self._job_part_map[partition_id].job_id != job_id: etcd_key = common.portal_job_part_etcd_key(self._portal_name, job_id, partition_id) data = self._etcd.get_data(etcd_key) if data is None: self._job_part_map[partition_id] = dp_pb.PortalJobPart( job_id=job_id, rank_id=-1, partition_id=partition_id) else: self._job_part_map[partition_id] = \ text_format.Parse(data, dp_pb.PortalJobPart()) return self._job_part_map[partition_id] def _update_job_part(self, job_part): partition_id = job_part.partition_id if partition_id not in self._job_part_map or \ self._job_part_map[partition_id] != job_part: self._job_part_map[partition_id] = None etcd_key = common.portal_job_part_etcd_key(self._portal_name, job_part.job_id, partition_id) data = text_format.MessageToString(job_part) self._etcd.set_data(etcd_key, data) self._job_part_map[partition_id] = job_part def _check_processing_job_finished(self): if not self._all_job_part_finished(): return False processing_job = self._sync_processing_job() if not processing_job.finished: finished_job = dp_pb.DataPortalJob() finished_job.MergeFrom(self._processing_job) finished_job.finished = True self._update_processing_job(finished_job) self._processing_job = None self._job_part_map = {} portal_mainifest = self._sync_portal_manifest() if portal_mainifest.processing_job_id >= 0: self._publish_raw_data(portal_mainifest.processing_job_id) new_portal_manifest = dp_pb.DataPortalManifest() new_portal_manifest.MergeFrom(self._sync_portal_manifest()) new_portal_manifest.processing_job_id = -1 self._update_portal_manifest(new_portal_manifest) return True @property def _output_partition_num(self): return self._portal_manifest.output_partition_num def _is_job_part_finished(self, job_part): assert self._portal_manifest is not None if self._portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: return job_part.part_state == dp_pb.PartState.kIdMapped return job_part.part_state == dp_pb.PartState.kEventTimeReduced def _map_output_dir(self, job_id): return common.portal_map_output_dir( self._portal_manifest.output_base_dir, job_id) def _reduce_output_dir(self, job_id): return common.portal_reduce_output_dir( self._portal_manifest.output_base_dir, job_id) def _publish_raw_data(self, job_id): portal_manifest = self._sync_portal_manifest() output_dir = None if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: output_dir = common.portal_map_output_dir( portal_manifest.output_base_dir, job_id) else: output_dir = common.portal_reduce_output_dir( portal_manifest.output_base_dir, job_id) for partition_id in range(self._output_partition_num): dpath = path.join(output_dir, common.partition_repr(partition_id)) fnames = [] if gfile.Exists(dpath) and gfile.IsDirectory(dpath): fnames = [ f for f in gfile.ListDirectory(dpath) if f.endswith(common.RawDataFileSuffix) ] if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: self._publish_psi_raw_data(partition_id, dpath, fnames) else: self._publish_streaming_raw_data(partition_id, dpath, fnames) def _publish_streaming_raw_data(self, partition_id, dpath, fnames): metas = [ MergedSortRunMeta.decode_sort_run_meta_from_fname(fname) for fname in fnames ] metas.sort() fpaths = [ path.join(dpath, meta.encode_merged_sort_run_fname()) for meta in metas ] self._publisher.publish_raw_data(partition_id, fpaths) def _publish_psi_raw_data(self, partition_id, dpath, fnames): metas = [ RawDataPartitioner.FileMeta.decode_meta_from_fname(fname) for fname in fnames ] fpaths = [ path.join(dpath, meta.encode_meta_to_fname()) for meta in metas ] self._publisher.publish_raw_data(partition_id, fpaths) self._publisher.finish_raw_data(partition_id)
class RsaPsiPreProcessor(object): def __init__(self, options, kvstore_type, use_mock_etcd=False): self._lock = threading.Condition() self._options = options kvstore = DBClient(kvstore_type, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(kvstore, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._callback_submitter = None # pre fock sub processor before launch grpc client self._process_pool_executor.submit(min, 1, 2).result() self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options) if self._options.role == common_pb.FLRole.Leader: private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, options.batch_processor_options.max_flying_item, self._options.max_flying_sign_batch, self._options.slow_sign_threshold, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem) self._callback_submitter = concur_futures.ThreadPoolExecutor(1) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, options.batch_processor_options.max_flying_item, self._options.max_flying_sign_batch, self._options.max_flying_sign_rpc, self._options.sign_rpc_timeout_ms, self._options.slow_sign_threshold, self._options.stub_fanout, self._process_pool_executor, self._callback_submitter, public_key, self._options.leader_rsa_psi_signer_addr) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( dj_pb.SortRunMergerOptions( merger_name='sort_run_merger_'+\ partition_repr(options.partition_id), reader_options=dj_pb.RawDataOptions( raw_data_iter=options.writer_options.output_writer, compressed_type=options.writer_options.compressed_type, read_ahead_size=\ options.sort_run_merger_read_ahead_buffer, read_batch_size=\ options.sort_run_merger_read_batch_size ), writer_options=options.writer_options, output_file_dir=options.output_file_dir, partition_id=options.partition_id ), self._merger_comparator ) self._produce_item_cnt = 0 self._comsume_item_cnt = 0 self._started = False def start_process(self): with self._lock: if not self._started: self._worker_map = { self._id_batch_fetcher_name(): RoutineWorker(self._id_batch_fetcher_name(), self._id_batch_fetch_fn, self._id_batch_fetch_cond, 5), self._psi_rsa_signer_name(): RoutineWorker(self._psi_rsa_signer_name(), self._psi_rsa_sign_fn, self._psi_rsa_sign_cond, 5), self._sort_run_dumper_name(): RoutineWorker(self._sort_run_dumper_name(), self._sort_run_dump_fn, self._sort_run_dump_cond, 5), self._sort_run_merger_name(): RoutineWorker(self._sort_run_merger_name(), self._sort_run_merge_fn, self._sort_run_merge_cond, 5) } for _, w in self._worker_map.items(): w.start_routine() self._started = True def stop_routine_workers(self): wait_join = True with self._lock: if self._started: wait_join = True self._started = False if wait_join: for w in self._worker_map.values(): w.stop_routine() def wait_for_finished(self): while not self._sort_run_merger.is_merged_finished(): with self._lock: self._lock.wait() self.stop_routine_workers() self._process_pool_executor.shutdown() if self._callback_submitter is not None: self._callback_submitter.shutdown() self._id_batch_fetcher.cleanup_visitor_meta_data() self._bye_for_signer() def _bye_for_signer(self): for rnd in range(60): try: self._psi_rsa_signer.say_signer_bye() logging.info("Success to say bye to signer at round "\ "%d, rsa_psi_preprocessor will exit", rnd) return except Exception as e: # pylint: disable=broad-except logging.warning("Failed to say bye to signer at "\ "round %d, sleep 10s and retry", rnd) time.sleep(10) logging.warning("Give up to say bye to signer after try 60"\ "times, rsa_psi_preprocessor will exit as -1") traceback.print_stack() os._exit(-1) # pylint: disable=protected-access def _id_batch_fetcher_name(self): return self._repr + ':id_batch_fetcher' def _wakeup_id_batch_fetcher(self): self._worker_map[self._id_batch_fetcher_name()].wakeup() def _id_batch_fetch_fn(self): next_index = self._psi_rsa_signer.get_next_index_to_fetch() for batch in self._id_batch_fetcher.make_processor(next_index): logging.debug("%s fetch batch begin at %d, len %d. wakeup %s", self._id_batch_fetcher_name(), batch.begin_index, len(batch), self._psi_rsa_signer_name()) self._produce_item_cnt += len(batch) self._wakeup_psi_rsa_signer() if self._stop_fetch_id(): break def _id_batch_fetch_cond(self): next_index = self._psi_rsa_signer.get_next_index_to_fetch() return self._id_batch_fetcher.need_process(next_index) and \ not self._stop_fetch_id() and \ not self._sort_run_dumper.is_dump_finished() def _stop_fetch_id(self): total_flying_item = self._produce_item_cnt - self._comsume_item_cnt if total_flying_item >= 5 << 20: logging.warning("stop fetch id since flying item "\ "reach to %d > 5m, produce_item_cnt: %d; "\ "consume_item_cnt: %d", total_flying_item, self._produce_item_cnt, self._comsume_item_cnt) return True potential_mem_incr = total_flying_item * \ self._psi_rsa_signer.additional_item_mem_usage() if get_heap_mem_stats(None).CheckOomRisk(total_flying_item, 0.80, potential_mem_incr): logging.warning("stop fetch id since has oom risk for 0.80, "\ "flying item reach to %d", total_flying_item) return True return False def _psi_rsa_signer_name(self): return self._repr + ':psi_rsa_signer' def _wakeup_psi_rsa_signer(self): self._worker_map[self._psi_rsa_signer_name()].wakeup() def _transmit_signed_batch(self, signed_index): evict_batch_cnt = self._id_batch_fetcher.evict_staless_item_batch( signed_index) self._psi_rsa_signer.update_next_batch_index_hint(evict_batch_cnt) self._wakeup_sort_run_dumper() def _psi_rsa_sign_fn(self): next_index = self._sort_run_dumper.get_next_index_to_dump() sign_cnt = 0 signed_index = None for signed_batch in self._psi_rsa_signer.make_processor(next_index): logging.debug("%s sign batch begin at %d, len %d. wakeup %s", self._psi_rsa_signer_name(), signed_batch.begin_index, len(signed_batch), self._sort_run_dumper_name()) sign_cnt += 1 if signed_batch is not None: signed_index = signed_batch.begin_index + len(signed_batch) - 1 if sign_cnt % 16 == 0: self._transmit_signed_batch(signed_index) self._transmit_signed_batch(signed_index) def _psi_rsa_sign_cond(self): next_index = self._sort_run_dumper.get_next_index_to_dump() return self._psi_rsa_signer.need_process(next_index) and \ not self._sort_run_dumper.is_dump_finished() def _sort_run_dumper_name(self): return self._repr + ':sort_run_dumper' def _wakeup_sort_run_dumper(self): self._worker_map[self._sort_run_dumper_name()].wakeup() def _load_sorted_items_from_rsa_signer(self): sort_run_dumper = self._sort_run_dumper rsi_signer = self._psi_rsa_signer next_index = sort_run_dumper.get_next_index_to_dump() hint_index = None items_buffer = [] signed_finished = False total_item_num = 0 max_flying_item = self._options.batch_processor_options.max_flying_item sort_run_size = max_flying_item // 2 while sort_run_size <= 0 or total_item_num < sort_run_size: signed_finished, batch, hint_index = \ rsi_signer.fetch_item_batch_by_index(next_index, hint_index) if batch is None: break assert next_index == batch.begin_index for item in batch: items_buffer.append(item) next_index += len(batch) total_item_num += len(batch) sorted_items_buffer = sorted(items_buffer, key=lambda item: item[0]) return signed_finished, sorted_items_buffer, next_index def _sort_run_dump_fn(self): signed_finished, items_buffer, next_index = \ self._load_sorted_items_from_rsa_signer() sort_run_dumper = self._sort_run_dumper if len(items_buffer) > 0: def producer(items_buffer): for signed_id, item, index in items_buffer: item.set_example_id(signed_id) yield signed_id, index, item sort_run_dumper.dump_sort_runs(producer(items_buffer)) if next_index is not None: self._psi_rsa_signer.evict_staless_item_batch(next_index - 1) if signed_finished: sort_run_dumper.finish_dump_sort_run() dump_cnt = len(items_buffer) self._comsume_item_cnt += dump_cnt del items_buffer logging.warning("dump %d item in sort run, and gc %d objects.", dump_cnt, gc.collect()) def _sort_run_dump_cond(self): sort_run_dumper = self._sort_run_dumper rsa_signer = self._psi_rsa_signer next_index = sort_run_dumper.get_next_index_to_dump() max_flying_item = self._options.batch_processor_options.max_flying_item dump_finished = sort_run_dumper.is_dump_finished() signed_finished = rsa_signer.get_process_finished() flying_item_cnt = rsa_signer.get_flying_item_count() flying_begin_index = rsa_signer.get_flying_begin_index() dump_cands_num = 0 if flying_begin_index is not None and next_index is not None and \ (flying_begin_index <= next_index < flying_begin_index + flying_item_cnt): dump_cands_num = flying_item_cnt - (next_index - flying_begin_index) return not dump_finished and \ (signed_finished or (dump_cands_num >= (2 << 20) or (max_flying_item > 2 and dump_cands_num > max_flying_item // 2)) or self._dump_for_forward(dump_cands_num)) def _dump_for_forward(self, dump_cands_num): if self._stop_fetch_id(): total_flying_item = self._produce_item_cnt - self._comsume_item_cnt return dump_cands_num > 0 and \ dump_cands_num >= total_flying_item // 2 return False def _sort_run_merger_name(self): return self._repr + ':sort_run_merger' def _sort_run_merge_fn(self): sort_runs = self._sort_run_dumper.get_all_sort_runs() input_dir = self._sort_run_dumper.sort_run_dump_dir() input_fpaths = [ os.path.join(input_dir, partition_repr(self._options.partition_id), sort_run.encode_sort_run_fname()) for sort_run in sort_runs ] output_fpaths = self._sort_run_merger.merge_sort_runs(input_fpaths) self._publisher.publish_raw_data(self._options.partition_id, output_fpaths) self._publisher.finish_raw_data(self._options.partition_id) self._sort_run_merger.set_merged_finished() def _sort_run_merge_cond(self): if self._sort_run_merger.is_merged_finished(): with self._lock: self._lock.notify() return False return self._sort_run_dumper.is_dump_finished() @staticmethod def _merger_comparator(a, b): return a.example_id < b.example_id
class RsaPsiPreProcessor(object): def __init__(self, options, etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd=False): self._lock = threading.Condition() self._options = options etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(etcd, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._id_batch_fetcher = IdBatchFetcher(etcd, self._options) max_flying_item = options.batch_processor_options.max_flying_item if self._options.role == common_pb.FLRole.Leader: private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.slow_sign_threshold, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.max_flying_sign_rpc, self._options.sign_rpc_timeout_ms, self._options.slow_sign_threshold, self._options.stub_fanout, self._process_pool_executor, public_key, self._options.leader_rsa_psi_signer_addr) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( dj_pb.SortRunMergerOptions( merger_name='sort_run_merger_'+\ partition_repr(options.partition_id), reader_options=dj_pb.RawDataOptions( raw_data_iter=options.writer_options.output_writer, compressed_type=options.writer_options.compressed_type, read_ahead_size=\ options.sort_run_merger_read_ahead_buffer ), writer_options=options.writer_options, output_file_dir=options.output_file_dir, partition_id=options.partition_id ), 'example_id' ) self._started = False def start_process(self): with self._lock: if not self._started: self._worker_map = { self._id_batch_fetcher_name(): RoutineWorker(self._id_batch_fetcher_name(), self._id_batch_fetch_fn, self._id_batch_fetch_cond, 5), self._psi_rsa_signer_name(): RoutineWorker(self._psi_rsa_signer_name(), self._psi_rsa_sign_fn, self._psi_rsa_sign_cond, 5), self._sort_run_dumper_name(): RoutineWorker(self._sort_run_dumper_name(), self._sort_run_dump_fn, self._sort_run_dump_cond, 5), self._sort_run_merger_name(): RoutineWorker(self._sort_run_merger_name(), self._sort_run_merge_fn, self._sort_run_merge_cond, 5) } for _, w in self._worker_map.items(): w.start_routine() self._started = True def stop_routine_workers(self): wait_join = True with self._lock: if self._started: wait_join = True self._started = False if wait_join: for w in self._worker_map.values(): w.stop_routine() def wait_for_finished(self): while not self._sort_run_merger.is_merged_finished(): with self._lock: self._lock.wait() self.stop_routine_workers() self._process_pool_executor.shutdown() self._id_batch_fetcher.cleanup_visitor_meta_data() def _id_batch_fetcher_name(self): return self._repr + ':id_batch_fetcher' def _wakeup_id_batch_fetcher(self): self._worker_map[self._id_batch_fetcher_name()].wakeup() def _id_batch_fetch_fn(self): next_index = self._psi_rsa_signer.get_next_index_to_fetch() for batch in self._id_batch_fetcher.make_processor(next_index): logging.debug("%s fetch batch begin at %d, len %d. wakeup %s", self._id_batch_fetcher_name(), batch.begin_index, len(batch), self._psi_rsa_signer_name()) self._wakeup_psi_rsa_signer() def _id_batch_fetch_cond(self): next_index = self._psi_rsa_signer.get_next_index_to_fetch() return self._id_batch_fetcher.need_process(next_index) def _psi_rsa_signer_name(self): return self._repr + ':psi_rsa_signer' def _wakeup_psi_rsa_signer(self): self._worker_map[self._psi_rsa_signer_name()].wakeup() def _psi_rsa_sign_fn(self): next_index = self._sort_run_dumper.get_next_index_to_dump() for signed_batch in self._psi_rsa_signer.make_processor(next_index): logging.debug("%s sign batch begin at %d, len %d. wakeup %s", self._psi_rsa_signer_name(), signed_batch.begin_index, len(signed_batch), self._sort_run_dumper_name()) self._wakeup_sort_run_dumper() staless_index = self._sort_run_dumper.get_next_index_to_dump() - 1 evict_batch_cnt = self._id_batch_fetcher.evict_staless_item_batch( staless_index) self._psi_rsa_signer.update_next_batch_index_hint(evict_batch_cnt) def _psi_rsa_sign_cond(self): next_index = self._sort_run_dumper.get_next_index_to_dump() return self._psi_rsa_signer.need_process(next_index) def _sort_run_dumper_name(self): return self._repr + ':sort_run_dumper' def _wakeup_sort_run_dumper(self): self._worker_map[self._sort_run_dumper_name()].wakeup() def _load_sorted_items_from_rsa_signer(self): sort_run_dumper = self._sort_run_dumper rsi_signer = self._psi_rsa_signer next_index = sort_run_dumper.get_next_index_to_dump() hint_index = None items_buffer = [] signed_finished = False total_item_num = 0 max_flying_item = self._options.batch_processor_options.max_flying_item sort_run_size = max_flying_item // 4 while True and total_item_num < sort_run_size: signed_finished, batch, hint_index = \ rsi_signer.fetch_item_batch_by_index(next_index, hint_index) if batch is None: break assert next_index == batch.begin_index for item in batch: items_buffer.append(item) next_index += len(batch) total_item_num += len(batch) sorted_items_buffer = sorted(items_buffer, key=lambda item: item[0]) return signed_finished, sorted_items_buffer, next_index def _sort_run_dump_fn(self): signed_finished, items_buffer, next_index = \ self._load_sorted_items_from_rsa_signer() sort_run_dumper = self._sort_run_dumper if len(items_buffer) > 0: def producer(items_buffer): for signed_id, item, index in items_buffer: item.set_example_id(signed_id) yield signed_id, index, item sort_run_dumper.dump_sort_runs(producer(items_buffer)) if next_index is not None: self._psi_rsa_signer.evict_staless_item_batch(next_index - 1) if signed_finished: sort_run_dumper.finish_dump_sort_run() def _sort_run_dump_cond(self): sort_run_dumper = self._sort_run_dumper rsa_signer = self._psi_rsa_signer next_index = sort_run_dumper.get_next_index_to_dump() max_flying_item = self._options.batch_processor_options.max_flying_item dump_finished = sort_run_dumper.is_dump_finished() signed_finished = rsa_signer.get_process_finished() flying_item_cnt = rsa_signer.get_flying_item_count() flying_begin_index = rsa_signer.get_flying_begin_index() return not dump_finished and \ (signed_finished or (flying_begin_index is not None and next_index is not None and (flying_begin_index <= next_index < flying_begin_index + flying_item_cnt) and (flying_item_cnt-(next_index-flying_begin_index) >= max_flying_item // 4))) def _sort_run_merger_name(self): return self._repr + ':sort_run_merger' def _sort_run_merge_fn(self): sort_runs = self._sort_run_dumper.get_all_sort_runs() input_dir = self._sort_run_dumper.sort_run_dump_dir() input_fpaths = [ os.path.join(input_dir, partition_repr(self._options.partition_id), sort_run.encode_sort_run_fname()) for sort_run in sort_runs ] output_fpaths = self._sort_run_merger.merge_sort_runs(input_fpaths) self._publisher.publish_raw_data(self._options.partition_id, output_fpaths) self._publisher.finish_raw_data(self._options.partition_id) self._sort_run_merger.set_merged_finished() def _sort_run_merge_cond(self): if self._sort_run_merger.is_merged_finished(): with self._lock: self._lock.notify() return False return self._sort_run_dumper.is_dump_finished()
class DataPortalJobManager(object): def __init__(self, etcd, portal_name, long_running): self._lock = threading.Lock() self._etcd = etcd self._portal_name = portal_name self._portal_manifest = None self._processing_job = None self._sync_portal_manifest() self._sync_processing_job() self._publisher = \ RawDataPublisher(etcd, self._portal_manifest.raw_data_publish_dir) self._long_running = long_running assert self._portal_manifest is not None self._processed_fpath = set() for job_id in range(0, self._portal_manifest.next_job_id): job = self._sync_portal_job(job_id) assert job is not None and job.job_id == job_id for fpath in job.fpaths: self._processed_fpath.add(fpath) self._job_part_map = {} if self._portal_manifest.processing_job_id >= 0: self._check_processing_job_finished() if self._portal_manifest.processing_job_id < 0: self._launch_new_portal_job() def get_portal_manifest(self): with self._lock: return self._sync_portal_manifest() def alloc_task(self, rank_id): with self._lock: self._sync_processing_job() if self._processing_job is not None: partition_id = self._try_to_alloc_part(rank_id, dp_pb.PartState.kInit, dp_pb.PartState.kIdMap) if partition_id is not None: return False, self._create_map_task(rank_id, partition_id) if self._all_job_part_mapped() and \ (self._portal_manifest.data_portal_type == dp_pb.DataPortalType.Streaming): partition_id = self._try_to_alloc_part( rank_id, dp_pb.PartState.kIdMapped, dp_pb.PartState.kEventTimeReduce) if partition_id is not None: return False, self._create_reduce_task( rank_id, partition_id) return (not self._long_running and self._all_job_part_finished()), None return not self._long_running, None def finish_task(self, rank_id, partition_id, part_state): with self._lock: processing_job = self._sync_processing_job() if processing_job is None: return job_id = self._processing_job.job_id job_part = self._sync_job_part(job_id, partition_id) if job_part.rank_id == rank_id and \ job_part.part_state == part_state: if job_part.part_state == dp_pb.PartState.kIdMap: self._finish_job_part(job_id, partition_id, dp_pb.PartState.kIdMap, dp_pb.PartState.kIdMapped) logging.info("Data portal worker-%d finish map task "\ "for partition %d of job %d", rank_id, partition_id, job_id) elif job_part.part_state == dp_pb.PartState.kEventTimeReduce: self._finish_job_part(job_id, partition_id, dp_pb.PartState.kEventTimeReduce, dp_pb.PartState.kEventTimeReduced) logging.info("Data portal worker-%d finish reduce task "\ "for partition %d of job %d", rank_id, partition_id, job_id) self._check_processing_job_finished() def backgroup_task(self): with self._lock: if self._sync_processing_job() is not None: self._check_processing_job_finished() if self._sync_processing_job() is None and self._long_running: self._launch_new_portal_job() def _all_job_part_mapped(self): processing_job = self._sync_processing_job() assert processing_job is not None job_id = processing_job.job_id for partition_id in range(self._output_partition_num): job_part = self._sync_job_part(job_id, partition_id) if job_part.part_state <= dp_pb.PartState.kIdMap: return False return True def _all_job_part_finished(self): processing_job = self._sync_processing_job() assert processing_job is not None job_id = self._processing_job.job_id for partition_id in range(self._output_partition_num): job_part = self._sync_job_part(job_id, partition_id) if not self._is_job_part_finished(job_part): return False return True def _finish_job_part(self, job_id, partition_id, src_state, target_state): job_part = self._sync_job_part(job_id, partition_id) assert job_part is not None and job_part.part_state == src_state new_job_part = dp_pb.PortalJobPart() new_job_part.MergeFrom(job_part) new_job_part.part_state = target_state new_job_part.rank_id = -1 self._update_job_part(new_job_part) def _create_map_task(self, rank_id, partition_id): assert self._processing_job is not None job = self._processing_job map_fpaths = [] for fpath in job.fpaths: if hash(fpath) % self._output_partition_num == partition_id: map_fpaths.append(fpath) task_name = '{}-dp_portal_job_{:08}-part-{:04}-map'.format( self._portal_manifest.name, job.job_id, partition_id) logging.info("Data portal worker-%d is allocated map task %s for "\ "partition %d of job %d. the map task has %d files"\ "-----------------\n", rank_id, task_name, partition_id, job.job_id, len(map_fpaths)) for seq, fpath in enumerate(map_fpaths): logging.info("%d. %s", seq, fpath) logging.info("---------------------------------\n") manifset = self._sync_portal_manifest() return dp_pb.MapTask(task_name=task_name, fpaths=map_fpaths, output_base_dir=self._map_output_dir(job.job_id), output_partition_num=self._output_partition_num, partition_id=partition_id, part_field=self._get_part_field(), data_portal_type=manifset.data_portal_type) def _get_part_field(self): portal_mainifest = self._sync_portal_manifest() if portal_mainifest.data_portal_type == dp_pb.DataPortalType.PSI: return 'raw_id' assert portal_mainifest.data_portal_type == \ dp_pb.DataPortalType.Streaming return 'example_id' def _create_reduce_task(self, rank_id, partition_id): assert self._processing_job is not None job = self._processing_job job_id = job.job_id task_name = '{}-dp_portal_job_{:08}-part-{:04}-reduce'.format( self._portal_manifest.name, job_id, partition_id) logging.info("Data portal worker-%d is allocated reduce task %s for "\ "partition %d of job %d. the reduce base dir %s"\ "-----------------\n", rank_id, task_name, partition_id, job_id, self._reduce_output_dir(job_id)) return dp_pb.ReduceTask( task_name=task_name, map_base_dir=self._map_output_dir(job_id), reduce_base_dir=self._reduce_output_dir(job_id), partition_id=partition_id) def _try_to_alloc_part(self, rank_id, src_state, target_state): alloc_partition_id = None processing_job = self._sync_processing_job() assert processing_job is not None job_id = self._processing_job.job_id for partition_id in range(self._output_partition_num): part_job = self._sync_job_part(job_id, partition_id) if part_job.part_state == src_state and \ alloc_partition_id is None: alloc_partition_id = partition_id if part_job.part_state == target_state and \ part_job.rank_id == rank_id: alloc_partition_id = partition_id break if alloc_partition_id is None: return None part_job = self._job_part_map[alloc_partition_id] if part_job.part_state == src_state: new_job_part = dp_pb.PortalJobPart(job_id=job_id, rank_id=rank_id, partition_id=alloc_partition_id, part_state=target_state) self._update_job_part(new_job_part) return alloc_partition_id def _sync_portal_job(self, job_id): etcd_key = common.portal_job_etcd_key(self._portal_name, job_id) data = self._etcd.get_data(etcd_key) if data is not None: return text_format.Parse(data, dp_pb.DataPortalJob()) return None def _sync_processing_job(self): assert self._sync_portal_manifest() is not None if self._portal_manifest.processing_job_id < 0: self._processing_job = None elif self._processing_job is None or \ (self._processing_job.job_id != self._portal_manifest.processing_job_id): job_id = self._portal_manifest.processing_job_id self._processing_job = self._sync_portal_job(job_id) assert self._processing_job is not None return self._processing_job def _update_processing_job(self, job): self._processing_job = None etcd_key = common.portal_job_etcd_key(self._portal_name, job.job_id) self._etcd.set_data(etcd_key, text_format.MessageToString(job)) self._processing_job = job def _sync_portal_manifest(self): if self._portal_manifest is None: etcd_key = common.portal_etcd_base_dir(self._portal_name) data = self._etcd.get_data(etcd_key) if data is not None: self._portal_manifest = \ text_format.Parse(data, dp_pb.DataPortalManifest()) return self._portal_manifest def _update_portal_manifest(self, new_portal_manifest): self._portal_manifest = None etcd_key = common.portal_etcd_base_dir(self._portal_name) data = text_format.MessageToString(new_portal_manifest) self._etcd.set_data(etcd_key, data) self._portal_manifest = new_portal_manifest def _launch_new_portal_job(self): assert self._sync_processing_job() is None all_fpaths = self._list_input_dir() rest_fpaths = [] for fpath in all_fpaths: if fpath not in self._processed_fpath: rest_fpaths.append(fpath) if len(rest_fpaths) == 0: logging.info("no file left for portal") return rest_fpaths.sort() portal_mainifest = self._sync_portal_manifest() new_job = dp_pb.DataPortalJob(job_id=portal_mainifest.next_job_id, finished=False, fpaths=rest_fpaths) self._update_processing_job(new_job) new_portal_manifest = dp_pb.DataPortalManifest() new_portal_manifest.MergeFrom(portal_mainifest) new_portal_manifest.next_job_id += 1 new_portal_manifest.processing_job_id = new_job.job_id self._update_portal_manifest(new_portal_manifest) for partition_id in range(self._output_partition_num): self._sync_job_part(new_job.job_id, partition_id) logging.info("Data Portal job %d has lanuched. %d files will be"\ "processed\n------------\n", new_job.job_id, len(new_job.fpaths)) for seq, fpath in enumerate(new_job.fpaths): logging.info("%d. %s", seq, fpath) logging.info("---------------------------------\n") def _list_input_dir(self): all_inputs = [] wildcard = self._portal_manifest.input_file_wildcard dirs = [self._portal_manifest.input_base_dir] while len(dirs) > 0: fdir = dirs[0] dirs = dirs[1:] fnames = gfile.ListDirectory(fdir) for fname in fnames: fpath = path.join(fdir, fname) if gfile.IsDirectory(fpath): dirs.append(fpath) elif len(wildcard) == 0 or fnmatch(fname, wildcard): all_inputs.append(fpath) return all_inputs def _sync_job_part(self, job_id, partition_id): if partition_id not in self._job_part_map or \ self._job_part_map[partition_id] is None or \ self._job_part_map[partition_id].job_id != job_id: etcd_key = common.portal_job_part_etcd_key(self._portal_name, job_id, partition_id) data = self._etcd.get_data(etcd_key) if data is None: self._job_part_map[partition_id] = dp_pb.PortalJobPart( job_id=job_id, rank_id=-1, partition_id=partition_id) else: self._job_part_map[partition_id] = \ text_format.Parse(data, dp_pb.PortalJobPart()) return self._job_part_map[partition_id] def _update_job_part(self, job_part): partition_id = job_part.partition_id if partition_id not in self._job_part_map or \ self._job_part_map[partition_id] != job_part: self._job_part_map[partition_id] = None etcd_key = common.portal_job_part_etcd_key(self._portal_name, job_part.job_id, partition_id) data = text_format.MessageToString(job_part) self._etcd.set_data(etcd_key, data) self._job_part_map[partition_id] = job_part def _check_processing_job_finished(self): if not self._all_job_part_finished(): return False processing_job = self._sync_processing_job() if not processing_job.finished: finished_job = dp_pb.DataPortalJob() finished_job.MergeFrom(self._processing_job) finished_job.finished = True self._update_processing_job(finished_job) self._processing_job = None self._job_part_map = {} portal_mainifest = self._sync_portal_manifest() if portal_mainifest.processing_job_id >= 0: self._publish_raw_data(portal_mainifest.processing_job_id) new_portal_manifest = dp_pb.DataPortalManifest() new_portal_manifest.MergeFrom(self._sync_portal_manifest()) new_portal_manifest.processing_job_id = -1 self._update_portal_manifest(new_portal_manifest) if processing_job is not None: logging.info("Data Portal job %d has finished. Processed %d "\ "following fpaths\n------------\n", processing_job.job_id, len(processing_job.fpaths)) for seq, fpath in enumerate(processing_job.fpaths): logging.info("%d. %s", seq, fpath) logging.info("---------------------------------\n") return True @property def _output_partition_num(self): return self._portal_manifest.output_partition_num def _is_job_part_finished(self, job_part): assert self._portal_manifest is not None if self._portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: return job_part.part_state == dp_pb.PartState.kIdMapped return job_part.part_state == dp_pb.PartState.kEventTimeReduced def _map_output_dir(self, job_id): return common.portal_map_output_dir( self._portal_manifest.output_base_dir, job_id) def _reduce_output_dir(self, job_id): return common.portal_reduce_output_dir( self._portal_manifest.output_base_dir, job_id) def _publish_raw_data(self, job_id): portal_manifest = self._sync_portal_manifest() output_dir = None if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: output_dir = common.portal_map_output_dir( portal_manifest.output_base_dir, job_id) else: output_dir = common.portal_reduce_output_dir( portal_manifest.output_base_dir, job_id) for partition_id in range(self._output_partition_num): dpath = path.join(output_dir, common.partition_repr(partition_id)) fnames = [] if gfile.Exists(dpath) and gfile.IsDirectory(dpath): fnames = [ f for f in gfile.ListDirectory(dpath) if f.endswith(common.RawDataFileSuffix) ] publish_fpaths = [] if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: publish_fpaths = self._publish_psi_raw_data( partition_id, dpath, fnames) else: publish_fpaths = self._publish_streaming_raw_data( partition_id, dpath, fnames) logging.info("Data Portal Master publish %d file for partition "\ "%d of streaming job %d\n----------\n", len(publish_fpaths), partition_id, job_id) for seq, fpath in enumerate(publish_fpaths): logging.info("%d. %s", seq, fpath) logging.info("------------------------------------------\n") def _publish_streaming_raw_data(self, partition_id, dpath, fnames): metas = [ MergedSortRunMeta.decode_sort_run_meta_from_fname(fname) for fname in fnames ] metas.sort() fpaths = [ path.join(dpath, meta.encode_merged_sort_run_fname()) for meta in metas ] self._publisher.publish_raw_data(partition_id, fpaths) return fpaths def _publish_psi_raw_data(self, partition_id, dpath, fnames): fpaths = [path.join(dpath, fname) for fname in fnames] self._publisher.publish_raw_data(partition_id, fpaths) self._publisher.finish_raw_data(partition_id) return fpaths
class DataPortalJobManager(object): def __init__(self, kvstore, portal_name, long_running, check_success_tag, single_subfolder, files_per_job_limit, max_files_per_job=8000): self._lock = threading.Lock() self._kvstore = kvstore self._portal_name = portal_name self._check_success_tag = check_success_tag self._single_subfolder = single_subfolder self._files_per_job_limit = files_per_job_limit self._max_files_per_job = max_files_per_job self._portal_manifest = None self._processing_job = None self._sync_portal_manifest() self._sync_processing_job() self._publisher = \ RawDataPublisher(kvstore, self._portal_manifest.raw_data_publish_dir) self._long_running = long_running self._finished = False assert self._portal_manifest is not None self._processed_fpath = set() for job_id in range(0, self._portal_manifest.next_job_id): job = self._sync_portal_job(job_id) assert job is not None and job.job_id == job_id for fpath in job.fpaths: self._processed_fpath.add(fpath) self._job_part_map = {} if self._portal_manifest.processing_job_id >= 0: self._check_processing_job_finished() if self._portal_manifest.processing_job_id < 0: if not self._launch_new_portal_job() and not self._long_running: self._finished = True def get_portal_manifest(self): with self._lock: return self._sync_portal_manifest() def alloc_task(self, rank_id): with self._lock: self._sync_processing_job() if self._processing_job is not None: partition_id = self._try_to_alloc_part(rank_id, dp_pb.PartState.kInit, dp_pb.PartState.kIdMap) if partition_id is not None: return False, self._create_map_task(rank_id, partition_id) if self._all_job_part_mapped() and \ (self._portal_manifest.data_portal_type == dp_pb.DataPortalType.Streaming): partition_id = self._try_to_alloc_part( rank_id, dp_pb.PartState.kIdMapped, dp_pb.PartState.kEventTimeReduce) if partition_id is not None: return False, self._create_reduce_task( rank_id, partition_id) return (self._finished and self._all_job_part_finished()), None return self._finished, None def finish_task(self, rank_id, partition_id, part_state): with self._lock: processing_job = self._sync_processing_job() if processing_job is None: return job_id = self._processing_job.job_id job_part = self._sync_job_part(job_id, partition_id) if job_part.rank_id == rank_id and \ job_part.part_state == part_state: if job_part.part_state == dp_pb.PartState.kIdMap: self._finish_job_part(job_id, partition_id, dp_pb.PartState.kIdMap, dp_pb.PartState.kIdMapped) logging.info("Data portal worker-%d finish map task "\ "for partition %d of job %d", rank_id, partition_id, job_id) elif job_part.part_state == dp_pb.PartState.kEventTimeReduce: self._finish_job_part(job_id, partition_id, dp_pb.PartState.kEventTimeReduce, dp_pb.PartState.kEventTimeReduced) logging.info("Data portal worker-%d finish reduce task "\ "for partition %d of job %d", rank_id, partition_id, job_id) self._check_processing_job_finished() def backgroup_task(self): with self._lock: if self._sync_processing_job() is not None: self._check_processing_job_finished() if self._sync_processing_job() is None: success = self._launch_new_portal_job() if not success and not self._long_running: self._finished = True def _all_job_part_mapped(self): processing_job = self._sync_processing_job() assert processing_job is not None job_id = processing_job.job_id for partition_id in range(self._output_partition_num): job_part = self._sync_job_part(job_id, partition_id) if job_part.part_state <= dp_pb.PartState.kIdMap: return False return True def _all_job_part_finished(self): processing_job = self._sync_processing_job() assert processing_job is not None job_id = self._processing_job.job_id for partition_id in range(self._output_partition_num): job_part = self._sync_job_part(job_id, partition_id) if not self._is_job_part_finished(job_part): return False return True def _finish_job_part(self, job_id, partition_id, src_state, target_state): job_part = self._sync_job_part(job_id, partition_id) assert job_part is not None and job_part.part_state == src_state new_job_part = dp_pb.PortalJobPart() new_job_part.MergeFrom(job_part) new_job_part.part_state = target_state new_job_part.rank_id = -1 self._update_job_part(new_job_part) def _create_map_task(self, rank_id, partition_id): assert self._processing_job is not None job = self._processing_job map_fpaths = [] for fpath in job.fpaths: if hash(fpath) % self._output_partition_num == partition_id: map_fpaths.append(fpath) task_name = '{}-dp_portal_job_{:08}-part-{:04}-map'.format( self._portal_manifest.name, job.job_id, partition_id) logging.info("Data portal worker-%d is allocated map task %s for "\ "partition %d of job %d. the map task has %d files"\ "-----------------\n", rank_id, task_name, partition_id, job.job_id, len(map_fpaths)) for seq, fpath in enumerate(map_fpaths): logging.info("%d. %s", seq, fpath) logging.info("---------------------------------\n") manifset = self._sync_portal_manifest() return dp_pb.MapTask(task_name=task_name, fpaths=map_fpaths, output_base_dir=self._map_output_dir(job.job_id), output_partition_num=self._output_partition_num, partition_id=partition_id, part_field=self._get_part_field(), data_portal_type=manifset.data_portal_type) def _get_part_field(self): portal_mainifest = self._sync_portal_manifest() if portal_mainifest.data_portal_type == dp_pb.DataPortalType.PSI: return 'raw_id' assert portal_mainifest.data_portal_type == \ dp_pb.DataPortalType.Streaming return 'example_id' def _create_reduce_task(self, rank_id, partition_id): assert self._processing_job is not None job = self._processing_job job_id = job.job_id task_name = '{}-dp_portal_job_{:08}-part-{:04}-reduce'.format( self._portal_manifest.name, job_id, partition_id) logging.info("Data portal worker-%d is allocated reduce task %s for "\ "partition %d of job %d. the reduce base dir %s"\ "-----------------\n", rank_id, task_name, partition_id, job_id, self._reduce_output_dir(job_id)) return dp_pb.ReduceTask( task_name=task_name, map_base_dir=self._map_output_dir(job_id), reduce_base_dir=self._reduce_output_dir(job_id), partition_id=partition_id) def _try_to_alloc_part(self, rank_id, src_state, target_state): alloc_partition_id = None processing_job = self._sync_processing_job() assert processing_job is not None job_id = self._processing_job.job_id for partition_id in range(self._output_partition_num): part_job = self._sync_job_part(job_id, partition_id) if part_job.part_state == src_state and \ alloc_partition_id is None: alloc_partition_id = partition_id if part_job.part_state == target_state and \ part_job.rank_id == rank_id: alloc_partition_id = partition_id break if alloc_partition_id is None: return None part_job = self._job_part_map[alloc_partition_id] if part_job.part_state == src_state: new_job_part = dp_pb.PortalJobPart(job_id=job_id, rank_id=rank_id, partition_id=alloc_partition_id, part_state=target_state) self._update_job_part(new_job_part) return alloc_partition_id def _sync_portal_job(self, job_id): kvstore_key = common.portal_job_kvstore_key(self._portal_name, job_id) data = self._kvstore.get_data(kvstore_key) if data is not None: return text_format.Parse(data, dp_pb.DataPortalJob(), allow_unknown_field=True) return None def _sync_processing_job(self): assert self._sync_portal_manifest() is not None if self._portal_manifest.processing_job_id < 0: self._processing_job = None elif self._processing_job is None or \ (self._processing_job.job_id != self._portal_manifest.processing_job_id): job_id = self._portal_manifest.processing_job_id self._processing_job = self._sync_portal_job(job_id) assert self._processing_job is not None return self._processing_job def _update_processing_job(self, job): self._processing_job = None kvstore_key = common.portal_job_kvstore_key(self._portal_name, job.job_id) self._kvstore.set_data(kvstore_key, text_format.MessageToString(job)) self._processing_job = job def _sync_portal_manifest(self): if self._portal_manifest is None: kvstore_key = common.portal_kvstore_base_dir(self._portal_name) data = self._kvstore.get_data(kvstore_key) if data is not None: self._portal_manifest = \ text_format.Parse(data, dp_pb.DataPortalManifest(), allow_unknown_field=True) return self._portal_manifest def _update_portal_manifest(self, new_portal_manifest): self._portal_manifest = None kvstore_key = common.portal_kvstore_base_dir(self._portal_name) data = text_format.MessageToString(new_portal_manifest) self._kvstore.set_data(kvstore_key, data) self._portal_manifest = new_portal_manifest def _launch_new_portal_job(self): assert self._sync_processing_job() is None rest_fpaths = self._list_input_dir() if len(rest_fpaths) == 0: logging.info("no file left for portal") return False rest_fpaths.sort() portal_mainifest = self._sync_portal_manifest() new_job = dp_pb.DataPortalJob(job_id=portal_mainifest.next_job_id, finished=False, fpaths=rest_fpaths) self._update_processing_job(new_job) new_portal_manifest = dp_pb.DataPortalManifest() new_portal_manifest.MergeFrom(portal_mainifest) new_portal_manifest.next_job_id += 1 new_portal_manifest.processing_job_id = new_job.job_id self._update_portal_manifest(new_portal_manifest) for partition_id in range(self._output_partition_num): self._sync_job_part(new_job.job_id, partition_id) logging.info("Data Portal job %d has lanuched. %d files will be"\ "processed\n------------\n", new_job.job_id, len(new_job.fpaths)) for seq, fpath in enumerate(new_job.fpaths): logging.info("%d. %s", seq, fpath) logging.info("---------------------------------\n") return True def _list_dir_helper_oss(self, root): # oss returns a file multiple times, e.g. listdir('root') returns # ['folder', 'file1.txt', 'folder/file2.txt'] # and then listdir('root/folder') returns # ['file2.txt'] filenames = set(path.join(root, i) for i in gfile.ListDirectory(root)) res = [] for fname in filenames: succ = path.join(path.dirname(fname), '_SUCCESS') if succ in filenames or not gfile.IsDirectory(fname): res.append(fname) return res def _list_dir_helper(self, root): filenames = list(gfile.ListDirectory(root)) # If _SUCCESS is present, we assume there are no subdirs if '_SUCCESS' in filenames: return [path.join(root, i) for i in filenames] res = [] for basename in filenames: fname = path.join(root, basename) if gfile.IsDirectory(fname): # 'ignore tmp dirs starting with _ if basename.startswith('_'): continue res += self._list_dir_helper(fname) else: res.append(fname) return res def _list_input_dir(self): logging.info("List input directory, it will take some time...") root = self._portal_manifest.input_base_dir wildcard = self._portal_manifest.input_file_wildcard if root.startswith('oss://'): all_files = set(self._list_dir_helper_oss(root)) else: all_files = set(self._list_dir_helper(root)) num_ignored = 0 num_target_files = 0 num_new_files = 0 by_folder = {} for fname in all_files: splits = path.split(path.relpath(fname, root)) basename = splits[-1] dirnames = splits[:-1] # ignore files and dirs starting with _ or . # for example: _SUCCESS or ._SUCCESS.crc ignore = False for name in splits: if name.startswith('_') or name.startswith('.'): ignore = True break if ignore: num_ignored += 1 continue # check wildcard if wildcard and not fnmatch(fname, wildcard): continue num_target_files += 1 # check success tag if self._check_success_tag: succ_fname = path.join(root, *dirnames, '_SUCCESS') if succ_fname not in all_files: continue if fname in self._processed_fpath: continue num_new_files += 1 folder = path.join(*dirnames) if folder not in by_folder: by_folder[folder] = [] by_folder[folder].append(fname) if not by_folder: rest_fpaths = [] elif self._single_subfolder: rest_folder, rest_fpaths = sorted(by_folder.items(), key=lambda x: x[0])[0] logging.info( 'single_subfolder is set. Only process folder %s ' 'in this iteration', rest_folder) else: rest_fpaths = [] if (self._files_per_job_limit <= 0 or self._files_per_job_limit > self._max_files_per_job) and \ sum([len(v) for _, v in by_folder.items()]) > \ self._max_files_per_job: logging.info( "Number of files exceeds limit, processing " "%d per job", self._max_files_per_job) self._files_per_job_limit = self._max_files_per_job for _, v in sorted(by_folder.items(), key=lambda x: x[0]): if self._files_per_job_limit and rest_fpaths and \ len(rest_fpaths) + len(v) > self._files_per_job_limit: break rest_fpaths.extend(v) logging.info( 'Listing %s: found %d dirs, %d files, %d tmp files ignored, ' '%d files matching wildcard, %d new files to process. ' 'Processing %d files in this iteration.', root, len(by_folder), len(all_files), num_ignored, num_target_files, num_new_files, len(rest_fpaths)) return rest_fpaths def _sync_job_part(self, job_id, partition_id): if partition_id not in self._job_part_map or \ self._job_part_map[partition_id] is None or \ self._job_part_map[partition_id].job_id != job_id: kvstore_key = common.portal_job_part_kvstore_key( self._portal_name, job_id, partition_id) data = self._kvstore.get_data(kvstore_key) if data is None: self._job_part_map[partition_id] = dp_pb.PortalJobPart( job_id=job_id, rank_id=-1, partition_id=partition_id) else: self._job_part_map[partition_id] = \ text_format.Parse(data, dp_pb.PortalJobPart(), allow_unknown_field=True) return self._job_part_map[partition_id] def _update_job_part(self, job_part): partition_id = job_part.partition_id if partition_id not in self._job_part_map or \ self._job_part_map[partition_id] != job_part: self._job_part_map[partition_id] = None kvstore_key = common.portal_job_part_kvstore_key( self._portal_name, job_part.job_id, partition_id) data = text_format.MessageToString(job_part) self._kvstore.set_data(kvstore_key, data) self._job_part_map[partition_id] = job_part def _check_processing_job_finished(self): if not self._all_job_part_finished(): return False processing_job = self._sync_processing_job() if not processing_job.finished: finished_job = dp_pb.DataPortalJob() finished_job.MergeFrom(self._processing_job) finished_job.finished = True self._update_processing_job(finished_job) for fpath in processing_job.fpaths: self._processed_fpath.add(fpath) self._processing_job = None self._job_part_map = {} portal_mainifest = self._sync_portal_manifest() if portal_mainifest.processing_job_id >= 0: self._publish_raw_data(portal_mainifest.processing_job_id) new_portal_manifest = dp_pb.DataPortalManifest() new_portal_manifest.MergeFrom(self._sync_portal_manifest()) new_portal_manifest.processing_job_id = -1 self._update_portal_manifest(new_portal_manifest) if processing_job is not None: logging.info("Data Portal job %d has finished. Processed %d "\ "following fpaths\n------------\n", processing_job.job_id, len(processing_job.fpaths)) for seq, fpath in enumerate(processing_job.fpaths): logging.info("%d. %s", seq, fpath) logging.info("---------------------------------\n") return True @property def _output_partition_num(self): return self._portal_manifest.output_partition_num def _is_job_part_finished(self, job_part): assert self._portal_manifest is not None if self._portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: return job_part.part_state == dp_pb.PartState.kIdMapped return job_part.part_state == dp_pb.PartState.kEventTimeReduced def _map_output_dir(self, job_id): return common.portal_map_output_dir( self._portal_manifest.output_base_dir, job_id) def _reduce_output_dir(self, job_id): return common.portal_reduce_output_dir( self._portal_manifest.output_base_dir, job_id) def _publish_raw_data(self, job_id): portal_manifest = self._sync_portal_manifest() output_dir = None if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: output_dir = common.portal_map_output_dir( portal_manifest.output_base_dir, job_id) else: output_dir = common.portal_reduce_output_dir( portal_manifest.output_base_dir, job_id) for partition_id in range(self._output_partition_num): dpath = path.join(output_dir, common.partition_repr(partition_id)) fnames = [] if gfile.Exists(dpath) and gfile.IsDirectory(dpath): fnames = [ f for f in gfile.ListDirectory(dpath) if f.endswith(common.RawDataFileSuffix) ] publish_fpaths = [] if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: publish_fpaths = self._publish_psi_raw_data( partition_id, dpath, fnames) else: publish_fpaths = self._publish_streaming_raw_data( partition_id, dpath, fnames) logging.info("Data Portal Master publish %d file for partition "\ "%d of streaming job %d\n----------\n", len(publish_fpaths), partition_id, job_id) for seq, fpath in enumerate(publish_fpaths): logging.info("%d. %s", seq, fpath) logging.info("------------------------------------------\n") def _publish_streaming_raw_data(self, partition_id, dpath, fnames): metas = [ MergedSortRunMeta.decode_sort_run_meta_from_fname(fname) for fname in fnames ] metas.sort() fpaths = [ path.join(dpath, meta.encode_merged_sort_run_fname()) for meta in metas ] self._publisher.publish_raw_data(partition_id, fpaths) return fpaths def _publish_psi_raw_data(self, partition_id, dpath, fnames): fpaths = [path.join(dpath, fname) for fname in fnames] self._publisher.publish_raw_data(partition_id, fpaths) self._publisher.finish_raw_data(partition_id) return fpaths
class PortalRepartitioner(object): def __init__(self, etcd, portal_manifest, portal_options): self._lock = threading.Lock() self._etcd = etcd self._portal_manifest = portal_manifest self._portal_options = portal_options self._publisher = RawDataPublisher( self._etcd, self._portal_manifest.raw_data_publish_dir) self._started = False self._input_ready_datetime = [] self._output_finished_datetime = [] self._worker_map = {} def start_routine_workers(self): with self._lock: if not self._started: self._worker_map = { self._input_data_ready_sniffer_name(): RoutineWorker(self._input_data_ready_sniffer_name(), self._input_data_ready_sniff_fn, self._input_data_ready_sniff_cond, 5), self._repart_executor_name(): RoutineWorker(self._repart_executor_name(), self._repart_execute_fn, self._repart_execute_cond, 5), self._committed_datetime_forwarder_name(): RoutineWorker(self._committed_datetime_forwarder_name(), self._committed_datetime_forward_fn, self._committed_datetime_forward_cond, 5) } self._started = True for _, w in self._worker_map.items(): w.start_routine() self._started = True def stop_routine_workers(self): wait_join = True with self._lock: if self._started: wait_join = True self._started = False if wait_join: for w in self._worker_map.values(): w.stop_routine() @classmethod def _input_data_ready_sniffer_name(cls): return 'input-data-ready-sniffer' @classmethod def _repart_executor_name(cls): return 'repart-executor' @classmethod def _committed_datetime_forwarder_name(cls): return 'committed-datetime-forwarder' def _ready_input_data_count(self): with self._lock: return len(self._input_ready_datetime) def _check_datetime_stale(self, date_time): with self._lock: committed_datetime = common.convert_timestamp_to_datetime( common.trim_timestamp_by_hourly( self._portal_manifest.committed_timestamp)) if date_time > committed_datetime: idx = bisect.bisect_left(self._input_ready_datetime, date_time) if idx < len(self._input_ready_datetime) and \ self._input_ready_datetime[idx] == date_time: return True idx = bisect.bisect_left(self._output_finished_datetime, date_time) if idx < len(self._output_finished_datetime) and \ self._output_finished_datetime[idx] == date_time: return True return False return True def _check_input_data_ready(self, date_time): finish_tag = common.encode_portal_hourly_finish_tag( self._portal_manifest.input_data_base_dir, date_time) return gfile.Exists(finish_tag) def _check_output_data_finish(self, date_time): finish_tag = common.encode_portal_hourly_finish_tag( self._portal_manifest.output_data_base_dir, date_time) return gfile.Exists(finish_tag) def _put_input_ready_datetime(self, date_time): with self._lock: idx = bisect.bisect_left(self._input_ready_datetime, date_time) if idx == len(self._input_ready_datetime) or \ self._input_ready_datetime[idx] != date_time: self._input_ready_datetime.insert(idx, date_time) def _get_potral_manifest(self): with self._lock: return self._portal_manifest @classmethod def _get_required_datetime(cls, portal_manifest): committed_datetime = common.convert_timestamp_to_datetime( common.trim_timestamp_by_hourly( portal_manifest.committed_timestamp)) begin_datetime = common.convert_timestamp_to_datetime( common.trim_timestamp_by_hourly(portal_manifest.begin_timestamp)) if begin_datetime >= committed_datetime + timedelta(hours=1): return begin_datetime return committed_datetime + timedelta(hours=1) def _wakeup_input_data_ready_sniffer(self): self._worker_map[self._input_data_ready_sniffer_name()].wakeup() def _input_data_ready_sniff_fn(self): committed_datetime = None date_time = None end_datetime = None with self._lock: date_time = self._get_required_datetime(self._portal_manifest) end_datetime = date_time + timedelta(days=1) assert date_time is not None and end_datetime is not None while date_time < end_datetime: if not self._check_datetime_stale(date_time) and \ self._check_input_data_ready(date_time): self._put_input_ready_datetime(date_time) date_time += timedelta(hours=1) if self._ready_input_data_count() > 0: self._wakeup_repart_executor() def _input_data_ready_sniff_cond(self): return self._ready_input_data_count() < 24 def _get_next_input_ready_datetime(self): while self._ready_input_data_count() > 0: date_time = None with self._lock: if len(self._input_ready_datetime) == 0: break date_time = self._input_ready_datetime[0] if not self._check_output_data_finish(date_time): return date_time self._transform_datetime_finished(date_time) return None def _transform_datetime_finished(self, date_time): with self._lock: idx = bisect.bisect_left(self._input_ready_datetime, date_time) if idx == len(self._input_ready_datetime) or \ self._input_ready_datetime[idx] != date_time: return self._input_ready_datetime.pop(idx) idx = bisect.bisect_left(self._output_finished_datetime, date_time) if idx == len(self._output_finished_datetime) or \ self._output_finished_datetime[idx] != date_time: self._output_finished_datetime.insert(idx, date_time) def _update_portal_commited_timestamp(self, new_committed_datetime): new_manifest = None with self._lock: old_committed_datetime = common.convert_timestamp_to_datetime( common.trim_timestamp_by_hourly( self._portal_manifest.committed_timestamp)) assert new_committed_datetime > old_committed_datetime new_manifest = common_pb.DataJoinPortalManifest() new_manifest.MergeFrom(self._portal_manifest) assert new_manifest is not None new_manifest.committed_timestamp.MergeFrom( common.trim_timestamp_by_hourly( common.convert_datetime_to_timestamp(new_committed_datetime))) common.commit_portal_manifest(self._etcd, new_manifest) return new_manifest def _wakeup_repart_executor(self): self._worker_map[self._repart_executor_name()].wakeup() def _repart_execute_fn(self): while True: date_time = self._get_next_input_ready_datetime() if date_time is None: break repart_reducer = portal_reducer.PotralHourlyInputReducer( self._get_potral_manifest(), self._portal_options, date_time) repart_mapper = portal_mapper.PotralHourlyOutputMapper( self._get_potral_manifest(), self._portal_options, date_time) for item in repart_reducer.make_reducer(): repart_mapper.map_data(item) repart_mapper.finish_map() self._transform_datetime_finished(date_time) self._wakeup_committed_datetime_forwarder() self._wakeup_input_data_ready_sniffer() def _repart_execute_cond(self): with self._lock: return len(self._input_ready_datetime) > 0 def _wakeup_committed_datetime_forwarder(self): self._worker_map[self._committed_datetime_forwarder_name()].wakeup() def _committed_datetime_forward_fn(self): new_committed_datetime = None updated = False pub_finfos = {} with self._lock: required_datetime = self._get_required_datetime( self._portal_manifest) idx = bisect.bisect_left(self._output_finished_datetime, required_datetime) partition_num = self._portal_manifest.output_partition_num for date_time in self._output_finished_datetime[idx:]: required_datetime = date_time if date_time != required_datetime: break ts = common.trim_timestamp_by_hourly( common.convert_datetime_to_timestamp(date_time)) for partition_id in range(partition_num): fpath = common.encode_portal_hourly_fpath( self._portal_manifest.output_data_base_dir, date_time, partition_id) if partition_id not in pub_finfos: pub_finfos[partition_id] = ([fpath], [ts]) else: pub_finfos[partition_id][0].append(fpath) pub_finfos[partition_id][1].append(ts) new_committed_datetime = required_datetime required_datetime += timedelta(hours=1) updated = True if updated: for partition_id, (fpaths, timestamps) in pub_finfos.items(): self._publisher.publish_raw_data(partition_id, fpaths, timestamps) assert new_committed_datetime is not None updated_manifest = \ self._update_portal_commited_timestamp( new_committed_datetime ) with self._lock: self._portal_manifest = updated_manifest skip_cnt = 0 for date_time in self._output_finished_datetime: if date_time <= new_committed_datetime: skip_cnt += 1 self._output_finished_datetime = \ self._output_finished_datetime[skip_cnt:] self._wakeup_input_data_ready_sniffer() def _committed_datetime_forward_cond(self): with self._lock: required_datetime = \ self._get_required_datetime(self._portal_manifest) idx = bisect.bisect_left(self._output_finished_datetime, required_datetime) return idx < len(self._output_finished_datetime) and \ self._output_finished_datetime[idx] == required_datetime