def main(): """ A Plato federated learning training session using the HGB algorithm. """ _ = Config() support_modalities = ['rgb', "flow", "audio"] kinetics_data_source = kinetics.DataSource() # define the sub-nets to be untrained for modality_nm in support_modalities: modality_model_nm = modality_nm + "_model" if modality_model_nm in Config.multimodal_nets_configs.keys(): modality_net = Config.multimodal_nets_configs[modality_model_nm] modality_net['backbone']['pretrained'] = None if "pretrained2d" in list(modality_net['backbone'].keys()): modality_net['backbone']['pretrained2d'] = False # define the model multi_model = multimodal_module.DynamicMultimodalModule( support_modality_names=support_modalities, multimodal_nets_configs=Config.multimodal_nets_configs, is_fused_head=True) trainer = basic.Trainer(model=multi_model) client = simple.Client(model=multi_model, datasource=kinetics_data_source, trainer=trainer) server = fedavg.Server(model=multi_model, trainer=trainer) server.run(client)
def main(): """ A Plato federated learning training session using Adaptive Synchronization Frequency. """ trainer = basic.Trainer() algorithm = adaptive_sync_algorithm.Algorithm(trainer=trainer) client = adaptive_sync_client.Client(algorithm=algorithm, trainer=trainer) server = fedavg.Server(algorithm=algorithm, trainer=trainer) server.run(client)
def main(): """A Plato federated learning training session using a custom model. """ model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) datasource = DataSource() trainer = Trainer(model=model) client = simple.Client(model=model, datasource=datasource, trainer=trainer) server = fedavg.Server(model=model, trainer=trainer) server.run(client)
async def test_fedavg_aggregation(self): print("\nTesting federated averaging.") updates = [] model = copy.deepcopy(self.model) server = fedavg_server.Server(model=model) trainer = basic.Trainer(model=model) algorithm = fedavg_alg.Algorithm(trainer=trainer) server.trainer = trainer server.algorithm = algorithm weights = copy.deepcopy(self.algorithm.extract_weights()) print(f"Report 1 weights: {weights}") updates.append((simple.Report(1, 100, 0, 0, 0), weights, 0)) self.model.train() self.optimizer.zero_grad() self.model.loss_criterion(self.model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(44.0, self.model(self.example).item()) weights = copy.deepcopy(self.algorithm.extract_weights()) print(f"Report 2 weights: {weights}") updates.append((simple.Report(1, 100, 0, 0, 0), weights, 0)) self.optimizer.zero_grad() self.model.loss_criterion(self.model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(43.2, np.round(self.model(self.example).item(), 4)) weights = copy.deepcopy(self.algorithm.extract_weights()) print(f"Report 3 Weights: {weights}") updates.append((simple.Report(1, 100, 0, 0, 0), weights, 0)) self.optimizer.zero_grad() self.model.loss_criterion(self.model(self.example), self.label).backward() self.optimizer.step() self.assertEqual(42.56, np.round(self.model(self.example).item(), 4)) weights = copy.deepcopy(self.algorithm.extract_weights()) print(f"Report 4 Weights: {weights}") updates.append((simple.Report(1, 100, 0, 0, 0), weights, 0)) print( f"Weights before federated averaging: {server.model.layer.weight.data}" ) update = await server.federated_averaging(updates) updated_weights = server.algorithm.update_weights(update) server.algorithm.load_weights(updated_weights) print( f"Weights after federated averaging: {server.model.layer.weight.data}") self.assertEqual(42.56, np.round(self.model(self.example).item(), 4))
def main(): """ A Plato federated learning training session using the Sub-FedAvg algorithm. """ trainer = subfedavg_trainer.Trainer() client = subfedavg_client.Client(trainer=trainer) server = fedavg.Server() server.run(client)
def main(): """ A Plato federated learning training session using FedAsync. """ trainer = fedprox_trainer.Trainer() client = simple.Client(trainer=trainer) server = fedavg.Server() server.run(client)