コード例 #1
0
    def _test_round(self, dumped_index, start_time, end_time):
        partition_num = self.data_source.data_source_meta.partition_num
        for i in range(partition_num):
            self.manifest_manager.forward_peer_dumped_index(i, dumped_index)
        visitor = data_block_visitor.DataBlockVisitor(
            self.data_source.data_source_meta.name, self.db_database,
            self.db_base_dir, self.db_addr, self.db_username, self.db_password,
            True)
        reps = visitor.LoadDataBlockRepByTimeFrame(start_time, end_time)
        metas = [
            meta for meta in self.data_block_matas
            if (not (meta.start_time > end_time or meta.end_time < start_time)
                and meta.data_block_index <= dumped_index)
        ]
        self.assertEqual(len(reps), len(metas))
        for meta in metas:
            self.assertTrue(meta.block_id in reps)
            rep = reps[meta.block_id]
            self.assertEqual(meta.block_id, rep.block_id)
            self.assertEqual(meta.start_time, rep.start_time)
            self.assertEqual(meta.end_time, rep.end_time)
            self.assertEqual(meta.partition_id, rep.partition_id)
            self.assertEqual(meta, rep.data_block_meta)
            data_block_fpath = os.path.join(
                common.data_source_data_block_dir(self.data_source),
                common.partition_repr(meta.partition_id),
                meta.block_id + common.DataBlockSuffix)
            self.assertEqual(data_block_fpath, rep.data_block_fpath)

        for i in range(0, 100):
            rep = visitor.LoadDataBlockReqByIndex(
                random.randint(0, partition_num - 1),
                random.randint(0, dumped_index))
            try:
                meta = [meta for meta in self.data_block_matas if \
                        meta.block_id == rep.block_id][0]
            except Exception as e:
                print(e)
            self.assertEqual(meta.block_id, rep.block_id)
            self.assertEqual(meta.start_time, rep.start_time)
            self.assertEqual(meta.end_time, rep.end_time)
            self.assertEqual(meta.partition_id, rep.partition_id)
            self.assertEqual(meta, rep.data_block_meta)
            data_block_fpath = os.path.join(
                common.data_source_data_block_dir(self.data_source),
                common.partition_repr(meta.partition_id),
                meta.block_id + common.DataBlockSuffix)
            self.assertEqual(data_block_fpath, rep.data_block_fpath)
            self.assertIsNone(
                visitor.LoadDataBlockReqByIndex(
                    random.randint(0, partition_num - 1),
                    random.randint(dumped_index, 10000)))
コード例 #2
0
 def _try_to_sub_raw_data(self, partition_id):
     sub_src_dir = path.join(self._raw_data_sub_dir,
                             common.partition_repr(partition_id))
     with self._lock:
         manifest = self._sync_manifest(partition_id)
         if manifest.finished:
             return
         next_sub_index = manifest.next_raw_data_sub_index
         add_candidates = []
         raw_data_finished = False
         while True:
             etcd_key = common.raw_data_pub_etcd_key(
                 self._raw_data_sub_dir, partition_id, next_sub_index)
             pub_data = self._etcd.get_data(etcd_key)
             if pub_data is None:
                 break
             raw_data_pub = text_format.Parse(pub_data, dj_pb.RawDatePub())
             if raw_data_pub.HasField('raw_data_meta'):
                 add_candidates.append(raw_data_pub.raw_data_meta)
                 next_sub_index += 1
             elif raw_data_pub.HasField('raw_data_finished'):
                 logging.warning("meet finish pub at pub index %d for "\
                                 "partition %d",
                                 next_sub_index, partition_id)
                 raw_data_finished = True
                 break
         self._store_raw_data_metas(partition_id, add_candidates)
         new_manifest = self._sync_manifest(partition_id)
         new_manifest.next_raw_data_sub_index = next_sub_index
         new_manifest.finished = raw_data_finished
         self._update_manifest(new_manifest)
