def distill_knowledge(
    conf,
    student_model,
    dataset,
    num_epochs,
    batch_size,
    teacher_model=None,
    softmax_temperature=1,
):
    # init.
    data_loader = create_data_loader(dataset, batch_size=batch_size)
    criterion = torch.nn.KLDivLoss(reduction="batchmean")
    tracker = RuntimeTracker(metrics_to_track=[])

    # check model status.
    untrainable_teacher_model = (agg_utils.modify_model_trainable_status(
        conf, teacher_model, trainable=False)
                                 if teacher_model is not None else None)
    trainable_student_model = agg_utils.check_trainable(conf, student_model)
    optimizer = create_optimizer(conf, trainable_student_model)

    # start the formal training.
    for epoch_idx in range(num_epochs):
        for _input, _target in data_loader:
            # init the _input, _target.
            if conf.graph.on_cuda:
                _input = _input.cuda()
            if untrainable_teacher_model is None and conf.graph.on_cuda:
                _target_prob = _target.cuda()

            # perform fp/bp on the student model.
            optimizer.zero_grad()
            _output = trainable_student_model(_input)

            # evaluate the loss.
            if untrainable_teacher_model is not None:
                loss = (softmax_temperature**2) * criterion(
                    torch.nn.functional.log_softmax(
                        _output / softmax_temperature, dim=1),
                    torch.nn.functional.softmax(
                        untrainable_teacher_model(_input).detach() /
                        softmax_temperature,
                        dim=1,
                    ),
                )
            else:
                loss = criterion(
                    torch.nn.functional.log_softmax(_output, dim=1),
                    _target_prob)
            loss.backward()
            optimizer.step()
            tracker.update_metrics([loss.item()], n_samples=_input.size(0))
    conf.logger.log(f"# of epochs={epoch_idx + 1}: {tracker()}")
    return trainable_student_model.cpu()
示例#2
0
    def _evaluate(_model, label):
        # define stat.
        tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names)

        # switch to evaluation mode
        _model.eval()

        # define hidden state for RNN.
        _hidden = (
            model.module.init_hidden(conf.batch_size)
            if "DataParallel" == model.__class__.__name__
            else model.init_hidden(conf.batch_size)
        )

        for batch in data_loader["val_loader"]:
            # load data and check performance.
            _input, _target = batch.text, batch.target

            # repackage the hidden.
            _hidden = (
                model.module.repackage_hidden(_hidden)
                if "DataParallel" == model.__class__.__name__
                else model.repackage_hidden(_hidden)
            )

            with torch.no_grad():
                _, _hidden = inference(
                    conf,
                    _model,
                    criterion,
                    metrics,
                    _input,
                    _target,
                    _hidden,
                    tracker_te,
                )

        # display the test stat.
        display_test_stat(conf, scheduler, tracker_te, label)

        # get global (mean) performance
        global_performance = tracker_te.evaluate_global_metrics()
        return global_performance
示例#3
0
    def _evaluate(_model, label):
        # define stat.
        tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names)

        # switch to evaluation mode
        _model.eval()

        for _input, _target in data_loader["val_loader"]:
            # load data and check performance.
            _input, _target = load_data_batch(conf, _input, _target)

            with torch.no_grad():
                inference(_model, criterion, metrics, _input, _target, tracker_te)

        # display the test stat.
        display_test_stat(conf, scheduler, tracker_te, label)

        # get global (mean) performance
        global_performance = tracker_te.evaluate_global_metrics()
        return global_performance
