Esempio n. 1
0
    def P2P_aggregate(self, experiment_id, global_model_id, send_model_task_id,
                      clients, aggregation_config, round_num):

        # in case of P2P, only one client at a time should be passed to method below
        # work with clients field in task collection
        task_dict = list(self.fl_db.task.find({"_id": send_model_task_id}))[0]
        clients = list(task_dict['clients'].keys())
        clients_model_updates = utils.load_clients_model_updates(
            experiment_id=experiment_id,
            send_model_task_id=send_model_task_id,
            clients=clients)

        # here we retrieve dict object, with fields pickle and trees
        clients_model_updates = clients_model_updates[clients[0]]

        logging.info(
            "NUMBER OF TREES: {}. Last tree built on client {}".format(
                len(eval(clients_model_updates["trees"])), clients[0]))

        global_model_document, _ = utils.get_model(db=self.fl_db,
                                                   model_id=global_model_id)

        # in P2P global model is just overwritten by the model sent from the client
        utils.save_model_parameters(model_id=global_model_document['_id'],
                                    model_document=global_model_document,
                                    parameters=clients_model_updates,
                                    send_model_task_id=send_model_task_id,
                                    overwrite=True)
Esempio n. 2
0
    def NN_aggregate(self, experiment_id, global_model_id, send_model_task_id,
                     clients, aggregation_config, round_num):
        clients_model_updates = utils.load_clients_model_updates(
            experiment_id=experiment_id,
            send_model_task_id=send_model_task_id,
            clients=clients)

        global_model_document, global_model_parameters = utils.get_model(
            db=self.fl_db, model_id=global_model_id)

        update_rate = aggregation_config.get("update_rate", 1)
        update_decay = aggregation_config.get("update_decay", 0)
        update_rate = update_rate - round_num * update_decay / aggregation_config.get(
            "total_rounds", 1)  # todo better
        logging.info(update_rate)
        updated_model_parameters = utils.NN_aggregate_model_updates(
            clients_model_updates=clients_model_updates,
            global_model_parameters=global_model_parameters,
            update_rate=update_rate,
            verbose=aggregation_config.get("verbose", 0),
            aggregation_type=aggregation_config.get("aggregation_type",
                                                    'mean'))

        utils.save_model_parameters(model_id=global_model_document['_id'],
                                    model_document=global_model_document,
                                    parameters=updated_model_parameters,
                                    send_model_task_id=send_model_task_id,
                                    overwrite=True)
Esempio n. 3
0
    def reset_experiment(self, experiment_id):
        with self.db_session.start_transaction():

            experiment_documents = list(
                self.fl_db.experiment.find({
                    "_id": experiment_id
                }).limit(1))
            if len(experiment_documents) == 0:
                logging.info(
                    f"The experiment you are trying to reset does not exist. {experiment_id}"
                )
                return False
            if experiment_documents[0].get('is_finished', False):
                logging.info(
                    f"You cannot reset a finished experiment {experiment_id}")
                return False
            _, parameters = utils.get_model(
                db=self.fl_db,
                model_id=experiment_documents[0]['start_model_id'])

            experiment_state_document, _ = utils.get_model(
                db=self.fl_db,
                model_id=experiment_documents[0]['experiment_state_model_id'])

            utils.save_model_parameters(
                model_id=experiment_state_document['_id'],
                model_document=experiment_state_document,
                parameters=parameters,
                overwrite=True)

            self.__reset_task_list(experiment_document=experiment_documents[0])

            result = self.fl_db.experiment.update_one({"_id": experiment_id}, {
                "$set": {
                    "is_running": False,
                    "is_finished": False,
                    "has_failed": False
                }
            },
                                                      session=self.db_session)
        return True
Esempio n. 4
0
 def __RF_add_trees_to_random_forest(self, experiment_id,
                                     experiment_id_list):
     experiment_document = list(
         self.fl_db.experiment.find({"_id": experiment_id}))[0]
     global_model_id = experiment_document['experiment_state_model_id']
     global_model_document, global_model_parameters = utils.get_model(
         db=self.fl_db, model_id=global_model_id)
     global_model_parameters['forest'] = []
     # update the global_model_parameters by grouping all trees together in the list
     for exp_id in experiment_id_list:
         exp_doc = list(self.fl_db.experiment.find({"_id": exp_id}))[0]
         exp_model_id = exp_doc['experiment_state_model_id']
         _, exp_model_params = utils.get_model(db=self.fl_db,
                                               model_id=exp_model_id)
         global_model_parameters['forest'].append(
             exp_model_params["forest"][0])
     # save global forest
     utils.save_model_parameters(model_id=global_model_document['_id'],
                                 model_document=global_model_document,
                                 parameters=global_model_parameters,
                                 send_model_task_id=None,
                                 overwrite=True)
