def test_dumped_example_visitor(self):
     visitor = example_id_visitor.ExampleIdVisitor(self.kvstore, self.data_source, 0)
     expected_index = 0
     for (index, example) in visitor:
         self.assertEqual(index, expected_index)
         self.assertEqual('{}'.format(index).encode(), example.example_id)
         self.assertEqual(150000000+index, example.event_time)
         self.assertEqual(index, example.index)
         expected_index += 1
     self.assertEqual(0, expected_index)
     self.assertRaises(StopIteration, visitor.seek, 200)
     self.assertTrue(visitor.finished())
     dumper = example_id_dumper.ExampleIdDumperManager(
             self.kvstore, self.data_source, 0, self.example_id_dump_options
         )
     self.assertEqual(dumper.get_next_index(), 0)
     self._dump_example_ids(dumper, 0, 10, 1024)
     self.assertTrue(visitor.is_visitor_stale())
     visitor.active_visitor()
     for (index, example) in visitor:
         self.assertEqual(index, expected_index)
         self.assertEqual('{}'.format(index).encode(), example.example_id)
         self.assertEqual(150000000+index, example.event_time)
         self.assertEqual(index, example.index)
         expected_index += 1
     self.assertEqual(10240, expected_index)
     self.assertTrue(visitor.finished())
     visitor.seek(200)
     expected_index = 200
     self.assertEqual(expected_index, visitor.get_index())
     self.assertEqual(expected_index, visitor.get_item().index)
     self.assertEqual(150000000+expected_index, visitor.get_item().event_time)
     dumper2 = example_id_dumper.ExampleIdDumperManager(
             self.kvstore, self.data_source, 0, self.example_id_dump_options
         )
     self._dump_example_ids(dumper2, 10240, 10, 1024)
     expected_index += 1
     self.assertTrue(visitor.is_visitor_stale())
     visitor.active_visitor()
     for (index, example) in visitor:
         self.assertEqual(index, expected_index)
         self.assertEqual('{}'.format(index).encode(), example.example_id)
         self.assertEqual(150000000+index, example.event_time)
         self.assertEqual(index, example.index)
         expected_index += 1
     self.assertEqual(10240 * 2, expected_index)
     visitor2 = example_id_visitor.ExampleIdVisitor(self.kvstore, self.data_source, 0)
     visitor2.seek(886)
     expected_index = 886
     self.assertEqual(expected_index, visitor2.get_index())
     self.assertEqual(expected_index, visitor2.get_item().index)
     self.assertEqual(150000000+expected_index, visitor2.get_item().event_time)
     expected_index += 1
     for (index, example) in visitor2:
         self.assertEqual(index, expected_index)
         self.assertEqual('{}'.format(index).encode(), example.example_id)
         self.assertEqual(150000000+index, example.event_time)
         self.assertEqual(index, example.index)
         expected_index += 1
     self.assertEqual(10240 * 2, expected_index)
 def test_example_id_dumper(self):
     example_id_dumper1 = example_id_dumper.ExampleIdDumperManager(
         self.kvstore, self.data_source, 0, self.example_id_dump_options)
     self.assertEqual(example_id_dumper1.get_next_index(), 0)
     self._dump_example_ids(example_id_dumper1, 0, 10, 1024)
     example_id_manager = \
         example_id_visitor.ExampleIdManager(self.kvstore, self.data_source, 0, True)
     last_dumped_index = example_id_manager.get_last_dumped_index()
     self.assertEqual(last_dumped_index, 10240 - 1)
     example_id_dumper2 = example_id_dumper.ExampleIdDumperManager(
         self.kvstore, self.data_source, 0, self.example_id_dump_options)
     self.assertEqual(example_id_dumper2.get_next_index(), 10240)
     self._dump_example_ids(example_id_dumper2, 10 * 1024, 10, 1024)
     last_dumped_index = example_id_manager.get_last_dumped_index()
     self.assertEqual(last_dumped_index, 2 * 10240 - 1)
