async def fit_model_on_worker( worker, built_model: sy.Plan, built_loss_fn: sy.Plan, encrypters, batch_size: int, curr_round: int, max_nr_batches: int, lr: float, ): """Send the model to the worker and fit the model on the worker's training data. Args: worker: Remote location, where the model shall be trained. traced_model: Model which shall be trained. batch_size: Batch size of each training step. curr_round: Index of the current training round (for logging purposes). max_nr_batches: If > 0, training on worker will stop at min(max_nr_batches, nr_available_batches). lr: Learning rate of each training step. Returns: A tuple containing: * worker_id: Union[int, str], id of the worker. * improved model: torch.jit.ScriptModule, model after training at the worker. * loss: Loss on last training batch, torch.tensor. """ num_of_parameters = len(built_model.parameters()) built_model.id = "GlobalModel_MNIST" built_loss_fn.id = "LossFunc" model_config = sy.ModelConfig(model=built_model, loss_fn=built_loss_fn, optimizer="SGD", batch_size=batch_size, optimizer_args={"lr": lr}, epochs=1, max_nr_batches=max_nr_batches) # model_config_send_start = time.time() # pdb.set_trace() # model_config.send(worker) # model_config_send_end = time.time() # print("[trace] GlobalInformationSend duration", worker.id, model_config_send_end - model_config_send_start) return_ids = [0, 1] for i in range(num_of_parameters): return_ids.append("p" + str(i)) fit_sagg_start = time.time() result_list = await worker.async_fit2_sagg_mc(model_config, dataset_key="mnist", encrypters=encrypters, return_ids=return_ids) fit_sagg_end = time.time() print("[trace] FitSagg", "duration", worker.id, fit_sagg_end - fit_sagg_start) loss = result_list[0] num_of_training_data = result_list[1] enc_params = result_list[2:] print("Iteration %s: %s loss: %s" % (curr_round, worker.id, loss)) return worker.id, enc_params, loss, num_of_training_data
def send_model_to_worker( worker, built_model: sy.Plan, ): """Send the model to the worker and fit the model on the worker's training data. Args: worker: Remote location, where the model shall be trained. traced_model: Model which shall be trained. batch_size: Batch size of each training step. curr_round: Index of the current training round (for logging purposes). max_nr_batches: If > 0, training on worker will stop at min(max_nr_batches, nr_available_batches). lr: Learning rate of each training step. Returns: A tuple containing: * worker_id: Union[int, str], id of the worker. * improved model: torch.jit.ScriptModule, model after training at the worker. * loss: Loss on last training batch, torch.tensor. """ built_model.id = "GlobalModel" model_send_start = time.time() # pdb.set_trace() built_model.send(worker) print("[trace] GlobalModelSend duration", worker.id, time.time() - model_send_start) return None