Пример #1
0
 def _setUpDataSource(self):
     self._data_source_name = 'test_data_source'
     self._etcd_l.delete_prefix(
         common.data_source_etcd_base_dir(self._data_source_name))
     self._etcd_f.delete_prefix(
         common.data_source_etcd_base_dir(self._data_source_name))
     self._data_source_l = common_pb.DataSource()
     self._data_source_l.role = common_pb.FLRole.Leader
     self._data_source_l.state = common_pb.DataSourceState.Init
     self._data_source_l.data_block_dir = "./data_block_l"
     self._data_source_l.raw_data_dir = "./raw_data_l"
     self._data_source_l.example_dumped_dir = "./example_dumped_l"
     self._data_source_l.raw_data_sub_dir = "./raw_data_sub_dir_l"
     self._data_source_f = common_pb.DataSource()
     self._data_source_f.role = common_pb.FLRole.Follower
     self._data_source_f.state = common_pb.DataSourceState.Init
     self._data_source_f.data_block_dir = "./data_block_f"
     self._data_source_f.raw_data_dir = "./raw_data_f"
     self._data_source_f.example_dumped_dir = "./example_dumped_f"
     self._data_source_f.raw_data_sub_dir = "./raw_data_sub_dir_f"
     data_source_meta = common_pb.DataSourceMeta()
     data_source_meta.name = self._data_source_name
     data_source_meta.partition_num = 4
     data_source_meta.start_time = 0
     data_source_meta.end_time = 100000000
     self._data_source_l.data_source_meta.MergeFrom(data_source_meta)
     self._data_source_f.data_source_meta.MergeFrom(data_source_meta)
     common.commit_data_source(self._etcd_l, self._data_source_l)
     common.commit_data_source(self._etcd_f, self._data_source_f)
Пример #2
0
 def tearDown(self):
     if gfile.Exists(self.data_source_l.output_base_dir):
         gfile.DeleteRecursively(self.data_source_l.output_base_dir)
     if gfile.Exists(self.raw_data_dir_l):
         gfile.DeleteRecursively(self.raw_data_dir_l)
     if gfile.Exists(self.data_source_f.output_base_dir):
         gfile.DeleteRecursively(self.data_source_f.output_base_dir)
     if gfile.Exists(self.raw_data_dir_f):
         gfile.DeleteRecursively(self.raw_data_dir_f)
     self.etcd_f.delete_prefix(common.data_source_etcd_base_dir(self.etcd_base_dir_f))
     self.etcd_l.delete_prefix(common.data_source_etcd_base_dir(self.etcd_base_dir_l))
Пример #3
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='')
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='STREAM_JOINER',
         min_matching_window=32,
         max_matching_window=128,
         data_block_dump_interval=30,
         data_block_dump_threshold=128)
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     self.g_data_block_index = 0
Пример #4
0
 def test_csv_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)),
                                   "../csv_raw_data")
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(0, [
         dj_pb.RawDataMeta(file_path=path.join(partition_dir,
                                               "test_raw_data.csv"),
                           timestamp=timestamp_pb2.Timestamp(seconds=3))
     ], True)
     raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                             read_ahead_size=1 << 20)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {}".format(index, item.raw_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertEqual(expected_index, 4999)
Пример #5
0
 def setUp(self):
     data_source_f = common_pb.DataSource()
     data_source_f.data_source_meta.name = "milestone"
     data_source_f.data_source_meta.partition_num = 1
     data_source_f.output_base_dir = "./output-f"
     self.data_source_f = data_source_f
     if gfile.Exists(self.data_source_f.output_base_dir):
         gfile.DeleteRecursively(self.data_source_f.output_base_dir)
     data_source_l = common_pb.DataSource()
     data_source_l.data_source_meta.name = "milestone"
     data_source_l.data_source_meta.partition_num = 1
     data_source_l.output_base_dir = "./output-l"
     self.raw_data_dir_l = "./raw_data-l"
     self.data_source_l = data_source_l
     if gfile.Exists(self.data_source_l.output_base_dir):
         gfile.DeleteRecursively(self.data_source_l.output_base_dir)
     if gfile.Exists(self.raw_data_dir_l):
         gfile.DeleteRecursively(self.raw_data_dir_l)
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source_l.data_source_meta.name))
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source_l)
Пример #6
0
 def tearDown(self):
     if gfile.Exists(self.data_source_l.data_block_dir):
         gfile.DeleteRecursively(self.data_source_l.data_block_dir)
     if gfile.Exists(self.data_source_l.example_dumped_dir):
         gfile.DeleteRecursively(self.data_source_l.example_dumped_dir)
     if gfile.Exists(self.data_source_l.raw_data_dir):
         gfile.DeleteRecursively(self.data_source_l.raw_data_dir)
     if gfile.Exists(self.data_source_f.data_block_dir):
         gfile.DeleteRecursively(self.data_source_f.data_block_dir)
     if gfile.Exists(self.data_source_f.example_dumped_dir):
         gfile.DeleteRecursively(self.data_source_f.example_dumped_dir)
     if gfile.Exists(self.data_source_f.raw_data_dir):
         gfile.DeleteRecursively(self.data_source_f.raw_data_dir)
     self.etcd_f.delete_prefix(
         common.data_source_etcd_base_dir(self.etcd_base_dir_f))
     self.etcd_l.delete_prefix(
         common.data_source_etcd_base_dir(self.etcd_base_dir_l))
