def _preprocess_rsa_psi_follower(self): processors = [] rsa_key_pem = None with gfile.GFile(self._rsa_public_key_path, 'rb') as f: rsa_key_pem = f.read() for partition_id in range( self._data_source_f.data_source_meta.partition_num): options = dj_pb.RsaPsiPreProcessorOptions( preprocessor_name='follower-rsa-psi-processor', role=common_pb.FLRole.Follower, rsa_key_pem=rsa_key_pem, input_file_paths=[self._psi_raw_data_fpaths_f[partition_id]], output_file_dir=self._pre_processor_ouput_dir_f, raw_data_publish_dir=self._raw_data_pub_dir_f, partition_id=partition_id, leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr, offload_processor_number=1, max_flying_sign_batch=128, max_flying_sign_rpc=64, sign_rpc_timeout_ms=100000, stub_fanout=2, slow_sign_threshold=8, sort_run_merger_read_ahead_buffer=1 << 20, rpc_sync_mode=True if partition_id % 2 == 0 else False, rpc_thread_pool_size=16, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=1 << 14)) processor = rsa_psi_preprocessor.RsaPsiPreProcessor( options, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_f, True) processor.start_process() processors.append(processor) for processor in processors: processor.wait_for_finished()
def _preprocess_rsa_psi_follower(self): processors = [] rsa_key_pem = None with gfile.GFile(self._rsa_public_key_path, 'rb') as f: rsa_key_pem = f.read() for partition_id in range( self._data_source_f.data_source_meta.partition_num): options = dj_pb.RsaPsiPreProcessorOptions( preprocessor_name='follower-rsa-psi-processor', role=common_pb.FLRole.Follower, rsa_key_pem=rsa_key_pem, input_file_paths=[self._psi_raw_data_fpaths_f[partition_id]], output_file_dir=self._pre_processor_ouput_dir_f, raw_data_publish_dir=self._raw_data_pub_dir_f, partition_id=partition_id, leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr, offload_processor_number=1, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=1 << 14)) processor = rsa_psi_preprocessor.RsaPsiPreProcessor( options, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_f, True) processor.start_process() processors.append(processor) for processor in processors: processor.wait_for_finished()
def _preprocess_rsa_psi_leader(self): processors = [] rsa_key_pem = None with gfile.GFile(self._rsa_private_key_path, 'rb') as f: rsa_key_pem = f.read() for partition_id in range( self._data_source_l.data_source_meta.partition_num): options = dj_pb.RsaPsiPreProcessorOptions( preprocessor_name='leader-rsa-psi-processor', role=common_pb.FLRole.Leader, rsa_key_pem=rsa_key_pem, input_file_paths=[self._psi_raw_data_fpaths_l[partition_id]], output_file_dir=self._pre_processor_ouput_dir_l, raw_data_publish_dir=self._raw_data_pub_dir_l, partition_id=partition_id, offload_processor_number=1, max_flying_sign_batch=128, stub_fanout=2, slow_sign_threshold=8, sort_run_merger_read_ahead_buffer=1 << 20, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=1 << 14)) processor = rsa_psi_preprocessor.RsaPsiPreProcessor( options, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_l, True) processor.start_process() processors.append(processor) for processor in processors: processor.wait_for_finished()
def _make_portal_worker(self): portal_worker_options = dp_pb.DataPortalWorkerOptions( raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD", compressed_type=''), writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=128, max_flying_item=300000), merge_buffer_size=4096, merger_read_ahead_size=1000000) self._portal_worker = DataPortalWorker(portal_worker_options, "localhost:5005", 0, "test_portal_worker_0", "portal_worker_0", "localhost:2379", True)
def _make_portal_worker(self): portal_worker_options = dp_pb.DataPortalWorkerOptions( raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD", read_ahead_size=1 << 20, read_batch_size=128), writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=128, max_flying_item=300000), merger_read_ahead_size=1000000, merger_read_batch_size=128) self._portal_worker = DataPortalWorker(portal_worker_options, "localhost:5005", 0, "test_portal_worker_0", "portal_worker_0", "localhost:2379", "test_user", "test_password", True)
def _launch_workers(self): worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='SORT_RUN_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=4096), data_block_builder_options=dj_pb.DataBlockBuilderOptions( data_block_builder='CSV_DICT_DATABLOCK_BUILDER')) self._worker_addrs_l = [ 'localhost:4161', 'localhost:4162', 'localhost:4163', 'localhost:4164' ] self._worker_addrs_f = [ 'localhost:5161', 'localhost:5162', 'localhost:5163', 'localhost:5164' ] self._workers_l = [] self._workers_f = [] for rank_id in range(4): worker_addr_l = self._worker_addrs_l[rank_id] worker_addr_f = self._worker_addrs_f[rank_id] self._workers_l.append( data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self._master_addr_l, rank_id, self._etcd_name, self._etcd_base_dir_l, self._etcd_addrs, worker_options)) self._workers_f.append( data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self._master_addr_f, rank_id, self._etcd_name, self._etcd_base_dir_f, self._etcd_addrs, worker_options)) for w in self._workers_l: w.start() for w in self._workers_f: w.start()
def _make_portal_worker(self, raw_data_iter, validation_ratio): portal_worker_options = dp_pb.DataPortalWorkerOptions( raw_data_options=dj_pb.RawDataOptions( raw_data_iter=raw_data_iter, read_ahead_size=1 << 20, read_batch_size=128, optional_fields=['label'], validation_ratio=validation_ratio, ), writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=128, max_flying_item=300000), merger_read_ahead_size=1000000, merger_read_batch_size=128) os.environ['ETCD_BASE_DIR'] = "portal_worker_0" self._portal_worker = DataPortalWorker(portal_worker_options, "localhost:5005", 0, "etcd", True)
def _preprocess_rsa_psi_follower(self): processors = [] rsa_key_pem = None with gfile.GFile(self._rsa_public_key_path, 'rb') as f: rsa_key_pem = f.read() self._follower_rsa_psi_sub_dir = 'follower_rsa_psi_sub_dir' rd_publisher = raw_data_publisher.RawDataPublisher( self._kvstore_f, self._follower_rsa_psi_sub_dir) for partition_id in range( self._data_source_f.data_source_meta.partition_num): rd_publisher.publish_raw_data( partition_id, [self._psi_raw_data_fpaths_f[partition_id]]) rd_publisher.finish_raw_data(partition_id) options = dj_pb.RsaPsiPreProcessorOptions( preprocessor_name='follower-rsa-psi-processor', role=common_pb.FLRole.Follower, rsa_key_pem=rsa_key_pem, input_file_subscribe_dir=self._follower_rsa_psi_sub_dir, output_file_dir=self._pre_processor_ouput_dir_f, raw_data_publish_dir=self._raw_data_pub_dir_f, partition_id=partition_id, leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr, offload_processor_number=1, max_flying_sign_batch=128, max_flying_sign_rpc=64, sign_rpc_timeout_ms=100000, stub_fanout=2, slow_sign_threshold=8, sort_run_merger_read_ahead_buffer=1 << 20, sort_run_merger_read_batch_size=128, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=1 << 14), input_raw_data=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', read_ahead_size=1 << 20), writer_options=dj_pb.WriterOptions(output_writer='CSV_DICT')) os.environ['ETCD_BASE_DIR'] = self.follower_base_dir processor = rsa_psi_preprocessor.RsaPsiPreProcessor( options, self.kvstore_type, True) processor.start_process() processors.append(processor) for processor in processors: processor.wait_for_finished()
def _preprocess_rsa_psi_leader(self): processors = [] for partition_id in range( self._data_source_l.data_source_meta.partition_num): options = dj_pb.RsaPsiPreProcessorOptions( role=common_pb.FLRole.Leader, rsa_key_file_path=self._rsa_private_key_path, input_file_paths=[self._psi_raw_data_fpaths_l[partition_id]], output_file_dir=self._pre_processor_ouput_dir_l, raw_data_publish_dir=self._raw_data_pub_dir_l, partition_id=partition_id, offload_processor_number=1, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=1 << 14)) processor = rsa_psi_preprocessor.RsaPsiPreProcessor( options, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_l, True) processor.start_process() processors.append(processor) for processor in processors: processor.wait_for_finished()
set_logger() if args.input_data_file_iter == 'TF_RECORD' or \ args.output_builder == 'TF_RECORD': import tensorflow tensorflow.compat.v1.enable_eager_execution() optional_fields = list( field for field in map(str.strip, args.optional_fields.split(',')) if field != '') portal_worker_options = dp_pb.DataPortalWorkerOptions( raw_data_options=dj_pb.RawDataOptions( raw_data_iter=args.input_data_file_iter, compressed_type=args.compressed_type, read_ahead_size=args.read_ahead_size, read_batch_size=args.read_batch_size, optional_fields=optional_fields), writer_options=dj_pb.WriterOptions( output_writer=args.output_builder, compressed_type=args.builder_compressed_type), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=args.batch_size, max_flying_item=-1), merger_read_ahead_size=args.merger_read_ahead_size, merger_read_batch_size=args.merger_read_batch_size, memory_limit_ratio=args.memory_limit_ratio / 100) data_portal_worker = DataPortalWorker(portal_worker_options, args.master_addr, args.rank_id, args.kvstore_type, (args.kvstore_type == 'mock')) data_portal_worker.start()
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, master_options, ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, master_options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) self.master_client_l = master_client_l self.master_client_f = master_client_f self.master_addr_l = master_addr_l self.master_addr_f = master_addr_f self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.master_l = master_l self.master_f = master_f self.data_source_name = data_source_name, self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.etcd_l, self.raw_data_pub_dir_l) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.etcd_f, self.raw_data_pub_dir_f) if gfile.Exists(data_source_l.data_block_dir): gfile.DeleteRecursively(data_source_l.data_block_dir) if gfile.Exists(data_source_l.example_dumped_dir): gfile.DeleteRecursively(data_source_l.example_dumped_dir) if gfile.Exists(data_source_l.raw_data_dir): gfile.DeleteRecursively(data_source_l.raw_data_dir) if gfile.Exists(data_source_f.data_block_dir): gfile.DeleteRecursively(data_source_f.data_block_dir) if gfile.Exists(data_source_f.example_dumped_dir): gfile.DeleteRecursively(data_source_f.example_dumped_dir) if gfile.Exists(data_source_f.raw_data_dir): gfile.DeleteRecursively(data_source_f.raw_data_dir) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD')) self.total_index = 1 << 13
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f= 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix(common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix(common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" self.raw_data_dir_l = "./raw_data_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" self.raw_data_dir_f = "./raw_data_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.data_source_name = data_source_name self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.etcd_l, self.raw_data_pub_dir_l ) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.etcd_f, self.raw_data_pub_dir_f ) if gfile.Exists(data_source_l.output_base_dir): gfile.DeleteRecursively(data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) if gfile.Exists(data_source_f.output_base_dir): gfile.DeleteRecursively(data_source_f.output_base_dir) if gfile.Exists(self.raw_data_dir_f): gfile.DeleteRecursively(self.raw_data_dir_f) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', read_ahead_size=1<<20, read_batch_size=128 ), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024 ), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000 ), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048 ), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD' ) ) self.total_index = 1 << 12
args = parser.parse_args() if args.tf_eager_mode: import tensorflow tensorflow.compat.v1.enable_eager_execution() worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=args.use_mock_etcd, raw_data_options=dj_pb.RawDataOptions( raw_data_iter=args.raw_data_iter, compressed_type=args.compressed_type, read_ahead_size=args.read_ahead_size), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner=args.example_joiner, min_matching_window=args.min_matching_window, max_matching_window=args.max_matching_window, data_block_dump_interval=args.data_block_dump_interval, data_block_dump_threshold=args.data_block_dump_threshold, ), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=args.example_id_dump_interval, example_id_dump_threshold=args.example_id_dump_threshold), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=args.example_id_batch_size, max_flying_item=args.max_flying_example_id), data_block_builder_options=dj_pb.DataBlockBuilderOptions( data_block_builder=args.data_block_builder)) worker_srv = DataJoinWorkerService(args.listen_port, args.peer_addr, args.master_addr, args.rank_id, args.etcd_name, args.etcd_base_dir, args.etcd_addrs, worker_options) worker_srv.run()
input_file_paths=all_fpaths, output_dir=args.output_dir, output_partition_num=args.output_partition_num, raw_data_options=dj_pb.RawDataOptions( raw_data_iter=args.raw_data_iter, compressed_type=args.compressed_type, read_ahead_size=args.read_ahead_size, read_batch_size=args.read_batch_size ), writer_options=dj_pb.WriterOptions( output_writer=args.output_builder, compressed_type=args.builder_compressed_type, ), partitioner_rank_id=args.partitioner_rank_id, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=4096, max_flying_item=-1 ) ) partitioner = RawDataPartitioner(partitioner_options, args.part_field, args.etcd_name, args.etcd_addrs, args.etcd_base_dir) logging.info("RawDataPartitioner %s of rank %d launched", partitioner_options.partitioner_name, partitioner_options.partitioner_rank_id) partitioner.start_process() partitioner.wait_for_finished() logging.info("RawDataPartitioner %s of rank %d finished", partitioner_options.partitioner_name, partitioner_options.partitioner_rank_id)
def _launch_workers(self): worker_options_l = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', read_ahead_size=1 << 20, read_batch_size=128), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='SORT_RUN_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=4096), data_block_builder_options=dj_pb.WriterOptions( output_writer='CSV_DICT')) worker_options_f = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=1 << 20, read_batch_size=128), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='SORT_RUN_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=4096), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD')) self._worker_addrs_l = [ 'localhost:4161', 'localhost:4162', 'localhost:4163', 'localhost:4164' ] self._worker_addrs_f = [ 'localhost:5161', 'localhost:5162', 'localhost:5163', 'localhost:5164' ] self._workers_l = [] self._workers_f = [] for rank_id in range(4): worker_addr_l = self._worker_addrs_l[rank_id] worker_addr_f = self._worker_addrs_f[rank_id] os.environ['ETCD_BASE_DIR'] = self.leader_base_dir self._workers_l.append( data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self._master_addr_l, rank_id, self.kvstore_type, worker_options_l)) os.environ['ETCD_BASE_DIR'] = self.follower_base_dir self._workers_f.append( data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self._master_addr_f, rank_id, self.kvstore_type, worker_options_f)) for w in self._workers_l: w.start() for w in self._workers_f: w.start()
if CityHash32(os.path.basename(fpath)) % partitioner_num == \ args.partitioner_rank_id] logging.info("Partitioner of rank id %d will process %d/%d "\ "input files", args.partitioner_rank_id, len(all_fpaths), origin_file_num) partitioner_options = dj_pb.RawDataPartitionerOptions( partitioner_name=args.partitioner_name, input_file_paths=all_fpaths, output_dir=args.output_dir, output_partition_num=args.output_partition_num, raw_data_options=dj_pb.RawDataOptions( raw_data_iter=args.raw_data_iter, compressed_type=args.compressed_type, read_ahead_size=args.read_ahead_size), output_builder=args.output_builder, output_item_threshold=args.output_item_threshold, partitioner_rank_id=args.partitioner_rank_id, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=args.raw_data_batch_size, max_flying_item=args.max_flying_raw_data)) partitioner = RawDataPartitioner(partitioner_options, args.etcd_name, args.etcd_addrs, args.etcd_base_dir) logging.info("RawDataPartitioner %s of rank %d launched", partitioner_options.partitioner_name, partitioner_options.partitioner_rank_id) partitioner.start_process() partitioner.wait_for_finished() logging.info("RawDataPartitioner %s of rank %d finished", partitioner_options.partitioner_name, partitioner_options.partitioner_rank_id)