def validate(
    conf,
    coordinator,
    model,
    criterion,
    metrics,
    data_loader,
    label="test_loader",
    display=True,
):
    """A function for model evaluation."""
    if data_loader is None:
        return None

    # switch to evaluation mode.
    model.eval()

    # place the model to the device.
    if conf.graph.on_cuda:
        model = model.cuda()

    # evaluate on test_loader.
    tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names)

    for _input, _target in data_loader:
        # load data and check performance.
        data_batch = create_dataset.load_data_batch(
            conf, _input, _target, is_training=False
        )

        with torch.no_grad():
            inference(
                conf,
                model,
                criterion,
                metrics,
                data_batch,
                tracker_te,
                is_training=False,
            )

    # place back model to the cpu.
    if conf.graph.on_cuda:
        model = model.cpu()

    # display the test stat.
    perf = tracker_te()
    if label is not None:
        display_test_stat(conf, coordinator, tracker_te, label)
    if display:
        conf.logger.log(f"The validation performance = {perf}.")
    return perf
    def _construct_input_via_expected_output_space(
            self, constructed_probs_and_labels):
        # generated the input based on these dirichlet distributions.
        generated_inputs, generated_probs = [], []
        model = agg_utils.modify_model_trainable_status(self.conf,
                                                        self.model,
                                                        trainable=False)
        criterion = torch.nn.KLDivLoss(reduction="batchmean")
        tracker = RuntimeTracker(metrics_to_track=[])

        # init the dataset for the training
        dataset = CustomDataset(constructed_probs_and_labels)
        data_loader = create_data_loader(
            dataset,
            batch_size=int(
                self.conf.fl_aggregate["kt_g_batch_size_per_class"]))
        num_update_per_batch = int(
            self.conf.fl_aggregate["kt_data_generate_iters"])
        self.conf.logger.log(
            f"# of mini-batches={len(data_loader)}, size of mini-batch={self.conf.fl_aggregate['kt_g_batch_size_per_class']}, # of update per-mini-batch={num_update_per_batch}"
        )

        # training the dataset.
        for batch_idx, probs in enumerate(data_loader):
            _generated_input = torch.rand(
                (len(probs), 3, 32, 32),
                requires_grad=True,
                device="cuda" if self.conf.graph.on_cuda else "cpu",
            )
            optimizer = torch.optim.Adam(
                [_generated_input],
                lr=self.conf.fl_aggregate["step_size"],
                betas=(self.conf.adam_beta_1, self.conf.adam_beta_2),
                eps=self.conf.adam_eps,
            )

            # improve the input_data to minic the output space of the network.
            for _ in range(num_update_per_batch):
                loss = update_input_data(
                    self.conf,
                    model,
                    criterion,
                    optimizer,
                    _generated_input,
                    expected_probs=probs.cuda()
                    if self.conf.graph.on_cuda else probs,
                )
                tracker.update_metrics([loss.item()],
                                       n_samples=_generated_input.size(0))
            self.conf.logger.log(
                f"\t the data generation loss (model index={self.model_idx}, batch index={batch_idx}) = {tracker()}."
            )
            tracker.reset()
            generated_inputs.append(copy.deepcopy(_generated_input.data))
            generated_probs.append(probs)
        generated_inputs = torch.cat(generated_inputs, dim=0).data.cpu()
        generated_probs = torch.cat(generated_probs, dim=0).data.cpu()
        return generated_inputs, generated_probs
def train_and_validate(conf, model, criterion, scheduler, optimizer, metrics,
                       data_loader):
    print("=>>>> start training and validation.\n")

    # define runtime stat tracker and start the training.
    tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names,
                                on_cuda=conf.graph.on_cuda)

    # get the timer.
    timer = conf.timer

    # break until finish expected full epoch training.
    print("=>>>> enter the training.\n")
    while True:
        dist.barrier()

        # configure local step.
        for _input, _target in data_loader["train_loader"]:
            model.train()
            scheduler.step(optimizer)

            # load data
            with timer("load_data", epoch=scheduler.epoch_):
                _input, _target = load_data_batch(conf, _input, _target)

            # inference and get current performance.
            with timer("forward_pass", epoch=scheduler.epoch_):
                optimizer.zero_grad()
                loss = inference(model, criterion, metrics, _input, _target,
                                 tracker_tr)

            with timer("backward_pass", epoch=scheduler.epoch_):
                loss.backward()

            with timer("sync_complete", epoch=scheduler.epoch_):
                n_bits_to_transmit = optimizer.step(timer=timer,
                                                    scheduler=scheduler)

            # display the logging info.
            display_training_stat(conf, scheduler, tracker_tr,
                                  n_bits_to_transmit)

            # finish one epoch training and to decide if we want to val our model.
            if scheduler.epoch_ % 1 == 0:
                if tracker_tr.stat["loss"].avg > 1e3 or np.isnan(
                        tracker_tr.stat["loss"].avg):
                    print("\nThe process diverges!!!!!Early stop it.")
                    error_handler.abort()

                # each worker finish one epoch training.
                do_validate(conf, model, optimizer, criterion, scheduler,
                            metrics, data_loader)

                # refresh the logging cache at the begining of each epoch.
                tracker_tr.reset()

                # evaluate (and only inference) on the whole training loader.
                if (conf.evaluate_consensus
                        or scheduler.is_stop()) and not conf.train_fast:
                    # prepare the dataloader for the consensus evaluation.
                    _data_loader = {
                        "val_loader":
                        _define_cv_dataset(
                            conf,
                            partition_type=None,
                            dataset_type="train",
                            force_shuffle=True,
                        )
                    }

                    # evaluate on the local model.
                    conf.logger.log(
                        "eval the local model on full training data.")
                    validate(
                        conf,
                        model,
                        optimizer,
                        criterion,
                        scheduler,
                        metrics,
                        data_loader=_data_loader,
                        label="eval_local_model_on_full_training_data",
                        force_evaluate_on_averaged_model=False,
                    )

                    # evaluate on the averaged model.
                    conf.logger.log(
                        "eval the averaged model on full training data.")
                    copied_model = copy.deepcopy(
                        model.module if "DataParallel" ==
                        model.__class__.__name__ else model)
                    optimizer.world_aggregator.agg_model(copied_model,
                                                         op="avg")
                    validate(
                        conf,
                        copied_model,
                        optimizer,
                        criterion,
                        scheduler,
                        metrics,
                        data_loader=_data_loader,
                        label="eval_averaged_model_on_full_training_data",
                        force_evaluate_on_averaged_model=False,
                    )

                # determine if the training is finished.
                if scheduler.is_stop():
                    # save json.
                    conf.logger.save_json()

                    # temporarily hack the exit parallelchoco
                    if optimizer.__class__.__name__ == "ParallelCHOCO":
                        error_handler.abort()
                    return

            # display tracking time.
            if (conf.graph.rank == 0 and conf.display_tracked_time
                    and scheduler.local_index % conf.summary_freq == 0):
                print(timer.summary())

        # reshuffle the data.
        if conf.reshuffle_per_epoch:
            print("\nReshuffle the dataset.")
            del data_loader
            gc.collect()
            data_loader = define_dataset(conf)
    def distillation(self):
        # init the tracker.
        server_tracker = RuntimeTracker(metrics_to_track=["student_loss"],
                                        force_to_replace_metrics=True)

        # init the data iter.
        if self.distillation_data_loader is not None:
            data_iter = iter(self.distillation_data_loader)

        # get the client_weights from client's validation performance.
        client_weights = self._get_client_weights()

        # get the init server perf.
        init_perf_on_val = self.validate(model=self.server_student,
                                         data_loader=self.val_data_loader)

        # iterate over dataset
        n_pseudo_batches = 0
        self.log_fn(
            f"Batch {n_pseudo_batches}/{self.total_n_server_pseudo_batches}: Student Validation Acc={init_perf_on_val}."
        )
        while n_pseudo_batches < self.total_n_server_pseudo_batches:
            # get the inputs.
            if self.distillation_data_loader is not None:
                try:
                    pseudo_data = next(data_iter)[0].to(device=self.device)
                except StopIteration:
                    data_iter = iter(self.distillation_data_loader)
                    pseudo_data = next(data_iter)[0].to(device=self.device)
            else:
                if self.conf.fl_aggregate["use_data_scheme"] == "random_data":
                    pseudo_data = self._create_data_randomly()
                else:
                    raise NotImplementedError("incorrect use_data_scheme.")

            # get the logits.
            with torch.no_grad():
                teacher_logits = [
                    _teacher(pseudo_data) for _teacher in self.client_teachers
                ]

            # steps on the same pseudo data
            student_logits = self.server_student(pseudo_data)
            student_logits_activations = [
                (student_logits, self.server_student.activations)
            ] * self.numb_teachers

            stud_avg_loss = self.update_student(
                student_logits_activations=student_logits_activations,
                base_solver=self.base_solver,
                _student=self.server_student,
                _teachers=self.client_teachers,
                _opt_student=self.swa_optimizer,
                teacher_logits=teacher_logits,
                update_student_scheme=self.update_student_scheme,
                weights=client_weights,
            )

            # update the tracker after each batch.
            server_tracker.update_metrics([stud_avg_loss],
                                          n_samples=self.batch_size)

            if (n_pseudo_batches + 1) % self.eval_batches_freq == 0:
                validated_perf = self.validate(
                    model=self.server_student,
                    data_loader=self.val_data_loader)
                self.log_fn(
                    f"Batch {n_pseudo_batches + 1}/{self.total_n_server_pseudo_batches}: Student Loss={server_tracker.stat['student_loss'].avg:02.5f}; Student Validation Acc={validated_perf}."
                )
                server_tracker.reset()

            n_pseudo_batches += 1

        # update the server model.
        self.swa_optimizer.swap_swa_sgd()
        self.server_student = self.server_student.cpu()