コード例 #3
0
 def _publish_raw_data(self, job_id):
     portal_manifest = self._sync_portal_manifest()
     output_dir = None
     if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
         output_dir = common.portal_map_output_dir(
             portal_manifest.output_base_dir, job_id)
     else:
         output_dir = common.portal_reduce_output_dir(
             portal_manifest.output_base_dir, job_id)
     for partition_id in range(self._output_partition_num):
         dpath = path.join(output_dir, common.partition_repr(partition_id))
         fnames = []
         if gfile.Exists(dpath) and gfile.IsDirectory(dpath):
             fnames = [
                 f for f in gfile.ListDirectory(dpath)
                 if f.endswith(common.RawDataFileSuffix)
             ]
         publish_fpaths = []
         if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
             publish_fpaths = self._publish_psi_raw_data(
                 partition_id, dpath, fnames)
         else:
             publish_fpaths = self._publish_streaming_raw_data(
                 partition_id, dpath, fnames)
         logging.info("Data Portal Master publish %d file for partition "\
                      "%d of streaming job %d\n----------\n",
                      len(publish_fpaths), partition_id, job_id)
         for seq, fpath in enumerate(publish_fpaths):
             logging.info("%d. %s", seq, fpath)
         logging.info("------------------------------------------\n")
コード例 #4
0
 def test_csv_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)),
                                   "../csv_raw_data")
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(0, [
         dj_pb.RawDataMeta(file_path=path.join(partition_dir,
                                               "test_raw_data.csv"),
                           timestamp=timestamp_pb2.Timestamp(seconds=3))
     ], True)
     raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                             read_ahead_size=1 << 20)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {}".format(index, item.raw_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertEqual(expected_index, 4999)
