def _initialize_dispatcher(self, config): """ """ if config['remote_compute_context']: pr = PackageRuntime(os.getcwd(), os.getcwd()) retval = None tries = 10 while tries > 0: retval = pr.download(config['discover_host'], config['discover_port'], config['token']) if retval: break time.sleep(60) print("No compute package available... retrying in 60s Trying {} more times.".format(tries), flush=True) tries -= 1 if retval: if not 'checksum' in config: print( "\nWARNING: Skipping security validation of local package!, make sure you trust the package source.\n", flush=True) else: checks_out = pr.validate(config['checksum']) if not checks_out: print("Validation was enforced and invalid, client closing!") self.error_state = True return if retval: pr.unpack() self.dispatcher = pr.dispatcher(self.run_path) try: print("Running Dispatcher for entrypoint: startup", flush=True) self.dispatcher.run_cmd("startup") except KeyError: pass else: # TODO: Deprecate dispatch_config = {'entry_points': {'predict': {'command': 'python3 predict.py'}, 'train': {'command': 'python3 train.py'}, 'validate': {'command': 'python3 validate.py'}}} dispatch_dir = os.getcwd() from_path = os.path.join(os.getcwd(), 'client') from distutils.dir_util import copy_tree copy_tree(from_path, self.run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path)
def dispatcher(self, run_path): dispatch_dir = self.dir from_path = os.path.join(os.getcwd(), 'client') import time dirname = time.strftime("%Y%m%d-%H%M%S") from distutils.dir_util import copy_tree copy_tree(from_path, run_path) os.chdir(run_path) try: cfg = None with open(os.path.join(to_path, 'fedn.yaml'), 'rb') as config_file: import yaml cfg = yaml.safe_load(config_file.read()) self.dispatch_config = cfg except Exception as e: print( "Error trying to load and unpack dispatcher config - trying default", flush=True) dispatcher = Dispatcher(self.dispatch_config, run_path) return dispatcher
def dispatcher(self, run_path): """ :param run_path: :return: """ from_path = os.path.join(os.getcwd(), 'client') from distutils.dir_util import copy_tree copy_tree(from_path, run_path) try: cfg = None with open(os.path.join(run_path, 'fedn.yaml'), 'rb') as config_file: import yaml cfg = yaml.safe_load(config_file.read()) self.dispatch_config = cfg except Exception as e: print( "Error trying to load and unpack dispatcher config - trying default", flush=True) dispatcher = Dispatcher(self.dispatch_config, run_path) return dispatcher
def dispatcher(self): os.chdir(os.path.join(self.dir, 'client')) try: cfg = None with open( os.path.join(os.path.join(self.dir, 'client'), 'fedn.yaml'), 'rb') as config_file: import yaml cfg = yaml.safe_load(config_file.read()) self.dispatch_config = cfg except Exception as e: print( "Error trying to load and unpack dispatcher config - trying default", flush=True) dispatcher = Dispatcher(self.dispatch_config, os.path.join(self.dir, 'client')) return dispatcher
def __init__(self, config): from fedn.discovery.connect import DiscoveryClientConnect, State self.controller = DiscoveryClientConnect(config['discover_host'], config['discover_port'], config['token'], config['name']) self.name = config['name'] self.started_at = datetime.now() self.logs = [] import time tries = 90 status = None while True: if tries > 0: status = self.controller.connect() if status == State.Disconnected: tries = tries - 1 if status == State.Connected: break time.sleep(2) print("try to reconnect to CONTROLLER", flush=True) combiner = None tries = 180 while True: status, state = self.controller.check_status() print("got status {}".format(status), flush=True) if state == state.Disconnected: print("lost connection. trying...") tries -= 1 time.sleep(1) if tries < 1: print("ERROR! NO CONTACT!", flush=True) raise Exception("NO CONTACT WITH DISCOVERY NODE") self.controller.connect() if status != 'A': print("waiting to be assigned..", flush=True) time.sleep(5) if status == 'A': print("yay! got assigned, fetching combiner", flush=True) combiner, _ = self.controller.get_config() break # TODO REMOVE ONLY FOR TESTING (only used for partial restructuring) repo_config = { 'storage_access_key': 'minio', 'storage_secret_key': 'minio123', 'storage_bucket': 'models', 'storage_secure_mode': False, 'storage_hostname': 'minio', 'storage_port': 9000 } self.repository = get_repository(repo_config) self.bucket_name = repo_config['storage_bucket'] channel = grpc.insecure_channel(combiner['host'] + ":" + str(combiner['port'])) self.connection = rpc.ConnectorStub(channel) self.orchestrator = rpc.CombinerStub(channel) self.models = rpc.ModelServiceStub(channel) print("Client: {} connected to {}:{}".format(self.name, combiner['host'], combiner['port'])) # TODO REMOVE OVERRIDE WITH CONTEXT FETCHED dispatch_config = { 'entry_points': { 'predict': { 'command': 'python3 predict.py' }, 'train': { 'command': 'python3 train.py' }, 'validate': { 'command': 'python3 validate.py' } } } import os # TODO REMOVE OVERRIDE WITH CONTEXT FETCHED dispatch_dir = os.getcwd() self.dispatcher = Dispatcher(dispatch_config, dispatch_dir) self.lock = threading.Lock() threading.Thread(target=self._send_heartbeat, daemon=True).start() threading.Thread(target=self.__listen_to_model_update_request_stream, daemon=True).start() threading.Thread( target=self.__listen_to_model_validation_request_stream, daemon=True).start() self.state = ClientState.idle
class Client: def __init__(self, config): from fedn.discovery.connect import DiscoveryClientConnect, State self.controller = DiscoveryClientConnect(config['discover_host'], config['discover_port'], config['token'], config['name']) self.name = config['name'] self.started_at = datetime.now() self.logs = [] import time tries = 90 status = None while True: if tries > 0: status = self.controller.connect() if status == State.Disconnected: tries = tries - 1 if status == State.Connected: break time.sleep(2) print("try to reconnect to CONTROLLER", flush=True) combiner = None tries = 180 while True: status, state = self.controller.check_status() print("got status {}".format(status), flush=True) if state == state.Disconnected: print("lost connection. trying...") tries -= 1 time.sleep(1) if tries < 1: print("ERROR! NO CONTACT!", flush=True) raise Exception("NO CONTACT WITH DISCOVERY NODE") self.controller.connect() if status != 'A': print("waiting to be assigned..", flush=True) time.sleep(5) if status == 'A': print("yay! got assigned, fetching combiner", flush=True) combiner, _ = self.controller.get_config() break # TODO REMOVE ONLY FOR TESTING (only used for partial restructuring) repo_config = { 'storage_access_key': 'minio', 'storage_secret_key': 'minio123', 'storage_bucket': 'models', 'storage_secure_mode': False, 'storage_hostname': 'minio', 'storage_port': 9000 } self.repository = get_repository(repo_config) self.bucket_name = repo_config['storage_bucket'] channel = grpc.insecure_channel(combiner['host'] + ":" + str(combiner['port'])) self.connection = rpc.ConnectorStub(channel) self.orchestrator = rpc.CombinerStub(channel) self.models = rpc.ModelServiceStub(channel) print("Client: {} connected to {}:{}".format(self.name, combiner['host'], combiner['port'])) # TODO REMOVE OVERRIDE WITH CONTEXT FETCHED dispatch_config = { 'entry_points': { 'predict': { 'command': 'python3 predict.py' }, 'train': { 'command': 'python3 train.py' }, 'validate': { 'command': 'python3 validate.py' } } } import os # TODO REMOVE OVERRIDE WITH CONTEXT FETCHED dispatch_dir = os.getcwd() self.dispatcher = Dispatcher(dispatch_config, dispatch_dir) self.lock = threading.Lock() threading.Thread(target=self._send_heartbeat, daemon=True).start() threading.Thread(target=self.__listen_to_model_update_request_stream, daemon=True).start() threading.Thread( target=self.__listen_to_model_validation_request_stream, daemon=True).start() self.state = ClientState.idle def get_model(self, id): from io import BytesIO data = BytesIO() # print("REACHED DOWNLOAD Trying now with id {}".format(id), flush=True) # print("TRYING DOWNLOAD 1.", flush=True) for part in self.models.Download(alliance.ModelRequest(id=id)): # print("TRYING DOWNLOAD 2.", flush=True) if part.status == alliance.ModelStatus.IN_PROGRESS: # print("WRITING PART FOR MODEL:{}".format(id), flush=True) data.write(part.data) if part.status == alliance.ModelStatus.OK: # print("DONE WRITING MODEL RETURNING {}".format(id), flush=True) return data if part.status == alliance.ModelStatus.FAILED: # print("FAILED TO DOWNLOAD MODEL::: bailing!",flush=True) return None # print("ERROR NO PARTS!",flush=True) return data def set_model(self, model, id): from io import BytesIO if not isinstance(model, BytesIO): bt = BytesIO() for d in model.stream(32 * 1024): bt.write(d) else: bt = model # print("SETTING MODEL OF SIZE {}".format(sys.getsizeof(bt)), flush=True) bt.seek(0, 0) def upload_request_generator(mdl): i = 1 while True: b = mdl.read(CHUNK_SIZE) if b: result = alliance.ModelRequest( data=b, id=id, status=alliance.ModelStatus.IN_PROGRESS) else: result = alliance.ModelRequest( id=id, status=alliance.ModelStatus.OK) yield result if not b: break result = self.models.Upload(upload_request_generator(bt)) return result def __listen_to_model_update_request_stream(self): """ Subscribe to the model update request stream. """ r = alliance.ClientAvailableMessage() r.sender.name = self.name r.sender.role = alliance.WORKER metadata = [('client', r.sender.name)] for request in self.orchestrator.ModelUpdateRequestStream( r, metadata=metadata): if request.sender.role == alliance.COMBINER: # Process training request global_model_id = request.model_id # TODO: Error handling self.send_status("Received model update request.", log_level=alliance.Status.AUDIT, type=alliance.StatusType.MODEL_UPDATE_REQUEST, request=request) model_id = self.__process_training_request(global_model_id) if model_id != None: # Notify the requesting client that a model update is available update = alliance.ModelUpdate() update.sender.name = self.name update.sender.role = alliance.WORKER update.receiver.name = request.sender.name update.receiver.role = request.sender.role update.model_id = request.model_id update.model_update_id = str(model_id) update.timestamp = str(datetime.now()) update.correlation_id = request.correlation_id response = self.orchestrator.SendModelUpdate(update) self.send_status("Model update completed.", log_level=alliance.Status.AUDIT, type=alliance.StatusType.MODEL_UPDATE, request=update) else: self.send_status( "Client {} failed to complete model update.", log_level=alliance.Status.WARNING, request=request) def __listen_to_model_validation_request_stream(self): """ Subscribe to the model update request stream. """ r = alliance.ClientAvailableMessage() r.sender.name = self.name r.sender.role = alliance.WORKER for request in self.orchestrator.ModelValidationRequestStream(r): # Process training request model_id = request.model_id # TODO: Error handling self.send_status("Recieved model validation request.", log_level=alliance.Status.AUDIT, type=alliance.StatusType.MODEL_VALIDATION_REQUEST, request=request) metrics = self.__process_validation_request(model_id) if metrics != None: # Send validation validation = alliance.ModelValidation() validation.sender.name = self.name validation.sender.role = alliance.WORKER validation.receiver.name = request.sender.name validation.receiver.role = request.sender.role validation.model_id = str(model_id) validation.data = json.dumps(metrics) self.str = str(datetime.now()) validation.timestamp = self.str validation.correlation_id = request.correlation_id response = self.orchestrator.SendModelValidation(validation) self.send_status("Model validation completed.", log_level=alliance.Status.AUDIT, type=alliance.StatusType.MODEL_VALIDATION, request=validation) else: self.send_status( "Client {} failed to complete model validation.".format( self.client), log_level=alliance.Status.WARNING, request=request) def __process_training_request(self, model_id): self.send_status( "\t Processing training request for model_id {}".format(model_id)) self.state = ClientState.training try: # print("IN TRAINING REQUEST 1", flush=True) mdl = self.get_model(str(model_id)) import sys # print("did i get a model? model_id: {} size:{}".format(model_id, sys.getsizeof(mdl))) # print("IN TRAINING REQUEST 2", flush=True) # model = self.repository.get_model(model_id) fid, infile_name = tempfile.mkstemp(suffix='.h5') fod, outfile_name = tempfile.mkstemp(suffix='.h5') with open(infile_name, "wb") as fh: fh.write(mdl.getbuffer()) # print("IN TRAINING REQUEST 3", flush=True) self.dispatcher.run_cmd("train {} {}".format( infile_name, outfile_name)) # print("IN TRAINING REQUEST 4", flush=True) # model_id = self.repository.set_model(outfile_name, is_file=True) import io out_model = None with open(outfile_name, "rb") as fr: out_model = io.BytesIO(fr.read()) # print("IN TRAINING REQUEST 5", flush=True) import uuid model_id = uuid.uuid4() self.set_model(out_model, str(model_id)) # print("IN TRAINING REQUEST 6", flush=True) os.unlink(infile_name) os.unlink(outfile_name) except Exception as e: print("ERROR could not process training request due to error: {}". format(e)) model_id = None self.state = ClientState.idle return model_id def __process_validation_request(self, model_id): self.send_status( "Processing validation request for model_id {}".format(model_id)) self.state = ClientState.validating try: model = self.get_model(model_id) # repository.get_model(model_id) fid, infile_name = tempfile.mkstemp(suffix='.h5') fod, outfile_name = tempfile.mkstemp(suffix='.h5') with open(infile_name, "wb") as fh: fh.write(model.getbuffer()) self.dispatcher.run_cmd("validate {} {}".format( infile_name, outfile_name)) with open(outfile_name, "r") as fh: validation = json.loads(fh.read()) os.unlink(infile_name) os.unlink(outfile_name) except Exception as e: print("Validation failed with exception {}".format(e), flush=True) self.state = ClientState.idle return None self.state = ClientState.idle return validation def send_status(self, msg, log_level=alliance.Status.INFO, type=None, request=None): from google.protobuf.json_format import MessageToJson status = alliance.Status() status.sender.name = self.name status.sender.role = alliance.WORKER status.log_level = log_level status.status = str(msg) if type is not None: status.type = type if request is not None: status.data = MessageToJson(request) self.logs.append("{} {} LOG LEVEL {} MESSAGE {}".format( str(datetime.now()), status.sender.name, status.log_level, status.status)) response = self.connection.SendStatus(status) def _send_heartbeat(self, update_frequency=2.0): while True: heartbeat = alliance.Heartbeat( sender=alliance.Client(name=self.name, role=alliance.WORKER)) self.connection.SendHeartbeat(heartbeat) # self.send_status("HEARTBEAT from {}".format(self.client),log_level=alliance.Status.INFO) import time time.sleep(update_frequency) def run_web(self): from flask import Flask app = Flask(__name__) from .pages import page, style @app.route('/') def index(): logs_fancy = str() for log in self.logs: logs_fancy += "<p>" + log + "</p>\n" return page.format(client=self.name, state=ClientStateToString(self.state), style=style, logs=logs_fancy) # return {"name": self.name, "State": ClientStateToString(self.state), "Runtime": str(datetime.now() - self.started_at), # "Since": str(self.started_at)} import os, sys self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') app.run(host="0.0.0.0", port="8090") sys.stdout.close() sys.stdout = self._original_stdout def run(self): import time import threading threading.Thread(target=self.run_web, daemon=True).start() try: cnt = 0 old_state = self.state while True: time.sleep(1) cnt += 1 if self.state != old_state: print("CLIENT {}".format(ClientStateToString(self.state)), flush=True) if cnt > 5: print("CLIENT active", flush=True) cnt = 0 except KeyboardInterrupt: print("ok exiting..")
class Client: """FEDn Client. Service running on client/datanodes in a federation, recieving and handling model update and model validation requests. Attibutes --------- config: dict A configuration dictionary containing connection information for the discovery service (controller) and settings governing e.g. client-combiner assignment behavior. """ def __init__(self, config): """ Parameters ---------- config: dict A configuration dictionary containing connection information for the discovery service (controller) and settings governing e.g. client-combiner assignment behavior. """ self.state = None self.error_state = False self._attached = False self._missed_heartbeat=0 self.config = config self.connector = ConnectorClient(config['discover_host'], config['discover_port'], config['token'], config['name'], config['remote_compute_context'], config['preferred_combiner'], config['client_id'], secure=config['secure'], preshared_cert=config['preshared_cert'], verify_cert=config['verify_cert']) self.name = config['name'] dirname = time.strftime("%Y%m%d-%H%M%S") self.run_path = os.path.join(os.getcwd(), dirname) os.mkdir(self.run_path) self.logger = Logger(to_file=config['logfile'], file_path=self.run_path) self.started_at = datetime.now() self.logs = [] self.inbox = queue.Queue() # Attach to the FEDn network (get combiner) client_config = self._attach() self._initialize_dispatcher(config) self._initialize_helper(client_config) if not self.helper: print("Failed to retrive helper class settings! {}".format(client_config), flush=True) self._subscribe_to_combiner(config) self.state = ClientState.idle def _detach(self): # Setting _attached to False will make all processing threads return if not self._attached: print("Client is not attached.",flush=True) self._attached = False # Close gRPC connection to combiner self._disconnect() def _attach(self): """ """ # Ask controller for a combiner and connect to that combiner. if self._attached: print("Client is already attached. ",flush=True) return None client_config = self._assign() self._connect(client_config) if client_config: self._attached=True return client_config def _initialize_helper(self,client_config): if 'model_type' in client_config.keys(): self.helper = get_helper(client_config['model_type']) def _subscribe_to_combiner(self,config): """Listen to combiner message stream and start all processing threads. """ # Start sending heartbeats to the combiner. threading.Thread(target=self._send_heartbeat, kwargs={'update_frequency': config['heartbeat_interval']}, daemon=True).start() # Start listening for combiner training and validation messages if config['trainer'] == True: threading.Thread(target=self._listen_to_model_update_request_stream, daemon=True).start() if config['validator'] == True: threading.Thread(target=self._listen_to_model_validation_request_stream, daemon=True).start() self._attached = True # Start processing the client message inbox threading.Thread(target=self.process_request, daemon=True).start() def _initialize_dispatcher(self, config): """ """ if config['remote_compute_context']: pr = PackageRuntime(os.getcwd(), os.getcwd()) retval = None tries = 10 while tries > 0: retval = pr.download(config['discover_host'], config['discover_port'], config['token']) if retval: break time.sleep(60) print("No compute package available... retrying in 60s Trying {} more times.".format(tries), flush=True) tries -= 1 if retval: if not 'checksum' in config: print( "\nWARNING: Skipping security validation of local package!, make sure you trust the package source.\n", flush=True) else: checks_out = pr.validate(config['checksum']) if not checks_out: print("Validation was enforced and invalid, client closing!") self.error_state = True return if retval: pr.unpack() self.dispatcher = pr.dispatcher(self.run_path) try: print("Running Dispatcher for entrypoint: startup", flush=True) self.dispatcher.run_cmd("startup") except KeyError: pass else: # TODO: Deprecate dispatch_config = {'entry_points': {'predict': {'command': 'python3 predict.py'}, 'train': {'command': 'python3 train.py'}, 'validate': {'command': 'python3 validate.py'}}} dispatch_dir = os.getcwd() from_path = os.path.join(os.getcwd(), 'client') from distutils.dir_util import copy_tree copy_tree(from_path, self.run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path) def _assign(self): """Contacts the controller and asks for combiner assignment. """ print("Asking for assignment!", flush=True) while True: status, response = self.connector.assign() if status == Status.TryAgain: print(response, flush=True) time.sleep(5) continue if status == Status.Assigned: client_config = response break if status == Status.UnAuthorized: print(response, flush=True) sys.exit("Exiting: Unauthorized") if status == Status.UnMatchedConfig: print(response, flush=True) sys.exit("Exiting: UnMatchedConfig") time.sleep(5) print(".", end=' ', flush=True) print("Got assigned!", flush=True) return client_config def _connect(self, client_config): """Connect to assigned combiner. Parameters ---------- client_config : dict A dictionary with connection information and settings for the assigned combiner. """ # TODO use the client_config['certificate'] for setting up secure comms' if client_config['certificate']: import base64 cert = base64.b64decode(client_config['certificate']) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel("{}:{}".format(client_config['host'], str(client_config['port'])), credentials) else: channel = grpc.insecure_channel("{}:{}".format(client_config['host'], str(client_config['port']))) self.channel = channel self.connection = rpc.ConnectorStub(channel) self.orchestrator = rpc.CombinerStub(channel) self.models = rpc.ModelServiceStub(channel) print("Client: {} connected {} to {}:{}".format(self.name, "SECURED" if client_config['certificate'] else "INSECURE", client_config['host'], client_config['port']), flush=True) print("Client: Using {} compute package.".format(client_config["package"])) def _disconnect(self): self.channel.close() def get_model(self, id): """Fetch a model from the assigned combiner. Downloads the model update object via a gRPC streaming channel, Dowload. Parameters ---------- id : str The id of the model update object. """ from io import BytesIO data = BytesIO() for part in self.models.Download(fedn.ModelRequest(id=id)): if part.status == fedn.ModelStatus.IN_PROGRESS: data.write(part.data) if part.status == fedn.ModelStatus.OK: return data if part.status == fedn.ModelStatus.FAILED: return None return data def set_model(self, model, id): """Send a model update to the assigned combiner. Uploads the model updated object via a gRPC streaming channel, Upload. Parameters ---------- model : BytesIO, object The model update object. id : str The id of the model update object. """ from io import BytesIO if not isinstance(model, BytesIO): bt = BytesIO() for d in model.stream(32 * 1024): bt.write(d) else: bt = model bt.seek(0, 0) def upload_request_generator(mdl): """ :param mdl: """ i = 1 while True: b = mdl.read(CHUNK_SIZE) if b: result = fedn.ModelRequest(data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) else: result = fedn.ModelRequest(id=id, status=fedn.ModelStatus.OK) yield result if not b: break result = self.models.Upload(upload_request_generator(bt)) return result def _listen_to_model_update_request_stream(self): """Subscribe to the model update request stream. """ r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER metadata = [('client', r.sender.name)] _disconnect = False while True: try: for request in self.orchestrator.ModelUpdateRequestStream(r, metadata=metadata): if request.sender.role == fedn.COMBINER: # Process training request self._send_status("Received model update request.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request) self.inbox.put(('train', request)) if not self._attached: return except grpc.RpcError as e: status_code = e.code() #TODO: make configurable timeout = 5 #print("CLIENT __listen_to_model_update_request_stream: GRPC ERROR {} retrying in {}..".format( # status_code.name, timeout), flush=True) time.sleep(timeout) except: raise if not self._attached: return def _listen_to_model_validation_request_stream(self): """Subscribe to the model validation request stream. """ r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER while True: try: for request in self.orchestrator.ModelValidationRequestStream(r): # Process validation request model_id = request.model_id self._send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION_REQUEST, request=request) self.inbox.put(('validate', request)) except grpc.RpcError as e: status_code = e.code() # TODO: make configurable timeout = 5 #print("CLIENT __listen_to_model_validation_request_stream: GRPC ERROR {} retrying in {}..".format( # status_code.name, timeout), flush=True) time.sleep(timeout) except: raise if not self._attached: return def process_request(self): """Process training and validation tasks. """ while True: if not self._attached: return try: (task_type, request) = self.inbox.get(timeout=1.0) if task_type == 'train': tic = time.time() self.state = ClientState.training model_id, meta = self._process_training_request(request.model_id) processing_time = time.time()-tic meta['processing_time'] = processing_time if model_id != None: # Notify the combiner that a model update is available update = fedn.ModelUpdate() update.sender.name = self.name update.sender.role = fedn.WORKER update.receiver.name = request.sender.name update.receiver.role = request.sender.role update.model_id = request.model_id update.model_update_id = str(model_id) update.timestamp = str(datetime.now()) update.correlation_id = request.correlation_id update.meta = json.dumps(meta) #TODO: Check responses response = self.orchestrator.SendModelUpdate(update) self._send_status("Model update completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE, request=update) else: self._send_status("Client {} failed to complete model update.", log_level=fedn.Status.WARNING, request=request) self.state = ClientState.idle self.inbox.task_done() elif task_type == 'validate': self.state = ClientState.validating metrics = self._process_validation_request(request.model_id) if metrics != None: # Send validation validation = fedn.ModelValidation() validation.sender.name = self.name validation.sender.role = fedn.WORKER validation.receiver.name = request.sender.name validation.receiver.role = request.sender.role validation.model_id = str(request.model_id) validation.data = json.dumps(metrics) self.str = str(datetime.now()) validation.timestamp = self.str validation.correlation_id = request.correlation_id response = self.orchestrator.SendModelValidation(validation) self._send_status("Model validation completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION, request=validation) else: self._send_status("Client {} failed to complete model validation.".format(self.name), log_level=fedn.Status.WARNING, request=request) self.state = ClientState.idle self.inbox.task_done() except queue.Empty: pass def _process_training_request(self, model_id): """Process a training (model update) request. Parameters ---------- model_id : Str The id of the model to update. """ self._send_status("\t Starting processing of training request for model_id {}".format(model_id)) self.state = ClientState.training try: meta = {} tic = time.time() mdl = self.get_model(str(model_id)) meta['fetch_model'] = time.time() - tic inpath = self.helper.get_tmp_path() with open(inpath, 'wb') as fh: fh.write(mdl.getbuffer()) outpath = self.helper.get_tmp_path() tic = time.time() # TODO: Check return status, fail gracefully self.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) meta['exec_training'] = time.time() - tic tic = time.time() out_model = None with open(outpath, "rb") as fr: out_model = io.BytesIO(fr.read()) # Push model update to combiner server updated_model_id = uuid.uuid4() self.set_model(out_model, str(updated_model_id)) meta['upload_model'] = time.time() - tic os.unlink(inpath) os.unlink(outpath) except Exception as e: print("ERROR could not process training request due to error: {}".format(e), flush=True) updated_model_id = None meta = {'status': 'failed', 'error': str(e)} self.state = ClientState.idle return updated_model_id, meta def _process_validation_request(self, model_id): self._send_status("Processing validation request for model_id {}".format(model_id)) self.state = ClientState.validating try: model = self.get_model(str(model_id)) inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: fh.write(model.getbuffer()) _, outpath = tempfile.mkstemp() self.dispatcher.run_cmd("validate {} {}".format(inpath, outpath)) with open(outpath, "r") as fh: validation = json.loads(fh.read()) os.unlink(inpath) os.unlink(outpath) except Exception as e: print("Validation failed with exception {}".format(e), flush=True) raise self.state = ClientState.idle return None self.state = ClientState.idle return validation def _handle_combiner_failure(self): """ Register failed combiner connection. """ self._missed_heartbeat += 1 if self._missed_heartbeat > self.config['reconnect_after_missed_heartbeat']: self._detach() def _send_heartbeat(self, update_frequency=2.0): """Send a heartbeat to the combiner. Parameters ---------- update_frequency : float The interval in seconds between heartbeat messages. """ while True: heartbeat = fedn.Heartbeat(sender=fedn.Client(name=self.name, role=fedn.WORKER)) try: self.connection.SendHeartbeat(heartbeat) self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() print("CLIENT heartbeat: GRPC ERROR {} retrying..".format(status_code.name), flush=True) self._handle_combiner_failure() time.sleep(update_frequency) if not self._attached: return def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None): """Send status message. """ from google.protobuf.json_format import MessageToJson status = fedn.Status() status.timestamp = str(datetime.now()) status.sender.name = self.name status.sender.role = fedn.WORKER status.log_level = log_level status.status = str(msg) if type is not None: status.type = type if request is not None: status.data = MessageToJson(request) self.logs.append( "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) response = self.connection.SendStatus(status) def run_web(self): """Starts a local logging UI (Flask app) serving on port 8080. Currently not in use as default. """ from flask import Flask app = Flask(__name__) from fedn.common.net.web.client import page, style @app.route('/') def index(): """ :return: """ logs_fancy = str() for log in self.logs: logs_fancy += "<p>" + log + "</p>\n" return page.format(client=self.name, state=ClientStateToString(self.state), style=style, logs=logs_fancy) import os, sys self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') app.run(host="0.0.0.0", port="8080") sys.stdout.close() sys.stdout = self._original_stdout def run(self): """ Main run loop. """ #threading.Thread(target=self.run_web, daemon=True).start() try: cnt = 0 old_state = self.state while True: time.sleep(1) cnt += 1 if self.state != old_state: print("{}:CLIENT in {} state".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), ClientStateToString(self.state)), flush=True) if cnt > 5: print("{}:CLIENT active".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')), flush=True) cnt = 0 if not self._attached: print("Detatched from combiner.", flush=True) # TODO: Implement a check/condition to ulitmately close down if too many reattachment attepts have failed. s self._attach() self._subscribe_to_combiner(self.config) if self.error_state: return except KeyboardInterrupt: print("Ok, exiting..")
def __init__(self, config): self.state = None self.error_state = False from fedn.common.net.connect import ConnectorClient, Status self.connector = ConnectorClient(config['discover_host'], config['discover_port'], config['token'], config['name'], config['preferred_combiner'], config['client_id'], secure=config['secure'], preshared_cert=config['preshared_cert'], verify_cert=config['verify_cert']) self.name = config['name'] import time dirname = time.strftime("%Y%m%d-%H%M%S") self.run_path = os.path.join(os.getcwd(), dirname) os.mkdir(self.run_path) from fedn.utils.logger import Logger self.logger = Logger(to_file=config['logfile'],file_path=self.run_path) self.started_at = datetime.now() self.logs = [] client_config = {} print("Asking for assignment",flush=True) import time while True: status, response = self.connector.assign() if status == Status.TryAgain: time.sleep(5) continue if status == Status.Assigned: client_config = response break time.sleep(5) print(".", end=' ', flush=True) print("Got assigned!", flush=True) # TODO use the client_config['certificate'] for setting up secure comms' if client_config['certificate']: import base64 cert = base64.b64decode(client_config['certificate']) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel("{}:{}".format(client_config['host'], str(client_config['port'])), credentials) else: channel = grpc.insecure_channel("{}:{}".format(client_config['host'], str(client_config['port']))) self.connection = rpc.ConnectorStub(channel) self.orchestrator = rpc.CombinerStub(channel) self.models = rpc.ModelServiceStub(channel) print("Client: {} connected {} to {}:{}".format(self.name, "SECURED" if client_config['certificate'] else "INSECURE", client_config['host'], client_config['port']), flush=True) if config['remote_compute_context']: from fedn.common.control.package import PackageRuntime pr = PackageRuntime(os.getcwd(), os.getcwd()) retval = None tries = 10 while tries > 0: retval = pr.download(config['discover_host'], config['discover_port'], config['token']) if retval: break time.sleep(60) print("No compute package available... retrying in 60s Trying {} more times.".format(tries),flush=True) tries -= 1 if retval: if not 'checksum' in config: print("\nWARNING: Skipping security validation of local package!, make sure you trust the package source.\n",flush=True) else: checks_out = pr.validate(config['checksum']) if not checks_out: print("Validation was enforced and invalid, client closing!") self.error_state = True return if retval: pr.unpack() self.dispatcher = pr.dispatcher(self.run_path) try: print("Running Dispatcher for entrypoint: startup", flush=True) self.dispatcher.run_cmd("startup") except KeyError: pass else: # TODO: Deprecate dispatch_config = {'entry_points': {'predict': {'command': 'python3 predict.py'}, 'train': {'command': 'python3 train.py'}, 'validate': {'command': 'python3 validate.py'}}} dispatch_dir = os.getcwd() from_path = os.path.join(os.getcwd(),'client') from distutils.dir_util import copy_tree copy_tree(from_path, run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path) self.lock = threading.Lock() if 'model_type' in client_config.keys(): self.helper = get_helper(client_config['model_type']) if not self.helper: print("Failed to retrive helper class settings! {}".format(client_config),flush=True) threading.Thread(target=self._send_heartbeat, daemon=True).start() threading.Thread(target=self.__listen_to_model_update_request_stream, daemon=True).start() threading.Thread(target=self.__listen_to_model_validation_request_stream, daemon=True).start() self.state = ClientState.idle
class Client: """FEDn Client. """ def __init__(self, config): self.state = None self.error_state = False from fedn.common.net.connect import ConnectorClient, Status self.connector = ConnectorClient(config['discover_host'], config['discover_port'], config['token'], config['name'], config['preferred_combiner'], config['client_id'], secure=config['secure'], preshared_cert=config['preshared_cert'], verify_cert=config['verify_cert']) self.name = config['name'] import time dirname = time.strftime("%Y%m%d-%H%M%S") self.run_path = os.path.join(os.getcwd(), dirname) os.mkdir(self.run_path) from fedn.utils.logger import Logger self.logger = Logger(to_file=config['logfile'],file_path=self.run_path) self.started_at = datetime.now() self.logs = [] client_config = {} print("Asking for assignment",flush=True) import time while True: status, response = self.connector.assign() if status == Status.TryAgain: time.sleep(5) continue if status == Status.Assigned: client_config = response break time.sleep(5) print(".", end=' ', flush=True) print("Got assigned!", flush=True) # TODO use the client_config['certificate'] for setting up secure comms' if client_config['certificate']: import base64 cert = base64.b64decode(client_config['certificate']) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel("{}:{}".format(client_config['host'], str(client_config['port'])), credentials) else: channel = grpc.insecure_channel("{}:{}".format(client_config['host'], str(client_config['port']))) self.connection = rpc.ConnectorStub(channel) self.orchestrator = rpc.CombinerStub(channel) self.models = rpc.ModelServiceStub(channel) print("Client: {} connected {} to {}:{}".format(self.name, "SECURED" if client_config['certificate'] else "INSECURE", client_config['host'], client_config['port']), flush=True) if config['remote_compute_context']: from fedn.common.control.package import PackageRuntime pr = PackageRuntime(os.getcwd(), os.getcwd()) retval = None tries = 10 while tries > 0: retval = pr.download(config['discover_host'], config['discover_port'], config['token']) if retval: break time.sleep(60) print("No compute package available... retrying in 60s Trying {} more times.".format(tries),flush=True) tries -= 1 if retval: if not 'checksum' in config: print("\nWARNING: Skipping security validation of local package!, make sure you trust the package source.\n",flush=True) else: checks_out = pr.validate(config['checksum']) if not checks_out: print("Validation was enforced and invalid, client closing!") self.error_state = True return if retval: pr.unpack() self.dispatcher = pr.dispatcher(self.run_path) try: print("Running Dispatcher for entrypoint: startup", flush=True) self.dispatcher.run_cmd("startup") except KeyError: pass else: # TODO: Deprecate dispatch_config = {'entry_points': {'predict': {'command': 'python3 predict.py'}, 'train': {'command': 'python3 train.py'}, 'validate': {'command': 'python3 validate.py'}}} dispatch_dir = os.getcwd() from_path = os.path.join(os.getcwd(),'client') from distutils.dir_util import copy_tree copy_tree(from_path, run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path) self.lock = threading.Lock() if 'model_type' in client_config.keys(): self.helper = get_helper(client_config['model_type']) if not self.helper: print("Failed to retrive helper class settings! {}".format(client_config),flush=True) threading.Thread(target=self._send_heartbeat, daemon=True).start() threading.Thread(target=self.__listen_to_model_update_request_stream, daemon=True).start() threading.Thread(target=self.__listen_to_model_validation_request_stream, daemon=True).start() self.state = ClientState.idle def get_model(self, id): """Fetch model from the Combiner. """ from io import BytesIO data = BytesIO() for part in self.models.Download(fedn.ModelRequest(id=id)): if part.status == fedn.ModelStatus.IN_PROGRESS: data.write(part.data) if part.status == fedn.ModelStatus.OK: return data if part.status == fedn.ModelStatus.FAILED: return None return data def set_model(self, model, id): """Upload a model to the Combiner. """ from io import BytesIO if not isinstance(model, BytesIO): bt = BytesIO() for d in model.stream(32 * 1024): bt.write(d) else: bt = model bt.seek(0, 0) def upload_request_generator(mdl): i = 1 while True: b = mdl.read(CHUNK_SIZE) if b: result = fedn.ModelRequest(data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) else: result = fedn.ModelRequest(id=id, status=fedn.ModelStatus.OK) yield result if not b: break result = self.models.Upload(upload_request_generator(bt)) return result def __listen_to_model_update_request_stream(self): """Subscribe to the model update request stream. """ r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER metadata = [('client', r.sender.name)] import time while True: try: for request in self.orchestrator.ModelUpdateRequestStream(r, metadata=metadata): if request.sender.role == fedn.COMBINER: # Process training request global_model_id = request.model_id # TODO: Error handling self.send_status("Received model update request.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request) tic = time.time() model_id, meta = self.__process_training_request(global_model_id) processing_time = time.time()-tic meta['processing_time'] = processing_time print(meta,flush=True) if model_id != None: # Notify the combiner that a model update is available update = fedn.ModelUpdate() update.sender.name = self.name update.sender.role = fedn.WORKER update.receiver.name = request.sender.name update.receiver.role = request.sender.role update.model_id = request.model_id update.model_update_id = str(model_id) update.timestamp = str(datetime.now()) update.correlation_id = request.correlation_id update.meta = json.dumps(meta) #TODO: Check responses response = self.orchestrator.SendModelUpdate(update) self.send_status("Model update completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE, request=update) else: self.send_status("Client {} failed to complete model update.", log_level=fedn.Status.WARNING, request=request) except grpc.RpcError as e: status_code = e.code() timeout = 5 print("CLIENT __listen_to_model_update_request_stream: GRPC ERROR {} retrying in {}..".format( status_code.name, timeout), flush=True) import time time.sleep(timeout) def __listen_to_model_validation_request_stream(self): """Subscribe to the model validation request stream. """ r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER while True: try: for request in self.orchestrator.ModelValidationRequestStream(r): # Process training request model_id = request.model_id # TODO: Error handling self.send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION_REQUEST, request=request) metrics = self.__process_validation_request(model_id) if metrics != None: # Send validation validation = fedn.ModelValidation() validation.sender.name = self.name validation.sender.role = fedn.WORKER validation.receiver.name = request.sender.name validation.receiver.role = request.sender.role validation.model_id = str(model_id) validation.data = json.dumps(metrics) self.str = str(datetime.now()) validation.timestamp = self.str validation.correlation_id = request.correlation_id response = self.orchestrator.SendModelValidation(validation) self.send_status("Model validation completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION, request=validation) else: self.send_status("Client {} failed to complete model validation.".format(self.name), log_level=fedn.Status.WARNING, request=request) except grpc.RpcError as e: status_code = e.code() timeout = 5 print("CLIENT __listen_to_model_validation_request_stream: GRPC ERROR {} retrying in {}..".format( status_code.name, timeout), flush=True) import time time.sleep(timeout) def __process_training_request(self, model_id): self.send_status("\t Starting processing of training request for model_id {}".format(model_id)) self.state = ClientState.training try: meta = {} tic = time.time() mdl = self.get_model(str(model_id)) meta['fetch_model'] = time.time()-tic import sys inpath = self.helper.get_tmp_path() with open(inpath,'wb') as fh: fh.write(mdl.getbuffer()) outpath = self.helper.get_tmp_path() tic = time.time() #TODO: Check return status, fail gracefully self.dispatcher.run_cmd("train {} {}".format(inpath, outpath)) meta['exec_training'] = time.time()-tic tic = time.time() import io out_model = None with open(outpath, "rb") as fr: out_model = io.BytesIO(fr.read()) import uuid updated_model_id = uuid.uuid4() self.set_model(out_model, str(updated_model_id)) meta['upload_model'] = time.time()-tic os.unlink(inpath) os.unlink(outpath) except Exception as e: print("ERROR could not process training request due to error: {}".format(e),flush=True) updated_model_id = None meta = {'status':'failed','error':str(e)} self.state = ClientState.idle return updated_model_id, meta def __process_validation_request(self, model_id): self.send_status("Processing validation request for model_id {}".format(model_id)) self.state = ClientState.validating try: model = self.get_model(str(model_id)) inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: fh.write(model.getbuffer()) _,outpath = tempfile.mkstemp() self.dispatcher.run_cmd("validate {} {}".format(inpath, outpath)) with open(outpath, "r") as fh: validation = json.loads(fh.read()) os.unlink(inpath) os.unlink(outpath) except Exception as e: print("Validation failed with exception {}".format(e), flush=True) raise self.state = ClientState.idle return None self.state = ClientState.idle return validation def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None): """Send status message. """ from google.protobuf.json_format import MessageToJson status = fedn.Status() status.timestamp = str(datetime.now()) status.sender.name = self.name status.sender.role = fedn.WORKER status.log_level = log_level status.status = str(msg) if type is not None: status.type = type if request is not None: status.data = MessageToJson(request) self.logs.append( "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) response = self.connection.SendStatus(status) def _send_heartbeat(self, update_frequency=2.0): """Send a heartbeat to the Combiner. """ while True: heartbeat = fedn.Heartbeat(sender=fedn.Client(name=self.name, role=fedn.WORKER)) try: self.connection.SendHeartbeat(heartbeat) except grpc.RpcError as e: status_code = e.code() print("CLIENT heartbeat: GRPC ERROR {} retrying..".format(status_code.name), flush=True) import time time.sleep(update_frequency) def run_web(self): from flask import Flask app = Flask(__name__) from fedn.common.net.web.client import page, style @app.route('/') def index(): logs_fancy = str() for log in self.logs: logs_fancy += "<p>" + log + "</p>\n" return page.format(client=self.name, state=ClientStateToString(self.state), style=style, logs=logs_fancy) import os, sys self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') app.run(host="0.0.0.0", port="8080") sys.stdout.close() sys.stdout = self._original_stdout def run(self): import time threading.Thread(target=self.run_web, daemon=True).start() try: cnt = 0 old_state = self.state while True: time.sleep(1) cnt += 1 if self.state != old_state: print("CLIENT {}".format(ClientStateToString(self.state)), flush=True) if cnt > 5: print("CLIENT active", flush=True) cnt = 0 if self.error_state: return except KeyboardInterrupt: print("ok exiting..")
def __init__(self, config): from fedn.common.net.connect import ConnectorClient, Status self.connector = ConnectorClient(config['discover_host'], config['discover_port'], config['token'], config['name'], config['preferred_combiner'], config['client_id'], secure=config['secure'], preshared_cert=['preshared_cert'], verify_cert=config['verify_cert']) self.name = config['name'] self.started_at = datetime.now() self.logs = [] client_config = {} print("Asking for assignment", flush=True) import time while True: status, response = self.connector.assign() if status == Status.TryAgain: time.sleep(5) continue if status == Status.Assigned: client_config = response break time.sleep(5) print(".", end=' ', flush=True) print("Got assigned!", flush=True) # TODO use the client_config['certificate'] for setting up secure comms' if client_config['certificate']: import base64 cert = base64.b64decode( client_config['certificate']) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel( "{}:{}".format(client_config['host'], str(client_config['port'])), credentials) else: channel = grpc.insecure_channel("{}:{}".format( client_config['host'], str(client_config['port']))) self.connection = rpc.ConnectorStub(channel) self.orchestrator = rpc.CombinerStub(channel) self.models = rpc.ModelServiceStub(channel) print("Client: {} connected {} to {}:{}".format( self.name, "SECURED" if client_config['certificate'] else "INSECURE", client_config['host'], client_config['port']), flush=True) if config['remote_compute_context']: from fedn.common.control.package import PackageRuntime pr = PackageRuntime(os.getcwd(), os.getcwd()) retval = None tries = 10 while tries > 0: retval = pr.download(config['discover_host'], config['discover_port'], config['token']) if retval: break time.sleep(60) print( "No compute package availabe... retrying in 60s Trying {} more times." .format(tries), flush=True) tries -= 1 if retval: pr.unpack() self.dispatcher = pr.dispatcher() try: self.dispatcher.run_cmd("startup") except KeyError: print("No startup code present. skipping") else: # TODO: Deprecate dispatch_config = { 'entry_points': { 'predict': { 'command': 'python3 predict.py' }, 'train': { 'command': 'python3 train.py' }, 'validate': { 'command': 'python3 validate.py' } } } dispatch_dir = os.getcwd() self.dispatcher = Dispatcher(dispatch_config, dispatch_dir) self.lock = threading.Lock() if 'model_type' in client_config.keys(): self.helper = get_helper(client_config['model_type']) if not self.helper: print("Failed to retrive helper class settings! {}".format( client_config), flush=True) threading.Thread(target=self._send_heartbeat, daemon=True).start() threading.Thread(target=self.__listen_to_model_update_request_stream, daemon=True).start() threading.Thread( target=self.__listen_to_model_validation_request_stream, daemon=True).start() self.state = ClientState.idle
def __init__(self, config): from fedn.common.net.connect import ConnectorClient, Status self.connector = ConnectorClient(config['discover_host'], config['discover_port'], config['token'], config['name'], config['client_id'], secure=config['secure'], preshared_cert=['preshared_cert'], verify_cert=config['verify_cert']) self.name = config['name'] self.started_at = datetime.now() self.logs = [] client_config = {} print("Asking for assignment") import time while True: status, response = self.connector.assign() #print(status,response,flush=True) if status == Status.TryAgain: time.sleep(5) continue if status == Status.Assigned: client_config = response break time.sleep(5) print(".", end=' ', flush=True) # print("try to reconnect to REDUCER", flush=True) # connect_config = None print("Got assigned!", flush=True) tries = 180 # while True: # connect_config = {'host': 'combiner', 'port': 12080} # TODO REMOVE ONLY FOR TESTING (only used for partial restructuring) # import os repo_config = { 'storage_access_key': os.environ['FEDN_MINIO_ACCESS_KEY'], 'storage_secret_key': os.environ['FEDN_MINIO_SECRET_KEY'], 'storage_bucket': 'models', 'storage_secure_mode': False, 'storage_hostname': os.environ['FEDN_MINIO_HOST'], 'storage_port': int(os.environ['FEDN_MINIO_PORT']) } # repo_config, _ = self.controller.get_config() self.bucket_name = repo_config['storage_bucket'] # TODO use the client_config['certificate'] for setting up secure comms' if client_config['certificate']: import base64 cert = base64.b64decode( client_config['certificate']) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel( "{}:{}".format(client_config['host'], str(client_config['port'])), credentials) else: channel = grpc.insecure_channel("{}:{}".format( client_config['host'], str(client_config['port']))) self.connection = rpc.ConnectorStub(channel) self.orchestrator = rpc.CombinerStub(channel) self.models = rpc.ModelServiceStub(channel) print("Client: {} connected {} to {}:{}".format( self.name, "SECURED" if client_config['certificate'] else "INSECURE", client_config['host'], client_config['port']), flush=True) # TODO REMOVE OVERRIDE WITH CONTEXT FETCHED dispatch_config = { 'entry_points': { 'predict': { 'command': 'python3 predict.py' }, 'train': { 'command': 'python3 train.py' }, 'validate': { 'command': 'python3 validate.py' } } } # TODO REMOVE OVERRIDE WITH CONTEXT FETCHED dispatch_dir = os.getcwd() self.dispatcher = Dispatcher(dispatch_config, dispatch_dir) self.lock = threading.Lock() threading.Thread(target=self._send_heartbeat, daemon=True).start() threading.Thread(target=self.__listen_to_model_update_request_stream, daemon=True).start() threading.Thread( target=self.__listen_to_model_validation_request_stream, daemon=True).start() self.state = ClientState.idle
class Client: def __init__(self, config): from fedn.common.net.connect import ConnectorClient, Status self.connector = ConnectorClient(config['discover_host'], config['discover_port'], config['token'], config['name'], config['client_id'], secure=config['secure'], preshared_cert=['preshared_cert'], verify_cert=config['verify_cert']) self.name = config['name'] self.started_at = datetime.now() self.logs = [] client_config = {} print("Asking for assignment") import time while True: status, response = self.connector.assign() #print(status,response,flush=True) if status == Status.TryAgain: time.sleep(5) continue if status == Status.Assigned: client_config = response break time.sleep(5) print(".", end=' ', flush=True) # print("try to reconnect to REDUCER", flush=True) # connect_config = None print("Got assigned!", flush=True) tries = 180 # while True: # connect_config = {'host': 'combiner', 'port': 12080} # TODO REMOVE ONLY FOR TESTING (only used for partial restructuring) # import os repo_config = { 'storage_access_key': os.environ['FEDN_MINIO_ACCESS_KEY'], 'storage_secret_key': os.environ['FEDN_MINIO_SECRET_KEY'], 'storage_bucket': 'models', 'storage_secure_mode': False, 'storage_hostname': os.environ['FEDN_MINIO_HOST'], 'storage_port': int(os.environ['FEDN_MINIO_PORT']) } # repo_config, _ = self.controller.get_config() self.bucket_name = repo_config['storage_bucket'] # TODO use the client_config['certificate'] for setting up secure comms' if client_config['certificate']: import base64 cert = base64.b64decode( client_config['certificate']) # .decode('utf-8') credentials = grpc.ssl_channel_credentials(root_certificates=cert) channel = grpc.secure_channel( "{}:{}".format(client_config['host'], str(client_config['port'])), credentials) else: channel = grpc.insecure_channel("{}:{}".format( client_config['host'], str(client_config['port']))) self.connection = rpc.ConnectorStub(channel) self.orchestrator = rpc.CombinerStub(channel) self.models = rpc.ModelServiceStub(channel) print("Client: {} connected {} to {}:{}".format( self.name, "SECURED" if client_config['certificate'] else "INSECURE", client_config['host'], client_config['port']), flush=True) # TODO REMOVE OVERRIDE WITH CONTEXT FETCHED dispatch_config = { 'entry_points': { 'predict': { 'command': 'python3 predict.py' }, 'train': { 'command': 'python3 train.py' }, 'validate': { 'command': 'python3 validate.py' } } } # TODO REMOVE OVERRIDE WITH CONTEXT FETCHED dispatch_dir = os.getcwd() self.dispatcher = Dispatcher(dispatch_config, dispatch_dir) self.lock = threading.Lock() threading.Thread(target=self._send_heartbeat, daemon=True).start() threading.Thread(target=self.__listen_to_model_update_request_stream, daemon=True).start() threading.Thread( target=self.__listen_to_model_validation_request_stream, daemon=True).start() self.state = ClientState.idle def get_model(self, id): from io import BytesIO data = BytesIO() # print("REACHED DOWNLOAD Trying now with id {}".format(id), flush=True) # print("TRYING DOWNLOAD 1.", flush=True) for part in self.models.Download(fedn.ModelRequest(id=id)): # print("TRYING DOWNLOAD 2.", flush=True) if part.status == fedn.ModelStatus.IN_PROGRESS: # print("WRITING PART FOR MODEL:{}".format(id), flush=True) data.write(part.data) if part.status == fedn.ModelStatus.OK: # print("DONE WRITING MODEL RETURNING {}".format(id), flush=True) return data if part.status == fedn.ModelStatus.FAILED: # print("FAILED TO DOWNLOAD MODEL::: bailing!",flush=True) return None # print("ERROR NO PARTS!",flush=True) return data def set_model(self, model, id): from io import BytesIO if not isinstance(model, BytesIO): bt = BytesIO() for d in model.stream(32 * 1024): bt.write(d) else: bt = model # print("SETTING MODEL OF SIZE {}".format(sys.getsizeof(bt)), flush=True) bt.seek(0, 0) def upload_request_generator(mdl): i = 1 while True: b = mdl.read(CHUNK_SIZE) if b: result = fedn.ModelRequest( data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) else: result = fedn.ModelRequest(id=id, status=fedn.ModelStatus.OK) yield result if not b: break result = self.models.Upload(upload_request_generator(bt)) return result def __listen_to_model_update_request_stream(self): """ Subscribe to the model update request stream. """ r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER metadata = [('client', r.sender.name)] while True: try: for request in self.orchestrator.ModelUpdateRequestStream( r, metadata=metadata): if request.sender.role == fedn.COMBINER: # Process training request global_model_id = request.model_id # TODO: Error handling self.send_status( "Received model update request.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE_REQUEST, request=request) model_id = self.__process_training_request( global_model_id) if model_id != None: # Notify the requesting client that a model update is available update = fedn.ModelUpdate() update.sender.name = self.name update.sender.role = fedn.WORKER update.receiver.name = request.sender.name update.receiver.role = request.sender.role update.model_id = request.model_id update.model_update_id = str(model_id) update.timestamp = str(datetime.now()) update.correlation_id = request.correlation_id response = self.orchestrator.SendModelUpdate( update) self.send_status("Model update completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_UPDATE, request=update) else: self.send_status( "Client {} failed to complete model update.", log_level=fedn.Status.WARNING, request=request) except grpc.RpcError as e: status_code = e.code() timeout = 5 print( "CLIENT __listen_to_model_update_request_stream: GRPC ERROR {} retrying in {}.." .format(status_code.name, timeout), flush=True) import time time.sleep(timeout) def __listen_to_model_validation_request_stream(self): """ Subscribe to the model update request stream. """ r = fedn.ClientAvailableMessage() r.sender.name = self.name r.sender.role = fedn.WORKER while True: try: for request in self.orchestrator.ModelValidationRequestStream( r): # Process training request model_id = request.model_id # TODO: Error handling self.send_status( "Recieved model validation request.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION_REQUEST, request=request) metrics = self.__process_validation_request(model_id) if metrics != None: # Send validation validation = fedn.ModelValidation() validation.sender.name = self.name validation.sender.role = fedn.WORKER validation.receiver.name = request.sender.name validation.receiver.role = request.sender.role validation.model_id = str(model_id) validation.data = json.dumps(metrics) self.str = str(datetime.now()) validation.timestamp = self.str validation.correlation_id = request.correlation_id response = self.orchestrator.SendModelValidation( validation) self.send_status("Model validation completed.", log_level=fedn.Status.AUDIT, type=fedn.StatusType.MODEL_VALIDATION, request=validation) else: self.send_status( "Client {} failed to complete model validation.". format(self.name), log_level=fedn.Status.WARNING, request=request) except grpc.RpcError as e: status_code = e.code() timeout = 5 print( "CLIENT __listen_to_model_validation_request_stream: GRPC ERROR {} retrying in {}.." .format(status_code.name, timeout), flush=True) import time time.sleep(timeout) def __process_training_request(self, model_id): self.send_status( "\t Processing training request for model_id {}".format(model_id)) self.state = ClientState.training try: # print("IN TRAINING REQUEST 1", flush=True) mdl = self.get_model(str(model_id)) import sys # print("did i get a model? model_id: {} size:{}".format(model_id, sys.getsizeof(mdl))) # print("IN TRAINING REQUEST 2", flush=True) # model = self.repository.get_model(model_id) fid, infile_name = tempfile.mkstemp(suffix='.h5') fod, outfile_name = tempfile.mkstemp(suffix='.h5') with open(infile_name, "wb") as fh: fh.write(mdl.getbuffer()) # print("IN TRAINING REQUEST 3", flush=True) self.dispatcher.run_cmd("train {} {}".format( infile_name, outfile_name)) # print("IN TRAINING REQUEST 4", flush=True) # model_id = self.repository.set_model(outfile_name, is_file=True) import io out_model = None with open(outfile_name, "rb") as fr: out_model = io.BytesIO(fr.read()) # print("IN TRAINING REQUEST 5", flush=True) import uuid model_id = uuid.uuid4() self.set_model(out_model, str(model_id)) # print("IN TRAINING REQUEST 6", flush=True) os.unlink(infile_name) os.unlink(outfile_name) except Exception as e: print("ERROR could not process training request due to error: {}". format(e)) model_id = None self.state = ClientState.idle return model_id def __process_validation_request(self, model_id): self.send_status( "Processing validation request for model_id {}".format(model_id)) self.state = ClientState.validating try: model = self.get_model( str(model_id)) # repository.get_model(model_id) fid, infile_name = tempfile.mkstemp(suffix='.h5') fod, outfile_name = tempfile.mkstemp(suffix='.h5') with open(infile_name, "wb") as fh: fh.write(model.getbuffer()) self.dispatcher.run_cmd("validate {} {}".format( infile_name, outfile_name)) with open(outfile_name, "r") as fh: validation = json.loads(fh.read()) os.unlink(infile_name) os.unlink(outfile_name) except Exception as e: print("Validation failed with exception {}".format(e), flush=True) self.state = ClientState.idle return None self.state = ClientState.idle return validation def send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None): print("SEND_STATUS REPORTS:{}".format(msg), flush=True) from google.protobuf.json_format import MessageToJson status = fedn.Status() status.sender.name = self.name status.sender.role = fedn.WORKER status.log_level = log_level status.status = str(msg) if type is not None: status.type = type if request is not None: status.data = MessageToJson(request) self.logs.append("{} {} LOG LEVEL {} MESSAGE {}".format( str(datetime.now()), status.sender.name, status.log_level, status.status)) response = self.connection.SendStatus(status) def _send_heartbeat(self, update_frequency=2.0): while True: heartbeat = fedn.Heartbeat( sender=fedn.Client(name=self.name, role=fedn.WORKER)) try: self.connection.SendHeartbeat(heartbeat) # self.send_status("HEARTBEAT from {}".format(self.client),log_level=fedn.Status.INFO) except grpc.RpcError as e: status_code = e.code() print("CLIENT heartbeat: GRPC ERROR {} retrying..".format( status_code.name), flush=True) import time time.sleep(update_frequency) def run_web(self): from flask import Flask app = Flask(__name__) from fedn.common.net.web.client import page, style @app.route('/') def index(): logs_fancy = str() for log in self.logs: logs_fancy += "<p>" + log + "</p>\n" return page.format(client=self.name, state=ClientStateToString(self.state), style=style, logs=logs_fancy) # return {"name": self.name, "State": ClientStateToString(self.state), "Runtime": str(datetime.now() - self.started_at), # "Since": str(self.started_at)} import os, sys self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') app.run(host="0.0.0.0", port="8080") sys.stdout.close() sys.stdout = self._original_stdout def run(self): import time threading.Thread(target=self.run_web, daemon=True).start() try: cnt = 0 old_state = self.state while True: time.sleep(1) cnt += 1 if self.state != old_state: print("CLIENT {}".format(ClientStateToString(self.state)), flush=True) if cnt > 5: print("CLIENT active", flush=True) cnt = 0 except KeyboardInterrupt: print("ok exiting..")