コード例 #1
0
    def __init__(self, data=None, estimator=None,
                 aggregation=None, transmitter=None) -> None:

        from plato.config import Config
        from plato.datasources import base
        # set parameters
        server = Config().server._asdict()
        clients = Config().clients._asdict()
        datastore = Config().data._asdict()
        train = Config().trainer._asdict()
        self.datasource = None
        if data is not None:
            if hasattr(data, "customized"):
                if data.customized:
                    self.datasource = base.DataSource()
                    self.datasource.trainset = data.trainset
                    self.datasource.testset = data.testset
            else:
                datastore.update(data.parameters)
                Config().data = Config.namedtuple_from_dict(datastore)

        self.model = None
        if estimator is not None:
            self.model = estimator.model
            train.update(estimator.hyperparameters)
            Config().trainer = Config.namedtuple_from_dict(train)

        if aggregation is not None:
            Config().algorithm = Config.namedtuple_from_dict(
                aggregation.parameters)
            if aggregation.parameters["type"] == "mistnet":
                clients["type"] = "mistnet"
                server["type"] = "mistnet"
            else:
                clients["do_test"] = True

        server["address"] = Context.get_parameters("AGG_IP")
        server["port"] = Context.get_parameters("AGG_PORT")

        if transmitter is not None:
            server.update(transmitter.parameters)

        Config().server = Config.namedtuple_from_dict(server)
        Config().clients = Config.namedtuple_from_dict(clients)

        from plato.clients import registry as client_registry
        self.client = client_registry.get(model=self.model,
                                          datasource=self.datasource)
        self.client.configure()
コード例 #2
0
    def __init__(self, data=None, estimator=None,
                 aggregation=None, transmitter=None,
                 chooser=None) -> None:
        from plato.config import Config
        # set parameters
        server = Config().server._asdict()
        clients = Config().clients._asdict()
        datastore = Config().data._asdict()
        train = Config().trainer._asdict()

        if data is not None:
            datastore.update(data.parameters)
            Config().data = Config.namedtuple_from_dict(datastore)

        self.model = None
        if estimator is not None:
            self.model = estimator.model
            if estimator.pretrained is not None:
                Config().params['pretrained_model_dir'] = estimator.pretrained
            if estimator.saved is not None:
                Config().params['model_dir'] = estimator.saved
            train.update(estimator.hyperparameters)
            Config().trainer = Config.namedtuple_from_dict(train)

        server["address"] = Context.get_parameters("AGG_BIND_IP", "0.0.0.0")
        server["port"] = int(Context.get_parameters("AGG_BIND_PORT", 7363))
        if transmitter is not None:
            server.update(transmitter.parameters)

        if aggregation is not None:
            Config().algorithm = Config.namedtuple_from_dict(
                aggregation.parameters)
            if aggregation.parameters["type"] == "mistnet":
                clients["type"] = "mistnet"
                server["type"] = "mistnet"
            else:
                clients["do_test"] = True

        if chooser is not None:
            clients["per_round"] = chooser.parameters["per_round"]

        LOGGER.info("address %s, port %s", server["address"], server["port"])

        Config().server = Config.namedtuple_from_dict(server)
        Config().clients = Config.namedtuple_from_dict(clients)

        from plato.servers import registry as server_registry
        self.server = server_registry.get(model=self.model)
コード例 #3
0
ファイル: config_tests.py プロジェクト: TL-System/plato
    def setup(self):
        super().setup()

        self.addTypeEqualityFunc(Config, 'assertConfigEqual')

        self.defined_config = Config()

        # define several example parameters
        data_params_config = {
            "downloader": {
                "num_workers": 4
            },
            "multi_modal_pipeliner": {
                "rgb": {
                    "rgb_data": {
                        "train": {
                            "type": "RawframeDataset"
                        }
                    }
                },
                "flow": {
                    "flow_data": {
                        "train": {
                            "type": "RawframeDataset"
                        }
                    }
                },
                "audio": {
                    "audio_data": {
                        "train": {
                            "type": "AudioFeatureDataset"
                        }
                    }
                }
            }
        }
        model_params_config = {
            "model_name": "rgb_flow_audio_model",
            "model_config": {
                "rgb_model": {
                    "type": "Recognizer3D"
                }
            }
        }

        self.data_config = Config.namedtuple_from_dict(data_params_config)
        self.model_config = Config.namedtuple_from_dict(model_params_config)