コード例 #1
0
 def test_csv_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)),
                                   "../csv_raw_data")
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(0, [
         dj_pb.RawDataMeta(file_path=path.join(partition_dir,
                                               "test_raw_data.csv"),
                           timestamp=timestamp_pb2.Timestamp(seconds=3))
     ], True)
     raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                             read_ahead_size=1 << 20)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {}".format(index, item.raw_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertEqual(expected_index, 4999)
コード例 #2
0
 def test_compressed_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(
             path.dirname(path.abspath(__file__)), "../compressed_raw_data"
         )
     self.kvstore = DBClient('etcd', True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_RECORD',
             compressed_type='GZIP',
             read_ahead_size=1<<20,
             read_batch_size=128
         )
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
コード例 #3
0
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=os.path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_DATASET',
             compressed_type='GZIP'
         )
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
コード例 #4
0
 def add_raw_data(self, partition_id, fpaths, dedup, timestamps=None):
     self._check_partition_id(partition_id)
     if not fpaths:
         raise RuntimeError("no files input")
     if timestamps is not None and len(fpaths) != len(timestamps):
         raise RuntimeError("the number of raw data file "\
                            "and timestamp mismatch")
     rdreq = dj_pb.RawDataRequest(
                 data_source_meta=self._data_source.data_source_meta,
                 partition_id=partition_id,
                 added_raw_data_metas=dj_pb.AddedRawDataMetas(
                     dedup=dedup
                 )
             )
     for index, fpath in enumerate(fpaths):
         if not gfile.Exists(fpath):
             raise ValueError('{} is not existed' % format(fpath))
         raw_data_meta = dj_pb.RawDataMeta(
                 file_path=fpath,
                 start_index=-1
             )
         if timestamps is not None:
             raw_data_meta.timestamp.MergeFrom(timestamps[index])
         rdreq.added_raw_data_metas.raw_data_metas.append(raw_data_meta)
     return self._master_client.AddRawData(rdreq)
コード例 #5
0
 def _preload_raw_data_meta(self):
     manifest_kvstore_key = common.partition_manifest_kvstore_key(
         self._data_source.data_source_meta.name, self._partition_id)
     all_metas = []
     index_metas = []
     for key, val in self._kvstore.get_prefix_kvs(manifest_kvstore_key,
                                                  True):
         bkey = os.path.basename(key)
         if not bkey.decode().startswith(common.RawDataMetaPrefix):
             continue
         index = int(bkey[len(common.RawDataMetaPrefix):])
         meta = text_format.Parse(val, dj_pb.RawDataMeta())
         all_metas.append((index, meta))
         if meta.start_index != -1:
             index_meta = visitor.IndexMeta(index, meta.start_index,
                                            meta.file_path)
             index_metas.append(index_meta)
     all_metas = sorted(all_metas, key=lambda meta: meta[0])
     for process_index, meta in enumerate(all_metas):
         if process_index != meta[0]:
             logging.fatal("process_index mismatch with index %d != %d "\
                           "for file path %s", process_index, meta[0],
                           meta[1].file_path)
             traceback.print_stack()
             os._exit(-1)  # pylint: disable=protected-access
     return all_metas, index_metas
コード例 #6
0
 def _gen_raw_data_file(self, start_index, end_index, no_data=False):
     partition_dir = os.path.join(self.raw_data_dir,
                                  common.partition_repr(0))
     fpaths = []
     for i in range(start_index, end_index):
         if no_data:
             fname = "{}.no_data".format(i)
         else:
             fname = "{}{}".format(i, common.RawDataFileSuffix)
         fpath = os.path.join(partition_dir, fname)
         fpaths.append(
             dj_pb.RawDataMeta(
                 file_path=fpath,
                 timestamp=timestamp_pb2.Timestamp(seconds=3)))
         writer = tf.io.TFRecordWriter(fpath)
         if not no_data:
             for j in range(100):
                 feat = {}
                 example_id = '{}'.format(i * 100 + j).encode()
                 feat['example_id'] = tf.train.Feature(
                     bytes_list=tf.train.BytesList(value=[example_id]))
                 example = tf.train.Example(features=tf.train.Features(
                     feature=feat))
                 writer.write(example.SerializeToString())
         writer.close()
     self.manifest_manager.add_raw_data(0, fpaths, True)
