def aggregate(conf, master_model, fedavg_model, client_models,
              flatten_local_models):
    # perform the server momentum (either heavy-ball momentum or nesterov momentum)
    fl_aggregate = conf.fl_aggregate

    assert "server_momentum_factor" in fl_aggregate

    # start the server momentum acceleration.
    current_model_tb = TensorBuffer(list(fedavg_model.parameters()))
    previous_model_tb = TensorBuffer(list(master_model.parameters()))

    # get the update direction.
    update = previous_model_tb.buffer - current_model_tb.buffer

    # using server momentum for the update.
    if not hasattr(conf, "server_momentum_buffer"):
        conf.server_momentum_buffer = torch.zeros_like(update)
    conf.server_momentum_buffer.mul_(
        fl_aggregate["server_momentum_factor"]).add_(update)
    previous_model_tb.buffer.add_(-conf.server_momentum_buffer)

    # update the master_model (but will use the bn stats from the fedavg_model)
    master_model = fedavg_model
    _model_param = list(master_model.parameters())
    previous_model_tb.unpack(_model_param)

    # free the memory.
    torch.cuda.empty_cache()

    # a temp hack (only for debug reason).
    client_models = dict((used_client_arch, master_model.cpu())
                         for used_client_arch in conf.used_client_archs)
    return client_models
Example #2
0
    def step(self, closure=None, **kargs):
        # Apply the gradients with the weight decay and momentum.
        with kargs["timer"]("grad.apply_grad", epoch=self.conf.epoch_):
            utils.apply_gradient(
                self.param_groups, self.state, apply_grad_to_model=True
            )

        with kargs["timer"]("grad.get_params", epoch=self.conf.epoch_):
            params, _ = comm.get_data(
                self.param_groups, self.param_names, is_get_grad=False
            )
            params_tb = TensorBuffer(params)

        with kargs["timer"]("grad.error_compensate", epoch=self.conf.epoch_):
            self.memory.buffer += params_tb.buffer

        with kargs["timer"]("grad.compress", epoch=self.conf.epoch_):
            sync_buffer = {"original_shapes": self.shapes, "params_tb": self.memory}
            local_compressed_params_tb = self.compressor.compress(sync_buffer)

        with kargs["timer"]("grad.update_memory", epoch=self.conf.epoch_):
            self.memory.buffer = self.memory.buffer - local_compressed_params_tb.buffer

        with kargs["timer"]("grad.sync", epoch=self.conf.epoch_):
            self.compressor.sync(sync_buffer)

        # update local model.
        with kargs["timer"]("grad.decompress", epoch=self.conf.epoch_):
            aggregated_info_tb = self.compressor.uncompress(
                sync_buffer, self.neighbors_info
            )
            params_tb.buffer += aggregated_info_tb.buffer
            params_tb.unpack(params)
        return sync_buffer["n_bits"]
Example #3
0
    def step(self, closure=None, **kargs):
        if self.conf.is_centralized:
            with kargs["timer"]("sync/get_data", epoch=self.conf.epoch_):
                # Get data.
                grads, _ = comm.get_data(self.param_groups,
                                         self.param_names,
                                         is_get_grad=True)
                flatten_grads = TensorBuffer(grads)

            with kargs["timer"]("sync/sync", epoch=self.conf.epoch_):
                # Aggregate the gradients.
                flatten_grads.buffer = self.world_aggregator._agg(
                    flatten_grads.buffer,
                    op="avg",
                    distributed=self.conf.distributed)

            with kargs["timer"]("sync/unflatten_grad", epoch=self.conf.epoch_):
                # unflatten grads.
                flatten_grads.unpack(grads)

            with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_):
                utils.apply_gradient(self.param_groups,
                                     self.state,
                                     apply_grad_to_model=True)

            # Get n_bits to transmit.
            n_bits = get_n_bits(flatten_grads.buffer)
        else:
            with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_):
                utils.apply_gradient(self.param_groups,
                                     self.state,
                                     apply_grad_to_model=True)

            with kargs["timer"]("sync/get_data", epoch=self.conf.epoch_):
                # first get and flatten all params.
                params, _ = comm.get_data(self.param_groups,
                                          self.param_names,
                                          is_get_grad=False)
                flatten_params = TensorBuffer(params)

            with kargs["timer"]("sync/sync", epoch=self.conf.epoch_):
                # prepare the sync.
                if self.conf.comm_device == "cpu":
                    flatten_params.buffer.cpu().detach_()

                # then sync.
                flatten_params.buffer = self.decentralized_aggregator._agg(
                    flatten_params.buffer, op="weighted")

            with kargs["timer"]("sync/update_model", epoch=self.conf.epoch_):
                # finally unflatten.
                flatten_params.unpack(params)

            # Get n_bits to transmit.
            n_bits = get_n_bits(flatten_params.buffer)
        return n_bits
