Exemplo n.º 1
0
def log_validation_per_client_centered(args,
                                       OnlineClients,
                                       online_clients,
                                       val=True,
                                       local=False):
    # log('Aggregate val performance from different clients.', args.debug)
    acc = []
    for oc in online_clients:
        if local:
            if val:
                acc.append(
                    OnlineClients[oc].local_personal_val_tracker['top1'].avg)
            else:
                acc.append(OnlineClients[oc].local_val_tracker['top1'].avg)
        else:
            if val:
                acc.append(
                    OnlineClients[oc].global_personal_val_tracker['top1'].avg)
            else:
                acc.append(OnlineClients[oc].global_val_tracker['top1'].avg)

    log('{} per client stat for {} at batch: {}. Epoch: {}. Process: {}. Worst: {:.3f} Best: {:.3f} Var: {:.3f} Comm: {}'
        .format('Personal' if local else 'Global',
                'validation' if val else 'train', args.local_index,
                args.epoch, args.graph.rank, np.min(acc), np.max(acc),
                np.std(acc), args.rounds_comm),
        debug=args.debug)
    return
Exemplo n.º 2
0
def get_data_stat(args, train_loader, test_loader=None):
    # get the data statictics (on behalf of each worker) for train.
    # args.num_batches_train_per_device_per_epoch = \
    #     len(train_loader) // args.graph.n_nodes \
    #     if not args.partition_data else len(train_loader)
    args.num_batches_train_per_device_per_epoch = len(train_loader)
    args.num_whole_train_batches_per_worker = \
        args.num_batches_train_per_device_per_epoch * args.num_epochs
    args.num_warmup_train_batches_per_worker = \
        args.num_batches_train_per_device_per_epoch * args.lr_warmup_epochs
    args.num_iterations_per_worker = args.num_iterations  #// args.graph.n_nodes

    # get the data statictics (on behalf of each worker) for val.
    if test_loader is not None:
        args.num_batches_val_per_device_per_epoch = len(test_loader)
    else:
        args.num_batches_val_per_device_per_epoch = 0

    # define some parameters for training.
    log(
        'we have {} epochs, \
        {} mini-batches per device for training. \
        {} mini-batches per device for test. \
        The batch size: {}.'.format(
            args.num_epochs, args.num_batches_train_per_device_per_epoch,
            args.num_batches_val_per_device_per_epoch, args.batch_size),
        args.debug)
Exemplo n.º 3
0
def _check_model_at_sync(iter,
                         gpu_id,
                         model,
                         is_weight=False,
                         is_gradient=True,
                         debug=True):
    model_parameters = list(model.parameters())
    param = model_parameters[0]
    if is_weight:
        log(
            "iter:{}, check process {}'s weights for 1st variable:{}".format(
                iter, gpu_id, torch.norm(param.data)), debug)
    if is_gradient:
        log(
            "iter:{}, check process {}'s gradients for 1st variable:{}".format(
                iter, gpu_id, torch.norm(param.grad.data)), debug)
Exemplo n.º 4
0
def log_validation_centered(args,
                            val_tracker,
                            personal=False,
                            val=True,
                            local=False):
    # log('Aggregate val performance from different clients.', args.debug)
    performance = [val_tracker[x].avg for x in ['top1', 'top5', 'losses']]
    pretext = []
    pretext.append('Personal' if personal or local else 'Global')
    pretext.append('validation' if val else 'train')

    log('{} performance for {} at batch: {}. Epoch: {}. Process: {}. Prec@1: {:.3f} Prec@5: {:.3f} Loss: {:.3f} Comm: {}'
        .format(pretext[0], pretext[1], args.local_index, args.epoch,
                args.graph.rank, performance[0], performance[1],
                performance[2], args.rounds_comm),
        debug=args.debug)
    return
Exemplo n.º 5
0
    def initialize(self):
        init_config(self.args)
        self.model, self.criterion, self.scheduler, self.optimizer, self.metrics = create_components(
            self.args)
        self.args.finish_one_epoch = False
        # Create a model server on each client to keep a copy of the server model at each communication round.
        self.model_server = deepcopy(self.model)

        configure_log(self.args)
        log_args(self.args, debug=self.args.debug)
        log('Rank {} with block {} on {} {}-{}'.format(
            self.args.graph.rank,
            self.args.graph.ranks_with_blocks[self.args.graph.rank],
            platform.node(), 'GPU' if self.args.graph.on_cuda else 'CPU',
            self.args.graph.device),
            debug=self.args.debug)

        self.all_clients_group = dist.new_group(self.args.graph.ranks)
Exemplo n.º 6
0
def define_dataset(args,
                   shuffle,
                   test=True,
                   Partitioner=None,
                   return_partitioner=False):
    log('create {} dataset for rank {}'.format(args.data, args.graph.rank),
        args.debug)

    train_loader = partition_dataset(args,
                                     shuffle,
                                     dataset_type='train',
                                     Partitioner=Partitioner,
                                     return_partitioner=return_partitioner)
    if return_partitioner:
        train_loader, Partitioner = train_loader
    if args.fed_personal:
        if args.federated_type == 'perfedavg':
            train_loader, val_loader, val_loader1 = train_loader
        else:
            train_loader, val_loader = train_loader
    if test:
        test_loader = partition_dataset(args, shuffle, dataset_type='test')
    else:
        test_loader = None

    get_data_stat(args, train_loader, test_loader)
    if args.fed_personal:
        if args.federated_type == 'perfedavg':
            out = [train_loader, test_loader, val_loader, val_loader1]
        else:
            out = [train_loader, test_loader, val_loader]
    else:
        out = [train_loader, test_loader]
    if return_partitioner:
        out = (out, Partitioner)
    return out
