def unitary_training( model_name: str, xy_train: FederatedDatasetPartition, xy_val: FederatedDatasetPartition, xy_test: FederatedDatasetPartition, E: int, B: int, ) -> Tuple[KerasHistory, float, float]: model_fn = load_model_fn(model_name) model_provider = ModelProvider(model_fn=model_fn) # Initialize model and participant cid = 0 participant = Participant( cid, model_provider, xy_train=xy_train, xy_val=xy_val, num_classes=10, batch_size=B, ) model = model_provider.init_model() theta = model.get_weights() # Train model hist = participant.fit(model, E) # Evaluate final performance theta = model.get_weights() loss, acc = participant.evaluate(theta, xy_test) # Report results return hist, loss, acc
def federated_training( model_name: str, xy_train_partitions: List[Partition], xy_val: Partition, xy_test: Partition, R: int, E: int, C: float, B: int, aggregator: Aggregator = None, ) -> Tuple[History, List[List[History]], List[List[Dict]], List[List[Metrics]], float, float]: # Initialize participants and coordinator # Note that there is no need for common initialization at this point: Common # initialization will happen during the first few rounds because the coordinator will # push its own weight to the respective participants of each training round. model_fn = load_model_fn(model_name) lr_fn_fn = load_lr_fn_fn(model_name) model_provider = ModelProvider(model_fn=model_fn, lr_fn_fn=lr_fn_fn) # Init participants participants = [] for cid, xy_train in enumerate(xy_train_partitions): participant = Participant(cid, model_provider, xy_train, xy_val, num_classes=10, batch_size=B) participants.append(participant) num_participants = len(participants) # Init coordinator controller = RandomController(num_participants) coordinator = Coordinator( controller, model_provider, participants, C=C, E=E, xy_val=xy_val, aggregator=aggregator, ) # Train model hist_co, hist_ps, hist_opt_configs, hist_metrics = coordinator.fit( num_rounds=R) # Evaluate final performance loss, acc = coordinator.evaluate(xy_test) # Report results return hist_co, hist_ps, hist_opt_configs, hist_metrics, loss, acc
def __init__( self, controller, model_provider: ModelProvider, participants: List[Participant], C: float, E: int, xy_val: Partition, aggregator: Optional[Aggregator] = None, ) -> None: self.controller = controller self.model = model_provider.init_model() self.participants = participants self.C = C self.E = E self.xy_val = xy_val self.aggregator = aggregator if aggregator else FederatedAveragingAgg() self.epoch = 0 # Count training epochs