def test_all_assembly(self): for i in range(self.data_source_l.data_source_meta.partition_num): self.generate_raw_data( self.etcd_l, self.raw_data_controller_l, self.data_source_l, i, 2048, 64, 'leader_key_partition_{}'.format(i) + ':{}', 'leader_value_partition_{}'.format(i) + ':{}') self.generate_raw_data( self.etcd_f, self.raw_data_controller_f, self.data_source_f, i, 4096, 128, 'follower_key_partition_{}'.format(i) + ':{}', 'follower_value_partition_{}'.format(i) + ':{}') worker_addr_l = 'localhost:4161' worker_addr_f = 'localhost:4162' worker_l = data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self.master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, self.worker_options) worker_f = data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self.master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, self.worker_options) worker_l.start() worker_f.start() for i in range(self.data_source_f.data_source_meta.partition_num): rdmreq = dj_pb.RawDataRequest( data_source_meta=self.data_source_l.data_source_meta, partition_id=i, finish_raw_data=empty_pb2.Empty()) rsp = self.master_client_l.FinishRawData(rdmreq) self.assertEqual(rsp.code, 0) rsp = self.master_client_f.FinishRawData(rdmreq) self.assertEqual(rsp.code, 0) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=self.data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=self.data_source_f.data_source_meta) dss_l = self.master_client_l.GetDataSourceStatus(req_l) dss_f = self.master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) worker_l.stop() worker_f.stop() self.master_l.stop() self.master_f.stop()
def test_all_assembly(self): for i in range(self.data_source_l.data_source_meta.partition_num): self.generate_raw_data( self.data_source_l, i, 2048, 64, 'leader_key_partition_{}'.format(i) + ':{}', 'leader_value_partition_{}'.format(i) + ':{}') self.generate_raw_data( self.data_source_f, i, 4096, 128, 'follower_key_partition_{}'.format(i) + ':{}', 'follower_value_partition_{}'.format(i) + ':{}') worker_addr_l = 'localhost:4161' worker_addr_f = 'localhost:4162' options = customized_options.CustomizedOptions() options.set_raw_data_iter('TF_RECORD') options.set_example_joiner('STREAM_JOINER') worker_l = data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self.master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, options) worker_f = data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self.master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, options) worker_l.start() worker_f.start() while True: rsp_l = self.master_client_l.GetDataSourceState( self.data_source_l.data_source_meta) rsp_f = self.master_client_f.GetDataSourceState( self.data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Finished and rsp_f.state == common_pb.DataSourceState.Finished): break else: time.sleep(2) worker_l.stop() worker_f.stop() self.master_l.stop() self.master_f.stop()
def _launch_workers(self): worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='SORT_RUN_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=4096), data_block_builder_options=dj_pb.DataBlockBuilderOptions( data_block_builder='CSV_DICT_DATABLOCK_BUILDER')) self._worker_addrs_l = [ 'localhost:4161', 'localhost:4162', 'localhost:4163', 'localhost:4164' ] self._worker_addrs_f = [ 'localhost:5161', 'localhost:5162', 'localhost:5163', 'localhost:5164' ] self._workers_l = [] self._workers_f = [] for rank_id in range(4): worker_addr_l = self._worker_addrs_l[rank_id] worker_addr_f = self._worker_addrs_f[rank_id] self._workers_l.append( data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self._master_addr_l, rank_id, self._etcd_name, self._etcd_base_dir_l, self._etcd_addrs, worker_options)) self._workers_f.append( data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self._master_addr_f, rank_id, self._etcd_name, self._etcd_base_dir_f, self._etcd_addrs, worker_options)) for w in self._workers_l: w.start() for w in self._workers_f: w.start()
def _launch_workers(self): worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), example_id_batch_options=dj_pb.ExampleIdBatchOptions( example_id_batch_size=1024, max_flying_example_id=4096)) self._worker_addrs_l = ['localhost:4161', 'localhost:4162'] self._worker_addrs_f = ['localhost:5161', 'localhost:5162'] self._workers_l = [] self._workers_f = [] for rank_id in range(2): worker_addr_l = self._worker_addrs_l[rank_id] worker_addr_f = self._worker_addrs_f[rank_id] self._workers_l.append( data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self._master_addr_l, rank_id, self._etcd_name, self._etcd_base_dir_l, self._etcd_addrs, worker_options)) self._workers_f.append( data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self._master_addr_f, rank_id, self._etcd_name, self._etcd_base_dir_f, self._etcd_addrs, worker_options)) for w in self._workers_l: w.start() for w in self._workers_f: w.start()
def _inner_test_round(self, start_index): for i in range(self.data_source_l.data_source_meta.partition_num): self.generate_raw_data( start_index, self.etcd_l, self.raw_data_publisher_l, self.data_source_l, self.raw_data_dir_l, i, 2048, 64, 'leader_key_partition_{}'.format(i) + ':{}', 'leader_value_partition_{}'.format(i) + ':{}' ) self.generate_raw_data( start_index, self.etcd_f, self.raw_data_publisher_f, self.data_source_f, self.raw_data_dir_f, i, 4096, 128, 'follower_key_partition_{}'.format(i) + ':{}', 'follower_value_partition_{}'.format(i) + ':{}' ) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True, batch_mode=True) master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, self.data_source_name, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, master_options, ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, self.data_source_name, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, master_options ) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: try: req_l = dj_pb.DataSourceRequest( data_source_meta=self.data_source_l.data_source_meta ) req_f = dj_pb.DataSourceRequest( data_source_meta=self.data_source_f.data_source_meta ) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break except Exception as e: pass time.sleep(2) worker_addr_l = 'localhost:4161' worker_addr_f = 'localhost:4162' worker_l = data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, self.worker_options ) worker_f = data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, self.worker_options ) th_l = threading.Thread(target=worker_l.run, name='worker_l') th_f = threading.Thread(target=worker_f.run, name='worker_f') th_l.start() th_f.start() while True: try: req_l = dj_pb.DataSourceRequest( data_source_meta=self.data_source_l.data_source_meta ) req_f = dj_pb.DataSourceRequest( data_source_meta=self.data_source_f.data_source_meta ) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Ready and \ dss_f.state == common_pb.DataSourceState.Ready: break except Exception as e: #xx pass time.sleep(2) th_l.join() th_f.join() master_l.stop() master_f.stop()
def _launch_workers(self): worker_options_l = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', read_ahead_size=1 << 20, read_batch_size=128), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='SORT_RUN_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=4096), data_block_builder_options=dj_pb.WriterOptions( output_writer='CSV_DICT')) worker_options_f = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=1 << 20, read_batch_size=128), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='SORT_RUN_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=4096), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD')) self._worker_addrs_l = [ 'localhost:4161', 'localhost:4162', 'localhost:4163', 'localhost:4164' ] self._worker_addrs_f = [ 'localhost:5161', 'localhost:5162', 'localhost:5163', 'localhost:5164' ] self._workers_l = [] self._workers_f = [] for rank_id in range(4): worker_addr_l = self._worker_addrs_l[rank_id] worker_addr_f = self._worker_addrs_f[rank_id] os.environ['ETCD_BASE_DIR'] = self.leader_base_dir self._workers_l.append( data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self._master_addr_l, rank_id, self.kvstore_type, worker_options_l)) os.environ['ETCD_BASE_DIR'] = self.follower_base_dir self._workers_f.append( data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self._master_addr_f, rank_id, self.kvstore_type, worker_options_f)) for w in self._workers_l: w.start() for w in self._workers_f: w.start()