class Worker(object):
    def __init__(self, conf):
        self.conf = conf

        # some initializations.
        self.rank = conf.graph.rank
        conf.graph.worker_id = conf.graph.rank
        self.device = torch.device(
            "cuda" if self.conf.graph.on_cuda else "cpu")

        # define the timer for different operations.
        # if we choose the `train_fast` mode, then we will not track the time.
        self.timer = Timer(
            verbosity_level=1
            if conf.track_time and not conf.train_fast else 0,
            log_fn=conf.logger.log_metric,
        )

        # create dataset (as well as the potential data_partitioner) for training.
        dist.barrier()
        self.dataset = create_dataset.define_dataset(conf, data=conf.data)
        _, self.data_partitioner = create_dataset.define_data_loader(
            self.conf,
            dataset=self.dataset["train"],
            localdata_id=0,  # random id here.
            is_train=True,
            data_partitioner=None,
        )
        conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} initialized the local training data with Master."
        )

        # define the criterion.
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

        # define the model compression operators.
        if conf.local_model_compression is not None:
            if conf.local_model_compression == "quantization":
                self.model_compression_fn = compressor.ModelQuantization(conf)

        conf.logger.log(
            f"Worker-{conf.graph.worker_id} initialized dataset/criterion.\n")

    def run(self):
        while True:
            self._listen_to_master()

            # check if we need to terminate the training or not.
            if self._terminate_by_early_stopping():
                return

            self._recv_model_from_master()
            self._train()
            self._send_model_to_master()

            # check if we need to terminate the training or not.
            if self._terminate_by_complete_training():
                return

    def _listen_to_master(self):
        # listen to master, related to the function `_activate_selected_clients` in `master.py`.
        msg = torch.zeros((3, self.conf.n_participated))
        dist.broadcast(tensor=msg, src=0)
        self.conf.graph.client_id, self.conf.graph.comm_round, self.n_local_epochs = (
            msg[:, self.conf.graph.rank - 1].to(int).cpu().numpy().tolist())

        # once we receive the signal, we init for the local training.
        self.arch, self.model = create_model.define_model(
            self.conf,
            to_consistent_model=False,
            client_id=self.conf.graph.client_id)
        self.model_state_dict = self.model.state_dict()
        self.model_tb = TensorBuffer(list(self.model_state_dict.values()))
        self.metrics = create_metrics.Metrics(self.model,
                                              task="classification")
        dist.barrier()

    def _recv_model_from_master(self):
        # related to the function `_send_model_to_selected_clients` in `master.py`
        old_buffer = copy.deepcopy(self.model_tb.buffer)
        dist.recv(self.model_tb.buffer, src=0)
        new_buffer = copy.deepcopy(self.model_tb.buffer)
        self.model_tb.unpack(self.model_state_dict.values())
        self.model.load_state_dict(self.model_state_dict)
        random_reinit.random_reinit_model(self.conf, self.model)
        self.init_model = self._turn_off_grad(
            copy.deepcopy(self.model).to(self.device))
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) received the model ({self.arch}) from Master. The model status {'is updated' if old_buffer.norm() != new_buffer.norm() else 'is not updated'}."
        )
        dist.barrier()

    def _train(self):
        self.model.train()

        # init the model and dataloader.
        if self.conf.graph.on_cuda:
            self.model = self.model.cuda()
        self.train_loader, _ = create_dataset.define_data_loader(
            self.conf,
            dataset=self.dataset["train"],
            # localdata_id start from 0 to the # of clients - 1.
            # client_id starts from 1 to the # of clients.
            localdata_id=self.conf.graph.client_id - 1,
            is_train=True,
            data_partitioner=self.data_partitioner,
        )

        # define optimizer, scheduler and runtime tracker.
        self.optimizer = create_optimizer.define_optimizer(
            self.conf, model=self.model, optimizer_name=self.conf.optimizer)
        self.scheduler = create_scheduler.Scheduler(self.conf,
                                                    optimizer=self.optimizer)
        self.tracker = RuntimeTracker(
            metrics_to_track=self.metrics.metric_names)
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) enters the local training phase (current communication rounds={self.conf.graph.comm_round})."
        )

        # efficient local training.
        if hasattr(self, "model_compression_fn"):
            self.model_compression_fn.compress_model(
                param_groups=self.optimizer.param_groups)

        # entering local updates and will finish only after reaching the expected local_n_epochs.
        while True:
            for _input, _target in self.train_loader:
                # load data
                with self.timer("load_data", epoch=self.scheduler.epoch_):
                    data_batch = create_dataset.load_data_batch(
                        self.conf, _input, _target, is_training=True)

                # inference and get current performance.
                with self.timer("forward_pass", epoch=self.scheduler.epoch_):
                    self.optimizer.zero_grad()
                    loss, output = self._inference(data_batch)

                    # in case we need self distillation to penalize the local training
                    # (avoid catastrophic forgetting).
                    self._local_training_with_self_distillation(
                        loss, output, data_batch)

                with self.timer("backward_pass", epoch=self.scheduler.epoch_):
                    loss.backward()
                    self._add_grad_from_prox_regularized_loss()
                    self.optimizer.step()
                    self.scheduler.step()

                # efficient local training.
                with self.timer("compress_model", epoch=self.scheduler.epoch_):
                    if hasattr(self, "model_compression_fn"):
                        self.model_compression_fn.compress_model(
                            param_groups=self.optimizer.param_groups)

                # display the logging info.
                display_training_stat(self.conf, self.scheduler, self.tracker)

                # display tracking time.
                if (self.conf.display_tracked_time
                        and self.scheduler.local_index % self.conf.summary_freq
                        == 0):
                    self.conf.logger.log(self.timer.summary())

                # check divergence.
                if self.tracker.stat["loss"].avg > 1e3 or np.isnan(
                        self.tracker.stat["loss"].avg):
                    self.conf.logger.log(
                        f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) diverges!!!!!Early stop it."
                    )
                    self._terminate_comm_round()
                    return

                # check stopping condition.
                if self._is_finished_one_comm_round():
                    self._terminate_comm_round()
                    return

            # refresh the logging cache at the end of each epoch.
            self.tracker.reset()
            if self.conf.logger.meet_cache_limit():
                self.conf.logger.save_json()

    def _inference(self, data_batch):
        """Inference on the given model and get loss and accuracy."""
        # do the forward pass and get the output.
        output = self.model(data_batch["input"])

        # evaluate the output and get the loss, performance.
        if self.conf.use_mixup:
            loss = mixup.mixup_criterion(
                self.criterion,
                output,
                data_batch["target_a"],
                data_batch["target_b"],
                data_batch["mixup_lambda"],
            )

            performance_a = self.metrics.evaluate(loss, output,
                                                  data_batch["target_a"])
            performance_b = self.metrics.evaluate(loss, output,
                                                  data_batch["target_b"])
            performance = [
                data_batch["mixup_lambda"] * _a +
                (1 - data_batch["mixup_lambda"]) * _b
                for _a, _b in zip(performance_a, performance_b)
            ]
        else:
            loss = self.criterion(output, data_batch["target"])
            performance = self.metrics.evaluate(loss, output,
                                                data_batch["target"])

        # update tracker.
        if self.tracker is not None:
            self.tracker.update_metrics([loss.item()] + performance,
                                        n_samples=data_batch["input"].size(0))
        return loss, output

    def _add_grad_from_prox_regularized_loss(self):
        assert self.conf.local_prox_term >= 0
        if self.conf.local_prox_term != 0:
            assert self.conf.weight_decay == 0
            assert self.conf.optimizer == "sgd"
            assert self.conf.momentum_factor == 0

            for _param, _init_param in zip(self.model.parameters(),
                                           self.init_model.parameters()):
                if _param.grad is not None:
                    _param.grad.data.add_((_param.data - _init_param.data) *
                                          self.conf.local_prox_term)

    def _local_training_with_self_distillation(self, loss, output, data_batch):
        if self.conf.self_distillation > 0:
            loss = loss * (
                1 - self.conf.self_distillation
            ) + self.conf.self_distillation * self._divergence(
                student_logits=output /
                self.conf.self_distillation_temperature,
                teacher_logits=self.init_model(data_batch["input"]) /
                self.conf.self_distillation_temperature,
            )
        return loss

    def _divergence(self, student_logits, teacher_logits):
        divergence = F.kl_div(
            F.log_softmax(student_logits, dim=1),
            F.softmax(teacher_logits, dim=1),
            reduction="batchmean",
        )  # forward KL
        return divergence

    def _turn_off_grad(self, model):
        for param in model.parameters():
            param.requires_grad = False
        return model

    def _send_model_to_master(self):
        dist.barrier()
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) sending the model ({self.arch}) back to Master."
        )
        flatten_model = TensorBuffer(list(self.model.state_dict().values()))
        dist.send(tensor=flatten_model.buffer, dst=0)
        dist.barrier()

    def _terminate_comm_round(self):
        self.model = self.model.cpu()
        del self.init_model
        self.scheduler.clean()
        self.conf.logger.save_json()
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) finished one round of federated learning: (comm_round={self.conf.graph.comm_round})."
        )

    def _terminate_by_early_stopping(self):
        if self.conf.graph.comm_round == -1:
            dist.barrier()
            self.conf.logger.log(
                f"Worker-{self.conf.graph.worker_id} finished the federated learning by early-stopping."
            )
            return True
        else:
            return False

    def _terminate_by_complete_training(self):
        if self.conf.graph.comm_round == self.conf.n_comm_rounds:
            dist.barrier()
            self.conf.logger.log(
                f"Worker-{self.conf.graph.worker_id} finished the federated learning: (total comm_rounds={self.conf.graph.comm_round})."
            )
            return True
        else:
            return False

    def _is_finished_one_comm_round(self):
        return True if self.conf.epoch_ >= self.conf.local_n_epochs else False
    def _train(self):
        self.model.train()

        # init the model and dataloader.
        if self.conf.graph.on_cuda:
            self.model = self.model.cuda()
        self.train_loader, _ = create_dataset.define_data_loader(
            self.conf,
            dataset=self.dataset["train"],
            # localdata_id start from 0 to the # of clients - 1.
            # client_id starts from 1 to the # of clients.
            localdata_id=self.conf.graph.client_id - 1,
            is_train=True,
            data_partitioner=self.data_partitioner,
        )

        # define optimizer, scheduler and runtime tracker.
        self.optimizer = create_optimizer.define_optimizer(
            self.conf, model=self.model, optimizer_name=self.conf.optimizer)
        self.scheduler = create_scheduler.Scheduler(self.conf,
                                                    optimizer=self.optimizer)
        self.tracker = RuntimeTracker(
            metrics_to_track=self.metrics.metric_names)
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) enters the local training phase (current communication rounds={self.conf.graph.comm_round})."
        )

        # efficient local training.
        if hasattr(self, "model_compression_fn"):
            self.model_compression_fn.compress_model(
                param_groups=self.optimizer.param_groups)

        # entering local updates and will finish only after reaching the expected local_n_epochs.
        while True:
            for _input, _target in self.train_loader:
                # load data
                with self.timer("load_data", epoch=self.scheduler.epoch_):
                    data_batch = create_dataset.load_data_batch(
                        self.conf, _input, _target, is_training=True)

                # inference and get current performance.
                with self.timer("forward_pass", epoch=self.scheduler.epoch_):
                    self.optimizer.zero_grad()
                    loss, output = self._inference(data_batch)

                    # in case we need self distillation to penalize the local training
                    # (avoid catastrophic forgetting).
                    self._local_training_with_self_distillation(
                        loss, output, data_batch)

                with self.timer("backward_pass", epoch=self.scheduler.epoch_):
                    loss.backward()
                    self._add_grad_from_prox_regularized_loss()
                    self.optimizer.step()
                    self.scheduler.step()

                # efficient local training.
                with self.timer("compress_model", epoch=self.scheduler.epoch_):
                    if hasattr(self, "model_compression_fn"):
                        self.model_compression_fn.compress_model(
                            param_groups=self.optimizer.param_groups)

                # display the logging info.
                display_training_stat(self.conf, self.scheduler, self.tracker)

                # display tracking time.
                if (self.conf.display_tracked_time
                        and self.scheduler.local_index % self.conf.summary_freq
                        == 0):
                    self.conf.logger.log(self.timer.summary())

                # check divergence.
                if self.tracker.stat["loss"].avg > 1e3 or np.isnan(
                        self.tracker.stat["loss"].avg):
                    self.conf.logger.log(
                        f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) diverges!!!!!Early stop it."
                    )
                    self._terminate_comm_round()
                    return

                # check stopping condition.
                if self._is_finished_one_comm_round():
                    self._terminate_comm_round()
                    return

            # refresh the logging cache at the end of each epoch.
            self.tracker.reset()
            if self.conf.logger.meet_cache_limit():
                self.conf.logger.save_json()