Example #4
0
    def step(self, closure=None, **kargs):
        # Apply the gradients with the weight decay and momentum.
        with kargs["timer"]("grad.apply_grad", epoch=self.conf.epoch_):
            utils.apply_gradient(self.param_groups,
                                 self.state,
                                 apply_grad_to_model=False)

        with kargs["timer"]("grad.get_grads", epoch=self.conf.epoch_):
            params, _ = comm.get_data(self.param_groups,
                                      self.param_names,
                                      is_get_grad=False)
            flatten_params = TensorBuffer(params)

            grads, _ = comm.get_data(self.param_groups,
                                     self.param_names,
                                     is_get_grad=True)
            flatten_grads = TensorBuffer(grads)

        # Get weighted hat params and apply the local gradient.
        with kargs["timer"]("grad.apply_local_gradient",
                            epoch=self.conf.epoch_):
            flatten_half_params = deepcopy(flatten_params)
            flatten_half_params.buffer = (sum([
                _hat_params.buffer * self.neighbors_info[_rank]
                for _rank, _hat_params in self.neighbor_hat_params.items()
            ]) - self.param_groups[0]["lr"] * flatten_grads.buffer)

        # compress the model difference and sync.
        with kargs["timer"]("grad.compress", epoch=self.conf.epoch_):
            sync_buffer = {
                "original_shapes": self.shapes,
                "flatten_half_params": flatten_half_params,
                "flatten_params": flatten_params,
            }
            self.compressor.compress(sync_buffer)

        with kargs["timer"]("grad.sync", epoch=self.conf.epoch_):
            self.compressor.sync(sync_buffer)

        # finally unflatten and update local model.
        with kargs["timer"]("grad.unflatten_to_update",
                            epoch=self.conf.epoch_):
            self.compressor.uncompress(sync_buffer, self.neighbor_hat_params)
            flatten_params.buffer = self.neighbor_hat_params[
                self.rank].buffer.clone()
            flatten_params.unpack(params)
        return sync_buffer["n_bits"]
Example #5
0
    def step(self, closure=None, **kargs):
        with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_):
            utils.apply_gradient(self.param_groups,
                                 self.state,
                                 apply_grad_to_model=False)

        with kargs["timer"]("sync.get_data", epoch=self.conf.epoch_):
            # Get data.
            grads, _ = comm.get_data(self.param_groups,
                                     self.param_names,
                                     is_get_grad=True)
            grads_tb = TensorBuffer(grads)

        with kargs["timer"]("sync.use_memory", epoch=self.conf.epoch_):
            # use memory.
            grads_tb.buffer.add_(self.memory_tb.buffer)

        with kargs["timer"]("sync.compress", epoch=self.conf.epoch_):
            # compress.
            sync_buffer = self.compressor.compress(grads_tb)

        with kargs["timer"]("sync.sync", epoch=self.conf.epoch_):
            self.compressor.sync(sync_buffer)

        with kargs["timer"]("sync.update_memory", epoch=self.conf.epoch_):
            # update memory.
            self.memory_tb.buffer = (grads_tb.buffer -
                                     sync_buffer["synced_grads_tb"].buffer)

        with kargs["timer"]("sync.decompress", epoch=self.conf.epoch_):
            sync_grads_tb = self.compressor.decompress(sync_buffer)

        with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_):
            # appply the gradient but only with the gradient.
            params, _ = comm.get_data(self.param_groups,
                                      self.param_names,
                                      is_get_grad=False)
            params_tb = TensorBuffer(params)

            # apply the gradient.
            params_tb.buffer.add_(-self.param_groups[0]["lr"] *
                                  sync_grads_tb.buffer)

            # unpack.
            params_tb.unpack(params)
        return sync_buffer["n_bits"]