コード例 #7
0
 def publish_raw_data(self, partition_id, fpaths, timestamps=None):
     if not fpaths:
         logging.warning("no raw data will be published")
         return
     if timestamps is not None and len(fpaths) != len(timestamps):
         raise RuntimeError("the number of raw data file "\
                            "and timestamp mismatch")
     new_raw_data_pubs = []
     for index, fpath in enumerate(fpaths):
         if not gfile.Exists(fpath):
             raise ValueError('{} is not existed'.format(fpath))
         raw_data_pub = dj_pb.RawDatePub(raw_data_meta=dj_pb.RawDataMeta(
             file_path=fpath, start_index=-1))
         if timestamps is not None:
             raw_data_pub.raw_data_meta.timestamp.MergeFrom(
                 timestamps[index])
         new_raw_data_pubs.append(raw_data_pub)
     next_pub_index = None
     item_index = 0
     data = text_format.MessageToString(new_raw_data_pubs[item_index])
     while item_index < len(new_raw_data_pubs):
         next_pub_index = self._forward_pub_index(partition_id,
                                                  next_pub_index)
         etcd_key = common.raw_data_pub_etcd_key(self._raw_data_pub_dir,
                                                 partition_id,
                                                 next_pub_index)
         if self._etcd.cas(etcd_key, None, data):
             logging.info("Success publish %s at index %d for partition"\
                          "%d", data, next_pub_index, partition_id)
             next_pub_index += 1
             item_index += 1
             if item_index < len(new_raw_data_pubs):
                 raw_data_pub = new_raw_data_pubs[item_index]
                 data = text_format.MessageToString(raw_data_pub)
コード例 #8
0
 def generate_leader_raw_data(self):
     dbm = data_block_manager.DataBlockManager(self.data_source_l, 0)
     raw_data_dir = os.path.join(self.data_source_l.raw_data_dir,
                                 common.partition_repr(0))
     if gfile.Exists(raw_data_dir):
         gfile.DeleteRecursively(raw_data_dir)
     gfile.MakeDirs(raw_data_dir)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source_l, 0)
     block_index = 0
     builder = create_data_block_builder(
         dj_pb.DataBlockBuilderOptions(
             data_block_builder='TF_RECORD_DATABLOCK_BUILDER'),
         self.data_source_l.raw_data_dir,
         self.data_source_l.data_source_meta.name, 0, block_index, None)
     process_index = 0
     start_index = 0
     for i in range(0, self.leader_end_index + 3):
         if (i > 0 and i % 2048 == 0) or (i == self.leader_end_index + 2):
             meta = builder.finish_data_block()
             if meta is not None:
                 ofname = common.encode_data_block_fname(
                     self.data_source_l.data_source_meta.name, meta)
                 fpath = os.path.join(raw_data_dir, ofname)
                 self.manifest_manager.add_raw_data(0, [
                     dj_pb.RawDataMeta(
                         file_path=fpath,
                         timestamp=timestamp_pb2.Timestamp(seconds=3))
                 ], False)
                 process_index += 1
                 start_index += len(meta.example_ids)
             block_index += 1
             builder = create_data_block_builder(
                 dj_pb.DataBlockBuilderOptions(
                     data_block_builder='TF_RECORD_DATABLOCK_BUILDER'),
                 self.data_source_l.raw_data_dir,
                 self.data_source_l.data_source_meta.name, 0, block_index,
                 None)
         feat = {}
         pt = i + 1 << 30
         if i % 3 == 0:
             pt = i // 3
         example_id = '{}'.format(pt).encode()
         feat['example_id'] = tf.train.Feature(
             bytes_list=tf.train.BytesList(value=[example_id]))
         event_time = 150000000 + pt
         feat['event_time'] = tf.train.Feature(
             int64_list=tf.train.Int64List(value=[event_time]))
         example = tf.train.Example(features=tf.train.Features(
             feature=feat))
         builder.append_record(example.SerializeToString(), example_id,
                               event_time, i, i)
     fpaths = [
         os.path.join(raw_data_dir, f)
         for f in gfile.ListDirectory(raw_data_dir)
         if not gfile.IsDirectory(os.path.join(raw_data_dir, f))
     ]
     for fpath in fpaths:
         if not fpath.endswith(common.DataBlockSuffix):
             gfile.Remove(fpath)
