예제 #1
0
def train(model, device, federated_train_loader):
    model.train()

    nr_batches = federate_after_n_batches

    models = {}
    loss_values = {}

    iter(federated_train_loader)  # initialize iterators
    batches = get_next_batches(federated_train_loader, nr_batches)
    counter = 0

    while True:
        logger.debug("Starting training round, batches [{}, {}]".format(
            counter, counter + nr_batches))
        data_for_all_workers = True
        for worker in batches:
            curr_batches = batches[worker]
            if curr_batches:
                models[worker], loss_values[worker] = train_on_batches(
                    worker, curr_batches, model, device, learning_rate)
            else:
                data_for_all_workers = False
        counter += nr_batches
        if not data_for_all_workers:
            logger.debug("At least one worker ran out of data, stopping.")
            break

        model = utils.federated_avg(models)
        batches = get_next_batches(federated_train_loader, nr_batches)
    return model
def train():
    for data_index in range(len(remote_dataset[0])-1):
        for remote_index in range(len(compute_nodes)):
            data, target = remote_dataset[remote_index][data_index]
            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])
        for model in models:
            model.get()
        return utils.federated_avg({
            "sm1": models[0],
            "sm2": models[1]
        })
예제 #3
0
    def train(self):
        for data_index in range(len(self.remote_dataset[0]) - 1):
            # update remote models
            for remote_index in range(len(self.compute_nodes)):
                data, target = self.remote_dataset[remote_index][data_index]
                self.models[remote_index] = train_process(data, target, self.models[remote_index], self.optimizers[remote_index])

            for model in self.models:
                model.get()
            return utils.federated_avg({
                "bob": self.models[0],
                "alice": self.models[1]
            })
예제 #4
0
def train_on_devices(remote_dataset, devices, models, optimizers):
    # iterate through each worker's dataset seperately
    for data_index in range(len(remote_dataset[0]) - 1):
        for device_index in range(len(devices)):
            data, target = remote_dataset[device_index][data_index]
            models[device_index] = update_model(data, target,
                                                models[device_index],
                                                optimizers[device_index])

        for model in models:
            model.get()

        return utils.federated_avg({'bob': models[0], 'alice': models[1]})
def fed_avg_every_n_iters(model_pointers, iter, federate_after_n_batches):
        models_local = {}
        
        if(iter % args.federate_after_n_batches == 0):
            for worker_name, model_pointer in model_pointers.items():
#                #need to assign the model to the worker it belongs to.
                models_local[worker_name] = model_pointer.copy().get()
            model_avg = utils.federated_avg(models_local)
           
            for worker in workers_virtual:
                model_copied_avg = model_avg.copy()
                model_ptr = model_copied_avg.send(worker) 
                model_pointers[worker.id] = model_ptr
                
        return(model_pointers)     
예제 #6
0
async def train(train_loader, valid_loader, model, architect, w_optim,
                alpha_optim, lr, epoch):
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()
    losses = utils.AverageMeter()

    cur_step = epoch * len(train_loader)
    writer.add_scalar('train/lr', lr, cur_step)

    model.train()

    for step in range(len(remote_train_data[0]) - 1):
        results = await asyncio.gather(*[
            update(step, i, model, alpha_optim, w_optim, architect, lr)
            for i in range(len(workers))
        ])

        models = {}

        for i, r in enumerate(results):
            models[r[0]] = r[1]
            losses.update(r[2].item(), r[5])
            top1.update(r[3].item(), r[5])
            top5.update(r[4].item(), r[5])

            writer.add_scalar('train/loss ' + str(i), r[2].item(), cur_step)
            writer.add_scalar('train/top1 ' + str(i), r[3].item(), cur_step)
            writer.add_scalar('train/top5 ' + str(i), r[4].item(), cur_step)

        model = federated_avg(models)

        if step % config.print_freq == 0 or step == len(train_loader) - 1:
            logger.info(
                "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    epoch + 1,
                    config.epochs,
                    step,
                    len(train_loader) - 1,
                    losses=losses,
                    top1=top1,
                    top5=top5))

        cur_step += 1

    logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
        epoch + 1, config.epochs, top1.avg))