Example #6
0
def aggregate(conf, master_model, fedavg_model, client_models,
              flatten_local_models):
    # perform the server Adam.
    # Following the setup in the paper, we use momentum of 0.9,
    # numerical stability constant epsilon to be 0.01,
    # the beta_2 is set to 0.99.
    # The suggested server_lr in the original paper is 0.1
    fl_aggregate = conf.fl_aggregate

    assert "server_lr" in fl_aggregate
    beta_2 = fl_aggregate["beta_2"] if "beta_2" in fl_aggregate else 0.99

    # start the server momentum acceleration.
    current_model_tb = TensorBuffer(list(fedavg_model.parameters()))
    previous_model_tb = TensorBuffer(list(master_model.parameters()))

    # get the update direction.
    update = previous_model_tb.buffer - current_model_tb.buffer

    # using server momentum for the update.
    if not hasattr(conf, "second_server_momentum_buffer"):
        conf.second_server_momentum_buffer = torch.zeros_like(update)
    conf.second_server_momentum_buffer.mul_(beta_2).add_(
        (1 - beta_2) * (update**2))
    previous_model_tb.buffer.add_(
        -fl_aggregate["server_lr"] * update /
        (torch.sqrt(conf.second_server_momentum_buffer) + 0.01))

    # update the master_model (but will use the bn stats from the fedavg_model)
    master_model = fedavg_model
    _model_param = list(master_model.parameters())
    previous_model_tb.unpack(_model_param)

    # free the memory.
    torch.cuda.empty_cache()

    # a temp hack (only for debug reason).
    client_models = dict((used_client_arch, master_model.cpu())
                         for used_client_arch in conf.used_client_archs)
    return client_models
Example #7
0
    def step(self, closure=None, **kargs):
        # do the local update steps.
        with kargs["timer"]("sync/get_data", epoch=self.conf.epoch_):
            # get parmas.
            params, _ = comm.get_data(self.param_groups,
                                      self.param_names,
                                      is_get_grad=False)
            params_tb = TensorBuffer(params)

        with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_):
            # prepare the gradient (sign)
            utils.apply_gradient(self.param_groups,
                                 self.state,
                                 apply_grad_to_model=False)
            # get grads.
            grads, _ = comm.get_data(self.param_groups,
                                     self.param_names,
                                     is_get_grad=True)
            grads_tb = TensorBuffer(grads)

        # enter the global sync if it satisfies the condition.
        # get the params difference w.r.t. previous synced model.
        with kargs["timer"]("sync/compress", epoch=self.conf.epoch_):
            sync_buffer = self.compressor.compress(grads_tb)

        # sync and decompress.
        with kargs["timer"]("sync/sync_and_decompress",
                            epoch=self.conf.epoch_):
            self.compressor.sync(sync_buffer)
            synced_updates_tb = self.compressor.decompress(sync_buffer)

        # unpack the synced info and update the consensus params.
        with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_):
            params_tb.buffer -= self.param_groups[0][
                "lr"] * synced_updates_tb.buffer
            params_tb.unpack(params)
        return sync_buffer["n_bits"]
