Exemple #1
0
 def __init__(self, options, kvstore_type, use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     kvstore = DBClient(kvstore_type, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(kvstore, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._callback_submitter = None
     # pre fock sub processor before launch grpc client
     self._process_pool_executor.submit(min, 1, 2).result()
     self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options)
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._callback_submitter = concur_futures.ThreadPoolExecutor(1)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, self._callback_submitter,
             public_key, self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
             dj_pb.SortRunMergerOptions(
                 merger_name='sort_run_merger_'+\
                             partition_repr(options.partition_id),
                 reader_options=dj_pb.RawDataOptions(
                     raw_data_iter=options.writer_options.output_writer,
                     compressed_type=options.writer_options.compressed_type,
                     read_ahead_size=\
                         options.sort_run_merger_read_ahead_buffer,
                     read_batch_size=\
                         options.sort_run_merger_read_batch_size
                 ),
                 writer_options=options.writer_options,
                 output_file_dir=options.output_file_dir,
                 partition_id=options.partition_id
             ),
             self._merger_comparator
         )
     self._produce_item_cnt = 0
     self._comsume_item_cnt = 0
     self._started = False
 def __init__(self,
              options,
              etcd_name,
              etcd_addrs,
              etcd_base_dir,
              use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(etcd, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._id_batch_fetcher = IdBatchFetcher(etcd, self._options)
     max_flying_item = options.batch_processor_options.max_flying_item
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher, max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, public_key,
             self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
             dj_pb.SortRunMergerOptions(
                 merger_name='sort_run_merger_'+\
                             partition_repr(options.partition_id),
                 reader_options=dj_pb.RawDataOptions(
                     raw_data_iter=options.writer_options.output_writer,
                     compressed_type=options.writer_options.compressed_type,
                     read_ahead_size=\
                         options.sort_run_merger_read_ahead_buffer
                 ),
                 writer_options=options.writer_options,
                 output_file_dir=options.output_file_dir,
                 partition_id=options.partition_id
             ),
             'example_id'
         )
     self._started = False
Exemple #3
0
 def _make_merger_options(self, task):
     return dj_pb.SortRunMergerOptions(
         merger_name="{}-rank_{}".format(task.task_name, self._rank_id),
         reader_options=dj_pb.RawDataOptions(
             raw_data_iter=self._options.writer_options.output_writer,
             compressed_type=self._options.writer_options.compressed_type,
             read_ahead_size=self._options.merger_read_ahead_size,
             read_batch_size=self._options.merger_read_batch_size),
         writer_options=self._options.writer_options,
         output_file_dir=task.reduce_base_dir,
         partition_id=task.partition_id,
     )