Пример #1
0
    def execute_training(self, config):
        """ Coordinates clients to execute training and validation tasks. """

        round_meta = {}
        round_meta['config'] = config
        round_meta['round_id'] = config['round_id']

        self.stage_model(config['model_id'])

        # Execute the configured number of rounds
        round_meta['local_round'] = {}
        for r in range(1, int(config['rounds']) + 1):
            self.server.report_status("ROUNDCONTROL: Starting training round {}".format(r), flush=True)
            clients = self.__assign_round_clients(self.server.max_clients)
            model, meta = self._training_round(config, clients)
            round_meta['local_round'][str(r)] = meta
            if model is None:
                self.server.report_status("\t Failed to update global model in round {0}!".format(r))

        if model is not None:
            helper = get_helper(config['helper_type'])
            a = helper.serialize_model_to_BytesIO(model)
            # Send aggregated model to server 
            model_id = str(uuid.uuid4())
            self.modelservice.set_model(a, model_id)
            a.close()

            # Update Combiner latest model
            self.server.set_active_model(model_id)

            print("------------------------------------------")
            self.server.report_status("ROUNDCONTROL: TRAINING ROUND COMPLETED.", flush=True)
            print("\n")
        return round_meta
Пример #2
0
    def run(self):

        import time
        try:
            while True:
                time.sleep(1)
                #self.server._log_queue_length()
                self.run_configs_lock.acquire()
                if len(self.run_configs) > 0:
                    compute_plan = self.run_configs.pop()
                    self.run_configs_lock.release()
                    self.config = compute_plan
                    self.helper = get_helper(self.config['helper_type'])

                    ready = self.__check_nr_round_clients(compute_plan)
                    if ready:
                        if compute_plan['task'] == 'training':
                            tic = time.time()
                            round_meta = self.exec_training(compute_plan)
                            round_meta['time_exec_training'] = time.time(
                            ) - tic
                            round_meta['name'] = self.id
                            self.server.tracer.set_round_meta(round_meta)
                        elif compute_plan['task'] == 'validation':
                            self.exec_validation(compute_plan,
                                                 compute_plan['model_id'])
                        else:
                            self.report_status(
                                "COMBINER: Compute plan contains unkown task type.",
                                flush=True)
                    else:
                        self.report_status(
                            "COMBINER: Failed to meet client allocation requirements for this compute plan.",
                            flush=True)

                if self.run_configs_lock.locked():
                    self.run_configs_lock.release()

        except (KeyboardInterrupt, SystemExit):
            pass
Пример #3
0
    def _training_round(self, config, clients):
        """Send model update requests to clients and aggregate results. 

        :param config: [description]
        :type config: [type]
        :param clients: [description]
        :type clients: [type]
        :return: [description]
        :rtype: [type]
        """

        # We flush the queue at a beginning of a round (no stragglers allowed)
        # TODO: Support other ways to handle stragglers. 
        with self.aggregator.model_updates.mutex:
            self.aggregator.model_updates.queue.clear()

        self.server.report_status("ROUNDCONTROL: Initiating training round, participating members: {}".format(clients))
        self.server.request_model_update(config['model_id'], clients=clients)

        meta = {}
        meta['nr_expected_updates'] = len(clients)
        meta['nr_required_updates'] = int(config['clients_required'])
        meta['timeout'] = float(config['round_timeout'])
        tic = time.time()
        model = None
        data = None
        try:
            helper = get_helper(config['helper_type'])
            model, data = self.aggregator.combine_models(nr_expected_models=len(clients),
                                              nr_required_models=int(config['clients_required']),
                                              helper=helper, timeout=float(config['round_timeout']))
        except Exception as e:
            print("TRAINING ROUND FAILED AT COMBINER! {}".format(e), flush=True)
        meta['time_combination'] = time.time() - tic
        meta['aggregation_time'] = data
        return model, meta
