def test_compressed_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(
             path.dirname(path.abspath(__file__)), "../compressed_raw_data"
         )
     self.kvstore = DBClient('etcd', True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_RECORD',
             compressed_type='GZIP',
             read_ahead_size=1<<20,
             read_batch_size=128
         )
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
Beispiel #2
0
 def __init__(self, options, kvstore_type, use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     kvstore = DBClient(kvstore_type, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(kvstore, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._callback_submitter = None
     # pre fock sub processor before launch grpc client
     self._process_pool_executor.submit(min, 1, 2).result()
     self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options)
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._callback_submitter = concur_futures.ThreadPoolExecutor(1)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, self._callback_submitter,
             public_key, self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
             dj_pb.SortRunMergerOptions(
                 merger_name='sort_run_merger_'+\
                             partition_repr(options.partition_id),
                 reader_options=dj_pb.RawDataOptions(
                     raw_data_iter=options.writer_options.output_writer,
                     compressed_type=options.writer_options.compressed_type,
                     read_ahead_size=\
                         options.sort_run_merger_read_ahead_buffer,
                     read_batch_size=\
                         options.sort_run_merger_read_batch_size
                 ),
                 writer_options=options.writer_options,
                 output_file_dir=options.output_file_dir,
                 partition_id=options.partition_id
             ),
             self._merger_comparator
         )
     self._produce_item_cnt = 0
     self._comsume_item_cnt = 0
     self._started = False
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='',
                                                  optional_fields=['label'])
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='STREAM_JOINER',
         min_matching_window=32,
         max_matching_window=128,
         data_block_dump_interval=30,
         data_block_dump_threshold=128)
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379',
                                          'test_user', 'test_password',
                                          'fedlearner', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
Beispiel #4
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.data_block_dir = "./data_block"
     data_source.example_dumped_dir = "./example_id"
     data_source.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='')
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='STREAM_JOINER',
         min_matching_window=32,
         max_matching_window=128,
         data_block_dump_interval=30,
         data_block_dump_threshold=128)
     if gfile.Exists(self.data_source.data_block_dir):
         gfile.DeleteRecursively(self.data_source.data_block_dir)
     if gfile.Exists(self.data_source.example_dumped_dir):
         gfile.DeleteRecursively(self.data_source.example_dumped_dir)
     if gfile.Exists(self.data_source.raw_data_dir):
         gfile.DeleteRecursively(self.data_source.raw_data_dir)
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     self.g_data_block_index = 0
Beispiel #5
0
 def _preprocess_rsa_psi_follower(self):
     processors = []
     rsa_key_pem = None
     with gfile.GFile(self._rsa_public_key_path, 'rb') as f:
         rsa_key_pem = f.read()
     for partition_id in range(
             self._data_source_f.data_source_meta.partition_num):
         options = dj_pb.RsaPsiPreProcessorOptions(
             preprocessor_name='follower-rsa-psi-processor',
             role=common_pb.FLRole.Follower,
             rsa_key_pem=rsa_key_pem,
             input_file_paths=[self._psi_raw_data_fpaths_f[partition_id]],
             output_file_dir=self._pre_processor_ouput_dir_f,
             raw_data_publish_dir=self._raw_data_pub_dir_f,
             partition_id=partition_id,
             leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr,
             offload_processor_number=1,
             max_flying_sign_batch=128,
             max_flying_sign_rpc=64,
             sign_rpc_timeout_ms=100000,
             stub_fanout=2,
             slow_sign_threshold=8,
             sort_run_merger_read_ahead_buffer=1 << 20,
             batch_processor_options=dj_pb.BatchProcessorOptions(
                 batch_size=1024, max_flying_item=1 << 14),
             input_raw_data=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                                 read_ahead_size=1 << 20),
             writer_options=dj_pb.WriterOptions(output_writer='TF_RECORD'))
         processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
             options, self._etcd_name, self._etcd_addrs,
             self._etcd_base_dir_f, True)
         processor.start_process()
         processors.append(processor)
     for processor in processors:
         processor.wait_for_finished()
