Пример #1
0
 def setUp(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir,
                                  common.partition_repr(0))
     if gfile.Exists(partition_dir):
         gfile.DeleteRecursively(partition_dir)
     gfile.MakeDirs(partition_dir)
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
Пример #2
0
    def setUp(self):
        self.sche = _TaskScheduler(30)
        self.kv_store = [None, None]
        self.app_id = "test_trainer_v1"
        db_database, db_addr, db_username, db_password, db_base_dir = \
                get_kvstore_config("etcd")
        data_source = [
            self._gen_ds_meta(common_pb.FLRole.Leader),
            self._gen_ds_meta(common_pb.FLRole.Follower)
        ]
        for role in range(2):
            self.kv_store[role] = mysql_client.DBClient(
                data_source[role].data_source_meta.name, db_addr, db_username,
                db_password, db_base_dir, True)
        self.data_source = data_source
        (x, y) = (None, None)
        if debug_mode:
            (x, y), _ = tf.keras.datasets.mnist.load_data(local_mnist_path)
        else:
            (x, y), _ = tf.keras.datasets.mnist.load_data()
        x = x[:200, ]

        x = x.reshape(x.shape[0], -1).astype(np.float32) / 255.0
        y = y.astype(np.int64)

        xl = x[:, :x.shape[1] // 2]
        xf = x[:, x.shape[1] // 2:]

        self._create_local_data(xl, xf, y)

        x = [xl, xf]
        for role in range(2):
            common.commit_data_source(self.kv_store[role], data_source[role])
            if gfile.Exists(data_source[role].output_base_dir):
                gfile.DeleteRecursively(data_source[role].output_base_dir)
            manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
                self.kv_store[role], data_source[role])
            partition_num = data_source[role].data_source_meta.partition_num
            for i in range(partition_num):
                self._create_data_block(data_source[role], i, x[role], y)
                #x[role], y if role == 0 else None)

                manifest_manager._finish_partition(
                    'join_example_rep', dj_pb.JoinExampleState.UnJoined,
                    dj_pb.JoinExampleState.Joined, -1, i)
Пример #3
0
 def init(self,
          dsname,
          joiner_name,
          version=Version.V1,
          cache_type="memory"):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = dsname
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "%s_ds_output" % dsname
     self.raw_data_dir = "%s_raw_data" % dsname
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(
         raw_data_iter='TF_RECORD',
         compressed_type='',
         raw_data_cache_type=cache_type,
     )
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner=joiner_name,
         min_matching_window=32,
         max_matching_window=51200,
         max_conversion_delay=interval_to_timestamp("124"),
         enable_negative_example_generator=True,
         data_block_dump_interval=32,
         data_block_dump_threshold=128,
         negative_sampling_rate=0.8,
         join_expr="example_id",
         join_key_mapper="DEFAULT",
         negative_sampling_filter_expr='',
     )
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
     self.version = version