コード例 #9
0
ファイル: raw_data_visitor.py プロジェクト: feiga/fedlearner
 def _sync_raw_data_meta(self, process_index):
     etcd_key = common.raw_data_meta_etcd_key(
         self._data_source.data_source_meta.name, self._partition_id,
         process_index)
     data = self._etcd.get_data(etcd_key)
     if data is not None:
         return text_format.Parse(data, dj_pb.RawDataMeta())
     return None
コード例 #10
0
 def _sync_raw_data_meta(self, process_index):
     kvstore_key = common.raw_data_meta_kvstore_key(
         self._data_source.data_source_meta.name, self._partition_id,
         process_index)
     data = self._kvstore.get_data(kvstore_key)
     if data is not None:
         return text_format.Parse(data,
                                  dj_pb.RawDataMeta(),
                                  allow_unknown_field=True)
     return None
コード例 #11
0
 def _new_index_meta(self, process_index, start_index):
     if self._manifest.next_process_index <= process_index:
         return None
     raw_data_meta = None
     if process_index < len(self._all_metas):
         assert process_index == self._all_metas[process_index][0], \
             "process index should equal {} != {}".format(
                 process_index, self._all_metas[process_index][0]
             )
         raw_data_meta = self._all_metas[process_index][1]
     else:
         assert process_index == len(self._all_metas), \
             "the process index should be the next all metas "\
             "{}(process_index) != {}(size of all_metas)".format(
                     process_index, len(self._all_metas)
                 )
         raw_data_meta = self._sync_raw_data_meta(process_index)
         if raw_data_meta is None:
             logging.fatal("the raw data of partition %d index with "\
                           "%d must in etcd",
                           self._partition_id, process_index)
             traceback.print_stack()
             os._exit(-1) # pylint: disable=protected-access
         self._all_metas.append((process_index, raw_data_meta))
     if raw_data_meta.start_index == -1:
         new_meta = dj_pb.RawDataMeta()
         new_meta.MergeFrom(raw_data_meta)
         new_meta.start_index = start_index
         odata = text_format.MessageToString(raw_data_meta)
         ndata = text_format.MessageToString(new_meta)
         etcd_key = common.raw_data_meta_etcd_key(
                 self._data_source.data_source_meta.name,
                 self._partition_id, process_index
             )
         if not self._etcd.cas(etcd_key, odata, ndata):
             raw_data_meta = self._sync_raw_data_meta(process_index)
             assert raw_data_meta is not None, \
                 "the raw data meta of process index {} "\
                 "must not None".format(process_index)
             if raw_data_meta.start_index != start_index:
                 logging.fatal("raw data of partition %d index with "\
                               "%d must start with %d",
                               self._partition_id, process_index,
                               start_index)
                 traceback.print_stack()
                 os._exit(-1) # pylint: disable=protected-access
     return visitor.IndexMeta(process_index, start_index,
                              raw_data_meta.file_path)
