def train(self, model: nn.Module, train_data: DataContainer, context: FederatedLearning.Context, config: TrainerParams) -> Tuple[any, int]: model.to(self.device) model.train() optimizer = config.get_optimizer()(model) criterion = config.get_criterion() epoch_loss = [] for epoch in range(config.epochs): batch_loss = [] for batch_idx, (x, labels) in enumerate( train_data.batch(config.batch_size)): x = x.to(self.device) labels = labels.to(self.device) optimizer.zero_grad() log_probs = model(x) loss = criterion(log_probs, labels) loss.backward() optimizer.step() batch_loss.append(loss.item()) if len(batch_loss) > 0: epoch_loss.append(sum(batch_loss) / len(batch_loss)) weights = model.cpu().state_dict() return weights, len(train_data)
def _create(self, trainer_id, config: TrainerParams) -> Trainer: trainer = config.trainer_class() self.trainers[trainer_id] = trainer return trainer
from src.federated.federated import Events from src.federated.federated import FederatedLearning from src.federated.protocols import TrainerParams from src.federated.components.trainer_manager import SeqTrainerManager, SharedTrainerProvider from src.federated.subscribers import Timer logging.basicConfig(level=logging.INFO) logger = logging.getLogger('main') logger.info('Generating Data --Started') client_data = data_loader.cifar10_10shards_100c_400min_400max() logger.info('Generating Data --Ended') trainer_params = TrainerParams(trainer_class=trainers.TorchTrainer, batch_size=50, epochs=1, optimizer='sgd', criterion='cel', lr=0.1) federated = FederatedLearning( trainer_manager=SeqTrainerManager(), trainer_config=trainer_params, aggregator=aggregators.AVGAggregator(), metrics=metrics.AccLoss(batch_size=50, criterion=nn.CrossEntropyLoss()), client_selector=client_selectors.Random(0.2), trainers_data_dict=client_data, initial_model=lambda: resnet56(10, 3, 32), num_rounds=50, desired_accuracy=0.99, ) federated.add_subscriber(