Пример #4
0
    def __init__(self,
                 base_path,
                 name,
                 role,
                 partition_num=1,
                 start_time=0,
                 end_time=100000):
        if role == 'leader':
            role = 0
        elif role == 'follower':
            role = 1
        else:
            raise ValueError("Unknown role %s" % role)
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = name
        data_source.data_source_meta.partition_num = partition_num
        data_source.data_source_meta.start_time = start_time
        data_source.data_source_meta.end_time = end_time
        data_source.output_base_dir = "{}/{}_{}/data_source/".format(
            base_path, data_source.data_source_meta.name, role)
        data_source.role = role
        if gfile.Exists(data_source.output_base_dir):
            gfile.DeleteRecursively(data_source.output_base_dir)

        self._data_source = data_source

        db_database, db_addr, db_username, db_password, db_base_dir = \
            get_kvstore_config("etcd")
        self._kv_store = mysql_client.DBClient(db_database, db_addr,
                                               db_username, db_password,
                                               db_base_dir, True)

        common.commit_data_source(self._kv_store, self._data_source)
        self._dbms = []
        for i in range(partition_num):
            manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
                self._kv_store, self._data_source)
            manifest_manager._finish_partition('join_example_rep',
                                               dj_pb.JoinExampleState.UnJoined,
                                               dj_pb.JoinExampleState.Joined,
                                               -1, i)
            self._dbms.append(
                data_block_manager.DataBlockManager(self._data_source, i))
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-x"
     data_source.data_source_meta.partition_num = 4
     data_source.data_source_meta.start_time = 0
     data_source.data_source_meta.end_time = 10000
     data_source.output_base_dir = "./ds_output"
     data_source.role = common_pb.FLRole.Follower
     self.data_source = data_source
     self.kvstore = db_client.DBClient('etcd', True)
     common.commit_data_source(self.kvstore, self.data_source)
     if gfile.Exists(data_source.output_base_dir):
         gfile.DeleteRecursively(data_source.output_base_dir)
     self.data_block_matas = []
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     partition_num = self.data_source.data_source_meta.partition_num
     for i in range(partition_num):
         self._create_data_block(i)
 def generate_leader_raw_data(self):
     dbm = data_block_manager.DataBlockManager(self.data_source_l, 0)
     raw_data_dir = os.path.join(self.data_source_l.raw_data_dir,
                                 'partition_{}'.format(0))
     if gfile.Exists(raw_data_dir):
         gfile.DeleteRecursively(raw_data_dir)
     gfile.MakeDirs(raw_data_dir)
     block_index = 0
     builder = data_block_manager.DataBlockBuilder(
         self.data_source_l.raw_data_dir, 0, block_index, None)
     for i in range(0, self.leader_end_index + 3):
         if i > 0 and i % 2048 == 0:
             builder.finish_data_block()
             block_index += 1
             builder = data_block_manager.DataBlockBuilder(
                 self.data_source_l.raw_data_dir, 0, block_index, None)
         feat = {}
         pt = i + 1 << 30
         if i % 3 == 0:
             pt = i // 3
         example_id = '{}'.format(pt).encode()
         feat['example_id'] = tf.train.Feature(
             bytes_list=tf.train.BytesList(value=[example_id]))
         event_time = 150000000 + pt
         feat['event_time'] = tf.train.Feature(
             int64_list=tf.train.Int64List(value=[event_time]))
         example = tf.train.Example(features=tf.train.Features(
             feature=feat))
         builder.append(example.SerializeToString(), example_id, event_time,
                        i, i)
     builder.finish_data_block()
     fpaths = [
         os.path.join(raw_data_dir, f)
         for f in gfile.ListDirectory(raw_data_dir)
         if not gfile.IsDirectory(os.path.join(raw_data_dir, f))
     ]
     for fpath in fpaths:
         if not fpath.endswith(common.DataBlockSuffix):
             gfile.Remove(fpath)
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source_l)
Пример #7
0
 def setUp(self):
     data_source_f = common_pb.DataSource()
     data_source_f.data_source_meta.name = "milestone"
     data_source_f.data_source_meta.partition_num = 1
     data_source_f.data_block_dir = "./data_block-f"
     self.data_source_f = data_source_f
     if gfile.Exists(self.data_source_f.data_block_dir):
         gfile.DeleteRecursively(self.data_source_f.data_block_dir)
     data_source_l = common_pb.DataSource()
     data_source_l.data_source_meta.name = "milestone"
     data_source_l.data_source_meta.partition_num = 1
     data_source_l.data_block_dir = "./data_block-l"
     data_source_l.raw_data_dir = "./raw_data-l"
     self.data_source_l = data_source_l
     if gfile.Exists(self.data_source_l.data_block_dir):
         gfile.DeleteRecursively(self.data_source_l.data_block_dir)
     if gfile.Exists(self.data_source_l.raw_data_dir):
         gfile.DeleteRecursively(self.data_source_l.raw_data_dir)
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source_l.data_source_meta.name)
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source_l)
Пример #8
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-x"
     data_source.data_source_meta.partition_num = 4
     data_source.data_source_meta.start_time = 0
     data_source.data_source_meta.end_time = 10000
     data_source.output_base_dir = "./ds_output"
     data_source.role = common_pb.FLRole.Follower
     self.data_source = data_source
     self.etcd_name = 'test_cluster'
     self.etcd_addrs = 'localhost:2379'
     self.etcd_base_dir = 'fedlearner'
     self.etcd = etcd_client.EtcdClient(self.etcd_name, self.etcd_addrs,
                                        self.etcd_base_dir, True)
     common.commit_data_source(self.etcd, self.data_source)
     if gfile.Exists(data_source.output_base_dir):
         gfile.DeleteRecursively(data_source.output_base_dir)
     self.data_block_matas = []
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     partition_num = self.data_source.data_source_meta.partition_num
     for i in range(partition_num):
         self._create_data_block(i)
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir, 'partition_0')
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     customized_options.set_raw_data_iter('TF_DATASET')
     customized_options.set_compressed_type('GZIP')
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {} {}".format(index, item.example_id, item.event_time))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
Пример #10
0
 def test_compressed_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(
             path.dirname(path.abspath(__file__)), "../compressed_raw_data"
         )
     self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379',
                                           'test_user', 'test_password',
                                           'fedlearner', True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_RECORD',
             compressed_type='GZIP',
             read_ahead_size=1<<20,
             read_batch_size=128
         )
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
Пример #11
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='')
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='ATTRIBUTION_JOINER',
         min_matching_window=32,
         max_matching_window=51200,
         max_conversion_delay=interval_to_timestamp("124"),
         enable_negative_example_generator=True,
         data_block_dump_interval=32,
         data_block_dump_threshold=128,
         negative_sampling_rate=0.8,
     )
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379',
                                          'test_user', 'test_password',
                                          'fedlearner', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
