Beispiel #1
0
def main():
    """ A Plato federated learning training session using the HGB algorithm. """
    _ = Config()
    support_modalities = ['rgb', "flow", "audio"]

    kinetics_data_source = kinetics.DataSource()

    # define the sub-nets to be untrained
    for modality_nm in support_modalities:
        modality_model_nm = modality_nm + "_model"
        if modality_model_nm in Config.multimodal_nets_configs.keys():
            modality_net = Config.multimodal_nets_configs[modality_model_nm]
            modality_net['backbone']['pretrained'] = None
            if "pretrained2d" in list(modality_net['backbone'].keys()):
                modality_net['backbone']['pretrained2d'] = False

    # define the model
    multi_model = multimodal_module.DynamicMultimodalModule(
        support_modality_names=support_modalities,
        multimodal_nets_configs=Config.multimodal_nets_configs,
        is_fused_head=True)

    trainer = basic.Trainer(model=multi_model)

    client = simple.Client(model=multi_model,
                           datasource=kinetics_data_source,
                           trainer=trainer)

    server = fedavg.Server(model=multi_model, trainer=trainer)

    server.run(client)
Beispiel #2
0
def main():
    """ A Plato federated learning training session using Adaptive Synchronization Frequency. """
    trainer = basic.Trainer()
    algorithm = adaptive_sync_algorithm.Algorithm(trainer=trainer)
    client = adaptive_sync_client.Client(algorithm=algorithm, trainer=trainer)
    server = fedavg.Server(algorithm=algorithm, trainer=trainer)
    server.run(client)
Beispiel #3
0
def main():
    """A Plato federated learning training session using a custom model. """

    model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))

    datasource = DataSource()
    trainer = Trainer(model=model)

    client = simple.Client(model=model, datasource=datasource, trainer=trainer)
    server = fedavg.Server(model=model, trainer=trainer)
    server.run(client)
Beispiel #4
0
async def test_fedavg_aggregation(self):
    print("\nTesting federated averaging.")
    updates = []
    model = copy.deepcopy(self.model)
    server = fedavg_server.Server(model=model)
    trainer = basic.Trainer(model=model)
    algorithm = fedavg_alg.Algorithm(trainer=trainer)
    server.trainer = trainer
    server.algorithm = algorithm

    weights = copy.deepcopy(self.algorithm.extract_weights())
    print(f"Report 1 weights: {weights}")
    updates.append((simple.Report(1, 100, 0, 0, 0), weights, 0))

    self.model.train()

    self.optimizer.zero_grad()
    self.model.loss_criterion(self.model(self.example), self.label).backward()
    self.optimizer.step()
    self.assertEqual(44.0, self.model(self.example).item())
    weights = copy.deepcopy(self.algorithm.extract_weights())
    print(f"Report 2 weights: {weights}")
    updates.append((simple.Report(1, 100, 0, 0, 0), weights, 0))

    self.optimizer.zero_grad()
    self.model.loss_criterion(self.model(self.example), self.label).backward()
    self.optimizer.step()
    self.assertEqual(43.2, np.round(self.model(self.example).item(), 4))
    weights = copy.deepcopy(self.algorithm.extract_weights())
    print(f"Report 3 Weights: {weights}")
    updates.append((simple.Report(1, 100, 0, 0, 0), weights, 0))

    self.optimizer.zero_grad()
    self.model.loss_criterion(self.model(self.example), self.label).backward()
    self.optimizer.step()
    self.assertEqual(42.56, np.round(self.model(self.example).item(), 4))
    weights = copy.deepcopy(self.algorithm.extract_weights())
    print(f"Report 4 Weights: {weights}")
    updates.append((simple.Report(1, 100, 0, 0, 0), weights, 0))

    print(
        f"Weights before federated averaging: {server.model.layer.weight.data}"
    )

    update = await server.federated_averaging(updates)
    updated_weights = server.algorithm.update_weights(update)
    server.algorithm.load_weights(updated_weights)

    print(
        f"Weights after federated averaging: {server.model.layer.weight.data}")
    self.assertEqual(42.56, np.round(self.model(self.example).item(), 4))
Beispiel #5
0
def main():
    """ A Plato federated learning training session using the Sub-FedAvg algorithm. """
    trainer = subfedavg_trainer.Trainer()
    client = subfedavg_client.Client(trainer=trainer)
    server = fedavg.Server()
    server.run(client)
Beispiel #6
0
def main():
    """ A Plato federated learning training session using FedAsync. """
    trainer = fedprox_trainer.Trainer()
    client = simple.Client(trainer=trainer)
    server = fedavg.Server()
    server.run(client)