Пример #7
0
 def tearDown(self):
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
Пример #8
0
 def setUp(self):
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-x"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.etcd.delete_prefix(common.data_source_etcd_base_dir(data_source.data_source_meta.name))
     self.data_source = data_source
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
             example_id_dump_interval=-1,
             example_id_dump_threshold=1024
         )
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     self.partition_dir = os.path.join(common.data_source_example_dumped_dir(self.data_source), common.partition_repr(0))
     gfile.MakeDirs(self.partition_dir)
Пример #9
0
 def setUp(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir,
                                  common.partition_repr(0))
     if gfile.Exists(partition_dir):
         gfile.DeleteRecursively(partition_dir)
     gfile.MakeDirs(partition_dir)
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
Пример #10
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]),
            master_addr_f,
            data_source_name,
            etcd_name,
            etcd_base_dir_l,
            etcd_addrs,
            master_options,
        )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, master_options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = master_client_l.GetDataSourceStatus(req_l)
            dss_f = master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        self.master_client_l = master_client_l
        self.master_client_f = master_client_f
        self.master_addr_l = master_addr_l
        self.master_addr_f = master_addr_f
        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.master_l = master_l
        self.master_f = master_f
        self.data_source_name = data_source_name,
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
            self.etcd_l, self.raw_data_pub_dir_l)
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
            self.etcd_f, self.raw_data_pub_dir_f)
        if gfile.Exists(data_source_l.data_block_dir):
            gfile.DeleteRecursively(data_source_l.data_block_dir)
        if gfile.Exists(data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(data_source_l.example_dumped_dir)
        if gfile.Exists(data_source_l.raw_data_dir):
            gfile.DeleteRecursively(data_source_l.raw_data_dir)
        if gfile.Exists(data_source_f.data_block_dir):
            gfile.DeleteRecursively(data_source_f.data_block_dir)
        if gfile.Exists(data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(data_source_f.example_dumped_dir)
        if gfile.Exists(data_source_f.raw_data_dir):
            gfile.DeleteRecursively(data_source_f.raw_data_dir)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='STREAM_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=512, max_flying_item=2048),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self.total_index = 1 << 13
Пример #11
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f= 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        self.raw_data_dir_l = "./raw_data_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        self.raw_data_dir_f = "./raw_data_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.data_source_name = data_source_name
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
                self.etcd_l, self.raw_data_pub_dir_l
            )
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
                self.etcd_f, self.raw_data_pub_dir_f
            )
        if gfile.Exists(data_source_l.output_base_dir):
            gfile.DeleteRecursively(data_source_l.output_base_dir)
        if gfile.Exists(self.raw_data_dir_l):
            gfile.DeleteRecursively(self.raw_data_dir_l)
        if gfile.Exists(data_source_f.output_base_dir):
            gfile.DeleteRecursively(data_source_f.output_base_dir)
        if gfile.Exists(self.raw_data_dir_f):
            gfile.DeleteRecursively(self.raw_data_dir_f)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
                use_mock_etcd=True,
                raw_data_options=dj_pb.RawDataOptions(
                    raw_data_iter='TF_RECORD',
                    read_ahead_size=1<<20,
                    read_batch_size=128
                ),
                example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                    example_id_dump_interval=1,
                    example_id_dump_threshold=1024
                ),
                example_joiner_options=dj_pb.ExampleJoinerOptions(
                    example_joiner='STREAM_JOINER',
                    min_matching_window=64,
                    max_matching_window=256,
                    data_block_dump_interval=30,
                    data_block_dump_threshold=1000
                ),
                batch_processor_options=dj_pb.BatchProcessorOptions(
                    batch_size=512,
                    max_flying_item=2048
                ),
                data_block_builder_options=dj_pb.WriterOptions(
                    output_writer='TF_RECORD'
                )
            )

        self.total_index = 1 << 12
Пример #12
0
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 1
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs, options)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=-1,
            join_example=empty_pb2.Empty())

        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=-1,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq1 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())

        try:
            rsp = client_l.FinishJoinPartition(rdreq1)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq2 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        try:
            rsp = client_l.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 1)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 3)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5)),
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 1)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 1)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1)),
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertTrue(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=4))
                ]))
        try:
            rsp = client_l.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertTrue(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        try:
            rsp = client_f.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)

        master_l.stop()
        master_f.stop()