コード例 #5
0
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=os.path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_DATASET',
             compressed_type='GZIP'
         )
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
コード例 #6
0
 def test_raw_data_manager(self):
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertEqual(len(rdm.get_index_metas()), 0)
     self.assertFalse(rdm.check_index_meta_by_process_index(0))
     self._gen_raw_data_file(0, 2)
     self.assertEqual(len(rdm.get_index_metas()), 0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     self.assertTrue(rdm.check_index_meta_by_process_index(1))
     self.assertEqual(len(rdm.get_index_metas()), 0)
     partition_dir = os.path.join(self.data_source.raw_data_dir,
                                  common.partition_repr(0))
     index_meta0 = rdm.get_index_meta_by_index(0, 0)
     self.assertEqual(index_meta0.start_index, 0)
     self.assertEqual(index_meta0.process_index, 0)
     self.assertEqual(len(rdm.get_index_metas()), 1)
     index_meta1 = rdm.get_index_meta_by_index(1, 100)
     self.assertEqual(index_meta1.start_index, 100)
     self.assertEqual(index_meta1.process_index, 1)
     self.assertEqual(len(rdm.get_index_metas()), 2)
     self.assertFalse(rdm.check_index_meta_by_process_index(2))
     self._gen_raw_data_file(2, 4)
     self.assertTrue(rdm.check_index_meta_by_process_index(2))
     self.assertTrue(rdm.check_index_meta_by_process_index(3))
     index_meta2 = rdm.get_index_meta_by_index(2, 200)
     self.assertEqual(index_meta2.start_index, 200)
     self.assertEqual(index_meta2.process_index, 2)
     self.assertEqual(len(rdm.get_index_metas()), 3)
     index_meta3 = rdm.get_index_meta_by_index(3, 300)
     self.assertEqual(index_meta3.start_index, 300)
     self.assertEqual(index_meta3.process_index, 3)
     self.assertEqual(len(rdm.get_index_metas()), 4)
コード例 #7
0
 def _gen_raw_data_file(self, start_index, end_index, no_data=False):
     partition_dir = os.path.join(self.raw_data_dir,
                                  common.partition_repr(0))
     fpaths = []
     for i in range(start_index, end_index):
         if no_data:
             fname = "{}.no_data".format(i)
         else:
             fname = "{}{}".format(i, common.RawDataFileSuffix)
         fpath = os.path.join(partition_dir, fname)
         fpaths.append(
             dj_pb.RawDataMeta(
                 file_path=fpath,
                 timestamp=timestamp_pb2.Timestamp(seconds=3)))
         writer = tf.io.TFRecordWriter(fpath)
         if not no_data:
             for j in range(100):
                 feat = {}
                 example_id = '{}'.format(i * 100 + j).encode()
                 feat['example_id'] = tf.train.Feature(
                     bytes_list=tf.train.BytesList(value=[example_id]))
                 example = tf.train.Example(features=tf.train.Features(
                     feature=feat))
                 writer.write(example.SerializeToString())
         writer.close()
     self.manifest_manager.add_raw_data(0, fpaths, True)
コード例 #8
0
 def _check_merged(self):
     merge_dir = os.path.join(self._options.output_file_dir,
                              common.partition_repr(self._partition_id))
     merged_fname = common.encode_merged_sort_run_fname(self._partition_id)
     return len([f for f in gfile.ListDirectory(merge_dir)
                 if (os.path.basename(f) == merged_fname or \
                     os.path.basename(f) == '_SUCCESS')]) > 0
コード例 #9
0
 def __init__(self, options, kvstore_type, use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     kvstore = DBClient(kvstore_type, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(kvstore, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._callback_submitter = None
     # pre fock sub processor before launch grpc client
     self._process_pool_executor.submit(min, 1, 2).result()
     self._id_batch_fetcher = IdBatchFetcher(kvstore, self._options)
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._callback_submitter = concur_futures.ThreadPoolExecutor(1)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher,
             options.batch_processor_options.max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, self._callback_submitter,
             public_key, self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
             dj_pb.SortRunMergerOptions(
                 merger_name='sort_run_merger_'+\
                             partition_repr(options.partition_id),
                 reader_options=dj_pb.RawDataOptions(
                     raw_data_iter=options.writer_options.output_writer,
                     compressed_type=options.writer_options.compressed_type,
                     read_ahead_size=\
                         options.sort_run_merger_read_ahead_buffer,
                     read_batch_size=\
                         options.sort_run_merger_read_batch_size
                 ),
                 writer_options=options.writer_options,
                 output_file_dir=options.output_file_dir,
                 partition_id=options.partition_id
             ),
             self._merger_comparator
         )
     self._produce_item_cnt = 0
     self._comsume_item_cnt = 0
     self._started = False
コード例 #10
0
 def _create_merged_dir_if_need(self):
     merge_dir = os.path.join(self._options.output_file_dir,
                              common.partition_repr(self._partition_id))
     if gfile.Exists(merge_dir):
         assert gfile.IsDirectory(merge_dir)
     else:
         gfile.MakeDirs(merge_dir)
コード例 #11
0
def generate_input_csv(base_dir, start_id, end_id, partition_num):
    for partition_id in range(partition_num):
        dirpath = os.path.join(base_dir, common.partition_repr(partition_id))
        if not gfile.Exists(dirpath):
            gfile.MakeDirs(dirpath)
        assert gfile.IsDirectory(dirpath)
    csv_writers = [
        SortRunMergerWriter(base_dir, 0, partition_id, 'CSV_DICT')
        for partition_id in range(partition_num)
    ]
    for idx in range(start_id, end_id):
        if idx % 262144 == 0:
            logging.info("Process at index %d", idx)
        partition_id = CityHash32(str(idx)) % partition_num
        raw = OrderedDict()
        raw['raw_id'] = str(idx)
        raw['feat_0'] = str((partition_id << 30) + 0) + str(idx)
        raw['feat_1'] = str((partition_id << 30) + 1) + str(idx)
        raw['feat_2'] = str((partition_id << 30) + 2) + str(idx)
        csv_writers[partition_id].append(raw)
    for partition_id, csv_writer in enumerate(csv_writers):
        fpaths = csv_writer.finish()
        logging.info("partition %d dump %d files", partition_id, len(fpaths))
        for seq_id, fpath in enumerate(fpaths):
            logging.info("  %d. %s", seq_id, fpath)
        logging.info("---------------")
コード例 #12
0
 def test_compressed_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(
             path.dirname(path.abspath(__file__)), "../compressed_raw_data"
         )
     self.kvstore = DBClient('etcd', True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_RECORD',
             compressed_type='GZIP',
             read_ahead_size=1<<20,
             read_batch_size=128
         )
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
コード例 #13
0
 def generate_leader_raw_data(self):
     dbm = data_block_manager.DataBlockManager(self.data_source_l, 0)
     raw_data_dir = os.path.join(self.data_source_l.raw_data_dir,
                                 common.partition_repr(0))
     if gfile.Exists(raw_data_dir):
         gfile.DeleteRecursively(raw_data_dir)
     gfile.MakeDirs(raw_data_dir)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source_l, 0)
     block_index = 0
     builder = create_data_block_builder(
         dj_pb.DataBlockBuilderOptions(
             data_block_builder='TF_RECORD_DATABLOCK_BUILDER'),
         self.data_source_l.raw_data_dir,
         self.data_source_l.data_source_meta.name, 0, block_index, None)
     process_index = 0
     start_index = 0
     for i in range(0, self.leader_end_index + 3):
         if (i > 0 and i % 2048 == 0) or (i == self.leader_end_index + 2):
             meta = builder.finish_data_block()
             if meta is not None:
                 ofname = common.encode_data_block_fname(
                     self.data_source_l.data_source_meta.name, meta)
                 fpath = os.path.join(raw_data_dir, ofname)
                 self.manifest_manager.add_raw_data(0, [
                     dj_pb.RawDataMeta(
                         file_path=fpath,
                         timestamp=timestamp_pb2.Timestamp(seconds=3))
                 ], False)
                 process_index += 1
                 start_index += len(meta.example_ids)
             block_index += 1
             builder = create_data_block_builder(
                 dj_pb.DataBlockBuilderOptions(
                     data_block_builder='TF_RECORD_DATABLOCK_BUILDER'),
                 self.data_source_l.raw_data_dir,
                 self.data_source_l.data_source_meta.name, 0, block_index,
                 None)
         feat = {}
         pt = i + 1 << 30
         if i % 3 == 0:
             pt = i // 3
         example_id = '{}'.format(pt).encode()
         feat['example_id'] = tf.train.Feature(
             bytes_list=tf.train.BytesList(value=[example_id]))
         event_time = 150000000 + pt
         feat['event_time'] = tf.train.Feature(
             int64_list=tf.train.Int64List(value=[event_time]))
         example = tf.train.Example(features=tf.train.Features(
             feature=feat))
         builder.append_record(example.SerializeToString(), example_id,
                               event_time, i, i)
     fpaths = [
         os.path.join(raw_data_dir, f)
         for f in gfile.ListDirectory(raw_data_dir)
         if not gfile.IsDirectory(os.path.join(raw_data_dir, f))
     ]
     for fpath in fpaths:
         if not fpath.endswith(common.DataBlockSuffix):
             gfile.Remove(fpath)
コード例 #14
0
 def __init__(self, base_dir, partition_id):
     self._merged_dir = \
         os.path.join(base_dir, common.partition_repr(partition_id))
     self._partition_id = partition_id
     self._process_index = 0
     self._csv_dict_writer = None
     self._merged_fpaths = []
     self._merged_num = 0
コード例 #15
0
 def __init__(self,
              options,
              etcd_name,
              etcd_addrs,
              etcd_base_dir,
              use_mock_etcd=False):
     self._lock = threading.Condition()
     self._options = options
     etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd)
     pub_dir = self._options.raw_data_publish_dir
     self._publisher = RawDataPublisher(etcd, pub_dir)
     self._process_pool_executor = \
             concur_futures.ProcessPoolExecutor(
                     options.offload_processor_number
                 )
     self._id_batch_fetcher = IdBatchFetcher(etcd, self._options)
     max_flying_item = options.batch_processor_options.max_flying_item
     if self._options.role == common_pb.FLRole.Leader:
         private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = LeaderPsiRsaSigner(
             self._id_batch_fetcher,
             max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.slow_sign_threshold,
             self._process_pool_executor,
             private_key,
         )
         self._repr = 'leader-' + 'rsa_psi_preprocessor'
     else:
         public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem)
         self._psi_rsa_signer = FollowerPsiRsaSigner(
             self._id_batch_fetcher, max_flying_item,
             self._options.max_flying_sign_batch,
             self._options.max_flying_sign_rpc,
             self._options.sign_rpc_timeout_ms,
             self._options.slow_sign_threshold, self._options.stub_fanout,
             self._process_pool_executor, public_key,
             self._options.leader_rsa_psi_signer_addr)
         self._repr = 'follower-' + 'rsa_psi_preprocessor'
     self._sort_run_dumper = SortRunDumper(options)
     self._sort_run_merger = SortRunMerger(
             dj_pb.SortRunMergerOptions(
                 merger_name='sort_run_merger_'+\
                             partition_repr(options.partition_id),
                 reader_options=dj_pb.RawDataOptions(
                     raw_data_iter=options.writer_options.output_writer,
                     compressed_type=options.writer_options.compressed_type,
                     read_ahead_size=\
                         options.sort_run_merger_read_ahead_buffer
                 ),
                 writer_options=options.writer_options,
                 output_file_dir=options.output_file_dir,
                 partition_id=options.partition_id
             ),
             'example_id'
         )
     self._started = False
