class DataPortalJobManager(object):
    def __init__(self, etcd, portal_name, long_running):
        self._lock = threading.Lock()
        self._etcd = etcd
        self._portal_name = portal_name
        self._portal_manifest = None
        self._processing_job = None
        self._sync_portal_manifest()
        self._sync_processing_job()
        self._publisher = \
            RawDataPublisher(etcd, self._portal_manifest.raw_data_publish_dir)
        self._long_running = long_running
        assert self._portal_manifest is not None
        self._processed_fpath = set()
        for job_id in range(0, self._portal_manifest.next_job_id):
            job = self._sync_portal_job(job_id)
            assert job is not None and job.job_id == job_id
            for fpath in job.fpaths:
                self._processed_fpath.add(fpath)
        self._job_part_map = {}
        if self._portal_manifest.processing_job_id >= 0:
            self._check_processing_job_finished()
        if self._portal_manifest.processing_job_id < 0:
            self._launch_new_portal_job()

    def get_portal_manifest(self):
        with self._lock:
            return self._sync_portal_manifest()

    def alloc_task(self, rank_id):
        with self._lock:
            self._sync_processing_job()
            if self._processing_job is not None:
                partition_id = self._try_to_alloc_part(rank_id,
                                                       dp_pb.PartState.kInit,
                                                       dp_pb.PartState.kIdMap)
                if partition_id is not None:
                    return False, self._create_map_task(rank_id, partition_id)
                if self._all_job_part_mapped() and \
                        (self._portal_manifest.data_portal_type ==
                                dp_pb.DataPortalType.Streaming):
                    partition_id = self._try_to_alloc_part(
                        rank_id, dp_pb.PartState.kIdMapped,
                        dp_pb.PartState.kEventTimeReduce)
                    if partition_id is not None:
                        return False, self._create_reduce_task(partition_id)
                return (not self._long_running
                        and self._all_job_part_finished()), None
            return not self._long_running, None

    def finish_task(self, rank_id, partition_id, part_state):
        with self._lock:
            processing_job = self._sync_processing_job()
            if processing_job is None:
                return
            job_id = self._processing_job.job_id
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.rank_id == rank_id and \
                    job_part.part_state == part_state:
                if job_part.part_state == dp_pb.PartState.kIdMap:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kIdMap,
                                          dp_pb.PartState.kIdMapped)
                elif job_part.part_state == dp_pb.PartState.kEventTimeReduce:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kEventTimeReduce,
                                          dp_pb.PartState.kEventTimeReduced)
            self._check_processing_job_finished()

    def backgroup_task(self):
        with self._lock:
            if self._sync_processing_job() is not None:
                self._check_processing_job_finished()
            if self._sync_processing_job() is None and self._long_running:
                self._launch_new_portal_job()

    def _all_job_part_mapped(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.part_state <= dp_pb.PartState.kIdMap:
                return False
        return True

    def _all_job_part_finished(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if not self._is_job_part_finished(job_part):
                return False
        return True

    def _finish_job_part(self, job_id, partition_id, src_state, target_state):
        job_part = self._sync_job_part(job_id, partition_id)
        assert job_part is not None and job_part.part_state == src_state
        new_job_part = dp_pb.PortalJobPart()
        new_job_part.MergeFrom(job_part)
        new_job_part.part_state = target_state
        new_job_part.rank_id = -1
        self._update_job_part(new_job_part)

    def _create_map_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        map_fpaths = []
        for fpath in job.fpaths:
            fname = path.basename(fpath)
            if hash(fname) % self._output_partition_num == partition_id:
                map_fpaths.append(fpath)
        return dp_pb.MapTask(fpaths=map_fpaths,
                             output_base_dir=self._map_output_dir(job.job_id),
                             output_partition_num=self._output_partition_num,
                             partition_id=partition_id,
                             part_field=self._get_part_field())

    def _get_part_field(self):
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return 'raw_id'
        assert portal_mainifest.data_portal_type == \
                dp_pb.DataPortalType.Streaming
        return 'example_id'

    def _create_reduce_task(self, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        job_id = job.job_id
        return dp_pb.ReduceTask(
            map_base_dir=self._map_output_dir(job_id),
            reduce_base_dir=self._reduce_output_dir(job_id),
            partition_id=partition_id)

    def _try_to_alloc_part(self, rank_id, src_state, target_state):
        alloc_partition_id = None
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            part_job = self._sync_job_part(job_id, partition_id)
            if part_job.part_state == src_state and \
                    alloc_partition_id is None:
                alloc_partition_id = partition_id
            if part_job.part_state == target_state and \
                    part_job.rank_id == rank_id:
                alloc_partition_id = partition_id
                break
        if alloc_partition_id is None:
            return None
        part_job = self._job_part_map[alloc_partition_id]
        if part_job.part_state == src_state:
            new_job_part = dp_pb.PortalJobPart(job_id=job_id,
                                               rank_id=rank_id,
                                               partition_id=alloc_partition_id,
                                               part_state=target_state)
            self._update_job_part(new_job_part)
        return alloc_partition_id

    def _sync_portal_job(self, job_id):
        etcd_key = common.portal_job_etcd_key(self._portal_name, job_id)
        data = self._etcd.get_data(etcd_key)
        if data is not None:
            return text_format.Parse(data, dp_pb.DataPortalJob())
        return None

    def _sync_processing_job(self):
        assert self._sync_portal_manifest() is not None
        if self._portal_manifest.processing_job_id < 0:
            self._processing_job = None
        elif self._processing_job is None or \
                (self._processing_job.job_id !=
                    self._portal_manifest.processing_job_id):
            job_id = self._portal_manifest.processing_job_id
            self._processing_job = self._sync_portal_job(job_id)
            assert self._processing_job is not None
        return self._processing_job

    def _update_processing_job(self, job):
        self._processing_job = None
        etcd_key = common.portal_job_etcd_key(self._portal_name, job.job_id)
        self._etcd.set_data(etcd_key, text_format.MessageToString(job))
        self._processing_job = job

    def _sync_portal_manifest(self):
        if self._portal_manifest is None:
            etcd_key = common.portal_etcd_base_dir(self._portal_name)
            data = self._etcd.get_data(etcd_key)
            if data is not None:
                self._portal_manifest = \
                    text_format.Parse(data, dp_pb.DataPortalManifest())
        return self._portal_manifest

    def _update_portal_manifest(self, new_portal_manifest):
        self._portal_manifest = None
        etcd_key = common.portal_etcd_base_dir(self._portal_name)
        data = text_format.MessageToString(new_portal_manifest)
        self._etcd.set_data(etcd_key, data)
        self._portal_manifest = new_portal_manifest

    def _launch_new_portal_job(self):
        assert self._sync_processing_job() is None
        all_fpaths = self._list_input_dir()
        rest_fpaths = []
        for fpath in all_fpaths:
            if fpath not in self._processed_fpath:
                rest_fpaths.append(fpath)
        if len(rest_fpaths) == 0:
            logging.info("no file left for portal")
            return
        rest_fpaths.sort()
        portal_mainifest = self._sync_portal_manifest()
        new_job = dp_pb.DataPortalJob(job_id=portal_mainifest.next_job_id,
                                      finished=False,
                                      fpaths=rest_fpaths)
        self._update_processing_job(new_job)
        new_portal_manifest = dp_pb.DataPortalManifest()
        new_portal_manifest.MergeFrom(portal_mainifest)
        new_portal_manifest.next_job_id += 1
        new_portal_manifest.processing_job_id = new_job.job_id
        self._update_portal_manifest(new_portal_manifest)
        for partition_id in range(self._output_partition_num):
            self._sync_job_part(new_job.job_id, partition_id)

    def _list_input_dir(self):
        input_dir = self._portal_manifest.input_base_dir
        fnames = gfile.ListDirectory(input_dir)
        if len(self._portal_manifest.input_file_wildcard) > 0:
            wildcard = self._portal_manifest.input_file_wildcard
            fnames = [f for f in fnames if fnmatch(f, wildcard)]
        return [path.join(input_dir, f) for f in fnames]

    def _sync_job_part(self, job_id, partition_id):
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] is None or \
                self._job_part_map[partition_id].job_id != job_id:
            etcd_key = common.portal_job_part_etcd_key(self._portal_name,
                                                       job_id, partition_id)
            data = self._etcd.get_data(etcd_key)
            if data is None:
                self._job_part_map[partition_id] = dp_pb.PortalJobPart(
                    job_id=job_id, rank_id=-1, partition_id=partition_id)
            else:
                self._job_part_map[partition_id] = \
                    text_format.Parse(data, dp_pb.PortalJobPart())
        return self._job_part_map[partition_id]

    def _update_job_part(self, job_part):
        partition_id = job_part.partition_id
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] != job_part:
            self._job_part_map[partition_id] = None
            etcd_key = common.portal_job_part_etcd_key(self._portal_name,
                                                       job_part.job_id,
                                                       partition_id)
            data = text_format.MessageToString(job_part)
            self._etcd.set_data(etcd_key, data)
        self._job_part_map[partition_id] = job_part

    def _check_processing_job_finished(self):
        if not self._all_job_part_finished():
            return False
        processing_job = self._sync_processing_job()
        if not processing_job.finished:
            finished_job = dp_pb.DataPortalJob()
            finished_job.MergeFrom(self._processing_job)
            finished_job.finished = True
            self._update_processing_job(finished_job)
        self._processing_job = None
        self._job_part_map = {}
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.processing_job_id >= 0:
            self._publish_raw_data(portal_mainifest.processing_job_id)
            new_portal_manifest = dp_pb.DataPortalManifest()
            new_portal_manifest.MergeFrom(self._sync_portal_manifest())
            new_portal_manifest.processing_job_id = -1
            self._update_portal_manifest(new_portal_manifest)
        return True

    @property
    def _output_partition_num(self):
        return self._portal_manifest.output_partition_num

    def _is_job_part_finished(self, job_part):
        assert self._portal_manifest is not None
        if self._portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return job_part.part_state == dp_pb.PartState.kIdMapped
        return job_part.part_state == dp_pb.PartState.kEventTimeReduced

    def _map_output_dir(self, job_id):
        return common.portal_map_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _reduce_output_dir(self, job_id):
        return common.portal_reduce_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _publish_raw_data(self, job_id):
        portal_manifest = self._sync_portal_manifest()
        output_dir = None
        if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            output_dir = common.portal_map_output_dir(
                portal_manifest.output_base_dir, job_id)
        else:
            output_dir = common.portal_reduce_output_dir(
                portal_manifest.output_base_dir, job_id)
        for partition_id in range(self._output_partition_num):
            dpath = path.join(output_dir, common.partition_repr(partition_id))
            fnames = []
            if gfile.Exists(dpath) and gfile.IsDirectory(dpath):
                fnames = [
                    f for f in gfile.ListDirectory(dpath)
                    if f.endswith(common.RawDataFileSuffix)
                ]
            if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
                self._publish_psi_raw_data(partition_id, dpath, fnames)
            else:
                self._publish_streaming_raw_data(partition_id, dpath, fnames)

    def _publish_streaming_raw_data(self, partition_id, dpath, fnames):
        metas = [
            MergedSortRunMeta.decode_sort_run_meta_from_fname(fname)
            for fname in fnames
        ]
        metas.sort()
        fpaths = [
            path.join(dpath, meta.encode_merged_sort_run_fname())
            for meta in metas
        ]
        self._publisher.publish_raw_data(partition_id, fpaths)

    def _publish_psi_raw_data(self, partition_id, dpath, fnames):
        metas = [
            RawDataPartitioner.FileMeta.decode_meta_from_fname(fname)
            for fname in fnames
        ]
        fpaths = [
            path.join(dpath, meta.encode_meta_to_fname()) for meta in metas
        ]
        self._publisher.publish_raw_data(partition_id, fpaths)
        self._publisher.finish_raw_data(partition_id)
Example #2
0
class RsaPsiPreProcessor(object):
    def __init__(self, options, kvstore_type, use_mock_etcd=False):
        self._lock = threading.Condition()
        self._options = options
        kvstore = DBClient(kvstore_type, use_mock_etcd)
        pub_dir = self._options.raw_data_publish_dir
        self._publisher = RawDataPublisher(kvstore, pub_dir)
        self._process_pool_executor = \
                concur_futures.ProcessPoolExecutor(
                        options.offload_processor_number
                    )
        self._callback_submitter = None
        # pre fock sub processor before launch grpc client
        self._process_pool_executor.submit(min, 1, 2).result()
        self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options)
        if self._options.role == common_pb.FLRole.Leader:
            private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
            self._psi_rsa_signer = LeaderPsiRsaSigner(
                self._id_batch_fetcher,
                options.batch_processor_options.max_flying_item,
                self._options.max_flying_sign_batch,
                self._options.slow_sign_threshold,
                self._process_pool_executor,
                private_key,
            )
            self._repr = 'leader-' + 'rsa_psi_preprocessor'
        else:
            public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
            self._callback_submitter = concur_futures.ThreadPoolExecutor(1)
            self._psi_rsa_signer = FollowerPsiRsaSigner(
                self._id_batch_fetcher,
                options.batch_processor_options.max_flying_item,
                self._options.max_flying_sign_batch,
                self._options.max_flying_sign_rpc,
                self._options.sign_rpc_timeout_ms,
                self._options.slow_sign_threshold, self._options.stub_fanout,
                self._process_pool_executor, self._callback_submitter,
                public_key, self._options.leader_rsa_psi_signer_addr)
            self._repr = 'follower-' + 'rsa_psi_preprocessor'
        self._sort_run_dumper = SortRunDumper(options)
        self._sort_run_merger = SortRunMerger(
                dj_pb.SortRunMergerOptions(
                    merger_name='sort_run_merger_'+\
                                partition_repr(options.partition_id),
                    reader_options=dj_pb.RawDataOptions(
                        raw_data_iter=options.writer_options.output_writer,
                        compressed_type=options.writer_options.compressed_type,
                        read_ahead_size=\
                            options.sort_run_merger_read_ahead_buffer,
                        read_batch_size=\
                            options.sort_run_merger_read_batch_size
                    ),
                    writer_options=options.writer_options,
                    output_file_dir=options.output_file_dir,
                    partition_id=options.partition_id
                ),
                self._merger_comparator
            )
        self._produce_item_cnt = 0
        self._comsume_item_cnt = 0
        self._started = False

    def start_process(self):
        with self._lock:
            if not self._started:
                self._worker_map = {
                    self._id_batch_fetcher_name():
                    RoutineWorker(self._id_batch_fetcher_name(),
                                  self._id_batch_fetch_fn,
                                  self._id_batch_fetch_cond, 5),
                    self._psi_rsa_signer_name():
                    RoutineWorker(self._psi_rsa_signer_name(),
                                  self._psi_rsa_sign_fn,
                                  self._psi_rsa_sign_cond, 5),
                    self._sort_run_dumper_name():
                    RoutineWorker(self._sort_run_dumper_name(),
                                  self._sort_run_dump_fn,
                                  self._sort_run_dump_cond, 5),
                    self._sort_run_merger_name():
                    RoutineWorker(self._sort_run_merger_name(),
                                  self._sort_run_merge_fn,
                                  self._sort_run_merge_cond, 5)
                }
                for _, w in self._worker_map.items():
                    w.start_routine()
                self._started = True

    def stop_routine_workers(self):
        wait_join = True
        with self._lock:
            if self._started:
                wait_join = True
                self._started = False
        if wait_join:
            for w in self._worker_map.values():
                w.stop_routine()

    def wait_for_finished(self):
        while not self._sort_run_merger.is_merged_finished():
            with self._lock:
                self._lock.wait()
        self.stop_routine_workers()
        self._process_pool_executor.shutdown()
        if self._callback_submitter is not None:
            self._callback_submitter.shutdown()
        self._id_batch_fetcher.cleanup_visitor_meta_data()
        self._bye_for_signer()

    def _bye_for_signer(self):
        for rnd in range(60):
            try:
                self._psi_rsa_signer.say_signer_bye()
                logging.info("Success to say bye to signer at round "\
                             "%d, rsa_psi_preprocessor will exit", rnd)
                return
            except Exception as e:  # pylint: disable=broad-except
                logging.warning("Failed to say bye to signer at "\
                                "round %d, sleep 10s and retry", rnd)
            time.sleep(10)
        logging.warning("Give up to say bye to signer after try 60"\
                        "times, rsa_psi_preprocessor will exit as -1")
        traceback.print_stack()
        os._exit(-1)  # pylint: disable=protected-access

    def _id_batch_fetcher_name(self):
        return self._repr + ':id_batch_fetcher'

    def _wakeup_id_batch_fetcher(self):
        self._worker_map[self._id_batch_fetcher_name()].wakeup()

    def _id_batch_fetch_fn(self):
        next_index = self._psi_rsa_signer.get_next_index_to_fetch()
        for batch in self._id_batch_fetcher.make_processor(next_index):
            logging.debug("%s fetch batch begin at %d, len %d. wakeup %s",
                          self._id_batch_fetcher_name(), batch.begin_index,
                          len(batch), self._psi_rsa_signer_name())
            self._produce_item_cnt += len(batch)
            self._wakeup_psi_rsa_signer()
            if self._stop_fetch_id():
                break

    def _id_batch_fetch_cond(self):
        next_index = self._psi_rsa_signer.get_next_index_to_fetch()
        return self._id_batch_fetcher.need_process(next_index) and \
                not self._stop_fetch_id() and \
                not self._sort_run_dumper.is_dump_finished()

    def _stop_fetch_id(self):
        total_flying_item = self._produce_item_cnt - self._comsume_item_cnt
        if total_flying_item >= 5 << 20:
            logging.warning("stop fetch id since flying item "\
                            "reach to %d > 5m, produce_item_cnt: %d; "\
                            "consume_item_cnt: %d", total_flying_item,
                            self._produce_item_cnt, self._comsume_item_cnt)
            return True
        potential_mem_incr = total_flying_item * \
                             self._psi_rsa_signer.additional_item_mem_usage()
        if get_heap_mem_stats(None).CheckOomRisk(total_flying_item, 0.80,
                                                 potential_mem_incr):
            logging.warning("stop fetch id since has oom risk for 0.80, "\
                            "flying item reach to %d", total_flying_item)
            return True
        return False

    def _psi_rsa_signer_name(self):
        return self._repr + ':psi_rsa_signer'

    def _wakeup_psi_rsa_signer(self):
        self._worker_map[self._psi_rsa_signer_name()].wakeup()

    def _transmit_signed_batch(self, signed_index):
        evict_batch_cnt = self._id_batch_fetcher.evict_staless_item_batch(
            signed_index)
        self._psi_rsa_signer.update_next_batch_index_hint(evict_batch_cnt)
        self._wakeup_sort_run_dumper()

    def _psi_rsa_sign_fn(self):
        next_index = self._sort_run_dumper.get_next_index_to_dump()
        sign_cnt = 0
        signed_index = None
        for signed_batch in self._psi_rsa_signer.make_processor(next_index):
            logging.debug("%s sign batch begin at %d, len %d. wakeup %s",
                          self._psi_rsa_signer_name(),
                          signed_batch.begin_index, len(signed_batch),
                          self._sort_run_dumper_name())
            sign_cnt += 1
            if signed_batch is not None:
                signed_index = signed_batch.begin_index + len(signed_batch) - 1
            if sign_cnt % 16 == 0:
                self._transmit_signed_batch(signed_index)
        self._transmit_signed_batch(signed_index)

    def _psi_rsa_sign_cond(self):
        next_index = self._sort_run_dumper.get_next_index_to_dump()
        return self._psi_rsa_signer.need_process(next_index) and \
                not self._sort_run_dumper.is_dump_finished()

    def _sort_run_dumper_name(self):
        return self._repr + ':sort_run_dumper'

    def _wakeup_sort_run_dumper(self):
        self._worker_map[self._sort_run_dumper_name()].wakeup()

    def _load_sorted_items_from_rsa_signer(self):
        sort_run_dumper = self._sort_run_dumper
        rsi_signer = self._psi_rsa_signer
        next_index = sort_run_dumper.get_next_index_to_dump()
        hint_index = None
        items_buffer = []
        signed_finished = False
        total_item_num = 0
        max_flying_item = self._options.batch_processor_options.max_flying_item
        sort_run_size = max_flying_item // 2
        while sort_run_size <= 0 or total_item_num < sort_run_size:
            signed_finished, batch, hint_index = \
                rsi_signer.fetch_item_batch_by_index(next_index, hint_index)
            if batch is None:
                break
            assert next_index == batch.begin_index
            for item in batch:
                items_buffer.append(item)
            next_index += len(batch)
            total_item_num += len(batch)
        sorted_items_buffer = sorted(items_buffer, key=lambda item: item[0])
        return signed_finished, sorted_items_buffer, next_index

    def _sort_run_dump_fn(self):
        signed_finished, items_buffer, next_index = \
                self._load_sorted_items_from_rsa_signer()
        sort_run_dumper = self._sort_run_dumper
        if len(items_buffer) > 0:

            def producer(items_buffer):
                for signed_id, item, index in items_buffer:
                    item.set_example_id(signed_id)
                    yield signed_id, index, item

            sort_run_dumper.dump_sort_runs(producer(items_buffer))
        if next_index is not None:
            self._psi_rsa_signer.evict_staless_item_batch(next_index - 1)
        if signed_finished:
            sort_run_dumper.finish_dump_sort_run()
        dump_cnt = len(items_buffer)
        self._comsume_item_cnt += dump_cnt
        del items_buffer
        logging.warning("dump %d item in sort run, and gc %d objects.",
                        dump_cnt, gc.collect())

    def _sort_run_dump_cond(self):
        sort_run_dumper = self._sort_run_dumper
        rsa_signer = self._psi_rsa_signer
        next_index = sort_run_dumper.get_next_index_to_dump()
        max_flying_item = self._options.batch_processor_options.max_flying_item
        dump_finished = sort_run_dumper.is_dump_finished()
        signed_finished = rsa_signer.get_process_finished()
        flying_item_cnt = rsa_signer.get_flying_item_count()
        flying_begin_index = rsa_signer.get_flying_begin_index()
        dump_cands_num = 0
        if flying_begin_index is not None and next_index is not None and \
                (flying_begin_index <= next_index <
                    flying_begin_index + flying_item_cnt):
            dump_cands_num = flying_item_cnt - (next_index -
                                                flying_begin_index)
        return not dump_finished and \
                (signed_finished or
                 (dump_cands_num >= (2 << 20) or
                  (max_flying_item > 2 and
                    dump_cands_num > max_flying_item // 2)) or
                  self._dump_for_forward(dump_cands_num))

    def _dump_for_forward(self, dump_cands_num):
        if self._stop_fetch_id():
            total_flying_item = self._produce_item_cnt - self._comsume_item_cnt
            return dump_cands_num > 0 and \
                    dump_cands_num >= total_flying_item // 2
        return False

    def _sort_run_merger_name(self):
        return self._repr + ':sort_run_merger'

    def _sort_run_merge_fn(self):
        sort_runs = self._sort_run_dumper.get_all_sort_runs()
        input_dir = self._sort_run_dumper.sort_run_dump_dir()
        input_fpaths = [
            os.path.join(input_dir, partition_repr(self._options.partition_id),
                         sort_run.encode_sort_run_fname())
            for sort_run in sort_runs
        ]
        output_fpaths = self._sort_run_merger.merge_sort_runs(input_fpaths)
        self._publisher.publish_raw_data(self._options.partition_id,
                                         output_fpaths)
        self._publisher.finish_raw_data(self._options.partition_id)
        self._sort_run_merger.set_merged_finished()

    def _sort_run_merge_cond(self):
        if self._sort_run_merger.is_merged_finished():
            with self._lock:
                self._lock.notify()
            return False
        return self._sort_run_dumper.is_dump_finished()

    @staticmethod
    def _merger_comparator(a, b):
        return a.example_id < b.example_id
Example #3
0
class DataPortalJobManager(object):
    def __init__(self, etcd, portal_name, long_running):
        self._lock = threading.Lock()
        self._etcd = etcd
        self._portal_name = portal_name
        self._portal_manifest = None
        self._processing_job = None
        self._sync_portal_manifest()
        self._sync_processing_job()
        self._publisher = \
            RawDataPublisher(etcd, self._portal_manifest.raw_data_publish_dir)
        self._long_running = long_running
        assert self._portal_manifest is not None
        self._processed_fpath = set()
        for job_id in range(0, self._portal_manifest.next_job_id):
            job = self._sync_portal_job(job_id)
            assert job is not None and job.job_id == job_id
            for fpath in job.fpaths:
                self._processed_fpath.add(fpath)
        self._job_part_map = {}
        if self._portal_manifest.processing_job_id >= 0:
            self._check_processing_job_finished()
        if self._portal_manifest.processing_job_id < 0:
            self._launch_new_portal_job()

    def get_portal_manifest(self):
        with self._lock:
            return self._sync_portal_manifest()

    def alloc_task(self, rank_id):
        with self._lock:
            self._sync_processing_job()
            if self._processing_job is not None:
                partition_id = self._try_to_alloc_part(rank_id,
                                                       dp_pb.PartState.kInit,
                                                       dp_pb.PartState.kIdMap)
                if partition_id is not None:
                    return False, self._create_map_task(rank_id, partition_id)
                if self._all_job_part_mapped() and \
                        (self._portal_manifest.data_portal_type ==
                                dp_pb.DataPortalType.Streaming):
                    partition_id = self._try_to_alloc_part(
                        rank_id, dp_pb.PartState.kIdMapped,
                        dp_pb.PartState.kEventTimeReduce)
                    if partition_id is not None:
                        return False, self._create_reduce_task(
                            rank_id, partition_id)
                return (not self._long_running
                        and self._all_job_part_finished()), None
            return not self._long_running, None

    def finish_task(self, rank_id, partition_id, part_state):
        with self._lock:
            processing_job = self._sync_processing_job()
            if processing_job is None:
                return
            job_id = self._processing_job.job_id
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.rank_id == rank_id and \
                    job_part.part_state == part_state:
                if job_part.part_state == dp_pb.PartState.kIdMap:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kIdMap,
                                          dp_pb.PartState.kIdMapped)
                    logging.info("Data portal worker-%d finish map task "\
                                 "for partition %d of job %d",
                                 rank_id, partition_id, job_id)
                elif job_part.part_state == dp_pb.PartState.kEventTimeReduce:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kEventTimeReduce,
                                          dp_pb.PartState.kEventTimeReduced)
                    logging.info("Data portal worker-%d finish reduce task "\
                                 "for partition %d of job %d",
                                 rank_id, partition_id, job_id)
            self._check_processing_job_finished()

    def backgroup_task(self):
        with self._lock:
            if self._sync_processing_job() is not None:
                self._check_processing_job_finished()
            if self._sync_processing_job() is None and self._long_running:
                self._launch_new_portal_job()

    def _all_job_part_mapped(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.part_state <= dp_pb.PartState.kIdMap:
                return False
        return True

    def _all_job_part_finished(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if not self._is_job_part_finished(job_part):
                return False
        return True

    def _finish_job_part(self, job_id, partition_id, src_state, target_state):
        job_part = self._sync_job_part(job_id, partition_id)
        assert job_part is not None and job_part.part_state == src_state
        new_job_part = dp_pb.PortalJobPart()
        new_job_part.MergeFrom(job_part)
        new_job_part.part_state = target_state
        new_job_part.rank_id = -1
        self._update_job_part(new_job_part)

    def _create_map_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        map_fpaths = []
        for fpath in job.fpaths:
            if hash(fpath) % self._output_partition_num == partition_id:
                map_fpaths.append(fpath)
        task_name = '{}-dp_portal_job_{:08}-part-{:04}-map'.format(
            self._portal_manifest.name, job.job_id, partition_id)
        logging.info("Data portal worker-%d is allocated map task %s for "\
                     "partition %d of job %d. the map task has %d files"\
                     "-----------------\n", rank_id, task_name,
                     partition_id, job.job_id, len(map_fpaths))
        for seq, fpath in enumerate(map_fpaths):
            logging.info("%d. %s", seq, fpath)
        logging.info("---------------------------------\n")
        manifset = self._sync_portal_manifest()
        return dp_pb.MapTask(task_name=task_name,
                             fpaths=map_fpaths,
                             output_base_dir=self._map_output_dir(job.job_id),
                             output_partition_num=self._output_partition_num,
                             partition_id=partition_id,
                             part_field=self._get_part_field(),
                             data_portal_type=manifset.data_portal_type)

    def _get_part_field(self):
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return 'raw_id'
        assert portal_mainifest.data_portal_type == \
                dp_pb.DataPortalType.Streaming
        return 'example_id'

    def _create_reduce_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        job_id = job.job_id
        task_name = '{}-dp_portal_job_{:08}-part-{:04}-reduce'.format(
            self._portal_manifest.name, job_id, partition_id)
        logging.info("Data portal worker-%d is allocated reduce task %s for "\
                     "partition %d of job %d. the reduce base dir %s"\
                     "-----------------\n", rank_id, task_name,
                     partition_id, job_id, self._reduce_output_dir(job_id))
        return dp_pb.ReduceTask(
            task_name=task_name,
            map_base_dir=self._map_output_dir(job_id),
            reduce_base_dir=self._reduce_output_dir(job_id),
            partition_id=partition_id)

    def _try_to_alloc_part(self, rank_id, src_state, target_state):
        alloc_partition_id = None
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            part_job = self._sync_job_part(job_id, partition_id)
            if part_job.part_state == src_state and \
                    alloc_partition_id is None:
                alloc_partition_id = partition_id
            if part_job.part_state == target_state and \
                    part_job.rank_id == rank_id:
                alloc_partition_id = partition_id
                break
        if alloc_partition_id is None:
            return None
        part_job = self._job_part_map[alloc_partition_id]
        if part_job.part_state == src_state:
            new_job_part = dp_pb.PortalJobPart(job_id=job_id,
                                               rank_id=rank_id,
                                               partition_id=alloc_partition_id,
                                               part_state=target_state)
            self._update_job_part(new_job_part)
        return alloc_partition_id

    def _sync_portal_job(self, job_id):
        etcd_key = common.portal_job_etcd_key(self._portal_name, job_id)
        data = self._etcd.get_data(etcd_key)
        if data is not None:
            return text_format.Parse(data, dp_pb.DataPortalJob())
        return None

    def _sync_processing_job(self):
        assert self._sync_portal_manifest() is not None
        if self._portal_manifest.processing_job_id < 0:
            self._processing_job = None
        elif self._processing_job is None or \
                (self._processing_job.job_id !=
                    self._portal_manifest.processing_job_id):
            job_id = self._portal_manifest.processing_job_id
            self._processing_job = self._sync_portal_job(job_id)
            assert self._processing_job is not None
        return self._processing_job

    def _update_processing_job(self, job):
        self._processing_job = None
        etcd_key = common.portal_job_etcd_key(self._portal_name, job.job_id)
        self._etcd.set_data(etcd_key, text_format.MessageToString(job))
        self._processing_job = job

    def _sync_portal_manifest(self):
        if self._portal_manifest is None:
            etcd_key = common.portal_etcd_base_dir(self._portal_name)
            data = self._etcd.get_data(etcd_key)
            if data is not None:
                self._portal_manifest = \
                    text_format.Parse(data, dp_pb.DataPortalManifest())
        return self._portal_manifest

    def _update_portal_manifest(self, new_portal_manifest):
        self._portal_manifest = None
        etcd_key = common.portal_etcd_base_dir(self._portal_name)
        data = text_format.MessageToString(new_portal_manifest)
        self._etcd.set_data(etcd_key, data)
        self._portal_manifest = new_portal_manifest

    def _launch_new_portal_job(self):
        assert self._sync_processing_job() is None
        all_fpaths = self._list_input_dir()
        rest_fpaths = []
        for fpath in all_fpaths:
            if fpath not in self._processed_fpath:
                rest_fpaths.append(fpath)
        if len(rest_fpaths) == 0:
            logging.info("no file left for portal")
            return
        rest_fpaths.sort()
        portal_mainifest = self._sync_portal_manifest()
        new_job = dp_pb.DataPortalJob(job_id=portal_mainifest.next_job_id,
                                      finished=False,
                                      fpaths=rest_fpaths)
        self._update_processing_job(new_job)
        new_portal_manifest = dp_pb.DataPortalManifest()
        new_portal_manifest.MergeFrom(portal_mainifest)
        new_portal_manifest.next_job_id += 1
        new_portal_manifest.processing_job_id = new_job.job_id
        self._update_portal_manifest(new_portal_manifest)
        for partition_id in range(self._output_partition_num):
            self._sync_job_part(new_job.job_id, partition_id)
        logging.info("Data Portal job %d has lanuched. %d files will be"\
                     "processed\n------------\n",
                     new_job.job_id, len(new_job.fpaths))
        for seq, fpath in enumerate(new_job.fpaths):
            logging.info("%d. %s", seq, fpath)
        logging.info("---------------------------------\n")

    def _list_input_dir(self):
        all_inputs = []
        wildcard = self._portal_manifest.input_file_wildcard
        dirs = [self._portal_manifest.input_base_dir]
        while len(dirs) > 0:
            fdir = dirs[0]
            dirs = dirs[1:]
            fnames = gfile.ListDirectory(fdir)
            for fname in fnames:
                fpath = path.join(fdir, fname)
                if gfile.IsDirectory(fpath):
                    dirs.append(fpath)
                elif len(wildcard) == 0 or fnmatch(fname, wildcard):
                    all_inputs.append(fpath)
        return all_inputs

    def _sync_job_part(self, job_id, partition_id):
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] is None or \
                self._job_part_map[partition_id].job_id != job_id:
            etcd_key = common.portal_job_part_etcd_key(self._portal_name,
                                                       job_id, partition_id)
            data = self._etcd.get_data(etcd_key)
            if data is None:
                self._job_part_map[partition_id] = dp_pb.PortalJobPart(
                    job_id=job_id, rank_id=-1, partition_id=partition_id)
            else:
                self._job_part_map[partition_id] = \
                    text_format.Parse(data, dp_pb.PortalJobPart())
        return self._job_part_map[partition_id]

    def _update_job_part(self, job_part):
        partition_id = job_part.partition_id
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] != job_part:
            self._job_part_map[partition_id] = None
            etcd_key = common.portal_job_part_etcd_key(self._portal_name,
                                                       job_part.job_id,
                                                       partition_id)
            data = text_format.MessageToString(job_part)
            self._etcd.set_data(etcd_key, data)
        self._job_part_map[partition_id] = job_part

    def _check_processing_job_finished(self):
        if not self._all_job_part_finished():
            return False
        processing_job = self._sync_processing_job()
        if not processing_job.finished:
            finished_job = dp_pb.DataPortalJob()
            finished_job.MergeFrom(self._processing_job)
            finished_job.finished = True
            self._update_processing_job(finished_job)
        self._processing_job = None
        self._job_part_map = {}
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.processing_job_id >= 0:
            self._publish_raw_data(portal_mainifest.processing_job_id)
            new_portal_manifest = dp_pb.DataPortalManifest()
            new_portal_manifest.MergeFrom(self._sync_portal_manifest())
            new_portal_manifest.processing_job_id = -1
            self._update_portal_manifest(new_portal_manifest)
        if processing_job is not None:
            logging.info("Data Portal job %d has finished. Processed %d "\
                         "following fpaths\n------------\n",
                         processing_job.job_id, len(processing_job.fpaths))
            for seq, fpath in enumerate(processing_job.fpaths):
                logging.info("%d. %s", seq, fpath)
            logging.info("---------------------------------\n")
        return True

    @property
    def _output_partition_num(self):
        return self._portal_manifest.output_partition_num

    def _is_job_part_finished(self, job_part):
        assert self._portal_manifest is not None
        if self._portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return job_part.part_state == dp_pb.PartState.kIdMapped
        return job_part.part_state == dp_pb.PartState.kEventTimeReduced

    def _map_output_dir(self, job_id):
        return common.portal_map_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _reduce_output_dir(self, job_id):
        return common.portal_reduce_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _publish_raw_data(self, job_id):
        portal_manifest = self._sync_portal_manifest()
        output_dir = None
        if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            output_dir = common.portal_map_output_dir(
                portal_manifest.output_base_dir, job_id)
        else:
            output_dir = common.portal_reduce_output_dir(
                portal_manifest.output_base_dir, job_id)
        for partition_id in range(self._output_partition_num):
            dpath = path.join(output_dir, common.partition_repr(partition_id))
            fnames = []
            if gfile.Exists(dpath) and gfile.IsDirectory(dpath):
                fnames = [
                    f for f in gfile.ListDirectory(dpath)
                    if f.endswith(common.RawDataFileSuffix)
                ]
            publish_fpaths = []
            if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
                publish_fpaths = self._publish_psi_raw_data(
                    partition_id, dpath, fnames)
            else:
                publish_fpaths = self._publish_streaming_raw_data(
                    partition_id, dpath, fnames)
            logging.info("Data Portal Master publish %d file for partition "\
                         "%d of streaming job %d\n----------\n",
                         len(publish_fpaths), partition_id, job_id)
            for seq, fpath in enumerate(publish_fpaths):
                logging.info("%d. %s", seq, fpath)
            logging.info("------------------------------------------\n")

    def _publish_streaming_raw_data(self, partition_id, dpath, fnames):
        metas = [
            MergedSortRunMeta.decode_sort_run_meta_from_fname(fname)
            for fname in fnames
        ]
        metas.sort()
        fpaths = [
            path.join(dpath, meta.encode_merged_sort_run_fname())
            for meta in metas
        ]
        self._publisher.publish_raw_data(partition_id, fpaths)
        return fpaths

    def _publish_psi_raw_data(self, partition_id, dpath, fnames):
        fpaths = [path.join(dpath, fname) for fname in fnames]
        self._publisher.publish_raw_data(partition_id, fpaths)
        self._publisher.finish_raw_data(partition_id)
        return fpaths
