def _launch_masters(self): self._master_addr_l = 'localhost:4061' self._master_addr_f = 'localhost:4062' master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) self._master_l = data_join_master.DataJoinMasterService( int(self._master_addr_l.split(':')[1]), self._master_addr_f, self._data_source_name, self._db_database, self._db_base_dir_l, self._db_addr, self._db_username_l, self._db_password_l, master_options) self._master_f = data_join_master.DataJoinMasterService( int(self._master_addr_f.split(':')[1]), self._master_addr_l, self._data_source_name, self._db_database, self._db_base_dir_f, self._db_addr, self._db_username_f, self._db_password_f, master_options) self._master_f.start() self._master_l.start() channel_l = make_insecure_channel(self._master_addr_l, ChannelType.INTERNAL) self._master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(self._master_addr_f, ChannelType.INTERNAL) self._master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=self._data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=self._data_source_f.data_source_meta) dss_l = self._master_client_l.GetDataSourceStatus(req_l) dss_f = self._master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) logging.info("masters turn into Processing state")
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, master_options, ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, master_options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) self.master_client_l = master_client_l self.master_client_f = master_client_f self.master_addr_l = master_addr_l self.master_addr_f = master_addr_f self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.master_l = master_l self.master_f = master_f self.data_source_name = data_source_name, self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.etcd_l, self.raw_data_pub_dir_l) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.etcd_f, self.raw_data_pub_dir_f) if gfile.Exists(data_source_l.data_block_dir): gfile.DeleteRecursively(data_source_l.data_block_dir) if gfile.Exists(data_source_l.example_dumped_dir): gfile.DeleteRecursively(data_source_l.example_dumped_dir) if gfile.Exists(data_source_l.raw_data_dir): gfile.DeleteRecursively(data_source_l.raw_data_dir) if gfile.Exists(data_source_f.data_block_dir): gfile.DeleteRecursively(data_source_f.data_block_dir) if gfile.Exists(data_source_f.example_dumped_dir): gfile.DeleteRecursively(data_source_f.example_dumped_dir) if gfile.Exists(data_source_f.raw_data_dir): gfile.DeleteRecursively(data_source_f.raw_data_dir) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD')) self.total_index = 1 << 13
def _inner_test_round(self, start_index): for i in range(self.data_source_l.data_source_meta.partition_num): self.generate_raw_data( start_index, self.etcd_l, self.raw_data_publisher_l, self.data_source_l, self.raw_data_dir_l, i, 2048, 64, 'leader_key_partition_{}'.format(i) + ':{}', 'leader_value_partition_{}'.format(i) + ':{}' ) self.generate_raw_data( start_index, self.etcd_f, self.raw_data_publisher_f, self.data_source_f, self.raw_data_dir_f, i, 4096, 128, 'follower_key_partition_{}'.format(i) + ':{}', 'follower_value_partition_{}'.format(i) + ':{}' ) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True, batch_mode=True) master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, self.data_source_name, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, master_options, ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, self.data_source_name, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, master_options ) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: try: req_l = dj_pb.DataSourceRequest( data_source_meta=self.data_source_l.data_source_meta ) req_f = dj_pb.DataSourceRequest( data_source_meta=self.data_source_f.data_source_meta ) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break except Exception as e: pass time.sleep(2) worker_addr_l = 'localhost:4161' worker_addr_f = 'localhost:4162' worker_l = data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, self.worker_options ) worker_f = data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, self.worker_options ) th_l = threading.Thread(target=worker_l.run, name='worker_l') th_f = threading.Thread(target=worker_f.run, name='worker_f') th_l.start() th_f.start() while True: try: req_l = dj_pb.DataSourceRequest( data_source_meta=self.data_source_l.data_source_meta ) req_f = dj_pb.DataSourceRequest( data_source_meta=self.data_source_f.data_source_meta ) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Ready and \ dss_f.state == common_pb.DataSourceState.Ready: break except Exception as e: #xx pass time.sleep(2) th_l.join() th_f.join() master_l.stop() master_f.stop()
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, options) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=-1, join_example=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=-1, sync_example_id=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq1 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq1) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq2 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 3) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)), dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, ) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 1) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)), dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertTrue(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=4)) ])) try: rsp = client_l.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertTrue(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) try: rsp = client_f.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) master_l.stop() master_f.stop()
default='test_etcd', help='the name of etcd') parser.add_argument('--etcd_addrs', type=str, default='localhost:2379', help='the addrs of etcd') parser.add_argument('--etcd_base_dir', type=str, default='fedlearner_test', help='the namespace of etcd key') parser.add_argument('--listen_port', '-p', type=int, default=4032, help='Listen port of data join master') parser.add_argument('--data_source_name', type=str, default='test_data_source', help='the name of data source') parser.add_argument('--use_mock_etcd', action='store_true', help='use to mock etcd for test') args = parser.parse_args() master_options = dj_pb.DataJoinMasterOptions( use_mock_etcd=args.use_mock_etcd) master_srv = DataJoinMasterService(args.listen_port, args.peer_addr, args.data_source_name, args.etcd_name, args.etcd_base_dir, args.etcd_addrs, master_options) master_srv.run()
logging.basicConfig(format="%(asctime)s %(filename)s "\ "%(lineno)s %(levelname)s - %(message)s") parser = argparse.ArgumentParser(description='DataJointMasterService cmd.') parser.add_argument('peer_addr', type=str, help='the addr(uuid) of peer data join master') parser.add_argument('--kvstore_type', type=str, default='etcd', help='the name of mysql') parser.add_argument('--listen_port', '-p', type=int, default=4032, help='Listen port of data join master') parser.add_argument('--data_source_name', type=str, default='test_data_source', help='the name of data source') parser.add_argument('--batch_mode', action='store_true', help='make the data join run in batch mode') args = parser.parse_args() master_options = dj_pb.DataJoinMasterOptions( use_mock_etcd=(args.kvstore_type == 'mock'), batch_mode=args.batch_mode) master_srv = DataJoinMasterService(args.listen_port, args.peer_addr, args.data_source_name, args.kvstore_type, master_options) master_srv.run()
type=str, default='localhost:2379', help='the addrs of etcd') parser.add_argument('--etcd_base_dir', type=str, default='fedlearner_test', help='the namespace of etcd key') parser.add_argument('--listen_port', '-p', type=int, default=4032, help='Listen port of data join master') parser.add_argument('--data_source_name', type=str, default='test_data_source', help='the name of data source') parser.add_argument('--use_mock_etcd', action='store_true', help='use to mock etcd for test') parser.add_argument('--batch_mode', action='store_true', help='make the data join run in batch mode') args = parser.parse_args() master_options = dj_pb.DataJoinMasterOptions( use_mock_etcd=args.use_mock_etcd, batch_mode=args.batch_mode) master_srv = DataJoinMasterService(args.listen_port, args.peer_addr, args.data_source_name, args.etcd_name, args.etcd_base_dir, args.etcd_addrs, master_options) master_srv.run()