示例#10
0
def train_and_validate(
    conf, model, criterion, scheduler, optimizer, metrics, data_loader
):
    print("=>>>> start training and validation.\n")

    # define runtime stat tracker and start the training.
    tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names)

    # get the timer.
    timer = conf.timer

    # break until finish expected full epoch training.
    print("=>>>> enter the training.\n")
    while True:
        # init the hidden state.
        _hidden = (
            model.module.init_hidden(conf.batch_size)
            if "DataParallel" == model.__class__.__name__
            else model.init_hidden(conf.batch_size)
        )

        # configure local step.
        for batch in data_loader["train_loader"]:
            model.train()

            # repackage the hidden.
            _hidden = (
                model.module.repackage_hidden(_hidden)
                if "DataParallel" == model.__class__.__name__
                else model.repackage_hidden(_hidden)
            )

            # load data
            with timer("load_data", epoch=scheduler.epoch_):
                _input = batch.text[
                    :,
                    conf.graph.rank
                    * conf.batch_size : (conf.graph.rank + 1)
                    * conf.batch_size,
                ]
                _target = batch.target[
                    :,
                    conf.graph.rank
                    * conf.batch_size : (conf.graph.rank + 1)
                    * conf.batch_size,
                ]
                _input, _target = load_data_batch(conf, _input, _target)

            # inference and get current performance.
            with timer("forward_pass", epoch=scheduler.epoch_):
                optimizer.zero_grad()
                loss, _hidden = inference(
                    conf,
                    model,
                    criterion,
                    metrics,
                    _input,
                    _target,
                    _hidden,
                    tracker_tr,
                )

            with timer("backward_pass", epoch=scheduler.epoch_):
                loss.backward()

            with timer("sync_complete", epoch=scheduler.epoch_):
                # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
                torch.nn.utils.clip_grad_norm_(model.parameters(), conf.rnn_clip)
                n_bits_to_transmit = optimizer.step(timer=timer)
                scheduler.step()

            # display the logging info.
            display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit)

            # finish one epoch training and to decide if we want to val our model.
            if scheduler.epoch_ % 1 == 0:
                if tracker_tr.stat["loss"].avg > 1e3 or np.isnan(
                    tracker_tr.stat["loss"].avg
                ):
                    print("\nThe process diverges!!!!!Early stop it.")
                    error_handler.abort()

                # each worker finish one epoch training.
                do_validate(
                    conf, model, optimizer, criterion, scheduler, metrics, data_loader
                )

                # refresh the logging cache at the begining of each epoch.
                tracker_tr.reset()

                # determine if the training is finished.
                if scheduler.is_stop():
                    conf.logger.save_json()
                    return

            # display tracking time.
            if (
                conf.graph.rank == 0
                and conf.display_tracked_time
                and scheduler.local_index % conf.summary_freq == 0
            ):
                print(timer.summary())