예제 #7
0
def test_federated_avg():
    class Net(th.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = th.nn.Linear(2, 2)

    net1 = Net()
    net2 = Net()
    net3 = Net()

    models = {}
    models[0] = net1
    models[1] = net2
    models[2] = net3

    avg_model = utils.federated_avg(models)
    assert avg_model != net1
    assert (avg_model.fc1.weight.data != net1.fc1.weight.data).all()
    assert (avg_model.fc1.bias.data != net1.fc1.bias.data).all()
예제 #8
0
def train(model,
          device,
          federated_train_loader,
          lr,
          federate_after_n_batches,
          abort_after_one=False):
    model.train()

    nr_batches = federate_after_n_batches

    models = {}
    loss_values = {}

    iter(federated_train_loader)  # initialize iterators
    batches = get_next_batches(federated_train_loader, nr_batches)
    counter = 0

    while True:
        print(
            f"Starting training round, batches [{counter}, {counter + nr_batches}]"
        )
        data_for_all_workers = True
        for worker in batches:
            curr_batches = batches[worker]
            if curr_batches:
                models[worker], loss_values[worker] = train_on_batches(
                    worker, curr_batches, model, device, lr)
            else:
                data_for_all_workers = False
        counter += nr_batches
        if not data_for_all_workers:
            print("At least one worker ran out of data, stopping.")
            break

        model = utils.federated_avg(models)
        batches = get_next_batches(federated_train_loader, nr_batches)
        if abort_after_one:
            break
    return model
예제 #9
0
def FederatedTrainer(model, device, fed_data_loader, lr, fed_after_n_batches):
    model.train()

    nr_batches = fed_after_n_batches
    data_for_all_workers = True
    batch_counter = 0
    models = {}
    losses = {}

    iter(fed_data_loader)
    batches = GetNextBatch(fed_data_loader, nr_batches)

    while True:
        mylogger.logger.debug(
            "Starting training round, batches [{}, {}]".format(
                batch_counter, batch_counter + nr_batches))

        for worker in batches:
            curr_batches = batches[worker]

            if curr_batches:
                models[worker], losses[worker] = TrainOnBatches(
                    worker, curr_batches, model, device, lr)
            else:
                data_for_all_workers = False

        batch_counter += 1

        if not data_for_all_workers:
            mylogger.logger.debug("stopping.")

            break

        model = SyftUtils.federated_avg(models)
        batches = GetNextBatch(fed_data_loader, nr_batches)

    return model
예제 #10
0
async def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    kwargs_websocket = {"hook": hook, "verbose": args.verbose, "host": "0.0.0.0"}
    alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
    bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
    charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)
    testing = websocket_client.WebsocketClientWorker(id="testing", port=8780, **kwargs_websocket)

    for wcw in [alice, bob, charlie, testing]:
        wcw.clear_objects_remote()

    worker_instances = [alice, bob, charlie]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    model = Net().to(device)

    traced_model = torch.jit.trace(model, torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device))
    learning_rate = args.lr

    for curr_round in range(1, args.training_rounds + 1):
        logger.info("Training round %s/%s", curr_round, args.training_rounds)

        results = await asyncio.gather(
            *[
                fit_model_on_worker(
                    worker=worker,
                    traced_model=traced_model,
                    batch_size=args.batch_size,
                    curr_round=curr_round,
                    max_nr_batches=args.federate_after_n_batches,
                    lr=learning_rate,
                )
                for worker in worker_instances
            ]
        )
        models = {}
        loss_values = {}

        test_models = curr_round % 10 == 1 or curr_round == args.training_rounds
        if test_models:
            logger.info("Evaluating models")
            np.set_printoptions(formatter={"float": "{: .0f}".format})
            for worker_id, worker_model, _ in results:
                evaluate_model_on_worker(
                    model_identifier="Model update " + worker_id,
                    worker=testing,
                    dataset_key="mnist_testing",
                    model=worker_model,
                    nr_bins=10,
                    batch_size=128,
                    device=args.device,
                    print_target_hist=False,
                )

        # Federate models (note that this will also change the model in models[0]
        for worker_id, worker_model, worker_loss in results:
            if worker_model is not None:
                models[worker_id] = worker_model
                loss_values[worker_id] = worker_loss

        traced_model = utils.federated_avg(models)

        if test_models:
            evaluate_model_on_worker(
                model_identifier="Federated model",
                worker=testing,
                dataset_key="mnist_testing",
                model=traced_model,
                nr_bins=10,
                batch_size=128,
                device=args.device,
                print_target_hist=False,
            )

        # decay learning rate
        learning_rate = max(0.98 * learning_rate, args.lr * 0.01)

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
예제 #11
0
    r = open("memory_cpu.txt", "a+")
    r.write("Iteration Number " + str(iteration) + '\n')
    r.write('physical memory use: (in MB)' + str(p.memory_info()[0] / 2.**20) +
            '\n')
    r.write('physical memory use: (in MB)' + str(p.memory_percent()) + '\n')
    r.write('percentage utilization of this process in the system' +
            str(p.cpu_percent(interval=None)) + '\n')
    r.close()
    #Loading all the updated models received from the clients -we have for beacause we trained with 4 clients-

    for i in range(int(0.7 * len(clients))):
        print(i)
        models[i] = pickle.load(open('model' + str(i) + '.sav', 'rb'))

    #doing the federated avg
    federated_model = utils.federated_avg(models)

    #Saving the global model to be sent again to the clients

    filename = 'initial_model.sav'
    pickle.dump(federated_model, open(filename, 'wb'))