Beispiel #6
0
 def test_csv_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)),
                                   "../csv_raw_data")
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(0, [
         dj_pb.RawDataMeta(file_path=path.join(partition_dir,
                                               "test_raw_data.csv"),
                           timestamp=timestamp_pb2.Timestamp(seconds=3))
     ], True)
     raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                             read_ahead_size=1 << 20)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {}".format(index, item.raw_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertEqual(expected_index, 4999)
Beispiel #7
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='')
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='ATTRIBUTION_JOINER',
         min_matching_window=32,
         max_matching_window=51200,
         max_conversion_delay=interval_to_timestamp("124"),
         enable_negative_example_generator=True,
         data_block_dump_interval=32,
         data_block_dump_threshold=128,
         negative_sampling_rate=0.8,
     )
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=os.path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_DATASET',
             compressed_type='GZIP'
         )
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
Beispiel #9
0
 def _preprocess_rsa_psi_leader(self):
     processors = []
     rsa_key_pem = None
     with gfile.GFile(self._rsa_private_key_path, 'rb') as f:
         rsa_key_pem = f.read()
     for partition_id in range(
             self._data_source_l.data_source_meta.partition_num):
         options = dj_pb.RsaPsiPreProcessorOptions(
             preprocessor_name='leader-rsa-psi-processor',
             role=common_pb.FLRole.Leader,
             rsa_key_pem=rsa_key_pem,
             input_file_paths=[self._psi_raw_data_fpaths_l[partition_id]],
             output_file_dir=self._pre_processor_ouput_dir_l,
             raw_data_publish_dir=self._raw_data_pub_dir_l,
             partition_id=partition_id,
             offload_processor_number=1,
             max_flying_sign_batch=128,
             stub_fanout=2,
             slow_sign_threshold=8,
             sort_run_merger_read_ahead_buffer=1 << 20,
             sort_run_merger_read_batch_size=128,
             batch_processor_options=dj_pb.BatchProcessorOptions(
                 batch_size=1024, max_flying_item=1 << 14),
             input_raw_data=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                                 read_ahead_size=1 << 20),
             writer_options=dj_pb.WriterOptions(output_writer='TF_RECORD'))
         os.environ['ETCD_BASE_DIR'] = self.leader_base_dir
         processor = rsa_psi_preprocessor.RsaPsiPreProcessor(
             options, self.kvstore_type, True)
         processor.start_process()
         processors.append(processor)
     for processor in processors:
         processor.wait_for_finished()
Beispiel #10
0
 def _create_etcd_based_mock_visitor(self):
     return EtcdBasedMockRawDataVisitor(
         self._etcd,
         dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                              read_ahead_size=134217728),
         '{}-proprocessor-mock-data-source-{:04}'.format(
             self._options.preprocessor_name, self._options.partition_id),
         self._options.input_file_subscribe_dir)
 def __init__(self,
              options,
              etcd_name,
              etcd_addrs,
              etcd_base_dir,
              use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(etcd, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._id_batch_fetcher = IdBatchFetcher(etcd, self._options)
     max_flying_item = options.batch_processor_options.max_flying_item
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher, max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, public_key,
             self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
             dj_pb.SortRunMergerOptions(
                 merger_name='sort_run_merger_'+\
                             partition_repr(options.partition_id),
                 reader_options=dj_pb.RawDataOptions(
                     raw_data_iter=options.writer_options.output_writer,
                     compressed_type=options.writer_options.compressed_type,
                     read_ahead_size=\
                         options.sort_run_merger_read_ahead_buffer
                 ),
                 writer_options=options.writer_options,
                 output_file_dir=options.output_file_dir,
                 partition_id=options.partition_id
             ),
             'example_id'
         )
     self._started = False
Beispiel #12
0
 def __init__(self, options):
     super(IdBatchFetcher, self).__init__(
             options.batch_processor_options.max_flying_item,
         )
     self._id_visitor = BatchRawDataVisitor(
             options.input_file_paths,
             dj_pb.RawDataOptions(raw_data_iter='CSV_DICT')
         )
     self._batch_size = options.batch_processor_options.batch_size
     self.set_input_finished()
Beispiel #13
0
 def _make_merger_options(self, task):
     return dj_pb.SortRunMergerOptions(
         merger_name="{}-rank_{}".format(task.task_name, self._rank_id),
         reader_options=dj_pb.RawDataOptions(
             raw_data_iter=self._options.writer_options.output_writer,
             compressed_type=self._options.writer_options.compressed_type,
             read_ahead_size=self._options.merger_read_ahead_size,
             read_batch_size=self._options.merger_read_batch_size),
         writer_options=self._options.writer_options,
         output_file_dir=task.reduce_base_dir,
         partition_id=task.partition_id,
     )
Beispiel #14
0
 def __init__(self, etcd, options):
     super(IdBatchFetcher,
           self).__init__(options.batch_processor_options.max_flying_item, )
     self._id_visitor = MockRawDataVisitor(
         etcd,
         dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                              read_ahead_size=134217728),
         '{}-proprocessor-mock-data-source-{:04}'.format(
             options.preprocessor_name,
             options.partition_id), options.input_file_paths)
     self._batch_size = options.batch_processor_options.batch_size
     self.set_input_finished()
