Example #1
0
 def __init__(self, options, kvstore_type, use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     kvstore = DBClient(kvstore_type, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(kvstore, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._callback_submitter = None
     # pre fock sub processor before launch grpc client
     self._process_pool_executor.submit(min, 1, 2).result()
     self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options)
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._callback_submitter = concur_futures.ThreadPoolExecutor(1)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, self._callback_submitter,
             public_key, self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
             dj_pb.SortRunMergerOptions(
                 merger_name='sort_run_merger_'+\
                             partition_repr(options.partition_id),
                 reader_options=dj_pb.RawDataOptions(
                     raw_data_iter=options.writer_options.output_writer,
                     compressed_type=options.writer_options.compressed_type,
                     read_ahead_size=\
                         options.sort_run_merger_read_ahead_buffer,
                     read_batch_size=\
                         options.sort_run_merger_read_batch_size
                 ),
                 writer_options=options.writer_options,
                 output_file_dir=options.output_file_dir,
                 partition_id=options.partition_id
             ),
             self._merger_comparator
         )
     self._produce_item_cnt = 0
     self._comsume_item_cnt = 0
     self._started = False
 def __init__(self, kvstore, portal_name, long_running, check_success_tag):
     self._lock = threading.Lock()
     self._kvstore = kvstore
     self._portal_name = portal_name
     self._check_success_tag = check_success_tag
     self._portal_manifest = None
     self._processing_job = None
     self._sync_portal_manifest()
     self._sync_processing_job()
     self._publisher = \
         RawDataPublisher(kvstore,
             self._portal_manifest.raw_data_publish_dir)
     self._long_running = long_running
     assert self._portal_manifest is not None
     self._processed_fpath = set()
     for job_id in range(0, self._portal_manifest.next_job_id):
         job = self._sync_portal_job(job_id)
         assert job is not None and job.job_id == job_id
         for fpath in job.fpaths:
             self._processed_fpath.add(fpath)
     self._job_part_map = {}
     if self._portal_manifest.processing_job_id >= 0:
         self._check_processing_job_finished()
     if self._portal_manifest.processing_job_id < 0:
         self._launch_new_portal_job()
 def __init__(self,
              options,
              etcd_name,
              etcd_addrs,
              etcd_base_dir,
              use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(etcd, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._id_batch_fetcher = IdBatchFetcher(etcd, self._options)
     max_flying_item = options.batch_processor_options.max_flying_item
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher, max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, public_key,
             self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
             dj_pb.SortRunMergerOptions(
                 merger_name='sort_run_merger_'+\
                             partition_repr(options.partition_id),
                 reader_options=dj_pb.RawDataOptions(
                     raw_data_iter=options.writer_options.output_writer,
                     compressed_type=options.writer_options.compressed_type,
                     read_ahead_size=\
                         options.sort_run_merger_read_ahead_buffer
                 ),
                 writer_options=options.writer_options,
                 output_file_dir=options.output_file_dir,
                 partition_id=options.partition_id
             ),
             'example_id'
         )
     self._started = False
 def __init__(self, etcd, portal_manifest, portal_options):
     self._lock = threading.Lock()
     self._etcd = etcd
     self._portal_manifest = portal_manifest
     self._portal_options = portal_options
     self._publisher = RawDataPublisher(
         self._etcd, self._portal_manifest.raw_data_publish_dir)
     self._started = False
     self._input_ready_datetime = []
     self._output_finished_datetime = []
     self._worker_map = {}
 def __init__(self,
              options,
              etcd_name,
              etcd_addrs,
              etcd_base_dir,
              use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(etcd, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._thread_pool_rpc_executor = None
     if options.rpc_sync_mode:
         assert options.rpc_thread_pool_size > 0
         self._thread_pool_rpc_executor = concur_futures.ThreadPoolExecutor(
             options.rpc_thread_pool_size)
     self._id_batch_fetcher = IdBatchFetcher(etcd, self._options)
     max_flying_item = options.batch_processor_options.max_flying_item
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher, max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, public_key,
             self._options.leader_rsa_psi_signer_addr,
             self._thread_pool_rpc_executor)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
         self._sort_run_dumper.sort_run_dump_dir(), self._options)
     self._started = False
 def __init__(self,
              options,
              etcd_name,
              etcd_addrs,
              etcd_base_dir,
              use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(etcd, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._id_batch_fetcher = IdBatchFetcher(self._options)
     max_flying_item = options.batch_processor_options.max_flying_item
     if self._options.role == common_pb.FLRole.Leader:
         private_key = None
         with gfile.GFile(options.rsa_key_file_path, 'rb') as f:
             file_content = f.read()
             private_key = rsa.PrivateKey.load_pkcs1(file_content)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             max_flying_item,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = None
         with gfile.GFile(options.rsa_key_file_path, 'rb') as f:
             file_content = f.read()
             public_key = rsa.PublicKey.load_pkcs1(file_content)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher, max_flying_item,
             self._process_pool_executor, public_key,
             self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
         self._sort_run_dumper.sort_run_dump_dir, self._options)
     self._worker_map = {}
     self._started = False
 def __init__(self,
              kvstore,
              portal_name,
              long_running,
              check_success_tag,
              single_subfolder,
              files_per_job_limit,
              max_files_per_job=8000,
              start_date=None,
              end_date=None):
     self._lock = threading.Lock()
     self._kvstore = kvstore
     self._portal_name = portal_name
     self._check_success_tag = check_success_tag
     self._single_subfolder = single_subfolder
     self._files_per_job_limit = files_per_job_limit
     self._max_files_per_job = max_files_per_job
     self._start_date = convert_to_datetime(start_date)
     self._end_date = convert_to_datetime(end_date)
     self._portal_manifest = None
     self._processing_job = None
     self._sync_portal_manifest()
     self._sync_processing_job()
     self._publisher = \
         RawDataPublisher(kvstore,
             self._portal_manifest.raw_data_publish_dir)
     self._long_running = long_running
     self._finished = False
     assert self._portal_manifest is not None
     self._processed_fpath = set()
     for job_id in range(0, self._portal_manifest.next_job_id):
         job = self._sync_portal_job(job_id)
         assert job is not None and job.job_id == job_id
         for fpath in job.fpaths:
             self._processed_fpath.add(fpath)
     self._job_part_map = {}
     if self._portal_manifest.processing_job_id >= 0:
         self._check_processing_job_finished()
     if self._portal_manifest.processing_job_id < 0:
         if not self._launch_new_portal_job() and not self._long_running:
             self._finished = True
class DataPortalJobManager(object):
    def __init__(self, etcd, portal_name, long_running):
        self._lock = threading.Lock()
        self._etcd = etcd
        self._portal_name = portal_name
        self._portal_manifest = None
        self._processing_job = None
        self._sync_portal_manifest()
        self._sync_processing_job()
        self._publisher = \
            RawDataPublisher(etcd, self._portal_manifest.raw_data_publish_dir)
        self._long_running = long_running
        assert self._portal_manifest is not None
        self._processed_fpath = set()
        for job_id in range(0, self._portal_manifest.next_job_id):
            job = self._sync_portal_job(job_id)
            assert job is not None and job.job_id == job_id
            for fpath in job.fpaths:
                self._processed_fpath.add(fpath)
        self._job_part_map = {}
        if self._portal_manifest.processing_job_id >= 0:
            self._check_processing_job_finished()
        if self._portal_manifest.processing_job_id < 0:
            self._launch_new_portal_job()

    def get_portal_manifest(self):
        with self._lock:
            return self._sync_portal_manifest()

    def alloc_task(self, rank_id):
        with self._lock:
            self._sync_processing_job()
            if self._processing_job is not None:
                partition_id = self._try_to_alloc_part(rank_id,
                                                       dp_pb.PartState.kInit,
                                                       dp_pb.PartState.kIdMap)
                if partition_id is not None:
                    return False, self._create_map_task(rank_id, partition_id)
                if self._all_job_part_mapped() and \
                        (self._portal_manifest.data_portal_type ==
                                dp_pb.DataPortalType.Streaming):
                    partition_id = self._try_to_alloc_part(
                        rank_id, dp_pb.PartState.kIdMapped,
                        dp_pb.PartState.kEventTimeReduce)
                    if partition_id is not None:
                        return False, self._create_reduce_task(partition_id)
                return (not self._long_running
                        and self._all_job_part_finished()), None
            return not self._long_running, None

    def finish_task(self, rank_id, partition_id, part_state):
        with self._lock:
            processing_job = self._sync_processing_job()
            if processing_job is None:
                return
            job_id = self._processing_job.job_id
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.rank_id == rank_id and \
                    job_part.part_state == part_state:
                if job_part.part_state == dp_pb.PartState.kIdMap:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kIdMap,
                                          dp_pb.PartState.kIdMapped)
                elif job_part.part_state == dp_pb.PartState.kEventTimeReduce:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kEventTimeReduce,
                                          dp_pb.PartState.kEventTimeReduced)
            self._check_processing_job_finished()

    def backgroup_task(self):
        with self._lock:
            if self._sync_processing_job() is not None:
                self._check_processing_job_finished()
            if self._sync_processing_job() is None and self._long_running:
                self._launch_new_portal_job()

    def _all_job_part_mapped(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.part_state <= dp_pb.PartState.kIdMap:
                return False
        return True

    def _all_job_part_finished(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if not self._is_job_part_finished(job_part):
                return False
        return True

    def _finish_job_part(self, job_id, partition_id, src_state, target_state):
        job_part = self._sync_job_part(job_id, partition_id)
        assert job_part is not None and job_part.part_state == src_state
        new_job_part = dp_pb.PortalJobPart()
        new_job_part.MergeFrom(job_part)
        new_job_part.part_state = target_state
        new_job_part.rank_id = -1
        self._update_job_part(new_job_part)

    def _create_map_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        map_fpaths = []
        for fpath in job.fpaths:
            fname = path.basename(fpath)
            if hash(fname) % self._output_partition_num == partition_id:
                map_fpaths.append(fpath)
        return dp_pb.MapTask(fpaths=map_fpaths,
                             output_base_dir=self._map_output_dir(job.job_id),
                             output_partition_num=self._output_partition_num,
                             partition_id=partition_id,
                             part_field=self._get_part_field())

    def _get_part_field(self):
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return 'raw_id'
        assert portal_mainifest.data_portal_type == \
                dp_pb.DataPortalType.Streaming
        return 'example_id'

    def _create_reduce_task(self, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        job_id = job.job_id
        return dp_pb.ReduceTask(
            map_base_dir=self._map_output_dir(job_id),
            reduce_base_dir=self._reduce_output_dir(job_id),
            partition_id=partition_id)

    def _try_to_alloc_part(self, rank_id, src_state, target_state):
        alloc_partition_id = None
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            part_job = self._sync_job_part(job_id, partition_id)
            if part_job.part_state == src_state and \
                    alloc_partition_id is None:
                alloc_partition_id = partition_id
            if part_job.part_state == target_state and \
                    part_job.rank_id == rank_id:
                alloc_partition_id = partition_id
                break
        if alloc_partition_id is None:
            return None
        part_job = self._job_part_map[alloc_partition_id]
        if part_job.part_state == src_state:
            new_job_part = dp_pb.PortalJobPart(job_id=job_id,
                                               rank_id=rank_id,
                                               partition_id=alloc_partition_id,
                                               part_state=target_state)
            self._update_job_part(new_job_part)
        return alloc_partition_id

    def _sync_portal_job(self, job_id):
        etcd_key = common.portal_job_etcd_key(self._portal_name, job_id)
        data = self._etcd.get_data(etcd_key)
        if data is not None:
            return text_format.Parse(data, dp_pb.DataPortalJob())
        return None

    def _sync_processing_job(self):
        assert self._sync_portal_manifest() is not None
        if self._portal_manifest.processing_job_id < 0:
            self._processing_job = None
        elif self._processing_job is None or \
                (self._processing_job.job_id !=
                    self._portal_manifest.processing_job_id):
            job_id = self._portal_manifest.processing_job_id
            self._processing_job = self._sync_portal_job(job_id)
            assert self._processing_job is not None
        return self._processing_job

    def _update_processing_job(self, job):
        self._processing_job = None
        etcd_key = common.portal_job_etcd_key(self._portal_name, job.job_id)
        self._etcd.set_data(etcd_key, text_format.MessageToString(job))
        self._processing_job = job

    def _sync_portal_manifest(self):
        if self._portal_manifest is None:
            etcd_key = common.portal_etcd_base_dir(self._portal_name)
            data = self._etcd.get_data(etcd_key)
            if data is not None:
                self._portal_manifest = \
                    text_format.Parse(data, dp_pb.DataPortalManifest())
        return self._portal_manifest

    def _update_portal_manifest(self, new_portal_manifest):
        self._portal_manifest = None
        etcd_key = common.portal_etcd_base_dir(self._portal_name)
        data = text_format.MessageToString(new_portal_manifest)
        self._etcd.set_data(etcd_key, data)
        self._portal_manifest = new_portal_manifest

    def _launch_new_portal_job(self):
        assert self._sync_processing_job() is None
        all_fpaths = self._list_input_dir()
        rest_fpaths = []
        for fpath in all_fpaths:
            if fpath not in self._processed_fpath:
                rest_fpaths.append(fpath)
        if len(rest_fpaths) == 0:
            logging.info("no file left for portal")
            return
        rest_fpaths.sort()
        portal_mainifest = self._sync_portal_manifest()
        new_job = dp_pb.DataPortalJob(job_id=portal_mainifest.next_job_id,
                                      finished=False,
                                      fpaths=rest_fpaths)
        self._update_processing_job(new_job)
        new_portal_manifest = dp_pb.DataPortalManifest()
        new_portal_manifest.MergeFrom(portal_mainifest)
        new_portal_manifest.next_job_id += 1
        new_portal_manifest.processing_job_id = new_job.job_id
        self._update_portal_manifest(new_portal_manifest)
        for partition_id in range(self._output_partition_num):
            self._sync_job_part(new_job.job_id, partition_id)

    def _list_input_dir(self):
        input_dir = self._portal_manifest.input_base_dir
        fnames = gfile.ListDirectory(input_dir)
        if len(self._portal_manifest.input_file_wildcard) > 0:
            wildcard = self._portal_manifest.input_file_wildcard
            fnames = [f for f in fnames if fnmatch(f, wildcard)]
        return [path.join(input_dir, f) for f in fnames]

    def _sync_job_part(self, job_id, partition_id):
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] is None or \
                self._job_part_map[partition_id].job_id != job_id:
            etcd_key = common.portal_job_part_etcd_key(self._portal_name,
                                                       job_id, partition_id)
            data = self._etcd.get_data(etcd_key)
            if data is None:
                self._job_part_map[partition_id] = dp_pb.PortalJobPart(
                    job_id=job_id, rank_id=-1, partition_id=partition_id)
            else:
                self._job_part_map[partition_id] = \
                    text_format.Parse(data, dp_pb.PortalJobPart())
        return self._job_part_map[partition_id]

    def _update_job_part(self, job_part):
        partition_id = job_part.partition_id
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] != job_part:
            self._job_part_map[partition_id] = None
            etcd_key = common.portal_job_part_etcd_key(self._portal_name,
                                                       job_part.job_id,
                                                       partition_id)
            data = text_format.MessageToString(job_part)
            self._etcd.set_data(etcd_key, data)
        self._job_part_map[partition_id] = job_part

    def _check_processing_job_finished(self):
        if not self._all_job_part_finished():
            return False
        processing_job = self._sync_processing_job()
        if not processing_job.finished:
            finished_job = dp_pb.DataPortalJob()
            finished_job.MergeFrom(self._processing_job)
            finished_job.finished = True
            self._update_processing_job(finished_job)
        self._processing_job = None
        self._job_part_map = {}
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.processing_job_id >= 0:
            self._publish_raw_data(portal_mainifest.processing_job_id)
            new_portal_manifest = dp_pb.DataPortalManifest()
            new_portal_manifest.MergeFrom(self._sync_portal_manifest())
            new_portal_manifest.processing_job_id = -1
            self._update_portal_manifest(new_portal_manifest)
        return True

    @property
    def _output_partition_num(self):
        return self._portal_manifest.output_partition_num

    def _is_job_part_finished(self, job_part):
        assert self._portal_manifest is not None
        if self._portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return job_part.part_state == dp_pb.PartState.kIdMapped
        return job_part.part_state == dp_pb.PartState.kEventTimeReduced

    def _map_output_dir(self, job_id):
        return common.portal_map_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _reduce_output_dir(self, job_id):
        return common.portal_reduce_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _publish_raw_data(self, job_id):
        portal_manifest = self._sync_portal_manifest()
        output_dir = None
        if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            output_dir = common.portal_map_output_dir(
                portal_manifest.output_base_dir, job_id)
        else:
            output_dir = common.portal_reduce_output_dir(
                portal_manifest.output_base_dir, job_id)
        for partition_id in range(self._output_partition_num):
            dpath = path.join(output_dir, common.partition_repr(partition_id))
            fnames = []
            if gfile.Exists(dpath) and gfile.IsDirectory(dpath):
                fnames = [
                    f for f in gfile.ListDirectory(dpath)
                    if f.endswith(common.RawDataFileSuffix)
                ]
            if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
                self._publish_psi_raw_data(partition_id, dpath, fnames)
            else:
                self._publish_streaming_raw_data(partition_id, dpath, fnames)

    def _publish_streaming_raw_data(self, partition_id, dpath, fnames):
        metas = [
            MergedSortRunMeta.decode_sort_run_meta_from_fname(fname)
            for fname in fnames
        ]
        metas.sort()
        fpaths = [
            path.join(dpath, meta.encode_merged_sort_run_fname())
            for meta in metas
        ]
        self._publisher.publish_raw_data(partition_id, fpaths)

    def _publish_psi_raw_data(self, partition_id, dpath, fnames):
        metas = [
            RawDataPartitioner.FileMeta.decode_meta_from_fname(fname)
            for fname in fnames
        ]
        fpaths = [
            path.join(dpath, meta.encode_meta_to_fname()) for meta in metas
        ]
        self._publisher.publish_raw_data(partition_id, fpaths)
        self._publisher.finish_raw_data(partition_id)
Example #9
0
class RsaPsiPreProcessor(object):
    def __init__(self, options, kvstore_type, use_mock_etcd=False):
        self._lock = threading.Condition()
        self._options = options
        kvstore = DBClient(kvstore_type, use_mock_etcd)
        pub_dir = self._options.raw_data_publish_dir
        self._publisher = RawDataPublisher(kvstore, pub_dir)
        self._process_pool_executor = \
                concur_futures.ProcessPoolExecutor(
                        options.offload_processor_number
                    )
        self._callback_submitter = None
        # pre fock sub processor before launch grpc client
        self._process_pool_executor.submit(min, 1, 2).result()
        self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options)
        if self._options.role == common_pb.FLRole.Leader:
            private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
            self._psi_rsa_signer = LeaderPsiRsaSigner(
                self._id_batch_fetcher,
                options.batch_processor_options.max_flying_item,
                self._options.max_flying_sign_batch,
                self._options.slow_sign_threshold,
                self._process_pool_executor,
                private_key,
            )
            self._repr = 'leader-' + 'rsa_psi_preprocessor'
        else:
            public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
            self._callback_submitter = concur_futures.ThreadPoolExecutor(1)
            self._psi_rsa_signer = FollowerPsiRsaSigner(
                self._id_batch_fetcher,
                options.batch_processor_options.max_flying_item,
                self._options.max_flying_sign_batch,
                self._options.max_flying_sign_rpc,
                self._options.sign_rpc_timeout_ms,
                self._options.slow_sign_threshold, self._options.stub_fanout,
                self._process_pool_executor, self._callback_submitter,
                public_key, self._options.leader_rsa_psi_signer_addr)
            self._repr = 'follower-' + 'rsa_psi_preprocessor'
        self._sort_run_dumper = SortRunDumper(options)
        self._sort_run_merger = SortRunMerger(
                dj_pb.SortRunMergerOptions(
                    merger_name='sort_run_merger_'+\
                                partition_repr(options.partition_id),
                    reader_options=dj_pb.RawDataOptions(
                        raw_data_iter=options.writer_options.output_writer,
                        compressed_type=options.writer_options.compressed_type,
                        read_ahead_size=\
                            options.sort_run_merger_read_ahead_buffer,
                        read_batch_size=\
                            options.sort_run_merger_read_batch_size
                    ),
                    writer_options=options.writer_options,
                    output_file_dir=options.output_file_dir,
                    partition_id=options.partition_id
                ),
                self._merger_comparator
            )
        self._produce_item_cnt = 0
        self._comsume_item_cnt = 0
        self._started = False

    def start_process(self):
        with self._lock:
            if not self._started:
                self._worker_map = {
                    self._id_batch_fetcher_name():
                    RoutineWorker(self._id_batch_fetcher_name(),
                                  self._id_batch_fetch_fn,
                                  self._id_batch_fetch_cond, 5),
                    self._psi_rsa_signer_name():
                    RoutineWorker(self._psi_rsa_signer_name(),
                                  self._psi_rsa_sign_fn,
                                  self._psi_rsa_sign_cond, 5),
                    self._sort_run_dumper_name():
                    RoutineWorker(self._sort_run_dumper_name(),
                                  self._sort_run_dump_fn,
                                  self._sort_run_dump_cond, 5),
                    self._sort_run_merger_name():
                    RoutineWorker(self._sort_run_merger_name(),
                                  self._sort_run_merge_fn,
                                  self._sort_run_merge_cond, 5)
                }
                for _, w in self._worker_map.items():
                    w.start_routine()
                self._started = True

    def stop_routine_workers(self):
        wait_join = True
        with self._lock:
            if self._started:
                wait_join = True
                self._started = False
        if wait_join:
            for w in self._worker_map.values():
                w.stop_routine()

    def wait_for_finished(self):
        while not self._sort_run_merger.is_merged_finished():
            with self._lock:
                self._lock.wait()
        self.stop_routine_workers()
        self._process_pool_executor.shutdown()
        if self._callback_submitter is not None:
            self._callback_submitter.shutdown()
        self._id_batch_fetcher.cleanup_visitor_meta_data()
        self._bye_for_signer()

    def _bye_for_signer(self):
        for rnd in range(60):
            try:
                self._psi_rsa_signer.say_signer_bye()
                logging.info("Success to say bye to signer at round "\
                             "%d, rsa_psi_preprocessor will exit", rnd)
                return
            except Exception as e:  # pylint: disable=broad-except
                logging.warning("Failed to say bye to signer at "\
                                "round %d, sleep 10s and retry", rnd)
            time.sleep(10)
        logging.warning("Give up to say bye to signer after try 60"\
                        "times, rsa_psi_preprocessor will exit as -1")
        traceback.print_stack()
        os._exit(-1)  # pylint: disable=protected-access

    def _id_batch_fetcher_name(self):
        return self._repr + ':id_batch_fetcher'

    def _wakeup_id_batch_fetcher(self):
        self._worker_map[self._id_batch_fetcher_name()].wakeup()

    def _id_batch_fetch_fn(self):
        next_index = self._psi_rsa_signer.get_next_index_to_fetch()
        for batch in self._id_batch_fetcher.make_processor(next_index):
            logging.debug("%s fetch batch begin at %d, len %d. wakeup %s",
                          self._id_batch_fetcher_name(), batch.begin_index,
                          len(batch), self._psi_rsa_signer_name())
            self._produce_item_cnt += len(batch)
            self._wakeup_psi_rsa_signer()
            if self._stop_fetch_id():
                break

    def _id_batch_fetch_cond(self):
        next_index = self._psi_rsa_signer.get_next_index_to_fetch()
        return self._id_batch_fetcher.need_process(next_index) and \
                not self._stop_fetch_id() and \
                not self._sort_run_dumper.is_dump_finished()

    def _stop_fetch_id(self):
        total_flying_item = self._produce_item_cnt - self._comsume_item_cnt
        if total_flying_item >= 5 << 20:
            logging.warning("stop fetch id since flying item "\
                            "reach to %d > 5m, produce_item_cnt: %d; "\
                            "consume_item_cnt: %d", total_flying_item,
                            self._produce_item_cnt, self._comsume_item_cnt)
            return True
        potential_mem_incr = total_flying_item * \
                             self._psi_rsa_signer.additional_item_mem_usage()
        if get_heap_mem_stats(None).CheckOomRisk(total_flying_item, 0.80,
                                                 potential_mem_incr):
            logging.warning("stop fetch id since has oom risk for 0.80, "\
                            "flying item reach to %d", total_flying_item)
            return True
        return False

    def _psi_rsa_signer_name(self):
        return self._repr + ':psi_rsa_signer'

    def _wakeup_psi_rsa_signer(self):
        self._worker_map[self._psi_rsa_signer_name()].wakeup()

    def _transmit_signed_batch(self, signed_index):
        evict_batch_cnt = self._id_batch_fetcher.evict_staless_item_batch(
            signed_index)
        self._psi_rsa_signer.update_next_batch_index_hint(evict_batch_cnt)
        self._wakeup_sort_run_dumper()

    def _psi_rsa_sign_fn(self):
        next_index = self._sort_run_dumper.get_next_index_to_dump()
        sign_cnt = 0
        signed_index = None
        for signed_batch in self._psi_rsa_signer.make_processor(next_index):
            logging.debug("%s sign batch begin at %d, len %d. wakeup %s",
                          self._psi_rsa_signer_name(),
                          signed_batch.begin_index, len(signed_batch),
                          self._sort_run_dumper_name())
            sign_cnt += 1
            if signed_batch is not None:
                signed_index = signed_batch.begin_index + len(signed_batch) - 1
            if sign_cnt % 16 == 0:
                self._transmit_signed_batch(signed_index)
        self._transmit_signed_batch(signed_index)

    def _psi_rsa_sign_cond(self):
        next_index = self._sort_run_dumper.get_next_index_to_dump()
        return self._psi_rsa_signer.need_process(next_index) and \
                not self._sort_run_dumper.is_dump_finished()

    def _sort_run_dumper_name(self):
        return self._repr + ':sort_run_dumper'

    def _wakeup_sort_run_dumper(self):
        self._worker_map[self._sort_run_dumper_name()].wakeup()

    def _load_sorted_items_from_rsa_signer(self):
        sort_run_dumper = self._sort_run_dumper
        rsi_signer = self._psi_rsa_signer
        next_index = sort_run_dumper.get_next_index_to_dump()
        hint_index = None
        items_buffer = []
        signed_finished = False
        total_item_num = 0
        max_flying_item = self._options.batch_processor_options.max_flying_item
        sort_run_size = max_flying_item // 2
        while sort_run_size <= 0 or total_item_num < sort_run_size:
            signed_finished, batch, hint_index = \
                rsi_signer.fetch_item_batch_by_index(next_index, hint_index)
            if batch is None:
                break
            assert next_index == batch.begin_index
            for item in batch:
                items_buffer.append(item)
            next_index += len(batch)
            total_item_num += len(batch)
        sorted_items_buffer = sorted(items_buffer, key=lambda item: item[0])
        return signed_finished, sorted_items_buffer, next_index

    def _sort_run_dump_fn(self):
        signed_finished, items_buffer, next_index = \
                self._load_sorted_items_from_rsa_signer()
        sort_run_dumper = self._sort_run_dumper
        if len(items_buffer) > 0:

            def producer(items_buffer):
                for signed_id, item, index in items_buffer:
                    item.set_example_id(signed_id)
                    yield signed_id, index, item

            sort_run_dumper.dump_sort_runs(producer(items_buffer))
        if next_index is not None:
            self._psi_rsa_signer.evict_staless_item_batch(next_index - 1)
        if signed_finished:
            sort_run_dumper.finish_dump_sort_run()
        dump_cnt = len(items_buffer)
        self._comsume_item_cnt += dump_cnt
        del items_buffer
        logging.warning("dump %d item in sort run, and gc %d objects.",
                        dump_cnt, gc.collect())

    def _sort_run_dump_cond(self):
        sort_run_dumper = self._sort_run_dumper
        rsa_signer = self._psi_rsa_signer
        next_index = sort_run_dumper.get_next_index_to_dump()
        max_flying_item = self._options.batch_processor_options.max_flying_item
        dump_finished = sort_run_dumper.is_dump_finished()
        signed_finished = rsa_signer.get_process_finished()
        flying_item_cnt = rsa_signer.get_flying_item_count()
        flying_begin_index = rsa_signer.get_flying_begin_index()
        dump_cands_num = 0
        if flying_begin_index is not None and next_index is not None and \
                (flying_begin_index <= next_index <
                    flying_begin_index + flying_item_cnt):
            dump_cands_num = flying_item_cnt - (next_index -
                                                flying_begin_index)
        return not dump_finished and \
                (signed_finished or
                 (dump_cands_num >= (2 << 20) or
                  (max_flying_item > 2 and
                    dump_cands_num > max_flying_item // 2)) or
                  self._dump_for_forward(dump_cands_num))

    def _dump_for_forward(self, dump_cands_num):
        if self._stop_fetch_id():
            total_flying_item = self._produce_item_cnt - self._comsume_item_cnt
            return dump_cands_num > 0 and \
                    dump_cands_num >= total_flying_item // 2
        return False

    def _sort_run_merger_name(self):
        return self._repr + ':sort_run_merger'

    def _sort_run_merge_fn(self):
        sort_runs = self._sort_run_dumper.get_all_sort_runs()
        input_dir = self._sort_run_dumper.sort_run_dump_dir()
        input_fpaths = [
            os.path.join(input_dir, partition_repr(self._options.partition_id),
                         sort_run.encode_sort_run_fname())
            for sort_run in sort_runs
        ]
        output_fpaths = self._sort_run_merger.merge_sort_runs(input_fpaths)
        self._publisher.publish_raw_data(self._options.partition_id,
                                         output_fpaths)
        self._publisher.finish_raw_data(self._options.partition_id)
        self._sort_run_merger.set_merged_finished()

    def _sort_run_merge_cond(self):
        if self._sort_run_merger.is_merged_finished():
            with self._lock:
                self._lock.notify()
            return False
        return self._sort_run_dumper.is_dump_finished()

    @staticmethod
    def _merger_comparator(a, b):
        return a.example_id < b.example_id
Example #10
0
class RsaPsiPreProcessor(object):
    def __init__(self,
                 options,
                 etcd_name,
                 etcd_addrs,
                 etcd_base_dir,
                 use_mock_etcd=False):
        self._lock = threading.Condition()
        self._options = options
        etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd)
        pub_dir = self._options.raw_data_publish_dir
        self._publisher = RawDataPublisher(etcd, pub_dir)
        self._process_pool_executor = \
                concur_futures.ProcessPoolExecutor(
                        options.offload_processor_number
                    )
        self._id_batch_fetcher = IdBatchFetcher(etcd, self._options)
        max_flying_item = options.batch_processor_options.max_flying_item
        if self._options.role == common_pb.FLRole.Leader:
            private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
            self._psi_rsa_signer = LeaderPsiRsaSigner(
                self._id_batch_fetcher,
                max_flying_item,
                self._options.max_flying_sign_batch,
                self._options.slow_sign_threshold,
                self._process_pool_executor,
                private_key,
            )
            self._repr = 'leader-' + 'rsa_psi_preprocessor'
        else:
            public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
            self._psi_rsa_signer = FollowerPsiRsaSigner(
                self._id_batch_fetcher, max_flying_item,
                self._options.max_flying_sign_batch,
                self._options.max_flying_sign_rpc,
                self._options.sign_rpc_timeout_ms,
                self._options.slow_sign_threshold, self._options.stub_fanout,
                self._process_pool_executor, public_key,
                self._options.leader_rsa_psi_signer_addr)
            self._repr = 'follower-' + 'rsa_psi_preprocessor'
        self._sort_run_dumper = SortRunDumper(options)
        self._sort_run_merger = SortRunMerger(
                dj_pb.SortRunMergerOptions(
                    merger_name='sort_run_merger_'+\
                                partition_repr(options.partition_id),
                    reader_options=dj_pb.RawDataOptions(
                        raw_data_iter=options.writer_options.output_writer,
                        compressed_type=options.writer_options.compressed_type,
                        read_ahead_size=\
                            options.sort_run_merger_read_ahead_buffer
                    ),
                    writer_options=options.writer_options,
                    output_file_dir=options.output_file_dir,
                    partition_id=options.partition_id
                ),
                'example_id'
            )
        self._started = False

    def start_process(self):
        with self._lock:
            if not self._started:
                self._worker_map = {
                    self._id_batch_fetcher_name():
                    RoutineWorker(self._id_batch_fetcher_name(),
                                  self._id_batch_fetch_fn,
                                  self._id_batch_fetch_cond, 5),
                    self._psi_rsa_signer_name():
                    RoutineWorker(self._psi_rsa_signer_name(),
                                  self._psi_rsa_sign_fn,
                                  self._psi_rsa_sign_cond, 5),
                    self._sort_run_dumper_name():
                    RoutineWorker(self._sort_run_dumper_name(),
                                  self._sort_run_dump_fn,
                                  self._sort_run_dump_cond, 5),
                    self._sort_run_merger_name():
                    RoutineWorker(self._sort_run_merger_name(),
                                  self._sort_run_merge_fn,
                                  self._sort_run_merge_cond, 5)
                }
                for _, w in self._worker_map.items():
                    w.start_routine()
                self._started = True

    def stop_routine_workers(self):
        wait_join = True
        with self._lock:
            if self._started:
                wait_join = True
                self._started = False
        if wait_join:
            for w in self._worker_map.values():
                w.stop_routine()

    def wait_for_finished(self):
        while not self._sort_run_merger.is_merged_finished():
            with self._lock:
                self._lock.wait()
        self.stop_routine_workers()
        self._process_pool_executor.shutdown()
        self._id_batch_fetcher.cleanup_visitor_meta_data()

    def _id_batch_fetcher_name(self):
        return self._repr + ':id_batch_fetcher'

    def _wakeup_id_batch_fetcher(self):
        self._worker_map[self._id_batch_fetcher_name()].wakeup()

    def _id_batch_fetch_fn(self):
        next_index = self._psi_rsa_signer.get_next_index_to_fetch()
        for batch in self._id_batch_fetcher.make_processor(next_index):
            logging.debug("%s fetch batch begin at %d, len %d. wakeup %s",
                          self._id_batch_fetcher_name(), batch.begin_index,
                          len(batch), self._psi_rsa_signer_name())
            self._wakeup_psi_rsa_signer()

    def _id_batch_fetch_cond(self):
        next_index = self._psi_rsa_signer.get_next_index_to_fetch()
        return self._id_batch_fetcher.need_process(next_index)

    def _psi_rsa_signer_name(self):
        return self._repr + ':psi_rsa_signer'

    def _wakeup_psi_rsa_signer(self):
        self._worker_map[self._psi_rsa_signer_name()].wakeup()

    def _psi_rsa_sign_fn(self):
        next_index = self._sort_run_dumper.get_next_index_to_dump()
        for signed_batch in self._psi_rsa_signer.make_processor(next_index):
            logging.debug("%s sign batch begin at %d, len %d. wakeup %s",
                          self._psi_rsa_signer_name(),
                          signed_batch.begin_index, len(signed_batch),
                          self._sort_run_dumper_name())
            self._wakeup_sort_run_dumper()
        staless_index = self._sort_run_dumper.get_next_index_to_dump() - 1
        evict_batch_cnt = self._id_batch_fetcher.evict_staless_item_batch(
            staless_index)
        self._psi_rsa_signer.update_next_batch_index_hint(evict_batch_cnt)

    def _psi_rsa_sign_cond(self):
        next_index = self._sort_run_dumper.get_next_index_to_dump()
        return self._psi_rsa_signer.need_process(next_index)

    def _sort_run_dumper_name(self):
        return self._repr + ':sort_run_dumper'

    def _wakeup_sort_run_dumper(self):
        self._worker_map[self._sort_run_dumper_name()].wakeup()

    def _load_sorted_items_from_rsa_signer(self):
        sort_run_dumper = self._sort_run_dumper
        rsi_signer = self._psi_rsa_signer
        next_index = sort_run_dumper.get_next_index_to_dump()
        hint_index = None
        items_buffer = []
        signed_finished = False
        total_item_num = 0
        max_flying_item = self._options.batch_processor_options.max_flying_item
        sort_run_size = max_flying_item // 4
        while True and total_item_num < sort_run_size:
            signed_finished, batch, hint_index = \
                rsi_signer.fetch_item_batch_by_index(next_index, hint_index)
            if batch is None:
                break
            assert next_index == batch.begin_index
            for item in batch:
                items_buffer.append(item)
            next_index += len(batch)
            total_item_num += len(batch)
        sorted_items_buffer = sorted(items_buffer, key=lambda item: item[0])
        return signed_finished, sorted_items_buffer, next_index

    def _sort_run_dump_fn(self):
        signed_finished, items_buffer, next_index = \
                self._load_sorted_items_from_rsa_signer()
        sort_run_dumper = self._sort_run_dumper
        if len(items_buffer) > 0:

            def producer(items_buffer):
                for signed_id, item, index in items_buffer:
                    item.set_example_id(signed_id)
                    yield signed_id, index, item

            sort_run_dumper.dump_sort_runs(producer(items_buffer))
        if next_index is not None:
            self._psi_rsa_signer.evict_staless_item_batch(next_index - 1)
        if signed_finished:
            sort_run_dumper.finish_dump_sort_run()

    def _sort_run_dump_cond(self):
        sort_run_dumper = self._sort_run_dumper
        rsa_signer = self._psi_rsa_signer
        next_index = sort_run_dumper.get_next_index_to_dump()
        max_flying_item = self._options.batch_processor_options.max_flying_item
        dump_finished = sort_run_dumper.is_dump_finished()
        signed_finished = rsa_signer.get_process_finished()
        flying_item_cnt = rsa_signer.get_flying_item_count()
        flying_begin_index = rsa_signer.get_flying_begin_index()
        return not dump_finished and \
                (signed_finished or
                 (flying_begin_index is not None and
                  next_index is not None and
                  (flying_begin_index <= next_index <
                      flying_begin_index + flying_item_cnt) and
                   (flying_item_cnt-(next_index-flying_begin_index) >=
                    max_flying_item // 4)))

    def _sort_run_merger_name(self):
        return self._repr + ':sort_run_merger'

    def _sort_run_merge_fn(self):
        sort_runs = self._sort_run_dumper.get_all_sort_runs()
        input_dir = self._sort_run_dumper.sort_run_dump_dir()
        input_fpaths = [
            os.path.join(input_dir, partition_repr(self._options.partition_id),
                         sort_run.encode_sort_run_fname())
            for sort_run in sort_runs
        ]
        output_fpaths = self._sort_run_merger.merge_sort_runs(input_fpaths)
        self._publisher.publish_raw_data(self._options.partition_id,
                                         output_fpaths)
        self._publisher.finish_raw_data(self._options.partition_id)
        self._sort_run_merger.set_merged_finished()

    def _sort_run_merge_cond(self):
        if self._sort_run_merger.is_merged_finished():
            with self._lock:
                self._lock.notify()
            return False
        return self._sort_run_dumper.is_dump_finished()
Example #11
0
class DataPortalJobManager(object):
    def __init__(self, etcd, portal_name, long_running):
        self._lock = threading.Lock()
        self._etcd = etcd
        self._portal_name = portal_name
        self._portal_manifest = None
        self._processing_job = None
        self._sync_portal_manifest()
        self._sync_processing_job()
        self._publisher = \
            RawDataPublisher(etcd, self._portal_manifest.raw_data_publish_dir)
        self._long_running = long_running
        assert self._portal_manifest is not None
        self._processed_fpath = set()
        for job_id in range(0, self._portal_manifest.next_job_id):
            job = self._sync_portal_job(job_id)
            assert job is not None and job.job_id == job_id
            for fpath in job.fpaths:
                self._processed_fpath.add(fpath)
        self._job_part_map = {}
        if self._portal_manifest.processing_job_id >= 0:
            self._check_processing_job_finished()
        if self._portal_manifest.processing_job_id < 0:
            self._launch_new_portal_job()

    def get_portal_manifest(self):
        with self._lock:
            return self._sync_portal_manifest()

    def alloc_task(self, rank_id):
        with self._lock:
            self._sync_processing_job()
            if self._processing_job is not None:
                partition_id = self._try_to_alloc_part(rank_id,
                                                       dp_pb.PartState.kInit,
                                                       dp_pb.PartState.kIdMap)
                if partition_id is not None:
                    return False, self._create_map_task(rank_id, partition_id)
                if self._all_job_part_mapped() and \
                        (self._portal_manifest.data_portal_type ==
                                dp_pb.DataPortalType.Streaming):
                    partition_id = self._try_to_alloc_part(
                        rank_id, dp_pb.PartState.kIdMapped,
                        dp_pb.PartState.kEventTimeReduce)
                    if partition_id is not None:
                        return False, self._create_reduce_task(
                            rank_id, partition_id)
                return (not self._long_running
                        and self._all_job_part_finished()), None
            return not self._long_running, None

    def finish_task(self, rank_id, partition_id, part_state):
        with self._lock:
            processing_job = self._sync_processing_job()
            if processing_job is None:
                return
            job_id = self._processing_job.job_id
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.rank_id == rank_id and \
                    job_part.part_state == part_state:
                if job_part.part_state == dp_pb.PartState.kIdMap:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kIdMap,
                                          dp_pb.PartState.kIdMapped)
                    logging.info("Data portal worker-%d finish map task "\
                                 "for partition %d of job %d",
                                 rank_id, partition_id, job_id)
                elif job_part.part_state == dp_pb.PartState.kEventTimeReduce:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kEventTimeReduce,
                                          dp_pb.PartState.kEventTimeReduced)
                    logging.info("Data portal worker-%d finish reduce task "\
                                 "for partition %d of job %d",
                                 rank_id, partition_id, job_id)
            self._check_processing_job_finished()

    def backgroup_task(self):
        with self._lock:
            if self._sync_processing_job() is not None:
                self._check_processing_job_finished()
            if self._sync_processing_job() is None and self._long_running:
                self._launch_new_portal_job()

    def _all_job_part_mapped(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.part_state <= dp_pb.PartState.kIdMap:
                return False
        return True

    def _all_job_part_finished(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if not self._is_job_part_finished(job_part):
                return False
        return True

    def _finish_job_part(self, job_id, partition_id, src_state, target_state):
        job_part = self._sync_job_part(job_id, partition_id)
        assert job_part is not None and job_part.part_state == src_state
        new_job_part = dp_pb.PortalJobPart()
        new_job_part.MergeFrom(job_part)
        new_job_part.part_state = target_state
        new_job_part.rank_id = -1
        self._update_job_part(new_job_part)

    def _create_map_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        map_fpaths = []
        for fpath in job.fpaths:
            if hash(fpath) % self._output_partition_num == partition_id:
                map_fpaths.append(fpath)
        task_name = '{}-dp_portal_job_{:08}-part-{:04}-map'.format(
            self._portal_manifest.name, job.job_id, partition_id)
        logging.info("Data portal worker-%d is allocated map task %s for "\
                     "partition %d of job %d. the map task has %d files"\
                     "-----------------\n", rank_id, task_name,
                     partition_id, job.job_id, len(map_fpaths))
        for seq, fpath in enumerate(map_fpaths):
            logging.info("%d. %s", seq, fpath)
        logging.info("---------------------------------\n")
        manifset = self._sync_portal_manifest()
        return dp_pb.MapTask(task_name=task_name,
                             fpaths=map_fpaths,
                             output_base_dir=self._map_output_dir(job.job_id),
                             output_partition_num=self._output_partition_num,
                             partition_id=partition_id,
                             part_field=self._get_part_field(),
                             data_portal_type=manifset.data_portal_type)

    def _get_part_field(self):
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return 'raw_id'
        assert portal_mainifest.data_portal_type == \
                dp_pb.DataPortalType.Streaming
        return 'example_id'

    def _create_reduce_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        job_id = job.job_id
        task_name = '{}-dp_portal_job_{:08}-part-{:04}-reduce'.format(
            self._portal_manifest.name, job_id, partition_id)
        logging.info("Data portal worker-%d is allocated reduce task %s for "\
                     "partition %d of job %d. the reduce base dir %s"\
                     "-----------------\n", rank_id, task_name,
                     partition_id, job_id, self._reduce_output_dir(job_id))
        return dp_pb.ReduceTask(
            task_name=task_name,
            map_base_dir=self._map_output_dir(job_id),
            reduce_base_dir=self._reduce_output_dir(job_id),
            partition_id=partition_id)

    def _try_to_alloc_part(self, rank_id, src_state, target_state):
        alloc_partition_id = None
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            part_job = self._sync_job_part(job_id, partition_id)
            if part_job.part_state == src_state and \
                    alloc_partition_id is None:
                alloc_partition_id = partition_id
            if part_job.part_state == target_state and \
                    part_job.rank_id == rank_id:
                alloc_partition_id = partition_id
                break
        if alloc_partition_id is None:
            return None
        part_job = self._job_part_map[alloc_partition_id]
        if part_job.part_state == src_state:
            new_job_part = dp_pb.PortalJobPart(job_id=job_id,
                                               rank_id=rank_id,
                                               partition_id=alloc_partition_id,
                                               part_state=target_state)
            self._update_job_part(new_job_part)
        return alloc_partition_id

    def _sync_portal_job(self, job_id):
        etcd_key = common.portal_job_etcd_key(self._portal_name, job_id)
        data = self._etcd.get_data(etcd_key)
        if data is not None:
            return text_format.Parse(data, dp_pb.DataPortalJob())
        return None

    def _sync_processing_job(self):
        assert self._sync_portal_manifest() is not None
        if self._portal_manifest.processing_job_id < 0:
            self._processing_job = None
        elif self._processing_job is None or \
                (self._processing_job.job_id !=
                    self._portal_manifest.processing_job_id):
            job_id = self._portal_manifest.processing_job_id
            self._processing_job = self._sync_portal_job(job_id)
            assert self._processing_job is not None
        return self._processing_job

    def _update_processing_job(self, job):
        self._processing_job = None
        etcd_key = common.portal_job_etcd_key(self._portal_name, job.job_id)
        self._etcd.set_data(etcd_key, text_format.MessageToString(job))
        self._processing_job = job

    def _sync_portal_manifest(self):
        if self._portal_manifest is None:
            etcd_key = common.portal_etcd_base_dir(self._portal_name)
            data = self._etcd.get_data(etcd_key)
            if data is not None:
                self._portal_manifest = \
                    text_format.Parse(data, dp_pb.DataPortalManifest())
        return self._portal_manifest

    def _update_portal_manifest(self, new_portal_manifest):
        self._portal_manifest = None
        etcd_key = common.portal_etcd_base_dir(self._portal_name)
        data = text_format.MessageToString(new_portal_manifest)
        self._etcd.set_data(etcd_key, data)
        self._portal_manifest = new_portal_manifest

    def _launch_new_portal_job(self):
        assert self._sync_processing_job() is None
        all_fpaths = self._list_input_dir()
        rest_fpaths = []
        for fpath in all_fpaths:
            if fpath not in self._processed_fpath:
                rest_fpaths.append(fpath)
        if len(rest_fpaths) == 0:
            logging.info("no file left for portal")
            return
        rest_fpaths.sort()
        portal_mainifest = self._sync_portal_manifest()
        new_job = dp_pb.DataPortalJob(job_id=portal_mainifest.next_job_id,
                                      finished=False,
                                      fpaths=rest_fpaths)
        self._update_processing_job(new_job)
        new_portal_manifest = dp_pb.DataPortalManifest()
        new_portal_manifest.MergeFrom(portal_mainifest)
        new_portal_manifest.next_job_id += 1
        new_portal_manifest.processing_job_id = new_job.job_id
        self._update_portal_manifest(new_portal_manifest)
        for partition_id in range(self._output_partition_num):
            self._sync_job_part(new_job.job_id, partition_id)
        logging.info("Data Portal job %d has lanuched. %d files will be"\
                     "processed\n------------\n",
                     new_job.job_id, len(new_job.fpaths))
        for seq, fpath in enumerate(new_job.fpaths):
            logging.info("%d. %s", seq, fpath)
        logging.info("---------------------------------\n")

    def _list_input_dir(self):
        all_inputs = []
        wildcard = self._portal_manifest.input_file_wildcard
        dirs = [self._portal_manifest.input_base_dir]
        while len(dirs) > 0:
            fdir = dirs[0]
            dirs = dirs[1:]
            fnames = gfile.ListDirectory(fdir)
            for fname in fnames:
                fpath = path.join(fdir, fname)
                if gfile.IsDirectory(fpath):
                    dirs.append(fpath)
                elif len(wildcard) == 0 or fnmatch(fname, wildcard):
                    all_inputs.append(fpath)
        return all_inputs

    def _sync_job_part(self, job_id, partition_id):
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] is None or \
                self._job_part_map[partition_id].job_id != job_id:
            etcd_key = common.portal_job_part_etcd_key(self._portal_name,
                                                       job_id, partition_id)
            data = self._etcd.get_data(etcd_key)
            if data is None:
                self._job_part_map[partition_id] = dp_pb.PortalJobPart(
                    job_id=job_id, rank_id=-1, partition_id=partition_id)
            else:
                self._job_part_map[partition_id] = \
                    text_format.Parse(data, dp_pb.PortalJobPart())
        return self._job_part_map[partition_id]

    def _update_job_part(self, job_part):
        partition_id = job_part.partition_id
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] != job_part:
            self._job_part_map[partition_id] = None
            etcd_key = common.portal_job_part_etcd_key(self._portal_name,
                                                       job_part.job_id,
                                                       partition_id)
            data = text_format.MessageToString(job_part)
            self._etcd.set_data(etcd_key, data)
        self._job_part_map[partition_id] = job_part

    def _check_processing_job_finished(self):
        if not self._all_job_part_finished():
            return False
        processing_job = self._sync_processing_job()
        if not processing_job.finished:
            finished_job = dp_pb.DataPortalJob()
            finished_job.MergeFrom(self._processing_job)
            finished_job.finished = True
            self._update_processing_job(finished_job)
        self._processing_job = None
        self._job_part_map = {}
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.processing_job_id >= 0:
            self._publish_raw_data(portal_mainifest.processing_job_id)
            new_portal_manifest = dp_pb.DataPortalManifest()
            new_portal_manifest.MergeFrom(self._sync_portal_manifest())
            new_portal_manifest.processing_job_id = -1
            self._update_portal_manifest(new_portal_manifest)
        if processing_job is not None:
            logging.info("Data Portal job %d has finished. Processed %d "\
                         "following fpaths\n------------\n",
                         processing_job.job_id, len(processing_job.fpaths))
            for seq, fpath in enumerate(processing_job.fpaths):
                logging.info("%d. %s", seq, fpath)
            logging.info("---------------------------------\n")
        return True

    @property
    def _output_partition_num(self):
        return self._portal_manifest.output_partition_num

    def _is_job_part_finished(self, job_part):
        assert self._portal_manifest is not None
        if self._portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return job_part.part_state == dp_pb.PartState.kIdMapped
        return job_part.part_state == dp_pb.PartState.kEventTimeReduced

    def _map_output_dir(self, job_id):
        return common.portal_map_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _reduce_output_dir(self, job_id):
        return common.portal_reduce_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _publish_raw_data(self, job_id):
        portal_manifest = self._sync_portal_manifest()
        output_dir = None
        if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            output_dir = common.portal_map_output_dir(
                portal_manifest.output_base_dir, job_id)
        else:
            output_dir = common.portal_reduce_output_dir(
                portal_manifest.output_base_dir, job_id)
        for partition_id in range(self._output_partition_num):
            dpath = path.join(output_dir, common.partition_repr(partition_id))
            fnames = []
            if gfile.Exists(dpath) and gfile.IsDirectory(dpath):
                fnames = [
                    f for f in gfile.ListDirectory(dpath)
                    if f.endswith(common.RawDataFileSuffix)
                ]
            publish_fpaths = []
            if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
                publish_fpaths = self._publish_psi_raw_data(
                    partition_id, dpath, fnames)
            else:
                publish_fpaths = self._publish_streaming_raw_data(
                    partition_id, dpath, fnames)
            logging.info("Data Portal Master publish %d file for partition "\
                         "%d of streaming job %d\n----------\n",
                         len(publish_fpaths), partition_id, job_id)
            for seq, fpath in enumerate(publish_fpaths):
                logging.info("%d. %s", seq, fpath)
            logging.info("------------------------------------------\n")

    def _publish_streaming_raw_data(self, partition_id, dpath, fnames):
        metas = [
            MergedSortRunMeta.decode_sort_run_meta_from_fname(fname)
            for fname in fnames
        ]
        metas.sort()
        fpaths = [
            path.join(dpath, meta.encode_merged_sort_run_fname())
            for meta in metas
        ]
        self._publisher.publish_raw_data(partition_id, fpaths)
        return fpaths

    def _publish_psi_raw_data(self, partition_id, dpath, fnames):
        fpaths = [path.join(dpath, fname) for fname in fnames]
        self._publisher.publish_raw_data(partition_id, fpaths)
        self._publisher.finish_raw_data(partition_id)
        return fpaths
class DataPortalJobManager(object):
    def __init__(self,
                 kvstore,
                 portal_name,
                 long_running,
                 check_success_tag,
                 single_subfolder,
                 files_per_job_limit,
                 max_files_per_job=8000):
        self._lock = threading.Lock()
        self._kvstore = kvstore
        self._portal_name = portal_name
        self._check_success_tag = check_success_tag
        self._single_subfolder = single_subfolder
        self._files_per_job_limit = files_per_job_limit
        self._max_files_per_job = max_files_per_job
        self._portal_manifest = None
        self._processing_job = None
        self._sync_portal_manifest()
        self._sync_processing_job()
        self._publisher = \
            RawDataPublisher(kvstore,
                self._portal_manifest.raw_data_publish_dir)
        self._long_running = long_running
        self._finished = False
        assert self._portal_manifest is not None
        self._processed_fpath = set()
        for job_id in range(0, self._portal_manifest.next_job_id):
            job = self._sync_portal_job(job_id)
            assert job is not None and job.job_id == job_id
            for fpath in job.fpaths:
                self._processed_fpath.add(fpath)
        self._job_part_map = {}
        if self._portal_manifest.processing_job_id >= 0:
            self._check_processing_job_finished()
        if self._portal_manifest.processing_job_id < 0:
            if not self._launch_new_portal_job() and not self._long_running:
                self._finished = True

    def get_portal_manifest(self):
        with self._lock:
            return self._sync_portal_manifest()

    def alloc_task(self, rank_id):
        with self._lock:
            self._sync_processing_job()
            if self._processing_job is not None:
                partition_id = self._try_to_alloc_part(rank_id,
                                                       dp_pb.PartState.kInit,
                                                       dp_pb.PartState.kIdMap)
                if partition_id is not None:
                    return False, self._create_map_task(rank_id, partition_id)
                if self._all_job_part_mapped() and \
                        (self._portal_manifest.data_portal_type ==
                                dp_pb.DataPortalType.Streaming):
                    partition_id = self._try_to_alloc_part(
                        rank_id, dp_pb.PartState.kIdMapped,
                        dp_pb.PartState.kEventTimeReduce)
                    if partition_id is not None:
                        return False, self._create_reduce_task(
                            rank_id, partition_id)
                return (self._finished and self._all_job_part_finished()), None
            return self._finished, None

    def finish_task(self, rank_id, partition_id, part_state):
        with self._lock:
            processing_job = self._sync_processing_job()
            if processing_job is None:
                return
            job_id = self._processing_job.job_id
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.rank_id == rank_id and \
                    job_part.part_state == part_state:
                if job_part.part_state == dp_pb.PartState.kIdMap:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kIdMap,
                                          dp_pb.PartState.kIdMapped)
                    logging.info("Data portal worker-%d finish map task "\
                                 "for partition %d of job %d",
                                 rank_id, partition_id, job_id)
                elif job_part.part_state == dp_pb.PartState.kEventTimeReduce:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kEventTimeReduce,
                                          dp_pb.PartState.kEventTimeReduced)
                    logging.info("Data portal worker-%d finish reduce task "\
                                 "for partition %d of job %d",
                                 rank_id, partition_id, job_id)
            self._check_processing_job_finished()

    def backgroup_task(self):
        with self._lock:
            if self._sync_processing_job() is not None:
                self._check_processing_job_finished()
            if self._sync_processing_job() is None:
                success = self._launch_new_portal_job()
                if not success and not self._long_running:
                    self._finished = True

    def _all_job_part_mapped(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.part_state <= dp_pb.PartState.kIdMap:
                return False
        return True

    def _all_job_part_finished(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if not self._is_job_part_finished(job_part):
                return False
        return True

    def _finish_job_part(self, job_id, partition_id, src_state, target_state):
        job_part = self._sync_job_part(job_id, partition_id)
        assert job_part is not None and job_part.part_state == src_state
        new_job_part = dp_pb.PortalJobPart()
        new_job_part.MergeFrom(job_part)
        new_job_part.part_state = target_state
        new_job_part.rank_id = -1
        self._update_job_part(new_job_part)

    def _create_map_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        map_fpaths = []
        for fpath in job.fpaths:
            if hash(fpath) % self._output_partition_num == partition_id:
                map_fpaths.append(fpath)
        task_name = '{}-dp_portal_job_{:08}-part-{:04}-map'.format(
            self._portal_manifest.name, job.job_id, partition_id)
        logging.info("Data portal worker-%d is allocated map task %s for "\
                     "partition %d of job %d. the map task has %d files"\
                     "-----------------\n", rank_id, task_name,
                     partition_id, job.job_id, len(map_fpaths))
        for seq, fpath in enumerate(map_fpaths):
            logging.info("%d. %s", seq, fpath)
        logging.info("---------------------------------\n")
        manifset = self._sync_portal_manifest()
        return dp_pb.MapTask(task_name=task_name,
                             fpaths=map_fpaths,
                             output_base_dir=self._map_output_dir(job.job_id),
                             output_partition_num=self._output_partition_num,
                             partition_id=partition_id,
                             part_field=self._get_part_field(),
                             data_portal_type=manifset.data_portal_type)

    def _get_part_field(self):
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return 'raw_id'
        assert portal_mainifest.data_portal_type == \
                dp_pb.DataPortalType.Streaming
        return 'example_id'

    def _create_reduce_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        job_id = job.job_id
        task_name = '{}-dp_portal_job_{:08}-part-{:04}-reduce'.format(
            self._portal_manifest.name, job_id, partition_id)
        logging.info("Data portal worker-%d is allocated reduce task %s for "\
                     "partition %d of job %d. the reduce base dir %s"\
                     "-----------------\n", rank_id, task_name,
                     partition_id, job_id, self._reduce_output_dir(job_id))
        return dp_pb.ReduceTask(
            task_name=task_name,
            map_base_dir=self._map_output_dir(job_id),
            reduce_base_dir=self._reduce_output_dir(job_id),
            partition_id=partition_id)

    def _try_to_alloc_part(self, rank_id, src_state, target_state):
        alloc_partition_id = None
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            part_job = self._sync_job_part(job_id, partition_id)
            if part_job.part_state == src_state and \
                    alloc_partition_id is None:
                alloc_partition_id = partition_id
            if part_job.part_state == target_state and \
                    part_job.rank_id == rank_id:
                alloc_partition_id = partition_id
                break
        if alloc_partition_id is None:
            return None
        part_job = self._job_part_map[alloc_partition_id]
        if part_job.part_state == src_state:
            new_job_part = dp_pb.PortalJobPart(job_id=job_id,
                                               rank_id=rank_id,
                                               partition_id=alloc_partition_id,
                                               part_state=target_state)
            self._update_job_part(new_job_part)
        return alloc_partition_id

    def _sync_portal_job(self, job_id):
        kvstore_key = common.portal_job_kvstore_key(self._portal_name, job_id)
        data = self._kvstore.get_data(kvstore_key)
        if data is not None:
            return text_format.Parse(data,
                                     dp_pb.DataPortalJob(),
                                     allow_unknown_field=True)
        return None

    def _sync_processing_job(self):
        assert self._sync_portal_manifest() is not None
        if self._portal_manifest.processing_job_id < 0:
            self._processing_job = None
        elif self._processing_job is None or \
                (self._processing_job.job_id !=
                    self._portal_manifest.processing_job_id):
            job_id = self._portal_manifest.processing_job_id
            self._processing_job = self._sync_portal_job(job_id)
            assert self._processing_job is not None
        return self._processing_job

    def _update_processing_job(self, job):
        self._processing_job = None
        kvstore_key = common.portal_job_kvstore_key(self._portal_name,
                                                    job.job_id)
        self._kvstore.set_data(kvstore_key, text_format.MessageToString(job))
        self._processing_job = job

    def _sync_portal_manifest(self):
        if self._portal_manifest is None:
            kvstore_key = common.portal_kvstore_base_dir(self._portal_name)
            data = self._kvstore.get_data(kvstore_key)
            if data is not None:
                self._portal_manifest = \
                    text_format.Parse(data, dp_pb.DataPortalManifest(),
                                      allow_unknown_field=True)
        return self._portal_manifest

    def _update_portal_manifest(self, new_portal_manifest):
        self._portal_manifest = None
        kvstore_key = common.portal_kvstore_base_dir(self._portal_name)
        data = text_format.MessageToString(new_portal_manifest)
        self._kvstore.set_data(kvstore_key, data)
        self._portal_manifest = new_portal_manifest

    def _launch_new_portal_job(self):
        assert self._sync_processing_job() is None
        rest_fpaths = self._list_input_dir()
        if len(rest_fpaths) == 0:
            logging.info("no file left for portal")
            return False
        rest_fpaths.sort()
        portal_mainifest = self._sync_portal_manifest()
        new_job = dp_pb.DataPortalJob(job_id=portal_mainifest.next_job_id,
                                      finished=False,
                                      fpaths=rest_fpaths)
        self._update_processing_job(new_job)
        new_portal_manifest = dp_pb.DataPortalManifest()
        new_portal_manifest.MergeFrom(portal_mainifest)
        new_portal_manifest.next_job_id += 1
        new_portal_manifest.processing_job_id = new_job.job_id
        self._update_portal_manifest(new_portal_manifest)
        for partition_id in range(self._output_partition_num):
            self._sync_job_part(new_job.job_id, partition_id)
        logging.info("Data Portal job %d has lanuched. %d files will be"\
                     "processed\n------------\n",
                     new_job.job_id, len(new_job.fpaths))
        for seq, fpath in enumerate(new_job.fpaths):
            logging.info("%d. %s", seq, fpath)
        logging.info("---------------------------------\n")

        return True

    def _list_dir_helper_oss(self, root):
        # oss returns a file multiple times, e.g. listdir('root') returns
        #   ['folder', 'file1.txt', 'folder/file2.txt']
        # and then listdir('root/folder') returns
        #   ['file2.txt']
        filenames = set(path.join(root, i) for i in gfile.ListDirectory(root))
        res = []
        for fname in filenames:
            succ = path.join(path.dirname(fname), '_SUCCESS')
            if succ in filenames or not gfile.IsDirectory(fname):
                res.append(fname)

        return res

    def _list_dir_helper(self, root):
        filenames = list(gfile.ListDirectory(root))
        # If _SUCCESS is present, we assume there are no subdirs
        if '_SUCCESS' in filenames:
            return [path.join(root, i) for i in filenames]

        res = []
        for basename in filenames:
            fname = path.join(root, basename)
            if gfile.IsDirectory(fname):
                # 'ignore tmp dirs starting with _
                if basename.startswith('_'):
                    continue
                res += self._list_dir_helper(fname)
            else:
                res.append(fname)
        return res

    def _list_input_dir(self):
        logging.info("List input directory, it will take some time...")
        root = self._portal_manifest.input_base_dir
        wildcard = self._portal_manifest.input_file_wildcard

        if root.startswith('oss://'):
            all_files = set(self._list_dir_helper_oss(root))
        else:
            all_files = set(self._list_dir_helper(root))

        num_ignored = 0
        num_target_files = 0
        num_new_files = 0
        by_folder = {}
        for fname in all_files:
            splits = path.split(path.relpath(fname, root))
            basename = splits[-1]
            dirnames = splits[:-1]

            # ignore files and dirs starting with _ or .
            # for example: _SUCCESS or ._SUCCESS.crc
            ignore = False
            for name in splits:
                if name.startswith('_') or name.startswith('.'):
                    ignore = True
                    break
            if ignore:
                num_ignored += 1
                continue

            # check wildcard
            if wildcard and not fnmatch(fname, wildcard):
                continue
            num_target_files += 1

            # check success tag
            if self._check_success_tag:
                succ_fname = path.join(root, *dirnames, '_SUCCESS')
                if succ_fname not in all_files:
                    continue

            if fname in self._processed_fpath:
                continue
            num_new_files += 1

            folder = path.join(*dirnames)
            if folder not in by_folder:
                by_folder[folder] = []
            by_folder[folder].append(fname)

        if not by_folder:
            rest_fpaths = []
        elif self._single_subfolder:
            rest_folder, rest_fpaths = sorted(by_folder.items(),
                                              key=lambda x: x[0])[0]
            logging.info(
                'single_subfolder is set. Only process folder %s '
                'in this iteration', rest_folder)
        else:
            rest_fpaths = []
            if (self._files_per_job_limit <= 0 or
                self._files_per_job_limit > self._max_files_per_job) and \
                sum([len(v) for _, v in by_folder.items()]) > \
                    self._max_files_per_job:
                logging.info(
                    "Number of files exceeds limit, processing "
                    "%d per job", self._max_files_per_job)
                self._files_per_job_limit = self._max_files_per_job
            for _, v in sorted(by_folder.items(), key=lambda x: x[0]):
                if self._files_per_job_limit and rest_fpaths and \
                        len(rest_fpaths) + len(v) > self._files_per_job_limit:
                    break
                rest_fpaths.extend(v)

        logging.info(
            'Listing %s: found %d dirs, %d files, %d tmp files ignored, '
            '%d files matching wildcard, %d new files to process. '
            'Processing %d files in this iteration.', root, len(by_folder),
            len(all_files), num_ignored, num_target_files, num_new_files,
            len(rest_fpaths))
        return rest_fpaths

    def _sync_job_part(self, job_id, partition_id):
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] is None or \
                self._job_part_map[partition_id].job_id != job_id:
            kvstore_key = common.portal_job_part_kvstore_key(
                self._portal_name, job_id, partition_id)
            data = self._kvstore.get_data(kvstore_key)
            if data is None:
                self._job_part_map[partition_id] = dp_pb.PortalJobPart(
                    job_id=job_id, rank_id=-1, partition_id=partition_id)
            else:
                self._job_part_map[partition_id] = \
                    text_format.Parse(data, dp_pb.PortalJobPart(),
                                      allow_unknown_field=True)
        return self._job_part_map[partition_id]

    def _update_job_part(self, job_part):
        partition_id = job_part.partition_id
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] != job_part:
            self._job_part_map[partition_id] = None
            kvstore_key = common.portal_job_part_kvstore_key(
                self._portal_name, job_part.job_id, partition_id)
            data = text_format.MessageToString(job_part)
            self._kvstore.set_data(kvstore_key, data)
        self._job_part_map[partition_id] = job_part

    def _check_processing_job_finished(self):
        if not self._all_job_part_finished():
            return False
        processing_job = self._sync_processing_job()
        if not processing_job.finished:
            finished_job = dp_pb.DataPortalJob()
            finished_job.MergeFrom(self._processing_job)
            finished_job.finished = True
            self._update_processing_job(finished_job)
        for fpath in processing_job.fpaths:
            self._processed_fpath.add(fpath)
        self._processing_job = None
        self._job_part_map = {}
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.processing_job_id >= 0:
            self._publish_raw_data(portal_mainifest.processing_job_id)
            new_portal_manifest = dp_pb.DataPortalManifest()
            new_portal_manifest.MergeFrom(self._sync_portal_manifest())
            new_portal_manifest.processing_job_id = -1
            self._update_portal_manifest(new_portal_manifest)
        if processing_job is not None:
            logging.info("Data Portal job %d has finished. Processed %d "\
                         "following fpaths\n------------\n",
                         processing_job.job_id, len(processing_job.fpaths))
            for seq, fpath in enumerate(processing_job.fpaths):
                logging.info("%d. %s", seq, fpath)
            logging.info("---------------------------------\n")
        return True

    @property
    def _output_partition_num(self):
        return self._portal_manifest.output_partition_num

    def _is_job_part_finished(self, job_part):
        assert self._portal_manifest is not None
        if self._portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return job_part.part_state == dp_pb.PartState.kIdMapped
        return job_part.part_state == dp_pb.PartState.kEventTimeReduced

    def _map_output_dir(self, job_id):
        return common.portal_map_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _reduce_output_dir(self, job_id):
        return common.portal_reduce_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _publish_raw_data(self, job_id):
        portal_manifest = self._sync_portal_manifest()
        output_dir = None
        if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            output_dir = common.portal_map_output_dir(
                portal_manifest.output_base_dir, job_id)
        else:
            output_dir = common.portal_reduce_output_dir(
                portal_manifest.output_base_dir, job_id)
        for partition_id in range(self._output_partition_num):
            dpath = path.join(output_dir, common.partition_repr(partition_id))
            fnames = []
            if gfile.Exists(dpath) and gfile.IsDirectory(dpath):
                fnames = [
                    f for f in gfile.ListDirectory(dpath)
                    if f.endswith(common.RawDataFileSuffix)
                ]
            publish_fpaths = []
            if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
                publish_fpaths = self._publish_psi_raw_data(
                    partition_id, dpath, fnames)
            else:
                publish_fpaths = self._publish_streaming_raw_data(
                    partition_id, dpath, fnames)
            logging.info("Data Portal Master publish %d file for partition "\
                         "%d of streaming job %d\n----------\n",
                         len(publish_fpaths), partition_id, job_id)
            for seq, fpath in enumerate(publish_fpaths):
                logging.info("%d. %s", seq, fpath)
            logging.info("------------------------------------------\n")

    def _publish_streaming_raw_data(self, partition_id, dpath, fnames):
        metas = [
            MergedSortRunMeta.decode_sort_run_meta_from_fname(fname)
            for fname in fnames
        ]
        metas.sort()
        fpaths = [
            path.join(dpath, meta.encode_merged_sort_run_fname())
            for meta in metas
        ]
        self._publisher.publish_raw_data(partition_id, fpaths)
        return fpaths

    def _publish_psi_raw_data(self, partition_id, dpath, fnames):
        fpaths = [path.join(dpath, fname) for fname in fnames]
        self._publisher.publish_raw_data(partition_id, fpaths)
        self._publisher.finish_raw_data(partition_id)
        return fpaths
class PortalRepartitioner(object):
    def __init__(self, etcd, portal_manifest, portal_options):
        self._lock = threading.Lock()
        self._etcd = etcd
        self._portal_manifest = portal_manifest
        self._portal_options = portal_options
        self._publisher = RawDataPublisher(
            self._etcd, self._portal_manifest.raw_data_publish_dir)
        self._started = False
        self._input_ready_datetime = []
        self._output_finished_datetime = []
        self._worker_map = {}

    def start_routine_workers(self):
        with self._lock:
            if not self._started:
                self._worker_map = {
                    self._input_data_ready_sniffer_name():
                    RoutineWorker(self._input_data_ready_sniffer_name(),
                                  self._input_data_ready_sniff_fn,
                                  self._input_data_ready_sniff_cond, 5),
                    self._repart_executor_name():
                    RoutineWorker(self._repart_executor_name(),
                                  self._repart_execute_fn,
                                  self._repart_execute_cond, 5),
                    self._committed_datetime_forwarder_name():
                    RoutineWorker(self._committed_datetime_forwarder_name(),
                                  self._committed_datetime_forward_fn,
                                  self._committed_datetime_forward_cond, 5)
                }
                self._started = True
                for _, w in self._worker_map.items():
                    w.start_routine()
                self._started = True

    def stop_routine_workers(self):
        wait_join = True
        with self._lock:
            if self._started:
                wait_join = True
                self._started = False
        if wait_join:
            for w in self._worker_map.values():
                w.stop_routine()

    @classmethod
    def _input_data_ready_sniffer_name(cls):
        return 'input-data-ready-sniffer'

    @classmethod
    def _repart_executor_name(cls):
        return 'repart-executor'

    @classmethod
    def _committed_datetime_forwarder_name(cls):
        return 'committed-datetime-forwarder'

    def _ready_input_data_count(self):
        with self._lock:
            return len(self._input_ready_datetime)

    def _check_datetime_stale(self, date_time):
        with self._lock:
            committed_datetime = common.convert_timestamp_to_datetime(
                common.trim_timestamp_by_hourly(
                    self._portal_manifest.committed_timestamp))
            if date_time > committed_datetime:
                idx = bisect.bisect_left(self._input_ready_datetime, date_time)
                if idx < len(self._input_ready_datetime) and \
                        self._input_ready_datetime[idx] == date_time:
                    return True
                idx = bisect.bisect_left(self._output_finished_datetime,
                                         date_time)
                if idx < len(self._output_finished_datetime) and \
                        self._output_finished_datetime[idx] == date_time:
                    return True
                return False
            return True

    def _check_input_data_ready(self, date_time):
        finish_tag = common.encode_portal_hourly_finish_tag(
            self._portal_manifest.input_data_base_dir, date_time)
        return gfile.Exists(finish_tag)

    def _check_output_data_finish(self, date_time):
        finish_tag = common.encode_portal_hourly_finish_tag(
            self._portal_manifest.output_data_base_dir, date_time)
        return gfile.Exists(finish_tag)

    def _put_input_ready_datetime(self, date_time):
        with self._lock:
            idx = bisect.bisect_left(self._input_ready_datetime, date_time)
            if idx == len(self._input_ready_datetime) or \
                    self._input_ready_datetime[idx] != date_time:
                self._input_ready_datetime.insert(idx, date_time)

    def _get_potral_manifest(self):
        with self._lock:
            return self._portal_manifest

    @classmethod
    def _get_required_datetime(cls, portal_manifest):
        committed_datetime = common.convert_timestamp_to_datetime(
            common.trim_timestamp_by_hourly(
                portal_manifest.committed_timestamp))
        begin_datetime = common.convert_timestamp_to_datetime(
            common.trim_timestamp_by_hourly(portal_manifest.begin_timestamp))
        if begin_datetime >= committed_datetime + timedelta(hours=1):
            return begin_datetime
        return committed_datetime + timedelta(hours=1)

    def _wakeup_input_data_ready_sniffer(self):
        self._worker_map[self._input_data_ready_sniffer_name()].wakeup()

    def _input_data_ready_sniff_fn(self):
        committed_datetime = None
        date_time = None
        end_datetime = None
        with self._lock:
            date_time = self._get_required_datetime(self._portal_manifest)
            end_datetime = date_time + timedelta(days=1)
        assert date_time is not None and end_datetime is not None
        while date_time < end_datetime:
            if not self._check_datetime_stale(date_time) and \
                    self._check_input_data_ready(date_time):
                self._put_input_ready_datetime(date_time)
            date_time += timedelta(hours=1)
        if self._ready_input_data_count() > 0:
            self._wakeup_repart_executor()

    def _input_data_ready_sniff_cond(self):
        return self._ready_input_data_count() < 24

    def _get_next_input_ready_datetime(self):
        while self._ready_input_data_count() > 0:
            date_time = None
            with self._lock:
                if len(self._input_ready_datetime) == 0:
                    break
                date_time = self._input_ready_datetime[0]
            if not self._check_output_data_finish(date_time):
                return date_time
            self._transform_datetime_finished(date_time)
        return None

    def _transform_datetime_finished(self, date_time):
        with self._lock:
            idx = bisect.bisect_left(self._input_ready_datetime, date_time)
            if idx == len(self._input_ready_datetime) or \
                    self._input_ready_datetime[idx] != date_time:
                return
            self._input_ready_datetime.pop(idx)
            idx = bisect.bisect_left(self._output_finished_datetime, date_time)
            if idx == len(self._output_finished_datetime) or \
                    self._output_finished_datetime[idx] != date_time:
                self._output_finished_datetime.insert(idx, date_time)

    def _update_portal_commited_timestamp(self, new_committed_datetime):
        new_manifest = None
        with self._lock:
            old_committed_datetime = common.convert_timestamp_to_datetime(
                common.trim_timestamp_by_hourly(
                    self._portal_manifest.committed_timestamp))
            assert new_committed_datetime > old_committed_datetime
            new_manifest = common_pb.DataJoinPortalManifest()
            new_manifest.MergeFrom(self._portal_manifest)
        assert new_manifest is not None
        new_manifest.committed_timestamp.MergeFrom(
            common.trim_timestamp_by_hourly(
                common.convert_datetime_to_timestamp(new_committed_datetime)))
        common.commit_portal_manifest(self._etcd, new_manifest)
        return new_manifest

    def _wakeup_repart_executor(self):
        self._worker_map[self._repart_executor_name()].wakeup()

    def _repart_execute_fn(self):
        while True:
            date_time = self._get_next_input_ready_datetime()
            if date_time is None:
                break
            repart_reducer = portal_reducer.PotralHourlyInputReducer(
                self._get_potral_manifest(), self._portal_options, date_time)
            repart_mapper = portal_mapper.PotralHourlyOutputMapper(
                self._get_potral_manifest(), self._portal_options, date_time)
            for item in repart_reducer.make_reducer():
                repart_mapper.map_data(item)
            repart_mapper.finish_map()
            self._transform_datetime_finished(date_time)
            self._wakeup_committed_datetime_forwarder()
            self._wakeup_input_data_ready_sniffer()

    def _repart_execute_cond(self):
        with self._lock:
            return len(self._input_ready_datetime) > 0

    def _wakeup_committed_datetime_forwarder(self):
        self._worker_map[self._committed_datetime_forwarder_name()].wakeup()

    def _committed_datetime_forward_fn(self):
        new_committed_datetime = None
        updated = False
        pub_finfos = {}
        with self._lock:
            required_datetime = self._get_required_datetime(
                self._portal_manifest)
            idx = bisect.bisect_left(self._output_finished_datetime,
                                     required_datetime)
            partition_num = self._portal_manifest.output_partition_num
            for date_time in self._output_finished_datetime[idx:]:
                required_datetime = date_time
                if date_time != required_datetime:
                    break
                ts = common.trim_timestamp_by_hourly(
                    common.convert_datetime_to_timestamp(date_time))
                for partition_id in range(partition_num):
                    fpath = common.encode_portal_hourly_fpath(
                        self._portal_manifest.output_data_base_dir, date_time,
                        partition_id)
                    if partition_id not in pub_finfos:
                        pub_finfos[partition_id] = ([fpath], [ts])
                    else:
                        pub_finfos[partition_id][0].append(fpath)
                        pub_finfos[partition_id][1].append(ts)
                new_committed_datetime = required_datetime
                required_datetime += timedelta(hours=1)
                updated = True
        if updated:
            for partition_id, (fpaths, timestamps) in pub_finfos.items():
                self._publisher.publish_raw_data(partition_id, fpaths,
                                                 timestamps)
            assert new_committed_datetime is not None
            updated_manifest = \
                    self._update_portal_commited_timestamp(
                            new_committed_datetime
                        )
            with self._lock:
                self._portal_manifest = updated_manifest
                skip_cnt = 0
                for date_time in self._output_finished_datetime:
                    if date_time <= new_committed_datetime:
                        skip_cnt += 1
                self._output_finished_datetime = \
                        self._output_finished_datetime[skip_cnt:]
            self._wakeup_input_data_ready_sniffer()

    def _committed_datetime_forward_cond(self):
        with self._lock:
            required_datetime = \
                    self._get_required_datetime(self._portal_manifest)
            idx = bisect.bisect_left(self._output_finished_datetime,
                                     required_datetime)
            return idx < len(self._output_finished_datetime) and \
                    self._output_finished_datetime[idx] == required_datetime