コード例 #12
0
 def __init__(self, etcd, raw_data_options, mock_data_source_name,
              input_fpaths):
     mock_data_source = common_pb.DataSource(
         state=common_pb.DataSourceState.Processing,
         data_source_meta=common_pb.DataSourceMeta(
             name=mock_data_source_name, partition_num=1))
     mock_rd_manifest_manager = RawDataManifestManager(
         etcd, mock_data_source)
     manifest = mock_rd_manifest_manager.get_manifest(0)
     if not manifest.finished:
         metas = []
         for fpath in input_fpaths:
             metas.append(dj_pb.RawDataMeta(file_path=fpath,
                                            start_index=-1))
         mock_rd_manifest_manager.add_raw_data(0, metas, True)
         mock_rd_manifest_manager.finish_raw_data(0)
     super(MockRawDataVisitor, self).__init__(etcd, mock_data_source, 0,
                                              raw_data_options)
コード例 #13
0
 def publish_raw_data(self, partition_id, fpaths, timestamps=None):
     if not fpaths:
         logging.warning("no raw data will be published")
         return
     if timestamps is not None and len(fpaths) != len(timestamps):
         raise RuntimeError("the number of raw data file "\
                            "and timestamp mismatch")
     new_raw_data_pubs = []
     for index, fpath in enumerate(fpaths):
         if not gfile.Exists(fpath):
             raise ValueError('{} is not existed'.format(fpath))
         raw_data_pub = dj_pb.RawDatePub(raw_data_meta=dj_pb.RawDataMeta(
             file_path=fpath, start_index=-1))
         if timestamps is not None:
             raw_data_pub.raw_data_meta.timestamp.MergeFrom(
                 timestamps[index])
         new_raw_data_pubs.append(raw_data_pub)
     next_pub_index = None
     item_index = 0
     data = text_format.MessageToString(new_raw_data_pubs[item_index])
     while item_index < len(new_raw_data_pubs):
         next_pub_index = self._forward_pub_index(partition_id,
                                                  next_pub_index)
         if self._check_finish_tag(partition_id, next_pub_index - 1):
             logging.warning("partition %d has been published finish tag "\
                             "at index %d", partition_id, next_pub_index-1)
             break
         kvstore_key = common.raw_data_pub_kvstore_key(
             self._raw_data_pub_dir, partition_id, next_pub_index)
         if self._kvstore.cas(kvstore_key, None, data):
             logging.info("Success publish %s at index %d for partition"\
                          "%d", data, next_pub_index, partition_id)
             next_pub_index += 1
             item_index += 1
             if item_index < len(new_raw_data_pubs):
                 raw_data_pub = new_raw_data_pubs[item_index]
                 data = text_format.MessageToString(raw_data_pub)
     if item_index < len(new_raw_data_pubs) - 1:
         logging.warning("%d files are not published since meet finish "\
                         "tag for partition %d. list following",
                         len(new_raw_data_pubs) - item_index, partition_id)
         for idx, pub in enumerate(new_raw_data_pubs[item_index:]):
             logging.warning("%d. %s", idx, pub.raw_data_meta.file_path)
コード例 #14
0
 def _process_next_process_index(self, partition_id, manifest):
     assert manifest is not None and manifest.partition_id == partition_id
     next_process_index = manifest.next_process_index
     while True:
         meta_etcd_key = \
                 common.raw_data_meta_etcd_key(
                         self._data_source.data_source_meta.name,
                         partition_id, next_process_index
                     )
         data = self._etcd.get_data(meta_etcd_key)
         if data is None:
             break
         meta = text_format.Parse(data, dj_pb.RawDataMeta())
         self._existed_fpath[meta.file_path] = \
                 (partition_id, next_process_index)
         self._update_raw_data_latest_timestamp(partition_id,
                                                meta.timestamp)
         next_process_index += 1
     if next_process_index != manifest.next_process_index:
         manifest.next_process_index = next_process_index
         self._update_manifest(manifest)
     else:
         self._local_manifest[partition_id] = manifest
