def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None) -> None: from plato.config import Config from plato.datasources import base # set parameters server = Config().server._asdict() clients = Config().clients._asdict() datastore = Config().data._asdict() train = Config().trainer._asdict() self.datasource = None if data is not None: if hasattr(data, "customized"): if data.customized: self.datasource = base.DataSource() self.datasource.trainset = data.trainset self.datasource.testset = data.testset else: datastore.update(data.parameters) Config().data = Config.namedtuple_from_dict(datastore) self.model = None if estimator is not None: self.model = estimator.model train.update(estimator.hyperparameters) Config().trainer = Config.namedtuple_from_dict(train) if aggregation is not None: Config().algorithm = Config.namedtuple_from_dict( aggregation.parameters) if aggregation.parameters["type"] == "mistnet": clients["type"] = "mistnet" server["type"] = "mistnet" else: clients["do_test"] = True server["address"] = Context.get_parameters("AGG_IP") server["port"] = Context.get_parameters("AGG_PORT") if transmitter is not None: server.update(transmitter.parameters) Config().server = Config.namedtuple_from_dict(server) Config().clients = Config.namedtuple_from_dict(clients) from plato.clients import registry as client_registry self.client = client_registry.get(model=self.model, datasource=self.datasource) self.client.configure()
def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None, chooser=None) -> None: from plato.config import Config # set parameters server = Config().server._asdict() clients = Config().clients._asdict() datastore = Config().data._asdict() train = Config().trainer._asdict() if data is not None: datastore.update(data.parameters) Config().data = Config.namedtuple_from_dict(datastore) self.model = None if estimator is not None: self.model = estimator.model if estimator.pretrained is not None: Config().params['pretrained_model_dir'] = estimator.pretrained if estimator.saved is not None: Config().params['model_dir'] = estimator.saved train.update(estimator.hyperparameters) Config().trainer = Config.namedtuple_from_dict(train) server["address"] = Context.get_parameters("AGG_BIND_IP", "0.0.0.0") server["port"] = int(Context.get_parameters("AGG_BIND_PORT", 7363)) if transmitter is not None: server.update(transmitter.parameters) if aggregation is not None: Config().algorithm = Config.namedtuple_from_dict( aggregation.parameters) if aggregation.parameters["type"] == "mistnet": clients["type"] = "mistnet" server["type"] = "mistnet" else: clients["do_test"] = True if chooser is not None: clients["per_round"] = chooser.parameters["per_round"] LOGGER.info("address %s, port %s", server["address"], server["port"]) Config().server = Config.namedtuple_from_dict(server) Config().clients = Config.namedtuple_from_dict(clients) from plato.servers import registry as server_registry self.server = server_registry.get(model=self.model)