Exemplo n.º 1
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='')
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='ATTRIBUTION_JOINER',
         min_matching_window=32,
         max_matching_window=51200,
         max_conversion_delay=interval_to_timestamp("124"),
         enable_negative_example_generator=True,
         data_block_dump_interval=32,
         data_block_dump_threshold=128,
         negative_sampling_rate=0.8,
     )
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
Exemplo n.º 2
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "./ds_output"
     self.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='',
                                                  optional_fields=['label'])
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='STREAM_JOINER',
         min_matching_window=32,
         max_matching_window=128,
         data_block_dump_interval=30,
         data_block_dump_threshold=128)
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = mysql_client.DBClient('test_cluster', 'localhost:2379',
                                          'test_user', 'test_password',
                                          'fedlearner', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
Exemplo n.º 3
0
 def setUp(self):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = "milestone-f"
     data_source.data_source_meta.partition_num = 1
     data_source.data_block_dir = "./data_block"
     data_source.example_dumped_dir = "./example_id"
     data_source.raw_data_dir = "./raw_data"
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type='')
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner='STREAM_JOINER',
         min_matching_window=32,
         max_matching_window=128,
         data_block_dump_interval=30,
         data_block_dump_threshold=128)
     if gfile.Exists(self.data_source.data_block_dir):
         gfile.DeleteRecursively(self.data_source.data_block_dir)
     if gfile.Exists(self.data_source.example_dumped_dir):
         gfile.DeleteRecursively(self.data_source.example_dumped_dir)
     if gfile.Exists(self.data_source.raw_data_dir):
         gfile.DeleteRecursively(self.data_source.raw_data_dir)
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     self.g_data_block_index = 0
    def test_universal_join_key_mapper_error(self):
        mapper_code = """
from fedlearner.data_join.key_mapper.key_mapping import BaseKeyMapper
class KeyMapperMock(BaseKeyMapper):
    def leader_mapping(self, item) -> dict:
        res = item.click_id.decode().split("_")
        raise ValueError
        return dict({"req_id":res[0], "cid":res[1]})

    def follower_mapping(self, item) -> dict:
        return dict()

    @classmethod
    def name(cls):
        return "TEST_MAPPER"
"""
        abspath = os.path.dirname(os.path.abspath(__file__))
        fname = os.path.realpath(
            os.path.join(
                abspath,
                "../../fedlearner/data_join/key_mapper/impl/keymapper_mock.py")
        )
        with open(fname, "w") as f:
            f.write(mapper_code)
        reload(key_mapper)

        self.example_joiner_options = dj_pb.ExampleJoinerOptions(
            example_joiner='UNIVERSAL_JOINER',
            min_matching_window=32,
            max_matching_window=51200,
            max_conversion_delay=interval_to_timestamp("258"),
            enable_negative_example_generator=True,
            data_block_dump_interval=32,
            data_block_dump_threshold=1024,
            negative_sampling_rate=0.8,
            join_expr="(cid,req_id) or (example_id)",
            join_key_mapper="TEST_MAPPER",
            negative_sampling_filter_expr='',
        )
        self.version = dsp.Version.V2

        sei = joiner_impl.create_example_joiner(
            self.example_joiner_options,
            self.raw_data_options,
            #dj_pb.WriterOptions(output_writer='TF_RECORD'),
            dj_pb.WriterOptions(output_writer='CSV_DICT'),
            self.kvstore,
            self.data_source,
            0)
        self.run_join(sei, 0)
        os.remove(fname)
Exemplo n.º 5
0
 def init(self,
          dsname,
          joiner_name,
          version=Version.V1,
          cache_type="memory"):
     data_source = common_pb.DataSource()
     data_source.data_source_meta.name = dsname
     data_source.data_source_meta.partition_num = 1
     data_source.output_base_dir = "%s_ds_output" % dsname
     self.raw_data_dir = "%s_raw_data" % dsname
     self.data_source = data_source
     self.raw_data_options = dj_pb.RawDataOptions(
         raw_data_iter='TF_RECORD',
         compressed_type='',
         raw_data_cache_type=cache_type,
     )
     self.example_id_dump_options = dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=1, example_id_dump_threshold=1024)
     self.example_joiner_options = dj_pb.ExampleJoinerOptions(
         example_joiner=joiner_name,
         min_matching_window=32,
         max_matching_window=51200,
         max_conversion_delay=interval_to_timestamp("124"),
         enable_negative_example_generator=True,
         data_block_dump_interval=32,
         data_block_dump_threshold=128,
         negative_sampling_rate=0.8,
         join_expr="example_id",
         join_key_mapper="DEFAULT",
         negative_sampling_filter_expr='',
     )
     if gfile.Exists(self.data_source.output_base_dir):
         gfile.DeleteRecursively(self.data_source.output_base_dir)
     if gfile.Exists(self.raw_data_dir):
         gfile.DeleteRecursively(self.raw_data_dir)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source.data_source_meta.name))
     self.total_raw_data_count = 0
     self.total_example_id_count = 0
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     self.g_data_block_index = 0
     self.version = version