Esempio n. 5
0
    def RF_aggregate(self, experiment_id, global_model_id, send_model_task_id,
                     clients, aggregation_config, round_num):
        """

        """
        # get updates
        clients_model_updates = utils.load_clients_model_updates(
            experiment_id=experiment_id,
            send_model_task_id=send_model_task_id,
            clients=clients)
        # per worker, we have received one dictionary, indexable by the model-feature-indices (as string)
        # each feature has a list associated with it, corresponding to a list of bins (dicts)
        # get current model and parameters (forest with one tree)
        global_model_document, global_model_parameters = utils.get_model(
            db=self.fl_db, model_id=global_model_id)

        # aggregate all histograms (dictionaries) per feature together
        # find optimal split in given information
        # update model parameters by inserting a new tree-node into the json-tree
        # Also done: Update the current_condition_list under global_model_document
        updated_model_document, updated_model_parameters, finished = utils.RF_aggregate_model_updates(
            clients_model_updates=clients_model_updates,
            global_model_document=global_model_document,
            global_model_parameters=global_model_parameters,
            clients=clients)

        # save new model parameters by writing down the json file
        utils.save_model_parameters(model_id=updated_model_document['_id'],
                                    model_document=updated_model_document,
                                    parameters=updated_model_parameters,
                                    send_model_task_id=send_model_task_id,
                                    overwrite=True)
        updated_model_document['_id'] = ObjectId(updated_model_document['_id'])
        result = self.fl_db.model.replace_one(
            {"_id": updated_model_document['_id']}, updated_model_document)
        # if the tree is finished, mark all tasks to this tree as finished
        if finished:

            task_list_update = self.fl_db.experiment.update_many(
                {"_id": ObjectId(experiment_id)},
                {"$set": {
                    "task_list.$[].task_status": config['TASK_DONE']
                }})

            logging.info("Aggregated and finished a tree%s", experiment_id)
        else:
            logging.info("Aggregated histograms%s", experiment_id)
Esempio n. 6
0
    def define_experiment(self,
                          start_model_id,
                          training_config,
                          tasks,
                          clients,
                          git_version,
                          experiment_name,
                          experiment_description,
                          model_name=None,
                          experiment_id=None,
                          experiment_state_model_id=None,
                          model_description=None,
                          testing=False):
        """
        object start_model_id:      id of the model we start your experiment with,
        dict training_config:       Configuration json on how to train the model. Is attached to the model_config of the
                                    experiment_state_model,
        list tasks:                 list of the task to perform in correct order
        list clients                list of clients to run the experiment with
        string git_version:         the last commit hash from git used in the script
        string experiment_name, experiment_description: strings to help identify the experiment
        string model_name, model_description:   strings to help identify the experiment state model if not set
                                                inherit from experiment
        bool testing: if the model is just for testing set to True
        """
        # with self.db_session.start_transaction():
        start_model_document, parameters = utils.get_model(
            db=self.fl_db, model_id=start_model_id)
        if start_model_document['is_running']:
            raise Exception("start_model is state model of other experiment")
        start_model_document['model']['training'] = training_config

        experiment_state_model_id = self.define_model(
            model_id=experiment_state_model_id,
            model=start_model_document['model'],
            parameters=parameters,
            protocol=start_model_document['protocol'],
            model_name=model_name if model_name else experiment_name,
            model_description=model_description
            if model_description else experiment_description,
            git_version=git_version,
            is_running=True,
            new_transaction=False,
            testing=testing)

        experiment_document = utils.build_experiment_document(
            experiment_id=experiment_id,
            start_model_id=start_model_id,
            experiment_state_model_id=experiment_state_model_id,
            training_config=training_config,
            task_list=[],
            is_running=False,
            protocol=start_model_document['protocol'],
            clients=clients,
            git_version=git_version,
            experiment_description=experiment_description,
            experiment_name=experiment_name,
            testing=testing)

        result = self.fl_db.experiment.insert_one(experiment_document,
                                                  session=self.db_session)
        experiment_id = result.inserted_id

        task_list = self.__define_task_list(
            experiment_id=experiment_id,
            tasks=tasks,
            clients=clients,
            testing=testing,
            protocol=start_model_document['protocol'])

        result = self.fl_db.experiment.update_one(
            {"_id": experiment_id}, {"$set": {
                "task_list": task_list
            }})

        return experiment_id