示例#1
0
    def __init__(self, connect_config):
        self.clients = {}

        self.modelservice = ModelService()

        self.id = connect_config['myname']
        self.role = Role.COMBINER
        self.max_clients = connect_config['max_clients']

        self.model_id = None

        from fedn.common.net.connect import ConnectorCombiner, Status
        announce_client = ConnectorCombiner(
            host=connect_config['discover_host'],
            port=connect_config['discover_port'],
            myhost=connect_config['myhost'],
            myport=connect_config['myport'],
            token=connect_config['token'],
            name=connect_config['myname'])

        response = None
        while True:
            status, response = announce_client.announce()
            if status == Status.TryAgain:
                time.sleep(5)
                continue
            if status == Status.Assigned:
                config = response
                print(
                    "COMBINER: was announced successfully. Waiting for clients and commands!",
                    flush=True)
                break

        cert = base64.b64decode(config['certificate'])  # .decode('utf-8')
        key = base64.b64decode(config['key'])  # .decode('utf-8')

        grpc_config = {
            'port': connect_config['myport'],
            'secure': connect_config['secure'],
            'certificate': cert,
            'key': key
        }

        self.repository = S3ModelRepository(
            config['storage']['storage_config'])
        self.server = Server(self, self.modelservice, grpc_config)

        from fedn.algo.fedavg import FEDAVGCombiner
        self.combiner = FEDAVGCombiner(self.id, self.repository, self,
                                       self.modelservice)

        threading.Thread(target=self.combiner.run, daemon=True).start()

        from fedn.common.tracer.mongotracer import MongoTracer
        self.tracer = MongoTracer(config['statestore']['mongo_config'],
                                  config['statestore']['network_id'])

        self.server.start()
示例#2
0
    def instruct(self, config):
        """ Main entrypoint, executes the compute plan. """

        if self.__state == ReducerState.instructing:
            print("Already set in INSTRUCTING state", flush=True)
            return

        self.__state = ReducerState.instructing

        if not self.get_latest_model():
            print("No model in model chain, please seed the alliance!")

        self.__state = ReducerState.monitoring

        # TODO: Validate and set the round config object
        #self.set_config(config)

        # TODO: Refactor
        from fedn.common.tracer.mongotracer import MongoTracer
        statestore_config = self.statestore.get_config()
        self.tracer = MongoTracer(statestore_config['mongo_config'],
                                  statestore_config['network_id'])
        last_round = self.tracer.get_latest_round()

        for round in range(1, int(config['rounds'] + 1)):
            tic = time.time()
            if last_round:
                current_round = last_round + round
            else:
                current_round = round

            from datetime import datetime
            start_time = datetime.now()
            # start round monitor
            self.tracer.start_monitor(round)

            model_id, round_meta = self.round(config, current_round)
            end_time = datetime.now()

            if model_id:
                print("REDUCER: Global round completed, new model: {}".format(
                    model_id),
                      flush=True)
                round_time = end_time - start_time
                self.tracer.set_latest_time(current_round, round_time.seconds)
                round_meta['status'] = 'Success'
            else:
                print("REDUCER: Global round failed!")
                round_meta['status'] = 'Failed'

            # stop round monitor
            self.tracer.stop_monitor()
            round_meta['time_round'] = time.time() - tic
            self.tracer.set_round_meta_reducer(round_meta)

        self.__state = ReducerState.idle
示例#3
0
        def delete_model_trail():
            if request.method == 'POST':
                from fedn.common.tracer.mongotracer import MongoTracer
                statestore_config = self.control.statestore.get_config()
                self.tracer = MongoTracer(statestore_config['mongo_config'],
                                          statestore_config['network_id'])
                try:
                    self.control.drop_models()
                except:
                    pass

                # drop objects in minio
                self.control.delete_bucket_objects()
                return redirect(url_for('models'))
            seed = True
            return redirect(url_for('models', seed=seed))
