Esempio n. 1
0
    def test_mysql_op(self):
        cli = db_client.DBClient('mysql', True)
        cli.delete('fl_key')
        cli.set_data('fl_key', 'fl_value')
        self.assertEqual(cli.get_data('fl_key'), b'fl_value')
        self.assertFalse(cli.cas('fl_key', 'fl_value1', 'fl_value2'))
        self.assertTrue(cli.cas('fl_key', 'fl_value', 'fl_value1'))
        self.assertEqual(cli.get_data('fl_key'), b'fl_value1')

        def thread_routine():
            cli.set_data('fl_key', 'fl_value2')
            self.assertEqual(cli.get_data('fl_key'), b'fl_value2')

        other = threading.Thread(target=thread_routine)
        other.start()
        other.join()
        cli.set_data('fl_key/a', '1')
        cli.set_data('fl_key/b', '2')
        cli.set_data('fl_key/c', '3')
        expected_kvs = [(b'fl_key', b'fl_value2'), (b'fl_key/a', b'1'),
                        (b'fl_key/b', b'2'), (b'fl_key/c', b'3')]
        for idx, kv in enumerate(cli.get_prefix_kvs('fl_key')):
            self.assertEqual(kv[0], expected_kvs[idx][0])
            self.assertEqual(kv[1], expected_kvs[idx][1])
        for idx, kv in enumerate(cli.get_prefix_kvs('fl_key', True)):
            self.assertEqual(kv[0], expected_kvs[idx + 1][0])
            self.assertEqual(kv[1], expected_kvs[idx + 1][1])
Esempio n. 2
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='')
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='ATTRIBUTION_JOINER',
         min_matching_window=32,
         max_matching_window=51200,
         max_conversion_delay=interval_to_timestamp("124"),
         enable_negative_example_generator=True,
         data_block_dump_interval=32,
         data_block_dump_threshold=128,
         negative_sampling_rate=0.8,
     )
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
Esempio n. 3
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='',
                                                  optional_fields=['label'])
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='STREAM_JOINER',
         min_matching_window=32,
         max_matching_window=128,
         data_block_dump_interval=30,
         data_block_dump_threshold=128)
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
 def test_raw_data_manifest_manager_with_nfs(self):
     root_dir = "test_fedlearner"
     os.environ["STORAGE_ROOT_PATH"] = root_dir
     cli = db_client.DBClient('nfs', True)
     self._raw_data_manifest_manager(cli)
     if gfile.Exists(root_dir):
         gfile.DeleteRecursively(root_dir)
 def setUp(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = "./raw_data"
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.raw_data_dir, common.partition_repr(0))
     if gfile.Exists(partition_dir):
         gfile.DeleteRecursively(partition_dir)
     gfile.MakeDirs(partition_dir)
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
Esempio n. 6
0
 def init(self,
          dsname,
          joiner_name,
          version=Version.V1,
          cache_type="memory"):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = dsname
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "%s_ds_output" % dsname
     self.raw_data_dir = "%s_raw_data" % dsname
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(
         raw_data_iter='TF_RECORD',
         compressed_type='',
         raw_data_cache_type=cache_type,
     )
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner=joiner_name,
         min_matching_window=32,
         max_matching_window=51200,
         max_conversion_delay=interval_to_timestamp("124"),
         enable_negative_example_generator=True,
         data_block_dump_interval=32,
         data_block_dump_threshold=128,
         negative_sampling_rate=0.8,
         join_expr="example_id",
         join_key_mapper="DEFAULT",
         negative_sampling_filter_expr='',
     )
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
     self.version = version
 def setUp(self):
     self.kvstore = db_client.DBClient('etcd', True)
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-x"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             data_source.data_source_meta.name))
     self.data_source = data_source
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=-1, example_id_dump_threshold=1024)
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     self.partition_dir = os.path.join(
         common.data_source_example_dumped_dir(self.data_source),
         common.partition_repr(0))
     gfile.MakeDirs(self.partition_dir)
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-x"
     data_source.data_source_meta.partition_num = 4
     data_source.data_source_meta.start_time = 0
     data_source.data_source_meta.end_time = 10000
     data_source.output_base_dir = "./ds_output"
     data_source.role = common_pb.FLRole.Follower
     self.data_source = data_source
     self.kvstore = db_client.DBClient('etcd', True)
     common.commit_data_source(self.kvstore, self.data_source)
     if gfile.Exists(data_source.output_base_dir):
         gfile.DeleteRecursively(data_source.output_base_dir)
     self.data_block_matas = []
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     partition_num = self.data_source.data_source_meta.partition_num
     for i in range(partition_num):
         self._create_data_block(i)
