Exemplo n.º 1
0
 def test_csv_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)),
                                   "../csv_raw_data")
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(0, [
         dj_pb.RawDataMeta(file_path=path.join(partition_dir,
                                               "test_raw_data.csv"),
                           timestamp=timestamp_pb2.Timestamp(seconds=3))
     ], True)
     raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                             read_ahead_size=1 << 20)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {}".format(index, item.raw_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertEqual(expected_index, 4999)
Exemplo n.º 2
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.data_block_dir = "./data_block"
     data_source.example_dumped_dir = "./example_id"
     data_source.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='')
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='STREAM_JOINER',
         min_matching_window=32,
         max_matching_window=128,
         data_block_dump_interval=30,
         data_block_dump_threshold=128)
     if gfile.Exists(self.data_source.data_block_dir):
         gfile.DeleteRecursively(self.data_source.data_block_dir)
     if gfile.Exists(self.data_source.example_dumped_dir):
         gfile.DeleteRecursively(self.data_source.example_dumped_dir)
     if gfile.Exists(self.data_source.raw_data_dir):
         gfile.DeleteRecursively(self.data_source.raw_data_dir)
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     self.g_data_block_index = 0
Exemplo n.º 3
0
 def setUp(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir,
                                  'partition_0')
     if gfile.Exists(partition_dir):
         gfile.DeleteRecursively(partition_dir)
     gfile.MakeDirs(partition_dir)
     for i in range(2):
         fname = 'raw_data_{}'.format(i)
         fpath = os.path.join(partition_dir, fname)
         writer = tf.io.TFRecordWriter(fpath)
         for j in range(100):
             feat = {}
             example_id = '{}'.format(i * 100 + j).encode()
             feat['example_id'] = tf.train.Feature(
                 bytes_list=tf.train.BytesList(value=[example_id]))
             example = tf.train.Example(features=tf.train.Features(
                 feature=feat))
             writer.write(example.SerializeToString())
         writer.close()
Exemplo n.º 4
0
 def setUp(self):
     data_source_f = common_pb.DataSource()
     data_source_f.data_source_meta.name = "milestone"
     data_source_f.data_source_meta.partition_num = 1
     data_source_f.output_base_dir = "./output-f"
     self.data_source_f = data_source_f
     if gfile.Exists(self.data_source_f.output_base_dir):
         gfile.DeleteRecursively(self.data_source_f.output_base_dir)
     data_source_l = common_pb.DataSource()
     data_source_l.data_source_meta.name = "milestone"
     data_source_l.data_source_meta.partition_num = 1
     data_source_l.output_base_dir = "./output-l"
     self.raw_data_dir_l = "./raw_data-l"
     self.data_source_l = data_source_l
     if gfile.Exists(self.data_source_l.output_base_dir):
         gfile.DeleteRecursively(self.data_source_l.output_base_dir)
     if gfile.Exists(self.raw_data_dir_l):
         gfile.DeleteRecursively(self.raw_data_dir_l)
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source_l.data_source_meta.name))
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source_l)
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=os.path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_DATASET',
             compressed_type='GZIP'
         )
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
Exemplo n.º 6
0
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner')
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir,
                                  'partition_0')
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     options = customized_options.CustomizedOptions()
     options.set_raw_data_iter('TF_DATASET')
     options.set_compressed_type('GZIP')
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {} {}".format(index, item.example_id,
                                     item.event_time))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
Exemplo n.º 7
0
    def test_etcd_op(self):
        cli = etcd_client.EtcdClient('test_cluster', '10.8.163.165:4578',
                                     'data_source_a')
        cli.delete('fl_key')
        cli.set_data('fl_key', 'fl_value')
        self.assertEqual(cli.get_data('fl_key'), b'fl_value')
        self.assertFalse(cli.cas('fl_key', 'fl_value1', 'fl_value2'))
        self.assertTrue(cli.cas('fl_key', 'fl_value', 'fl_value1'))
        self.assertEqual(cli.get_data('fl_key'), b'fl_value1')

        goahead = False

        def thread_routine():
            cli.set_data('fl_key', 'fl_value2')
            self.assertEqual(cli.get_data('fl_key'), b'fl_value2')

        eiter, cancel = cli.watch_key('fl_key')
        other = threading.Thread(target=thread_routine)
        other.start()
        for e in eiter:
            self.assertEqual(e.key, b'/data_source_a/fl_key')
            self.assertEqual(e.value, b'fl_value2')
            cancel()
        other.join()
        cli.destory_client_pool()
Exemplo n.º 8
0
 def setUp(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir,
                                  common.partition_repr(0))
     if gfile.Exists(partition_dir):
         gfile.DeleteRecursively(partition_dir)
     gfile.MakeDirs(partition_dir)
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
 def setUp(self):
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-x"
     data_source.data_source_meta.partition_num = 1
     data_source.example_dumped_dir = "./example_ids"
     self.etcd.delete_prefix(data_source.data_source_meta.name)
     self.data_source = data_source
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=-1, example_id_dump_threshold=1024)
     if gfile.Exists(self.data_source.example_dumped_dir):
         gfile.DeleteRecursively(self.data_source.example_dumped_dir)
     self.partition_dir = os.path.join(self.data_source.example_dumped_dir,
                                       common.partition_repr(0))
     gfile.MakeDirs(self.partition_dir)
