示例#1
0
def train_and_validate_federated_afl(client):
    """The training scheme of Federated Learning systems.
        This the implementation of Agnostic Federated Learning
        https://arxiv.org/abs/1902.00146
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the testing on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    # Initialize lambda variable proportianate to their sample size
    if client.args.graph.rank == 0:
        gather_list_size = [
            torch.tensor(0.0) for _ in range(client.args.graph.n_nodes)
        ]
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    gather_list=gather_list_size,
                    dst=0)
        client.lambda_vector = torch.stack(
            gather_list_size) / client.args.train_dataset_size
    else:
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    dst=0)
        client.lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] *
                                            client.args.graph.n_nodes)

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        # TODO: not make the server rank hard coded
        log("Starting round {} of training".format(n_c + 1), client.args.debug)
        online_clients = set_online_clients(client.args)
        if n_c == 0:
            # The first round server should be in the communication to initilize its own training
            online_clients = online_clients if 0 in online_clients else online_clients + [
                0
            ]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)

        if client.args.graph.rank in online_clients_server:
            st = time.time()
            client.model_server = distribute_model_server(client.model_server,
                                                          online_clients_group,
                                                          src=0)
            dist.broadcast(client.lambda_vector,
                           src=0,
                           group=online_clients_group)
            client.args.comm_time[-1] += time.time() - st
            client.model.load_state_dict(client.model_server.state_dict())

            # This loss tensor is for those clients not participating in the first round
            loss = torch.tensor(0.0)
            # Start running updates on local machines
            if client.args.graph.rank in online_clients:
                is_sync = False
                while not is_sync:
                    for _input, _target in client.train_loader:

                        client.model.train()
                        # update local step.
                        logging_load_time(tracker)
                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)

                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()
                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)
                        # compute gradient and do local SGD step.
                        loss.backward()
                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        # logging locally.
                        # logging_computing(tracker, loss, performance, _input, lr)

                        # display the logging info.
                        # logging_display_training(args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break
            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

            # Validate the local models befor sync
            do_validate(client.args,
                        client.model,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        local=True)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            local=True)
            # Sync the model server based on client models
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1

            client.model_server, loss_tensor_online = afl_aggregation(
                client.args, client.model_server, client.model,
                client.lambda_vector[client.args.graph.rank].item(),
                torch.tensor(loss.item()), online_clients_group,
                online_clients, client.optimizer)

            # evaluate the sync time
            logging_sync_time(tracker)
            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # Updating lambda variable
            if client.args.graph.rank == 0:
                num_online_clients = len(
                    online_clients
                ) if 0 in online_clients else len(online_clients) + 1
                loss_tensor = torch.zeros(client.args.graph.n_nodes)
                loss_tensor[sorted(online_clients_server)] = loss_tensor_online
                # Dual update
                client.lambda_vector += client.args.drfa_gamma * loss_tensor
                # Projection into a simplex
                client.lambda_vector = euclidean_proj_simplex(
                    client.lambda_vector)
                # Avoid zero probability
                lambda_zeros = client.lambda_vector <= 1e-3
                if lambda_zeros.sum() > 0:
                    client.lambda_vector[lambda_zeros] = 1e-3
                    client.lambda_vector /= client.lambda_vector.sum()

            # logging.
            logging_globally(tracker, start_global_time)

            # reset start round time.
            start_global_time = time.time()
            # validate the model at the server
            if client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')
            log(
                'This round communication time is: {}'.format(
                    client.args.comm_time[-1]), client.args.debug)
        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)
        dist.barrier(group=client.all_clients_group)

    return
示例#2
0
def train_and_validate_drfa_centered(Clients, Server):
    log('start training and validation with Federated setting in a centered way.'
        )

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.')

    for oc in range(Server.args.graph.n_nodes):
        Server.lambda_vector[oc] = Clients[oc].args.num_samples_per_epoch
    Server.lambda_vector /= Server.lambda_vector.sum()
    # Number of communication rounds in federated setting should be defined
    for n_c in range(Server.args.num_comms):
        Server.args.rounds_comm += 1
        Server.args.local_index += 1
        Server.args.quant_error = 0.0

        # Preset variables for this round of communication
        Server.zero_grad()
        Server.reset_tracker(Server.local_val_tracker)
        Server.reset_tracker(Server.global_val_tracker)
        Server.reset_tracker(Server.global_test_tracker)
        if Server.args.fed_personal:
            Server.reset_tracker(Server.local_personal_val_tracker)
            Server.reset_tracker(Server.global_personal_val_tracker)

        # Configuring the devices for this round of communication
        log("Starting round {} of training".format(n_c + 1))
        online_clients = set_online_clients_centered(Server.args)

        Server.args.drfa_gamma *= 0.9

        for oc in online_clients:
            Clients[oc].model.load_state_dict(Server.model.state_dict())
            Clients[oc].args.rounds_comm = Server.args.rounds_comm
            local_steps = 0
            is_sync = False

            do_validate_centered(Clients[oc].args,
                                 Server.model,
                                 Server.criterion,
                                 Server.metrics,
                                 Server.optimizer,
                                 Clients[oc].train_loader,
                                 Server.global_val_tracker,
                                 val=False,
                                 local=False)
            if Server.args.per_class_acc:
                Clients[oc].reset_tracker(Clients[oc].local_val_tracker)
                Clients[oc].reset_tracker(Clients[oc].global_val_tracker)
                if Server.args.fed_personal:
                    Clients[oc].reset_tracker(
                        Clients[oc].local_personal_val_tracker)
                    Clients[oc].reset_tracker(
                        Clients[oc].global_personal_val_tracker)
                    do_validate_centered(
                        Clients[oc].args,
                        Server.model,
                        Server.criterion,
                        Server.metrics,
                        Server.optimizer,
                        Clients[oc].val_loader,
                        Clients[oc].global_personal_val_tracker,
                        val=True,
                        local=False)
                do_validate_centered(Clients[oc].args,
                                     Server.model,
                                     Server.criterion,
                                     Server.metrics,
                                     Server.optimizer,
                                     Clients[oc].train_loader,
                                     Clients[oc].global_val_tracker,
                                     val=False,
                                     local=False)
            if Server.args.fed_personal:
                do_validate_centered(Clients[oc].args,
                                     Server.model,
                                     Server.criterion,
                                     Server.metrics,
                                     Server.optimizer,
                                     Clients[oc].val_loader,
                                     Server.global_personal_val_tracker,
                                     val=True,
                                     local=False)

            if Server.args.federated_type == 'perfedavg':
                for _input_val, _target_val in Clients[oc].val_loader1:
                    _input_val, _target_val = load_data_batch(
                        Clients[oc].args, _input_val, _target_val, tracker)
                    break

            k = torch.randint(low=1, high=Server.args.local_step,
                              size=(1, )).item()
            while not is_sync:
                if Server.args.arch == 'rnn':
                    Clients[oc].model.init_hidden(Server.args.batch_size)
                for _input, _target in Clients[oc].train_loader:
                    local_steps += 1
                    if k == local_steps:
                        Clients[oc].kth_model.load_state_dict(
                            Clients[oc].model.state_dict())
                    Clients[oc].model.train()

                    # update local step.
                    logging_load_time(tracker)

                    # update local index and get local step
                    Clients[oc].args.local_index += 1
                    Clients[oc].args.local_data_seen += len(_target)
                    get_current_epoch(Clients[oc].args)
                    local_step = get_current_local_step(Clients[oc].args)

                    # adjust learning rate (based on the # of accessed samples)
                    lr = adjust_learning_rate(Clients[oc].args,
                                              Clients[oc].optimizer,
                                              Clients[oc].scheduler)

                    # load data
                    _input, _target = load_data_batch(Clients[oc].args, _input,
                                                      _target, tracker)

                    # Skip batches with one sample because of BatchNorm issue in some models!
                    if _input.size(0) == 1:
                        is_sync = is_sync_fed(Clients[oc].args)
                        break

                    # inference and get current performance.
                    Clients[oc].optimizer.zero_grad()

                    loss, performance = inference(Clients[oc].model,
                                                  Clients[oc].criterion,
                                                  Clients[oc].metrics,
                                                  _input,
                                                  _target,
                                                  rnn=Server.args.arch
                                                  in ['rnn'])

                    # compute gradient and do local SGD step.
                    loss.backward()

                    if Clients[oc].args.federated_type == 'fedgate':
                        # Update gradients with control variates
                        for client_param, delta_param in zip(
                                Clients[oc].model.parameters(),
                                Clients[oc].model_delta.parameters()):
                            client_param.grad.data -= delta_param.data
                    elif Clients[oc].args.federated_type == 'scaffold':
                        for cp, ccp, scp in zip(
                                Clients[oc].model.parameters(),
                                Clients[oc].model_client_control.parameters(),
                                Server.model_server_control.parameters()):
                            cp.grad.data += scp.data - ccp.data
                    elif Clients[oc].args.federated_type == 'fedprox':
                        # Adding proximal gradients and loss for fedprox
                        for client_param, server_param in zip(
                                Clients[oc].model.parameters(),
                                Server.model.parameters()):
                            loss += Clients[
                                oc].args.fedprox_mu / 2 * torch.norm(
                                    client_param.data - server_param.data)
                            client_param.grad.data += Clients[
                                oc].args.fedprox_mu * (client_param.data -
                                                       server_param.data)

                    if 'robust' in Clients[oc].args.arch:
                        Clients[oc].model.noise.grad.data *= -1

                    Clients[oc].optimizer.step(
                        apply_lr=True,
                        apply_in_momentum=Clients[oc].args.in_momentum,
                        apply_out_momentum=False)

                    if 'robust' in Clients[oc].args.arch:
                        if torch.norm(Clients[oc].model.noise.data) > 1:
                            Clients[oc].model.noise.data /= torch.norm(
                                Clients[oc].model.noise.data)

                    if Clients[oc].args.federated_type == 'perfedavg':
                        # _input_val, _target_val = Clients[oc].load_next_val_batch()
                        lr = adjust_learning_rate(
                            Clients[oc].args,
                            Clients[oc].optimizer,
                            Clients[oc].scheduler,
                            lr_external=Clients[oc].args.perfedavg_beta)
                        if _input_val.size(0) == 1:
                            is_sync = is_sync_fed(Clients[oc].args)
                            break

                        Clients[oc].optimizer.zero_grad()
                        loss, performance = inference(Clients[oc].model,
                                                      Clients[oc].criterion,
                                                      Clients[oc].metrics,
                                                      _input_val, _target_val)
                        loss.backward()
                        Clients[oc].optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=Clients[oc].args.in_momentum,
                            apply_out_momentum=False)

                    # reset load time for the tracker.
                    tracker['start_load_time'] = time.time()
                    # model_local = deepcopy(model_client)
                    is_sync = is_sync_fed(Clients[oc].args)
                    if is_sync:
                        break

            do_validate_centered(Clients[oc].args,
                                 Clients[oc].model,
                                 Clients[oc].criterion,
                                 Clients[oc].metrics,
                                 Clients[oc].optimizer,
                                 Clients[oc].train_loader,
                                 Server.local_val_tracker,
                                 val=False,
                                 local=True)
            if Server.args.per_class_acc:
                do_validate_centered(Clients[oc].args,
                                     Clients[oc].model,
                                     Clients[oc].criterion,
                                     Clients[oc].metrics,
                                     Clients[oc].optimizer,
                                     Clients[oc].train_loader,
                                     Clients[oc].local_val_tracker,
                                     val=False,
                                     local=True)
                if Server.args.fed_personal:
                    do_validate_centered(
                        Clients[oc].args,
                        Clients[oc].model,
                        Clients[oc].criterion,
                        Clients[oc].metrics,
                        Clients[oc].optimizer,
                        Clients[oc].val_loader,
                        Clients[oc].local_personal_val_tracker,
                        val=True,
                        local=True)
            if Server.args.fed_personal:
                do_validate_centered(Clients[oc].args,
                                     Clients[oc].model,
                                     Clients[oc].criterion,
                                     Clients[oc].metrics,
                                     Clients[oc].optimizer,
                                     Clients[oc].val_loader,
                                     Server.local_personal_val_tracker,
                                     val=True,
                                     local=True)
            # Sync the model server based on model_clients
            tracker['start_sync_time'] = time.time()
            Server.args.global_index += 1
            logging_sync_time(tracker)

        if Server.args.federated_type == 'scaffold':
            scaffold_aggregation_centered(Clients, Server, online_clients,
                                          local_steps, lr)
        elif Server.args.federated_type == 'fedgate':
            fedgate_aggregation_centered(Clients, Server, online_clients,
                                         local_steps, lr)
        elif Server.args.federated_type == 'qsparse':
            qsparse_aggregation_centered(Clients, Server, online_clients,
                                         local_steps, lr)
        else:
            fedavg_aggregation_centered(Clients, Server, online_clients,
                                        Server.lambda_vector.numpy())

        # Aggregate Kth models
        aggregate_kth_model_centered(Clients, Server, online_clients)

        # Log performance
        # Client training performance
        log_validation_centered(Server.args,
                                Server.local_val_tracker,
                                val=False,
                                local=True)
        # Server training performance
        log_validation_centered(Server.args,
                                Server.global_val_tracker,
                                val=False,
                                local=False)
        if Server.args.fed_personal:
            # Client validation performance
            log_validation_centered(Server.args,
                                    Server.local_personal_val_tracker,
                                    val=True,
                                    local=True)
            # Server validation performance
            log_validation_centered(Server.args,
                                    Server.global_personal_val_tracker,
                                    val=True,
                                    local=False)

        # Per client stats
        if Server.args.per_class_acc:
            log_validation_per_client_centered(Server.args,
                                               Clients,
                                               online_clients,
                                               val=False,
                                               local=False)
            log_validation_per_client_centered(Server.args,
                                               Clients,
                                               online_clients,
                                               val=False,
                                               local=True)
            if Server.args.fed_personal:
                log_validation_per_client_centered(Server.args,
                                                   Clients,
                                                   online_clients,
                                                   val=True,
                                                   local=False)
                log_validation_per_client_centered(Server.args,
                                                   Clients,
                                                   online_clients,
                                                   val=True,
                                                   local=True)

        # Test on server
        do_validate_centered(Server.args,
                             Server.model,
                             Server.criterion,
                             Server.metrics,
                             Server.optimizer,
                             Server.test_loader,
                             Server.global_test_tracker,
                             val=False,
                             local=False)
        log_test_centered(Server.args, Server.global_test_tracker)

        online_clients_lambda = set_online_clients_centered(Server.args)
        loss_tensor = torch.zeros(Server.args.graph.n_nodes)
        num_online_clients = len(online_clients_lambda)
        for ocl in online_clients_lambda:
            for _input, _target in Clients[ocl].train_loader:
                _input, _target = load_data_batch(Clients[ocl].args, _input,
                                                  _target, tracker)
                loss, _ = inference(Server.kth_model,
                                    Clients[ocl].criterion,
                                    Clients[ocl].metrics,
                                    _input,
                                    _target,
                                    rnn=Server.args.arch in ['rnn'])
                break
            loss_tensor[ocl] = loss * (Server.args.graph.n_nodes /
                                       num_online_clients)
        Server.lambda_vector += Server.args.drfa_gamma * Server.args.local_step * loss_tensor
        lambda_vector = projection_simplex_sort(
            Server.lambda_vector.detach().numpy())
        print(lambda_vector)
        # Avoid zero probability
        lambda_zeros = np.argwhere(lambda_vector <= 1e-3)
        if len(lambda_zeros) > 0:
            lambda_vector[lambda_zeros[0]] = 1e-3
            lambda_vector /= np.sum(lambda_vector)
        Server.lambda_vector = torch.tensor(lambda_vector)

        # logging.
        logging_globally(tracker, start_global_time)

        # reset start round time.
        start_global_time = time.time()

        # validate the model at the server
        # if args.graph.rank == 0:
        #     do_test(args, model_server, optimizer, criterion, metrics, test_loader)
        # do_validate_test(args, model_server, optimizer, criterion, metrics, test_loader)
    return
示例#3
0
def train_and_validate_federated_drfa(client):
    """The training scheme of Distributionally Robust Federated Learning DRFA.
        paper: https://papers.nips.cc/paper/2020/hash/ac450d10e166657ec8f93a1b65ca1b14-Abstract.html
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the testing on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    # Initialize lambda variable proportianate to their sample size
    if client.args.graph.rank == 0:
        gather_list_size = [
            torch.tensor(0.0) for _ in range(client.args.graph.n_nodes)
        ]
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    gather_list=gather_list_size,
                    dst=0)
        lambda_vector = torch.stack(
            gather_list_size) / client.args.train_dataset_size
    else:
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    dst=0)
        lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] *
                                     client.args.graph.n_nodes)

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        # TODO: not make the server rank hard coded
        log("Starting round {} of training".format(n_c + 1), client.args.debug)
        online_clients = set_online_clients(client.args)
        if n_c == 0:
            # The first round server should be in the communication to initilize its own training
            online_clients = online_clients if 0 in online_clients else online_clients + [
                0
            ]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)
        client.args.drfa_gamma *= 0.9
        if client.args.graph.rank in online_clients_server:
            if client.args.federated_type == 'scaffold':
                st = time.time()
                client.model_server, client.model_server_control = distribute_model_server_control(
                    client.model_server,
                    client.model_server_control,
                    online_clients_group,
                    src=0)
                client.args.comm_time[-1] += time.time() - st
            else:
                st = time.time()
                model_server = distribute_model_server(client.model_server,
                                                       online_clients_group,
                                                       src=0)
                client.args.comm_time[-1] += time.time() - st
            client.model.load_state_dict(client.model_server.state_dict())

            # Send related variables to drfa algorithm
            st = time.time()
            dist.broadcast(client.lambda_vector,
                           src=0,
                           group=online_clients_group)
            # Sending the random number k to all nodes:
            # Does not fully support the epoch mode now
            k = torch.randint(low=1, high=client.args.local_step, size=(1, ))
            dist.broadcast(k, src=0, group=online_clients_group)
            client.args.comm_time[-1] += time.time() - st

            k = k.tolist()[0]
            local_steps = 0
            # Start running updates on local machines
            if client.args.graph.rank in online_clients:
                is_sync = False
                while not is_sync:
                    for _input, _target in client.train_loader:
                        local_steps += 1
                        # Getting the k-th model for dual variable update
                        if k == local_steps:
                            client.kth_model.load_state_dict(
                                client.model.state_dict())
                        client.model.train()

                        # update local step.
                        logging_load_time(tracker)

                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)
                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()
                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)

                        # compute gradient and do local SGD step.
                        loss.backward()

                        if client.args.federated_type == 'fedgate':
                            for client_param, delta_param in zip(
                                    client.model.parameters(),
                                    client.model_delta.parameters()):
                                client_param.grad.data -= delta_param.data
                        elif client.args.federated_type == 'scaffold':
                            for cp, ccp, scp in zip(
                                    client.model.parameters(),
                                    client.model_client_control.parameters(),
                                    client.model_server_control.parameters()):
                                cp.grad.data += scp.data - ccp.data

                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        # logging locally.
                        # logging_computing(tracker, loss, performance, _input, lr)

                        # display the logging info.
                        # logging_display_training(client.args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break
            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

        # Validate the local models befor sync
            do_validate(client.args,
                        client.model,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        local=True)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            local=True)
            # Sync the model server based on model_clients
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1

            if client.args.federated_type == 'fedgate':
                client.model_server, client.model_delta = fedgate_aggregation(
                    client.args,
                    client.model_server,
                    client.model,
                    client.model_delta,
                    client.model_memory,
                    online_clients_group,
                    online_clients,
                    client.optimizer,
                    lr,
                    local_steps,
                    lambda_weight=client.lambda_vector[
                        client.args.graph.rank].item())
            elif client.args.federated_type == 'scaffold':
                client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation(
                    client.args,
                    client.model_server,
                    client.model,
                    client.model_server_control,
                    client.model_client_control,
                    online_clients_group,
                    online_clients,
                    client.optimizer,
                    lr,
                    local_steps,
                    lambda_weight=client.lambda_vector[
                        client.args.graph.rank].item())
            else:
                client.model_server = fedavg_aggregation(
                    client.args,
                    client.model_server,
                    client.model,
                    online_clients_group,
                    online_clients,
                    client.optimizer,
                    lambda_weight=client.lambda_vector[
                        client.args.graph.rank].item())
            # Average the kth_model
            client.kth_model = aggregate_models_virtual(
                client.args, client.kth_model, online_clients_group,
                online_clients)
            # evaluate the sync time
            logging_sync_time(tracker)

            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # validate the model at the server
            if client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')

        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)

        # Update lambda parameters
        online_clients_lambda = set_online_clients(client.args)
        online_clients_server_lambda = online_clients_lambda if 0 in online_clients_lambda else [
            0
        ] + online_clients_lambda
        online_clients_group_lambda = dist.new_group(
            online_clients_server_lambda)

        if client.args.graph.rank in online_clients_server_lambda:
            st = time.time()
            client.kth_model = distribute_model_server(
                client.kth_model, online_clients_group_lambda, src=0)
            client.args.comm_time[-1] += time.time() - st
            loss = torch.tensor(0.0)

            if client.args.graph.rank in online_clients_lambda:
                for _input, _target in client.train_loader:
                    _input, _target = load_data_batch(client.args, _input,
                                                      _target, tracker)
                    # Skip batches with one sample because of BatchNorm issue in some models!
                    if _input.size(0) == 1:
                        break
                    loss, _ = inference(client.kth_model, client.criterion,
                                        client.metrics, _input, _target)
                    break
            loss_tensor_online = loss_gather(
                client.args,
                torch.tensor(loss.item()),
                group=online_clients_group_lambda,
                online_clients=online_clients_lambda)
            if client.args.graph.rank == 0:
                num_online_clients = len(
                    online_clients_lambda
                ) if 0 in online_clients_lambda else len(
                    online_clients_lambda) + 1
                loss_tensor = torch.zeros(client.args.graph.n_nodes)
                loss_tensor[sorted(
                    online_clients_server_lambda)] = loss_tensor_online * (
                        client.args.graph.n_nodes / num_online_clients)
                # Dual update
                client.lambda_vector += client.args.drfa_gamma * client.args.local_step * loss_tensor
                client.lambda_vector = euclidean_proj_simplex(
                    client.lambda_vector)

                # Avoid zero probability
                lambda_zeros = client.lambda_vector <= 1e-3
                if lambda_zeros.sum() > 0:
                    client.lambda_vector[lambda_zeros] = 1e-3
                    client.lambda_vector /= client.lambda_vector.sum()

        # logging.
        logging_globally(tracker, start_global_time)

        # reset start round time.
        start_global_time = time.time()
        log(
            'This round communication time is: {}'.format(
                client.args.comm_time[-1]), client.args.debug)
        dist.barrier(group=client.all_clients_group)
    return
