def _setUpDataSource(self): self._data_source_name = 'test_data_source' self._kvstore_l.delete_prefix( common.data_source_kvstore_base_dir(self._data_source_name)) self._kvstore_f.delete_prefix( common.data_source_kvstore_base_dir(self._data_source_name)) self._data_source_l = common_pb.DataSource() self._data_source_l.role = common_pb.FLRole.Leader self._data_source_l.state = common_pb.DataSourceState.Init self._data_source_l.output_base_dir = "./ds_output_l" self._raw_data_dir_l = "./raw_data_l" self._data_source_l.raw_data_sub_dir = "./raw_data_sub_dir_l" self._data_source_f = common_pb.DataSource() self._data_source_f.role = common_pb.FLRole.Follower self._data_source_f.state = common_pb.DataSourceState.Init self._data_source_f.output_base_dir = "./ds_output_f" self._raw_data_dir_f = "./raw_data_f" self._data_source_f.raw_data_sub_dir = "./raw_data_sub_dir_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = self._data_source_name data_source_meta.partition_num = 4 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 self._data_source_l.data_source_meta.MergeFrom(data_source_meta) self._data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(self._kvstore_l, self._data_source_l) common.commit_data_source(self._kvstore_f, self._data_source_f)
def tearDown(self): if gfile.Exists(self.data_source_l.output_base_dir): gfile.DeleteRecursively(self.data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) if gfile.Exists(self.data_source_f.output_base_dir): gfile.DeleteRecursively(self.data_source_f.output_base_dir) if gfile.Exists(self.raw_data_dir_f): gfile.DeleteRecursively(self.raw_data_dir_f) self.kvstore_f.delete_prefix( common.data_source_kvstore_base_dir(self.db_base_dir_f)) self.kvstore_l.delete_prefix( common.data_source_kvstore_base_dir(self.db_base_dir_l))
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-f" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.raw_data_dir = "./raw_data" self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type='', optional_fields=['label']) self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=32, max_matching_window=128, data_block_dump_interval=30, data_block_dump_threshold=128) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-f" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.raw_data_dir = "./raw_data" self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type='') self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='ATTRIBUTION_JOINER', min_matching_window=32, max_matching_window=51200, max_conversion_delay=interval_to_timestamp("124"), enable_negative_example_generator=True, data_block_dump_interval=32, data_block_dump_threshold=128, negative_sampling_rate=0.8, ) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0
def test_csv_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = path.join( path.dirname(path.abspath(__file__)), "../csv_raw_data" ) self.kvstore = DBClient("etcd", True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = path.join(self.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) manifest_manager.add_raw_data( 0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "test_raw_data.csv"), timestamp=timestamp_pb2.Timestamp(seconds=3))], True) raw_data_options = dj_pb.RawDataOptions( raw_data_iter='CSV_DICT', read_ahead_size=1<<20 ) rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 1024 == 0: print("{} {}".format(index, item.raw_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertEqual(expected_index, 4999)
def tearDown(self): if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name))
def setUp(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = "./raw_data" self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = os.path.join(self.raw_data_dir, common.partition_repr(0)) if gfile.Exists(partition_dir): gfile.DeleteRecursively(partition_dir) gfile.MakeDirs(partition_dir) self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source)
def init(self, dsname, joiner_name, version=Version.V1, cache_type="memory"): data_source = common_pb.DataSource() data_source.data_source_meta.name = dsname data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "%s_ds_output" % dsname self.raw_data_dir = "%s_raw_data" % dsname self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', compressed_type='', raw_data_cache_type=cache_type, ) self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner=joiner_name, min_matching_window=32, max_matching_window=51200, max_conversion_delay=interval_to_timestamp("124"), enable_negative_example_generator=True, data_block_dump_interval=32, data_block_dump_threshold=128, negative_sampling_rate=0.8, join_expr="example_id", join_key_mapper="DEFAULT", negative_sampling_filter_expr='', ) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0 self.version = version
def setUp(self): self.kvstore = db_client.DBClient('etcd', True) data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( data_source.data_source_meta.name)) self.data_source = data_source self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=-1, example_id_dump_threshold=1024) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) self.partition_dir = os.path.join( common.data_source_example_dumped_dir(self.data_source), common.partition_repr(0)) gfile.MakeDirs(self.partition_dir)
def test_compressed_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = path.join( path.dirname(path.abspath(__file__)), "../compressed_raw_data" ) self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = path.join(self.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) manifest_manager.add_raw_data( 0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"), timestamp=timestamp_pb2.Timestamp(seconds=3))], True) raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', compressed_type='GZIP', read_ahead_size=1<<20, read_batch_size=128 ) rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 32 == 0: print("{} {}".format(index, item.example_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertGreater(expected_index, 0)
def setUp(self): data_source_f = common_pb.DataSource() data_source_f.data_source_meta.name = "milestone" data_source_f.data_source_meta.partition_num = 1 data_source_f.output_base_dir = "./output-f" self.data_source_f = data_source_f if gfile.Exists(self.data_source_f.output_base_dir): gfile.DeleteRecursively(self.data_source_f.output_base_dir) data_source_l = common_pb.DataSource() data_source_l.data_source_meta.name = "milestone" data_source_l.data_source_meta.partition_num = 1 data_source_l.output_base_dir = "./output-l" self.raw_data_dir_l = "./raw_data-l" self.data_source_l = data_source_l if gfile.Exists(self.data_source_l.output_base_dir): gfile.DeleteRecursively(self.data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source_l.data_source_meta.name)) self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source_l)
parser.add_argument('--raw_data_sub_dir', type=str, required=True, help='the mysql base dir to subscribe new raw data') args = parser.parse_args() data_source = common_pb.DataSource() data_source.data_source_meta.name = args.data_source_name data_source.data_source_meta.partition_num = args.partition_num data_source.data_source_meta.start_time = args.start_time data_source.data_source_meta.end_time = args.end_time data_source.data_source_meta.negative_sampling_rate = \ args.negative_sampling_rate if args.role.upper() == 'LEADER': data_source.role = common_pb.FLRole.Leader else: assert args.role.upper() == 'FOLLOWER' data_source.role = common_pb.FLRole.Follower data_source.output_base_dir = args.output_base_dir data_source.raw_data_sub_dir = args.raw_data_sub_dir data_source.state = common_pb.DataSourceState.Init kvstore = DBClient(args.kvstore_type) master_kvstore_key = common.data_source_kvstore_base_dir( data_source.data_source_meta.name) raw_data = kvstore.get_data(master_kvstore_key) if raw_data is None: logging.info("data source %s is not existed", args.data_source_name) common.commit_data_source(kvstore, data_source) logging.info("apply new data source %s", args.data_source_name) else: logging.info("data source %s has been existed", args.data_source_name)
def setUp(self): db_database = 'test_mysql' db_addr = 'localhost:2379' db_username_l = 'test_user_l' db_username_f = 'test_user_f' db_password_l = 'test_password_l' db_password_f = 'test_password_f' db_base_dir_l = 'byefl_l' db_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' kvstore_l = DBClient(db_database, db_addr, db_username_l, db_password_l, db_base_dir_l, True) kvstore_f = DBClient(db_database, db_addr, db_username_f, db_password_f, db_base_dir_f, True) kvstore_l.delete_prefix( common.data_source_kvstore_base_dir(data_source_name)) kvstore_f.delete_prefix( common.data_source_kvstore_base_dir(data_source_name)) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" self.raw_data_dir_l = "./raw_data_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" self.raw_data_dir_f = "./raw_data_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(kvstore_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(kvstore_f, data_source_f) self.kvstore_l = kvstore_l self.kvstore_f = kvstore_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.data_source_name = data_source_name self.db_database = db_database self.db_addr = db_addr self.db_username_l = db_username_l self.db_username_f = db_username_f self.db_password_l = db_password_l self.db_password_f = db_password_f self.db_base_dir_l = db_base_dir_l self.db_base_dir_f = db_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.kvstore_l, self.raw_data_pub_dir_l) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.kvstore_f, self.raw_data_pub_dir_f) if gfile.Exists(data_source_l.output_base_dir): gfile.DeleteRecursively(data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) if gfile.Exists(data_source_f.output_base_dir): gfile.DeleteRecursively(data_source_f.output_base_dir) if gfile.Exists(self.raw_data_dir_f): gfile.DeleteRecursively(self.raw_data_dir_f) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', read_ahead_size=1 << 20, read_batch_size=128), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD')) self.total_index = 1 << 12
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) os.environ['ETCD_BASE_DIR'] = 'bytefl_l' data_source_name = 'test_data_source' kvstore_l = DBClient('etcd', True) os.environ['ETCD_BASE_DIR'] = 'bytefl_f' kvstore_f = DBClient('etcd', True) kvstore_l.delete_prefix( common.data_source_kvstore_base_dir(data_source_name)) kvstore_f.delete_prefix( common.data_source_kvstore_base_dir(data_source_name)) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(kvstore_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(kvstore_f, data_source_f) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) os.environ['ETCD_BASE_DIR'] = 'bytefl_l' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, 'etcd', options) master_l.start() os.environ['ETCD_BASE_DIR'] = 'bytefl_f' master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, 'etcd', options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=-1, join_example=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=-1, sync_example_id=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq1 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq1) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq2 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 3) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)), dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, ) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 1) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)), dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertTrue(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=4)) ])) try: rsp = client_l.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertTrue(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) try: rsp = client_f.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) master_l.stop() master_f.stop()
def test_raw_data_manifest_manager(self): cli = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) partition_num = 4 rank_id = 2 data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = partition_num data_source.role = common_pb.FLRole.Leader cli.delete_prefix( common.data_source_kvstore_base_dir( data_source.data_source_meta.name)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( cli, data_source) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) manifest = manifest_manager.alloc_sync_exampld_id(rank_id) self.assertNotEqual(manifest, None) partition_id = manifest.partition_id manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i != partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) partition_id2 = 3 - partition_id rank_id2 = 100 manifest = manifest_manager.alloc_join_example(rank_id2, partition_id2) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) self.assertRaises(Exception, manifest_manager.finish_join_example, rank_id, partition_id) self.assertRaises(Exception, manifest_manager.finish_join_example, rank_id2, partition_id2) self.assertRaises(Exception, manifest_manager.finish_sync_example_id, -rank_id, partition_id) self.assertRaises(Exception, manifest_manager.finish_sync_example_id, rank_id2, partition_id2) rank_id3 = 0 manifest = manifest_manager.alloc_join_example(rank_id3, partition_id) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3) elif i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) self.assertRaises(Exception, manifest_manager.finish_sync_example_id, rank_id, partition_id) raw_data_metas = [ dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='c', timestamp=timestamp_pb2.Timestamp(seconds=1)) ] self.assertRaises(Exception, manifest_manager.add_raw_data, partition_id, raw_data_metas, False) manifest_manager.add_raw_data(partition_id, raw_data_metas, True) latest_ts = manifest_manager.get_raw_date_latest_timestamp( partition_id) self.assertEqual(latest_ts.seconds, 3) self.assertEqual(latest_ts.nanos, 0) manifest = manifest_manager.get_manifest(partition_id) self.assertEqual(manifest.next_process_index, 2) raw_data_metas = [ dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)), dj_pb.RawDataMeta(file_path='c', timestamp=timestamp_pb2.Timestamp(seconds=1)), dj_pb.RawDataMeta(file_path='d', timestamp=timestamp_pb2.Timestamp(seconds=4)) ] manifest_manager.add_raw_data(partition_id, raw_data_metas, True) latest_ts = manifest_manager.get_raw_date_latest_timestamp( partition_id) self.assertEqual(latest_ts.seconds, 4) self.assertEqual(latest_ts.nanos, 0) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].next_process_index, 4) else: self.assertEqual(manifest_map[i].next_process_index, 0) manifest_manager.finish_raw_data(partition_id) manifest_manager.finish_raw_data(partition_id) self.assertRaises(Exception, manifest_manager.add_raw_data, partition_id, 200) manifest_manager.finish_sync_example_id(rank_id, partition_id) manifest_manager.finish_sync_example_id(rank_id, partition_id) manifest_map = manifest_manager.list_all_manifest() for i in range(data_source.data_source_meta.partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Synced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) self.assertTrue(manifest_map[i].finished) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3) elif i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) manifest_manager.finish_join_example(rank_id3, partition_id) manifest_manager.finish_join_example(rank_id3, partition_id) manifest_map = manifest_manager.list_all_manifest() for i in range(data_source.data_source_meta.partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Synced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3) elif i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) cli.destroy_client_pool()
def cleanup_meta_data(self): with self._lock: data_source_name = self._data_source.data_source_meta.name kvstore_base_key = \ common.data_source_kvstore_base_dir(data_source_name) self._kvstore.delete_prefix(kvstore_base_key)