def train_and_validate(
    conf, model, criterion, scheduler, optimizer, metrics, data_loader
):
    print("=>>>> start training and validation.")
    # define runtime stat tracker and start the training.
    tracker_tr = RuntimeTracker(
        metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda
    )

    # get the timer.
    timer = conf.timer
    # break until finish expected full epoch training.
    print("=>>>> enter the training.\n")
    while True:
        dist.barrier()

        # configure local step.
        for _input, _target in data_loader["train_loader"]:
            model.train()

            # load data
            with timer("load_data", epoch=scheduler.epoch_):
                _input, _target = load_data_batch(conf, _input, _target)

            # inference and get current performance.
            with timer("forward_pass", epoch=scheduler.epoch_):
                optimizer.zero_grad()
                loss = inference(model, criterion, metrics, _input, _target, tracker_tr)

            with timer("backward_pass", epoch=scheduler.epoch_):
                loss.backward()

            with timer("sync_and_apply_grad", epoch=scheduler.epoch_):
                n_bits_to_transmit = optimizer.step(timer=timer, scheduler=scheduler)
                scheduler.step()

            # display the logging info.
            display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit)

            # finish one epoch training and to decide if we want to val our model.
            if scheduler.epoch_ % 1 == 0:
                if tracker_tr.stat["loss"].avg > 1e3 or np.isnan(
                    tracker_tr.stat["loss"].avg
                ):
                    print("\nThe process diverges!!!!!Early stop it.")
                    error_handler.abort()

                # each worker finish one epoch training.
                do_validate(
                    conf, model, optimizer, criterion, scheduler, metrics, data_loader
                )

                # refresh the logging cache at the begining of each epoch.
                tracker_tr.reset()

                # determine if the training is finished.
                if scheduler.is_stop():
                    # save json.
                    conf.logger.save_json()
                    return

            # display tracking time.
            if (
                conf.graph.rank == 0
                and conf.display_tracked_time
                and scheduler.local_index % conf.summary_freq == 0
            ):
                print(timer.summary())

        # reshuffle the data.
        if conf.reshuffle_per_epoch:
            print("\nReshuffle the dataset.")
            del data_loader
            gc.collect()
            data_loader = define_dataset(conf)
