def test_dumped_example_visitor(self): visitor = example_id_visitor.ExampleIdVisitor(self.kvstore, self.data_source, 0) expected_index = 0 for (index, example) in visitor: self.assertEqual(index, expected_index) self.assertEqual('{}'.format(index).encode(), example.example_id) self.assertEqual(150000000+index, example.event_time) self.assertEqual(index, example.index) expected_index += 1 self.assertEqual(0, expected_index) self.assertRaises(StopIteration, visitor.seek, 200) self.assertTrue(visitor.finished()) dumper = example_id_dumper.ExampleIdDumperManager( self.kvstore, self.data_source, 0, self.example_id_dump_options ) self.assertEqual(dumper.get_next_index(), 0) self._dump_example_ids(dumper, 0, 10, 1024) self.assertTrue(visitor.is_visitor_stale()) visitor.active_visitor() for (index, example) in visitor: self.assertEqual(index, expected_index) self.assertEqual('{}'.format(index).encode(), example.example_id) self.assertEqual(150000000+index, example.event_time) self.assertEqual(index, example.index) expected_index += 1 self.assertEqual(10240, expected_index) self.assertTrue(visitor.finished()) visitor.seek(200) expected_index = 200 self.assertEqual(expected_index, visitor.get_index()) self.assertEqual(expected_index, visitor.get_item().index) self.assertEqual(150000000+expected_index, visitor.get_item().event_time) dumper2 = example_id_dumper.ExampleIdDumperManager( self.kvstore, self.data_source, 0, self.example_id_dump_options ) self._dump_example_ids(dumper2, 10240, 10, 1024) expected_index += 1 self.assertTrue(visitor.is_visitor_stale()) visitor.active_visitor() for (index, example) in visitor: self.assertEqual(index, expected_index) self.assertEqual('{}'.format(index).encode(), example.example_id) self.assertEqual(150000000+index, example.event_time) self.assertEqual(index, example.index) expected_index += 1 self.assertEqual(10240 * 2, expected_index) visitor2 = example_id_visitor.ExampleIdVisitor(self.kvstore, self.data_source, 0) visitor2.seek(886) expected_index = 886 self.assertEqual(expected_index, visitor2.get_index()) self.assertEqual(expected_index, visitor2.get_item().index) self.assertEqual(150000000+expected_index, visitor2.get_item().event_time) expected_index += 1 for (index, example) in visitor2: self.assertEqual(index, expected_index) self.assertEqual('{}'.format(index).encode(), example.example_id) self.assertEqual(150000000+index, example.event_time) self.assertEqual(index, example.index) expected_index += 1 self.assertEqual(10240 * 2, expected_index)
def test_example_id_dumper(self): example_id_dumper1 = example_id_dumper.ExampleIdDumperManager( self.kvstore, self.data_source, 0, self.example_id_dump_options) self.assertEqual(example_id_dumper1.get_next_index(), 0) self._dump_example_ids(example_id_dumper1, 0, 10, 1024) example_id_manager = \ example_id_visitor.ExampleIdManager(self.kvstore, self.data_source, 0, True) last_dumped_index = example_id_manager.get_last_dumped_index() self.assertEqual(last_dumped_index, 10240 - 1) example_id_dumper2 = example_id_dumper.ExampleIdDumperManager( self.kvstore, self.data_source, 0, self.example_id_dump_options) self.assertEqual(example_id_dumper2.get_next_index(), 10240) self._dump_example_ids(example_id_dumper2, 10 * 1024, 10, 1024) last_dumped_index = example_id_manager.get_last_dumped_index() self.assertEqual(last_dumped_index, 2 * 10240 - 1)
def generate_example_id(self): eid = example_id_dumper.ExampleIdDumperManager(self.data_source, 0) for req_index in range(self.total_index // 512): req = dj_pb.SyncExamplesRequest( data_source_meta=self.data_source.data_source_meta, partition_id=0, begin_index=req_index * 512) cands = list(range(req_index * 512, (req_index + 1) * 512)) start_index = cands[0] for i in range(len(cands)): if random.randint(1, 4) > 1: continue a = random.randint(i - 32, i + 32) b = random.randint(i - 32, i + 32) if a < 0: a = 0 if a >= len(cands): a = len(cands) - 1 if b < 0: b = 0 if b >= len(cands): b = len(cands) - 1 if (abs(cands[a] - i - start_index) <= 32 and abs(cands[b] - i - start_index) <= 32): cands[a], cands[b] = cands[b], cands[a] for example_idx in cands: req.example_id.append('{}'.format(example_idx).encode()) req.event_time.append(150000000 + example_idx) eid.append_synced_example_req(req) self.assertEqual(eid.get_next_index(), (req_index + 1) * 512) eid.finish_sync_example() self.assertTrue(eid.need_dump()) eid.dump_example_ids()
def setUp(self): data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = 1 data_source.example_dumped_dir = "./example_ids" self.data_source = data_source if gfile.Exists(self.data_source.example_dumped_dir): gfile.DeleteRecursively(self.data_source.example_dumped_dir) self.partition_dir = os.path.join(self.data_source.example_dumped_dir, 'partition_0') gfile.MakeDirs(self.partition_dir) self._example_id_dumper = example_id_dumper.ExampleIdDumperManager( self.data_source, 0) self.assertEqual(self._example_id_dumper.get_next_index(), 0) index = 0 for i in range(5): req = dj_pb.SyncExamplesRequest( data_source_meta=data_source.data_source_meta, partition_id=0, begin_index=index) for j in range(1 << 15): req.example_id.append('{}'.format(index).encode()) req.event_time.append(150000000 + index) self.end_index = index index += 1 self._example_id_dumper.append_synced_example_req(req) self.assertEqual(self._example_id_dumper.get_next_index(), index) self._example_id_dumper.finish_sync_example() self.assertTrue(self._example_id_dumper.need_dump()) self._example_id_dumper.dump_example_ids()
def test_example_joiner(self): sei = joiner_impl.create_example_joiner( self.example_joiner_options, self.raw_data_options, dj_pb.WriterOptions(output_writer='TF_RECORD'), self.kvstore, self.data_source, 0) metas = [] with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.assertEqual(len(metas), 0) self.generate_raw_data(0, 2 * 2048) dumper = example_id_dumper.ExampleIdDumperManager( self.kvstore, self.data_source, 0, self.example_id_dump_options) self.generate_example_id(dumper, 0, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_raw_data(2 * 2048, 2048) self.generate_example_id(dumper, 3 * 2048, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_raw_data(3 * 2048, 5 * 2048) self.generate_example_id(dumper, 6 * 2048, 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_raw_data(8 * 2048, 2 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_example_id(dumper, 7 * 2048, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) sei.set_sync_example_id_finished() sei.set_raw_data_finished() with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) dbm = data_block_manager.DataBlockManager(self.data_source, 0) data_block_num = dbm.get_dumped_data_block_count() self.assertEqual(len(metas), data_block_num) join_count = 0 for data_block_index in range(data_block_num): meta = dbm.get_data_block_meta_by_index(data_block_index) self.assertEqual(meta, metas[data_block_index]) join_count += len(meta.example_ids) print("join rate {}/{}({}), min_matching_window {}, "\ "max_matching_window {}".format( join_count, 20480, (join_count+.0)/(10 * 2048), self.example_joiner_options.min_matching_window, self.example_joiner_options.max_matching_window))
def run_join_small_follower(self, sei, rate): metas = [] with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.assertEqual(len(metas), 0) self.generate_raw_data(8, 2 * 2048) dumper = example_id_dumper.ExampleIdDumperManager( self.kvstore, self.data_source, 0, self.example_id_dump_options) self.generate_example_id(dumper, 0, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) sei.set_raw_data_finished() self.generate_example_id(dumper, 3 * 2048, 7 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) sei.set_sync_example_id_finished() with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) dbm = data_block_manager.DataBlockManager(self.data_source, 0) data_block_num = dbm.get_dumped_data_block_count() self.assertEqual(len(metas), data_block_num) join_count = 0 for data_block_index in range(data_block_num): meta = dbm.get_data_block_meta_by_index(data_block_index) self.assertEqual(meta, metas[data_block_index]) join_count += len(meta.example_ids) print("join rate {}/{}({}), min_matching_window {}, "\ "max_matching_window {}".format( join_count, 20480 * 2, (join_count+.0)/(10 * 2048 * 2), self.example_joiner_options.min_matching_window, self.example_joiner_options.max_matching_window)) self.assertTrue((join_count + .0) / (10 * 2048 * 2) >= rate)