示例#4
0
class Combiner(rpc.CombinerServicer, rpc.ReducerServicer,
               rpc.ConnectorServicer, rpc.ControlServicer):
    """ Communication relayer. """
    def __init__(self, connect_config):
        self.clients = {}

        self.modelservice = ModelService()

        self.id = connect_config['myname']
        self.role = Role.COMBINER
        self.max_clients = connect_config['max_clients']

        self.model_id = None

        from fedn.common.net.connect import ConnectorCombiner, Status
        announce_client = ConnectorCombiner(
            host=connect_config['discover_host'],
            port=connect_config['discover_port'],
            myhost=connect_config['myhost'],
            myport=connect_config['myport'],
            token=connect_config['token'],
            name=connect_config['myname'])

        response = None
        while True:
            status, response = announce_client.announce()
            if status == Status.TryAgain:
                time.sleep(5)
                continue
            if status == Status.Assigned:
                config = response
                print(
                    "COMBINER: was announced successfully. Waiting for clients and commands!",
                    flush=True)
                break

        cert = base64.b64decode(config['certificate'])  # .decode('utf-8')
        key = base64.b64decode(config['key'])  # .decode('utf-8')

        grpc_config = {
            'port': connect_config['myport'],
            'secure': connect_config['secure'],
            'certificate': cert,
            'key': key
        }

        self.repository = S3ModelRepository(
            config['storage']['storage_config'])
        self.server = Server(self, self.modelservice, grpc_config)

        from fedn.algo.fedavg import FEDAVGCombiner
        self.combiner = FEDAVGCombiner(self.id, self.repository, self,
                                       self.modelservice)

        threading.Thread(target=self.combiner.run, daemon=True).start()

        from fedn.common.tracer.mongotracer import MongoTracer
        self.tracer = MongoTracer(config['statestore']['mongo_config'],
                                  config['statestore']['network_id'])

        self.server.start()

    def __whoami(self, client, instance):
        def role_to_proto_role(role):
            if role == Role.COMBINER:
                return fedn.COMBINER
            if role == Role.WORKER:
                return fedn.WORKER
            if role == Role.REDUCER:
                return fedn.REDUCER
            if role == Role.OTHER:
                return fedn.OTHER

        client.name = instance.id
        client.role = role_to_proto_role(instance.role)
        return client

    def get_active_model(self):
        return self.model_id

    def set_active_model(self, model_id):
        self.model_id = model_id

    def request_model_update(self, model_id, clients=[]):
        """ Ask clients to update the current global model. If an empty list
            is passed, broadcasts to all active clients. s
        """

        print("COMBINER: Sending to clients {}".format(clients), flush=True)
        request = fedn.ModelUpdateRequest()
        self.__whoami(request.sender, self)
        request.model_id = model_id
        request.correlation_id = str(uuid.uuid4())
        request.timestamp = str(datetime.now())

        if len(clients) == 0:
            clients = self.get_active_trainers()

        for client in clients:
            request.receiver.name = client.name
            request.receiver.role = fedn.WORKER
            self.SendModelUpdateRequest(request, self)

    def request_model_validation(self, model_id, clients=[]):
        """ Ask clients to validate the current global model. If an empty list
            is passed, broadcasts to all active clients. s
        """
        request = fedn.ModelValidationRequest()
        self.__whoami(request.sender, self)
        request.model_id = model_id
        request.correlation_id = str(uuid.uuid4())
        request.timestamp = str(datetime.now())

        if len(clients) == 0:
            clients = self.get_active_validators()

        for client in clients:
            request.receiver.name = client.name
            request.receiver.role = fedn.WORKER
            self.SendModelValidationRequest(request, self)

        print(
            "COMBINER: Sent validation request for model {}".format(model_id),
            flush=True)

    def _list_clients(self, channel):
        request = fedn.ListClientsRequest()
        self.__whoami(request.sender, self)
        request.channel = channel
        clients = self.ListActiveClients(request, self)
        return clients.client

    def get_active_trainers(self):
        trainers = self._list_clients(fedn.Channel.MODEL_UPDATE_REQUESTS)
        return trainers

    def get_active_validators(self):
        validators = self._list_clients(fedn.Channel.MODEL_VALIDATION_REQUESTS)
        return validators

    def nr_active_trainers(self):
        return len(self.get_active_trainers())

    def nr_active_validators(self):
        return len(self.get_active_validators())

    ####################################################################################################################

    #def _log_queue_length(self):
    #    ql = self.combiner.model_updates.qsize()
    #    if ql > 0:
    #        self.tracer.set_combiner_queue_length(str(datetime.now()),ql)

    def __join_client(self, client):
        """ Add a client to the combiner. """
        if not client.name in self.clients.keys():
            self.clients[client.name] = {"lastseen": datetime.now()}

    def _subscribe_client_to_queue(self, client, queue_name):
        self.__join_client(client)
        if not queue_name in self.clients[client.name].keys():
            self.clients[client.name][queue_name] = queue.Queue()

    def __get_queue(self, client, queue_name):
        try:
            return self.clients[client.name][queue_name]
        except KeyError:
            raise

    def __get_status_queue(self, client):
        return self.__get_queue(client, fedn.Channel.STATUS)

    def _send_request(self, request, queue_name):
        self.__route_request_to_client(request, request.receiver, queue_name)

    def _broadcast_request(self, request, queue_name):
        """ Publish a request to all subscribed members. """
        active_clients = self._list_active_clients()
        for client in active_clients:
            self.clients[client.name][queue_name].put(request)

    def __route_request_to_client(self, request, client, queue_name):
        try:
            q = self.__get_queue(client, queue_name)
            q.put(request)
        except:
            print("Failed to route request to client: {} {}", request.receiver,
                  queue_name)
            raise

    def _send_status(self, status):

        self.tracer.report(status)
        for name, client in self.clients.items():
            try:
                q = client[fedn.Channel.STATUS]
                status.timestamp = str(datetime.now())
                q.put(status)
            except KeyError:
                pass

    def __register_heartbeat(self, client):
        """ Register a client if first time connecting. Update heartbeat timestamp. """
        self.__join_client(client)
        self.clients[client.name]["lastseen"] = datetime.now()

    #####################################################################################################################

    ## Control Service

    def Start(self, control: fedn.ControlRequest, context):
        response = fedn.ControlResponse()
        print("\n\n\n GOT CONTROL **START** from Command {}\n\n\n".format(
            control.command),
              flush=True)

        config = {}
        for parameter in control.parameter:
            config.update({parameter.key: parameter.value})
        print(
            "\n\n\n\nSTARTING JOB AT COMBINER WITH {}\n\n\n\n".format(config),
            flush=True)

        job_id = self.combiner.push_run_config(config)
        return response

    def Configure(self, control: fedn.ControlRequest, context):
        response = fedn.ControlResponse()
        for parameter in control.parameter:
            setattr(self, parameter.key, parameter.value)
        return response

    def Stop(self, control: fedn.ControlRequest, context):
        response = fedn.ControlResponse()
        print("\n\n\n\n\n GOT CONTROL **STOP** from Command\n\n\n\n\n",
              flush=True)
        return response

    def Report(self, control: fedn.ControlRequest, context):
        """ Descibe current state of the Combiner. """

        response = fedn.ControlResponse()
        print("\n\n\n\n\n GOT CONTROL **REPORT** from Command\n\n\n\n\n",
              flush=True)

        active_clients = self._list_active_clients(
            fedn.Channel.MODEL_UPDATE_REQUESTS)
        nr_active_clients = len(active_clients)

        p = response.parameter.add()
        p.key = "nr_active_clients"
        p.value = str(nr_active_clients)

        p = response.parameter.add()
        p.key = "model_id"
        model_id = self.get_active_model()
        if model_id == None:
            model_id = ""
        p.value = str(model_id)

        p = response.parameter.add()
        p.key = "nr_unprocessed_compute_plans"
        p.value = str(len(self.combiner.run_configs))

        p = response.parameter.add()
        p.key = "name"
        p.value = str(self.id)

        # Get IP information
        #try:
        #    url = 'http://ipinfo.io/json'
        #    data = requests.get(url)
        #    combiner_location = json.loads(data.text)
        #    for key,value in combiner_location.items():
        #        p = response.parameter.add()
        #        p.key = str(key)
        #        p.value = str(value)
        #except Exception as e:
        #    print(e,flush=True)
        #    pass

        return response

    #####################################################################################################################

    def AllianceStatusStream(self, response, context):
        """ A server stream RPC endpoint that emits status messages. """
        status = fedn.Status(
            status="Client {} connecting to AllianceStatusStream.".format(
                response.sender))
        status.log_level = fedn.Status.INFO
        status.sender.name = self.id
        status.sender.role = role_to_proto_role(self.role)
        self._subscribe_client_to_queue(response.sender, fedn.Channel.STATUS)
        q = self.__get_queue(response.sender, fedn.Channel.STATUS)
        self._send_status(status)

        while True:
            yield q.get()

    def SendStatus(self, status: fedn.Status, context):
        # Add the status message to all subscribers of the status channel
        self._send_status(status)

        response = fedn.Response()
        response.response = "Status received."
        return response

    def _list_subscribed_clients(self, queue_name):
        subscribed_clients = []
        for name, client in self.clients.items():
            if queue_name in client.keys():
                subscribed_clients.append(name)
        return subscribed_clients

    def _list_active_clients(self, channel):
        active_clients = []
        for client in self._list_subscribed_clients(channel):
            # This can break with different timezones.
            now = datetime.now()
            then = self.clients[client]["lastseen"]
            # TODO: move the heartbeat timeout to config.
            if (now - then) < timedelta(seconds=10):
                active_clients.append(client)
        return active_clients

    def ListActiveClients(self, request: fedn.ListClientsRequest, context):
        """ RPC endpoint that returns a ClientList containing the names of all active clients.
            An active client has sent a status message / responded to a heartbeat
            request in the last 10 seconds.
        """
        clients = fedn.ClientList()
        active_clients = self._list_active_clients(request.channel)

        for client in active_clients:
            clients.client.append(fedn.Client(name=client, role=fedn.WORKER))
        return clients

    def AcceptingClients(self, request: fedn.ConnectionRequest, context):
        response = fedn.ConnectionResponse()
        active_clients = self._list_active_clients(
            fedn.Channel.MODEL_UPDATE_REQUESTS)

        try:
            #requested = int(self.combiner.config['clients_requested'])
            requested = int(self.max_clients)
            if len(active_clients) >= requested:
                response.status = fedn.ConnectionStatus.NOT_ACCEPTING
                return response
            if len(active_clients) < requested:
                response.status = fedn.ConnectionStatus.ACCEPTING
                return response

        except Exception as e:
            print("Combiner not properly configured! {}".format(e), flush=True)
            raise

        response.status = fedn.ConnectionStatus.TRY_AGAIN_LATER
        return response

    def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context):
        """ RPC that lets clients send a hearbeat, notifying the server that
            the client is available. """
        self.__register_heartbeat(heartbeat.sender)
        response = fedn.Response()
        response.sender.name = heartbeat.sender.name
        response.sender.role = heartbeat.sender.role
        response.response = "Heartbeat received"
        return response

    ## Combiner Service

    def ModelUpdateStream(self, update, context):
        client = update.sender
        status = fedn.Status(
            status="Client {} connecting to ModelUpdateStream.".format(
                client.name))
        status.log_level = fedn.Status.INFO
        status.sender.name = self.id
        status.sender.role = role_to_proto_role(self.role)

        self._subscribe_client_to_queue(client, fedn.Channel.MODEL_UPDATES)
        q = self.__get_queue(client, fedn.Channel.MODEL_UPDATES)

        self._send_status(status)

        while True:
            yield q.get()

    def ModelUpdateRequestStream(self, response, context):
        """ A server stream RPC endpoint. Messages from client stream. """

        client = response.sender
        metadata = context.invocation_metadata()
        if metadata:
            print("\n\n\nGOT METADATA: {}\n\n\n".format(metadata), flush=True)

        status = fedn.Status(
            status="Client {} connecting to ModelUpdateRequestStream.".format(
                client.name))
        status.log_level = fedn.Status.INFO

        self.__whoami(status.sender, self)

        self._subscribe_client_to_queue(client,
                                        fedn.Channel.MODEL_UPDATE_REQUESTS)
        q = self.__get_queue(client, fedn.Channel.MODEL_UPDATE_REQUESTS)

        self._send_status(status)

        while True:
            yield q.get()

    def ModelValidationStream(self, update, context):
        client = update.sender
        status = fedn.Status(
            status="Client {} connecting to ModelValidationStream.".format(
                client.name))
        status.log_level = fedn.Status.INFO

        status.sender.name = self.id
        status.sender.role = role_to_proto_role(self.role)

        self._subscribe_client_to_queue(client, fedn.Channel.MODEL_VALIDATIONS)
        q = self.__get_queue(client, fedn.Channel.MODEL_VALIDATIONS)

        self._send_status(status)

        while True:
            yield q.get()

    def ModelValidationRequestStream(self, response, context):
        """ A server stream RPC endpoint. Messages from client stream. """

        client = response.sender
        status = fedn.Status(
            status="Client {} connecting to ModelValidationRequestStream.".
            format(client.name))
        status.log_level = fedn.Status.INFO
        status.sender.name = self.id
        status.sender.role = role_to_proto_role(self.role)

        self._subscribe_client_to_queue(client,
                                        fedn.Channel.MODEL_VALIDATION_REQUESTS)
        q = self.__get_queue(client, fedn.Channel.MODEL_VALIDATION_REQUESTS)

        self._send_status(status)

        while True:
            yield q.get()

    def SendModelUpdateRequest(self, request, context):
        """ Send a model update request. """
        self._send_request(request, fedn.Channel.MODEL_UPDATE_REQUESTS)

        response = fedn.Response()
        response.response = "CONTROLLER RECEIVED ModelUpdateRequest from client {}".format(
            request.sender.name)
        return response  # TODO Fill later

    def SendModelUpdate(self, request, context):
        """ Send a model update response. """
        self.combiner.receive_model_candidate(request.model_update_id)
        print("ORCHESTRATOR: Received model update", flush=True)

        response = fedn.Response()
        response.response = "RECEIVED ModelUpdate {} from client  {}".format(
            response, response.sender.name)
        return response  # TODO Fill later

    def SendModelValidationRequest(self, request, context):
        """ Send a model update request. """
        self._send_request(request, fedn.Channel.MODEL_VALIDATION_REQUESTS)

        response = fedn.Response()
        response.response = "CONTROLLER RECEIVED ModelValidationRequest from client {}".format(
            request.sender.name)
        return response  # TODO Fill later

    def SendModelValidation(self, request, context):
        """ Send a model update response. """
        # self._send_request(request,fedn.Channel.MODEL_VALIDATIONS)
        self.combiner.receive_validation(request)
        print("ORCHESTRATOR received validation ", flush=True)
        response = fedn.Response()
        response.response = "RECEIVED ModelValidation {} from client  {}".format(
            response, response.sender.name)
        return response  # TODO Fill later

    ## Reducer Service

    def GetGlobalModel(self, request, context):

        response = fedn.GetGlobalModelResponse()
        self.__whoami(response.sender, self)
        response.receiver.name = "reducer"
        response.receiver.role = role_to_proto_role(Role.REDUCER)
        if not self.get_active_model():
            response.model_id = ''
        else:
            response.model_id = self.get_active_model()
        return response

    ####################################################################################################################

    def run(self):
        import signal
        print("COMBINER:starting {}".format(self.id), flush=True)
        try:
            while True:
                signal.pause()
        except (KeyboardInterrupt, SystemExit):
            pass
        self.server.stop()