Exemplo n.º 7
0
def train_and_validate_drfa_centered(Clients, Server):
    log('start training and validation with Federated setting in a centered way.'
        )

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.')

    for oc in range(Server.args.graph.n_nodes):
        Server.lambda_vector[oc] = Clients[oc].args.num_samples_per_epoch
    Server.lambda_vector /= Server.lambda_vector.sum()
    # Number of communication rounds in federated setting should be defined
    for n_c in range(Server.args.num_comms):
        Server.args.rounds_comm += 1
        Server.args.local_index += 1
        Server.args.quant_error = 0.0

        # Preset variables for this round of communication
        Server.zero_grad()
        Server.reset_tracker(Server.local_val_tracker)
        Server.reset_tracker(Server.global_val_tracker)
        Server.reset_tracker(Server.global_test_tracker)
        if Server.args.fed_personal:
            Server.reset_tracker(Server.local_personal_val_tracker)
            Server.reset_tracker(Server.global_personal_val_tracker)

        # Configuring the devices for this round of communication
        log("Starting round {} of training".format(n_c + 1))
        online_clients = set_online_clients_centered(Server.args)

        Server.args.drfa_gamma *= 0.9

        for oc in online_clients:
            Clients[oc].model.load_state_dict(Server.model.state_dict())
            Clients[oc].args.rounds_comm = Server.args.rounds_comm
            local_steps = 0
            is_sync = False

            do_validate_centered(Clients[oc].args,
                                 Server.model,
                                 Server.criterion,
                                 Server.metrics,
                                 Server.optimizer,
                                 Clients[oc].train_loader,
                                 Server.global_val_tracker,
                                 val=False,
                                 local=False)
            if Server.args.per_class_acc:
                Clients[oc].reset_tracker(Clients[oc].local_val_tracker)
                Clients[oc].reset_tracker(Clients[oc].global_val_tracker)
                if Server.args.fed_personal:
                    Clients[oc].reset_tracker(
                        Clients[oc].local_personal_val_tracker)
                    Clients[oc].reset_tracker(
                        Clients[oc].global_personal_val_tracker)
                    do_validate_centered(
                        Clients[oc].args,
                        Server.model,
                        Server.criterion,
                        Server.metrics,
                        Server.optimizer,
                        Clients[oc].val_loader,
                        Clients[oc].global_personal_val_tracker,
                        val=True,
                        local=False)
                do_validate_centered(Clients[oc].args,
                                     Server.model,
                                     Server.criterion,
                                     Server.metrics,
                                     Server.optimizer,
                                     Clients[oc].train_loader,
                                     Clients[oc].global_val_tracker,
                                     val=False,
                                     local=False)
            if Server.args.fed_personal:
                do_validate_centered(Clients[oc].args,
                                     Server.model,
                                     Server.criterion,
                                     Server.metrics,
                                     Server.optimizer,
                                     Clients[oc].val_loader,
                                     Server.global_personal_val_tracker,
                                     val=True,
                                     local=False)

            if Server.args.federated_type == 'perfedavg':
                for _input_val, _target_val in Clients[oc].val_loader1:
                    _input_val, _target_val = load_data_batch(
                        Clients[oc].args, _input_val, _target_val, tracker)
                    break

            k = torch.randint(low=1, high=Server.args.local_step,
                              size=(1, )).item()
            while not is_sync:
                if Server.args.arch == 'rnn':
                    Clients[oc].model.init_hidden(Server.args.batch_size)
                for _input, _target in Clients[oc].train_loader:
                    local_steps += 1
                    if k == local_steps:
                        Clients[oc].kth_model.load_state_dict(
                            Clients[oc].model.state_dict())
                    Clients[oc].model.train()

                    # update local step.
                    logging_load_time(tracker)

                    # update local index and get local step
                    Clients[oc].args.local_index += 1
                    Clients[oc].args.local_data_seen += len(_target)
                    get_current_epoch(Clients[oc].args)
                    local_step = get_current_local_step(Clients[oc].args)

                    # adjust learning rate (based on the # of accessed samples)
                    lr = adjust_learning_rate(Clients[oc].args,
                                              Clients[oc].optimizer,
                                              Clients[oc].scheduler)

                    # load data
                    _input, _target = load_data_batch(Clients[oc].args, _input,
                                                      _target, tracker)

                    # Skip batches with one sample because of BatchNorm issue in some models!
                    if _input.size(0) == 1:
                        is_sync = is_sync_fed(Clients[oc].args)
                        break

                    # inference and get current performance.
                    Clients[oc].optimizer.zero_grad()

                    loss, performance = inference(Clients[oc].model,
                                                  Clients[oc].criterion,
                                                  Clients[oc].metrics,
                                                  _input,
                                                  _target,
                                                  rnn=Server.args.arch
                                                  in ['rnn'])

                    # compute gradient and do local SGD step.
                    loss.backward()

                    if Clients[oc].args.federated_type == 'fedgate':
                        # Update gradients with control variates
                        for client_param, delta_param in zip(
                                Clients[oc].model.parameters(),
                                Clients[oc].model_delta.parameters()):
                            client_param.grad.data -= delta_param.data
                    elif Clients[oc].args.federated_type == 'scaffold':
                        for cp, ccp, scp in zip(
                                Clients[oc].model.parameters(),
                                Clients[oc].model_client_control.parameters(),
                                Server.model_server_control.parameters()):
                            cp.grad.data += scp.data - ccp.data
                    elif Clients[oc].args.federated_type == 'fedprox':
                        # Adding proximal gradients and loss for fedprox
                        for client_param, server_param in zip(
                                Clients[oc].model.parameters(),
                                Server.model.parameters()):
                            loss += Clients[
                                oc].args.fedprox_mu / 2 * torch.norm(
                                    client_param.data - server_param.data)
                            client_param.grad.data += Clients[
                                oc].args.fedprox_mu * (client_param.data -
                                                       server_param.data)

                    if 'robust' in Clients[oc].args.arch:
                        Clients[oc].model.noise.grad.data *= -1

                    Clients[oc].optimizer.step(
                        apply_lr=True,
                        apply_in_momentum=Clients[oc].args.in_momentum,
                        apply_out_momentum=False)

                    if 'robust' in Clients[oc].args.arch:
                        if torch.norm(Clients[oc].model.noise.data) > 1:
                            Clients[oc].model.noise.data /= torch.norm(
                                Clients[oc].model.noise.data)

                    if Clients[oc].args.federated_type == 'perfedavg':
                        # _input_val, _target_val = Clients[oc].load_next_val_batch()
                        lr = adjust_learning_rate(
                            Clients[oc].args,
                            Clients[oc].optimizer,
                            Clients[oc].scheduler,
                            lr_external=Clients[oc].args.perfedavg_beta)
                        if _input_val.size(0) == 1:
                            is_sync = is_sync_fed(Clients[oc].args)
                            break

                        Clients[oc].optimizer.zero_grad()
                        loss, performance = inference(Clients[oc].model,
                                                      Clients[oc].criterion,
                                                      Clients[oc].metrics,
                                                      _input_val, _target_val)
                        loss.backward()
                        Clients[oc].optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=Clients[oc].args.in_momentum,
                            apply_out_momentum=False)

                    # reset load time for the tracker.
                    tracker['start_load_time'] = time.time()
                    # model_local = deepcopy(model_client)
                    is_sync = is_sync_fed(Clients[oc].args)
                    if is_sync:
                        break

            do_validate_centered(Clients[oc].args,
                                 Clients[oc].model,
                                 Clients[oc].criterion,
                                 Clients[oc].metrics,
                                 Clients[oc].optimizer,
                                 Clients[oc].train_loader,
                                 Server.local_val_tracker,
                                 val=False,
                                 local=True)
            if Server.args.per_class_acc:
                do_validate_centered(Clients[oc].args,
                                     Clients[oc].model,
                                     Clients[oc].criterion,
                                     Clients[oc].metrics,
                                     Clients[oc].optimizer,
                                     Clients[oc].train_loader,
                                     Clients[oc].local_val_tracker,
                                     val=False,
                                     local=True)
                if Server.args.fed_personal:
                    do_validate_centered(
                        Clients[oc].args,
                        Clients[oc].model,
                        Clients[oc].criterion,
                        Clients[oc].metrics,
                        Clients[oc].optimizer,
                        Clients[oc].val_loader,
                        Clients[oc].local_personal_val_tracker,
                        val=True,
                        local=True)
            if Server.args.fed_personal:
                do_validate_centered(Clients[oc].args,
                                     Clients[oc].model,
                                     Clients[oc].criterion,
                                     Clients[oc].metrics,
                                     Clients[oc].optimizer,
                                     Clients[oc].val_loader,
                                     Server.local_personal_val_tracker,
                                     val=True,
                                     local=True)
            # Sync the model server based on model_clients
            tracker['start_sync_time'] = time.time()
            Server.args.global_index += 1
            logging_sync_time(tracker)

        if Server.args.federated_type == 'scaffold':
            scaffold_aggregation_centered(Clients, Server, online_clients,
                                          local_steps, lr)
        elif Server.args.federated_type == 'fedgate':
            fedgate_aggregation_centered(Clients, Server, online_clients,
                                         local_steps, lr)
        elif Server.args.federated_type == 'qsparse':
            qsparse_aggregation_centered(Clients, Server, online_clients,
                                         local_steps, lr)
        else:
            fedavg_aggregation_centered(Clients, Server, online_clients,
                                        Server.lambda_vector.numpy())

        # Aggregate Kth models
        aggregate_kth_model_centered(Clients, Server, online_clients)

        # Log performance
        # Client training performance
        log_validation_centered(Server.args,
                                Server.local_val_tracker,
                                val=False,
                                local=True)
        # Server training performance
        log_validation_centered(Server.args,
                                Server.global_val_tracker,
                                val=False,
                                local=False)
        if Server.args.fed_personal:
            # Client validation performance
            log_validation_centered(Server.args,
                                    Server.local_personal_val_tracker,
                                    val=True,
                                    local=True)
            # Server validation performance
            log_validation_centered(Server.args,
                                    Server.global_personal_val_tracker,
                                    val=True,
                                    local=False)

        # Per client stats
        if Server.args.per_class_acc:
            log_validation_per_client_centered(Server.args,
                                               Clients,
                                               online_clients,
                                               val=False,
                                               local=False)
            log_validation_per_client_centered(Server.args,
                                               Clients,
                                               online_clients,
                                               val=False,
                                               local=True)
            if Server.args.fed_personal:
                log_validation_per_client_centered(Server.args,
                                                   Clients,
                                                   online_clients,
                                                   val=True,
                                                   local=False)
                log_validation_per_client_centered(Server.args,
                                                   Clients,
                                                   online_clients,
                                                   val=True,
                                                   local=True)

        # Test on server
        do_validate_centered(Server.args,
                             Server.model,
                             Server.criterion,
                             Server.metrics,
                             Server.optimizer,
                             Server.test_loader,
                             Server.global_test_tracker,
                             val=False,
                             local=False)
        log_test_centered(Server.args, Server.global_test_tracker)

        online_clients_lambda = set_online_clients_centered(Server.args)
        loss_tensor = torch.zeros(Server.args.graph.n_nodes)
        num_online_clients = len(online_clients_lambda)
        for ocl in online_clients_lambda:
            for _input, _target in Clients[ocl].train_loader:
                _input, _target = load_data_batch(Clients[ocl].args, _input,
                                                  _target, tracker)
                loss, _ = inference(Server.kth_model,
                                    Clients[ocl].criterion,
                                    Clients[ocl].metrics,
                                    _input,
                                    _target,
                                    rnn=Server.args.arch in ['rnn'])
                break
            loss_tensor[ocl] = loss * (Server.args.graph.n_nodes /
                                       num_online_clients)
        Server.lambda_vector += Server.args.drfa_gamma * Server.args.local_step * loss_tensor
        lambda_vector = projection_simplex_sort(
            Server.lambda_vector.detach().numpy())
        print(lambda_vector)
        # Avoid zero probability
        lambda_zeros = np.argwhere(lambda_vector <= 1e-3)
        if len(lambda_zeros) > 0:
            lambda_vector[lambda_zeros[0]] = 1e-3
            lambda_vector /= np.sum(lambda_vector)
        Server.lambda_vector = torch.tensor(lambda_vector)

        # logging.
        logging_globally(tracker, start_global_time)

        # reset start round time.
        start_global_time = time.time()

        # validate the model at the server
        # if args.graph.rank == 0:
        #     do_test(args, model_server, optimizer, criterion, metrics, test_loader)
        # do_validate_test(args, model_server, optimizer, criterion, metrics, test_loader)
    return
