Exemple #1
0
async def fit_model_on_worker(
    worker,
    built_model: sy.Plan,
    built_loss_fn: sy.Plan,
    encrypters,
    batch_size: int,
    curr_round: int,
    max_nr_batches: int,
    lr: float,
):
    """Send the model to the worker and fit the model on the worker's training data.

    Args:
        worker: Remote location, where the model shall be trained.
        traced_model: Model which shall be trained.
        batch_size: Batch size of each training step.
        curr_round: Index of the current training round (for logging purposes).
        max_nr_batches: If > 0, training on worker will stop at min(max_nr_batches, nr_available_batches).
        lr: Learning rate of each training step.

    Returns:
        A tuple containing:
            * worker_id: Union[int, str], id of the worker.
            * improved model: torch.jit.ScriptModule, model after training at the worker.
            * loss: Loss on last training batch, torch.tensor.
    """
    num_of_parameters = len(built_model.parameters())
    built_model.id = "GlobalModel_MNIST"
    built_loss_fn.id = "LossFunc"
    model_config = sy.ModelConfig(model=built_model,
                              loss_fn=built_loss_fn,
                              optimizer="SGD",
                              batch_size=batch_size,
                              optimizer_args={"lr": lr},
                              epochs=1,
                              max_nr_batches=max_nr_batches)
    # model_config_send_start = time.time()
    # pdb.set_trace()
    # model_config.send(worker)
    # model_config_send_end = time.time()
    # print("[trace] GlobalInformationSend duration", worker.id, model_config_send_end - model_config_send_start)

    return_ids = [0, 1]
    for i in range(num_of_parameters):
        return_ids.append("p" + str(i))

    fit_sagg_start = time.time()
    result_list = await worker.async_fit2_sagg_mc(model_config, dataset_key="mnist", encrypters=encrypters, return_ids=return_ids)
    fit_sagg_end = time.time()
    print("[trace] FitSagg", "duration", worker.id, fit_sagg_end - fit_sagg_start)

    loss = result_list[0]
    num_of_training_data = result_list[1]
    enc_params = result_list[2:]

    print("Iteration %s: %s loss: %s" % (curr_round, worker.id, loss))

    return worker.id, enc_params, loss, num_of_training_data
Exemple #2
0
def send_model_to_worker(
    worker,
    built_model: sy.Plan,
):
    """Send the model to the worker and fit the model on the worker's training data.

    Args:
        worker: Remote location, where the model shall be trained.
        traced_model: Model which shall be trained.
        batch_size: Batch size of each training step.
        curr_round: Index of the current training round (for logging purposes).
        max_nr_batches: If > 0, training on worker will stop at min(max_nr_batches, nr_available_batches).
        lr: Learning rate of each training step.

    Returns:
        A tuple containing:
            * worker_id: Union[int, str], id of the worker.
            * improved model: torch.jit.ScriptModule, model after training at the worker.
            * loss: Loss on last training batch, torch.tensor.
    """
    built_model.id = "GlobalModel"

    model_send_start = time.time()
    # pdb.set_trace()
    built_model.send(worker)
    print("[trace] GlobalModelSend duration", worker.id,
          time.time() - model_send_start)

    return None