Beispiel #15
0
 def _get_id_visitor(self):
     if self._id_visitor is None:
         self._id_visitor = MockRawDataVisitor(
                 self._etcd,
                 dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                      read_ahead_size=134217728),
                 '{}-proprocessor-mock-data-source-{:04}'.format(
                     self._options.preprocessor_name,
                     self._options.partition_id
                 ),
                 self._options.input_file_paths
             )
         self.set_input_finished()
     return self._id_visitor
    def _make_portal_worker(self):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD",
                                                  compressed_type=''),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merge_buffer_size=4096,
            merger_read_ahead_size=1000000)

        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0,
                                               "test_portal_worker_0",
                                               "portal_worker_0",
                                               "localhost:2379", True)
Beispiel #17
0
    def _make_portal_worker(self):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter="TF_RECORD",
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merger_read_ahead_size=1000000,
            merger_read_batch_size=128)

        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0,
                                               "test_portal_worker_0",
                                               "portal_worker_0",
                                               "localhost:2379", "test_user",
                                               "test_password", True)
Beispiel #18
0
 def init(self,
          dsname,
          joiner_name,
          version=Version.V1,
          cache_type="memory"):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = dsname
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "%s_ds_output" % dsname
     self.raw_data_dir = "%s_raw_data" % dsname
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(
         raw_data_iter='TF_RECORD',
         compressed_type='',
         raw_data_cache_type=cache_type,
     )
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner=joiner_name,
         min_matching_window=32,
         max_matching_window=51200,
         max_conversion_delay=interval_to_timestamp("124"),
         enable_negative_example_generator=True,
         data_block_dump_interval=32,
         data_block_dump_threshold=128,
         negative_sampling_rate=0.8,
         join_expr="example_id",
         join_key_mapper="DEFAULT",
         negative_sampling_filter_expr='',
     )
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
     self.version = version
    def _launch_portals(self):
        portal_options = dj_pb.DataJoinPotralOptions(
            example_validator=dj_pb.ExampleValidatorOptions(
                example_validator='EXAMPLE_VALIDATOR',
                validate_event_time=True),
            reducer_buffer_size=1024,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            use_mock_etcd=True)
        self._portal_l = data_join_portal.DataJoinPortal(
            self._portal_name, self._etcd_name, self._etcd_addrs,
            self._etcd_base_dir_l, portal_options)

        self._portal_f = data_join_portal.DataJoinPortal(
            self._portal_name, self._etcd_name, self._etcd_addrs,
            self._etcd_base_dir_f, portal_options)
        self._portal_l.start()
        self._portal_f.start()
Beispiel #20
0
 def _next_internal(self):
     if not self._finished:
         try:
             item = None
             if self._fiter is None:
                 raw_data_options = \
                     dj_pb.RawDataOptions(raw_data_iter='CSV_DICT')
                 self._fiter = CsvDictIter(raw_data_options)
                 meta = visitor.IndexMeta(0, 0, self._fpath)
                 self._fiter.reset_iter(meta, True)
                 item = self._fiter.get_item()
             else:
                 _, item = next(self._fiter)
             assert item is not None
             return SortRunReader.MergeItem(item, self._reader_index)
         except StopIteration:
             self._finished = True
     raise StopIteration("%s has been iter finished" % self._fpath)
