async def fit_model_on_worker( worker, built_model: sy.Plan, built_loss_fn: sy.Plan, encrypters, batch_size: int, curr_round: int, max_nr_batches: int, lr: float, ): """Send the model to the worker and fit the model on the worker's training data. Args: worker: Remote location, where the model shall be trained. traced_model: Model which shall be trained. batch_size: Batch size of each training step. curr_round: Index of the current training round (for logging purposes). max_nr_batches: If > 0, training on worker will stop at min(max_nr_batches, nr_available_batches). lr: Learning rate of each training step. Returns: A tuple containing: * worker_id: Union[int, str], id of the worker. * improved model: torch.jit.ScriptModule, model after training at the worker. * loss: Loss on last training batch, torch.tensor. """ num_of_parameters = len(built_model.parameters()) built_model.id = "GlobalModel_MNIST" built_loss_fn.id = "LossFunc" model_config = sy.ModelConfig(model=built_model, loss_fn=built_loss_fn, optimizer="SGD", batch_size=batch_size, optimizer_args={"lr": lr}, epochs=1, max_nr_batches=max_nr_batches) # model_config_send_start = time.time() # pdb.set_trace() # model_config.send(worker) # model_config_send_end = time.time() # print("[trace] GlobalInformationSend duration", worker.id, model_config_send_end - model_config_send_start) return_ids = [0, 1] for i in range(num_of_parameters): return_ids.append("p" + str(i)) fit_sagg_start = time.time() result_list = await worker.async_fit2_sagg_mc(model_config, dataset_key="mnist", encrypters=encrypters, return_ids=return_ids) fit_sagg_end = time.time() print("[trace] FitSagg", "duration", worker.id, fit_sagg_end - fit_sagg_start) loss = result_list[0] num_of_training_data = result_list[1] enc_params = result_list[2:] print("Iteration %s: %s loss: %s" % (curr_round, worker.id, loss)) return worker.id, enc_params, loss, num_of_training_data
def evaluate_model_on_worker(model_identifier, worker, dataset_key, model, built_loss_fn, nr_bins, batch_size, device, print_target_hist=False): model_config = sy.ModelConfig(batch_size=batch_size, model=model, loss_fn=built_loss_fn, optimizer_args=None, epochs=1) model_config.send(worker) result = worker.evaluate_mc( dataset_key=dataset_key, return_histograms=True, nr_bins=nr_bins, return_loss=True, return_raw_accuracy=True, device=device, ) test_loss = result["loss"] correct = result["nr_correct_predictions"] len_dataset = result["nr_predictions"] hist_pred = result["histogram_predictions"] hist_target = result["histogram_target"] if print_target_hist: logger.info("Target histogram: %s", hist_target) percentage_0_3 = int(100 * sum(hist_pred[0:4]) / len_dataset) percentage_4_6 = int(100 * sum(hist_pred[4:7]) / len_dataset) percentage_7_9 = int(100 * sum(hist_pred[7:10]) / len_dataset) logger.info( "%s: Percentage numbers 0-3: %s%%, 4-6: %s%%, 7-9: %s%%", model_identifier, percentage_0_3, percentage_4_6, percentage_7_9, ) logger.info( "%s: Average loss: %s, Accuracy: %s/%s (%s%%)", model_identifier, f"{test_loss:.4f}", correct, len_dataset, f"{100.0 * correct / len_dataset:.2f}", ) print("[trace]", "TestAccuracy", "acc", "testing", 100.0 * correct / len_dataset)
async def main(): hook = sy.TorchHook(torch) device = torch.device("cpu") optimizer = "SGD" epochs = 1 shuffle = True model = Net() model.build(torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device)) # model.build(torch.zeros([2], dtype=torch.float).to(device)) @sy.func2plan(args_shape=[(-1, 1), (-1, 1)]) def loss_fn(target, pred): return ((target.view(pred.shape).float() - pred.float())**2).mean() batch_size = 64 lr = 0.1 learning_rate = lr optimizer_args = {"lr": lr} model_config = sy.ModelConfig(model=model, loss_fn=loss_fn, optimizer=optimizer, batch_size=batch_size, optimizer_args=optimizer_args, epochs=epochs, shuffle=shuffle) # alice = NodeClient(hook, "ws://172.16.179.20:6666" , id="alice") # bob = NodeClient(hook, "ws://172.16.179.21:6667" , id="bob") # charlie = NodeClient(hook, "ws://172.16.179.22:6668", id="charlie") # testing = NodeClient(hook, "ws://localhost:6669" , id="testing") # worker_list = [alice, bob, charlie] worker_list = [] for i in range(2, 2 + 12): worker = NodeClient(hook, "ws://" + flvm_ip[i] + ":6666", id="flvm-" + str(i)) worker_list.append(worker) for worker in worker_list: model_config.send(worker) grid = sy.PrivateGridNetwork(*worker_list) num_of_parameters = len(model.parameters()) return_ids = [] for i in range(num_of_parameters): return_ids.append("p" + str(i)) start = time.time() # worker_0 = worker_list[0] # worker_1 = worker_list[1] # worker_2 = worker_list[2] enc_results = await asyncio.gather(*[ worker.async_model_share(worker_list, return_ids=return_ids) for worker in worker_list ]) end = time.time() ## aggregation dst_enc_model = enc_results[0] agg_start = time.time() with torch.no_grad(): for i in range(len(dst_enc_model)): layer_start = time.time() for j in range(1, len(enc_results)): add_start = time.time() dst_enc_model[i] += enc_results[j][i] print("[PROF]", "AddParams", time.time() - add_start) print("[PROF]", "Layer" + str(i), time.time() - layer_start) print("[PROF]", "AggTime", time.time() - agg_start)