#When all communication rounds end, the final model is sent to all the clients that are still connected
for client in clients:
    t = Thread(target=last_model, args=(client[0], client[1]))
    trds.append(t)
    t.start()

for tr in trds:
    tr.join()
예제 #12
0
def async_train(model,
                device,
                federated_train_loader,
                lr,
                federate_after_n_batches,
                abort_after_one=False):
    model.train()

    nr_batches = federate_after_n_batches

    models = {}
    loss_values = {}
    # Create queue to save results
    ret_queue = queue.Queue()

    iter(federated_train_loader)  # initialize iterators
    batches = get_next_batches(federated_train_loader, nr_batches)
    counter = 0

    while True:
        logger.debug(
            f"Starting training round, batches [{counter}, {counter + nr_batches}]"
        )
        data_for_all_workers = True

        # Create thread pool
        threads_pool = []
        # clear queue
        ret_queue.queue.clear()

        # for each work to start train thread
        for worker in batches:
            curr_batches = batches[worker]
            if curr_batches:
                # Create new threads
                p_thread = TrainThread(
                    [ret_queue, [worker, curr_batches, model, device, lr]])

                # append to threads pool
                threads_pool.append(p_thread)

                # The start() method starts a thread by calling the run method.
                p_thread.start()
            else:
                data_for_all_workers = False

        # The join() waits for all threads to terminate.
        for p_thread in threads_pool:
            p_thread.join()

        # get all results from queue
        while not ret_queue.empty():
            q_data = ret_queue.get()
            models[q_data[0]] = q_data[1]
            loss_values[q_data[0]] = q_data[2]

        counter += nr_batches
        if not data_for_all_workers:
            logger.debug("At least one worker ran out of data, stopping.")
            break

        logger.info("Execute federated avg.")
        model = utils.federated_avg(models)

        logger.info("Get next batches.")
        batches = get_next_batches(federated_train_loader, nr_batches)
        if abort_after_one:
            break
    return model
예제 #13
0
async def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    kwargs_websocket = {
        "hook": hook,
        "verbose": args.verbose,
        "host": "127.0.0.1"
    }
    alice = websocket_client.WebsocketClientWorker(id="alice",
                                                   port=8777,
                                                   **kwargs_websocket)
    bob = websocket_client.WebsocketClientWorker(id="bob",
                                                 port=8778,
                                                 **kwargs_websocket)
    charlie = websocket_client.WebsocketClientWorker(id="charlie",
                                                     port=8779,
                                                     **kwargs_websocket)
    testing = websocket_client.WebsocketClientWorker(id="testing",
                                                     port=8780,
                                                     **kwargs_websocket)

    for wcw in [alice, bob, charlie, testing]:
        wcw.clear_objects_remote()

    worker_instances = [alice, bob, charlie]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    model = Net().to(device)

    traced_model = torch.jit.trace(
        model,
        torch.zeros([1, 47], dtype=torch.long).to(device))
    learning_rate = args.lr

    for epoch in range(1, 11):
        logger.info("Training epoch %s/%s", epoch, 10)

        results = await asyncio.gather(*[
            fit_model_on_worker(
                worker=worker,
                traced_model=traced_model,
                batch_size=args.batch_size,
                epoch=epoch,
                max_nr_batches=-1,
                lr=learning_rate,
            ) for worker in worker_instances
        ])
        models = {}
        loss_values = {}

        test_models = epoch > 0 and epoch <= 10
        if test_models:
            logger.info("Evaluating models")
            np.set_printoptions(formatter={"float": "{: .0f}".format})
            for worker_id, worker_model, _ in results:
                evaluate_model_on_worker(
                    model_identifier="Model update " + worker_id,
                    worker=testing,
                    dataset_key="dga_testing",
                    model=worker_model,
                    nr_bins=2,
                    batch_size=500,
                    device=device,
                    print_target_hist=False,
                )

        for worker_id, worker_model, worker_loss in results:
            if worker_model is not None:
                models[worker_id] = worker_model
                loss_values[worker_id] = worker_loss

        traced_model = utils.federated_avg(models)

        if test_models:
            evaluate_model_on_worker(
                model_identifier="Federated model",
                worker=testing,
                dataset_key="dga_testing",
                model=traced_model,
                nr_bins=2,
                batch_size=500,
                device=device,
                print_target_hist=False,
            )

        # decay learning rate
        learning_rate = max(0.98 * learning_rate, args.lr * 0.01)
