def _allocate_join_partition_fn(self): assert self._processing_manifest is None req = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rsp = self._master_client.RequestJoinPartition(req) if rsp.status.code != 0: raise RuntimeError("Failed to Request partition for "\ "example intsesection, error msg {}".format( rsp.status.error_message)) if rsp.HasField('finished'): with self._lock: self._state = None return if not rsp.HasField('manifest'): logging.warning("no manifest is at state %d, wait and retry", dj_pb.RawDataState.Synced) return joiner = create_example_joiner(self._etcd, self._data_source, rsp.manifest.partition_id) with self._lock: self._processing_manifest = rsp.manifest self._joiner = joiner self._check_manifest() self._wakeup_leader_example_joiner() self._wakeup_data_block_meta_syncer()
def test_example_joiner(self): sei = joiner_impl.create_example_joiner( self.example_joiner_options, self.raw_data_options, dj_pb.WriterOptions(output_writer='TF_RECORD'), self.kvstore, self.data_source, 0) metas = [] with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.assertEqual(len(metas), 0) self.generate_raw_data(0, 2 * 2048) dumper = example_id_dumper.ExampleIdDumperManager( self.kvstore, self.data_source, 0, self.example_id_dump_options) self.generate_example_id(dumper, 0, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_raw_data(2 * 2048, 2048) self.generate_example_id(dumper, 3 * 2048, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_raw_data(3 * 2048, 5 * 2048) self.generate_example_id(dumper, 6 * 2048, 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_raw_data(8 * 2048, 2 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) self.generate_example_id(dumper, 7 * 2048, 3 * 2048) with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) sei.set_sync_example_id_finished() sei.set_raw_data_finished() with sei.make_example_joiner() as joiner: for meta in joiner: metas.append(meta) dbm = data_block_manager.DataBlockManager(self.data_source, 0) data_block_num = dbm.get_dumped_data_block_count() self.assertEqual(len(metas), data_block_num) join_count = 0 for data_block_index in range(data_block_num): meta = dbm.get_data_block_meta_by_index(data_block_index) self.assertEqual(meta, metas[data_block_index]) join_count += len(meta.example_ids) print("join rate {}/{}({}), min_matching_window {}, "\ "max_matching_window {}".format( join_count, 20480, (join_count+.0)/(10 * 2048), self.example_joiner_options.min_matching_window, self.example_joiner_options.max_matching_window))
def __init__(self, etcd, data_source, raw_data_manifest, example_joiner_options, raw_data_options, data_block_builder_options): super(ExampleJoinLeader.ImplContext, self).__init__( raw_data_manifest ) self.example_joiner = create_example_joiner( example_joiner_options, raw_data_options, data_block_builder_options, etcd, data_source, raw_data_manifest.partition_id, )
def test_universal_join_key_mapper_error(self): mapper_code = """ from fedlearner.data_join.key_mapper.key_mapping import BaseKeyMapper class KeyMapperMock(BaseKeyMapper): def leader_mapping(self, item) -> dict: res = item.click_id.decode().split("_") raise ValueError return dict({"req_id":res[0], "cid":res[1]}) def follower_mapping(self, item) -> dict: return dict() @classmethod def name(cls): return "TEST_MAPPER" """ abspath = os.path.dirname(os.path.abspath(__file__)) fname = os.path.realpath( os.path.join( abspath, "../../fedlearner/data_join/key_mapper/impl/keymapper_mock.py") ) with open(fname, "w") as f: f.write(mapper_code) reload(key_mapper) self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='UNIVERSAL_JOINER', min_matching_window=32, max_matching_window=51200, max_conversion_delay=interval_to_timestamp("258"), enable_negative_example_generator=True, data_block_dump_interval=32, data_block_dump_threshold=1024, negative_sampling_rate=0.8, join_expr="(cid,req_id) or (example_id)", join_key_mapper="TEST_MAPPER", negative_sampling_filter_expr='', ) self.version = dsp.Version.V2 sei = joiner_impl.create_example_joiner( self.example_joiner_options, self.raw_data_options, #dj_pb.WriterOptions(output_writer='TF_RECORD'), dj_pb.WriterOptions(output_writer='CSV_DICT'), self.kvstore, self.data_source, 0) self.run_join(sei, 0) os.remove(fname)
def test_universal_join_small_follower(self): self.example_joiner_options = dj_pb.ExampleJoinerOptions( example_joiner='UNIVERSAL_JOINER', min_matching_window=32, max_matching_window=20240, max_conversion_delay=interval_to_timestamp("128"), enable_negative_example_generator=False, data_block_dump_interval=32, data_block_dump_threshold=1024, negative_sampling_rate=0.8, join_expr="(id_type, example_id, trunc(event_time,1))", join_key_mapper="DEFAULT", negative_sampling_filter_expr='', ) self.version = dsp.Version.V2 sei = joiner_impl.create_example_joiner( self.example_joiner_options, self.raw_data_options, dj_pb.WriterOptions(output_writer='TF_RECORD'), self.kvstore, self.data_source, 0) self.run_join_small_follower(sei, 0.15)
def test_example_join(self): self.generate_raw_data() self.generate_example_id() customized_options.set_example_joiner('STREAM_JOINER') sei = joiner_impl.create_example_joiner(self.etcd, self.data_source, 0) sei.join_example() self.assertTrue(sei.join_finished()) dbm = data_block_manager.DataBlockManager(self.data_source, 0) data_block_num = dbm.get_dumped_data_block_num() join_count = 0 for data_block_index in range(data_block_num): meta = dbm.get_data_block_meta_by_index(data_block_index)[0] self.assertTrue(meta is not None) join_count += len(meta.example_ids) print("join rate {}/{}({}), min_matching_window {}, "\ "max_matching_window {}".format( join_count, self.total_index, (join_count+.0)/self.total_index, self.data_source.data_source_meta.min_matching_window, self.data_source.data_source_meta.max_matching_window))