def validate(
    conf,
    coordinator,
    model,
    criterion,
    metrics,
    data_loader,
    label="test_loader",
    display=True,
):
    """A function for model evaluation."""
    if data_loader is None:
        return None

    # switch to evaluation mode.
    model.eval()

    # place the model to the device.
    if conf.graph.on_cuda:
        model = model.cuda()

    # evaluate on test_loader.
    tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names)

    for _input, _target in data_loader:
        # load data and check performance.
        data_batch = create_dataset.load_data_batch(
            conf, _input, _target, is_training=False
        )

        with torch.no_grad():
            inference(
                conf,
                model,
                criterion,
                metrics,
                data_batch,
                tracker_te,
                is_training=False,
            )

    # place back model to the cpu.
    if conf.graph.on_cuda:
        model = model.cpu()

    # display the test stat.
    perf = tracker_te()
    if label is not None:
        display_test_stat(conf, coordinator, tracker_te, label)
    if display:
        conf.logger.log(f"The validation performance = {perf}.")
    return perf
Beispiel #2
0
    def _evaluate(_model, label):
        # define stat.
        tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names)

        # switch to evaluation mode
        _model.eval()

        for _input, _target in data_loader["val_loader"]:
            # load data and check performance.
            _input, _target = load_data_batch(conf, _input, _target)

            with torch.no_grad():
                inference(_model, criterion, metrics, _input, _target, tracker_te)

        # display the test stat.
        display_test_stat(conf, scheduler, tracker_te, label)

        # get global (mean) performance
        global_performance = tracker_te.evaluate_global_metrics()
        return global_performance
Beispiel #3
0
def extract_labels_from_local_models(conf, local_models, data_info):
    # init the basic data-loader.
    dataset = data_info["sampler"].use_indices()
    basic_data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=conf.batch_size,
        shuffle=False,
        num_workers=conf.num_workers,
        pin_memory=conf.pin_memory,
        drop_last=False,
    )

    # extract the predictions from local models.
    dict_of_labels = {}
    for idx, _model in local_models.items():
        _model.eval()
        _list_of_labels = []
        if conf.graph.on_cuda:
            _model = _model.cuda()

        for _input, _target in basic_data_loader:
            _data_batch = create_dataset.load_data_batch(conf,
                                                         _input,
                                                         _target,
                                                         is_training=False)
            _output = _model(_data_batch["input"])
            # extract the exact label for the current prediction.
            _list_of_labels.append(
                torch.max(torch.softmax(_output, dim=1), dim=1).indices)

        # concatenate list_of_output
        dict_of_labels[idx] = torch.cat(_list_of_labels)

        # free the memory.
        if conf.graph.on_cuda:
            _model = _model.cpu()
    return dataset, dict_of_labels