Exemplo n.º 6
0
 def _launch_workers(self):
     worker_options = dj_pb.DataJoinWorkerOptions(
         use_mock_etcd=True,
         raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                               compressed_type=''),
         example_id_dump_options=dj_pb.ExampleIdDumpOptions(
             example_id_dump_interval=1, example_id_dump_threshold=1024),
         example_joiner_options=dj_pb.ExampleJoinerOptions(
             example_joiner='SORT_RUN_JOINER',
             min_matching_window=64,
             max_matching_window=256,
             data_block_dump_interval=30,
             data_block_dump_threshold=1000),
         batch_processor_options=dj_pb.BatchProcessorOptions(
             batch_size=1024, max_flying_item=4096),
         data_block_builder_options=dj_pb.DataBlockBuilderOptions(
             data_block_builder='CSV_DICT_DATABLOCK_BUILDER'))
     self._worker_addrs_l = [
         'localhost:4161', 'localhost:4162', 'localhost:4163',
         'localhost:4164'
     ]
     self._worker_addrs_f = [
         'localhost:5161', 'localhost:5162', 'localhost:5163',
         'localhost:5164'
     ]
     self._workers_l = []
     self._workers_f = []
     for rank_id in range(4):
         worker_addr_l = self._worker_addrs_l[rank_id]
         worker_addr_f = self._worker_addrs_f[rank_id]
         self._workers_l.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_l.split(':')[1]), worker_addr_f,
                 self._master_addr_l, rank_id, self._etcd_name,
                 self._etcd_base_dir_l, self._etcd_addrs, worker_options))
         self._workers_f.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_f.split(':')[1]), worker_addr_l,
                 self._master_addr_f, rank_id, self._etcd_name,
                 self._etcd_base_dir_f, self._etcd_addrs, worker_options))
     for w in self._workers_l:
         w.start()
     for w in self._workers_f:
         w.start()
    def test_universal_join_small_follower(self):
        self.example_joiner_options = dj_pb.ExampleJoinerOptions(
            example_joiner='UNIVERSAL_JOINER',
            min_matching_window=32,
            max_matching_window=20240,
            max_conversion_delay=interval_to_timestamp("128"),
            enable_negative_example_generator=False,
            data_block_dump_interval=32,
            data_block_dump_threshold=1024,
            negative_sampling_rate=0.8,
            join_expr="(id_type, example_id, trunc(event_time,1))",
            join_key_mapper="DEFAULT",
            negative_sampling_filter_expr='',
        )
        self.version = dsp.Version.V2

        sei = joiner_impl.create_example_joiner(
            self.example_joiner_options, self.raw_data_options,
            dj_pb.WriterOptions(output_writer='TF_RECORD'), self.kvstore,
            self.data_source, 0)
        self.run_join_small_follower(sei, 0.15)