コード例 #16
0
 def _create_sort_run_readers(self, sort_runs):
     assert len(sort_runs) > 0
     readers = []
     for index, sort_run in enumerate(sort_runs):
         fpath = os.path.join(self._input_dir,
                              common.partition_repr(self._partition_id),
                              sort_run.encode_sort_run_fname())
         readers.append(SortRunReader(index, fpath))
     return readers
コード例 #17
0
 def __init__(self, input_dir, options):
     self._lock = threading.Lock()
     self._input_dir = input_dir
     self._options = options
     self._merge_finished = False
     self._merged_dir = os.path.join(
         self._options.output_file_dir,
         common.partition_repr(self._partition_id))
     self._create_merged_dir_if_need()
コード例 #18
0
 def __init__(self, options, partition_id):
     self._partition_id = partition_id
     self._writer = None
     self._begin_index = None
     self._end_index = None
     self._options = options
     self._size_bytes = 0
     self._tmp_fpath = common.gen_tmp_fpath(
         os.path.join(self._options.output_dir,
                      common.partition_repr(self._partition_id)))
コード例 #19
0
 def __init__(self, options, partition_id, process_index):
     self._options = options
     self._partition_id = partition_id
     self._process_index = process_index
     self._begin_index = None
     self._end_index = None
     self._buffer = []
     self._tmp_fpath = common.gen_tmp_fpath(
         os.path.join(self._options.output_dir,
                      common.partition_repr(self._partition_id)))