示例#5
0
    def round(self, config, round_number):
        """ Execute one global round. """

        round_meta = {'round_id': round_number}

        if len(self.network.get_combiners()) < 1:
            print("REDUCER: No combiners connected!")
            return None, round_meta

        # 1. Formulate compute plans for this round and determine which combiners should participate in the round.
        compute_plan = copy.deepcopy(config)
        compute_plan['rounds'] = 1
        compute_plan['round_id'] = round_number
        compute_plan['task'] = 'training'
        compute_plan['model_id'] = self.get_latest_model()
        compute_plan['helper_type'] = self.statestore.get_framework()

        round_meta['compute_plan'] = compute_plan

        combiners = []
        for combiner in self.network.get_combiners():

            try:
                combiner_state = combiner.report()
            except CombinerUnavailableError:
                self._handle_unavailable_combiner(combiner)
                combiner_state = None

            if combiner_state != None:
                is_participating = self.check_round_participation_policy(
                    compute_plan, combiner_state)
                if is_participating:
                    combiners.append((combiner, compute_plan))

        round_start = self.check_round_start_policy(combiners)

        print("CONTROL: round start policy met, participating combiners {}".
              format(combiners),
              flush=True)
        if not round_start:
            print("CONTROL: Round start policy not met, skipping round!",
                  flush=True)
            return None

        # 2. Sync up and ask participating combiners to coordinate model updates
        # TODO refactor

        statestore_config = self.statestore.get_config()

        self.tracer = MongoTracer(statestore_config['mongo_config'],
                                  statestore_config['network_id'])

        start_time = datetime.now()

        for combiner, compute_plan in combiners:
            try:
                self.sync_combiners([combiner], self.get_latest_model())
                response = combiner.start(compute_plan)
            except CombinerUnavailableError:
                # This is OK, handled by round accept policy
                self._handle_unavailable_combiner(combiner)
                pass
            except:
                # Unknown error
                raise

        # Wait until participating combiners have a model that is out of sync with the current global model.
        # TODO: We do not need to wait until all combiners complete before we start reducing.
        cl = []
        for combiner, plan in combiners:
            cl.append(combiner)

        wait = 0.0
        while len(self._out_of_sync(cl)) < len(combiners):
            time.sleep(1.0)
            wait += 1.0
            if wait >= config['round_timeout']:
                break

        # TODO refactor
        end_time = datetime.now()
        round_time = end_time - start_time
        self.tracer.set_combiner_time(round_number, round_time.seconds)

        round_meta['time_combiner_update'] = round_time.seconds

        # OBS! Here we are checking against all combiners, not just those that computed in this round.
        # This means we let straggling combiners participate in the update
        updated = self._out_of_sync()
        print("COMBINERS UPDATED MODELS: {}".format(updated), flush=True)

        print("Checking round validity policy...", flush=True)
        round_valid = self.check_round_validity_policy(updated)
        if round_valid == False:
            # TODO: Should we reset combiner state here?
            print("REDUCER CONTROL: Round invalid!", flush=True)
            return None, round_meta
        print("Round valid.")

        print("Starting reducing models...", flush=True)
        # 3. Reduce combiner models into a global model
        try:
            model, data = self.reduce(updated)
            round_meta['reduce'] = data
        except Exception as e:
            print("CONTROL: Failed to reduce models from combiners: {}".format(
                updated),
                  flush=True)
            print(e, flush=True)
            return None, round_meta
        print("DONE", flush=True)

        # 6. Commit the global model to the ledger
        print("Committing global model...", flush=True)
        if model is not None:
            # Commit to model ledger
            tic = time.time()
            import uuid
            model_id = uuid.uuid4()
            self.commit(model_id, model)
            round_meta['time_commit'] = time.time() - tic
        else:
            print("REDUCER: failed to update model in round with config {}".
                  format(config),
                  flush=True)
            return None, round_meta
        print("DONE", flush=True)

        # 4. Trigger participating combiner nodes to execute a validation round for the current model
        validate = config['validate']
        if validate:
            combiner_config = copy.deepcopy(config)
            combiner_config['model_id'] = self.get_latest_model()
            combiner_config['task'] = 'validation'
            combiner_config['helper_type'] = self.statestore.get_framework()

            validating_combiners = self._select_round_combiners(
                combiner_config)

            for combiner, combiner_config in validating_combiners:
                try:
                    self.sync_combiners([combiner], self.get_latest_model())
                    combiner.start(combiner_config)
                except CombinerUnavailableError:
                    # OK if validation fails for a combiner
                    self._handle_unavailable_combiner(combiner)
                    pass

        # 5. Check commit policy based on validation result (optionally)
        # TODO: Implement.

        return model_id, round_meta