Exemplo n.º 8
0
 def _launch_workers(self):
     worker_options = dj_pb.DataJoinWorkerOptions(
         use_mock_etcd=True,
         raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                               compressed_type=''),
         example_id_dump_options=dj_pb.ExampleIdDumpOptions(
             example_id_dump_interval=1, example_id_dump_threshold=1024),
         example_joiner_options=dj_pb.ExampleJoinerOptions(
             example_joiner='STREAM_JOINER',
             min_matching_window=64,
             max_matching_window=256,
             data_block_dump_interval=30,
             data_block_dump_threshold=1000),
         example_id_batch_options=dj_pb.ExampleIdBatchOptions(
             example_id_batch_size=1024, max_flying_example_id=4096))
     self._worker_addrs_l = ['localhost:4161', 'localhost:4162']
     self._worker_addrs_f = ['localhost:5161', 'localhost:5162']
     self._workers_l = []
     self._workers_f = []
     for rank_id in range(2):
         worker_addr_l = self._worker_addrs_l[rank_id]
         worker_addr_f = self._worker_addrs_f[rank_id]
         self._workers_l.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_l.split(':')[1]), worker_addr_f,
                 self._master_addr_l, rank_id, self._etcd_name,
                 self._etcd_base_dir_l, self._etcd_addrs, worker_options))
         self._workers_f.append(
             data_join_worker.DataJoinWorkerService(
                 int(worker_addr_f.split(':')[1]), worker_addr_l,
                 self._master_addr_f, rank_id, self._etcd_name,
                 self._etcd_base_dir_f, self._etcd_addrs, worker_options))
     for w in self._workers_l:
         w.start()
     for w in self._workers_f:
         w.start()
Exemplo n.º 9
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]),
            master_addr_f,
            data_source_name,
            etcd_name,
            etcd_base_dir_l,
            etcd_addrs,
            master_options,
        )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, master_options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = master_client_l.GetDataSourceStatus(req_l)
            dss_f = master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        self.master_client_l = master_client_l
        self.master_client_f = master_client_f
        self.master_addr_l = master_addr_l
        self.master_addr_f = master_addr_f
        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.master_l = master_l
        self.master_f = master_f
        self.data_source_name = data_source_name,
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
            self.etcd_l, self.raw_data_pub_dir_l)
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
            self.etcd_f, self.raw_data_pub_dir_f)
        if gfile.Exists(data_source_l.data_block_dir):
            gfile.DeleteRecursively(data_source_l.data_block_dir)
        if gfile.Exists(data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(data_source_l.example_dumped_dir)
        if gfile.Exists(data_source_l.raw_data_dir):
            gfile.DeleteRecursively(data_source_l.raw_data_dir)
        if gfile.Exists(data_source_f.data_block_dir):
            gfile.DeleteRecursively(data_source_f.data_block_dir)
        if gfile.Exists(data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(data_source_f.example_dumped_dir)
        if gfile.Exists(data_source_f.raw_data_dir):
            gfile.DeleteRecursively(data_source_f.raw_data_dir)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='STREAM_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=512, max_flying_item=2048),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self.total_index = 1 << 13
Exemplo n.º 10
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f= 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        self.raw_data_dir_l = "./raw_data_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        self.raw_data_dir_f = "./raw_data_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.data_source_name = data_source_name
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
                self.etcd_l, self.raw_data_pub_dir_l
            )
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
                self.etcd_f, self.raw_data_pub_dir_f
            )
        if gfile.Exists(data_source_l.output_base_dir):
            gfile.DeleteRecursively(data_source_l.output_base_dir)
        if gfile.Exists(self.raw_data_dir_l):
            gfile.DeleteRecursively(self.raw_data_dir_l)
        if gfile.Exists(data_source_f.output_base_dir):
            gfile.DeleteRecursively(data_source_f.output_base_dir)
        if gfile.Exists(self.raw_data_dir_f):
            gfile.DeleteRecursively(self.raw_data_dir_f)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
                use_mock_etcd=True,
                raw_data_options=dj_pb.RawDataOptions(
                    raw_data_iter='TF_RECORD',
                    read_ahead_size=1<<20,
                    read_batch_size=128
                ),
                example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                    example_id_dump_interval=1,
                    example_id_dump_threshold=1024
                ),
                example_joiner_options=dj_pb.ExampleJoinerOptions(
                    example_joiner='STREAM_JOINER',
                    min_matching_window=64,
                    max_matching_window=256,
                    data_block_dump_interval=30,
                    data_block_dump_threshold=1000
                ),
                batch_processor_options=dj_pb.BatchProcessorOptions(
                    batch_size=512,
                    max_flying_item=2048
                ),
                data_block_builder_options=dj_pb.WriterOptions(
                    output_writer='TF_RECORD'
                )
            )

        self.total_index = 1 << 12
