def test_mysql_op(self): cli = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'data_source_a', True) cli.delete('fl_key') cli.set_data('fl_key', 'fl_value') self.assertEqual(cli.get_data('fl_key'), b'fl_value') self.assertFalse(cli.cas('fl_key', 'fl_value1', 'fl_value2')) self.assertTrue(cli.cas('fl_key', 'fl_value', 'fl_value1')) self.assertEqual(cli.get_data('fl_key'), b'fl_value1') def thread_routine(): cli.set_data('fl_key', 'fl_value2') self.assertEqual(cli.get_data('fl_key'), b'fl_value2') other = threading.Thread(target=thread_routine) other.start() other.join() cli.set_data('fl_key/a', '1') cli.set_data('fl_key/b', '2') cli.set_data('fl_key/c', '3') expected_kvs = [(b'fl_key', b'fl_value2'), (b'fl_key/a', b'1'), (b'fl_key/b', b'2'), (b'fl_key/c', b'3')] for idx, kv in enumerate(cli.get_prefix_kvs('fl_key')): self.assertEqual(kv[0], expected_kvs[idx][0]) self.assertEqual(kv[1], expected_kvs[idx][1]) for idx, kv in enumerate(cli.get_prefix_kvs('fl_key', True)): self.assertEqual(kv[0], expected_kvs[idx + 1][0]) self.assertEqual(kv[1], expected_kvs[idx + 1][1])
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-f" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.raw_data_dir = "./raw_data" self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type='', optional_fields=['label']) self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=32, max_matching_window=128, data_block_dump_interval=30, data_block_dump_threshold=128) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = 4 data_source.data_source_meta.start_time = 0 data_source.data_source_meta.end_time = 10000 data_source.output_base_dir = "./ds_output" data_source.role = common_pb.FLRole.Follower self.data_source = data_source self.db_database = 'test_cluster' self.db_addr = 'localhost:2379' self.db_base_dir = 'fedlearner' self.db_username = '******' self.db_password = '******' self.kvstore = mysql_client.DBClient(self.db_database, self.db_addr, self.db_username, self.db_password, self.db_base_dir, True) common.commit_data_source(self.kvstore, self.data_source) if gfile.Exists(data_source.output_base_dir): gfile.DeleteRecursively(data_source.output_base_dir) self.data_block_matas = [] self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) partition_num = self.data_source.data_source_meta.partition_num for i in range(partition_num): self._create_data_block(i)
def test_csv_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = path.join( path.dirname(path.abspath(__file__)), "../csv_raw_data" ) self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = path.join(self.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) manifest_manager.add_raw_data( 0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "test_raw_data.csv"), timestamp=timestamp_pb2.Timestamp(seconds=3))], True) raw_data_options = dj_pb.RawDataOptions( raw_data_iter='CSV_DICT', read_ahead_size=1<<20 ) rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 1024 == 0: print("{} {}".format(index, item.raw_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertEqual(expected_index, 4999)
def setUp(self): self.sche = _TaskScheduler(30) self.kv_store = [None, None] self.app_id = "test_trainer_v1" db_database, db_addr, db_username, db_password, db_base_dir = \ get_kvstore_config("etcd") data_source = [ self._gen_ds_meta(common_pb.FLRole.Leader), self._gen_ds_meta(common_pb.FLRole.Follower) ] for role in range(2): self.kv_store[role] = mysql_client.DBClient( data_source[role].data_source_meta.name, db_addr, db_username, db_password, db_base_dir, True) self.data_source = data_source (x, y) = (None, None) if debug_mode: (x, y), _ = tf.keras.datasets.mnist.load_data(local_mnist_path) else: (x, y), _ = tf.keras.datasets.mnist.load_data() x = x[:200, ] x = x.reshape(x.shape[0], -1).astype(np.float32) / 255.0 y = y.astype(np.int64) xl = x[:, :x.shape[1] // 2] xf = x[:, x.shape[1] // 2:] self._create_local_data(xl, xf, y) x = [xl, xf] for role in range(2): common.commit_data_source(self.kv_store[role], data_source[role]) if gfile.Exists(data_source[role].output_base_dir): gfile.DeleteRecursively(data_source[role].output_base_dir) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kv_store[role], data_source[role]) partition_num = data_source[role].data_source_meta.partition_num for i in range(partition_num): self._create_data_block(data_source[role], i, x[role], y) #x[role], y if role == 0 else None) manifest_manager._finish_partition( 'join_example_rep', dj_pb.JoinExampleState.UnJoined, dj_pb.JoinExampleState.Joined, -1, i)
def setUp(self): self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(data_source.data_source_meta.name)) self.data_source = data_source self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=-1, example_id_dump_threshold=1024 ) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) self.partition_dir = os.path.join(common.data_source_example_dumped_dir(self.data_source), common.partition_repr(0)) gfile.MakeDirs(self.partition_dir)
def __init__(self, base_path, name, role, partition_num=1, start_time=0, end_time=100000): if role == 'leader': role = 0 elif role == 'follower': role = 1 else: raise ValueError("Unknown role %s" % role) data_source = common_pb.DataSource() data_source.data_source_meta.name = name data_source.data_source_meta.partition_num = partition_num data_source.data_source_meta.start_time = start_time data_source.data_source_meta.end_time = end_time data_source.output_base_dir = "{}/{}_{}/data_source/".format( base_path, data_source.data_source_meta.name, role) data_source.role = role if gfile.Exists(data_source.output_base_dir): gfile.DeleteRecursively(data_source.output_base_dir) self._data_source = data_source db_database, db_addr, db_username, db_password, db_base_dir = \ get_kvstore_config("etcd") self._kv_store = mysql_client.DBClient(db_database, db_addr, db_username, db_password, db_base_dir, True) common.commit_data_source(self._kv_store, self._data_source) self._dbms = [] for i in range(partition_num): manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self._kv_store, self._data_source) manifest_manager._finish_partition('join_example_rep', dj_pb.JoinExampleState.UnJoined, dj_pb.JoinExampleState.Joined, -1, i) self._dbms.append( data_block_manager.DataBlockManager(self._data_source, i))
def setUp(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = "./raw_data" self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = os.path.join(self.raw_data_dir, common.partition_repr(0)) if gfile.Exists(partition_dir): gfile.DeleteRecursively(partition_dir) gfile.MakeDirs(partition_dir) self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source)
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-f" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.raw_data_dir = "./raw_data" self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type='') self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='ATTRIBUTION_JOINER', min_matching_window=32, max_matching_window=51200, max_conversion_delay=interval_to_timestamp("124"), enable_negative_example_generator=True, data_block_dump_interval=32, data_block_dump_threshold=128, negative_sampling_rate=0.8, ) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0
def setUp(self): data_source_f = common_pb.DataSource() data_source_f.data_source_meta.name = "milestone" data_source_f.data_source_meta.partition_num = 1 data_source_f.output_base_dir = "./output-f" self.data_source_f = data_source_f if gfile.Exists(self.data_source_f.output_base_dir): gfile.DeleteRecursively(self.data_source_f.output_base_dir) data_source_l = common_pb.DataSource() data_source_l.data_source_meta.name = "milestone" data_source_l.data_source_meta.partition_num = 1 data_source_l.output_base_dir = "./output-l" self.raw_data_dir_l = "./raw_data-l" self.data_source_l = data_source_l if gfile.Exists(self.data_source_l.output_base_dir): gfile.DeleteRecursively(self.data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source_l.data_source_meta.name)) self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source_l)
def test_raw_data_manifest_manager(self): cli = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) partition_num = 4 rank_id = 2 data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = partition_num data_source.role = common_pb.FLRole.Leader cli.delete_prefix( common.data_source_kvstore_base_dir( data_source.data_source_meta.name)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( cli, data_source) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) manifest = manifest_manager.alloc_sync_exampld_id(rank_id) self.assertNotEqual(manifest, None) partition_id = manifest.partition_id manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i != partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) partition_id2 = 3 - partition_id rank_id2 = 100 manifest = manifest_manager.alloc_join_example(rank_id2, partition_id2) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) self.assertRaises(Exception, manifest_manager.finish_join_example, rank_id, partition_id) self.assertRaises(Exception, manifest_manager.finish_join_example, rank_id2, partition_id2) self.assertRaises(Exception, manifest_manager.finish_sync_example_id, -rank_id, partition_id) self.assertRaises(Exception, manifest_manager.finish_sync_example_id, rank_id2, partition_id2) rank_id3 = 0 manifest = manifest_manager.alloc_join_example(rank_id3, partition_id) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3) elif i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) self.assertRaises(Exception, manifest_manager.finish_sync_example_id, rank_id, partition_id) raw_data_metas = [ dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='c', timestamp=timestamp_pb2.Timestamp(seconds=1)) ] self.assertRaises(Exception, manifest_manager.add_raw_data, partition_id, raw_data_metas, False) manifest_manager.add_raw_data(partition_id, raw_data_metas, True) latest_ts = manifest_manager.get_raw_date_latest_timestamp( partition_id) self.assertEqual(latest_ts.seconds, 3) self.assertEqual(latest_ts.nanos, 0) manifest = manifest_manager.get_manifest(partition_id) self.assertEqual(manifest.next_process_index, 2) raw_data_metas = [ dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)), dj_pb.RawDataMeta(file_path='c', timestamp=timestamp_pb2.Timestamp(seconds=1)), dj_pb.RawDataMeta(file_path='d', timestamp=timestamp_pb2.Timestamp(seconds=4)) ] manifest_manager.add_raw_data(partition_id, raw_data_metas, True) latest_ts = manifest_manager.get_raw_date_latest_timestamp( partition_id) self.assertEqual(latest_ts.seconds, 4) self.assertEqual(latest_ts.nanos, 0) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].next_process_index, 4) else: self.assertEqual(manifest_map[i].next_process_index, 0) manifest_manager.finish_raw_data(partition_id) manifest_manager.finish_raw_data(partition_id) self.assertRaises(Exception, manifest_manager.add_raw_data, partition_id, 200) manifest_manager.finish_sync_example_id(rank_id, partition_id) manifest_manager.finish_sync_example_id(rank_id, partition_id) manifest_map = manifest_manager.list_all_manifest() for i in range(data_source.data_source_meta.partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Synced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) self.assertTrue(manifest_map[i].finished) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3) elif i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) manifest_manager.finish_join_example(rank_id3, partition_id) manifest_manager.finish_join_example(rank_id3, partition_id) manifest_map = manifest_manager.list_all_manifest() for i in range(data_source.data_source_meta.partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Synced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3) elif i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) cli.destroy_client_pool()