class Worker(object):
    def __init__(self, conf):
        self.conf = conf

        # some initializations.
        self.rank = conf.graph.rank
        conf.graph.worker_id = conf.graph.rank
        self.device = torch.device(
            "cuda" if self.conf.graph.on_cuda else "cpu")

        # define the timer for different operations.
        # if we choose the `train_fast` mode, then we will not track the time.
        self.timer = Timer(
            verbosity_level=1
            if conf.track_time and not conf.train_fast else 0,
            log_fn=conf.logger.log_metric,
        )

        # create dataset (as well as the potential data_partitioner) for training.
        dist.barrier()
        self.dataset = create_dataset.define_dataset(conf, data=conf.data)
        _, self.data_partitioner = create_dataset.define_data_loader(
            self.conf,
            dataset=self.dataset["train"],
            localdata_id=0,  # random id here.
            is_train=True,
            data_partitioner=None,
        )
        conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} initialized the local training data with Master."
        )

        # define the criterion.
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

        # define the model compression operators.
        if conf.local_model_compression is not None:
            if conf.local_model_compression == "quantization":
                self.model_compression_fn = compressor.ModelQuantization(conf)

        conf.logger.log(
            f"Worker-{conf.graph.worker_id} initialized dataset/criterion.\n")

    def run(self):
        while True:
            self._listen_to_master()

            # check if we need to terminate the training or not.
            if self._terminate_by_early_stopping():
                return

            self._recv_model_from_master()
            self._train()
            self._send_model_to_master()

            # check if we need to terminate the training or not.
            if self._terminate_by_complete_training():
                return

    def _listen_to_master(self):
        # listen to master, related to the function `_activate_selected_clients` in `master.py`.
        msg = torch.zeros((3, self.conf.n_participated))
        dist.broadcast(tensor=msg, src=0)
        self.conf.graph.client_id, self.conf.graph.comm_round, self.n_local_epochs = (
            msg[:, self.conf.graph.rank - 1].to(int).cpu().numpy().tolist())

        # once we receive the signal, we init for the local training.
        self.arch, self.model = create_model.define_model(
            self.conf,
            to_consistent_model=False,
            client_id=self.conf.graph.client_id)
        self.model_state_dict = self.model.state_dict()
        self.model_tb = TensorBuffer(list(self.model_state_dict.values()))
        self.metrics = create_metrics.Metrics(self.model,
                                              task="classification")
        dist.barrier()

    def _recv_model_from_master(self):
        # related to the function `_send_model_to_selected_clients` in `master.py`
        old_buffer = copy.deepcopy(self.model_tb.buffer)
        dist.recv(self.model_tb.buffer, src=0)
        new_buffer = copy.deepcopy(self.model_tb.buffer)
        self.model_tb.unpack(self.model_state_dict.values())
        self.model.load_state_dict(self.model_state_dict)
        random_reinit.random_reinit_model(self.conf, self.model)
        self.init_model = self._turn_off_grad(
            copy.deepcopy(self.model).to(self.device))
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) received the model ({self.arch}) from Master. The model status {'is updated' if old_buffer.norm() != new_buffer.norm() else 'is not updated'}."
        )
        dist.barrier()

    def _train(self):
        self.model.train()

        # init the model and dataloader.
        if self.conf.graph.on_cuda:
            self.model = self.model.cuda()
        self.train_loader, _ = create_dataset.define_data_loader(
            self.conf,
            dataset=self.dataset["train"],
            # localdata_id start from 0 to the # of clients - 1.
            # client_id starts from 1 to the # of clients.
            localdata_id=self.conf.graph.client_id - 1,
            is_train=True,
            data_partitioner=self.data_partitioner,
        )

        # define optimizer, scheduler and runtime tracker.
        self.optimizer = create_optimizer.define_optimizer(
            self.conf, model=self.model, optimizer_name=self.conf.optimizer)
        self.scheduler = create_scheduler.Scheduler(self.conf,
                                                    optimizer=self.optimizer)
        self.tracker = RuntimeTracker(
            metrics_to_track=self.metrics.metric_names)
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) enters the local training phase (current communication rounds={self.conf.graph.comm_round})."
        )

        # efficient local training.
        if hasattr(self, "model_compression_fn"):
            self.model_compression_fn.compress_model(
                param_groups=self.optimizer.param_groups)

        # entering local updates and will finish only after reaching the expected local_n_epochs.
        while True:
            for _input, _target in self.train_loader:
                # load data
                with self.timer("load_data", epoch=self.scheduler.epoch_):
                    data_batch = create_dataset.load_data_batch(
                        self.conf, _input, _target, is_training=True)

                # inference and get current performance.
                with self.timer("forward_pass", epoch=self.scheduler.epoch_):
                    self.optimizer.zero_grad()
                    loss, output = self._inference(data_batch)

                    # in case we need self distillation to penalize the local training
                    # (avoid catastrophic forgetting).
                    self._local_training_with_self_distillation(
                        loss, output, data_batch)

                with self.timer("backward_pass", epoch=self.scheduler.epoch_):
                    loss.backward()
                    self._add_grad_from_prox_regularized_loss()
                    self.optimizer.step()
                    self.scheduler.step()

                # efficient local training.
                with self.timer("compress_model", epoch=self.scheduler.epoch_):
                    if hasattr(self, "model_compression_fn"):
                        self.model_compression_fn.compress_model(
                            param_groups=self.optimizer.param_groups)

                # display the logging info.
                display_training_stat(self.conf, self.scheduler, self.tracker)

                # display tracking time.
                if (self.conf.display_tracked_time
                        and self.scheduler.local_index % self.conf.summary_freq
                        == 0):
                    self.conf.logger.log(self.timer.summary())

                # check divergence.
                if self.tracker.stat["loss"].avg > 1e3 or np.isnan(
                        self.tracker.stat["loss"].avg):
                    self.conf.logger.log(
                        f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) diverges!!!!!Early stop it."
                    )
                    self._terminate_comm_round()
                    return

                # check stopping condition.
                if self._is_finished_one_comm_round():
                    self._terminate_comm_round()
                    return

            # refresh the logging cache at the end of each epoch.
            self.tracker.reset()
            if self.conf.logger.meet_cache_limit():
                self.conf.logger.save_json()

    def _inference(self, data_batch):
        """Inference on the given model and get loss and accuracy."""
        # do the forward pass and get the output.
        output = self.model(data_batch["input"])

        # evaluate the output and get the loss, performance.
        if self.conf.use_mixup:
            loss = mixup.mixup_criterion(
                self.criterion,
                output,
                data_batch["target_a"],
                data_batch["target_b"],
                data_batch["mixup_lambda"],
            )

            performance_a = self.metrics.evaluate(loss, output,
                                                  data_batch["target_a"])
            performance_b = self.metrics.evaluate(loss, output,
                                                  data_batch["target_b"])
            performance = [
                data_batch["mixup_lambda"] * _a +
                (1 - data_batch["mixup_lambda"]) * _b
                for _a, _b in zip(performance_a, performance_b)
            ]
        else:
            loss = self.criterion(output, data_batch["target"])
            performance = self.metrics.evaluate(loss, output,
                                                data_batch["target"])

        # update tracker.
        if self.tracker is not None:
            self.tracker.update_metrics([loss.item()] + performance,
                                        n_samples=data_batch["input"].size(0))
        return loss, output

    def _add_grad_from_prox_regularized_loss(self):
        assert self.conf.local_prox_term >= 0
        if self.conf.local_prox_term != 0:
            assert self.conf.weight_decay == 0
            assert self.conf.optimizer == "sgd"
            assert self.conf.momentum_factor == 0

            for _param, _init_param in zip(self.model.parameters(),
                                           self.init_model.parameters()):
                if _param.grad is not None:
                    _param.grad.data.add_((_param.data - _init_param.data) *
                                          self.conf.local_prox_term)

    def _local_training_with_self_distillation(self, loss, output, data_batch):
        if self.conf.self_distillation > 0:
            loss = loss * (
                1 - self.conf.self_distillation
            ) + self.conf.self_distillation * self._divergence(
                student_logits=output /
                self.conf.self_distillation_temperature,
                teacher_logits=self.init_model(data_batch["input"]) /
                self.conf.self_distillation_temperature,
            )
        return loss

    def _divergence(self, student_logits, teacher_logits):
        divergence = F.kl_div(
            F.log_softmax(student_logits, dim=1),
            F.softmax(teacher_logits, dim=1),
            reduction="batchmean",
        )  # forward KL
        return divergence

    def _turn_off_grad(self, model):
        for param in model.parameters():
            param.requires_grad = False
        return model

    def _send_model_to_master(self):
        dist.barrier()
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) sending the model ({self.arch}) back to Master."
        )
        flatten_model = TensorBuffer(list(self.model.state_dict().values()))
        dist.send(tensor=flatten_model.buffer, dst=0)
        dist.barrier()

    def _terminate_comm_round(self):
        self.model = self.model.cpu()
        del self.init_model
        self.scheduler.clean()
        self.conf.logger.save_json()
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) finished one round of federated learning: (comm_round={self.conf.graph.comm_round})."
        )

    def _terminate_by_early_stopping(self):
        if self.conf.graph.comm_round == -1:
            dist.barrier()
            self.conf.logger.log(
                f"Worker-{self.conf.graph.worker_id} finished the federated learning by early-stopping."
            )
            return True
        else:
            return False

    def _terminate_by_complete_training(self):
        if self.conf.graph.comm_round == self.conf.n_comm_rounds:
            dist.barrier()
            self.conf.logger.log(
                f"Worker-{self.conf.graph.worker_id} finished the federated learning: (total comm_rounds={self.conf.graph.comm_round})."
            )
            return True
        else:
            return False

    def _is_finished_one_comm_round(self):
        return True if self.conf.epoch_ >= self.conf.local_n_epochs else False