Exemplo n.º 8
0
def train_and_validate_federated_afl(client):
    """The training scheme of Federated Learning systems.
        This the implementation of Agnostic Federated Learning
        https://arxiv.org/abs/1902.00146
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the testing on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    # Initialize lambda variable proportianate to their sample size
    if client.args.graph.rank == 0:
        gather_list_size = [
            torch.tensor(0.0) for _ in range(client.args.graph.n_nodes)
        ]
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    gather_list=gather_list_size,
                    dst=0)
        client.lambda_vector = torch.stack(
            gather_list_size) / client.args.train_dataset_size
    else:
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    dst=0)
        client.lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] *
                                            client.args.graph.n_nodes)

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        # TODO: not make the server rank hard coded
        log("Starting round {} of training".format(n_c + 1), client.args.debug)
        online_clients = set_online_clients(client.args)
        if n_c == 0:
            # The first round server should be in the communication to initilize its own training
            online_clients = online_clients if 0 in online_clients else online_clients + [
                0
            ]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)

        if client.args.graph.rank in online_clients_server:
            st = time.time()
            client.model_server = distribute_model_server(client.model_server,
                                                          online_clients_group,
                                                          src=0)
            dist.broadcast(client.lambda_vector,
                           src=0,
                           group=online_clients_group)
            client.args.comm_time[-1] += time.time() - st
            client.model.load_state_dict(client.model_server.state_dict())

            # This loss tensor is for those clients not participating in the first round
            loss = torch.tensor(0.0)
            # Start running updates on local machines
            if client.args.graph.rank in online_clients:
                is_sync = False
                while not is_sync:
                    for _input, _target in client.train_loader:

                        client.model.train()
                        # update local step.
                        logging_load_time(tracker)
                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)

                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()
                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)
                        # compute gradient and do local SGD step.
                        loss.backward()
                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        # logging locally.
                        # logging_computing(tracker, loss, performance, _input, lr)

                        # display the logging info.
                        # logging_display_training(args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break
            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

            # Validate the local models befor sync
            do_validate(client.args,
                        client.model,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        local=True)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            local=True)
            # Sync the model server based on client models
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1

            client.model_server, loss_tensor_online = afl_aggregation(
                client.args, client.model_server, client.model,
                client.lambda_vector[client.args.graph.rank].item(),
                torch.tensor(loss.item()), online_clients_group,
                online_clients, client.optimizer)

            # evaluate the sync time
            logging_sync_time(tracker)
            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # Updating lambda variable
            if client.args.graph.rank == 0:
                num_online_clients = len(
                    online_clients
                ) if 0 in online_clients else len(online_clients) + 1
                loss_tensor = torch.zeros(client.args.graph.n_nodes)
                loss_tensor[sorted(online_clients_server)] = loss_tensor_online
                # Dual update
                client.lambda_vector += client.args.drfa_gamma * loss_tensor
                # Projection into a simplex
                client.lambda_vector = euclidean_proj_simplex(
                    client.lambda_vector)
                # Avoid zero probability
                lambda_zeros = client.lambda_vector <= 1e-3
                if lambda_zeros.sum() > 0:
                    client.lambda_vector[lambda_zeros] = 1e-3
                    client.lambda_vector /= client.lambda_vector.sum()

            # logging.
            logging_globally(tracker, start_global_time)

            # reset start round time.
            start_global_time = time.time()
            # validate the model at the server
            if client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')
            log(
                'This round communication time is: {}'.format(
                    client.args.comm_time[-1]), client.args.debug)
        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)
        dist.barrier(group=client.all_clients_group)

    return
Exemplo n.º 9
0
def train_and_validate_federated_drfa(client):
    """The training scheme of Distributionally Robust Federated Learning DRFA.
        paper: https://papers.nips.cc/paper/2020/hash/ac450d10e166657ec8f93a1b65ca1b14-Abstract.html
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the testing on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    # Initialize lambda variable proportianate to their sample size
    if client.args.graph.rank == 0:
        gather_list_size = [
            torch.tensor(0.0) for _ in range(client.args.graph.n_nodes)
        ]
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    gather_list=gather_list_size,
                    dst=0)
        lambda_vector = torch.stack(
            gather_list_size) / client.args.train_dataset_size
    else:
        dist.gather(torch.tensor(client.args.num_samples_per_epoch,
                                 dtype=torch.float32),
                    dst=0)
        lambda_vector = torch.tensor([1 / client.args.graph.n_nodes] *
                                     client.args.graph.n_nodes)

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        # TODO: not make the server rank hard coded
        log("Starting round {} of training".format(n_c + 1), client.args.debug)
        online_clients = set_online_clients(client.args)
        if n_c == 0:
            # The first round server should be in the communication to initilize its own training
            online_clients = online_clients if 0 in online_clients else online_clients + [
                0
            ]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)
        client.args.drfa_gamma *= 0.9
        if client.args.graph.rank in online_clients_server:
            if client.args.federated_type == 'scaffold':
                st = time.time()
                client.model_server, client.model_server_control = distribute_model_server_control(
                    client.model_server,
                    client.model_server_control,
                    online_clients_group,
                    src=0)
                client.args.comm_time[-1] += time.time() - st
            else:
                st = time.time()
                model_server = distribute_model_server(client.model_server,
                                                       online_clients_group,
                                                       src=0)
                client.args.comm_time[-1] += time.time() - st
            client.model.load_state_dict(client.model_server.state_dict())

            # Send related variables to drfa algorithm
            st = time.time()
            dist.broadcast(client.lambda_vector,
                           src=0,
                           group=online_clients_group)
            # Sending the random number k to all nodes:
            # Does not fully support the epoch mode now
            k = torch.randint(low=1, high=client.args.local_step, size=(1, ))
            dist.broadcast(k, src=0, group=online_clients_group)
            client.args.comm_time[-1] += time.time() - st

            k = k.tolist()[0]
            local_steps = 0
            # Start running updates on local machines
            if client.args.graph.rank in online_clients:
                is_sync = False
                while not is_sync:
                    for _input, _target in client.train_loader:
                        local_steps += 1
                        # Getting the k-th model for dual variable update
                        if k == local_steps:
                            client.kth_model.load_state_dict(
                                client.model.state_dict())
                        client.model.train()

                        # update local step.
                        logging_load_time(tracker)

                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)
                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()
                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)

                        # compute gradient and do local SGD step.
                        loss.backward()

                        if client.args.federated_type == 'fedgate':
                            for client_param, delta_param in zip(
                                    client.model.parameters(),
                                    client.model_delta.parameters()):
                                client_param.grad.data -= delta_param.data
                        elif client.args.federated_type == 'scaffold':
                            for cp, ccp, scp in zip(
                                    client.model.parameters(),
                                    client.model_client_control.parameters(),
                                    client.model_server_control.parameters()):
                                cp.grad.data += scp.data - ccp.data

                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        # logging locally.
                        # logging_computing(tracker, loss, performance, _input, lr)

                        # display the logging info.
                        # logging_display_training(client.args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break
            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

        # Validate the local models befor sync
            do_validate(client.args,
                        client.model,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        local=True)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            local=True)
            # Sync the model server based on model_clients
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1

            if client.args.federated_type == 'fedgate':
                client.model_server, client.model_delta = fedgate_aggregation(
                    client.args,
                    client.model_server,
                    client.model,
                    client.model_delta,
                    client.model_memory,
                    online_clients_group,
                    online_clients,
                    client.optimizer,
                    lr,
                    local_steps,
                    lambda_weight=client.lambda_vector[
                        client.args.graph.rank].item())
            elif client.args.federated_type == 'scaffold':
                client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation(
                    client.args,
                    client.model_server,
                    client.model,
                    client.model_server_control,
                    client.model_client_control,
                    online_clients_group,
                    online_clients,
                    client.optimizer,
                    lr,
                    local_steps,
                    lambda_weight=client.lambda_vector[
                        client.args.graph.rank].item())
            else:
                client.model_server = fedavg_aggregation(
                    client.args,
                    client.model_server,
                    client.model,
                    online_clients_group,
                    online_clients,
                    client.optimizer,
                    lambda_weight=client.lambda_vector[
                        client.args.graph.rank].item())
            # Average the kth_model
            client.kth_model = aggregate_models_virtual(
                client.args, client.kth_model, online_clients_group,
                online_clients)
            # evaluate the sync time
            logging_sync_time(tracker)

            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # validate the model at the server
            if client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')

        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)

        # Update lambda parameters
        online_clients_lambda = set_online_clients(client.args)
        online_clients_server_lambda = online_clients_lambda if 0 in online_clients_lambda else [
            0
        ] + online_clients_lambda
        online_clients_group_lambda = dist.new_group(
            online_clients_server_lambda)

        if client.args.graph.rank in online_clients_server_lambda:
            st = time.time()
            client.kth_model = distribute_model_server(
                client.kth_model, online_clients_group_lambda, src=0)
            client.args.comm_time[-1] += time.time() - st
            loss = torch.tensor(0.0)

            if client.args.graph.rank in online_clients_lambda:
                for _input, _target in client.train_loader:
                    _input, _target = load_data_batch(client.args, _input,
                                                      _target, tracker)
                    # Skip batches with one sample because of BatchNorm issue in some models!
                    if _input.size(0) == 1:
                        break
                    loss, _ = inference(client.kth_model, client.criterion,
                                        client.metrics, _input, _target)
                    break
            loss_tensor_online = loss_gather(
                client.args,
                torch.tensor(loss.item()),
                group=online_clients_group_lambda,
                online_clients=online_clients_lambda)
            if client.args.graph.rank == 0:
                num_online_clients = len(
                    online_clients_lambda
                ) if 0 in online_clients_lambda else len(
                    online_clients_lambda) + 1
                loss_tensor = torch.zeros(client.args.graph.n_nodes)
                loss_tensor[sorted(
                    online_clients_server_lambda)] = loss_tensor_online * (
                        client.args.graph.n_nodes / num_online_clients)
                # Dual update
                client.lambda_vector += client.args.drfa_gamma * client.args.local_step * loss_tensor
                client.lambda_vector = euclidean_proj_simplex(
                    client.lambda_vector)

                # Avoid zero probability
                lambda_zeros = client.lambda_vector <= 1e-3
                if lambda_zeros.sum() > 0:
                    client.lambda_vector[lambda_zeros] = 1e-3
                    client.lambda_vector /= client.lambda_vector.sum()

        # logging.
        logging_globally(tracker, start_global_time)

        # reset start round time.
        start_global_time = time.time()
        log(
            'This round communication time is: {}'.format(
                client.args.comm_time[-1]), client.args.debug)
        dist.barrier(group=client.all_clients_group)
    return
