def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None) -> None: from plato.config import Config from plato.datasources import base # set parameters server = Config().server._asdict() clients = Config().clients._asdict() datastore = Config().data._asdict() train = Config().trainer._asdict() self.datasource = None if data is not None: if hasattr(data, "customized"): if data.customized: self.datasource = base.DataSource() self.datasource.trainset = data.trainset self.datasource.testset = data.testset else: datastore.update(data.parameters) Config().data = Config.namedtuple_from_dict(datastore) self.model = None if estimator is not None: self.model = estimator.model train.update(estimator.hyperparameters) Config().trainer = Config.namedtuple_from_dict(train) if aggregation is not None: Config().algorithm = Config.namedtuple_from_dict( aggregation.parameters) if aggregation.parameters["type"] == "mistnet": clients["type"] = "mistnet" server["type"] = "mistnet" else: clients["do_test"] = True server["address"] = Context.get_parameters("AGG_IP") server["port"] = Context.get_parameters("AGG_PORT") if transmitter is not None: server.update(transmitter.parameters) Config().server = Config.namedtuple_from_dict(server) Config().clients = Config.namedtuple_from_dict(clients) from plato.clients import registry as client_registry self.client = client_registry.get(model=self.model, datasource=self.datasource) self.client.configure()
def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None, chooser=None) -> None: from plato.config import Config # set parameters server = Config().server._asdict() clients = Config().clients._asdict() datastore = Config().data._asdict() train = Config().trainer._asdict() if data is not None: datastore.update(data.parameters) Config().data = Config.namedtuple_from_dict(datastore) self.model = None if estimator is not None: self.model = estimator.model if estimator.pretrained is not None: Config().params['pretrained_model_dir'] = estimator.pretrained if estimator.saved is not None: Config().params['model_dir'] = estimator.saved train.update(estimator.hyperparameters) Config().trainer = Config.namedtuple_from_dict(train) server["address"] = Context.get_parameters("AGG_BIND_IP", "0.0.0.0") server["port"] = int(Context.get_parameters("AGG_BIND_PORT", 7363)) if transmitter is not None: server.update(transmitter.parameters) if aggregation is not None: Config().algorithm = Config.namedtuple_from_dict( aggregation.parameters) if aggregation.parameters["type"] == "mistnet": clients["type"] = "mistnet" server["type"] = "mistnet" else: clients["do_test"] = True if chooser is not None: clients["per_round"] = chooser.parameters["per_round"] LOGGER.info("address %s, port %s", server["address"], server["port"]) Config().server = Config.namedtuple_from_dict(server) Config().clients = Config.namedtuple_from_dict(clients) from plato.servers import registry as server_registry self.server = server_registry.get(model=self.model)
def setup(self): super().setup() self.addTypeEqualityFunc(Config, 'assertConfigEqual') self.defined_config = Config() # define several example parameters data_params_config = { "downloader": { "num_workers": 4 }, "multi_modal_pipeliner": { "rgb": { "rgb_data": { "train": { "type": "RawframeDataset" } } }, "flow": { "flow_data": { "train": { "type": "RawframeDataset" } } }, "audio": { "audio_data": { "train": { "type": "AudioFeatureDataset" } } } } } model_params_config = { "model_name": "rgb_flow_audio_model", "model_config": { "rgb_model": { "type": "Recognizer3D" } } } self.data_config = Config.namedtuple_from_dict(data_params_config) self.model_config = Config.namedtuple_from_dict(model_params_config)