def test_compressed_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = path.join( path.dirname(path.abspath(__file__)), "../compressed_raw_data" ) self.kvstore = DBClient('etcd', True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = path.join(self.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) manifest_manager.add_raw_data( 0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"), timestamp=timestamp_pb2.Timestamp(seconds=3))], True) raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', compressed_type='GZIP', read_ahead_size=1<<20, read_batch_size=128 ) rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 32 == 0: print("{} {}".format(index, item.example_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertGreater(expected_index, 0)
def __init__(self, options, kvstore_type, use_mock_etcd=False): self._lock = threading.Condition() self._options = options kvstore = DBClient(kvstore_type, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(kvstore, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._callback_submitter = None # pre fock sub processor before launch grpc client self._process_pool_executor.submit(min, 1, 2).result() self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options) if self._options.role == common_pb.FLRole.Leader: private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, options.batch_processor_options.max_flying_item, self._options.max_flying_sign_batch, self._options.slow_sign_threshold, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem) self._callback_submitter = concur_futures.ThreadPoolExecutor(1) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, options.batch_processor_options.max_flying_item, self._options.max_flying_sign_batch, self._options.max_flying_sign_rpc, self._options.sign_rpc_timeout_ms, self._options.slow_sign_threshold, self._options.stub_fanout, self._process_pool_executor, self._callback_submitter, public_key, self._options.leader_rsa_psi_signer_addr) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( dj_pb.SortRunMergerOptions( merger_name='sort_run_merger_'+\ partition_repr(options.partition_id), reader_options=dj_pb.RawDataOptions( raw_data_iter=options.writer_options.output_writer, compressed_type=options.writer_options.compressed_type, read_ahead_size=\ options.sort_run_merger_read_ahead_buffer, read_batch_size=\ options.sort_run_merger_read_batch_size ), writer_options=options.writer_options, output_file_dir=options.output_file_dir, partition_id=options.partition_id ), self._merger_comparator ) self._produce_item_cnt = 0 self._comsume_item_cnt = 0 self._started = False
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-f" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.raw_data_dir = "./raw_data" self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type='', optional_fields=['label']) self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=32, max_matching_window=128, data_block_dump_interval=30, data_block_dump_threshold=128) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-f" data_source.data_source_meta.partition_num = 1 data_source.data_block_dir = "./data_block" data_source.example_dumped_dir = "./example_id" data_source.raw_data_dir = "./raw_data" self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type='') self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=32, max_matching_window=128, data_block_dump_interval=30, data_block_dump_threshold=128) if gfile.Exists(self.data_source.data_block_dir): gfile.DeleteRecursively(self.data_source.data_block_dir) if gfile.Exists(self.data_source.example_dumped_dir): gfile.DeleteRecursively(self.data_source.example_dumped_dir) if gfile.Exists(self.data_source.raw_data_dir): gfile.DeleteRecursively(self.data_source.raw_data_dir) self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379', 'fedlearner', True) self.etcd.delete_prefix(self.data_source.data_source_meta.name) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) self.g_data_block_index = 0
def _preprocess_rsa_psi_follower(self): processors = [] rsa_key_pem = None with gfile.GFile(self._rsa_public_key_path, 'rb') as f: rsa_key_pem = f.read() for partition_id in range( self._data_source_f.data_source_meta.partition_num): options = dj_pb.RsaPsiPreProcessorOptions( preprocessor_name='follower-rsa-psi-processor', role=common_pb.FLRole.Follower, rsa_key_pem=rsa_key_pem, input_file_paths=[self._psi_raw_data_fpaths_f[partition_id]], output_file_dir=self._pre_processor_ouput_dir_f, raw_data_publish_dir=self._raw_data_pub_dir_f, partition_id=partition_id, leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr, offload_processor_number=1, max_flying_sign_batch=128, max_flying_sign_rpc=64, sign_rpc_timeout_ms=100000, stub_fanout=2, slow_sign_threshold=8, sort_run_merger_read_ahead_buffer=1 << 20, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=1 << 14), input_raw_data=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=1 << 20), writer_options=dj_pb.WriterOptions(output_writer='TF_RECORD')) processor = rsa_psi_preprocessor.RsaPsiPreProcessor( options, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_f, True) processor.start_process() processors.append(processor) for processor in processors: processor.wait_for_finished()
def test_csv_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)), "../csv_raw_data") self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379', 'fedlearner', True) self.etcd.delete_prefix( common.data_source_etcd_base_dir( self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = path.join(self.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) manifest_manager.add_raw_data(0, [ dj_pb.RawDataMeta(file_path=path.join(partition_dir, "test_raw_data.csv"), timestamp=timestamp_pb2.Timestamp(seconds=3)) ], True) raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=1 << 20) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 1024 == 0: print("{} {}".format(index, item.raw_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertEqual(expected_index, 4999)
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-f" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.raw_data_dir = "./raw_data" self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type='') self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='ATTRIBUTION_JOINER', min_matching_window=32, max_matching_window=51200, max_conversion_delay=interval_to_timestamp("124"), enable_negative_example_generator=True, data_block_dump_interval=32, data_block_dump_threshold=128, negative_sampling_rate=0.8, ) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0
def test_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.data_source.raw_data_dir = "./test/compressed_raw_data" self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379', 'fedlearner', True) self.etcd.delete_prefix(self.data_source.data_source_meta.name) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) manifest_manager.add_raw_data( 0, [dj_pb.RawDataMeta(file_path=os.path.join(partition_dir, "0-0.idx"), timestamp=timestamp_pb2.Timestamp(seconds=3))], True) raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_DATASET', compressed_type='GZIP' ) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source,0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 32 == 0: print("{} {}".format(index, item.example_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertGreater(expected_index, 0)
def _preprocess_rsa_psi_leader(self): processors = [] rsa_key_pem = None with gfile.GFile(self._rsa_private_key_path, 'rb') as f: rsa_key_pem = f.read() for partition_id in range( self._data_source_l.data_source_meta.partition_num): options = dj_pb.RsaPsiPreProcessorOptions( preprocessor_name='leader-rsa-psi-processor', role=common_pb.FLRole.Leader, rsa_key_pem=rsa_key_pem, input_file_paths=[self._psi_raw_data_fpaths_l[partition_id]], output_file_dir=self._pre_processor_ouput_dir_l, raw_data_publish_dir=self._raw_data_pub_dir_l, partition_id=partition_id, offload_processor_number=1, max_flying_sign_batch=128, stub_fanout=2, slow_sign_threshold=8, sort_run_merger_read_ahead_buffer=1 << 20, sort_run_merger_read_batch_size=128, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=1 << 14), input_raw_data=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=1 << 20), writer_options=dj_pb.WriterOptions(output_writer='TF_RECORD')) os.environ['ETCD_BASE_DIR'] = self.leader_base_dir processor = rsa_psi_preprocessor.RsaPsiPreProcessor( options, self.kvstore_type, True) processor.start_process() processors.append(processor) for processor in processors: processor.wait_for_finished()
def _create_etcd_based_mock_visitor(self): return EtcdBasedMockRawDataVisitor( self._etcd, dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=134217728), '{}-proprocessor-mock-data-source-{:04}'.format( self._options.preprocessor_name, self._options.partition_id), self._options.input_file_subscribe_dir)
def __init__(self, options, etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd=False): self._lock = threading.Condition() self._options = options etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(etcd, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._id_batch_fetcher = IdBatchFetcher(etcd, self._options) max_flying_item = options.batch_processor_options.max_flying_item if self._options.role == common_pb.FLRole.Leader: private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.slow_sign_threshold, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.max_flying_sign_rpc, self._options.sign_rpc_timeout_ms, self._options.slow_sign_threshold, self._options.stub_fanout, self._process_pool_executor, public_key, self._options.leader_rsa_psi_signer_addr) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( dj_pb.SortRunMergerOptions( merger_name='sort_run_merger_'+\ partition_repr(options.partition_id), reader_options=dj_pb.RawDataOptions( raw_data_iter=options.writer_options.output_writer, compressed_type=options.writer_options.compressed_type, read_ahead_size=\ options.sort_run_merger_read_ahead_buffer ), writer_options=options.writer_options, output_file_dir=options.output_file_dir, partition_id=options.partition_id ), 'example_id' ) self._started = False
def __init__(self, options): super(IdBatchFetcher, self).__init__( options.batch_processor_options.max_flying_item, ) self._id_visitor = BatchRawDataVisitor( options.input_file_paths, dj_pb.RawDataOptions(raw_data_iter='CSV_DICT') ) self._batch_size = options.batch_processor_options.batch_size self.set_input_finished()
def _make_merger_options(self, task): return dj_pb.SortRunMergerOptions( merger_name="{}-rank_{}".format(task.task_name, self._rank_id), reader_options=dj_pb.RawDataOptions( raw_data_iter=self._options.writer_options.output_writer, compressed_type=self._options.writer_options.compressed_type, read_ahead_size=self._options.merger_read_ahead_size, read_batch_size=self._options.merger_read_batch_size), writer_options=self._options.writer_options, output_file_dir=task.reduce_base_dir, partition_id=task.partition_id, )
def __init__(self, etcd, options): super(IdBatchFetcher, self).__init__(options.batch_processor_options.max_flying_item, ) self._id_visitor = MockRawDataVisitor( etcd, dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=134217728), '{}-proprocessor-mock-data-source-{:04}'.format( options.preprocessor_name, options.partition_id), options.input_file_paths) self._batch_size = options.batch_processor_options.batch_size self.set_input_finished()
def _get_id_visitor(self): if self._id_visitor is None: self._id_visitor = MockRawDataVisitor( self._etcd, dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=134217728), '{}-proprocessor-mock-data-source-{:04}'.format( self._options.preprocessor_name, self._options.partition_id ), self._options.input_file_paths ) self.set_input_finished() return self._id_visitor
def _make_portal_worker(self): portal_worker_options = dp_pb.DataPortalWorkerOptions( raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD", compressed_type=''), writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=128, max_flying_item=300000), merge_buffer_size=4096, merger_read_ahead_size=1000000) self._portal_worker = DataPortalWorker(portal_worker_options, "localhost:5005", 0, "test_portal_worker_0", "portal_worker_0", "localhost:2379", True)
def _make_portal_worker(self): portal_worker_options = dp_pb.DataPortalWorkerOptions( raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD", read_ahead_size=1 << 20, read_batch_size=128), writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=128, max_flying_item=300000), merger_read_ahead_size=1000000, merger_read_batch_size=128) self._portal_worker = DataPortalWorker(portal_worker_options, "localhost:5005", 0, "test_portal_worker_0", "portal_worker_0", "localhost:2379", "test_user", "test_password", True)
def init(self, dsname, joiner_name, version=Version.V1, cache_type="memory"): data_source = common_pb.DataSource() data_source.data_source_meta.name = dsname data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "%s_ds_output" % dsname self.raw_data_dir = "%s_raw_data" % dsname self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', compressed_type='', raw_data_cache_type=cache_type, ) self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner=joiner_name, min_matching_window=32, max_matching_window=51200, max_conversion_delay=interval_to_timestamp("124"), enable_negative_example_generator=True, data_block_dump_interval=32, data_block_dump_threshold=128, negative_sampling_rate=0.8, join_expr="example_id", join_key_mapper="DEFAULT", negative_sampling_filter_expr='', ) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0 self.version = version
def _launch_portals(self): portal_options = dj_pb.DataJoinPotralOptions( example_validator=dj_pb.ExampleValidatorOptions( example_validator='EXAMPLE_VALIDATOR', validate_event_time=True), reducer_buffer_size=1024, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), use_mock_etcd=True) self._portal_l = data_join_portal.DataJoinPortal( self._portal_name, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_l, portal_options) self._portal_f = data_join_portal.DataJoinPortal( self._portal_name, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_f, portal_options) self._portal_l.start() self._portal_f.start()
def _next_internal(self): if not self._finished: try: item = None if self._fiter is None: raw_data_options = \ dj_pb.RawDataOptions(raw_data_iter='CSV_DICT') self._fiter = CsvDictIter(raw_data_options) meta = visitor.IndexMeta(0, 0, self._fpath) self._fiter.reset_iter(meta, True) item = self._fiter.get_item() else: _, item = next(self._fiter) assert item is not None return SortRunReader.MergeItem(item, self._reader_index) except StopIteration: self._finished = True raise StopIteration("%s has been iter finished" % self._fpath)
def _launch_workers(self): worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='SORT_RUN_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=4096), data_block_builder_options=dj_pb.DataBlockBuilderOptions( data_block_builder='CSV_DICT_DATABLOCK_BUILDER')) self._worker_addrs_l = [ 'localhost:4161', 'localhost:4162', 'localhost:4163', 'localhost:4164' ] self._worker_addrs_f = [ 'localhost:5161', 'localhost:5162', 'localhost:5163', 'localhost:5164' ] self._workers_l = [] self._workers_f = [] for rank_id in range(4): worker_addr_l = self._worker_addrs_l[rank_id] worker_addr_f = self._worker_addrs_f[rank_id] self._workers_l.append( data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self._master_addr_l, rank_id, self._etcd_name, self._etcd_base_dir_l, self._etcd_addrs, worker_options)) self._workers_f.append( data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self._master_addr_f, rank_id, self._etcd_name, self._etcd_base_dir_f, self._etcd_addrs, worker_options)) for w in self._workers_l: w.start() for w in self._workers_f: w.start()
def _prepare_test(self): self._portal_manifest = common_pb.DataJoinPortalManifest( name='test_portal', input_partition_num=4, output_partition_num=8, input_data_base_dir='./portal_input', output_data_base_dir='./portal_output') self._portal_options = dj_pb.DataJoinPotralOptions( example_validator=dj_pb.ExampleValidatorOptions( example_validator='EXAMPLE_VALIDATOR', validate_event_time=True, ), reducer_buffer_size=128, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD'), use_mock_etcd=True) self._date_time = common.convert_timestamp_to_datetime( common.trim_timestamp_by_hourly( common.convert_datetime_to_timestamp(datetime.now()))) self._generate_input_data()
def _make_portal_worker(self, raw_data_iter, validation_ratio): portal_worker_options = dp_pb.DataPortalWorkerOptions( raw_data_options=dj_pb.RawDataOptions( raw_data_iter=raw_data_iter, read_ahead_size=1 << 20, read_batch_size=128, optional_fields=['label'], validation_ratio=validation_ratio, ), writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=128, max_flying_item=300000), merger_read_ahead_size=1000000, merger_read_batch_size=128) os.environ['ETCD_BASE_DIR'] = "portal_worker_0" self._portal_worker = DataPortalWorker(portal_worker_options, "localhost:5005", 0, "etcd", True)
def _launch_workers(self): worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), example_id_batch_options=dj_pb.ExampleIdBatchOptions( example_id_batch_size=1024, max_flying_example_id=4096)) self._worker_addrs_l = ['localhost:4161', 'localhost:4162'] self._worker_addrs_f = ['localhost:5161', 'localhost:5162'] self._workers_l = [] self._workers_f = [] for rank_id in range(2): worker_addr_l = self._worker_addrs_l[rank_id] worker_addr_f = self._worker_addrs_f[rank_id] self._workers_l.append( data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self._master_addr_l, rank_id, self._etcd_name, self._etcd_base_dir_l, self._etcd_addrs, worker_options)) self._workers_f.append( data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self._master_addr_f, rank_id, self._etcd_name, self._etcd_base_dir_f, self._etcd_addrs, worker_options)) for w in self._workers_l: w.start() for w in self._workers_f: w.start()
args = parser.parse_args() set_logger() if args.input_data_file_iter == 'TF_RECORD' or \ args.output_builder == 'TF_RECORD': import tensorflow tensorflow.compat.v1.enable_eager_execution() optional_fields = list( field for field in map(str.strip, args.optional_fields.split(',')) if field != '') portal_worker_options = dp_pb.DataPortalWorkerOptions( raw_data_options=dj_pb.RawDataOptions( raw_data_iter=args.input_data_file_iter, compressed_type=args.compressed_type, read_ahead_size=args.read_ahead_size, read_batch_size=args.read_batch_size, optional_fields=optional_fields), writer_options=dj_pb.WriterOptions( output_writer=args.output_builder, compressed_type=args.builder_compressed_type), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=args.batch_size, max_flying_item=-1), merger_read_ahead_size=args.merger_read_ahead_size, merger_read_batch_size=args.merger_read_batch_size, memory_limit_ratio=args.memory_limit_ratio / 100) data_portal_worker = DataPortalWorker(portal_worker_options, args.master_addr, args.rank_id, args.kvstore_type, (args.kvstore_type == 'mock')) data_portal_worker.start()
def test_data_block_dumper(self): self.generate_follower_data_block() self.generate_leader_raw_data() dbd = data_block_dumper.DataBlockDumperManager( self.etcd, self.data_source_l, 0, dj_pb.RawDataOptions(raw_data_iter='TF_RECORD'), dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER'), ) self.assertEqual(dbd.get_next_data_block_index(), 0) for (idx, meta) in enumerate(self.dumped_metas): success, next_index = dbd.add_synced_data_block_meta(meta) self.assertTrue(success) self.assertEqual(next_index, idx + 1) self.assertTrue(dbd.need_dump()) self.assertEqual(dbd.get_next_data_block_index(), len(self.dumped_metas)) with dbd.make_data_block_dumper() as dumper: dumper() dbm_f = data_block_manager.DataBlockManager(self.data_source_f, 0) dbm_l = data_block_manager.DataBlockManager(self.data_source_l, 0) self.assertEqual(dbm_f.get_dumped_data_block_count(), len(self.dumped_metas)) self.assertEqual(dbm_f.get_dumped_data_block_count(), dbm_l.get_dumped_data_block_count()) for (idx, meta) in enumerate(self.dumped_metas): self.assertEqual(meta.data_block_index, idx) self.assertEqual(dbm_l.get_data_block_meta_by_index(idx), meta) self.assertEqual(dbm_f.get_data_block_meta_by_index(idx), meta) meta_fpth_l = os.path.join( self.data_source_l.data_block_dir, common.partition_repr(0), common.encode_data_block_meta_fname( self.data_source_l.data_source_meta.name, 0, meta.data_block_index)) mitr = tf.io.tf_record_iterator(meta_fpth_l) meta_l = text_format.Parse(next(mitr), dj_pb.DataBlockMeta()) self.assertEqual(meta_l, meta) meta_fpth_f = os.path.join( self.data_source_f.data_block_dir, common.partition_repr(0), common.encode_data_block_meta_fname( self.data_source_f.data_source_meta.name, 0, meta.data_block_index)) mitr = tf.io.tf_record_iterator(meta_fpth_f) meta_f = text_format.Parse(next(mitr), dj_pb.DataBlockMeta()) self.assertEqual(meta_f, meta) data_fpth_l = os.path.join( self.data_source_l.data_block_dir, common.partition_repr(0), common.encode_data_block_fname( self.data_source_l.data_source_meta.name, meta_l)) for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_l)): example = tf.train.Example() example.ParseFromString(record) feat = example.features.feature self.assertEqual(feat['example_id'].bytes_list.value[0], meta.example_ids[iidx]) self.assertEqual(len(meta.example_ids), iidx + 1) data_fpth_f = os.path.join( self.data_source_f.data_block_dir, common.partition_repr(0), common.encode_data_block_fname( self.data_source_l.data_source_meta.name, meta_f)) for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_f)): example = tf.train.Example() example.ParseFromString(record) feat = example.features.feature self.assertEqual(feat['example_id'].bytes_list.value[0], meta.example_ids[iidx]) self.assertEqual(len(meta.example_ids), iidx + 1)
max_flying_sign_batch=args.max_flying_sign_batch, max_flying_sign_rpc=args.max_flying_sign_rpc, sign_rpc_timeout_ms=args.sign_rpc_timeout_ms, stub_fanout=args.stub_fanout, slow_sign_threshold=args.slow_sign_threshold, sort_run_merger_read_ahead_buffer=\ args.sort_run_merger_read_ahead_buffer, sort_run_merger_read_batch_size=\ args.sort_run_merger_read_batch_size, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=args.process_batch_size, max_flying_item=-1 ), input_raw_data=dj_pb.RawDataOptions( raw_data_iter=args.raw_data_iter, compressed_type=args.compressed_type, read_ahead_size=args.read_ahead_size, read_batch_size=args.read_batch_size ), writer_options=dj_pb.WriterOptions( output_writer=args.output_builder, compressed_type=args.builder_compressed_type, ) ) if args.psi_role.upper() == 'LEADER': preprocessor_options.role = common_pb.FLRole.Leader else: assert args.psi_role.upper() == 'FOLLOWER' preprocessor_options.role = common_pb.FLRole.Follower preprocessor = RsaPsiPreProcessor(preprocessor_options, args.etcd_name, args.etcd_addrs, args.etcd_base_dir) preprocessor.start_process()
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f= 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix(common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix(common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" self.raw_data_dir_l = "./raw_data_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" self.raw_data_dir_f = "./raw_data_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.data_source_name = data_source_name self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.etcd_l, self.raw_data_pub_dir_l ) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.etcd_f, self.raw_data_pub_dir_f ) if gfile.Exists(data_source_l.output_base_dir): gfile.DeleteRecursively(data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) if gfile.Exists(data_source_f.output_base_dir): gfile.DeleteRecursively(data_source_f.output_base_dir) if gfile.Exists(self.raw_data_dir_f): gfile.DeleteRecursively(self.raw_data_dir_f) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', read_ahead_size=1<<20, read_batch_size=128 ), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024 ), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000 ), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048 ), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD' ) ) self.total_index = 1 << 12
type=int, default=4096, help='the buffer size of portal reducer') parser.add_argument('--example_validator', type=str, default='EXAMPLE_VALIDATOR', help='the name of example validator') parser.add_argument('--validate_event_time', action='store_true', help='validate the example has event time') args = parser.parse_args() options = dj_pb.DataJoinPotralOptions( example_validator=dj_pb.ExampleValidatorOptions( example_validator=args.example_validator, validate_event_time=args.validate_event_time), reducer_buffer_size=args.portal_reducer_buffer_size, raw_data_options=dj_pb.RawDataOptions( raw_data_iter=args.input_data_file_iter, compressed_type=args.compressed_type), raw_data_publish_dir=args.raw_data_publish_dir, use_mock_etcd=args.use_mock_etcd) portal_srv = DataJoinPortalService( args.listen_port, args.data_join_portal_name, args.etcd_name, args.etcd_addrs, args.etcd_base_dir, options, ) portal_srv.run()
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, master_options, ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, master_options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) self.master_client_l = master_client_l self.master_client_f = master_client_f self.master_addr_l = master_addr_l self.master_addr_f = master_addr_f self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.master_l = master_l self.master_f = master_f self.data_source_name = data_source_name, self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.etcd_l, self.raw_data_pub_dir_l) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.etcd_f, self.raw_data_pub_dir_f) if gfile.Exists(data_source_l.data_block_dir): gfile.DeleteRecursively(data_source_l.data_block_dir) if gfile.Exists(data_source_l.example_dumped_dir): gfile.DeleteRecursively(data_source_l.example_dumped_dir) if gfile.Exists(data_source_l.raw_data_dir): gfile.DeleteRecursively(data_source_l.raw_data_dir) if gfile.Exists(data_source_f.data_block_dir): gfile.DeleteRecursively(data_source_f.data_block_dir) if gfile.Exists(data_source_f.example_dumped_dir): gfile.DeleteRecursively(data_source_f.example_dumped_dir) if gfile.Exists(data_source_f.raw_data_dir): gfile.DeleteRecursively(data_source_f.raw_data_dir) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD')) self.total_index = 1 << 13