Пример #12
0
 def setUp(self):
     data_source_f = common_pb.DataSource()
     data_source_f.data_source_meta.name = "milestone"
     data_source_f.data_source_meta.partition_num = 1
     data_source_f.output_base_dir = "./output-f"
     self.data_source_f = data_source_f
     if gfile.Exists(self.data_source_f.output_base_dir):
         gfile.DeleteRecursively(self.data_source_f.output_base_dir)
     data_source_l = common_pb.DataSource()
     data_source_l.data_source_meta.name = "milestone"
     data_source_l.data_source_meta.partition_num = 1
     data_source_l.output_base_dir = "./output-l"
     self.raw_data_dir_l = "./raw_data-l"
     self.data_source_l = data_source_l
     if gfile.Exists(self.data_source_l.output_base_dir):
         gfile.DeleteRecursively(self.data_source_l.output_base_dir)
     if gfile.Exists(self.raw_data_dir_l):
         gfile.DeleteRecursively(self.raw_data_dir_l)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source_l.data_source_meta.name))
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source_l)
Пример #13
0
    def test_raw_data_manifest_manager(self):
        cli = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                     'fedlearner', True)
        partition_num = 4
        rank_id = 2
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = "milestone-x"
        data_source.data_source_meta.partition_num = partition_num
        data_source.role = common_pb.FLRole.Leader
        cli.delete_prefix(data_source.data_source_meta.name)
        manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
            cli, data_source)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(
                manifest_map[i].sync_example_id_rep.state,
                dj_pb.SyncExampleIdState.UnSynced
            )
            self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            self.assertEqual(
                manifest_map[i].join_example_rep.state,
                dj_pb.JoinExampleState.UnJoined
            )
            self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        manifest = manifest_manager.alloc_sync_exampld_id(rank_id)
        self.assertNotEqual(manifest, None)
        partition_id = manifest.partition_id
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i != partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Syncing
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        partition_id2 = 3 - partition_id
        rank_id2 = 100
        manifest = manifest_manager.alloc_join_example(rank_id2, partition_id2)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Syncing
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            if i == partition_id2:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
            else:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception,  manifest_manager.finish_join_example,
                rank_id, partition_id)
        self.assertRaises(Exception,  manifest_manager.finish_join_example,
                rank_id2, partition_id2)
        self.assertRaises(Exception,  manifest_manager.finish_sync_example_id,
                -rank_id, partition_id)
        self.assertRaises(Exception,  manifest_manager.finish_sync_example_id,
                rank_id2, partition_id2)
        rank_id3 = 0
        manifest = manifest_manager.alloc_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Syncing
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3)
            elif i == partition_id2:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
            else:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception, manifest_manager.finish_sync_example_id, 
                          rank_id, partition_id)
        self.assertRaises(Exception, manifest_manager.add_raw_data, 
                          partition_id, ['a', 'a', 'b'], False)
        manifest_manager.add_raw_data(partition_id, ['a', 'a', 'b'], True)
        manifest = manifest_manager.get_manifest(partition_id)
        self.assertEqual(manifest.next_process_index, 2)
        manifest_manager.add_raw_data(partition_id, ['a', 'a', 'b', 'c', 'd'], True)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].next_process_index, 4)
            else:
                self.assertEqual(manifest_map[i].next_process_index, 0)
        manifest_manager.finish_raw_data(partition_id)
        manifest_manager.finish_raw_data(partition_id)
        self.assertRaises(Exception, manifest_manager.add_raw_data, partition_id, 200)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Synced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
                self.assertTrue(manifest_map[i].finished)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3)
            elif i == partition_id2:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
            else:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Synced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3)
            elif i == partition_id2:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
            else:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        cli.destory_client_pool()
