def __init__(self, work_mode=WorkModeStrategy.WORKMODE_STANDALONE, models=None, data=None, client_id=0, client_ip="", client_port=8081, server_url="", curve=False, local_epoch=5, concurrent_num=5): self.work_mode = work_mode self.data = data self.client_id = str(client_id) self.local_epoch = local_epoch self.concurrent_num = concurrent_num self.trainer_executor_pool = ThreadPoolExecutor(self.concurrent_num) self.job_path = JOB_PATH self.models = models self.fed_step = {} self.job_train_strategy = {} self.client_ip = client_ip self.client_port = str(client_port) self.server_url = server_url self.curve = curve self.logger = LoggerFactory.getLogger("TrainerController", logging.INFO)
def __init__(self, job, data, fed_step, client_id, model, curve): super(TrainStandloneDistillationStrategy, self).__init__(job, data, fed_step, client_id, model, curve) self.train_model = self._load_job_model( job.get_job_id(), job.get_train_model_class_name()) self.logger = LoggerFactory.getLogger( "TrainStandloneDistillationStrategy", logging.INFO)
def __init__(self, job, data, fed_step, client_ip, client_port, server_url, client_id, model, curve): super(TrainMPCDistillationStrategy, self).__init__(job, data, fed_step, client_id, model, curve) self.client_ip = client_ip self.client_port = client_port self.server_url = server_url self.logger = LoggerFactory.getLogger("TrainMPCDistillationStrategy", logging.INFO)
def __init__(self, job, data, fed_step, client_id, model, curve): super(TrainStandloneNormalStrategy, self).__init__(job, data, fed_step, client_id, model, curve) self.logger = LoggerFactory.getLogger("TrainStandloneNormalStrategy", logging.INFO)
def __init__(self): self.job_path = JOB_PATH self.logger = LoggerFactory.getLogger("JobManager", logging.INFO)
# limitations under the License. import os import json import logging from flask import Flask, send_from_directory, request from werkzeug.serving import run_simple from gfl.entity.runtime_config import CONNECTED_TRAINER_LIST from gfl.core.job_manager import JobManager from gfl.utils.utils import JobEncoder, return_data_decorator, LoggerFactory API_VERSION = "/api/v1" JOB_PATH = os.path.join(os.path.abspath("."), "res", "jobs_server") BASE_MODEL_PATH = os.path.join(os.path.abspath("."), "res", "models") logger = LoggerFactory.getLogger(__name__, logging.INFO) app = Flask(__name__) @app.route("/test/<name>") @return_data_decorator def test_flask_server(name): return name, 200 @app.route("/register/<ip>/<port>/<client_id>", methods=['POST'], endpoint='register_trainer') @return_data_decorator def register_trainer(ip, port, client_id):
def __init__(self): super(FLServer, self).__init__() self.logger = LoggerFactory.getLogger("FlServer", logging.INFO)
def __init__(self, work_mode, job_path, base_model_path): super(FedAvgAggregator, self).__init__(work_mode, job_path, base_model_path) self.fed_step = {} self.logger = LoggerFactory.getLogger("FedAvgAggregator", logging.INFO)