Esempio n. 9
0
    def setUp(self):
        self.sche = _TaskScheduler(30)
        self.kv_store = [None, None]
        self.app_id = "test_trainer_v1"
        data_source = [
            self._gen_ds_meta(common_pb.FLRole.Leader),
            self._gen_ds_meta(common_pb.FLRole.Follower)
        ]
        for role in range(2):
            os.environ['ETCD_NAME'] = data_source[role].data_source_meta.name
            self.kv_store[role] = db_client.DBClient("etcd", True)
        self.data_source = data_source
        (x, y) = (None, None)
        if debug_mode:
            (x, y), _ = tf.keras.datasets.mnist.load_data(local_mnist_path)
        else:
            (x, y), _ = tf.keras.datasets.mnist.load_data()
        x = x[:200, ]

        x = x.reshape(x.shape[0], -1).astype(np.float32) / 255.0
        y = y.astype(np.int64)

        xl = x[:, :x.shape[1] // 2]
        xf = x[:, x.shape[1] // 2:]

        self._create_local_data(xl, xf, y)

        x = [xl, xf]
        for role in range(2):
            common.commit_data_source(self.kv_store[role], data_source[role])
            if gfile.Exists(data_source[role].output_base_dir):
                gfile.DeleteRecursively(data_source[role].output_base_dir)
            manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
                self.kv_store[role], data_source[role])
            partition_num = data_source[role].data_source_meta.partition_num
            for i in range(partition_num):
                self._create_data_block(data_source[role], i, x[role], y)
                #x[role], y if role == 0 else None)

                manifest_manager._finish_partition(
                    'join_example_rep', dj_pb.JoinExampleState.UnJoined,
                    dj_pb.JoinExampleState.Joined, -1, i)
Esempio n. 10
0
    def __init__(self,
                 base_path,
                 name,
                 role,
                 partition_num=1,
                 start_time=0,
                 end_time=100000):
        if role == 'leader':
            role = 0
        elif role == 'follower':
            role = 1
        else:
            raise ValueError("Unknown role %s" % role)
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = name
        data_source.data_source_meta.partition_num = partition_num
        data_source.data_source_meta.start_time = start_time
        data_source.data_source_meta.end_time = end_time
        data_source.output_base_dir = "{}/{}_{}/data_source/".format(
            base_path, data_source.data_source_meta.name, role)
        data_source.role = role
        if gfile.Exists(data_source.output_base_dir):
            gfile.DeleteRecursively(data_source.output_base_dir)

        self._data_source = data_source

        self._kv_store = db_client.DBClient("etcd", True)

        common.commit_data_source(self._kv_store, self._data_source)
        self._dbms = []
        for i in range(partition_num):
            manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
                self._kv_store, self._data_source)
            manifest_manager._finish_partition('join_example_rep',
                                               dj_pb.JoinExampleState.UnJoined,
                                               dj_pb.JoinExampleState.Joined,
                                               -1, i)
            self._dbms.append(
                data_block_manager.DataBlockManager(self._data_source, i))
Esempio n. 11
0
 def setUp(self):
     data_source_f = common_pb.DataSource()
     data_source_f.data_source_meta.name = "milestone"
     data_source_f.data_source_meta.partition_num = 1
     data_source_f.output_base_dir = "./output-f"
     self.data_source_f = data_source_f
     if gfile.Exists(self.data_source_f.output_base_dir):
         gfile.DeleteRecursively(self.data_source_f.output_base_dir)
     data_source_l = common_pb.DataSource()
     data_source_l.data_source_meta.name = "milestone"
     data_source_l.data_source_meta.partition_num = 1
     data_source_l.output_base_dir = "./output-l"
     self.raw_data_dir_l = "./raw_data-l"
     self.data_source_l = data_source_l
     if gfile.Exists(self.data_source_l.output_base_dir):
         gfile.DeleteRecursively(self.data_source_l.output_base_dir)
     if gfile.Exists(self.raw_data_dir_l):
         gfile.DeleteRecursively(self.raw_data_dir_l)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source_l.data_source_meta.name))
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source_l)
 def test_raw_data_manifest_manager_with_db(self):
     cli = db_client.DBClient('etcd', True)
     self._raw_data_manifest_manager(cli)