Пример #14
0
    def test_raw_data_manifest_manager(self):
        cli = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                     'fedlearner')
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = "milestone-x"
        data_source.data_source_meta.partition_num = 4
        cli.delete_prefix(data_source.data_source_meta.name)
        manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
            cli, data_source)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(manifest_map[i].state,
                             dj_pb.RawDataState.UnAllocated)
            self.assertEqual(manifest_map[i].allocated_rank_id, -1)

        manifest, finished = manifest_manager.alloc_unallocated_partition(0)
        self.assertFalse(finished)
        partition_id = manifest.partition_id
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i != partition_id:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.UnAllocated)
                self.assertEqual(manifest_map[i].allocated_rank_id, -1)
            else:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.Syncing)
                self.assertEqual(manifest_map[i].allocated_rank_id, 0)

        manifest_manager.finish_sync_partition(0, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(manifest_map[i].allocated_rank_id, -1)
            if i != partition_id:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.UnAllocated)
            else:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.Synced)
        manifest2, finished = manifest_manager.alloc_synced_partition(2)
        self.assertFalse(finished)
        self.assertEqual(manifest.partition_id, manifest2.partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i != partition_id:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.UnAllocated)
                self.assertEqual(manifest_map[i].allocated_rank_id, -1)
            else:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.Joining)
                self.assertEqual(manifest_map[i].allocated_rank_id, 2)
        manifest_manager.finish_join_partition(2, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(manifest_map[i].allocated_rank_id, -1)
            if i != partition_id:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.UnAllocated)
            else:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.Done)
        cli.destory_client_pool()
    def test_raw_data_manifest_manager(self):
        cli = mysql_client.DBClient('test_cluster', 'localhost:2379',
                                    'test_user', 'test_password', 'fedlearner',
                                    True)
        partition_num = 4
        rank_id = 2
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = "milestone-x"
        data_source.data_source_meta.partition_num = partition_num
        data_source.role = common_pb.FLRole.Leader
        cli.delete_prefix(
            common.data_source_kvstore_base_dir(
                data_source.data_source_meta.name))
        manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
            cli, data_source)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                             dj_pb.SyncExampleIdState.UnSynced)
            self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            self.assertEqual(manifest_map[i].join_example_rep.state,
                             dj_pb.JoinExampleState.UnJoined)
            self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        manifest = manifest_manager.alloc_sync_exampld_id(rank_id)
        self.assertNotEqual(manifest, None)
        partition_id = manifest.partition_id
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i != partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Syncing)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        partition_id2 = 3 - partition_id
        rank_id2 = 100
        manifest = manifest_manager.alloc_join_example(rank_id2, partition_id2)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Syncing)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception, manifest_manager.finish_join_example,
                          rank_id, partition_id)
        self.assertRaises(Exception, manifest_manager.finish_join_example,
                          rank_id2, partition_id2)
        self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
                          -rank_id, partition_id)
        self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
                          rank_id2, partition_id2)
        rank_id3 = 0
        manifest = manifest_manager.alloc_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Syncing)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id3)
            elif i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
                          rank_id, partition_id)
        raw_data_metas = [
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='c',
                              timestamp=timestamp_pb2.Timestamp(seconds=1))
        ]
        self.assertRaises(Exception, manifest_manager.add_raw_data,
                          partition_id, raw_data_metas, False)
        manifest_manager.add_raw_data(partition_id, raw_data_metas, True)
        latest_ts = manifest_manager.get_raw_date_latest_timestamp(
            partition_id)
        self.assertEqual(latest_ts.seconds, 3)
        self.assertEqual(latest_ts.nanos, 0)
        manifest = manifest_manager.get_manifest(partition_id)
        self.assertEqual(manifest.next_process_index, 2)
        raw_data_metas = [
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='b',
                              timestamp=timestamp_pb2.Timestamp(seconds=2)),
            dj_pb.RawDataMeta(file_path='c',
                              timestamp=timestamp_pb2.Timestamp(seconds=1)),
            dj_pb.RawDataMeta(file_path='d',
                              timestamp=timestamp_pb2.Timestamp(seconds=4))
        ]
        manifest_manager.add_raw_data(partition_id, raw_data_metas, True)
        latest_ts = manifest_manager.get_raw_date_latest_timestamp(
            partition_id)
        self.assertEqual(latest_ts.seconds, 4)
        self.assertEqual(latest_ts.nanos, 0)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].next_process_index, 4)
            else:
                self.assertEqual(manifest_map[i].next_process_index, 0)
        manifest_manager.finish_raw_data(partition_id)
        manifest_manager.finish_raw_data(partition_id)
        self.assertRaises(Exception, manifest_manager.add_raw_data,
                          partition_id, 200)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Synced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
                self.assertTrue(manifest_map[i].finished)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id3)
            elif i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Synced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id3)
            elif i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        cli.destroy_client_pool()