Пример #4
0
 def _initialize_helper(self,client_config):
     
     if 'model_type' in client_config.keys():
         self.helper = get_helper(client_config['model_type'])
Пример #5
0
    def __init__(self, config):

        self.state = None
        self.error_state = False
        from fedn.common.net.connect import ConnectorClient, Status
        self.connector = ConnectorClient(config['discover_host'],
                                         config['discover_port'],
                                         config['token'],
                                         config['name'],
                                         config['preferred_combiner'],
                                         config['client_id'],
                                         secure=config['secure'],
                                         preshared_cert=config['preshared_cert'],
                                         verify_cert=config['verify_cert'])
        self.name = config['name']
        import time
        dirname = time.strftime("%Y%m%d-%H%M%S")
        self.run_path = os.path.join(os.getcwd(), dirname)
        os.mkdir(self.run_path)

        from fedn.utils.logger import Logger
        self.logger = Logger(to_file=config['logfile'],file_path=self.run_path)
        self.started_at = datetime.now()
        self.logs = []
        client_config = {}
        print("Asking for assignment",flush=True)
        import time
        while True:
            status, response = self.connector.assign()
            if status == Status.TryAgain:
                time.sleep(5)
                continue
            if status == Status.Assigned:
                client_config = response
                break
            time.sleep(5)
            print(".", end=' ', flush=True)

        print("Got assigned!", flush=True)

        # TODO use the client_config['certificate'] for setting up secure comms'
        if client_config['certificate']:
            import base64
            cert = base64.b64decode(client_config['certificate'])  # .decode('utf-8')
            credentials = grpc.ssl_channel_credentials(root_certificates=cert)
            channel = grpc.secure_channel("{}:{}".format(client_config['host'], str(client_config['port'])),
                                          credentials)
        else:
            channel = grpc.insecure_channel("{}:{}".format(client_config['host'], str(client_config['port'])))

        self.connection = rpc.ConnectorStub(channel)
        self.orchestrator = rpc.CombinerStub(channel)
        self.models = rpc.ModelServiceStub(channel)

        print("Client: {} connected {} to {}:{}".format(self.name,
                                                        "SECURED" if client_config['certificate'] else "INSECURE",
                                                        client_config['host'], client_config['port']), flush=True)
        if config['remote_compute_context']:
            from fedn.common.control.package import PackageRuntime
            pr = PackageRuntime(os.getcwd(), os.getcwd())

            retval = None
            tries = 10

            while tries > 0:
                retval =  pr.download(config['discover_host'], config['discover_port'], config['token'])
                if retval:
                    break
                time.sleep(60)
                print("No compute package available... retrying in 60s Trying {} more times.".format(tries),flush=True)
                tries -= 1

            if retval:
                if not 'checksum' in config:
                    print("\nWARNING: Skipping security validation of local package!, make sure you trust the package source.\n",flush=True)
                else:
                    checks_out = pr.validate(config['checksum'])
                    if not checks_out:
                        print("Validation was enforced and invalid, client closing!")
                        self.error_state = True
                        return

            if retval:
                pr.unpack()

            self.dispatcher = pr.dispatcher(self.run_path)
            try:
                print("Running Dispatcher for entrypoint: startup", flush=True)
                self.dispatcher.run_cmd("startup")
            except KeyError:
                pass
        else:
            # TODO: Deprecate
            dispatch_config = {'entry_points':
                                   {'predict': {'command': 'python3 predict.py'},
                                    'train': {'command': 'python3 train.py'},
                                    'validate': {'command': 'python3 validate.py'}}}
            dispatch_dir = os.getcwd()
            from_path = os.path.join(os.getcwd(),'client')

            from distutils.dir_util import copy_tree
            copy_tree(from_path, run_path)
            self.dispatcher = Dispatcher(dispatch_config, self.run_path)

        self.lock = threading.Lock()

        if 'model_type' in client_config.keys():
            self.helper = get_helper(client_config['model_type'])

        if not self.helper:
            print("Failed to retrive helper class settings! {}".format(client_config),flush=True)

        threading.Thread(target=self._send_heartbeat, daemon=True).start()
        threading.Thread(target=self.__listen_to_model_update_request_stream, daemon=True).start()
        threading.Thread(target=self.__listen_to_model_validation_request_stream, daemon=True).start()

        self.state = ClientState.idle