示例#4
0
def train_and_validate(client):
    """The training scheme of Distributed Local SGD."""
    log('start training and validation.', client.args.debug)

    if client.args.evaluate and client.args.graph.rank==0:
        # Do the training on the server and return
        do_validate(client.args, client.model, client.optimizer,  client.criterion, client.metrics,
                         client.test_loader, client.all_clients_group, data_mode='test')
        return

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    client.args.comm_time.append(0.0)
    # break until finish expected full epoch training.
    while True:
        # configure local step.
        for _input, _target in client.train_loader:
            client.model.train()

            # update local step.
            logging_load_time(tracker)

            # update local index and get local step
            client.args.local_index += 1
            client.args.local_data_seen += len(_target)
            get_current_epoch(client.args)
            local_step = get_current_local_step(client.args)

            # adjust learning rate (based on the # of accessed samples)
            lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler)

            # load data
            _input, _target = load_data_batch(client.args, _input, _target, tracker)

            # inference and get current performance.
            client.optimizer.zero_grad()
            loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target)

            # compute gradient and do local SGD step.
            loss.backward()
            client.optimizer.step(
                apply_lr=True,
                apply_in_momentum=client.args.in_momentum, apply_out_momentum=False
            )

            # logging locally.
            logging_computing(tracker, loss, performance, _input, lr)

            # evaluate the status.
            is_sync = client.args.local_index % local_step == 0
            if client.args.epoch_ % 1 == 0:
                client.args.finish_one_epoch = True

            # sync
            if is_sync:
                log('Enter synching', client.args.debug)
                client.args.global_index += 1

                # broadcast gradients to other nodes by using reduce_sum.
                client.model_server = aggregate_gradients(client.args, client.model_server,
                                                          client.model, client.optimizer, is_sync)
                # evaluate the sync time
                logging_sync_time(tracker)

                # logging.
                logging_globally(tracker, start_global_time)
                
                # reset start round time.
                start_global_time = time.time()
            
            # finish one epoch training and to decide if we want to val our model.
            if client.args.finish_one_epoch:
                if client.args.epoch % client.args.eval_freq ==0 and client.args.graph.rank == 0:
                        do_validate(client.args, client.model, client.optimizer,  client.criterion, client.metrics,
                            client.test_loader, client.all_clients_group, data_mode='test')
                dist.barrier(group=client.all_clients_group)
                # refresh the logging cache at the begining of each epoch.
                client.args.finish_one_epoch = False
                tracker = define_local_training_tracker()
            
            # determine if the training is finished.
            if is_stop(client.args):
                #Last Sync
                log('Enter synching', client.args.debug)
                client.args.global_index += 1

                # broadcast gradients to other nodes by using reduce_sum.
                client.model_server = aggregate_gradients(client.args, client.model_server,
                                                          client.model, client.optimizer, is_sync)

                print("Total number of samples seen on device {} is {}".format(client.args.graph.rank, client.args.local_data_seen))
                if client.args.graph.rank == 0:
                    do_validate(client.args, client.model_server, client.optimizer,  client.criterion, client.metrics,
                        client.test_loader, client.all_clients_group, data_mode='test')
                return

            # display the logging info.
            logging_display_training(client.args, tracker)

            # reset load time for the tracker.
            tracker['start_load_time'] = time.time()

        # reshuffle the data.
        if client.args.reshuffle_per_epoch:
            log('reshuffle the dataset.', client.args.debug)
            del client.train_loader, client.test_loader
            gc.collect()
            log('reshuffle the dataset.', client.args.debug)
            client.load_local_dataset()
