Esempio n. 1
0
 def test_raw_data_visitor(self):
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest, finished = manifest_manager.alloc_unallocated_partition(2)
     self.assertFalse(finished)
     self.assertEqual(manifest.partition_id, 0)
     self.assertEqual(manifest.state, dj_pb.RawDataState.Syncing)
     self.assertEqual(manifest.allocated_rank_id, 2)
     customized_options.set_raw_data_iter('TF_RECORD')
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0)
     expected_index = 0
     for (index, item) in rdv:
         self.assertEqual(index, expected_index)
         expected_index += 1
         self.assertEqual(item.example_id, '{}'.format(index).encode())
     try:
         rdv.seek(200)
     except StopIteration:
         self.assertTrue(True)
         self.assertEqual(rdv.get_current_index(), 199)
     else:
         self.assertFalse(False)
     index, item = rdv.seek(50)
     self.assertEqual(index, 50)
     self.assertEqual(item.example_id, '{}'.format(index).encode())
     expected_index = index + 1
     for (index, item) in rdv:
         self.assertEqual(index, expected_index)
         expected_index += 1
         self.assertEqual(item.example_id, '{}'.format(index).encode())
Esempio n. 2
0
 def test_csv_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)),
                                   "../csv_raw_data")
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(0, [
         dj_pb.RawDataMeta(file_path=path.join(partition_dir,
                                               "test_raw_data.csv"),
                           timestamp=timestamp_pb2.Timestamp(seconds=3))
     ], True)
     raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                             read_ahead_size=1 << 20)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {}".format(index, item.raw_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertEqual(expected_index, 4999)
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=os.path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_DATASET',
             compressed_type='GZIP'
         )
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
Esempio n. 4
0
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner')
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir,
                                  'partition_0')
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     options = customized_options.CustomizedOptions()
     options.set_raw_data_iter('TF_DATASET')
     options.set_compressed_type('GZIP')
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {} {}".format(index, item.example_id,
                                     item.event_time))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
 def test_compressed_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(
             path.dirname(path.abspath(__file__)), "../compressed_raw_data"
         )
     self.kvstore = DBClient('etcd', True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_RECORD',
             compressed_type='GZIP',
             read_ahead_size=1<<20,
             read_batch_size=128
         )
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
Esempio n. 6
0
 def test_raw_data_visitor(self):
     rank_id = 2
     manifest = self.manifest_manager.alloc_sync_exampld_id(rank_id)
     self.assertEqual(manifest.partition_id, 0)
     self.assertEqual(manifest.sync_example_id_rep.state,
                      dj_pb.SyncExampleIdState.Syncing)
     self.assertEqual(manifest.sync_example_id_rep.rank_id, rank_id)
     raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD')
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source,
                                           manifest.partition_id,
                                           raw_data_options)
     self.assertRaises(StopIteration, rdv.seek, 0)
     self.assertTrue(rdv.finished())
     self.assertFalse(rdv.is_visitor_stale())
     self._gen_raw_data_file(0, 2)
     self.assertTrue(rdv.is_visitor_stale())
     self.assertRaises(StopIteration, rdv.seek, 0)
     rdv.active_visitor()
     self.assertFalse(rdv.finished())
     expected_index = 0
     for (index, item) in rdv:
         self.assertEqual(index, expected_index)
         expected_index += 1
         self.assertEqual(item.example_id, '{}'.format(index).encode())
     self.assertEqual(expected_index, 200)
     self.assertRaises(StopIteration, rdv.seek, 200)
     self.assertTrue(rdv.finished())
     index, item = rdv.seek(50)
     self.assertEqual(index, 50)
     self.assertEqual(item.example_id, '{}'.format(index).encode())
     self.assertFalse(rdv.finished())
     expected_index = index + 1
     for (index, item) in rdv:
         self.assertEqual(index, expected_index)
         expected_index += 1
         self.assertEqual(item.example_id, '{}'.format(index).encode())
     self.assertEqual(expected_index, 200)
     self._gen_raw_data_file(2, 4)
     self.assertTrue(rdv.is_visitor_stale())
     self.assertTrue(rdv.finished())
     rdv.active_visitor()
     self.assertFalse(rdv.finished())
     for (index, item) in rdv:
         self.assertEqual(index, expected_index)
         expected_index += 1
         self.assertEqual(item.example_id, '{}'.format(index).encode())
     self.assertEqual(expected_index, 400)
     self.assertTrue(rdv.finished())
     rdv.reset()
     self.assertFalse(rdv.finished())
     expected_index = 0
     for (index, item) in rdv:
         self.assertEqual(index, expected_index)
         expected_index += 1
         self.assertEqual(item.example_id, '{}'.format(index).encode())
     self.assertEqual(expected_index, 400)
     self.assertTrue(rdv.finished())
     rdv2 = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source,
                                            manifest.partition_id,
                                            raw_data_options)
     expected_index = 0
     for (index, item) in rdv2:
         self.assertEqual(index, expected_index)
         expected_index += 1
         self.assertEqual(item.example_id, '{}'.format(index).encode())
     self.assertEqual(expected_index, 400)
     self.assertTrue(rdv2.finished())