class RsaPsiPreProcessor(object):
    def __init__(self,
                 options,
                 etcd_name,
                 etcd_addrs,
                 etcd_base_dir,
                 use_mock_etcd=False):
        self._lock = threading.Condition()
        self._options = options
        etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd)
        pub_dir = self._options.raw_data_publish_dir
        self._publisher = RawDataPublisher(etcd, pub_dir)
        self._process_pool_executor = \
                concur_futures.ProcessPoolExecutor(
                        options.offload_processor_number
                    )
        self._id_batch_fetcher = IdBatchFetcher(etcd, self._options)
        max_flying_item = options.batch_processor_options.max_flying_item
        if self._options.role == common_pb.FLRole.Leader:
            private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
            self._psi_rsa_signer = LeaderPsiRsaSigner(
                self._id_batch_fetcher,
                max_flying_item,
                self._options.max_flying_sign_batch,
                self._options.slow_sign_threshold,
                self._process_pool_executor,
                private_key,
            )
            self._repr = 'leader-' + 'rsa_psi_preprocessor'
        else:
            public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
            self._psi_rsa_signer = FollowerPsiRsaSigner(
                self._id_batch_fetcher, max_flying_item,
                self._options.max_flying_sign_batch,
                self._options.max_flying_sign_rpc,
                self._options.sign_rpc_timeout_ms,
                self._options.slow_sign_threshold, self._options.stub_fanout,
                self._process_pool_executor, public_key,
                self._options.leader_rsa_psi_signer_addr)
            self._repr = 'follower-' + 'rsa_psi_preprocessor'
        self._sort_run_dumper = SortRunDumper(options)
        self._sort_run_merger = SortRunMerger(
                dj_pb.SortRunMergerOptions(
                    merger_name='sort_run_merger_'+\
                                partition_repr(options.partition_id),
                    reader_options=dj_pb.RawDataOptions(
                        raw_data_iter=options.writer_options.output_writer,
                        compressed_type=options.writer_options.compressed_type,
                        read_ahead_size=\
                            options.sort_run_merger_read_ahead_buffer
                    ),
                    writer_options=options.writer_options,
                    output_file_dir=options.output_file_dir,
                    partition_id=options.partition_id
                ),
                'example_id'
            )
        self._started = False

    def start_process(self):
        with self._lock:
            if not self._started:
                self._worker_map = {
                    self._id_batch_fetcher_name():
                    RoutineWorker(self._id_batch_fetcher_name(),
                                  self._id_batch_fetch_fn,
                                  self._id_batch_fetch_cond, 5),
                    self._psi_rsa_signer_name():
                    RoutineWorker(self._psi_rsa_signer_name(),
                                  self._psi_rsa_sign_fn,
                                  self._psi_rsa_sign_cond, 5),
                    self._sort_run_dumper_name():
                    RoutineWorker(self._sort_run_dumper_name(),
                                  self._sort_run_dump_fn,
                                  self._sort_run_dump_cond, 5),
                    self._sort_run_merger_name():
                    RoutineWorker(self._sort_run_merger_name(),
                                  self._sort_run_merge_fn,
                                  self._sort_run_merge_cond, 5)
                }
                for _, w in self._worker_map.items():
                    w.start_routine()
                self._started = True

    def stop_routine_workers(self):
        wait_join = True
        with self._lock:
            if self._started:
                wait_join = True
                self._started = False
        if wait_join:
            for w in self._worker_map.values():
                w.stop_routine()

    def wait_for_finished(self):
        while not self._sort_run_merger.is_merged_finished():
            with self._lock:
                self._lock.wait()
        self.stop_routine_workers()
        self._process_pool_executor.shutdown()
        self._id_batch_fetcher.cleanup_visitor_meta_data()

    def _id_batch_fetcher_name(self):
        return self._repr + ':id_batch_fetcher'

    def _wakeup_id_batch_fetcher(self):
        self._worker_map[self._id_batch_fetcher_name()].wakeup()

    def _id_batch_fetch_fn(self):
        next_index = self._psi_rsa_signer.get_next_index_to_fetch()
        for batch in self._id_batch_fetcher.make_processor(next_index):
            logging.debug("%s fetch batch begin at %d, len %d. wakeup %s",
                          self._id_batch_fetcher_name(), batch.begin_index,
                          len(batch), self._psi_rsa_signer_name())
            self._wakeup_psi_rsa_signer()

    def _id_batch_fetch_cond(self):
        next_index = self._psi_rsa_signer.get_next_index_to_fetch()
        return self._id_batch_fetcher.need_process(next_index)

    def _psi_rsa_signer_name(self):
        return self._repr + ':psi_rsa_signer'

    def _wakeup_psi_rsa_signer(self):
        self._worker_map[self._psi_rsa_signer_name()].wakeup()

    def _psi_rsa_sign_fn(self):
        next_index = self._sort_run_dumper.get_next_index_to_dump()
        for signed_batch in self._psi_rsa_signer.make_processor(next_index):
            logging.debug("%s sign batch begin at %d, len %d. wakeup %s",
                          self._psi_rsa_signer_name(),
                          signed_batch.begin_index, len(signed_batch),
                          self._sort_run_dumper_name())
            self._wakeup_sort_run_dumper()
        staless_index = self._sort_run_dumper.get_next_index_to_dump() - 1
        evict_batch_cnt = self._id_batch_fetcher.evict_staless_item_batch(
            staless_index)
        self._psi_rsa_signer.update_next_batch_index_hint(evict_batch_cnt)

    def _psi_rsa_sign_cond(self):
        next_index = self._sort_run_dumper.get_next_index_to_dump()
        return self._psi_rsa_signer.need_process(next_index)

    def _sort_run_dumper_name(self):
        return self._repr + ':sort_run_dumper'

    def _wakeup_sort_run_dumper(self):
        self._worker_map[self._sort_run_dumper_name()].wakeup()

    def _load_sorted_items_from_rsa_signer(self):
        sort_run_dumper = self._sort_run_dumper
        rsi_signer = self._psi_rsa_signer
        next_index = sort_run_dumper.get_next_index_to_dump()
        hint_index = None
        items_buffer = []
        signed_finished = False
        total_item_num = 0
        max_flying_item = self._options.batch_processor_options.max_flying_item
        sort_run_size = max_flying_item // 4
        while True and total_item_num < sort_run_size:
            signed_finished, batch, hint_index = \
                rsi_signer.fetch_item_batch_by_index(next_index, hint_index)
            if batch is None:
                break
            assert next_index == batch.begin_index
            for item in batch:
                items_buffer.append(item)
            next_index += len(batch)
            total_item_num += len(batch)
        sorted_items_buffer = sorted(items_buffer, key=lambda item: item[0])
        return signed_finished, sorted_items_buffer, next_index

    def _sort_run_dump_fn(self):
        signed_finished, items_buffer, next_index = \
                self._load_sorted_items_from_rsa_signer()
        sort_run_dumper = self._sort_run_dumper
        if len(items_buffer) > 0:

            def producer(items_buffer):
                for signed_id, item, index in items_buffer:
                    item.set_example_id(signed_id)
                    yield signed_id, index, item

            sort_run_dumper.dump_sort_runs(producer(items_buffer))
        if next_index is not None:
            self._psi_rsa_signer.evict_staless_item_batch(next_index - 1)
        if signed_finished:
            sort_run_dumper.finish_dump_sort_run()

    def _sort_run_dump_cond(self):
        sort_run_dumper = self._sort_run_dumper
        rsa_signer = self._psi_rsa_signer
        next_index = sort_run_dumper.get_next_index_to_dump()
        max_flying_item = self._options.batch_processor_options.max_flying_item
        dump_finished = sort_run_dumper.is_dump_finished()
        signed_finished = rsa_signer.get_process_finished()
        flying_item_cnt = rsa_signer.get_flying_item_count()
        flying_begin_index = rsa_signer.get_flying_begin_index()
        return not dump_finished and \
                (signed_finished or
                 (flying_begin_index is not None and
                  next_index is not None and
                  (flying_begin_index <= next_index <
                      flying_begin_index + flying_item_cnt) and
                   (flying_item_cnt-(next_index-flying_begin_index) >=
                    max_flying_item // 4)))

    def _sort_run_merger_name(self):
        return self._repr + ':sort_run_merger'

    def _sort_run_merge_fn(self):
        sort_runs = self._sort_run_dumper.get_all_sort_runs()
        input_dir = self._sort_run_dumper.sort_run_dump_dir()
        input_fpaths = [
            os.path.join(input_dir, partition_repr(self._options.partition_id),
                         sort_run.encode_sort_run_fname())
            for sort_run in sort_runs
        ]
        output_fpaths = self._sort_run_merger.merge_sort_runs(input_fpaths)
        self._publisher.publish_raw_data(self._options.partition_id,
                                         output_fpaths)
        self._publisher.finish_raw_data(self._options.partition_id)
        self._sort_run_merger.set_merged_finished()

    def _sort_run_merge_cond(self):
        if self._sort_run_merger.is_merged_finished():
            with self._lock:
                self._lock.notify()
            return False
        return self._sort_run_dumper.is_dump_finished()
class DataPortalJobManager(object):
    def __init__(self,
                 kvstore,
                 portal_name,
                 long_running,
                 check_success_tag,
                 single_subfolder,
                 files_per_job_limit,
                 max_files_per_job=8000):
        self._lock = threading.Lock()
        self._kvstore = kvstore
        self._portal_name = portal_name
        self._check_success_tag = check_success_tag
        self._single_subfolder = single_subfolder
        self._files_per_job_limit = files_per_job_limit
        self._max_files_per_job = max_files_per_job
        self._portal_manifest = None
        self._processing_job = None
        self._sync_portal_manifest()
        self._sync_processing_job()
        self._publisher = \
            RawDataPublisher(kvstore,
                self._portal_manifest.raw_data_publish_dir)
        self._long_running = long_running
        self._finished = False
        assert self._portal_manifest is not None
        self._processed_fpath = set()
        for job_id in range(0, self._portal_manifest.next_job_id):
            job = self._sync_portal_job(job_id)
            assert job is not None and job.job_id == job_id
            for fpath in job.fpaths:
                self._processed_fpath.add(fpath)
        self._job_part_map = {}
        if self._portal_manifest.processing_job_id >= 0:
            self._check_processing_job_finished()
        if self._portal_manifest.processing_job_id < 0:
            if not self._launch_new_portal_job() and not self._long_running:
                self._finished = True

    def get_portal_manifest(self):
        with self._lock:
            return self._sync_portal_manifest()

    def alloc_task(self, rank_id):
        with self._lock:
            self._sync_processing_job()
            if self._processing_job is not None:
                partition_id = self._try_to_alloc_part(rank_id,
                                                       dp_pb.PartState.kInit,
                                                       dp_pb.PartState.kIdMap)
                if partition_id is not None:
                    return False, self._create_map_task(rank_id, partition_id)
                if self._all_job_part_mapped() and \
                        (self._portal_manifest.data_portal_type ==
                                dp_pb.DataPortalType.Streaming):
                    partition_id = self._try_to_alloc_part(
                        rank_id, dp_pb.PartState.kIdMapped,
                        dp_pb.PartState.kEventTimeReduce)
                    if partition_id is not None:
                        return False, self._create_reduce_task(
                            rank_id, partition_id)
                return (self._finished and self._all_job_part_finished()), None
            return self._finished, None

    def finish_task(self, rank_id, partition_id, part_state):
        with self._lock:
            processing_job = self._sync_processing_job()
            if processing_job is None:
                return
            job_id = self._processing_job.job_id
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.rank_id == rank_id and \
                    job_part.part_state == part_state:
                if job_part.part_state == dp_pb.PartState.kIdMap:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kIdMap,
                                          dp_pb.PartState.kIdMapped)
                    logging.info("Data portal worker-%d finish map task "\
                                 "for partition %d of job %d",
                                 rank_id, partition_id, job_id)
                elif job_part.part_state == dp_pb.PartState.kEventTimeReduce:
                    self._finish_job_part(job_id, partition_id,
                                          dp_pb.PartState.kEventTimeReduce,
                                          dp_pb.PartState.kEventTimeReduced)
                    logging.info("Data portal worker-%d finish reduce task "\
                                 "for partition %d of job %d",
                                 rank_id, partition_id, job_id)
            self._check_processing_job_finished()

    def backgroup_task(self):
        with self._lock:
            if self._sync_processing_job() is not None:
                self._check_processing_job_finished()
            if self._sync_processing_job() is None:
                success = self._launch_new_portal_job()
                if not success and not self._long_running:
                    self._finished = True

    def _all_job_part_mapped(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if job_part.part_state <= dp_pb.PartState.kIdMap:
                return False
        return True

    def _all_job_part_finished(self):
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            job_part = self._sync_job_part(job_id, partition_id)
            if not self._is_job_part_finished(job_part):
                return False
        return True

    def _finish_job_part(self, job_id, partition_id, src_state, target_state):
        job_part = self._sync_job_part(job_id, partition_id)
        assert job_part is not None and job_part.part_state == src_state
        new_job_part = dp_pb.PortalJobPart()
        new_job_part.MergeFrom(job_part)
        new_job_part.part_state = target_state
        new_job_part.rank_id = -1
        self._update_job_part(new_job_part)

    def _create_map_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        map_fpaths = []
        for fpath in job.fpaths:
            if hash(fpath) % self._output_partition_num == partition_id:
                map_fpaths.append(fpath)
        task_name = '{}-dp_portal_job_{:08}-part-{:04}-map'.format(
            self._portal_manifest.name, job.job_id, partition_id)
        logging.info("Data portal worker-%d is allocated map task %s for "\
                     "partition %d of job %d. the map task has %d files"\
                     "-----------------\n", rank_id, task_name,
                     partition_id, job.job_id, len(map_fpaths))
        for seq, fpath in enumerate(map_fpaths):
            logging.info("%d. %s", seq, fpath)
        logging.info("---------------------------------\n")
        manifset = self._sync_portal_manifest()
        return dp_pb.MapTask(task_name=task_name,
                             fpaths=map_fpaths,
                             output_base_dir=self._map_output_dir(job.job_id),
                             output_partition_num=self._output_partition_num,
                             partition_id=partition_id,
                             part_field=self._get_part_field(),
                             data_portal_type=manifset.data_portal_type)

    def _get_part_field(self):
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return 'raw_id'
        assert portal_mainifest.data_portal_type == \
                dp_pb.DataPortalType.Streaming
        return 'example_id'

    def _create_reduce_task(self, rank_id, partition_id):
        assert self._processing_job is not None
        job = self._processing_job
        job_id = job.job_id
        task_name = '{}-dp_portal_job_{:08}-part-{:04}-reduce'.format(
            self._portal_manifest.name, job_id, partition_id)
        logging.info("Data portal worker-%d is allocated reduce task %s for "\
                     "partition %d of job %d. the reduce base dir %s"\
                     "-----------------\n", rank_id, task_name,
                     partition_id, job_id, self._reduce_output_dir(job_id))
        return dp_pb.ReduceTask(
            task_name=task_name,
            map_base_dir=self._map_output_dir(job_id),
            reduce_base_dir=self._reduce_output_dir(job_id),
            partition_id=partition_id)

    def _try_to_alloc_part(self, rank_id, src_state, target_state):
        alloc_partition_id = None
        processing_job = self._sync_processing_job()
        assert processing_job is not None
        job_id = self._processing_job.job_id
        for partition_id in range(self._output_partition_num):
            part_job = self._sync_job_part(job_id, partition_id)
            if part_job.part_state == src_state and \
                    alloc_partition_id is None:
                alloc_partition_id = partition_id
            if part_job.part_state == target_state and \
                    part_job.rank_id == rank_id:
                alloc_partition_id = partition_id
                break
        if alloc_partition_id is None:
            return None
        part_job = self._job_part_map[alloc_partition_id]
        if part_job.part_state == src_state:
            new_job_part = dp_pb.PortalJobPart(job_id=job_id,
                                               rank_id=rank_id,
                                               partition_id=alloc_partition_id,
                                               part_state=target_state)
            self._update_job_part(new_job_part)
        return alloc_partition_id

    def _sync_portal_job(self, job_id):
        kvstore_key = common.portal_job_kvstore_key(self._portal_name, job_id)
        data = self._kvstore.get_data(kvstore_key)
        if data is not None:
            return text_format.Parse(data,
                                     dp_pb.DataPortalJob(),
                                     allow_unknown_field=True)
        return None

    def _sync_processing_job(self):
        assert self._sync_portal_manifest() is not None
        if self._portal_manifest.processing_job_id < 0:
            self._processing_job = None
        elif self._processing_job is None or \
                (self._processing_job.job_id !=
                    self._portal_manifest.processing_job_id):
            job_id = self._portal_manifest.processing_job_id
            self._processing_job = self._sync_portal_job(job_id)
            assert self._processing_job is not None
        return self._processing_job

    def _update_processing_job(self, job):
        self._processing_job = None
        kvstore_key = common.portal_job_kvstore_key(self._portal_name,
                                                    job.job_id)
        self._kvstore.set_data(kvstore_key, text_format.MessageToString(job))
        self._processing_job = job

    def _sync_portal_manifest(self):
        if self._portal_manifest is None:
            kvstore_key = common.portal_kvstore_base_dir(self._portal_name)
            data = self._kvstore.get_data(kvstore_key)
            if data is not None:
                self._portal_manifest = \
                    text_format.Parse(data, dp_pb.DataPortalManifest(),
                                      allow_unknown_field=True)
        return self._portal_manifest

    def _update_portal_manifest(self, new_portal_manifest):
        self._portal_manifest = None
        kvstore_key = common.portal_kvstore_base_dir(self._portal_name)
        data = text_format.MessageToString(new_portal_manifest)
        self._kvstore.set_data(kvstore_key, data)
        self._portal_manifest = new_portal_manifest

    def _launch_new_portal_job(self):
        assert self._sync_processing_job() is None
        rest_fpaths = self._list_input_dir()
        if len(rest_fpaths) == 0:
            logging.info("no file left for portal")
            return False
        rest_fpaths.sort()
        portal_mainifest = self._sync_portal_manifest()
        new_job = dp_pb.DataPortalJob(job_id=portal_mainifest.next_job_id,
                                      finished=False,
                                      fpaths=rest_fpaths)
        self._update_processing_job(new_job)
        new_portal_manifest = dp_pb.DataPortalManifest()
        new_portal_manifest.MergeFrom(portal_mainifest)
        new_portal_manifest.next_job_id += 1
        new_portal_manifest.processing_job_id = new_job.job_id
        self._update_portal_manifest(new_portal_manifest)
        for partition_id in range(self._output_partition_num):
            self._sync_job_part(new_job.job_id, partition_id)
        logging.info("Data Portal job %d has lanuched. %d files will be"\
                     "processed\n------------\n",
                     new_job.job_id, len(new_job.fpaths))
        for seq, fpath in enumerate(new_job.fpaths):
            logging.info("%d. %s", seq, fpath)
        logging.info("---------------------------------\n")

        return True

    def _list_dir_helper_oss(self, root):
        # oss returns a file multiple times, e.g. listdir('root') returns
        #   ['folder', 'file1.txt', 'folder/file2.txt']
        # and then listdir('root/folder') returns
        #   ['file2.txt']
        filenames = set(path.join(root, i) for i in gfile.ListDirectory(root))
        res = []
        for fname in filenames:
            succ = path.join(path.dirname(fname), '_SUCCESS')
            if succ in filenames or not gfile.IsDirectory(fname):
                res.append(fname)

        return res

    def _list_dir_helper(self, root):
        filenames = list(gfile.ListDirectory(root))
        # If _SUCCESS is present, we assume there are no subdirs
        if '_SUCCESS' in filenames:
            return [path.join(root, i) for i in filenames]

        res = []
        for basename in filenames:
            fname = path.join(root, basename)
            if gfile.IsDirectory(fname):
                # 'ignore tmp dirs starting with _
                if basename.startswith('_'):
                    continue
                res += self._list_dir_helper(fname)
            else:
                res.append(fname)
        return res

    def _list_input_dir(self):
        logging.info("List input directory, it will take some time...")
        root = self._portal_manifest.input_base_dir
        wildcard = self._portal_manifest.input_file_wildcard

        if root.startswith('oss://'):
            all_files = set(self._list_dir_helper_oss(root))
        else:
            all_files = set(self._list_dir_helper(root))

        num_ignored = 0
        num_target_files = 0
        num_new_files = 0
        by_folder = {}
        for fname in all_files:
            splits = path.split(path.relpath(fname, root))
            basename = splits[-1]
            dirnames = splits[:-1]

            # ignore files and dirs starting with _ or .
            # for example: _SUCCESS or ._SUCCESS.crc
            ignore = False
            for name in splits:
                if name.startswith('_') or name.startswith('.'):
                    ignore = True
                    break
            if ignore:
                num_ignored += 1
                continue

            # check wildcard
            if wildcard and not fnmatch(fname, wildcard):
                continue
            num_target_files += 1

            # check success tag
            if self._check_success_tag:
                succ_fname = path.join(root, *dirnames, '_SUCCESS')
                if succ_fname not in all_files:
                    continue

            if fname in self._processed_fpath:
                continue
            num_new_files += 1

            folder = path.join(*dirnames)
            if folder not in by_folder:
                by_folder[folder] = []
            by_folder[folder].append(fname)

        if not by_folder:
            rest_fpaths = []
        elif self._single_subfolder:
            rest_folder, rest_fpaths = sorted(by_folder.items(),
                                              key=lambda x: x[0])[0]
            logging.info(
                'single_subfolder is set. Only process folder %s '
                'in this iteration', rest_folder)
        else:
            rest_fpaths = []
            if (self._files_per_job_limit <= 0 or
                self._files_per_job_limit > self._max_files_per_job) and \
                sum([len(v) for _, v in by_folder.items()]) > \
                    self._max_files_per_job:
                logging.info(
                    "Number of files exceeds limit, processing "
                    "%d per job", self._max_files_per_job)
                self._files_per_job_limit = self._max_files_per_job
            for _, v in sorted(by_folder.items(), key=lambda x: x[0]):
                if self._files_per_job_limit and rest_fpaths and \
                        len(rest_fpaths) + len(v) > self._files_per_job_limit:
                    break
                rest_fpaths.extend(v)

        logging.info(
            'Listing %s: found %d dirs, %d files, %d tmp files ignored, '
            '%d files matching wildcard, %d new files to process. '
            'Processing %d files in this iteration.', root, len(by_folder),
            len(all_files), num_ignored, num_target_files, num_new_files,
            len(rest_fpaths))
        return rest_fpaths

    def _sync_job_part(self, job_id, partition_id):
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] is None or \
                self._job_part_map[partition_id].job_id != job_id:
            kvstore_key = common.portal_job_part_kvstore_key(
                self._portal_name, job_id, partition_id)
            data = self._kvstore.get_data(kvstore_key)
            if data is None:
                self._job_part_map[partition_id] = dp_pb.PortalJobPart(
                    job_id=job_id, rank_id=-1, partition_id=partition_id)
            else:
                self._job_part_map[partition_id] = \
                    text_format.Parse(data, dp_pb.PortalJobPart(),
                                      allow_unknown_field=True)
        return self._job_part_map[partition_id]

    def _update_job_part(self, job_part):
        partition_id = job_part.partition_id
        if partition_id not in self._job_part_map or \
                self._job_part_map[partition_id] != job_part:
            self._job_part_map[partition_id] = None
            kvstore_key = common.portal_job_part_kvstore_key(
                self._portal_name, job_part.job_id, partition_id)
            data = text_format.MessageToString(job_part)
            self._kvstore.set_data(kvstore_key, data)
        self._job_part_map[partition_id] = job_part

    def _check_processing_job_finished(self):
        if not self._all_job_part_finished():
            return False
        processing_job = self._sync_processing_job()
        if not processing_job.finished:
            finished_job = dp_pb.DataPortalJob()
            finished_job.MergeFrom(self._processing_job)
            finished_job.finished = True
            self._update_processing_job(finished_job)
        for fpath in processing_job.fpaths:
            self._processed_fpath.add(fpath)
        self._processing_job = None
        self._job_part_map = {}
        portal_mainifest = self._sync_portal_manifest()
        if portal_mainifest.processing_job_id >= 0:
            self._publish_raw_data(portal_mainifest.processing_job_id)
            new_portal_manifest = dp_pb.DataPortalManifest()
            new_portal_manifest.MergeFrom(self._sync_portal_manifest())
            new_portal_manifest.processing_job_id = -1
            self._update_portal_manifest(new_portal_manifest)
        if processing_job is not None:
            logging.info("Data Portal job %d has finished. Processed %d "\
                         "following fpaths\n------------\n",
                         processing_job.job_id, len(processing_job.fpaths))
            for seq, fpath in enumerate(processing_job.fpaths):
                logging.info("%d. %s", seq, fpath)
            logging.info("---------------------------------\n")
        return True

    @property
    def _output_partition_num(self):
        return self._portal_manifest.output_partition_num

    def _is_job_part_finished(self, job_part):
        assert self._portal_manifest is not None
        if self._portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            return job_part.part_state == dp_pb.PartState.kIdMapped
        return job_part.part_state == dp_pb.PartState.kEventTimeReduced

    def _map_output_dir(self, job_id):
        return common.portal_map_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _reduce_output_dir(self, job_id):
        return common.portal_reduce_output_dir(
            self._portal_manifest.output_base_dir, job_id)

    def _publish_raw_data(self, job_id):
        portal_manifest = self._sync_portal_manifest()
        output_dir = None
        if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
            output_dir = common.portal_map_output_dir(
                portal_manifest.output_base_dir, job_id)
        else:
            output_dir = common.portal_reduce_output_dir(
                portal_manifest.output_base_dir, job_id)
        for partition_id in range(self._output_partition_num):
            dpath = path.join(output_dir, common.partition_repr(partition_id))
            fnames = []
            if gfile.Exists(dpath) and gfile.IsDirectory(dpath):
                fnames = [
                    f for f in gfile.ListDirectory(dpath)
                    if f.endswith(common.RawDataFileSuffix)
                ]
            publish_fpaths = []
            if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
                publish_fpaths = self._publish_psi_raw_data(
                    partition_id, dpath, fnames)
            else:
                publish_fpaths = self._publish_streaming_raw_data(
                    partition_id, dpath, fnames)
            logging.info("Data Portal Master publish %d file for partition "\
                         "%d of streaming job %d\n----------\n",
                         len(publish_fpaths), partition_id, job_id)
            for seq, fpath in enumerate(publish_fpaths):
                logging.info("%d. %s", seq, fpath)
            logging.info("------------------------------------------\n")

    def _publish_streaming_raw_data(self, partition_id, dpath, fnames):
        metas = [
            MergedSortRunMeta.decode_sort_run_meta_from_fname(fname)
            for fname in fnames
        ]
        metas.sort()
        fpaths = [
            path.join(dpath, meta.encode_merged_sort_run_fname())
            for meta in metas
        ]
        self._publisher.publish_raw_data(partition_id, fpaths)
        return fpaths

    def _publish_psi_raw_data(self, partition_id, dpath, fnames):
        fpaths = [path.join(dpath, fname) for fname in fnames]
        self._publisher.publish_raw_data(partition_id, fpaths)
        self._publisher.finish_raw_data(partition_id)
        return fpaths