def test_data_block_dumper(self): self.generate_follower_data_block() self.generate_leader_raw_data() options = customized_options.CustomizedOptions() options.set_raw_data_iter('TF_RECORD') dbd = data_block_dumper.DataBlockDumperManager( self.etcd, self.data_source_l, 0, options ) self.assertEqual(dbd.get_partition_id(), 0) self.assertEqual(dbd.get_next_data_block_index(), 0) for (idx, meta) in enumerate(self.dumped_metas): success, next_index = dbd.append_synced_data_block_meta(meta) self.assertTrue(success) self.assertEqual(next_index, idx + 1) self.assertTrue(dbd.need_dump()) self.assertEqual(dbd.get_next_data_block_index(), len(self.dumped_metas)) dbd.dump_data_blocks() dbm_f = data_block_manager.DataBlockManager(self.data_source_f, 0) dbm_l = data_block_manager.DataBlockManager(self.data_source_l, 0) self.assertEqual(dbm_f.get_dumped_data_block_num(), len(self.dumped_metas)) self.assertEqual(dbm_f.get_dumped_data_block_num(), dbm_l.get_dumped_data_block_num()) for (idx, meta) in enumerate(self.dumped_metas): self.assertEqual(meta.data_block_index, idx) self.assertEqual(dbm_l.get_data_block_meta_by_index(idx)[0], meta) self.assertEqual(dbm_f.get_data_block_meta_by_index(idx)[0], meta) block_id = meta.block_id meta_fpth_l = os.path.join(self.data_source_l.data_block_dir, 'partition_0', block_id + common.DataBlockMetaSuffix) mitr = tf.io.tf_record_iterator(meta_fpth_l) meta_l = dj_pb.DataBlockMeta() meta_l.ParseFromString(next(mitr)) self.assertEqual(meta_l, meta) meta_fpth_f = os.path.join(self.data_source_f.data_block_dir, 'partition_0', block_id + common.DataBlockMetaSuffix) mitr = tf.io.tf_record_iterator(meta_fpth_f) meta_f = dj_pb.DataBlockMeta() meta_f.ParseFromString(next(mitr)) self.assertEqual(meta_f, meta) data_fpth_l = os.path.join(self.data_source_l.data_block_dir, 'partition_0', block_id + common.DataBlockSuffix) for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_l)): example = tf.train.Example() example.ParseFromString(record) feat = example.features.feature self.assertEqual(feat['example_id'].bytes_list.value[0], meta.example_ids[iidx]) self.assertEqual(len(meta.example_ids), iidx + 1) data_fpth_f = os.path.join(self.data_source_f.data_block_dir, 'partition_0', block_id + common.DataBlockSuffix) for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_f)): example = tf.train.Example() example.ParseFromString(record) feat = example.features.feature self.assertEqual(feat['example_id'].bytes_list.value[0], meta.example_ids[iidx]) self.assertEqual(len(meta.example_ids), iidx + 1)
def _create_data_block(self, partition_id): dbm = data_block_manager.DataBlockManager(self.data_source, partition_id) self.assertEqual(dbm.get_dumped_data_block_count(), 0) self.assertEqual(dbm.get_lastest_data_block_meta(), None) leader_index = 0 follower_index = 65536 for i in range(64): builder = DataBlockBuilder( common.data_source_data_block_dir(self.data_source), self.data_source.data_source_meta.name, partition_id, i, dj_pb.WriterOptions(output_writer='TF_RECORD'), None ) builder.set_data_block_manager(dbm) for j in range(4): feat = {} example_id = '{}'.format(i * 1024 + j).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = random.randint(0, 10) feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) feat['leader_index'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[leader_index])) feat['follower_index'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[follower_index])) example = tf.train.Example(features=tf.train.Features(feature=feat)) builder.append_item(TfExampleItem(example.SerializeToString()), leader_index, follower_index) leader_index += 1 follower_index += 1 self.data_block_matas.append(builder.finish_data_block())
def generate_leader_raw_data(self): dbm = data_block_manager.DataBlockManager(self.data_source_l, 0) raw_data_dir = os.path.join(self.data_source_l.raw_data_dir, common.partition_repr(0)) if gfile.Exists(raw_data_dir): gfile.DeleteRecursively(raw_data_dir) gfile.MakeDirs(raw_data_dir) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source_l, 0) block_index = 0 builder = create_data_block_builder( dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER'), self.data_source_l.raw_data_dir, self.data_source_l.data_source_meta.name, 0, block_index, None) process_index = 0 start_index = 0 for i in range(0, self.leader_end_index + 3): if (i > 0 and i % 2048 == 0) or (i == self.leader_end_index + 2): meta = builder.finish_data_block() if meta is not None: ofname = common.encode_data_block_fname( self.data_source_l.data_source_meta.name, meta) fpath = os.path.join(raw_data_dir, ofname) self.manifest_manager.add_raw_data(0, [ dj_pb.RawDataMeta( file_path=fpath, timestamp=timestamp_pb2.Timestamp(seconds=3)) ], False) process_index += 1 start_index += len(meta.example_ids) block_index += 1 builder = create_data_block_builder( dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER'), self.data_source_l.raw_data_dir, self.data_source_l.data_source_meta.name, 0, block_index, None) feat = {} pt = i + 1 << 30 if i % 3 == 0: pt = i // 3 example_id = '{}'.format(pt).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + pt feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append_record(example.SerializeToString(), example_id, event_time, i, i) fpaths = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in fpaths: if not fpath.endswith(common.DataBlockSuffix): gfile.Remove(fpath)
def generate_raw_data(self, data_source, partition_id, block_size, shuffle_win_size, feat_key_fmt, feat_val_fmt): dbm = data_block_manager.DataBlockManager(data_source, partition_id) raw_data_dir = os.path.join(data_source.raw_data_dir, 'partition_{}'.format(partition_id)) if gfile.Exists(raw_data_dir): gfile.DeleteRecursively(raw_data_dir) gfile.MakeDirs(raw_data_dir) useless_index = 0 for block_index in range(self.total_index // block_size): builder = data_block_manager.DataBlockBuilder( data_source.raw_data_dir, partition_id, block_index, None) cands = list( range(block_index * block_size, (block_index + 1) * block_size)) start_index = cands[0] for i in range(len(cands)): if random.randint(1, 4) > 2: continue a = random.randint(i - shuffle_win_size, i + shuffle_win_size) b = random.randint(i - shuffle_win_size, i + shuffle_win_size) if a < 0: a = 0 if a >= len(cands): a = len(cands) - 1 if b < 0: b = 0 if b >= len(cands): b = len(cands) - 1 if (abs(cands[a] - i - start_index) <= shuffle_win_size and abs(cands[b] - i - start_index) <= shuffle_win_size): cands[a], cands[b] = cands[b], cands[a] for example_idx in cands: feat = {} example_id = '{}'.format(example_idx).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + example_idx feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) feat[feat_key_fmt.format(example_idx)] = tf.train.Feature( bytes_list=tf.train.BytesList( value=[feat_val_fmt.format(example_idx).encode()])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append(example.SerializeToString(), example_id, event_time, useless_index, useless_index) useless_index += 1 builder.finish_data_block() fpaths = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in fpaths: if not fpath.endswith(common.DataBlockSuffix): gfile.Remove(fpath)
def test_example_joiner(self): sei = joiner_impl.create_example_joiner( self.example_joiner_options, self.raw_data_options, dj_pb.WriterOptions(output_writer='TF_RECORD'), self.kvstore, self.data_source, 0) metas = [] with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.assertEqual(len(metas), 0) self.generate_raw_data(0, 2 * 2048) dumper = example_id_dumper.ExampleIdDumperManager( self.kvstore, self.data_source, 0, self.example_id_dump_options) self.generate_example_id(dumper, 0, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_raw_data(2 * 2048, 2048) self.generate_example_id(dumper, 3 * 2048, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_raw_data(3 * 2048, 5 * 2048) self.generate_example_id(dumper, 6 * 2048, 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_raw_data(8 * 2048, 2 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_example_id(dumper, 7 * 2048, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) sei.set_sync_example_id_finished() sei.set_raw_data_finished() with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) dbm = data_block_manager.DataBlockManager(self.data_source, 0) data_block_num = dbm.get_dumped_data_block_count() self.assertEqual(len(metas), data_block_num) join_count = 0 for data_block_index in range(data_block_num): meta = dbm.get_data_block_meta_by_index(data_block_index) self.assertEqual(meta, metas[data_block_index]) join_count += len(meta.example_ids) print("join rate {}/{}({}), min_matching_window {}, "\ "max_matching_window {}".format( join_count, 20480, (join_count+.0)/(10 * 2048), self.example_joiner_options.min_matching_window, self.example_joiner_options.max_matching_window))
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = 1 data_source.data_block_dir = "./data_block" self.data_source = data_source if gfile.Exists(data_source.data_block_dir): gfile.DeleteRecursively(data_source.data_block_dir) self.data_block_manager = data_block_manager.DataBlockManager( data_source, 0) self.assertEqual(self.data_block_manager.get_dumped_data_block_count(), 0) self.assertEqual(self.data_block_manager.get_lastest_data_block_meta(), None)
def _create_data_block(self, data_source, partition_id, x, y): data_block_metas = [] dbm = data_block_manager.DataBlockManager(data_source, partition_id) self.assertEqual(dbm.get_dumped_data_block_count(), 0) self.assertEqual(dbm.get_lastest_data_block_meta(), None) N = 200 chunk_size = x.shape[0] // N leader_index = 0 follower_index = N * chunk_size * 10 for i in range(N): builder = DataBlockBuilder( common.data_source_data_block_dir(data_source), data_source.data_source_meta.name, partition_id, i, dj_pb.WriterOptions(output_writer="TF_RECORD"), None ) builder.set_data_block_manager(dbm) for j in range(chunk_size): feat = {} idx = i * chunk_size + j exam_id = '{}'.format(idx).encode() feat['example_id'] = Feature( bytes_list=BytesList(value=[exam_id])) evt_time = random.randint(1, 1000) feat['event_time'] = Feature( int64_list = Int64List(value=[evt_time]) ) feat['x'] = Feature(float_list=FloatList(value=list(x[idx]))) if y is not None: feat['y'] = Feature(int64_list=Int64List(value=[y[idx]])) feat['leader_index'] = Feature( int64_list = Int64List(value=[leader_index]) ) feat['follower_index'] = Feature( int64_list = Int64List(value=[follower_index]) ) example = Example(features=Features(feature=feat)) builder.append_item(TfExampleItem(example.SerializeToString()), leader_index, follower_index) leader_index += 1 follower_index += 1 data_block_metas.append(builder.finish_data_block()) self.max_index = follower_index return data_block_metas
def __init__(self, base_path, name, role, partition_num=1, start_time=0, end_time=100000): if role == 'leader': role = 0 elif role == 'follower': role = 1 else: raise ValueError("Unknown role %s" % role) data_source = common_pb.DataSource() data_source.data_source_meta.name = name data_source.data_source_meta.partition_num = partition_num data_source.data_source_meta.start_time = start_time data_source.data_source_meta.end_time = end_time data_source.output_base_dir = "{}/{}_{}/data_source/".format( base_path, data_source.data_source_meta.name, role) data_source.role = role if gfile.Exists(data_source.output_base_dir): gfile.DeleteRecursively(data_source.output_base_dir) self._data_source = data_source db_database, db_addr, db_username, db_password, db_base_dir = \ get_kvstore_config("etcd") self._kv_store = mysql_client.DBClient(db_database, db_addr, db_username, db_password, db_base_dir, True) common.commit_data_source(self._kv_store, self._data_source) self._dbms = [] for i in range(partition_num): manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self._kv_store, self._data_source) manifest_manager._finish_partition('join_example_rep', dj_pb.JoinExampleState.UnJoined, dj_pb.JoinExampleState.Joined, -1, i) self._dbms.append( data_block_manager.DataBlockManager(self._data_source, i))
def generate_leader_raw_data(self): dbm = data_block_manager.DataBlockManager(self.data_source_l, 0) raw_data_dir = os.path.join(self.data_source_l.raw_data_dir, 'partition_{}'.format(0)) if gfile.Exists(raw_data_dir): gfile.DeleteRecursively(raw_data_dir) gfile.MakeDirs(raw_data_dir) block_index = 0 builder = data_block_manager.DataBlockBuilder( self.data_source_l.raw_data_dir, 0, block_index, None) for i in range(0, self.leader_end_index + 3): if i > 0 and i % 2048 == 0: builder.finish_data_block() block_index += 1 builder = data_block_manager.DataBlockBuilder( self.data_source_l.raw_data_dir, 0, block_index, None) feat = {} pt = i + 1 << 30 if i % 3 == 0: pt = i // 3 example_id = '{}'.format(pt).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + pt feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append(example.SerializeToString(), example_id, event_time, i, i) builder.finish_data_block() fpaths = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in fpaths: if not fpath.endswith(common.DataBlockSuffix): gfile.Remove(fpath) self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source_l)
def test_example_join(self): self.generate_raw_data() self.generate_example_id() customized_options.set_example_joiner('STREAM_JOINER') sei = joiner_impl.create_example_joiner(self.etcd, self.data_source, 0) sei.join_example() self.assertTrue(sei.join_finished()) dbm = data_block_manager.DataBlockManager(self.data_source, 0) data_block_num = dbm.get_dumped_data_block_num() join_count = 0 for data_block_index in range(data_block_num): meta = dbm.get_data_block_meta_by_index(data_block_index)[0] self.assertTrue(meta is not None) join_count += len(meta.example_ids) print("join rate {}/{}({}), min_matching_window {}, "\ "max_matching_window {}".format( join_count, self.total_index, (join_count+.0)/self.total_index, self.data_source.data_source_meta.min_matching_window, self.data_source.data_source_meta.max_matching_window))
def run_join_small_follower(self, sei, rate): metas = [] with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.assertEqual(len(metas), 0) self.generate_raw_data(8, 2 * 2048) dumper = example_id_dumper.ExampleIdDumperManager( self.kvstore, self.data_source, 0, self.example_id_dump_options) self.generate_example_id(dumper, 0, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) sei.set_raw_data_finished() self.generate_example_id(dumper, 3 * 2048, 7 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) sei.set_sync_example_id_finished() with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) dbm = data_block_manager.DataBlockManager(self.data_source, 0) data_block_num = dbm.get_dumped_data_block_count() self.assertEqual(len(metas), data_block_num) join_count = 0 for data_block_index in range(data_block_num): meta = dbm.get_data_block_meta_by_index(data_block_index) self.assertEqual(meta, metas[data_block_index]) join_count += len(meta.example_ids) print("join rate {}/{}({}), min_matching_window {}, "\ "max_matching_window {}".format( join_count, 20480 * 2, (join_count+.0)/(10 * 2048 * 2), self.example_joiner_options.min_matching_window, self.example_joiner_options.max_matching_window)) self.assertTrue((join_count + .0) / (10 * 2048 * 2) >= rate)
def generate_follower_data_block(self): dbm = data_block_manager.DataBlockManager(self.data_source_f, 0) self.assertEqual(dbm.get_dumped_data_block_count(), 0) self.assertEqual(dbm.get_lastest_data_block_meta(), None) leader_index = 0 follower_index = 65536 self.dumped_metas = [] for i in range(5): builder = create_data_block_builder( dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER'), self.data_source_f.data_block_dir, self.data_source_f.data_source_meta.name, 0, i, None) builder.set_data_block_manager(dbm) for j in range(1024): feat = {} example_id = '{}'.format(i * 1024 + j).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + i * 1024 + j feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) feat['leader_index'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[leader_index])) feat['follower_index'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[follower_index])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append_record(example.SerializeToString(), example_id, event_time, leader_index, follower_index) leader_index += 3 follower_index += 1 meta = builder.finish_data_block() self.dumped_metas.append(meta) self.leader_start_index = 0 self.leader_end_index = leader_index self.assertEqual(dbm.get_dumped_data_block_count(), 5) for (idx, meta) in enumerate(self.dumped_metas): self.assertEqual(dbm.get_data_block_meta_by_index(idx), meta)
def generate_follower_data_block(self): dbm = data_block_manager.DataBlockManager(self.data_source_f, 0) self.assertEqual(dbm.get_dumped_data_block_num(), 0) self.assertEqual(dbm.get_last_data_block_meta(), None) leader_index = 0 follower_index = 65536 self.dumped_metas = [] for i in range(5): builder = data_block_manager.DataBlockBuilder( self.data_source_f.data_block_dir, 0, i, None) for j in range(1024): feat = {} example_id = '{}'.format(i * 1024 + j).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + i * 1024 + j feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) feat['leader_index'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[leader_index])) feat['follower_index'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[follower_index])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append(example.SerializeToString(), example_id, event_time, leader_index, follower_index) leader_index += 3 follower_index += 1 builder.finish_data_block() meta = builder.get_data_block_meta() self.dumped_metas.append(meta) dbm.add_dumped_data_block_meta(meta) self.leader_start_index = 0 self.leader_end_index = leader_index self.assertEqual(dbm.get_dumped_data_block_num(True), 5) for (idx, meta) in enumerate(self.dumped_metas): self.assertEqual(dbm.get_data_block_meta_by_index(idx)[0], meta) self.assertEqual(dbm.get_dumped_data_block_num(True), 5)
def test_data_block_manager(self): data_block_datas = [] data_block_metas = [] leader_index = 0 follower_index = 65536 for i in range(5): fill_examples = [] builder = DataBlockBuilder( self.data_source.data_block_dir, self.data_source.data_source_meta.name, 0, i, dj_pb.WriterOptions(output_writer='TF_RECORD'), None) builder.set_data_block_manager(self.data_block_manager) for j in range(1024): feat = {} example_id = '{}'.format(i * 1024 + j).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + i * 1024 + j feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) feat['leader_index'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[leader_index])) feat['follower_index'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[follower_index])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append_item(TfExampleItem(example.SerializeToString()), leader_index, follower_index) fill_examples.append((example, { 'example_id': example_id, 'event_time': event_time, 'leader_index': leader_index, 'follower_index': follower_index })) leader_index += 1 follower_index += 1 meta = builder.finish_data_block() data_block_datas.append(fill_examples) data_block_metas.append(meta) self.assertEqual(self.data_block_manager.get_dumped_data_block_count(), 5) self.assertEqual(self.data_block_manager.get_lastest_data_block_meta(), data_block_metas[-1]) for (idx, meta) in enumerate(data_block_metas): self.assertEqual( self.data_block_manager.get_data_block_meta_by_index(idx), meta) self.assertEqual( meta.block_id, common.encode_block_id(self.data_source.data_source_meta.name, meta)) self.assertEqual( self.data_block_manager.get_data_block_meta_by_index(5), None) data_block_dir = os.path.join(self.data_source.data_block_dir, common.partition_repr(0)) for (i, meta) in enumerate(data_block_metas): data_block_fpath = os.path.join( data_block_dir, meta.block_id) + common.DataBlockSuffix data_block_meta_fpath = os.path.join( data_block_dir, common.encode_data_block_meta_fname( self.data_source.data_source_meta.name, 0, meta.data_block_index)) self.assertTrue(gfile.Exists(data_block_fpath)) self.assertTrue(gfile.Exists(data_block_meta_fpath)) fiter = tf.io.tf_record_iterator(data_block_meta_fpath) remote_meta = text_format.Parse( next(fiter).decode(), dj_pb.DataBlockMeta()) self.assertEqual(meta, remote_meta) for (j, record) in enumerate( tf.io.tf_record_iterator(data_block_fpath)): example = tf.train.Example() example.ParseFromString(record) stored_data = data_block_datas[i][j] self.assertEqual(example, stored_data[0]) feat = example.features.feature stored_feat = stored_data[1] self.assertTrue('example_id' in feat) self.assertTrue('example_id' in stored_feat) self.assertEqual(stored_feat['example_id'], '{}'.format(i * 1024 + j).encode()) self.assertEqual(stored_feat['example_id'], feat['example_id'].bytes_list.value[0]) self.assertTrue('event_time' in feat) self.assertTrue('event_time' in stored_feat) self.assertEqual(stored_feat['event_time'], feat['event_time'].int64_list.value[0]) self.assertTrue('leader_index' in feat) self.assertTrue('leader_index' in stored_feat) self.assertEqual(stored_feat['leader_index'], feat['leader_index'].int64_list.value[0]) self.assertTrue('follower_index' in feat) self.assertTrue('follower_index' in stored_feat) self.assertEqual(stored_feat['follower_index'], feat['follower_index'].int64_list.value[0]) self.assertEqual(j, 1023) data_block_manager2 = data_block_manager.DataBlockManager( self.data_source, 0) self.assertEqual(self.data_block_manager.get_dumped_data_block_count(), 5)
def generate_raw_data(self, start_index, etcd, rdp, data_source, raw_data_base_dir, partition_id, block_size, shuffle_win_size, feat_key_fmt, feat_val_fmt): dbm = data_block_manager.DataBlockManager(data_source, partition_id) raw_data_dir = os.path.join(raw_data_base_dir, common.partition_repr(partition_id)) if not gfile.Exists(raw_data_dir): gfile.MakeDirs(raw_data_dir) useless_index = 0 new_raw_data_fnames = [] for block_index in range(start_index // block_size, (start_index + self.total_index) // block_size): builder = DataBlockBuilder( raw_data_base_dir, data_source.data_source_meta.name, partition_id, block_index, dj_pb.WriterOptions(output_writer='TF_RECORD'), None ) cands = list(range(block_index * block_size, (block_index + 1) * block_size)) start_index = cands[0] for i in range(len(cands)): if random.randint(1, 4) > 2: continue a = random.randint(i - shuffle_win_size, i + shuffle_win_size) b = random.randint(i - shuffle_win_size, i + shuffle_win_size) if a < 0: a = 0 if a >= len(cands): a = len(cands) - 1 if b < 0: b = 0 if b >= len(cands): b = len(cands) - 1 if (abs(cands[a]-i-start_index) <= shuffle_win_size and abs(cands[b]-i-start_index) <= shuffle_win_size): cands[a], cands[b] = cands[b], cands[a] for example_idx in cands: feat = {} example_id = '{}'.format(example_idx).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + example_idx feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) feat[feat_key_fmt.format(example_idx)] = tf.train.Feature( bytes_list=tf.train.BytesList( value=[feat_val_fmt.format(example_idx).encode()])) example = tf.train.Example(features=tf.train.Features(feature=feat)) builder.append_item(TfExampleItem(example.SerializeToString()), useless_index, useless_index) useless_index += 1 meta = builder.finish_data_block() fname = common.encode_data_block_fname( data_source.data_source_meta.name, meta ) new_raw_data_fnames.append(os.path.join(raw_data_dir, fname)) fpaths = [os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f))] for fpath in fpaths: if fpath.endswith(common.DataBlockMetaSuffix): gfile.Remove(fpath) rdp.publish_raw_data(partition_id, new_raw_data_fnames)
def test_data_block_dumper(self): self.generate_follower_data_block() self.generate_leader_raw_data() dbd = data_block_dumper.DataBlockDumperManager( self.etcd, self.data_source_l, 0, dj_pb.RawDataOptions(raw_data_iter='TF_RECORD'), dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER'), ) self.assertEqual(dbd.get_next_data_block_index(), 0) for (idx, meta) in enumerate(self.dumped_metas): success, next_index = dbd.add_synced_data_block_meta(meta) self.assertTrue(success) self.assertEqual(next_index, idx + 1) self.assertTrue(dbd.need_dump()) self.assertEqual(dbd.get_next_data_block_index(), len(self.dumped_metas)) with dbd.make_data_block_dumper() as dumper: dumper() dbm_f = data_block_manager.DataBlockManager(self.data_source_f, 0) dbm_l = data_block_manager.DataBlockManager(self.data_source_l, 0) self.assertEqual(dbm_f.get_dumped_data_block_count(), len(self.dumped_metas)) self.assertEqual(dbm_f.get_dumped_data_block_count(), dbm_l.get_dumped_data_block_count()) for (idx, meta) in enumerate(self.dumped_metas): self.assertEqual(meta.data_block_index, idx) self.assertEqual(dbm_l.get_data_block_meta_by_index(idx), meta) self.assertEqual(dbm_f.get_data_block_meta_by_index(idx), meta) meta_fpth_l = os.path.join( self.data_source_l.data_block_dir, common.partition_repr(0), common.encode_data_block_meta_fname( self.data_source_l.data_source_meta.name, 0, meta.data_block_index)) mitr = tf.io.tf_record_iterator(meta_fpth_l) meta_l = text_format.Parse(next(mitr), dj_pb.DataBlockMeta()) self.assertEqual(meta_l, meta) meta_fpth_f = os.path.join( self.data_source_f.data_block_dir, common.partition_repr(0), common.encode_data_block_meta_fname( self.data_source_f.data_source_meta.name, 0, meta.data_block_index)) mitr = tf.io.tf_record_iterator(meta_fpth_f) meta_f = text_format.Parse(next(mitr), dj_pb.DataBlockMeta()) self.assertEqual(meta_f, meta) data_fpth_l = os.path.join( self.data_source_l.data_block_dir, common.partition_repr(0), common.encode_data_block_fname( self.data_source_l.data_source_meta.name, meta_l)) for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_l)): example = tf.train.Example() example.ParseFromString(record) feat = example.features.feature self.assertEqual(feat['example_id'].bytes_list.value[0], meta.example_ids[iidx]) self.assertEqual(len(meta.example_ids), iidx + 1) data_fpth_f = os.path.join( self.data_source_f.data_block_dir, common.partition_repr(0), common.encode_data_block_fname( self.data_source_l.data_source_meta.name, meta_f)) for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_f)): example = tf.train.Example() example.ParseFromString(record) feat = example.features.feature self.assertEqual(feat['example_id'].bytes_list.value[0], meta.example_ids[iidx]) self.assertEqual(len(meta.example_ids), iidx + 1)