Exemplo n.º 10
0
def train_and_validate(client):
    """The training scheme of Distributed Local SGD."""
    log('start training and validation.', client.args.debug)

    if client.args.evaluate and client.args.graph.rank==0:
        # Do the training on the server and return
        do_validate(client.args, client.model, client.optimizer,  client.criterion, client.metrics,
                         client.test_loader, client.all_clients_group, data_mode='test')
        return

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    client.args.comm_time.append(0.0)
    # break until finish expected full epoch training.
    while True:
        # configure local step.
        for _input, _target in client.train_loader:
            client.model.train()

            # update local step.
            logging_load_time(tracker)

            # update local index and get local step
            client.args.local_index += 1
            client.args.local_data_seen += len(_target)
            get_current_epoch(client.args)
            local_step = get_current_local_step(client.args)

            # adjust learning rate (based on the # of accessed samples)
            lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler)

            # load data
            _input, _target = load_data_batch(client.args, _input, _target, tracker)

            # inference and get current performance.
            client.optimizer.zero_grad()
            loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target)

            # compute gradient and do local SGD step.
            loss.backward()
            client.optimizer.step(
                apply_lr=True,
                apply_in_momentum=client.args.in_momentum, apply_out_momentum=False
            )

            # logging locally.
            logging_computing(tracker, loss, performance, _input, lr)

            # evaluate the status.
            is_sync = client.args.local_index % local_step == 0
            if client.args.epoch_ % 1 == 0:
                client.args.finish_one_epoch = True

            # sync
            if is_sync:
                log('Enter synching', client.args.debug)
                client.args.global_index += 1

                # broadcast gradients to other nodes by using reduce_sum.
                client.model_server = aggregate_gradients(client.args, client.model_server,
                                                          client.model, client.optimizer, is_sync)
                # evaluate the sync time
                logging_sync_time(tracker)

                # logging.
                logging_globally(tracker, start_global_time)
                
                # reset start round time.
                start_global_time = time.time()
            
            # finish one epoch training and to decide if we want to val our model.
            if client.args.finish_one_epoch:
                if client.args.epoch % client.args.eval_freq ==0 and client.args.graph.rank == 0:
                        do_validate(client.args, client.model, client.optimizer,  client.criterion, client.metrics,
                            client.test_loader, client.all_clients_group, data_mode='test')
                dist.barrier(group=client.all_clients_group)
                # refresh the logging cache at the begining of each epoch.
                client.args.finish_one_epoch = False
                tracker = define_local_training_tracker()
            
            # determine if the training is finished.
            if is_stop(client.args):
                #Last Sync
                log('Enter synching', client.args.debug)
                client.args.global_index += 1

                # broadcast gradients to other nodes by using reduce_sum.
                client.model_server = aggregate_gradients(client.args, client.model_server,
                                                          client.model, client.optimizer, is_sync)

                print("Total number of samples seen on device {} is {}".format(client.args.graph.rank, client.args.local_data_seen))
                if client.args.graph.rank == 0:
                    do_validate(client.args, client.model_server, client.optimizer,  client.criterion, client.metrics,
                        client.test_loader, client.all_clients_group, data_mode='test')
                return

            # display the logging info.
            logging_display_training(client.args, tracker)

            # reset load time for the tracker.
            tracker['start_load_time'] = time.time()

        # reshuffle the data.
        if client.args.reshuffle_per_epoch:
            log('reshuffle the dataset.', client.args.debug)
            del client.train_loader, client.test_loader
            gc.collect()
            log('reshuffle the dataset.', client.args.debug)
            client.load_local_dataset()