Exemplo n.º 10
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.data_block_dir = "./data_block"
     data_source.example_dumped_dir = "./example_id"
     data_source.raw_data_dir = "./raw_data"
     data_source.data_source_meta.min_matching_window = 64
     data_source.data_source_meta.max_matching_window = 128
     data_source.data_source_meta.max_example_in_data_block = 128
     self.data_source = data_source
     if gfile.Exists(self.data_source.data_block_dir):
         gfile.DeleteRecursively(self.data_source.data_block_dir)
     if gfile.Exists(self.data_source.example_dumped_dir):
         gfile.DeleteRecursively(self.data_source.example_dumped_dir)
     if gfile.Exists(self.data_source.raw_data_dir):
         gfile.DeleteRecursively(self.data_source.raw_data_dir)
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
Exemplo n.º 11
0
 def setUp(self):
     data_source_f = common_pb.DataSource()
     data_source_f.data_source_meta.name = "milestone-f"
     data_source_f.data_source_meta.partition_num = 1
     data_source_f.data_block_dir = "./data_block-f"
     self.data_source_f = data_source_f
     if gfile.Exists(self.data_source_f.data_block_dir):
         gfile.DeleteRecursively(self.data_source_f.data_block_dir)
     data_source_l = common_pb.DataSource()
     data_source_l.data_source_meta.name = "milestone-l"
     data_source_l.data_source_meta.partition_num = 1
     data_source_l.data_block_dir = "./data_block-l"
     data_source_l.raw_data_dir = "./raw_data-l"
     self.data_source_l = data_source_l
     if gfile.Exists(self.data_source_l.data_block_dir):
         gfile.DeleteRecursively(self.data_source_l.data_block_dir)
     if gfile.Exists(self.data_source_l.raw_data_dir):
         gfile.DeleteRecursively(self.data_source_l.raw_data_dir)
     self.etcd = etcd_client.EtcdClient('test_cluster', '10.8.163.165:4578', 'byte_fl')
     self.etcd.delete_prefix(self.data_source_l.data_source_meta.name)
Exemplo n.º 12
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-x"
     data_source.data_source_meta.partition_num = 4
     data_source.data_source_meta.start_time = 0
     data_source.data_source_meta.end_time = 10000
     data_source.output_base_dir = "./ds_output"
     data_source.role = common_pb.FLRole.Follower
     self.data_source = data_source
     self.etcd_name = 'test_cluster'
     self.etcd_addrs = 'localhost:2379'
     self.etcd_base_dir = 'fedlearner'
     self.etcd = etcd_client.EtcdClient(self.etcd_name, self.etcd_addrs,
                                        self.etcd_base_dir, True)
     common.commit_data_source(self.etcd, self.data_source)
     if gfile.Exists(data_source.output_base_dir):
         gfile.DeleteRecursively(data_source.output_base_dir)
     self.data_block_matas = []
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     partition_num = self.data_source.data_source_meta.partition_num
     for i in range(partition_num):
         self._create_data_block(i)
Exemplo n.º 13
0
    def test_etcd_op(self):
        cli = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                     'data_source_a', True)
        cli.delete('fl_key')
        cli.set_data('fl_key', 'fl_value')
        self.assertEqual(cli.get_data('fl_key'), b'fl_value')
        self.assertFalse(cli.cas('fl_key', 'fl_value1', 'fl_value2'))
        self.assertTrue(cli.cas('fl_key', 'fl_value', 'fl_value1'))
        self.assertEqual(cli.get_data('fl_key'), b'fl_value1')

        goahead = False
        def thread_routine():
            cli.set_data('fl_key', 'fl_value2')
            self.assertEqual(cli.get_data('fl_key'), b'fl_value2')

        eiter, cancel = cli.watch_key('fl_key')
        other = threading.Thread(target=thread_routine)
        other.start()
        for e in eiter:
            self.assertEqual(e.key, b'fl_key')
            self.assertEqual(e.value, b'fl_value2')
            cancel()
        other.join()
        cli.set_data('fl_key/a', '1')
        cli.set_data('fl_key/b', '2')
        cli.set_data('fl_key/c', '3')
        expected_kvs = [(b'fl_key', b'fl_value2'), (b'fl_key/a', b'1'),
                        (b'fl_key/b', b'2'), (b'fl_key/c', b'3')]
        for idx, kv in enumerate(cli.get_prefix_kvs('fl_key')):
            self.assertEqual(kv[0], expected_kvs[idx][0])
            self.assertEqual(kv[1], expected_kvs[idx][1])
        for idx, kv in enumerate(cli.get_prefix_kvs('fl_key', True)):
            self.assertEqual(kv[0], expected_kvs[idx+1][0])
            self.assertEqual(kv[1], expected_kvs[idx+1][1])

        cli.destory_client_pool()