コード例 #20
0
 def _make_merge_options(self, task):
     merge_options = self._options.merge_options
     merge_options.output_builder = "TF_RECORD"
     merge_options.input_dir = os.path.join(task.map_base_dir, \
         common.partition_repr(task.partition_id))
     merge_options.output_dir = task.reduce_base_dir
     merge_options.partition_id = task.partition_id
     merge_options.fpath.extend(gfile.ListDirectory(
         merge_options.input_dir))
     return merge_options
コード例 #21
0
 def __init__(self, base_dir, process_index, partition_id, writer_options):
     self._merged_dir = \
         os.path.join(base_dir, common.partition_repr(partition_id))
     self._partition_id = partition_id
     self._writer_options = writer_options
     self._process_index = process_index
     self._writer = None
     self._merged_fpaths = []
     self._merged_num = 0
     self._tmp_fpath = None
コード例 #22
0
 def finish(self):
     if self._begin_index is not None \
         and self._end_index is not None:
         self._writer.close()
         meta = Merge.FileMeta(self._partition_id, self._begin_index,
                               self._end_index)
         fpath = os.path.join(self._options.output_dir,
                              common.partition_repr(self._partition_id),
                              meta.encode_meta_to_fname())
         gfile.Rename(self.get_tmp_fpath(), fpath, True)
         self._writer = None
コード例 #23
0
 def _list_file_metas(self, partition_id):
     dumped_dir = os.path.join(self._options.output_dir,
                               common.partition_repr(partition_id))
     if not gfile.Exists(dumped_dir):
         gfile.MakeDirs(dumped_dir)
     assert gfile.IsDirectory(dumped_dir)
     fnames = [os.path.basename(f) for f in gfile.ListDirectory(dumped_dir)
               if f.endswith(common.RawDataFileSuffix)]
     metas = [RawDataPartitioner.FileMeta.decode_meta_from_fname(f)
              for f in fnames]
     return [meta for meta in metas \
             if meta.rank_id == self._options.partitioner_rank_id]