示例#6
0
class ReducerControl:
    """ Main conroller for training round. 

    """
    def __init__(self, statestore):

        self.__state = ReducerState.setup
        self.statestore = statestore
        if self.statestore.is_inited():
            self.network = Network(self, statestore)

        try:
            config = self.statestore.get_storage_backend()
        except:
            print(
                "REDUCER CONTROL: Failed to retrive storage configuration, exiting.",
                flush=True)
            raise MisconfiguredStorageBackend()
        if not config:
            print(
                "REDUCER CONTROL: No storage configuration available, exiting.",
                flush=True)
            raise MisconfiguredStorageBackend()

        if config['storage_type'] == 'S3':
            from fedn.common.storage.s3.s3repo import S3ModelRepository
            self.model_repository = S3ModelRepository(config['storage_config'])
        else:
            print("REDUCER CONTROL: Unsupported storage backend, exiting.",
                  flush=True)
            raise UnsupportedStorageBackend()

        self.client_allocation_policy = self.client_allocation_policy_least_packed

        if self.statestore.is_inited():
            self.__state = ReducerState.idle

    def get_helper(self):
        """

        :return:
        """
        helper_type = self.statestore.get_framework()
        helper = fedn.utils.helpers.get_helper(helper_type)
        if not helper:
            print(
                "CONTROL: Unsupported helper type {}, please configure compute_context.helper !"
                .format(helper_type),
                flush=True)
            return None
        return helper

    def delete_bucket_objects(self):
        """

        :return:
        """
        return self.model_repository.delete_objects()

    def get_state(self):
        """

        :return:
        """
        return self.__state

    def idle(self):
        """

        :return:
        """
        if self.__state == ReducerState.idle:
            return True
        else:
            return False

    def get_first_model(self):
        """

        :return:
        """
        return self.statestore.get_first()

    def get_latest_model(self):
        """

        :return:
        """
        return self.statestore.get_latest()

    def get_model_info(self):
        """

        :return:
        """
        return self.statestore.get_model_info()

    def get_events(self):
        """

        :return:
        """
        return self.statestore.get_events()

    def drop_models(self):
        """

        """
        self.statestore.drop_models()

    def get_compute_context(self):
        """

        :return:
        """
        definition = self.statestore.get_compute_context()
        if definition:
            try:
                context = definition['filename']
                return context
            except (IndexError, KeyError):
                print("No context filename set for compute context definition",
                      flush=True)
                return None
        else:
            return None

    def set_compute_context(self, filename, path):
        """ Persist the configuration for the compute package. """
        self.model_repository.set_compute_context(filename, path)
        self.statestore.set_compute_context(filename)

    def get_compute_package(self, compute_package=''):
        """

        :param compute_package:
        :return:
        """
        if compute_package == '':
            compute_package = self.get_compute_context()
        return self.model_repository.get_compute_package(compute_package)

    def commit(self, model_id, model=None):
        """ Commit a model to the global model trail. The model commited becomes the lastest consensus model. """

        helper = self.get_helper()
        if model is not None:
            print("Saving model to disk...", flush=True)
            outfile_name = helper.save_model(model)
            print("DONE", flush=True)
            print("Uploading model to Minio...", flush=True)
            model_id = self.model_repository.set_model(outfile_name,
                                                       is_file=True)
            print("DONE", flush=True)
            os.unlink(outfile_name)

        self.statestore.set_latest(model_id)

    def _out_of_sync(self, combiners=None):

        if not combiners:
            combiners = self.network.get_combiners()

        osync = []
        for combiner in combiners:
            try:
                model_id = combiner.get_model_id()
            except CombinerUnavailableError:
                self._handle_unavailable_combiner(combiner)
                model_id = None

            if model_id and (model_id != self.get_latest_model()):
                osync.append(combiner)
        return osync

    def check_round_participation_policy(self, compute_plan, combiner_state):
        """ Evaluate reducer level policy for combiner round-participation.
            This is a decision on ReducerControl level, additional checks
            applies on combiner level. Not all reducer control flows might
            need or want to use a participation policy.  """

        if compute_plan['task'] == 'training':
            nr_active_clients = int(combiner_state['nr_active_trainers'])
        elif compute_plan['task'] == 'validation':
            nr_active_clients = int(combiner_state['nr_active_validators'])
        else:
            print("Invalid task type!", flush=True)
            return False

        if int(compute_plan['clients_required']) <= nr_active_clients:
            return True
        else:
            return False

    def check_round_start_policy(self, combiners):
        """ Check if the overall network state meets the policy to start a round. """
        if len(combiners) > 0:
            return True
        else:
            return False

    def check_round_validity_policy(self, combiners):
        """
            At the end of the round, before committing a model to the model ledger,
            we check if a round validity policy has been met. This can involve
            e.g. asserting that a certain number of combiners have reported in an
            updated model, or that criteria on model performance have been met.
        """
        if combiners == []:
            return False
        else:
            return True

    def _handle_unavailable_combiner(self, combiner):
        """ This callback is triggered if a combiner is found to be unresponsive. """
        # TODO: Implement strategy to handle the case.
        print("REDUCER CONTROL: Combiner {} unavailable.".format(
            combiner.name),
              flush=True)

    def _select_round_combiners(self, compute_plan):
        combiners = []
        for combiner in self.network.get_combiners():
            try:
                combiner_state = combiner.report()
            except CombinerUnavailableError:
                self._handle_unavailable_combiner(combiner)
                combiner_state = None

            if combiner_state:
                is_participating = self.check_round_participation_policy(
                    compute_plan, combiner_state)
                if is_participating:
                    combiners.append((combiner, compute_plan))
        return combiners

    def round(self, config, round_number):
        """ Execute one global round. """

        round_meta = {'round_id': round_number}

        if len(self.network.get_combiners()) < 1:
            print("REDUCER: No combiners connected!")
            return None, round_meta

        # 1. Formulate compute plans for this round and determine which combiners should participate in the round.
        compute_plan = copy.deepcopy(config)
        compute_plan['rounds'] = 1
        compute_plan['round_id'] = round_number
        compute_plan['task'] = 'training'
        compute_plan['model_id'] = self.get_latest_model()
        compute_plan['helper_type'] = self.statestore.get_framework()

        round_meta['compute_plan'] = compute_plan

        combiners = []
        for combiner in self.network.get_combiners():

            try:
                combiner_state = combiner.report()
            except CombinerUnavailableError:
                self._handle_unavailable_combiner(combiner)
                combiner_state = None

            if combiner_state != None:
                is_participating = self.check_round_participation_policy(
                    compute_plan, combiner_state)
                if is_participating:
                    combiners.append((combiner, compute_plan))

        round_start = self.check_round_start_policy(combiners)

        print("CONTROL: round start policy met, participating combiners {}".
              format(combiners),
              flush=True)
        if not round_start:
            print("CONTROL: Round start policy not met, skipping round!",
                  flush=True)
            return None

        # 2. Sync up and ask participating combiners to coordinate model updates
        # TODO refactor

        statestore_config = self.statestore.get_config()

        self.tracer = MongoTracer(statestore_config['mongo_config'],
                                  statestore_config['network_id'])

        start_time = datetime.now()

        for combiner, compute_plan in combiners:
            try:
                self.sync_combiners([combiner], self.get_latest_model())
                response = combiner.start(compute_plan)
            except CombinerUnavailableError:
                # This is OK, handled by round accept policy
                self._handle_unavailable_combiner(combiner)
                pass
            except:
                # Unknown error
                raise

        # Wait until participating combiners have a model that is out of sync with the current global model.
        # TODO: We do not need to wait until all combiners complete before we start reducing.
        cl = []
        for combiner, plan in combiners:
            cl.append(combiner)

        wait = 0.0
        while len(self._out_of_sync(cl)) < len(combiners):
            time.sleep(1.0)
            wait += 1.0
            if wait >= config['round_timeout']:
                break

        # TODO refactor
        end_time = datetime.now()
        round_time = end_time - start_time
        self.tracer.set_combiner_time(round_number, round_time.seconds)

        round_meta['time_combiner_update'] = round_time.seconds

        # OBS! Here we are checking against all combiners, not just those that computed in this round.
        # This means we let straggling combiners participate in the update
        updated = self._out_of_sync()
        print("COMBINERS UPDATED MODELS: {}".format(updated), flush=True)

        print("Checking round validity policy...", flush=True)
        round_valid = self.check_round_validity_policy(updated)
        if round_valid == False:
            # TODO: Should we reset combiner state here?
            print("REDUCER CONTROL: Round invalid!", flush=True)
            return None, round_meta
        print("Round valid.")

        print("Starting reducing models...", flush=True)
        # 3. Reduce combiner models into a global model
        try:
            model, data = self.reduce(updated)
            round_meta['reduce'] = data
        except Exception as e:
            print("CONTROL: Failed to reduce models from combiners: {}".format(
                updated),
                  flush=True)
            print(e, flush=True)
            return None, round_meta
        print("DONE", flush=True)

        # 6. Commit the global model to the ledger
        print("Committing global model...", flush=True)
        if model is not None:
            # Commit to model ledger
            tic = time.time()
            import uuid
            model_id = uuid.uuid4()
            self.commit(model_id, model)
            round_meta['time_commit'] = time.time() - tic
        else:
            print("REDUCER: failed to update model in round with config {}".
                  format(config),
                  flush=True)
            return None, round_meta
        print("DONE", flush=True)

        # 4. Trigger participating combiner nodes to execute a validation round for the current model
        validate = config['validate']
        if validate:
            combiner_config = copy.deepcopy(config)
            combiner_config['model_id'] = self.get_latest_model()
            combiner_config['task'] = 'validation'
            combiner_config['helper_type'] = self.statestore.get_framework()

            validating_combiners = self._select_round_combiners(
                combiner_config)

            for combiner, combiner_config in validating_combiners:
                try:
                    self.sync_combiners([combiner], self.get_latest_model())
                    combiner.start(combiner_config)
                except CombinerUnavailableError:
                    # OK if validation fails for a combiner
                    self._handle_unavailable_combiner(combiner)
                    pass

        # 5. Check commit policy based on validation result (optionally)
        # TODO: Implement.

        return model_id, round_meta

    def sync_combiners(self, combiners, model_id):
        """ Spread the current consensus model to all active combiner nodes. """
        if not model_id:
            print("GOT NO MODEL TO SET! Have you seeded the FedML model?",
                  flush=True)
            return

        for combiner in combiners:
            response = combiner.set_model_id(model_id)

    def instruct(self, config):
        """ Main entrypoint, executes the compute plan. """

        if self.__state == ReducerState.instructing:
            print("Already set in INSTRUCTING state", flush=True)
            return

        self.__state = ReducerState.instructing

        if not self.get_latest_model():
            print("No model in model chain, please seed the alliance!")

        self.__state = ReducerState.monitoring

        # TODO: Validate and set the round config object
        # self.set_config(config)

        # TODO: Refactor
        from fedn.common.tracer.mongotracer import MongoTracer
        statestore_config = self.statestore.get_config()
        self.tracer = MongoTracer(statestore_config['mongo_config'],
                                  statestore_config['network_id'])
        last_round = self.tracer.get_latest_round()

        for round in range(1, int(config['rounds'] + 1)):
            tic = time.time()
            if last_round:
                current_round = last_round + round
            else:
                current_round = round

            from datetime import datetime
            start_time = datetime.now()
            # start round monitor
            self.tracer.start_monitor(round)
            # todo add try except bloc for round meta
            model_id = None
            round_meta = {'round_id': current_round}
            try:
                model_id, round_meta = self.round(config, current_round)
            except TypeError:
                print("Could not unpack data from round...", flush=True)

            end_time = datetime.now()

            if model_id:
                print("REDUCER: Global round completed, new model: {}".format(
                    model_id),
                      flush=True)
                round_time = end_time - start_time
                self.tracer.set_latest_time(current_round, round_time.seconds)
                round_meta['status'] = 'Success'
            else:
                print("REDUCER: Global round failed!")
                round_meta['status'] = 'Failed'

            # stop round monitor
            self.tracer.stop_monitor()
            round_meta['time_round'] = time.time() - tic
            self.tracer.set_round_meta_reducer(round_meta)

        self.__state = ReducerState.idle

    def reduce(self, combiners):
        """ Combine current models at Combiner nodes into one global model. """

        meta = {}
        meta['time_fetch_model'] = 0.0
        meta['time_load_model'] = 0.0
        meta['time_aggregate_model'] = 0.0

        i = 1
        model = None
        for combiner in combiners:

            # TODO: Handle inactive RPC error in get_model and raise specific error
            try:
                tic = time.time()
                data = combiner.get_model()
                meta['time_fetch_model'] += (time.time() - tic)
            except:
                pass

            helper = self.get_helper()

            if data is not None:
                try:
                    tic = time.time()
                    model_str = combiner.get_model().getbuffer()
                    model_next = helper.load_model_from_BytesIO(model_str)
                    meta['time_load_model'] += (time.time() - tic)
                    tic = time.time()
                    model = helper.increment_average(model, model_next, i)
                    meta['time_aggregate_model'] += (time.time() - tic)
                except:
                    tic = time.time()
                    model = helper.load_model_from_BytesIO(data.getbuffer())
                    meta['time_aggregate_model'] += (time.time() - tic)
                i = i + 1

        return model, meta

    def monitor(self, config=None):
        """

        :param config:
        """
        # status = self.network.check_health()
        pass

    def client_allocation_policy_first_available(self):
        """
            Allocate client to the first available combiner in the combiner list.
            Packs one combiner full before filling up next combiner.
        """
        for combiner in self.network.get_combiners():
            if combiner.allowing_clients():
                return combiner
        return None

    def client_allocation_policy_least_packed(self):
        """
            Allocate client to the available combiner with the smallest number of clients.
            Spreads clients evenly over all active combiners.

            TODO: Not thread safe - not garanteed to result in a perfectly even partition.

        """
        min_clients = None
        selected_combiner = None

        for combiner in self.network.get_combiners():
            try:
                if combiner.allowing_clients():
                    combiner_state = combiner.report()
                    nac = combiner_state['nr_active_clients']
                    if not min_clients:
                        min_clients = nac
                        selected_combiner = combiner
                    elif nac < min_clients:
                        min_clients = nac
                        selected_combiner = combiner
            except CombinerUnavailableError as err:
                print("Combiner was not responding, continuing to next")

        return selected_combiner

    def find(self, name):
        """

        :param name:
        :return:
        """
        for combiner in self.network.get_combiners():
            if name == combiner.name:
                return combiner
        return None

    def find_available_combiner(self):
        """

        :return:
        """
        combiner = self.client_allocation_policy()
        return combiner

    def state(self):
        """

        :return:
        """
        return self.__state