Beispiel #21
0
 def _launch_workers(self):
     worker_options = dj_pb.DataJoinWorkerOptions(
         use_mock_etcd=True,
         raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                               compressed_type=''),
         example_id_dump_options=dj_pb.ExampleIdDumpOptions(
             example_id_dump_interval=1, example_id_dump_threshold=1024),
         example_joiner_options=dj_pb.ExampleJoinerOptions(
             example_joiner='SORT_RUN_JOINER',
             min_matching_window=64,
             max_matching_window=256,
             data_block_dump_interval=30,
             data_block_dump_threshold=1000),
         batch_processor_options=dj_pb.BatchProcessorOptions(
             batch_size=1024, max_flying_item=4096),
         data_block_builder_options=dj_pb.DataBlockBuilderOptions(
             data_block_builder='CSV_DICT_DATABLOCK_BUILDER'))
     self._worker_addrs_l = [
         'localhost:4161', 'localhost:4162', 'localhost:4163',
         'localhost:4164'
     ]
     self._worker_addrs_f = [
         'localhost:5161', 'localhost:5162', 'localhost:5163',
         'localhost:5164'
     ]
     self._workers_l = []
     self._workers_f = []
     for rank_id in range(4):
         worker_addr_l = self._worker_addrs_l[rank_id]
         worker_addr_f = self._worker_addrs_f[rank_id]
         self._workers_l.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_l.split(':')[1]), worker_addr_f,
                 self._master_addr_l, rank_id, self._etcd_name,
                 self._etcd_base_dir_l, self._etcd_addrs, worker_options))
         self._workers_f.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_f.split(':')[1]), worker_addr_l,
                 self._master_addr_f, rank_id, self._etcd_name,
                 self._etcd_base_dir_f, self._etcd_addrs, worker_options))
     for w in self._workers_l:
         w.start()
     for w in self._workers_f:
         w.start()
Beispiel #22
0
 def _prepare_test(self):
     self._portal_manifest = common_pb.DataJoinPortalManifest(
         name='test_portal',
         input_partition_num=4,
         output_partition_num=8,
         input_data_base_dir='./portal_input',
         output_data_base_dir='./portal_output')
     self._portal_options = dj_pb.DataJoinPotralOptions(
         example_validator=dj_pb.ExampleValidatorOptions(
             example_validator='EXAMPLE_VALIDATOR',
             validate_event_time=True,
         ),
         reducer_buffer_size=128,
         raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD'),
         use_mock_etcd=True)
     self._date_time = common.convert_timestamp_to_datetime(
         common.trim_timestamp_by_hourly(
             common.convert_datetime_to_timestamp(datetime.now())))
     self._generate_input_data()
Beispiel #23
0
    def _make_portal_worker(self, raw_data_iter, validation_ratio):
        portal_worker_options = dp_pb.DataPortalWorkerOptions(
            raw_data_options=dj_pb.RawDataOptions(
                raw_data_iter=raw_data_iter,
                read_ahead_size=1 << 20,
                read_batch_size=128,
                optional_fields=['label'],
                validation_ratio=validation_ratio,
            ),
            writer_options=dj_pb.WriterOptions(output_writer="TF_RECORD"),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=128, max_flying_item=300000),
            merger_read_ahead_size=1000000,
            merger_read_batch_size=128)

        os.environ['ETCD_BASE_DIR'] = "portal_worker_0"
        self._portal_worker = DataPortalWorker(portal_worker_options,
                                               "localhost:5005", 0, "etcd",
                                               True)