コード例 #24
0
 def _sort_run_merge_fn(self):
     sort_runs = self._sort_run_dumper.get_all_sort_runs()
     input_dir = self._sort_run_dumper.sort_run_dump_dir()
     input_fpaths = [os.path.join(input_dir,
                                  partition_repr(self._options.partition_id),
                                  sort_run.encode_sort_run_fname())
                     for sort_run in sort_runs]
     output_fpaths = self._sort_run_merger.merge_sort_runs(input_fpaths)
     self._publisher.publish_raw_data(self._options.partition_id,
                                      output_fpaths)
     self._publisher.finish_raw_data(self._options.partition_id)
     self._sort_run_merger.set_merged_finished()
コード例 #25
0
 def __init__(self, options, partition_id, process_index):
     self._options = options
     self._partition_id = partition_id
     self._process_index = process_index
     self._begin_index = None
     self._end_index = None
     self._writer = None
     self._tmp_fpath = os.path.join(
         self._options.output_dir,
         common.partition_repr(self._partition_id),
         '{}-{}.tmp'.format(str(uuid.uuid1()), self.TMP_COUNTER))
     self.TMP_COUNTER += 1
コード例 #26
0
 def _run_reduce_task(self, task):
     merger_options = self._make_merger_options(task)
     sort_run_merger = SortRunMerger(merger_options, 'event_time')
     input_dir = os.path.join(task.map_base_dir,
                              common.partition_repr(task.partition_id))
     input_fpaths = [
         os.path.join(input_dir, f) for f in gfile.ListDirectory(input_dir)
         if f.endswith(common.RawDataFileSuffix)
     ]
     logging.info("Merger input_dir:%s(with %d files) rank_id:%s "\
                  "partition_id:%d start", task.map_base_dir,
                  len(input_fpaths), self._rank_id, task.partition_id)
     sort_run_merger.merge_sort_runs(input_fpaths)
コード例 #27
0
 def finish(self):
     meta = None
     if self._writer is not None:
         self._writer.close()
         self._writer = None
         meta = RawDataPartitioner.FileMeta(
             self._options.partitioner_rank_id, self._process_index,
             self._begin_index, self._end_index)
         fpath = os.path.join(self._options.output_dir,
                              common.partition_repr(self._partition_id),
                              meta.encode_meta_to_fname())
         gfile.Rename(self.get_tmp_fpath(), fpath, True)
     return meta
コード例 #28
0
 def setUp(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = "./raw_data"
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.raw_data_dir, common.partition_repr(0))
     if gfile.Exists(partition_dir):
         gfile.DeleteRecursively(partition_dir)
     gfile.MakeDirs(partition_dir)
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
コード例 #29
0
 def __init__(self, options, partition_id):
     self._options = options
     self._partition_id = partition_id
     self._process_index = 0
     self._writer = None
     self._dumped_item = 0
     self._output_fpaths = []
     self._output_dir = os.path.join(
             self._options.output_dir,
             common.partition_repr(self._partition_id)
         )
     if not gfile.Exists(self._output_dir):
         gfile.MakeDirs(self._output_dir)
     assert gfile.IsDirectory(self._output_dir)
コード例 #30
0
 def _check_merge(self, reduce_task):
     dpath = os.path.join(self._merge_output_dir, \
         common.partition_repr(reduce_task.partition_id))
     fpaths = gfile.ListDirectory(dpath)
     fpaths = sorted(fpaths, key=lambda fpath: fpath, reverse=False)
     event_time = 0
     total_cnt = 0
     for fpath in fpaths:
         fpath = os.path.join(dpath, fpath)
         logging.info("check merge path:{}".format(fpath))
         for record in tf.python_io.tf_record_iterator(fpath):
             tf_item = TfExampleItem(record)
             self.assertTrue(tf_item.event_time >= event_time)
             event_time = tf_item.event_time
             total_cnt += 1
     return total_cnt