Exemplo n.º 11
0
     choices=['TF_RECORD_DATABLOCK_BUILDER', 'CSV_DICT_DATABLOCK_BUILDER'],
     help='the builder of generated data block')
 args = parser.parse_args()
 if args.tf_eager_mode:
     import tensorflow
     tensorflow.compat.v1.enable_eager_execution()
 worker_options = dj_pb.DataJoinWorkerOptions(
     use_mock_etcd=args.use_mock_etcd,
     raw_data_options=dj_pb.RawDataOptions(
         raw_data_iter=args.raw_data_iter,
         compressed_type=args.compressed_type,
         read_ahead_size=args.read_ahead_size),
     example_joiner_options=dj_pb.ExampleJoinerOptions(
         example_joiner=args.example_joiner,
         min_matching_window=args.min_matching_window,
         max_matching_window=args.max_matching_window,
         data_block_dump_interval=args.data_block_dump_interval,
         data_block_dump_threshold=args.data_block_dump_threshold,
     ),
     example_id_dump_options=dj_pb.ExampleIdDumpOptions(
         example_id_dump_interval=args.example_id_dump_interval,
         example_id_dump_threshold=args.example_id_dump_threshold),
     batch_processor_options=dj_pb.BatchProcessorOptions(
         batch_size=args.example_id_batch_size,
         max_flying_item=args.max_flying_example_id),
     data_block_builder_options=dj_pb.DataBlockBuilderOptions(
         data_block_builder=args.data_block_builder))
 worker_srv = DataJoinWorkerService(args.listen_port, args.peer_addr,
                                    args.master_addr, args.rank_id,
                                    args.etcd_name, args.etcd_base_dir,
                                    args.etcd_addrs, worker_options)
Exemplo n.º 12
0
                     "filled with label: 0")
 args = parser.parse_args()
 worker_options = dj_pb.DataJoinWorkerOptions(
         use_mock_etcd=(args.kvstore_type == 'mock'),
         raw_data_options=dj_pb.RawDataOptions(
                 raw_data_iter=args.raw_data_iter,
                 compressed_type=args.compressed_type,
                 read_ahead_size=args.read_ahead_size,
                 read_batch_size=args.read_batch_size
             ),
         example_joiner_options=dj_pb.ExampleJoinerOptions(
                 example_joiner=args.example_joiner,
                 min_matching_window=args.min_matching_window,
                 max_matching_window=args.max_matching_window,
                 data_block_dump_interval=args.data_block_dump_interval,
                 data_block_dump_threshold=args.data_block_dump_threshold,
                 max_conversion_delay=interval_to_timestamp(\
                                         args.max_conversion_delay),
                 enable_negative_example_generator=\
                     args.enable_negative_example_generator,
             ),
         example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                 example_id_dump_interval=args.example_id_dump_interval,
                 example_id_dump_threshold=args.example_id_dump_threshold
             ),
         batch_processor_options=dj_pb.BatchProcessorOptions(
                 batch_size=4096,
                 max_flying_item=-1
             ),
         data_block_builder_options=dj_pb.WriterOptions(
                 output_writer=args.data_block_builder,
Exemplo n.º 13
0
    def _launch_workers(self):
        worker_options_l = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='SORT_RUN_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=1024, max_flying_item=4096),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='CSV_DICT'))
        worker_options_f = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                                  read_ahead_size=1 << 20,
                                                  read_batch_size=128),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='SORT_RUN_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=1024, max_flying_item=4096),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self._worker_addrs_l = [
            'localhost:4161', 'localhost:4162', 'localhost:4163',
            'localhost:4164'
        ]
        self._worker_addrs_f = [
            'localhost:5161', 'localhost:5162', 'localhost:5163',
            'localhost:5164'
        ]
        self._workers_l = []
        self._workers_f = []
        for rank_id in range(4):
            worker_addr_l = self._worker_addrs_l[rank_id]
            worker_addr_f = self._worker_addrs_f[rank_id]
            os.environ['ETCD_BASE_DIR'] = self.leader_base_dir
            self._workers_l.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_l.split(':')[1]), worker_addr_f,
                    self._master_addr_l, rank_id, self.kvstore_type,
                    worker_options_l))
            os.environ['ETCD_BASE_DIR'] = self.follower_base_dir
            self._workers_f.append(
                data_join_worker.DataJoinWorkerService(
                    int(worker_addr_f.split(':')[1]), worker_addr_l,
                    self._master_addr_f, rank_id, self.kvstore_type,
                    worker_options_f))
        for w in self._workers_l:
            w.start()
        for w in self._workers_f:
            w.start()