示例#7
0
    def __init__(self, connect_config):
        self.clients = {}

        import io
        from collections import defaultdict
        self.modelservice = ModelService()

        self.model_id = None

        self.role = Role.COMBINER

        self.id = connect_config['myname']
        address = connect_config['myhost']
        port = connect_config['myport']

        self.max_clients = connect_config['max_clients']

        from fedn.common.net.connect import ConnectorCombiner, Status
        announce_client = ConnectorCombiner(
            host=connect_config['discover_host'],
            port=connect_config['discover_port'],
            myhost=connect_config['myhost'],
            myport=connect_config['myport'],
            token=connect_config['token'],
            name=connect_config['myname'])

        import time
        response = None
        while True:
            status, response = announce_client.announce()
            if status == Status.TryAgain:
                time.sleep(5)
                continue
            if status == Status.Assigned:
                config = response
                print(
                    "COMBINER: was announced successfully. Waiting for clients and commands!",
                    flush=True)
                break

        import base64
        cert = base64.b64decode(response['certificate'])  # .decode('utf-8')
        key = base64.b64decode(response['key'])  # .decode('utf-8')

        grpc_config = {
            'port': port,
            'secure': connect_config['secure'],
            'certificate': cert,
            'key': key
        }

        # TODO remove temporary hardcoded config of storage persistance backend
        combiner_config = {
            'storage_access_key': os.environ['FEDN_MINIO_ACCESS_KEY'],
            'storage_secret_key': os.environ['FEDN_MINIO_SECRET_KEY'],
            'storage_bucket': 'models',
            'storage_secure_mode': False,
            'storage_hostname': os.environ['FEDN_MINIO_HOST'],
            'storage_port': int(os.environ['FEDN_MINIO_PORT'])
        }

        self.repository = S3ModelRepository(combiner_config)
        self.bucket_name = combiner_config["storage_bucket"]

        self.server = Server(self, self.modelservice, grpc_config)

        from fedn.algo.fedavg import FEDAVGCombiner
        self.combiner = FEDAVGCombiner(self.id, self.repository, self,
                                       self.modelservice)

        threading.Thread(target=self.combiner.run, daemon=True).start()

        from fedn.common.tracer.mongotracer import MongoTracer
        self.tracer = MongoTracer()

        self.server.start()