示例#5
0
def train_and_validate_perfedme_centered(Clients, Server):
    log('start training and validation with Federated setting in a centered way.'
        )

    # For Sparsified SGD
    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.')

    # Number of communication rounds in federated setting should be defined
    for n_c in range(Server.args.num_comms):
        Server.args.rounds_comm += 1
        Server.args.local_index += 1
        Server.args.quant_error = 0.0

        # Preset variables for this round of communication
        Server.zero_grad()
        Server.reset_tracker(Server.local_val_tracker)
        Server.reset_tracker(Server.global_val_tracker)
        if Server.args.fed_personal:
            Server.reset_tracker(Server.local_personal_val_tracker)
            Server.reset_tracker(Server.global_personal_val_tracker)

        # Configuring the devices for this round of communication
        log("Starting round {} of training".format(n_c + 1))
        online_clients = set_online_clients_centered(Server.args)

        for oc in online_clients:
            Clients[oc].model.load_state_dict(Server.model.state_dict())
            Clients[oc].args.rounds_comm = Server.args.rounds_comm
            local_steps = 0
            is_sync = False

            do_validate_centered(Clients[oc].args,
                                 Server.model,
                                 Server.criterion,
                                 Server.metrics,
                                 Server.optimizer,
                                 Clients[oc].train_loader,
                                 Server.global_val_tracker,
                                 val=False,
                                 local=False)
            if Server.args.fed_personal:
                do_validate_centered(Clients[oc].args,
                                     Server.model,
                                     Server.criterion,
                                     Server.metrics,
                                     Server.optimizer,
                                     Clients[oc].val_loader,
                                     Server.global_personal_val_tracker,
                                     val=True,
                                     local=False)

            while not is_sync:
                for _input, _target in Clients[oc].train_loader:
                    local_steps += 1
                    Clients[oc].model.train()
                    Clients[oc].model_personal.train()

                    # update local step.
                    logging_load_time(tracker)

                    # update local index and get local step
                    Clients[oc].args.local_index += 1
                    Clients[oc].args.local_data_seen += len(_target)
                    get_current_epoch(Clients[oc].args)
                    local_step = get_current_local_step(Clients[oc].args)

                    # adjust learning rate (based on the # of accessed samples)
                    lr = adjust_learning_rate(Clients[oc].args,
                                              Clients[oc].optimizer_personal,
                                              Clients[oc].scheduler)

                    # load data
                    _input, _target = load_data_batch(Clients[oc].args, _input,
                                                      _target, tracker)

                    # Skip batches with one sample because of BatchNorm issue in some models!
                    if _input.size(0) == 1:
                        is_sync = is_sync_fed(Clients[oc].args)
                        break

                    # inference and get current performance.
                    Clients[oc].optimizer_personal.zero_grad()

                    loss, performance = inference(Clients[oc].model_personal,
                                                  Clients[oc].criterion,
                                                  Clients[oc].metrics, _input,
                                                  _target)
                    # print("loss in rank {} is {}".format(oc,loss))

                    # compute gradient and do local SGD step.
                    loss.backward()

                    for client_param, personal_param in zip(
                            Clients[oc].model.parameters(),
                            Clients[oc].model_personal.parameters()):
                        # loss += Clients[oc].args.perfedme_lambda * torch.norm(personal_param.data - client_param.data)**2
                        personal_param.grad.data += Clients[
                            oc].args.perfedme_lambda * (personal_param.data -
                                                        client_param.data)

                    Clients[oc].optimizer_personal.step(
                        apply_lr=True,
                        apply_in_momentum=Clients[oc].args.in_momentum,
                        apply_out_momentum=False)

                    if Clients[oc].args.local_index == 1:
                        Clients[oc].optimizer.zero_grad()
                        loss, performance = inference(Clients[oc].model,
                                                      Clients[oc].criterion,
                                                      Clients[oc].metrics,
                                                      _input, _target)
                        loss.backward()

                    is_sync = is_sync_fed(Clients[oc].args)
                    if Clients[oc].args.local_index % 5 == 0 or is_sync:
                        log('Updating the local version of the global model',
                            Clients[oc].args.debug)
                        lr = adjust_learning_rate(Clients[oc].args,
                                                  Clients[oc].optimizer,
                                                  Clients[oc].scheduler)
                        Clients[oc].optimizer.zero_grad()
                        for client_param, personal_param in zip(
                                Clients[oc].model.parameters(),
                                Clients[oc].model_personal.parameters()):
                            client_param.grad.data = Clients[
                                oc].args.perfedme_lambda * (
                                    client_param.data - personal_param.data)
                        Clients[oc].optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=Clients[oc].args.in_momentum,
                            apply_out_momentum=False)

                    # reset load time for the tracker.
                    tracker['start_load_time'] = time.time()
                    # model_local = deepcopy(model_client)

                    if is_sync:
                        break

            do_validate_centered(Clients[oc].args,
                                 Clients[oc].model_personal,
                                 Clients[oc].criterion,
                                 Clients[oc].metrics,
                                 Clients[oc].optimizer_personal,
                                 Clients[oc].train_loader,
                                 Server.local_val_tracker,
                                 val=False,
                                 local=True)
            if Server.args.fed_personal:
                do_validate_centered(Clients[oc].args,
                                     Clients[oc].model_personal,
                                     Clients[oc].criterion,
                                     Clients[oc].metrics,
                                     Clients[oc].optimizer_personal,
                                     Clients[oc].val_loader,
                                     Server.local_personal_val_tracker,
                                     val=True,
                                     local=True)
            # Sync the model server based on model_clients
            tracker['start_sync_time'] = time.time()
            Server.args.global_index += 1

            logging_sync_time(tracker)

        fedavg_aggregation_centered(Clients, Server, online_clients)
        # Log local performance
        log_validation_centered(Server.args,
                                Server.local_val_tracker,
                                val=False,
                                local=True)
        if Server.args.fed_personal:
            log_validation_centered(Server.args,
                                    Server.local_personal_val_tracker,
                                    val=True,
                                    local=True)

        # Log server performance
        log_validation_centered(Server.args,
                                Server.global_val_tracker,
                                val=False,
                                local=False)
        if Server.args.fed_personal:
            log_validation_centered(Server.args,
                                    Server.global_personal_val_tracker,
                                    val=True,
                                    local=False)

        # logging.
        logging_globally(tracker, start_global_time)

        # reset start round time.
        start_global_time = time.time()

        # validate the model at the server
        # if args.graph.rank == 0:
        #     do_test(args, model_server, optimizer, criterion, metrics, test_loader)
        # do_validate_test(args, model_server, optimizer, criterion, metrics, test_loader)
    return
