Ejemplo n.º 1
0
async def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    kwargs_websocket = {"hook": hook, "verbose": args.verbose, "host": "0.0.0.0"}
    alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
    bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
    charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)
    testing = websocket_client.WebsocketClientWorker(id="testing", port=8780, **kwargs_websocket)

    for wcw in [alice, bob, charlie, testing]:
        wcw.clear_objects_remote()

    worker_instances = [alice, bob, charlie]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    model = Net().to(device)

    traced_model = torch.jit.trace(model, torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device))
    learning_rate = args.lr

    for curr_round in range(1, args.training_rounds + 1):
        logger.info("Training round %s/%s", curr_round, args.training_rounds)

        results = await asyncio.gather(
            *[
                fit_model_on_worker(
                    worker=worker,
                    traced_model=traced_model,
                    batch_size=args.batch_size,
                    curr_round=curr_round,
                    max_nr_batches=args.federate_after_n_batches,
                    lr=learning_rate,
                )
                for worker in worker_instances
            ]
        )
        models = {}
        loss_values = {}

        test_models = curr_round % 10 == 1 or curr_round == args.training_rounds
        if test_models:
            logger.info("Evaluating models")
            np.set_printoptions(formatter={"float": "{: .0f}".format})
            for worker_id, worker_model, _ in results:
                evaluate_model_on_worker(
                    model_identifier="Model update " + worker_id,
                    worker=testing,
                    dataset_key="mnist_testing",
                    model=worker_model,
                    nr_bins=10,
                    batch_size=128,
                    device=args.device,
                    print_target_hist=False,
                )

        # Federate models (note that this will also change the model in models[0]
        for worker_id, worker_model, worker_loss in results:
            if worker_model is not None:
                models[worker_id] = worker_model
                loss_values[worker_id] = worker_loss

        traced_model = utils.federated_avg(models)

        if test_models:
            evaluate_model_on_worker(
                model_identifier="Federated model",
                worker=testing,
                dataset_key="mnist_testing",
                model=traced_model,
                nr_bins=10,
                batch_size=128,
                device=args.device,
                print_target_hist=False,
            )

        # decay learning rate
        learning_rate = max(0.98 * learning_rate, args.lr * 0.01)

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
Ejemplo n.º 2
0
import torch

import syft as sy
from syft.workers import websocket_client as wsc

hook = sy.TorchHook(torch)

data = torch.tensor([[0., 0.], [0., 2.]], requires_grad=True)
# target = torch.tensor([[0.], [0.]], requires_grad=True)

kwargs_websocket = {"hook": hook,
                    "verbose": True,
                    "host": "127.0.0.1"}

mbp = wsc.WebsocketClientWorker(id="mbp",
                                port=8001,
                                **kwargs_websocket)
print(mbp.list_objects_remote())
data.send(mbp)
print(mbp.list_objects_remote())


if __name__ == "__main__":
    pass
Ejemplo n.º 3
0
async def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    kwargs_websocket = {
        "hook": hook,
        "verbose": args.verbose,
        "host": "127.0.0.1"
    }
    alice = websocket_client.WebsocketClientWorker(id="alice",
                                                   port=8777,
                                                   **kwargs_websocket)
    bob = websocket_client.WebsocketClientWorker(id="bob",
                                                 port=8778,
                                                 **kwargs_websocket)
    charlie = websocket_client.WebsocketClientWorker(id="charlie",
                                                     port=8779,
                                                     **kwargs_websocket)
    testing = websocket_client.WebsocketClientWorker(id="testing",
                                                     port=8780,
                                                     **kwargs_websocket)

    for wcw in [alice, bob, charlie, testing]:
        wcw.clear_objects_remote()

    worker_instances = [alice, bob, charlie]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    model = Net().to(device)

    traced_model = torch.jit.trace(
        model,
        torch.zeros([1, 47], dtype=torch.long).to(device))
    learning_rate = args.lr

    for epoch in range(1, 11):
        logger.info("Training epoch %s/%s", epoch, 10)

        results = await asyncio.gather(*[
            fit_model_on_worker(
                worker=worker,
                traced_model=traced_model,
                batch_size=args.batch_size,
                epoch=epoch,
                max_nr_batches=-1,
                lr=learning_rate,
            ) for worker in worker_instances
        ])
        models = {}
        loss_values = {}

        test_models = epoch > 0 and epoch <= 10
        if test_models:
            logger.info("Evaluating models")
            np.set_printoptions(formatter={"float": "{: .0f}".format})
            for worker_id, worker_model, _ in results:
                evaluate_model_on_worker(
                    model_identifier="Model update " + worker_id,
                    worker=testing,
                    dataset_key="dga_testing",
                    model=worker_model,
                    nr_bins=2,
                    batch_size=500,
                    device=device,
                    print_target_hist=False,
                )

        for worker_id, worker_model, worker_loss in results:
            if worker_model is not None:
                models[worker_id] = worker_model
                loss_values[worker_id] = worker_loss

        traced_model = utils.federated_avg(models)

        if test_models:
            evaluate_model_on_worker(
                model_identifier="Federated model",
                worker=testing,
                dataset_key="dga_testing",
                model=traced_model,
                nr_bins=2,
                batch_size=500,
                device=device,
                print_target_hist=False,
            )

        # decay learning rate
        learning_rate = max(0.98 * learning_rate, args.lr * 0.01)