Exemple #3
0
    def generate_example_id(self):
        eid = example_id_dumper.ExampleIdDumperManager(self.data_source, 0)

        for req_index in range(self.total_index // 512):
            req = dj_pb.SyncExamplesRequest(
                data_source_meta=self.data_source.data_source_meta,
                partition_id=0,
                begin_index=req_index * 512)
            cands = list(range(req_index * 512, (req_index + 1) * 512))
            start_index = cands[0]
            for i in range(len(cands)):
                if random.randint(1, 4) > 1:
                    continue
                a = random.randint(i - 32, i + 32)
                b = random.randint(i - 32, i + 32)
                if a < 0:
                    a = 0
                if a >= len(cands):
                    a = len(cands) - 1
                if b < 0:
                    b = 0
                if b >= len(cands):
                    b = len(cands) - 1
                if (abs(cands[a] - i - start_index) <= 32
                        and abs(cands[b] - i - start_index) <= 32):
                    cands[a], cands[b] = cands[b], cands[a]
            for example_idx in cands:
                req.example_id.append('{}'.format(example_idx).encode())
                req.event_time.append(150000000 + example_idx)
            eid.append_synced_example_req(req)
            self.assertEqual(eid.get_next_index(), (req_index + 1) * 512)
        eid.finish_sync_example()
        self.assertTrue(eid.need_dump())
        eid.dump_example_ids()
Exemple #4
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-x"
     data_source.data_source_meta.partition_num = 1
     data_source.example_dumped_dir = "./example_ids"
     self.data_source = data_source
     if gfile.Exists(self.data_source.example_dumped_dir):
         gfile.DeleteRecursively(self.data_source.example_dumped_dir)
     self.partition_dir = os.path.join(self.data_source.example_dumped_dir,
                                       'partition_0')
     gfile.MakeDirs(self.partition_dir)
     self._example_id_dumper = example_id_dumper.ExampleIdDumperManager(
         self.data_source, 0)
     self.assertEqual(self._example_id_dumper.get_next_index(), 0)
     index = 0
     for i in range(5):
         req = dj_pb.SyncExamplesRequest(
             data_source_meta=data_source.data_source_meta,
             partition_id=0,
             begin_index=index)
         for j in range(1 << 15):
             req.example_id.append('{}'.format(index).encode())
             req.event_time.append(150000000 + index)
             self.end_index = index
             index += 1
         self._example_id_dumper.append_synced_example_req(req)
         self.assertEqual(self._example_id_dumper.get_next_index(), index)
     self._example_id_dumper.finish_sync_example()
     self.assertTrue(self._example_id_dumper.need_dump())
     self._example_id_dumper.dump_example_ids()
    def test_example_joiner(self):
        sei = joiner_impl.create_example_joiner(
            self.example_joiner_options, self.raw_data_options,
            dj_pb.WriterOptions(output_writer='TF_RECORD'), self.kvstore,
            self.data_source, 0)
        metas = []
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)
        self.assertEqual(len(metas), 0)
        self.generate_raw_data(0, 2 * 2048)
        dumper = example_id_dumper.ExampleIdDumperManager(
            self.kvstore, self.data_source, 0, self.example_id_dump_options)
        self.generate_example_id(dumper, 0, 3 * 2048)
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)
        self.generate_raw_data(2 * 2048, 2048)
        self.generate_example_id(dumper, 3 * 2048, 3 * 2048)
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)
        self.generate_raw_data(3 * 2048, 5 * 2048)
        self.generate_example_id(dumper, 6 * 2048, 2048)
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)
        self.generate_raw_data(8 * 2048, 2 * 2048)
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)
        self.generate_example_id(dumper, 7 * 2048, 3 * 2048)
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)
        sei.set_sync_example_id_finished()
        sei.set_raw_data_finished()
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)

        dbm = data_block_manager.DataBlockManager(self.data_source, 0)
        data_block_num = dbm.get_dumped_data_block_count()
        self.assertEqual(len(metas), data_block_num)
        join_count = 0
        for data_block_index in range(data_block_num):
            meta = dbm.get_data_block_meta_by_index(data_block_index)
            self.assertEqual(meta, metas[data_block_index])
            join_count += len(meta.example_ids)

        print("join rate {}/{}({}), min_matching_window {}, "\
              "max_matching_window {}".format(
              join_count, 20480,
              (join_count+.0)/(10 * 2048),
              self.example_joiner_options.min_matching_window,
              self.example_joiner_options.max_matching_window))
    def run_join_small_follower(self, sei, rate):
        metas = []
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)
        self.assertEqual(len(metas), 0)
        self.generate_raw_data(8, 2 * 2048)
        dumper = example_id_dumper.ExampleIdDumperManager(
            self.kvstore, self.data_source, 0, self.example_id_dump_options)
        self.generate_example_id(dumper, 0, 3 * 2048)
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)

        sei.set_raw_data_finished()
        self.generate_example_id(dumper, 3 * 2048, 7 * 2048)
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)
        sei.set_sync_example_id_finished()
        with sei.make_example_joiner() as joiner:
            for meta in joiner:
                metas.append(meta)

        dbm = data_block_manager.DataBlockManager(self.data_source, 0)
        data_block_num = dbm.get_dumped_data_block_count()
        self.assertEqual(len(metas), data_block_num)
        join_count = 0
        for data_block_index in range(data_block_num):
            meta = dbm.get_data_block_meta_by_index(data_block_index)
            self.assertEqual(meta, metas[data_block_index])
            join_count += len(meta.example_ids)

        print("join rate {}/{}({}), min_matching_window {}, "\
              "max_matching_window {}".format(
              join_count, 20480 * 2,
              (join_count+.0)/(10 * 2048 * 2),
              self.example_joiner_options.min_matching_window,
              self.example_joiner_options.max_matching_window))
        self.assertTrue((join_count + .0) / (10 * 2048 * 2) >= rate)