Пример #6
0
    def __init__(self, config):

        from fedn.common.net.connect import ConnectorClient, Status
        self.connector = ConnectorClient(config['discover_host'],
                                         config['discover_port'],
                                         config['token'],
                                         config['name'],
                                         config['preferred_combiner'],
                                         config['client_id'],
                                         secure=config['secure'],
                                         preshared_cert=['preshared_cert'],
                                         verify_cert=config['verify_cert'])
        self.name = config['name']

        self.started_at = datetime.now()
        self.logs = []
        client_config = {}
        print("Asking for assignment", flush=True)
        import time
        while True:
            status, response = self.connector.assign()
            if status == Status.TryAgain:
                time.sleep(5)
                continue
            if status == Status.Assigned:
                client_config = response
                break
            time.sleep(5)
            print(".", end=' ', flush=True)

        print("Got assigned!", flush=True)

        # TODO use the client_config['certificate'] for setting up secure comms'
        if client_config['certificate']:
            import base64
            cert = base64.b64decode(
                client_config['certificate'])  # .decode('utf-8')
            credentials = grpc.ssl_channel_credentials(root_certificates=cert)
            channel = grpc.secure_channel(
                "{}:{}".format(client_config['host'],
                               str(client_config['port'])), credentials)
        else:
            channel = grpc.insecure_channel("{}:{}".format(
                client_config['host'], str(client_config['port'])))

        self.connection = rpc.ConnectorStub(channel)
        self.orchestrator = rpc.CombinerStub(channel)
        self.models = rpc.ModelServiceStub(channel)

        print("Client: {} connected {} to {}:{}".format(
            self.name,
            "SECURED" if client_config['certificate'] else "INSECURE",
            client_config['host'], client_config['port']),
              flush=True)
        if config['remote_compute_context']:
            from fedn.common.control.package import PackageRuntime
            pr = PackageRuntime(os.getcwd(), os.getcwd())

            retval = None
            tries = 10

            while tries > 0:
                retval = pr.download(config['discover_host'],
                                     config['discover_port'], config['token'])
                if retval:
                    break
                time.sleep(60)
                print(
                    "No compute package availabe... retrying in 60s Trying {} more times."
                    .format(tries),
                    flush=True)
                tries -= 1

            if retval:
                pr.unpack()

            self.dispatcher = pr.dispatcher()
            try:
                self.dispatcher.run_cmd("startup")
            except KeyError:
                print("No startup code present. skipping")
        else:
            # TODO: Deprecate
            dispatch_config = {
                'entry_points': {
                    'predict': {
                        'command': 'python3 predict.py'
                    },
                    'train': {
                        'command': 'python3 train.py'
                    },
                    'validate': {
                        'command': 'python3 validate.py'
                    }
                }
            }
            dispatch_dir = os.getcwd()
            self.dispatcher = Dispatcher(dispatch_config, dispatch_dir)

        self.lock = threading.Lock()

        if 'model_type' in client_config.keys():
            self.helper = get_helper(client_config['model_type'])

        if not self.helper:
            print("Failed to retrive helper class settings! {}".format(
                client_config),
                  flush=True)

        threading.Thread(target=self._send_heartbeat, daemon=True).start()
        threading.Thread(target=self.__listen_to_model_update_request_stream,
                         daemon=True).start()
        threading.Thread(
            target=self.__listen_to_model_validation_request_stream,
            daemon=True).start()

        self.state = ClientState.idle