예제 #14
0
async def main():
    args = define_and_get_arguments()

    hook = sy.TorchHook(torch)

    if (args.localworkers):
        # ----------------------------- This is for localhost workers --------------------------------
        kwargs_websocket = {
            "hook": hook,
            "verbose": args.verbose,
            "host": "0.0.0.0"
        }
        alice = websocket_client.WebsocketClientWorker(id="alice",
                                                       port=8777,
                                                       **kwargs_websocket)
        bob = websocket_client.WebsocketClientWorker(id="bob",
                                                     port=8778,
                                                     **kwargs_websocket)
        charlie = websocket_client.WebsocketClientWorker(id="charlie",
                                                         port=8779,
                                                         **kwargs_websocket)
        testing = websocket_client.WebsocketClientWorker(id="testing",
                                                         port=8780,
                                                         **kwargs_websocket)
    else:
        # ----------------------------- This is for remote workers ------------------------------------
        kwargs_websocket_alice = {"host": "128.226.78.195", "hook": hook}
        alice = websocket_client.WebsocketClientWorker(
            id="alice", port=8777, **kwargs_websocket_alice)

        kwargs_websocket_bob = {"host": "128.226.77.222", "hook": hook}
        bob = websocket_client.WebsocketClientWorker(id="bob",
                                                     port=8777,
                                                     **kwargs_websocket_bob)

        kwargs_websocket_charlie = {"host": "128.226.88.120", "hook": hook}
        charlie = websocket_client.WebsocketClientWorker(
            id="charlie", port=8777, **kwargs_websocket_charlie)

        # kwargs_websocket_testing = {"host": "128.226.77.111", "hook": hook}
        kwargs_websocket_testing = {"host": "128.226.88.210", "hook": hook}
        testing = websocket_client.WebsocketClientWorker(
            id="testing", port=8777, **kwargs_websocket_testing)

    for wcw in [alice, bob, charlie, testing]:
        wcw.clear_objects_remote()

    worker_instances = [alice, bob, charlie]

    use_cuda = args.cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    model = Net().to(device)
    if (os.path.isfile('mnist_cnn_asyn.pt')):
        model.load_state_dict(torch.load("mnist_cnn_asyn.pt"))
        model.eval()

    traced_model = torch.jit.trace(
        model,
        torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device))
    learning_rate = args.lr

    # Execute traning and test process round
    for curr_round in range(1, args.training_rounds + 1):
        logger.info("Training round %s/%s", curr_round, args.training_rounds)

        results = await asyncio.gather(*[
            fit_model_on_worker(
                worker=worker,
                traced_model=traced_model,
                batch_size=args.batch_size,
                curr_round=curr_round,
                max_nr_batches=args.federate_after_n_batches,
                lr=learning_rate,
            ) for worker in worker_instances
        ])
        models = {}
        loss_values = {}

        # Apply evaluate model for each 10 round and at the last round
        test_models = curr_round % 10 == 1 or curr_round == args.training_rounds
        if test_models:
            logger.info("Evaluating models")
            np.set_printoptions(formatter={"float": "{: .0f}".format})
            for worker_id, worker_model, _ in results:
                evaluate_model_on_worker(
                    model_identifier="Model update " + worker_id,
                    worker=testing,
                    dataset_key="mnist_testing",
                    model=worker_model,
                    nr_bins=10,
                    batch_size=128,
                    device=device,
                    print_target_hist=False,
                )

        # Federate models (note that this will also change the model in models[0]
        for worker_id, worker_model, worker_loss in results:
            if worker_model is not None:
                models[worker_id] = worker_model
                loss_values[worker_id] = worker_loss

        traced_model = utils.federated_avg(models)

        if test_models:
            evaluate_model_on_worker(
                model_identifier="Federated model",
                worker=testing,
                dataset_key="mnist_testing",
                model=traced_model,
                nr_bins=10,
                batch_size=128,
                device=device,
                print_target_hist=False,
            )
            # save indermediate model
            model_dir = "models_asyn"
            if (not os.path.exists(model_dir)):
                os.makedirs(model_dir)
            model_name = "{}/mnist_cnn_{}.pt".format(model_dir, curr_round)
            torch.save(traced_model.state_dict(), model_name)

        # decay learning rate
        learning_rate = max(0.98 * learning_rate, args.lr * 0.01)

    if args.save_model:
        torch.save(traced_model.state_dict(), "mnist_cnn_asyn.pt")