コード例 #15
0
 def generate_raw_data(self, begin_index, item_count):
     raw_data_dir = os.path.join(self.raw_data_dir,
                                 common.partition_repr(0))
     if not gfile.Exists(raw_data_dir):
         gfile.MakeDirs(raw_data_dir)
     self.total_raw_data_count += item_count
     useless_index = 0
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,
                                           0)
     fpaths = []
     for block_index in range(0, item_count // 2048):
         builder = DataBlockBuilder(
             self.raw_data_dir,
             self.data_source.data_source_meta.name, 0, block_index,
             dj_pb.WriterOptions(output_writer='TF_RECORD'), None)
         cands = list(
             range(begin_index + block_index * 2048,
                   begin_index + (block_index + 1) * 2048))
         start_index = cands[0]
         for i in range(len(cands)):
             if random.randint(1, 4) > 2:
                 continue
             a = random.randint(i - 32, i + 32)
             b = random.randint(i - 32, i + 32)
             if a < 0:
                 a = 0
             if a >= len(cands):
                 a = len(cands) - 1
             if b < 0:
                 b = 0
             if b >= len(cands):
                 b = len(cands) - 1
             if (abs(cands[a] - i - start_index) <= 32
                     and abs(cands[b] - i - start_index) <= 32):
                 cands[a], cands[b] = cands[b], cands[a]
         for example_idx in cands:
             feat = {}
             example_id = '{}'.format(example_idx).encode()
             feat['example_id'] = tf.train.Feature(
                 bytes_list=tf.train.BytesList(value=[example_id]))
             event_time = 150000000 + example_idx
             feat['event_time'] = tf.train.Feature(
                 int64_list=tf.train.Int64List(value=[event_time]))
             label = random.choice([1, 0])
             if random.random() < 0.8:
                 feat['label'] = tf.train.Feature(
                     int64_list=tf.train.Int64List(value=[label]))
             example = tf.train.Example(features=tf.train.Features(
                 feature=feat))
             builder.append_item(TfExampleItem(example.SerializeToString()),
                                 useless_index, useless_index)
             useless_index += 1
         meta = builder.finish_data_block()
         fname = common.encode_data_block_fname(
             self.data_source.data_source_meta.name, meta)
         fpath = os.path.join(raw_data_dir, fname)
         fpaths.append(
             dj_pb.RawDataMeta(
                 file_path=fpath,
                 timestamp=timestamp_pb2.Timestamp(seconds=3)))
         self.g_data_block_index += 1
     all_files = [
         os.path.join(raw_data_dir, f)
         for f in gfile.ListDirectory(raw_data_dir)
         if not gfile.IsDirectory(os.path.join(raw_data_dir, f))
     ]
     for fpath in all_files:
         if not fpath.endswith(common.DataBlockSuffix):
             gfile.Remove(fpath)
     self.manifest_manager.add_raw_data(0, fpaths, False)
コード例 #16
0
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 1
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs, options)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=-1,
            join_example=empty_pb2.Empty())

        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=-1,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq1 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())

        try:
            rsp = client_l.FinishJoinPartition(rdreq1)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq2 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        try:
            rsp = client_l.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 1)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 3)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5)),
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 1)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 1)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1)),
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertTrue(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=4))
                ]))
        try:
            rsp = client_l.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertTrue(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        try:
            rsp = client_f.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)

        master_l.stop()
        master_f.stop()