Exemplo n.º 11
0
def train_and_validate_perfedme_centered(Clients, Server):
    log('start training and validation with Federated setting in a centered way.'
        )

    # For Sparsified SGD
    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.')

    # Number of communication rounds in federated setting should be defined
    for n_c in range(Server.args.num_comms):
        Server.args.rounds_comm += 1
        Server.args.local_index += 1
        Server.args.quant_error = 0.0

        # Preset variables for this round of communication
        Server.zero_grad()
        Server.reset_tracker(Server.local_val_tracker)
        Server.reset_tracker(Server.global_val_tracker)
        if Server.args.fed_personal:
            Server.reset_tracker(Server.local_personal_val_tracker)
            Server.reset_tracker(Server.global_personal_val_tracker)

        # Configuring the devices for this round of communication
        log("Starting round {} of training".format(n_c + 1))
        online_clients = set_online_clients_centered(Server.args)

        for oc in online_clients:
            Clients[oc].model.load_state_dict(Server.model.state_dict())
            Clients[oc].args.rounds_comm = Server.args.rounds_comm
            local_steps = 0
            is_sync = False

            do_validate_centered(Clients[oc].args,
                                 Server.model,
                                 Server.criterion,
                                 Server.metrics,
                                 Server.optimizer,
                                 Clients[oc].train_loader,
                                 Server.global_val_tracker,
                                 val=False,
                                 local=False)
            if Server.args.fed_personal:
                do_validate_centered(Clients[oc].args,
                                     Server.model,
                                     Server.criterion,
                                     Server.metrics,
                                     Server.optimizer,
                                     Clients[oc].val_loader,
                                     Server.global_personal_val_tracker,
                                     val=True,
                                     local=False)

            while not is_sync:
                for _input, _target in Clients[oc].train_loader:
                    local_steps += 1
                    Clients[oc].model.train()
                    Clients[oc].model_personal.train()

                    # update local step.
                    logging_load_time(tracker)

                    # update local index and get local step
                    Clients[oc].args.local_index += 1
                    Clients[oc].args.local_data_seen += len(_target)
                    get_current_epoch(Clients[oc].args)
                    local_step = get_current_local_step(Clients[oc].args)

                    # adjust learning rate (based on the # of accessed samples)
                    lr = adjust_learning_rate(Clients[oc].args,
                                              Clients[oc].optimizer_personal,
                                              Clients[oc].scheduler)

                    # load data
                    _input, _target = load_data_batch(Clients[oc].args, _input,
                                                      _target, tracker)

                    # Skip batches with one sample because of BatchNorm issue in some models!
                    if _input.size(0) == 1:
                        is_sync = is_sync_fed(Clients[oc].args)
                        break

                    # inference and get current performance.
                    Clients[oc].optimizer_personal.zero_grad()

                    loss, performance = inference(Clients[oc].model_personal,
                                                  Clients[oc].criterion,
                                                  Clients[oc].metrics, _input,
                                                  _target)
                    # print("loss in rank {} is {}".format(oc,loss))

                    # compute gradient and do local SGD step.
                    loss.backward()

                    for client_param, personal_param in zip(
                            Clients[oc].model.parameters(),
                            Clients[oc].model_personal.parameters()):
                        # loss += Clients[oc].args.perfedme_lambda * torch.norm(personal_param.data - client_param.data)**2
                        personal_param.grad.data += Clients[
                            oc].args.perfedme_lambda * (personal_param.data -
                                                        client_param.data)

                    Clients[oc].optimizer_personal.step(
                        apply_lr=True,
                        apply_in_momentum=Clients[oc].args.in_momentum,
                        apply_out_momentum=False)

                    if Clients[oc].args.local_index == 1:
                        Clients[oc].optimizer.zero_grad()
                        loss, performance = inference(Clients[oc].model,
                                                      Clients[oc].criterion,
                                                      Clients[oc].metrics,
                                                      _input, _target)
                        loss.backward()

                    is_sync = is_sync_fed(Clients[oc].args)
                    if Clients[oc].args.local_index % 5 == 0 or is_sync:
                        log('Updating the local version of the global model',
                            Clients[oc].args.debug)
                        lr = adjust_learning_rate(Clients[oc].args,
                                                  Clients[oc].optimizer,
                                                  Clients[oc].scheduler)
                        Clients[oc].optimizer.zero_grad()
                        for client_param, personal_param in zip(
                                Clients[oc].model.parameters(),
                                Clients[oc].model_personal.parameters()):
                            client_param.grad.data = Clients[
                                oc].args.perfedme_lambda * (
                                    client_param.data - personal_param.data)
                        Clients[oc].optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=Clients[oc].args.in_momentum,
                            apply_out_momentum=False)

                    # reset load time for the tracker.
                    tracker['start_load_time'] = time.time()
                    # model_local = deepcopy(model_client)

                    if is_sync:
                        break

            do_validate_centered(Clients[oc].args,
                                 Clients[oc].model_personal,
                                 Clients[oc].criterion,
                                 Clients[oc].metrics,
                                 Clients[oc].optimizer_personal,
                                 Clients[oc].train_loader,
                                 Server.local_val_tracker,
                                 val=False,
                                 local=True)
            if Server.args.fed_personal:
                do_validate_centered(Clients[oc].args,
                                     Clients[oc].model_personal,
                                     Clients[oc].criterion,
                                     Clients[oc].metrics,
                                     Clients[oc].optimizer_personal,
                                     Clients[oc].val_loader,
                                     Server.local_personal_val_tracker,
                                     val=True,
                                     local=True)
            # Sync the model server based on model_clients
            tracker['start_sync_time'] = time.time()
            Server.args.global_index += 1

            logging_sync_time(tracker)

        fedavg_aggregation_centered(Clients, Server, online_clients)
        # Log local performance
        log_validation_centered(Server.args,
                                Server.local_val_tracker,
                                val=False,
                                local=True)
        if Server.args.fed_personal:
            log_validation_centered(Server.args,
                                    Server.local_personal_val_tracker,
                                    val=True,
                                    local=True)

        # Log server performance
        log_validation_centered(Server.args,
                                Server.global_val_tracker,
                                val=False,
                                local=False)
        if Server.args.fed_personal:
            log_validation_centered(Server.args,
                                    Server.global_personal_val_tracker,
                                    val=True,
                                    local=False)

        # logging.
        logging_globally(tracker, start_global_time)

        # reset start round time.
        start_global_time = time.time()

        # validate the model at the server
        # if args.graph.rank == 0:
        #     do_test(args, model_server, optimizer, criterion, metrics, test_loader)
        # do_validate_test(args, model_server, optimizer, criterion, metrics, test_loader)
    return