def train_and_validate(conf, model, criterion, scheduler, optimizer, metrics,
                       data_loader):
    print("=>>>> start training and validation.\n")

    # define runtime stat tracker and start the training.
    tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names,
                                on_cuda=conf.graph.on_cuda)

    # get the timer.
    timer = conf.timer

    # break until finish expected full epoch training.
    print("=>>>> enter the training.\n")
    while True:
        dist.barrier()

        # configure local step.
        for _input, _target in data_loader["train_loader"]:
            model.train()
            scheduler.step(optimizer)

            # load data
            with timer("load_data", epoch=scheduler.epoch_):
                _input, _target = load_data_batch(conf, _input, _target)

            # inference and get current performance.
            with timer("forward_pass", epoch=scheduler.epoch_):
                optimizer.zero_grad()
                loss = inference(model, criterion, metrics, _input, _target,
                                 tracker_tr)

            with timer("backward_pass", epoch=scheduler.epoch_):
                loss.backward()

            with timer("sync_complete", epoch=scheduler.epoch_):
                n_bits_to_transmit = optimizer.step(timer=timer,
                                                    scheduler=scheduler)

            # display the logging info.
            display_training_stat(conf, scheduler, tracker_tr,
                                  n_bits_to_transmit)

            # finish one epoch training and to decide if we want to val our model.
            if scheduler.epoch_ % 1 == 0:
                if tracker_tr.stat["loss"].avg > 1e3 or np.isnan(
                        tracker_tr.stat["loss"].avg):
                    print("\nThe process diverges!!!!!Early stop it.")
                    error_handler.abort()

                # each worker finish one epoch training.
                do_validate(conf, model, optimizer, criterion, scheduler,
                            metrics, data_loader)

                # refresh the logging cache at the begining of each epoch.
                tracker_tr.reset()

                # evaluate (and only inference) on the whole training loader.
                if (conf.evaluate_consensus
                        or scheduler.is_stop()) and not conf.train_fast:
                    # prepare the dataloader for the consensus evaluation.
                    _data_loader = {
                        "val_loader":
                        _define_cv_dataset(
                            conf,
                            partition_type=None,
                            dataset_type="train",
                            force_shuffle=True,
                        )
                    }

                    # evaluate on the local model.
                    conf.logger.log(
                        "eval the local model on full training data.")
                    validate(
                        conf,
                        model,
                        optimizer,
                        criterion,
                        scheduler,
                        metrics,
                        data_loader=_data_loader,
                        label="eval_local_model_on_full_training_data",
                        force_evaluate_on_averaged_model=False,
                    )

                    # evaluate on the averaged model.
                    conf.logger.log(
                        "eval the averaged model on full training data.")
                    copied_model = copy.deepcopy(
                        model.module if "DataParallel" ==
                        model.__class__.__name__ else model)
                    optimizer.world_aggregator.agg_model(copied_model,
                                                         op="avg")
                    validate(
                        conf,
                        copied_model,
                        optimizer,
                        criterion,
                        scheduler,
                        metrics,
                        data_loader=_data_loader,
                        label="eval_averaged_model_on_full_training_data",
                        force_evaluate_on_averaged_model=False,
                    )

                # determine if the training is finished.
                if scheduler.is_stop():
                    # save json.
                    conf.logger.save_json()

                    # temporarily hack the exit parallelchoco
                    if optimizer.__class__.__name__ == "ParallelCHOCO":
                        error_handler.abort()
                    return

            # display tracking time.
            if (conf.graph.rank == 0 and conf.display_tracked_time
                    and scheduler.local_index % conf.summary_freq == 0):
                print(timer.summary())

        # reshuffle the data.
        if conf.reshuffle_per_epoch:
            print("\nReshuffle the dataset.")
            del data_loader
            gc.collect()
            data_loader = define_dataset(conf)
    def _train(self):
        self.model.train()

        # init the model and dataloader.
        if self.conf.graph.on_cuda:
            self.model = self.model.cuda()
        self.train_loader, _ = create_dataset.define_data_loader(
            self.conf,
            dataset=self.dataset["train"],
            # localdata_id start from 0 to the # of clients - 1.
            # client_id starts from 1 to the # of clients.
            localdata_id=self.conf.graph.client_id - 1,
            is_train=True,
            data_partitioner=self.data_partitioner,
        )

        # define optimizer, scheduler and runtime tracker.
        self.optimizer = create_optimizer.define_optimizer(
            self.conf, model=self.model, optimizer_name=self.conf.optimizer)
        self.scheduler = create_scheduler.Scheduler(self.conf,
                                                    optimizer=self.optimizer)
        self.tracker = RuntimeTracker(
            metrics_to_track=self.metrics.metric_names)
        self.conf.logger.log(
            f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) enters the local training phase (current communication rounds={self.conf.graph.comm_round})."
        )

        # efficient local training.
        if hasattr(self, "model_compression_fn"):
            self.model_compression_fn.compress_model(
                param_groups=self.optimizer.param_groups)

        # entering local updates and will finish only after reaching the expected local_n_epochs.
        while True:
            for _input, _target in self.train_loader:
                # load data
                with self.timer("load_data", epoch=self.scheduler.epoch_):
                    data_batch = create_dataset.load_data_batch(
                        self.conf, _input, _target, is_training=True)

                # inference and get current performance.
                with self.timer("forward_pass", epoch=self.scheduler.epoch_):
                    self.optimizer.zero_grad()
                    loss, output = self._inference(data_batch)

                    # in case we need self distillation to penalize the local training
                    # (avoid catastrophic forgetting).
                    self._local_training_with_self_distillation(
                        loss, output, data_batch)

                with self.timer("backward_pass", epoch=self.scheduler.epoch_):
                    loss.backward()
                    self._add_grad_from_prox_regularized_loss()
                    self.optimizer.step()
                    self.scheduler.step()

                # efficient local training.
                with self.timer("compress_model", epoch=self.scheduler.epoch_):
                    if hasattr(self, "model_compression_fn"):
                        self.model_compression_fn.compress_model(
                            param_groups=self.optimizer.param_groups)

                # display the logging info.
                display_training_stat(self.conf, self.scheduler, self.tracker)

                # display tracking time.
                if (self.conf.display_tracked_time
                        and self.scheduler.local_index % self.conf.summary_freq
                        == 0):
                    self.conf.logger.log(self.timer.summary())

                # check divergence.
                if self.tracker.stat["loss"].avg > 1e3 or np.isnan(
                        self.tracker.stat["loss"].avg):
                    self.conf.logger.log(
                        f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) diverges!!!!!Early stop it."
                    )
                    self._terminate_comm_round()
                    return

                # check stopping condition.
                if self._is_finished_one_comm_round():
                    self._terminate_comm_round()
                    return

            # refresh the logging cache at the end of each epoch.
            self.tracker.reset()
            if self.conf.logger.meet_cache_limit():
                self.conf.logger.save_json()