def ensembled_validate(
    conf,
    coordinator,
    models,
    criterion,
    metrics,
    data_loader,
    label="test_loader",
    ensemble_scheme=None,
):
    """A function for model evaluation."""
    if data_loader is None:
        return None

    # switch to evaluation mode.
    for model in models:
        model.eval()

        # place the model to the device.
        if conf.graph.on_cuda:
            model = model.cuda()

    # evaluate on test_loader.
    tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names)

    for _input, _target in data_loader:
        # load data and check performance.
        data_batch = create_dataset.load_data_batch(
            conf, _input, _target, is_training=False
        )

        with torch.no_grad():
            # ensemble.
            if (
                ensemble_scheme is None
                or ensemble_scheme == "avg_losses"
                or ensemble_scheme == "avg_logits"
            ):
                outputs = []
                for model in models:
                    outputs.append(model(data_batch["input"]))
                output = sum(outputs) / len(outputs)
            elif ensemble_scheme == "avg_probs":
                outputs = []
                for model in models:
                    outputs.append(F.softmax(model(data_batch["input"])))
                output = sum(outputs) / len(outputs)

            # eval the performance.
            loss = torch.FloatTensor([0])
            performance = metrics.evaluate(loss, output, data_batch["target"])

        # update the tracker.
        tracker_te.update_metrics(
            [loss.item()] + performance, n_samples=data_batch["input"].size(0)
        )

    # place back model to the cpu.
    for model in models:
        if conf.graph.on_cuda:
            model = model.cpu()

    # display the test stat.
    if label is not None:
        display_test_stat(conf, coordinator, tracker_te, label)
    perf = tracker_te()
    conf.logger.log(f"The performance of the ensenmbled model: {perf}.")
    return perf