Beispiel #24
0
 def _launch_workers(self):
     worker_options = dj_pb.DataJoinWorkerOptions(
         use_mock_etcd=True,
         raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                               compressed_type=''),
         example_id_dump_options=dj_pb.ExampleIdDumpOptions(
             example_id_dump_interval=1, example_id_dump_threshold=1024),
         example_joiner_options=dj_pb.ExampleJoinerOptions(
             example_joiner='STREAM_JOINER',
             min_matching_window=64,
             max_matching_window=256,
             data_block_dump_interval=30,
             data_block_dump_threshold=1000),
         example_id_batch_options=dj_pb.ExampleIdBatchOptions(
             example_id_batch_size=1024, max_flying_example_id=4096))
     self._worker_addrs_l = ['localhost:4161', 'localhost:4162']
     self._worker_addrs_f = ['localhost:5161', 'localhost:5162']
     self._workers_l = []
     self._workers_f = []
     for rank_id in range(2):
         worker_addr_l = self._worker_addrs_l[rank_id]
         worker_addr_f = self._worker_addrs_f[rank_id]
         self._workers_l.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_l.split(':')[1]), worker_addr_f,
                 self._master_addr_l, rank_id, self._etcd_name,
                 self._etcd_base_dir_l, self._etcd_addrs, worker_options))
         self._workers_f.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_f.split(':')[1]), worker_addr_l,
                 self._master_addr_f, rank_id, self._etcd_name,
                 self._etcd_base_dir_f, self._etcd_addrs, worker_options))
     for w in self._workers_l:
         w.start()
     for w in self._workers_f:
         w.start()
    args = parser.parse_args()
    set_logger()
    if args.input_data_file_iter == 'TF_RECORD' or \
            args.output_builder == 'TF_RECORD':
        import tensorflow
        tensorflow.compat.v1.enable_eager_execution()

    optional_fields = list(
        field for field in map(str.strip, args.optional_fields.split(','))
        if field != '')

    portal_worker_options = dp_pb.DataPortalWorkerOptions(
        raw_data_options=dj_pb.RawDataOptions(
            raw_data_iter=args.input_data_file_iter,
            compressed_type=args.compressed_type,
            read_ahead_size=args.read_ahead_size,
            read_batch_size=args.read_batch_size,
            optional_fields=optional_fields),
        writer_options=dj_pb.WriterOptions(
            output_writer=args.output_builder,
            compressed_type=args.builder_compressed_type),
        batch_processor_options=dj_pb.BatchProcessorOptions(
            batch_size=args.batch_size, max_flying_item=-1),
        merger_read_ahead_size=args.merger_read_ahead_size,
        merger_read_batch_size=args.merger_read_batch_size,
        memory_limit_ratio=args.memory_limit_ratio / 100)
    data_portal_worker = DataPortalWorker(portal_worker_options,
                                          args.master_addr, args.rank_id,
                                          args.kvstore_type,
                                          (args.kvstore_type == 'mock'))
    data_portal_worker.start()