Beispiel #6
0
def train_and_validate(
    conf, model, criterion, scheduler, optimizer, metrics, data_loader
):
    print("=>>>> start training and validation.\n")

    # define runtime stat tracker and start the training.
    tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names)

    # get the timer.
    timer = conf.timer

    # break until finish expected full epoch training.
    print("=>>>> enter the training.\n")
    while True:
        # init the hidden state.
        _hidden = (
            model.module.init_hidden(conf.batch_size)
            if "DataParallel" == model.__class__.__name__
            else model.init_hidden(conf.batch_size)
        )

        # configure local step.
        for batch in data_loader["train_loader"]:
            model.train()

            # repackage the hidden.
            _hidden = (
                model.module.repackage_hidden(_hidden)
                if "DataParallel" == model.__class__.__name__
                else model.repackage_hidden(_hidden)
            )

            # load data
            with timer("load_data", epoch=scheduler.epoch_):
                _input = batch.text[
                    :,
                    conf.graph.rank
                    * conf.batch_size : (conf.graph.rank + 1)
                    * conf.batch_size,
                ]
                _target = batch.target[
                    :,
                    conf.graph.rank
                    * conf.batch_size : (conf.graph.rank + 1)
                    * conf.batch_size,
                ]
                _input, _target = load_data_batch(conf, _input, _target)

            # inference and get current performance.
            with timer("forward_pass", epoch=scheduler.epoch_):
                optimizer.zero_grad()
                loss, _hidden = inference(
                    conf,
                    model,
                    criterion,
                    metrics,
                    _input,
                    _target,
                    _hidden,
                    tracker_tr,
                )

            with timer("backward_pass", epoch=scheduler.epoch_):
                loss.backward()

            with timer("sync_complete", epoch=scheduler.epoch_):
                # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
                torch.nn.utils.clip_grad_norm_(model.parameters(), conf.rnn_clip)
                n_bits_to_transmit = optimizer.step(timer=timer)
                scheduler.step()

            # display the logging info.
            display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit)

            # finish one epoch training and to decide if we want to val our model.
            if scheduler.epoch_ % 1 == 0:
                if tracker_tr.stat["loss"].avg > 1e3 or np.isnan(
                    tracker_tr.stat["loss"].avg
                ):
                    print("\nThe process diverges!!!!!Early stop it.")
                    error_handler.abort()

                # each worker finish one epoch training.
                do_validate(
                    conf, model, optimizer, criterion, scheduler, metrics, data_loader
                )

                # refresh the logging cache at the begining of each epoch.
                tracker_tr.reset()

                # determine if the training is finished.
                if scheduler.is_stop():
                    conf.logger.save_json()
                    return

            # display tracking time.
            if (
                conf.graph.rank == 0
                and conf.display_tracked_time
                and scheduler.local_index % conf.summary_freq == 0
            ):
                print(timer.summary())