コード例 #17
0
    def test_raw_data_manifest_manager(self):
        cli = mysql_client.DBClient('test_cluster', 'localhost:2379',
                                    'test_user', 'test_password', 'fedlearner',
                                    True)
        partition_num = 4
        rank_id = 2
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = "milestone-x"
        data_source.data_source_meta.partition_num = partition_num
        data_source.role = common_pb.FLRole.Leader
        cli.delete_prefix(
            common.data_source_kvstore_base_dir(
                data_source.data_source_meta.name))
        manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
            cli, data_source)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                             dj_pb.SyncExampleIdState.UnSynced)
            self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1)
            self.assertEqual(manifest_map[i].join_example_rep.state,
                             dj_pb.JoinExampleState.UnJoined)
            self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        manifest = manifest_manager.alloc_sync_exampld_id(rank_id)
        self.assertNotEqual(manifest, None)
        partition_id = manifest.partition_id
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i != partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Syncing)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        partition_id2 = 3 - partition_id
        rank_id2 = 100
        manifest = manifest_manager.alloc_join_example(rank_id2, partition_id2)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Syncing)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception, manifest_manager.finish_join_example,
                          rank_id, partition_id)
        self.assertRaises(Exception, manifest_manager.finish_join_example,
                          rank_id2, partition_id2)
        self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
                          -rank_id, partition_id)
        self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
                          rank_id2, partition_id2)
        rank_id3 = 0
        manifest = manifest_manager.alloc_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Syncing)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id3)
            elif i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)
            self.assertFalse(manifest_map[i].finished)

        self.assertRaises(Exception, manifest_manager.finish_sync_example_id,
                          rank_id, partition_id)
        raw_data_metas = [
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='c',
                              timestamp=timestamp_pb2.Timestamp(seconds=1))
        ]
        self.assertRaises(Exception, manifest_manager.add_raw_data,
                          partition_id, raw_data_metas, False)
        manifest_manager.add_raw_data(partition_id, raw_data_metas, True)
        latest_ts = manifest_manager.get_raw_date_latest_timestamp(
            partition_id)
        self.assertEqual(latest_ts.seconds, 3)
        self.assertEqual(latest_ts.nanos, 0)
        manifest = manifest_manager.get_manifest(partition_id)
        self.assertEqual(manifest.next_process_index, 2)
        raw_data_metas = [
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='a',
                              timestamp=timestamp_pb2.Timestamp(seconds=3)),
            dj_pb.RawDataMeta(file_path='b',
                              timestamp=timestamp_pb2.Timestamp(seconds=2)),
            dj_pb.RawDataMeta(file_path='c',
                              timestamp=timestamp_pb2.Timestamp(seconds=1)),
            dj_pb.RawDataMeta(file_path='d',
                              timestamp=timestamp_pb2.Timestamp(seconds=4))
        ]
        manifest_manager.add_raw_data(partition_id, raw_data_metas, True)
        latest_ts = manifest_manager.get_raw_date_latest_timestamp(
            partition_id)
        self.assertEqual(latest_ts.seconds, 4)
        self.assertEqual(latest_ts.nanos, 0)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].next_process_index, 4)
            else:
                self.assertEqual(manifest_map[i].next_process_index, 0)
        manifest_manager.finish_raw_data(partition_id)
        manifest_manager.finish_raw_data(partition_id)
        self.assertRaises(Exception, manifest_manager.add_raw_data,
                          partition_id, 200)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_manager.finish_sync_example_id(rank_id, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Synced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
                self.assertTrue(manifest_map[i].finished)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id3)
            elif i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_manager.finish_join_example(rank_id3, partition_id)
        manifest_map = manifest_manager.list_all_manifest()
        for i in range(data_source.data_source_meta.partition_num):
            self.assertTrue(i in manifest_map)
            if i == partition_id:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.Synced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 rank_id)
            else:
                self.assertEqual(manifest_map[i].sync_example_id_rep.state,
                                 dj_pb.SyncExampleIdState.UnSynced)
                self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id,
                                 -1)
            if i == partition_id:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id3)
            elif i == partition_id2:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.Joining)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id,
                                 rank_id2)
            else:
                self.assertEqual(manifest_map[i].join_example_rep.state,
                                 dj_pb.JoinExampleState.UnJoined)
                self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1)

        cli.destroy_client_pool()