Beispiel #26
0
 def test_data_block_dumper(self):
     self.generate_follower_data_block()
     self.generate_leader_raw_data()
     dbd = data_block_dumper.DataBlockDumperManager(
         self.etcd,
         self.data_source_l,
         0,
         dj_pb.RawDataOptions(raw_data_iter='TF_RECORD'),
         dj_pb.DataBlockBuilderOptions(
             data_block_builder='TF_RECORD_DATABLOCK_BUILDER'),
     )
     self.assertEqual(dbd.get_next_data_block_index(), 0)
     for (idx, meta) in enumerate(self.dumped_metas):
         success, next_index = dbd.add_synced_data_block_meta(meta)
         self.assertTrue(success)
         self.assertEqual(next_index, idx + 1)
     self.assertTrue(dbd.need_dump())
     self.assertEqual(dbd.get_next_data_block_index(),
                      len(self.dumped_metas))
     with dbd.make_data_block_dumper() as dumper:
         dumper()
     dbm_f = data_block_manager.DataBlockManager(self.data_source_f, 0)
     dbm_l = data_block_manager.DataBlockManager(self.data_source_l, 0)
     self.assertEqual(dbm_f.get_dumped_data_block_count(),
                      len(self.dumped_metas))
     self.assertEqual(dbm_f.get_dumped_data_block_count(),
                      dbm_l.get_dumped_data_block_count())
     for (idx, meta) in enumerate(self.dumped_metas):
         self.assertEqual(meta.data_block_index, idx)
         self.assertEqual(dbm_l.get_data_block_meta_by_index(idx), meta)
         self.assertEqual(dbm_f.get_data_block_meta_by_index(idx), meta)
         meta_fpth_l = os.path.join(
             self.data_source_l.data_block_dir, common.partition_repr(0),
             common.encode_data_block_meta_fname(
                 self.data_source_l.data_source_meta.name, 0,
                 meta.data_block_index))
         mitr = tf.io.tf_record_iterator(meta_fpth_l)
         meta_l = text_format.Parse(next(mitr), dj_pb.DataBlockMeta())
         self.assertEqual(meta_l, meta)
         meta_fpth_f = os.path.join(
             self.data_source_f.data_block_dir, common.partition_repr(0),
             common.encode_data_block_meta_fname(
                 self.data_source_f.data_source_meta.name, 0,
                 meta.data_block_index))
         mitr = tf.io.tf_record_iterator(meta_fpth_f)
         meta_f = text_format.Parse(next(mitr), dj_pb.DataBlockMeta())
         self.assertEqual(meta_f, meta)
         data_fpth_l = os.path.join(
             self.data_source_l.data_block_dir, common.partition_repr(0),
             common.encode_data_block_fname(
                 self.data_source_l.data_source_meta.name, meta_l))
         for (iidx,
              record) in enumerate(tf.io.tf_record_iterator(data_fpth_l)):
             example = tf.train.Example()
             example.ParseFromString(record)
             feat = example.features.feature
             self.assertEqual(feat['example_id'].bytes_list.value[0],
                              meta.example_ids[iidx])
         self.assertEqual(len(meta.example_ids), iidx + 1)
         data_fpth_f = os.path.join(
             self.data_source_f.data_block_dir, common.partition_repr(0),
             common.encode_data_block_fname(
                 self.data_source_l.data_source_meta.name, meta_f))
         for (iidx,
              record) in enumerate(tf.io.tf_record_iterator(data_fpth_f)):
             example = tf.train.Example()
             example.ParseFromString(record)
             feat = example.features.feature
             self.assertEqual(feat['example_id'].bytes_list.value[0],
                              meta.example_ids[iidx])
         self.assertEqual(len(meta.example_ids), iidx + 1)
Beispiel #27
0
         max_flying_sign_batch=args.max_flying_sign_batch,
         max_flying_sign_rpc=args.max_flying_sign_rpc,
         sign_rpc_timeout_ms=args.sign_rpc_timeout_ms,
         stub_fanout=args.stub_fanout,
         slow_sign_threshold=args.slow_sign_threshold,
         sort_run_merger_read_ahead_buffer=\
             args.sort_run_merger_read_ahead_buffer,
         sort_run_merger_read_batch_size=\
             args.sort_run_merger_read_batch_size,
         batch_processor_options=dj_pb.BatchProcessorOptions(
             batch_size=args.process_batch_size,
             max_flying_item=-1
         ),
         input_raw_data=dj_pb.RawDataOptions(
             raw_data_iter=args.raw_data_iter,
             compressed_type=args.compressed_type,
             read_ahead_size=args.read_ahead_size,
             read_batch_size=args.read_batch_size
         ),
         writer_options=dj_pb.WriterOptions(
             output_writer=args.output_builder,
             compressed_type=args.builder_compressed_type,
         )
     )
 if args.psi_role.upper() == 'LEADER':
     preprocessor_options.role = common_pb.FLRole.Leader
 else:
     assert args.psi_role.upper() == 'FOLLOWER'
     preprocessor_options.role = common_pb.FLRole.Follower
 preprocessor = RsaPsiPreProcessor(preprocessor_options, args.etcd_name,
                                   args.etcd_addrs, args.etcd_base_dir)
 preprocessor.start_process()