def train_and_validate(
    conf, model, criterion, scheduler, optimizer, metrics, data_loader
):
    print("=>>>> start training and validation.")
    # define runtime stat tracker and start the training.
    tracker_tr = RuntimeTracker(
        metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda
    )

    # get the timer.
    timer = conf.timer
    # break until finish expected full epoch training.
    print("=>>>> enter the training.\n")
    while True:
        dist.barrier()

        # configure local step.
        for _input, _target in data_loader["train_loader"]:
            model.train()

            # load data
            with timer("load_data", epoch=scheduler.epoch_):
                _input, _target = load_data_batch(conf, _input, _target)

            # inference and get current performance.
            with timer("forward_pass", epoch=scheduler.epoch_):
                optimizer.zero_grad()
                loss = inference(model, criterion, metrics, _input, _target, tracker_tr)

            with timer("backward_pass", epoch=scheduler.epoch_):
                loss.backward()

            with timer("sync_and_apply_grad", epoch=scheduler.epoch_):
                n_bits_to_transmit = optimizer.step(timer=timer, scheduler=scheduler)
                scheduler.step()

            # display the logging info.
            display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit)

            # finish one epoch training and to decide if we want to val our model.
            if scheduler.epoch_ % 1 == 0:
                if tracker_tr.stat["loss"].avg > 1e3 or np.isnan(
                    tracker_tr.stat["loss"].avg
                ):
                    print("\nThe process diverges!!!!!Early stop it.")
                    error_handler.abort()

                # each worker finish one epoch training.
                do_validate(
                    conf, model, optimizer, criterion, scheduler, metrics, data_loader
                )

                # refresh the logging cache at the begining of each epoch.
                tracker_tr.reset()

                # determine if the training is finished.
                if scheduler.is_stop():
                    # save json.
                    conf.logger.save_json()
                    return

            # display tracking time.
            if (
                conf.graph.rank == 0
                and conf.display_tracked_time
                and scheduler.local_index % conf.summary_freq == 0
            ):
                print(timer.summary())

        # reshuffle the data.
        if conf.reshuffle_per_epoch:
            print("\nReshuffle the dataset.")
            del data_loader
            gc.collect()
            data_loader = define_dataset(conf)
def ensembled_validate(
    conf,
    coordinator,
    models,
    criterion,
    metrics,
    data_loader,
    label="test_loader",
    ensemble_scheme=None,
):
    """A function for model evaluation."""
    if data_loader is None:
        return None

    # switch to evaluation mode.
    for model in models:
        model.eval()

        # place the model to the device.
        if conf.graph.on_cuda:
            model = model.cuda()

    # evaluate on test_loader.
    tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names)

    for _input, _target in data_loader:
        # load data and check performance.
        data_batch = create_dataset.load_data_batch(
            conf, _input, _target, is_training=False
        )

        with torch.no_grad():
            # ensemble.
            if (
                ensemble_scheme is None
                or ensemble_scheme == "avg_losses"
                or ensemble_scheme == "avg_logits"
            ):
                outputs = []
                for model in models:
                    outputs.append(model(data_batch["input"]))
                output = sum(outputs) / len(outputs)
            elif ensemble_scheme == "avg_probs":
                outputs = []
                for model in models:
                    outputs.append(F.softmax(model(data_batch["input"])))
                output = sum(outputs) / len(outputs)

            # eval the performance.
            loss = torch.FloatTensor([0])
            performance = metrics.evaluate(loss, output, data_batch["target"])

        # update the tracker.
        tracker_te.update_metrics(
            [loss.item()] + performance, n_samples=data_batch["input"].size(0)
        )

    # place back model to the cpu.
    for model in models:
        if conf.graph.on_cuda:
            model = model.cpu()

    # display the test stat.
    if label is not None:
        display_test_stat(conf, coordinator, tracker_te, label)
    perf = tracker_te()
    conf.logger.log(f"The performance of the ensenmbled model: {perf}.")
    return perf