Пример #13
0
                        required=True,
                        help='the etcd base dir to subscribe new raw data')
    args = parser.parse_args()
    data_source = common_pb.DataSource()
    data_source.data_source_meta.name = args.data_source_name
    data_source.data_source_meta.partition_num = args.partition_num
    data_source.data_source_meta.start_time = args.start_time
    data_source.data_source_meta.end_time = args.end_time
    data_source.data_source_meta.negative_sampling_rate = \
            args.negative_sampling_rate
    if args.role == 'leader':
        data_source.role = common_pb.FLRole.Leader
    else:
        assert args.role == 'follower'
        data_source.role = common_pb.FLRole.Follower
        data_source.example_dumped_dir = args.example_dump_dir
    data_source.data_block_dir = args.data_block_dir
    data_source.raw_data_sub_dir = args.raw_data_sub_dir
    data_source.state = common_pb.DataSourceState.Init
    etcd = EtcdClient(args.etcd_name, args.etcd_addrs, args.etcd_base_dir)
    master_etcd_key = common.data_source_etcd_base_dir(
        data_source.data_source_meta.name)
    raw_data = etcd.get_data(master_etcd_key)
    if raw_data is None:
        logging.info("data source %s is not existed", args.data_source_name)
        common.commit_data_source(etcd, data_source)
        logging.info("apply new data source %s", args.data_source_name)
    else:
        logging.info("data source %s has been existed", args.data_source_name)
    etcd.destroy_client_pool()
Пример #14
0
    def test_raw_data_manifest_manager(self):
        cli = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                     'fedlearner', True)
        partition_num = 4
        rank_id = 2
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = "milestone-x"
        data_source.data_source_meta.partition_num = partition_num
        data_source.role = common_pb.FLRole.Leader
        cli.delete_prefix(
            common.data_source_etcd_base_dir(
                data_source.data_source_meta.name))
        manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
            cli, data_source)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                             dj_pb.SyncExampleIdState.UnSynced)
            self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            self.assertEqual(manifest_map[i].join_example_rep.state,
                             dj_pb.JoinExampleState.UnJoined)
            self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        manifest = manifest_manager.alloc_sync_exampld_id(rank_id)
        self.assertNotEqual(manifest, None)
        partition_id = manifest.partition_id
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i != partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Syncing)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        partition_id2 = 3 - partition_id
        rank_id2 = 100
        manifest = manifest_manager.alloc_join_example(rank_id2, partition_id2)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Syncing)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception, manifest_manager.finish_join_example,
                          rank_id, partition_id)
        self.assertRaises(Exception, manifest_manager.finish_join_example,
                          rank_id2, partition_id2)
        self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
                          -rank_id, partition_id)
        self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
                          rank_id2, partition_id2)
        rank_id3 = 0
        manifest = manifest_manager.alloc_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Syncing)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id3)
            elif i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
                          rank_id, partition_id)
        raw_data_metas = [
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='c',
                              timestamp=timestamp_pb2.Timestamp(seconds=1))
        ]
        self.assertRaises(Exception, manifest_manager.add_raw_data,
                          partition_id, raw_data_metas, False)
        manifest_manager.add_raw_data(partition_id, raw_data_metas, True)
        latest_ts = manifest_manager.get_raw_date_latest_timestamp(
            partition_id)
        self.assertEqual(latest_ts.seconds, 3)
        self.assertEqual(latest_ts.nanos, 0)
        manifest = manifest_manager.get_manifest(partition_id)
        self.assertEqual(manifest.next_process_index, 2)
        raw_data_metas = [
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='b',
                              timestamp=timestamp_pb2.Timestamp(seconds=2)),
            dj_pb.RawDataMeta(file_path='c',
                              timestamp=timestamp_pb2.Timestamp(seconds=1)),
            dj_pb.RawDataMeta(file_path='d',
                              timestamp=timestamp_pb2.Timestamp(seconds=4))
        ]
        manifest_manager.add_raw_data(partition_id, raw_data_metas, True)
        latest_ts = manifest_manager.get_raw_date_latest_timestamp(
            partition_id)
        self.assertEqual(latest_ts.seconds, 4)
        self.assertEqual(latest_ts.nanos, 0)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].next_process_index, 4)
            else:
                self.assertEqual(manifest_map[i].next_process_index, 0)
        manifest_manager.finish_raw_data(partition_id)
        manifest_manager.finish_raw_data(partition_id)
        self.assertRaises(Exception, manifest_manager.add_raw_data,
                          partition_id, 200)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Synced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
                self.assertTrue(manifest_map[i].finished)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id3)
            elif i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Synced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id3)
            elif i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        cli.destroy_client_pool()
Пример #15
0
 def cleanup_meta_data(self):
     with self._lock:
         data_source_name = self._data_source.data_source_meta.name
         etcd_base_key = common.data_source_etcd_base_dir(data_source_name)
         self._etcd.delete_prefix(etcd_base_key)