def test_mysql_op(self): cli = db_client.DBClient('mysql', True) cli.delete('fl_key') cli.set_data('fl_key', 'fl_value') self.assertEqual(cli.get_data('fl_key'), b'fl_value') self.assertFalse(cli.cas('fl_key', 'fl_value1', 'fl_value2')) self.assertTrue(cli.cas('fl_key', 'fl_value', 'fl_value1')) self.assertEqual(cli.get_data('fl_key'), b'fl_value1') def thread_routine(): cli.set_data('fl_key', 'fl_value2') self.assertEqual(cli.get_data('fl_key'), b'fl_value2') other = threading.Thread(target=thread_routine) other.start() other.join() cli.set_data('fl_key/a', '1') cli.set_data('fl_key/b', '2') cli.set_data('fl_key/c', '3') expected_kvs = [(b'fl_key', b'fl_value2'), (b'fl_key/a', b'1'), (b'fl_key/b', b'2'), (b'fl_key/c', b'3')] for idx, kv in enumerate(cli.get_prefix_kvs('fl_key')): self.assertEqual(kv[0], expected_kvs[idx][0]) self.assertEqual(kv[1], expected_kvs[idx][1]) for idx, kv in enumerate(cli.get_prefix_kvs('fl_key', True)): self.assertEqual(kv[0], expected_kvs[idx + 1][0]) self.assertEqual(kv[1], expected_kvs[idx + 1][1])
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-f" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.raw_data_dir = "./raw_data" self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type='') self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='ATTRIBUTION_JOINER', min_matching_window=32, max_matching_window=51200, max_conversion_delay=interval_to_timestamp("124"), enable_negative_example_generator=True, data_block_dump_interval=32, data_block_dump_threshold=128, negative_sampling_rate=0.8, ) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-f" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.raw_data_dir = "./raw_data" self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type='', optional_fields=['label']) self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=32, max_matching_window=128, data_block_dump_interval=30, data_block_dump_threshold=128) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0
def test_raw_data_manifest_manager_with_nfs(self): root_dir = "test_fedlearner" os.environ["STORAGE_ROOT_PATH"] = root_dir cli = db_client.DBClient('nfs', True) self._raw_data_manifest_manager(cli) if gfile.Exists(root_dir): gfile.DeleteRecursively(root_dir)
def setUp(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = "./raw_data" self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = os.path.join(self.raw_data_dir, common.partition_repr(0)) if gfile.Exists(partition_dir): gfile.DeleteRecursively(partition_dir) gfile.MakeDirs(partition_dir) self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source)
def init(self, dsname, joiner_name, version=Version.V1, cache_type="memory"): data_source = common_pb.DataSource() data_source.data_source_meta.name = dsname data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "%s_ds_output" % dsname self.raw_data_dir = "%s_raw_data" % dsname self.data_source = data_source self.raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', compressed_type='', raw_data_cache_type=cache_type, ) self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner=joiner_name, min_matching_window=32, max_matching_window=51200, max_conversion_delay=interval_to_timestamp("124"), enable_negative_example_generator=True, data_block_dump_interval=32, data_block_dump_threshold=128, negative_sampling_rate=0.8, join_expr="example_id", join_key_mapper="DEFAULT", negative_sampling_filter_expr='', ) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) if gfile.Exists(self.raw_data_dir): gfile.DeleteRecursively(self.raw_data_dir) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source.data_source_meta.name)) self.total_raw_data_count = 0 self.total_example_id_count = 0 self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) self.g_data_block_index = 0 self.version = version
def setUp(self): self.kvstore = db_client.DBClient('etcd', True) data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = 1 data_source.output_base_dir = "./ds_output" self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( data_source.data_source_meta.name)) self.data_source = data_source self.example_id_dump_options = dj_pb.ExampleIdDumpOptions( example_id_dump_interval=-1, example_id_dump_threshold=1024) if gfile.Exists(self.data_source.output_base_dir): gfile.DeleteRecursively(self.data_source.output_base_dir) self.partition_dir = os.path.join( common.data_source_example_dumped_dir(self.data_source), common.partition_repr(0)) gfile.MakeDirs(self.partition_dir)
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = 4 data_source.data_source_meta.start_time = 0 data_source.data_source_meta.end_time = 10000 data_source.output_base_dir = "./ds_output" data_source.role = common_pb.FLRole.Follower self.data_source = data_source self.kvstore = db_client.DBClient('etcd', True) common.commit_data_source(self.kvstore, self.data_source) if gfile.Exists(data_source.output_base_dir): gfile.DeleteRecursively(data_source.output_base_dir) self.data_block_matas = [] self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) partition_num = self.data_source.data_source_meta.partition_num for i in range(partition_num): self._create_data_block(i)
def setUp(self): self.sche = _TaskScheduler(30) self.kv_store = [None, None] self.app_id = "test_trainer_v1" data_source = [ self._gen_ds_meta(common_pb.FLRole.Leader), self._gen_ds_meta(common_pb.FLRole.Follower) ] for role in range(2): os.environ['ETCD_NAME'] = data_source[role].data_source_meta.name self.kv_store[role] = db_client.DBClient("etcd", True) self.data_source = data_source (x, y) = (None, None) if debug_mode: (x, y), _ = tf.keras.datasets.mnist.load_data(local_mnist_path) else: (x, y), _ = tf.keras.datasets.mnist.load_data() x = x[:200, ] x = x.reshape(x.shape[0], -1).astype(np.float32) / 255.0 y = y.astype(np.int64) xl = x[:, :x.shape[1] // 2] xf = x[:, x.shape[1] // 2:] self._create_local_data(xl, xf, y) x = [xl, xf] for role in range(2): common.commit_data_source(self.kv_store[role], data_source[role]) if gfile.Exists(data_source[role].output_base_dir): gfile.DeleteRecursively(data_source[role].output_base_dir) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kv_store[role], data_source[role]) partition_num = data_source[role].data_source_meta.partition_num for i in range(partition_num): self._create_data_block(data_source[role], i, x[role], y) #x[role], y if role == 0 else None) manifest_manager._finish_partition( 'join_example_rep', dj_pb.JoinExampleState.UnJoined, dj_pb.JoinExampleState.Joined, -1, i)
def __init__(self, base_path, name, role, partition_num=1, start_time=0, end_time=100000): if role == 'leader': role = 0 elif role == 'follower': role = 1 else: raise ValueError("Unknown role %s" % role) data_source = common_pb.DataSource() data_source.data_source_meta.name = name data_source.data_source_meta.partition_num = partition_num data_source.data_source_meta.start_time = start_time data_source.data_source_meta.end_time = end_time data_source.output_base_dir = "{}/{}_{}/data_source/".format( base_path, data_source.data_source_meta.name, role) data_source.role = role if gfile.Exists(data_source.output_base_dir): gfile.DeleteRecursively(data_source.output_base_dir) self._data_source = data_source self._kv_store = db_client.DBClient("etcd", True) common.commit_data_source(self._kv_store, self._data_source) self._dbms = [] for i in range(partition_num): manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self._kv_store, self._data_source) manifest_manager._finish_partition('join_example_rep', dj_pb.JoinExampleState.UnJoined, dj_pb.JoinExampleState.Joined, -1, i) self._dbms.append( data_block_manager.DataBlockManager(self._data_source, i))
def setUp(self): data_source_f = common_pb.DataSource() data_source_f.data_source_meta.name = "milestone" data_source_f.data_source_meta.partition_num = 1 data_source_f.output_base_dir = "./output-f" self.data_source_f = data_source_f if gfile.Exists(self.data_source_f.output_base_dir): gfile.DeleteRecursively(self.data_source_f.output_base_dir) data_source_l = common_pb.DataSource() data_source_l.data_source_meta.name = "milestone" data_source_l.data_source_meta.partition_num = 1 data_source_l.output_base_dir = "./output-l" self.raw_data_dir_l = "./raw_data-l" self.data_source_l = data_source_l if gfile.Exists(self.data_source_l.output_base_dir): gfile.DeleteRecursively(self.data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source_l.data_source_meta.name)) self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source_l)
def test_raw_data_manifest_manager_with_db(self): cli = db_client.DBClient('etcd', True) self._raw_data_manifest_manager(cli)