Ejemplo n.º 4
0
async def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    if (args.localworkers):
        # ----------------------------- This is for localhost workers --------------------------------
        kwargs_websocket = {
            "hook": hook,
            "verbose": args.verbose,
            "host": "0.0.0.0"
        }
        alice = websocket_client.WebsocketClientWorker(id="alice",
                                                       port=8777,
                                                       **kwargs_websocket)
        bob = websocket_client.WebsocketClientWorker(id="bob",
                                                     port=8778,
                                                     **kwargs_websocket)
        charlie = websocket_client.WebsocketClientWorker(id="charlie",
                                                         port=8779,
                                                         **kwargs_websocket)
        testing = websocket_client.WebsocketClientWorker(id="testing",
                                                         port=8780,
                                                         **kwargs_websocket)
    else:
        # ----------------------------- This is for remote workers ------------------------------------
        kwargs_websocket_alice = {"host": "128.226.78.195", "hook": hook}
        alice = websocket_client.WebsocketClientWorker(
            id="alice", port=8777, **kwargs_websocket_alice)

        kwargs_websocket_bob = {"host": "128.226.77.222", "hook": hook}
        bob = websocket_client.WebsocketClientWorker(id="bob",
                                                     port=8777,
                                                     **kwargs_websocket_bob)

        kwargs_websocket_charlie = {"host": "128.226.88.120", "hook": hook}
        charlie = websocket_client.WebsocketClientWorker(
            id="charlie", port=8777, **kwargs_websocket_charlie)

        # kwargs_websocket_testing = {"host": "128.226.77.111", "hook": hook}
        kwargs_websocket_testing = {"host": "128.226.88.210", "hook": hook}
        testing = websocket_client.WebsocketClientWorker(
            id="testing", port=8777, **kwargs_websocket_testing)

    for wcw in [alice, bob, charlie, testing]:
        wcw.clear_objects_remote()

    worker_instances = [alice, bob, charlie]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    model = Net().to(device)
    if (os.path.isfile('mnist_cnn_asyn.pt')):
        model.load_state_dict(torch.load("mnist_cnn_asyn.pt"))
        model.eval()

    traced_model = torch.jit.trace(
        model,
        torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device))
    learning_rate = args.lr

    # Execute traning and test process round
    for curr_round in range(1, args.training_rounds + 1):
        logger.info("Training round %s/%s", curr_round, args.training_rounds)

        results = await asyncio.gather(*[
            fit_model_on_worker(
                worker=worker,
                traced_model=traced_model,
                batch_size=args.batch_size,
                curr_round=curr_round,
                max_nr_batches=args.federate_after_n_batches,
                lr=learning_rate,
            ) for worker in worker_instances
        ])
        models = {}
        loss_values = {}

        # Apply evaluate model for each 10 round and at the last round
        test_models = curr_round % 10 == 1 or curr_round == args.training_rounds
        if test_models:
            logger.info("Evaluating models")
            np.set_printoptions(formatter={"float": "{: .0f}".format})
            for worker_id, worker_model, _ in results:
                evaluate_model_on_worker(
                    model_identifier="Model update " + worker_id,
                    worker=testing,
                    dataset_key="mnist_testing",
                    model=worker_model,
                    nr_bins=10,
                    batch_size=128,
                    device=device,
                    print_target_hist=False,
                )

        # Federate models (note that this will also change the model in models[0]
        for worker_id, worker_model, worker_loss in results:
            if worker_model is not None:
                models[worker_id] = worker_model
                loss_values[worker_id] = worker_loss

        traced_model = utils.federated_avg(models)

        if test_models:
            evaluate_model_on_worker(
                model_identifier="Federated model",
                worker=testing,
                dataset_key="mnist_testing",
                model=traced_model,
                nr_bins=10,
                batch_size=128,
                device=device,
                print_target_hist=False,
            )
            # save indermediate model
            model_dir = "models_asyn"
            if (not os.path.exists(model_dir)):
                os.makedirs(model_dir)
            model_name = "{}/mnist_cnn_{}.pt".format(model_dir, curr_round)
            torch.save(traced_model.state_dict(), model_name)

        # decay learning rate
        learning_rate = max(0.98 * learning_rate, args.lr * 0.01)

    if args.save_model:
        torch.save(traced_model.state_dict(), "mnist_cnn_asyn.pt")