Beispiel #28
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f= 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        self.raw_data_dir_l = "./raw_data_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        self.raw_data_dir_f = "./raw_data_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.data_source_name = data_source_name
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
                self.etcd_l, self.raw_data_pub_dir_l
            )
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
                self.etcd_f, self.raw_data_pub_dir_f
            )
        if gfile.Exists(data_source_l.output_base_dir):
            gfile.DeleteRecursively(data_source_l.output_base_dir)
        if gfile.Exists(self.raw_data_dir_l):
            gfile.DeleteRecursively(self.raw_data_dir_l)
        if gfile.Exists(data_source_f.output_base_dir):
            gfile.DeleteRecursively(data_source_f.output_base_dir)
        if gfile.Exists(self.raw_data_dir_f):
            gfile.DeleteRecursively(self.raw_data_dir_f)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
                use_mock_etcd=True,
                raw_data_options=dj_pb.RawDataOptions(
                    raw_data_iter='TF_RECORD',
                    read_ahead_size=1<<20,
                    read_batch_size=128
                ),
                example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                    example_id_dump_interval=1,
                    example_id_dump_threshold=1024
                ),
                example_joiner_options=dj_pb.ExampleJoinerOptions(
                    example_joiner='STREAM_JOINER',
                    min_matching_window=64,
                    max_matching_window=256,
                    data_block_dump_interval=30,
                    data_block_dump_threshold=1000
                ),
                batch_processor_options=dj_pb.BatchProcessorOptions(
                    batch_size=512,
                    max_flying_item=2048
                ),
                data_block_builder_options=dj_pb.WriterOptions(
                    output_writer='TF_RECORD'
                )
            )

        self.total_index = 1 << 12
                        type=int,
                        default=4096,
                        help='the buffer size of portal reducer')
    parser.add_argument('--example_validator',
                        type=str,
                        default='EXAMPLE_VALIDATOR',
                        help='the name of example validator')
    parser.add_argument('--validate_event_time',
                        action='store_true',
                        help='validate the example has event time')
    args = parser.parse_args()
    options = dj_pb.DataJoinPotralOptions(
        example_validator=dj_pb.ExampleValidatorOptions(
            example_validator=args.example_validator,
            validate_event_time=args.validate_event_time),
        reducer_buffer_size=args.portal_reducer_buffer_size,
        raw_data_options=dj_pb.RawDataOptions(
            raw_data_iter=args.input_data_file_iter,
            compressed_type=args.compressed_type),
        raw_data_publish_dir=args.raw_data_publish_dir,
        use_mock_etcd=args.use_mock_etcd)
    portal_srv = DataJoinPortalService(
        args.listen_port,
        args.data_join_portal_name,
        args.etcd_name,
        args.etcd_addrs,
        args.etcd_base_dir,
        options,
    )
    portal_srv.run()
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]),
            master_addr_f,
            data_source_name,
            etcd_name,
            etcd_base_dir_l,
            etcd_addrs,
            master_options,
        )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, master_options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = master_client_l.GetDataSourceStatus(req_l)
            dss_f = master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        self.master_client_l = master_client_l
        self.master_client_f = master_client_f
        self.master_addr_l = master_addr_l
        self.master_addr_f = master_addr_f
        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.master_l = master_l
        self.master_f = master_f
        self.data_source_name = data_source_name,
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
            self.etcd_l, self.raw_data_pub_dir_l)
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
            self.etcd_f, self.raw_data_pub_dir_f)
        if gfile.Exists(data_source_l.data_block_dir):
            gfile.DeleteRecursively(data_source_l.data_block_dir)
        if gfile.Exists(data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(data_source_l.example_dumped_dir)
        if gfile.Exists(data_source_l.raw_data_dir):
            gfile.DeleteRecursively(data_source_l.raw_data_dir)
        if gfile.Exists(data_source_f.data_block_dir):
            gfile.DeleteRecursively(data_source_f.data_block_dir)
        if gfile.Exists(data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(data_source_f.example_dumped_dir)
        if gfile.Exists(data_source_f.raw_data_dir):
            gfile.DeleteRecursively(data_source_f.raw_data_dir)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='STREAM_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=512, max_flying_item=2048),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self.total_index = 1 << 13