Exemplo n.º 12
0
def log_test_centered(args, val_tracker):
    performance = [val_tracker[x].avg for x in ['top1', 'top5', 'losses']]
    log('Test at batch: {}. Epoch: {}. Process: {}. Prec@1: {:.3f} Prec@5: {:.3f} Loss: {:.3f} Comm: {}'
        .format(args.local_index, args.epoch, args.graph.rank, performance[0],
                performance[1], performance[2], args.rounds_comm),
        debug=args.debug)
Exemplo n.º 13
0
def train_and_validate_federated(client):
    """The training scheme of Federated Learning systems.
        The basic model is FedAvg https://arxiv.org/abs/1602.05629
        TODO: Merge different models under this method
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the training on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    # init global variable.

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        log("Starting round {} of training".format(n_c + 1), client.args.debug)
        online_clients = set_online_clients(client.args)
        if (n_c == 0) and (0 not in online_clients):
            online_clients += [0]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)

        if client.args.graph.rank in online_clients_server:
            if client.args.federated_type == 'scaffold':
                st = time.time()
                client.model_server, client.model_server_control = distribute_model_server_control(
                    client.model_server,
                    client.model_server_control,
                    online_clients_group,
                    src=0)
                client.args.comm_time[-1] += time.time() - st
            else:
                st = time.time()
                client.model_server = distribute_model_server(
                    client.model_server, online_clients_group, src=0)
                client.args.comm_time[-1] += time.time() - st
            client.model.load_state_dict(client.model_server.state_dict())
            local_steps = 0
            if client.args.graph.rank in online_clients:
                is_sync = False
                while not is_sync:
                    for _input, _target in client.train_loader:
                        local_steps += 1
                        client.model.train()

                        # update local step.
                        logging_load_time(tracker)

                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)

                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()

                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)

                        # compute gradient and do local SGD step.
                        loss.backward()

                        if client.args.federated_type == 'fedgate':
                            # Update gradients with control variates
                            for client_param, delta_param in zip(
                                    client.model.parameters(),
                                    client.model_delta.parameters()):
                                client_param.grad.data -= delta_param.data
                        elif client.args.federated_type == 'scaffold':
                            for cp, ccp, scp in zip(
                                    client.model.parameters(),
                                    client.model_client_control.parameters(),
                                    client.model_server_control.parameters()):
                                cp.grad.data += scp.data - ccp.data
                        elif client.args.federated_type == 'fedprox':
                            # Adding proximal gradients and loss for fedprox
                            for client_param, server_param in zip(
                                    client.model.parameters(),
                                    client.model_server.parameters()):
                                if client.args.graph.rank == 0:
                                    print(
                                        "distance norm for prox is:{}".format(
                                            torch.norm(client_param.data -
                                                       server_param.data)))
                                loss += client.args.fedprox_mu / 2 * torch.norm(
                                    client_param.data - server_param.data)
                                client_param.grad.data += client.args.fedprox_mu * (
                                    client_param.data - server_param.data)

                        if 'robust' in client.args.arch:
                            client.model.noise.grad.data *= -1

                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        if 'robust' in client.args.arch:
                            if torch.norm(client.model.noise.data) > 1:
                                client.model.noise.data /= torch.norm(
                                    client.model.noise.data)

                        # logging locally.
                        # logging_computing(tracker, loss_v, performance_v, _input, lr)

                        # display the logging info.
                        # logging_display_training(args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        # model_local = deepcopy(model_client)
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break

            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

            # Validate the local models befor sync
            do_validate(client.args,
                        client.model,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        local=True)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            local=True)
            # Sync the model server based on client models
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1

            if client.args.federated_type == 'fedgate':
                client.model_server, client.model_delta = fedgate_aggregation(
                    client.args, client.model_server, client.model,
                    client.model_delta, client.model_memory,
                    online_clients_group, online_clients, client.optimizer, lr,
                    local_steps)
            elif client.args.federated_type == 'scaffold':
                client.model_server, client.model_client_control, client.model_server_control = scaffold_aggregation(
                    client.args, client.model_server, client.model,
                    client.model_server_control, client.model_client_control,
                    online_clients_group, online_clients, client.optimizer, lr,
                    local_steps)
            elif client.args.federated_type == 'qsparse':
                client.model_server = qsparse_aggregation(
                    client.args, client.model_server, client.model,
                    online_clients_group, online_clients, client.optimizer,
                    client.model_memory)
            else:
                client.model_server = fedavg_aggregation(
                    client.args, client.model_server, client.model,
                    online_clients_group, online_clients, client.optimizer)
            # evaluate the sync time
            logging_sync_time(tracker)

            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # logging.
            logging_globally(tracker, start_global_time)

            # reset start round time.
            start_global_time = time.time()

            # validate the model at the server
            if client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')
            log(
                'This round communication time is: {}'.format(
                    client.args.comm_time[-1]), client.args.debug)
        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)
        dist.barrier(group=client.all_clients_group)

    return
Exemplo n.º 14
0
def partition_dataset(args,
                      shuffle,
                      dataset_type,
                      Partitioner=None,
                      return_partitioner=False):
    """ Given a dataset, partition it. """
    if Partitioner is None:
        dataset = get_dataset(args,
                              args.data,
                              args.data_dir,
                              split=dataset_type)
    else:
        dataset = Partitioner.data
    # Federated Dataset with Validation
    # if args.data in ['emnist'] and args.fed_personal and dataset_type=='train':
    #     dataset, dataset_val = dataset
    batch_size = args.batch_size
    world_size = args.graph.n_nodes

    # partition data.
    if args.partition_data and dataset_type == 'train':
        if args.iid_data:
            if args.data in [
                    'emnist', 'emnist_full', 'synthetic', 'shakespeare'
            ]:
                raise ValueError(
                    'The dataset {} does not have a structure for iid distribution of data'
                    .format(args.data))
            if args.growing_batch_size:
                pt = 'growing'
            else:
                pt = 'normal'
        else:
            if args.data not in [
                    'mnist', 'fashion_mnist', 'emnist', 'emnist_full',
                    'cifar10', 'cifar100', 'adult', 'synthetic', 'shakespeare'
            ]:
                raise NotImplementedError(
                    """Non-iid distribution of data for dataset {} is not implemented.
                        Set the distribution to iid.""".format(args.data))
            if args.growing_batch_size:
                raise ValueError(
                    'Growing Minibatch Size is not designed for non-iid data distribution'
                )
            else:
                pt = 'noniid'

        if Partitioner is None:
            if return_partitioner:
                data_to_load, Partitioner = partitioner(
                    args,
                    dataset,
                    shuffle,
                    world_size,
                    partition_type=pt,
                    return_partitioner=True)
            else:
                data_to_load = partitioner(args,
                                           dataset,
                                           shuffle,
                                           world_size,
                                           partition_type=pt)
            log('Make {} data partitions and use the subdata.'.format(pt),
                args.debug)
        else:
            data_to_load = Partitioner.use(args.graph.rank)
            log('use the loaded partitioner to load the data.', args.debug)
    else:
        if Partitioner is not None:
            raise ValueError(
                'Partitioner is provided but data partition method is not defined!'
            )
        # If test dataset needs to be partitioned this should be changed
        data_to_load = dataset
        log('used whole data.', args.debug)

    # Log stats about the dataset to laod
    if dataset_type == 'train':
        args.train_dataset_size = len(dataset)
        log(
            '  We have {} samples for {}, \
            load {} data for process (rank {}).'.format(
                len(dataset), dataset_type, len(data_to_load),
                args.graph.rank), args.debug)
    else:
        args.val_dataset_size = len(dataset)
        log(
            '  We have {} samples for {}, \
            load {} val data for process (rank {}).'.format(
                len(dataset), dataset_type, len(data_to_load),
                args.graph.rank), args.debug)

    # Batching
    if (args.growing_batch_size) and (dataset_type == 'train'):
        batch_sampler = GrowingMinibatchSampler(
            data_source=data_to_load,
            num_epochs=args.num_epochs,
            num_iterations=args.num_iterations,
            base_batch_size=args.base_batch_size,
            max_batch_size=args.max_batch_size)
        args.num_epochs = batch_sampler.num_epochs
        args.num_iterations = batch_sampler.num_iterations
        args.total_data_size = len(data_to_load)
        args.num_samples_per_epoch = len(data_to_load) / args.num_epochs
        data_loader = torch.utils.data.DataLoader(data_to_load,
                                                  batch_sampler=batch_sampler,
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_memory)
        log(
            'we have {} batches for {} for rank {}.'.format(
                len(data_loader), dataset_type, args.graph.rank), args.debug)
    elif dataset_type == 'train':
        # Adjust stopping criteria
        if args.stop_criteria == 'epoch':
            args.num_iterations = int(
                len(data_to_load) * args.num_epochs / batch_size)
        else:
            args.num_epochs = int(args.num_iterations * batch_size /
                                  len(data_to_load))
        args.total_data_size = len(data_to_load) * args.num_epochs
        args.num_samples_per_epoch = len(data_to_load)

        # Generate validation data part
        if args.fed_personal:
            if args.data in ['emnist', 'emnist_full', 'shakespeare']:
                data_to_load_train = data_to_load
                data_to_load_val = get_dataset(args,
                                               args.data,
                                               args.data_dir,
                                               split='val')

            if args.federated_type == "perfedavg":
                if args.data in ['emnist', 'emnist_full', 'shakespeare']:
                    #TODO: make this size a paramter
                    val_size = int(0.1 * len(data_to_load))
                    data_to_load_train, data_to_load_val1 = torch.utils.data.random_split(
                        data_to_load, [len(data_to_load) - val_size, val_size])
                else:
                    val_size = int(0.1 * len(data_to_load))
                    data_to_load_train, data_to_load_val, data_to_load_val1 = torch.utils.data.random_split(
                        data_to_load, [
                            len(data_to_load) - 3 * val_size, 2 * val_size,
                            val_size
                        ])
                data_loader_val1 = torch.utils.data.DataLoader(
                    data_to_load_val1,
                    batch_size=batch_size,
                    shuffle=True,
                    num_workers=5,
                    pin_memory=args.pin_memory,
                    drop_last=False)
            else:
                if args.data not in ['emnist', 'emnist_full', 'shakespeare']:
                    val_size = int(0.2 * len(data_to_load))
                    data_to_load_train, data_to_load_val = torch.utils.data.random_split(
                        data_to_load, [len(data_to_load) - val_size, val_size])
            # Generate data loaders
            data_loader_val = torch.utils.data.DataLoader(
                data_to_load_val,
                batch_size=batch_size,
                shuffle=True,
                num_workers=5,
                pin_memory=args.pin_memory,
                drop_last=False)

            data_loader_train = torch.utils.data.DataLoader(
                data_to_load_train,
                batch_size=batch_size,
                shuffle=True,
                num_workers=5,
                pin_memory=args.pin_memory,
                drop_last=False)
            if args.federated_type == 'perfedavg':
                data_loader = [
                    data_loader_train, data_loader_val, data_loader_val1
                ]
            else:
                data_loader = [data_loader_train, data_loader_val]

            log(
                'we have {} batches for {} for rank {}.'.format(
                    len(data_loader[0]), 'train', args.graph.rank), args.debug)
            log(
                'we have {} batches for {} for rank {}.'.format(
                    len(data_loader[1]), 'val', args.graph.rank), args.debug)
        else:
            data_loader = torch.utils.data.DataLoader(
                data_to_load,
                batch_size=batch_size,
                shuffle=True,
                num_workers=5,
                pin_memory=args.pin_memory,
                drop_last=False)
            log(
                'we have {} batches for {} for rank {}.'.format(
                    len(data_loader), 'train', args.graph.rank), args.debug)
    else:
        data_loader = torch.utils.data.DataLoader(data_to_load,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=5,
                                                  pin_memory=args.pin_memory,
                                                  drop_last=False)
        log(
            'we have {} batches for {} for rank {}.'.format(
                len(data_loader), dataset_type, args.graph.rank), args.debug)
    if return_partitioner:
        return data_loader, Partitioner
    else:
        return data_loader
Exemplo n.º 15
0
def train_and_validate_federated_apfl(client):
    """The training scheme of Personalized Federated Learning.
        Official implementation for https://arxiv.org/abs/2003.13461
    """
    log('start training and validation with Federated setting.',
        client.args.debug)

    if client.args.evaluate and client.args.graph.rank == 0:
        # Do the testing on the server and return
        do_validate(client.args,
                    client.model,
                    client.optimizer,
                    client.criterion,
                    client.metrics,
                    client.test_loader,
                    client.all_clients_group,
                    data_mode='test')
        return

    tracker = define_local_training_tracker()
    start_global_time = time.time()
    tracker['start_load_time'] = time.time()
    log('enter the training.', client.args.debug)

    # Number of communication rounds in federated setting should be defined
    for n_c in range(client.args.num_comms):
        client.args.rounds_comm += 1
        client.args.comm_time.append(0.0)
        # Configuring the devices for this round of communication
        # TODO: not make the server rank hard coded
        log("Starting round {} of training".format(n_c), client.args.debug)
        online_clients = set_online_clients(client.args)
        if (n_c == 0) and (0 not in online_clients):
            online_clients += [0]
        online_clients_server = online_clients if 0 in online_clients else online_clients + [
            0
        ]
        online_clients_group = dist.new_group(online_clients_server)

        if client.args.graph.rank in online_clients_server:
            client.model_server = distribute_model_server(client.model_server,
                                                          online_clients_group,
                                                          src=0)
            client.model.load_state_dict(client.model_server.state_dict())
            if client.args.graph.rank in online_clients:
                is_sync = False
                ep = -1  # counting number of epochs
                while not is_sync:
                    ep += 1
                    for i, (_input, _target) in enumerate(client.train_loader):
                        client.model.train()

                        # update local step.
                        logging_load_time(tracker)

                        # update local index and get local step
                        client.args.local_index += 1
                        client.args.local_data_seen += len(_target)
                        get_current_epoch(client.args)
                        local_step = get_current_local_step(client.args)

                        # adjust learning rate (based on the # of accessed samples)
                        lr = adjust_learning_rate(client.args,
                                                  client.optimizer,
                                                  client.scheduler)

                        # load data
                        _input, _target = load_data_batch(
                            client.args, _input, _target, tracker)
                        # Skip batches with one sample because of BatchNorm issue in some models!
                        if _input.size(0) == 1:
                            is_sync = is_sync_fed(client.args)
                            break

                        # inference and get current performance.
                        client.optimizer.zero_grad()
                        loss, performance = inference(client.model,
                                                      client.criterion,
                                                      client.metrics, _input,
                                                      _target)

                        # compute gradient and do local SGD step.
                        loss.backward()
                        client.optimizer.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        client.optimizer.zero_grad()
                        client.optimizer_personal.zero_grad()
                        loss_personal, performance_personal = inference_personal(
                            client.model_personal, client.model,
                            client.args.fed_personal_alpha, client.criterion,
                            client.metrics, _input, _target)

                        # compute gradient and do local SGD step.
                        loss_personal.backward()
                        client.optimizer_personal.step(
                            apply_lr=True,
                            apply_in_momentum=client.args.in_momentum,
                            apply_out_momentum=False)

                        # update alpha
                        if client.args.fed_adaptive_alpha and i == 0 and ep == 0:
                            client.args.fed_personal_alpha = alpha_update(
                                client.model, client.model_personal,
                                client.args.fed_personal_alpha,
                                lr)  #0.1/np.sqrt(1+args.local_index))
                            average_alpha = client.args.fed_personal_alpha
                            average_alpha = global_average(
                                average_alpha,
                                client.args.graph.n_nodes,
                                group=online_clients_group)
                            log("New alpha is:{}".format(average_alpha.item()),
                                client.args.debug)

                        # logging locally.
                        # logging_computing(tracker, loss, performance, _input, lr)

                        # display the logging info.
                        # logging_display_training(args, tracker)

                        # reset load time for the tracker.
                        tracker['start_load_time'] = time.time()
                        is_sync = is_sync_fed(client.args)
                        if is_sync:
                            break
            else:
                log("Offline in this round. Waiting on others to finish!",
                    client.args.debug)

            do_validate(client.args,
                        client.model,
                        client.optimizer_personal,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train',
                        personal=True,
                        model_personal=client.model_personal,
                        alpha=client.args.fed_personal_alpha)
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model,
                            client.optimizer_personal,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation',
                            personal=True,
                            model_personal=client.model_personal,
                            alpha=client.args.fed_personal_alpha)

            # Sync the model server based on model_clients
            log('Enter synching', client.args.debug)
            tracker['start_sync_time'] = time.time()
            client.args.global_index += 1
            client.model_server = fedavg_aggregation(
                client.args, client.model_server, client.model,
                online_clients_group, online_clients, client.optimizer)
            # evaluate the sync time
            logging_sync_time(tracker)

            # Do the validation on the server model
            do_validate(client.args,
                        client.model_server,
                        client.optimizer,
                        client.criterion,
                        client.metrics,
                        client.train_loader,
                        online_clients_group,
                        data_mode='train')
            if client.args.fed_personal:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.val_loader,
                            online_clients_group,
                            data_mode='validation')

            # logging.
            logging_globally(tracker, start_global_time)

            # reset start round time.
            start_global_time = time.time()

            # validate the models at the test data
            if client.args.fed_personal_test:
                do_validate(client.args,
                            client.model_client,
                            client.optimizer_personal,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test',
                            personal=True,
                            model_personal=client.model_personal,
                            alpha=client.args.fed_personal_alpha)
            elif client.args.graph.rank == 0:
                do_validate(client.args,
                            client.model_server,
                            client.optimizer,
                            client.criterion,
                            client.metrics,
                            client.test_loader,
                            online_clients_group,
                            data_mode='test')
        else:
            log("Offline in this round. Waiting on others to finish!",
                client.args.debug)
        dist.barrier(group=client.all_clients_group)