示例#6
0
def train_and_validate_federated(client):
    """The training scheme of Federated Learning systems.
        The basic model is FedAvg https://arxiv.org/abs/1602.05629
        TODO: Merge different models under this method
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the training on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    # init global variable.

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        log("Starting round {} of training".format(n_c + 1), client.args.debug)
        online_clients = set_online_clients(client.args)
        if (n_c == 0) and (0 not in online_clients):
            online_clients += [0]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)

        if client.args.graph.rank in online_clients_server:
            if client.args.federated_type == 'scaffold':
                st = time.time()
                client.model_server, client.model_server_control = distribute_model_server_control(
                    client.model_server,
                    client.model_server_control,
                    online_clients_group,
                    src=0)
                client.args.comm_time[-1] += time.time() - st
            else:
                st = time.time()
                client.model_server = distribute_model_server(
                    client.model_server, online_clients_group, src=0)
                client.args.comm_time[-1] += time.time() - st
            client.model.load_state_dict(client.model_server.state_dict())
            local_steps = 0
            if client.args.graph.rank in online_clients:
                is_sync = False
                while not is_sync:
                    for _input, _target in client.train_loader:
                        local_steps += 1
                        client.model.train()

                        # update local step.
                        logging_load_time(tracker)

                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)

                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()

                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)

                        # compute gradient and do local SGD step.
                        loss.backward()

                        if client.args.federated_type == 'fedgate':
                            # Update gradients with control variates
                            for client_param, delta_param in zip(
                                    client.model.parameters(),
                                    client.model_delta.parameters()):
                                client_param.grad.data -= delta_param.data
                        elif client.args.federated_type == 'scaffold':
                            for cp, ccp, scp in zip(
                                    client.model.parameters(),
                                    client.model_client_control.parameters(),
                                    client.model_server_control.parameters()):
                                cp.grad.data += scp.data - ccp.data
                        elif client.args.federated_type == 'fedprox':
                            # Adding proximal gradients and loss for fedprox
                            for client_param, server_param in zip(
                                    client.model.parameters(),
                                    client.model_server.parameters()):
                                if client.args.graph.rank == 0:
                                    print(
                                        "distance norm for prox is:{}".format(
                                            torch.norm(client_param.data -
                                                       server_param.data)))
                                loss += client.args.fedprox_mu / 2 * torch.norm(
                                    client_param.data - server_param.data)
                                client_param.grad.data += client.args.fedprox_mu * (
                                    client_param.data - server_param.data)

                        if 'robust' in client.args.arch:
                            client.model.noise.grad.data *= -1

                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        if 'robust' in client.args.arch:
                            if torch.norm(client.model.noise.data) > 1:
                                client.model.noise.data /= torch.norm(
                                    client.model.noise.data)

                        # logging locally.
                        # logging_computing(tracker, loss_v, performance_v, _input, lr)

                        # display the logging info.
                        # logging_display_training(args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        # model_local = deepcopy(model_client)
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break

            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

            # Validate the local models befor sync
            do_validate(client.args,
                        client.model,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        local=True)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            local=True)
            # Sync the model server based on client models
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1

            if client.args.federated_type == 'fedgate':
                client.model_server, client.model_delta = fedgate_aggregation(
                    client.args, client.model_server, client.model,
                    client.model_delta, client.model_memory,
                    online_clients_group, online_clients, client.optimizer, lr,
                    local_steps)
            elif client.args.federated_type == 'scaffold':
                client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation(
                    client.args, client.model_server, client.model,
                    client.model_server_control, client.model_client_control,
                    online_clients_group, online_clients, client.optimizer, lr,
                    local_steps)
            elif client.args.federated_type == 'qsparse':
                client.model_server = qsparse_aggregation(
                    client.args, client.model_server, client.model,
                    online_clients_group, online_clients, client.optimizer,
                    client.model_memory)
            else:
                client.model_server = fedavg_aggregation(
                    client.args, client.model_server, client.model,
                    online_clients_group, online_clients, client.optimizer)
            # evaluate the sync time
            logging_sync_time(tracker)

            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # logging.
            logging_globally(tracker, start_global_time)

            # reset start round time.
            start_global_time = time.time()

            # validate the model at the server
            if client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')
            log(
                'This round communication time is: {}'.format(
                    client.args.comm_time[-1]), client.args.debug)
        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)
        dist.barrier(group=client.all_clients_group)

    return
示例#7
0
def train_and_validate_federated_apfl(client):
    """The training scheme of Personalized Federated Learning.
        Official implementation for https://arxiv.org/abs/2003.13461
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the testing on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        # TODO: not make the server rank hard coded
        log("Starting round {} of training".format(n_c), client.args.debug)
        online_clients = set_online_clients(client.args)
        if (n_c == 0) and (0 not in online_clients):
            online_clients += [0]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)

        if client.args.graph.rank in online_clients_server:
            client.model_server = distribute_model_server(client.model_server,
                                                          online_clients_group,
                                                          src=0)
            client.model.load_state_dict(client.model_server.state_dict())
            if client.args.graph.rank in online_clients:
                is_sync = False
                ep = -1  # counting number of epochs
                while not is_sync:
                    ep += 1
                    for i, (_input, _target) in enumerate(client.train_loader):
                        client.model.train()

                        # update local step.
                        logging_load_time(tracker)

                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)
                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()
                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)

                        # compute gradient and do local SGD step.
                        loss.backward()
                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        client.optimizer.zero_grad()
                        client.optimizer_personal.zero_grad()
                        loss_personal, performance_personal = inference_personal(
                            client.model_personal, client.model,
                            client.args.fed_personal_alpha, client.criterion,
                            client.metrics, _input, _target)

                        # compute gradient and do local SGD step.
                        loss_personal.backward()
                        client.optimizer_personal.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        # update alpha
                        if client.args.fed_adaptive_alpha and i == 0 and ep == 0:
                            client.args.fed_personal_alpha = alpha_update(
                                client.model, client.model_personal,
                                client.args.fed_personal_alpha,
                                lr)  #0.1/np.sqrt(1+args.local_index))
                            average_alpha = client.args.fed_personal_alpha
                            average_alpha = global_average(
                                average_alpha,
                                client.args.graph.n_nodes,
                                group=online_clients_group)
                            log("New alpha is:{}".format(average_alpha.item()),
                                client.args.debug)

                        # logging locally.
                        # logging_computing(tracker, loss, performance, _input, lr)

                        # display the logging info.
                        # logging_display_training(args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break
            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

            do_validate(client.args,
                        client.model,
                        client.optimizer_personal,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        personal=True,
                        model_personal=client.model_personal,
                        alpha=client.args.fed_personal_alpha)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer_personal,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            personal=True,
                            model_personal=client.model_personal,
                            alpha=client.args.fed_personal_alpha)

            # Sync the model server based on model_clients
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1
            client.model_server = fedavg_aggregation(
                client.args, client.model_server, client.model,
                online_clients_group, online_clients, client.optimizer)
            # evaluate the sync time
            logging_sync_time(tracker)

            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # logging.
            logging_globally(tracker, start_global_time)

            # reset start round time.
            start_global_time = time.time()

            # validate the models at the test data
            if client.args.fed_personal_test:
                do_validate(client.args,
                            client.model_client,
                            client.optimizer_personal,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test',
                            personal=True,
                            model_personal=client.model_personal,
                            alpha=client.args.fed_personal_alpha)
            elif client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')
        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)
        dist.barrier(group=client.all_clients_group)