Exemplo n.º 14
0
    def test_raw_data_manifest_manager(self):
        cli = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                     'fedlearner', True)
        partition_num = 4
        rank_id = 2
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = "milestone-x"
        data_source.data_source_meta.partition_num = partition_num
        data_source.role = common_pb.FLRole.Leader
        cli.delete_prefix(data_source.data_source_meta.name)
        manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
            cli, data_source)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(
                manifest_map[i].sync_example_id_rep.state,
                dj_pb.SyncExampleIdState.UnSynced
            )
            self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            self.assertEqual(
                manifest_map[i].join_example_rep.state,
                dj_pb.JoinExampleState.UnJoined
            )
            self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        manifest = manifest_manager.alloc_sync_exampld_id(rank_id)
        self.assertNotEqual(manifest, None)
        partition_id = manifest.partition_id
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i != partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Syncing
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        partition_id2 = 3 - partition_id
        rank_id2 = 100
        manifest = manifest_manager.alloc_join_example(rank_id2, partition_id2)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Syncing
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            if i == partition_id2:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
            else:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception,  manifest_manager.finish_join_example,
                rank_id, partition_id)
        self.assertRaises(Exception,  manifest_manager.finish_join_example,
                rank_id2, partition_id2)
        self.assertRaises(Exception,  manifest_manager.finish_sync_example_id,
                -rank_id, partition_id)
        self.assertRaises(Exception,  manifest_manager.finish_sync_example_id,
                rank_id2, partition_id2)
        rank_id3 = 0
        manifest = manifest_manager.alloc_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Syncing
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3)
            elif i == partition_id2:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
            else:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception, manifest_manager.finish_sync_example_id, 
                          rank_id, partition_id)
        self.assertRaises(Exception, manifest_manager.add_raw_data, 
                          partition_id, ['a', 'a', 'b'], False)
        manifest_manager.add_raw_data(partition_id, ['a', 'a', 'b'], True)
        manifest = manifest_manager.get_manifest(partition_id)
        self.assertEqual(manifest.next_process_index, 2)
        manifest_manager.add_raw_data(partition_id, ['a', 'a', 'b', 'c', 'd'], True)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].next_process_index, 4)
            else:
                self.assertEqual(manifest_map[i].next_process_index, 0)
        manifest_manager.finish_raw_data(partition_id)
        manifest_manager.finish_raw_data(partition_id)
        self.assertRaises(Exception, manifest_manager.add_raw_data, partition_id, 200)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Synced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
                self.assertTrue(manifest_map[i].finished)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3)
            elif i == partition_id2:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
            else:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.Synced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id)
            else:
                self.assertEqual(
                    manifest_map[i].sync_example_id_rep.state,
                    dj_pb.SyncExampleIdState.UnSynced
                )
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            if i == partition_id:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3)
            elif i == partition_id2:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.Joining
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2)
            else:
                self.assertEqual(
                    manifest_map[i].join_example_rep.state,
                    dj_pb.JoinExampleState.UnJoined
                )
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        cli.destory_client_pool()
Exemplo n.º 15
0
    def test_raw_data_manifest_manager(self):
        cli = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                     'fedlearner')
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = "milestone-x"
        data_source.data_source_meta.partition_num = 4
        cli.delete_prefix(data_source.data_source_meta.name)
        manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
            cli, data_source)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(manifest_map[i].state,
                             dj_pb.RawDataState.UnAllocated)
            self.assertEqual(manifest_map[i].allocated_rank_id, -1)

        manifest, finished = manifest_manager.alloc_unallocated_partition(0)
        self.assertFalse(finished)
        partition_id = manifest.partition_id
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i != partition_id:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.UnAllocated)
                self.assertEqual(manifest_map[i].allocated_rank_id, -1)
            else:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.Syncing)
                self.assertEqual(manifest_map[i].allocated_rank_id, 0)

        manifest_manager.finish_sync_partition(0, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(manifest_map[i].allocated_rank_id, -1)
            if i != partition_id:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.UnAllocated)
            else:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.Synced)
        manifest2, finished = manifest_manager.alloc_synced_partition(2)
        self.assertFalse(finished)
        self.assertEqual(manifest.partition_id, manifest2.partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i != partition_id:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.UnAllocated)
                self.assertEqual(manifest_map[i].allocated_rank_id, -1)
            else:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.Joining)
                self.assertEqual(manifest_map[i].allocated_rank_id, 2)
        manifest_manager.finish_join_partition(2, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(manifest_map[i].allocated_rank_id, -1)
            if i != partition_id:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.UnAllocated)
            else:
                self.assertEqual(manifest_map[i].state,
                                 dj_pb.RawDataState.Done)
        cli.destory_client_pool()