示例#13
0
def training(conf, model, criterion, data_loaders, eps):
    # place the model on gpu
    if conf.graph.on_cuda:
        model = model.cuda()

    # then train the averaged model on the created virtual model.
    optimizer = create_optimizer(conf, model)

    # init the training setup.
    epoch_count = 0
    final_model = copy.deepcopy(model)

    # init the recording status.
    if data_loaders["val_data_loader"] is not None:
        tracker_val = RuntimeTracker(metrics_to_track=[])
        for _ind, (_input,
                   _target) in enumerate(data_loaders["val_data_loader"]):
            # place model and data.
            if conf.graph.on_cuda:
                _input, _target = _input.cuda(), _target.cuda()

            # inference and evaluate.
            model.eval()
            loss = criterion(model(_input), _target)
            tracker_val.update_metrics([loss.item()], n_samples=_input.size(0))
            tracking = {
                "tr_loss_last_epoch": float("inf"),
                "val_loss_last_epoch": tracker_val.stat["loss"].avg,
            }
    else:
        tracking = {
            "tr_loss_last_epoch": float("inf"),
            "val_loss_last_epoch": float("inf"),
        }
    conf.logger.log(
        f"finish {epoch_count} epoch on-server training: train={tracking['tr_loss_last_epoch']}, val={tracking['val_loss_last_epoch']}."
    )

    # on server training and validation.
    while True:
        epoch_count += 1
        tracker_tr = RuntimeTracker(metrics_to_track=[])
        tracker_val = RuntimeTracker(metrics_to_track=[])

        # train on the tr_data_loader.
        for _ind, (_input,
                   _target) in enumerate(data_loaders["tr_data_loader"]):
            # place model and data.
            if conf.graph.on_cuda:
                _input, _target = _input.cuda(), _target.cuda()

            # inference and update alpha
            model.train()
            optimizer.zero_grad()
            loss = criterion(model(_input), _target, smooth_eps=eps)
            tracker_tr.update_metrics([loss.item()], n_samples=_input.size(0))
            loss.backward()
            optimizer.step()

        # validate on the val_data_loader.
        if data_loaders["val_data_loader"] is not None:
            for _ind, (_input,
                       _target) in enumerate(data_loaders["val_data_loader"]):
                # place model and data.
                if conf.graph.on_cuda:
                    _input, _target = _input.cuda(), _target.cuda()

                # inference and evaluate.
                model.eval()
                loss = criterion(model(_input), _target)
                tracker_val.update_metrics([loss.item()],
                                           n_samples=_input.size(0))

            # check the condition.
            if (tracker_tr.stat["loss"].avg < tracking["tr_loss_last_epoch"]
                    and tracker_val.stat["loss"].avg <
                    tracking["val_loss_last_epoch"]):
                conf.logger.log(
                    f"finish {epoch_count} epoch on-server training: train={tracker_tr()}, val={tracker_val()}: will continue training."
                )
                final_model = copy.deepcopy(model)
            else:
                conf.logger.log(
                    f"finish {epoch_count} epoch on-server training: train={tracker_tr()}, val={tracker_val()}: will end training."
                )
                if conf.graph.on_cuda:
                    final_model = final_model.cpu()
                del model
                return final_model
        else:
            conf.logger.log(
                f"finish {epoch_count} epoch on-server training: {tracker_tr()}"
            )
            assert conf.fl_aggregate["epochs"] == "plateau"
            assert "epochs_max" in conf.fl_aggregate
            if (tracking["tr_loss_last_epoch"] - tracker_tr.stat["loss"].avg <=
                    conf.fl_aggregate["plateau_tol"]
                ) or epoch_count >= conf.fl_aggregate["epochs_max"]:
                if conf.graph.on_cuda:
                    model = model.cpu()
                return model

        # update the tracking records.
        tracking = {
            "tr_loss_last_epoch": tracker_tr.stat["loss"].avg,
            "val_loss_last_epoch": tracker_val.stat["loss"].avg,
        }
示例#14
0
    def distillation(self):
        # init the tracker.
        server_tracker = RuntimeTracker(metrics_to_track=["student_loss"],
                                        force_to_replace_metrics=True)
        server_best_tracker = BestPerf(best_perf=None, larger_is_better=True)

        # update the server generator/student
        n_pseudo_batches = 0
        best_models = [None]

        # init the data iter.
        if self.distillation_data_loader is not None:
            data_iter = iter(self.distillation_data_loader)

        # get the client_weights from client's validation performance.
        client_weights = self._get_client_weights()

        # get the init server perf.
        init_perf_on_val = self.validate(model=self.init_server_student,
                                         data_loader=self.val_data_loader)
        self.log_fn(
            f"Batch {n_pseudo_batches}/{self.total_n_server_pseudo_batches}: Student Validation Acc={init_perf_on_val}."
        )

        # iterate over dataset
        while n_pseudo_batches < self.total_n_server_pseudo_batches:
            # get the inputs.
            if self.distillation_data_loader is not None:
                try:
                    pseudo_data = next(data_iter)[0].to(device=self.device)
                except StopIteration:
                    data_iter = iter(self.distillation_data_loader)
                    pseudo_data = next(data_iter)[0].to(device=self.device)
            else:
                if self.conf.fl_aggregate["use_data_scheme"] == "random_data":
                    pseudo_data = self._create_data_randomly()
                else:
                    raise NotImplementedError("incorrect use_data_scheme.")

            # get the logits.
            with torch.no_grad():
                teacher_logits = [
                    _teacher(pseudo_data) for _teacher in self.client_teachers
                ]

            # steps on the same pseudo data
            for _ in range(self.server_local_steps):
                student_logits = self.server_student(pseudo_data)
                student_logits_activations = [
                    (student_logits, self.server_student.activations)
                ] * self.numb_teachers

                stud_avg_loss = self.update_student(
                    student_logits_activations=student_logits_activations,
                    base_solver=self.base_solver,
                    _student=self.server_student,
                    _teachers=self.client_teachers,
                    _opt_student=self.optimizer_server_student,
                    teacher_logits=teacher_logits,
                    update_student_scheme=self.update_student_scheme,
                    weights=client_weights,
                )

            # after each batch.
            if self.use_server_model_scheduler:
                self.scheduler_server_student.step()

            # update the tracker after each batch.
            server_tracker.update_metrics([stud_avg_loss],
                                          n_samples=self.batch_size)

            if (n_pseudo_batches + 1) % self.eval_batches_freq == 0:
                validated_perf = self.validate(
                    model=self.server_student,
                    data_loader=self.val_data_loader)
                self.log_fn(
                    f"Batch {n_pseudo_batches + 1}/{self.total_n_server_pseudo_batches}: Student Loss={server_tracker.stat['student_loss'].avg:02.5f}; Student Validation Acc={validated_perf}."
                )
                server_tracker.reset()

                # check early stopping.
                if self.base_solver.check_early_stopping(
                        model=self.server_student,
                        model_ind=0,
                        best_tracker=server_best_tracker,
                        validated_perf=validated_perf,
                        validated_perfs=self.validated_perfs,
                        perf_index=n_pseudo_batches + 1,
                        early_stopping_batches=self.
                        early_stopping_server_batches,
                        best_models=best_models,
                ):
                    break
            n_pseudo_batches += 1

        # recover the best server model
        use_init_server_model = False
        if self.return_best_model_on_val:
            use_init_server_model = (True if init_perf_on_val["top1"] >
                                     server_best_tracker.best_perf else False)

        # get the server model.
        if use_init_server_model:
            self.log_fn("use init server model instead.")
            best_server_dict = self.init_server_student.state_dict()
        else:
            best_server_dict = best_models[0].state_dict()

        # update the server model.
        self.server_student.load_state_dict(best_server_dict)
        self.server_student = self.server_student.cpu()