def distill_knowledge( conf, student_model, dataset, num_epochs, batch_size, teacher_model=None, softmax_temperature=1, ): # init. data_loader = create_data_loader(dataset, batch_size=batch_size) criterion = torch.nn.KLDivLoss(reduction="batchmean") tracker = RuntimeTracker(metrics_to_track=[]) # check model status. untrainable_teacher_model = (agg_utils.modify_model_trainable_status( conf, teacher_model, trainable=False) if teacher_model is not None else None) trainable_student_model = agg_utils.check_trainable(conf, student_model) optimizer = create_optimizer(conf, trainable_student_model) # start the formal training. for epoch_idx in range(num_epochs): for _input, _target in data_loader: # init the _input, _target. if conf.graph.on_cuda: _input = _input.cuda() if untrainable_teacher_model is None and conf.graph.on_cuda: _target_prob = _target.cuda() # perform fp/bp on the student model. optimizer.zero_grad() _output = trainable_student_model(_input) # evaluate the loss. if untrainable_teacher_model is not None: loss = (softmax_temperature**2) * criterion( torch.nn.functional.log_softmax( _output / softmax_temperature, dim=1), torch.nn.functional.softmax( untrainable_teacher_model(_input).detach() / softmax_temperature, dim=1, ), ) else: loss = criterion( torch.nn.functional.log_softmax(_output, dim=1), _target_prob) loss.backward() optimizer.step() tracker.update_metrics([loss.item()], n_samples=_input.size(0)) conf.logger.log(f"# of epochs={epoch_idx + 1}: {tracker()}") return trainable_student_model.cpu()
def _evaluate(_model, label): # define stat. tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names) # switch to evaluation mode _model.eval() # define hidden state for RNN. _hidden = ( model.module.init_hidden(conf.batch_size) if "DataParallel" == model.__class__.__name__ else model.init_hidden(conf.batch_size) ) for batch in data_loader["val_loader"]: # load data and check performance. _input, _target = batch.text, batch.target # repackage the hidden. _hidden = ( model.module.repackage_hidden(_hidden) if "DataParallel" == model.__class__.__name__ else model.repackage_hidden(_hidden) ) with torch.no_grad(): _, _hidden = inference( conf, _model, criterion, metrics, _input, _target, _hidden, tracker_te, ) # display the test stat. display_test_stat(conf, scheduler, tracker_te, label) # get global (mean) performance global_performance = tracker_te.evaluate_global_metrics() return global_performance
def _evaluate(_model, label): # define stat. tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names) # switch to evaluation mode _model.eval() for _input, _target in data_loader["val_loader"]: # load data and check performance. _input, _target = load_data_batch(conf, _input, _target) with torch.no_grad(): inference(_model, criterion, metrics, _input, _target, tracker_te) # display the test stat. display_test_stat(conf, scheduler, tracker_te, label) # get global (mean) performance global_performance = tracker_te.evaluate_global_metrics() return global_performance
def validate( conf, coordinator, model, criterion, metrics, data_loader, label="test_loader", display=True, ): """A function for model evaluation.""" if data_loader is None: return None # switch to evaluation mode. model.eval() # place the model to the device. if conf.graph.on_cuda: model = model.cuda() # evaluate on test_loader. tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names) for _input, _target in data_loader: # load data and check performance. data_batch = create_dataset.load_data_batch( conf, _input, _target, is_training=False ) with torch.no_grad(): inference( conf, model, criterion, metrics, data_batch, tracker_te, is_training=False, ) # place back model to the cpu. if conf.graph.on_cuda: model = model.cpu() # display the test stat. perf = tracker_te() if label is not None: display_test_stat(conf, coordinator, tracker_te, label) if display: conf.logger.log(f"The validation performance = {perf}.") return perf
def _construct_input_via_expected_output_space( self, constructed_probs_and_labels): # generated the input based on these dirichlet distributions. generated_inputs, generated_probs = [], [] model = agg_utils.modify_model_trainable_status(self.conf, self.model, trainable=False) criterion = torch.nn.KLDivLoss(reduction="batchmean") tracker = RuntimeTracker(metrics_to_track=[]) # init the dataset for the training dataset = CustomDataset(constructed_probs_and_labels) data_loader = create_data_loader( dataset, batch_size=int( self.conf.fl_aggregate["kt_g_batch_size_per_class"])) num_update_per_batch = int( self.conf.fl_aggregate["kt_data_generate_iters"]) self.conf.logger.log( f"# of mini-batches={len(data_loader)}, size of mini-batch={self.conf.fl_aggregate['kt_g_batch_size_per_class']}, # of update per-mini-batch={num_update_per_batch}" ) # training the dataset. for batch_idx, probs in enumerate(data_loader): _generated_input = torch.rand( (len(probs), 3, 32, 32), requires_grad=True, device="cuda" if self.conf.graph.on_cuda else "cpu", ) optimizer = torch.optim.Adam( [_generated_input], lr=self.conf.fl_aggregate["step_size"], betas=(self.conf.adam_beta_1, self.conf.adam_beta_2), eps=self.conf.adam_eps, ) # improve the input_data to minic the output space of the network. for _ in range(num_update_per_batch): loss = update_input_data( self.conf, model, criterion, optimizer, _generated_input, expected_probs=probs.cuda() if self.conf.graph.on_cuda else probs, ) tracker.update_metrics([loss.item()], n_samples=_generated_input.size(0)) self.conf.logger.log( f"\t the data generation loss (model index={self.model_idx}, batch index={batch_idx}) = {tracker()}." ) tracker.reset() generated_inputs.append(copy.deepcopy(_generated_input.data)) generated_probs.append(probs) generated_inputs = torch.cat(generated_inputs, dim=0).data.cpu() generated_probs = torch.cat(generated_probs, dim=0).data.cpu() return generated_inputs, generated_probs
def train_and_validate(conf, model, criterion, scheduler, optimizer, metrics, data_loader): print("=>>>> start training and validation.\n") # define runtime stat tracker and start the training. tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda) # get the timer. timer = conf.timer # break until finish expected full epoch training. print("=>>>> enter the training.\n") while True: dist.barrier() # configure local step. for _input, _target in data_loader["train_loader"]: model.train() scheduler.step(optimizer) # load data with timer("load_data", epoch=scheduler.epoch_): _input, _target = load_data_batch(conf, _input, _target) # inference and get current performance. with timer("forward_pass", epoch=scheduler.epoch_): optimizer.zero_grad() loss = inference(model, criterion, metrics, _input, _target, tracker_tr) with timer("backward_pass", epoch=scheduler.epoch_): loss.backward() with timer("sync_complete", epoch=scheduler.epoch_): n_bits_to_transmit = optimizer.step(timer=timer, scheduler=scheduler) # display the logging info. display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) # finish one epoch training and to decide if we want to val our model. if scheduler.epoch_ % 1 == 0: if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( tracker_tr.stat["loss"].avg): print("\nThe process diverges!!!!!Early stop it.") error_handler.abort() # each worker finish one epoch training. do_validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader) # refresh the logging cache at the begining of each epoch. tracker_tr.reset() # evaluate (and only inference) on the whole training loader. if (conf.evaluate_consensus or scheduler.is_stop()) and not conf.train_fast: # prepare the dataloader for the consensus evaluation. _data_loader = { "val_loader": _define_cv_dataset( conf, partition_type=None, dataset_type="train", force_shuffle=True, ) } # evaluate on the local model. conf.logger.log( "eval the local model on full training data.") validate( conf, model, optimizer, criterion, scheduler, metrics, data_loader=_data_loader, label="eval_local_model_on_full_training_data", force_evaluate_on_averaged_model=False, ) # evaluate on the averaged model. conf.logger.log( "eval the averaged model on full training data.") copied_model = copy.deepcopy( model.module if "DataParallel" == model.__class__.__name__ else model) optimizer.world_aggregator.agg_model(copied_model, op="avg") validate( conf, copied_model, optimizer, criterion, scheduler, metrics, data_loader=_data_loader, label="eval_averaged_model_on_full_training_data", force_evaluate_on_averaged_model=False, ) # determine if the training is finished. if scheduler.is_stop(): # save json. conf.logger.save_json() # temporarily hack the exit parallelchoco if optimizer.__class__.__name__ == "ParallelCHOCO": error_handler.abort() return # display tracking time. if (conf.graph.rank == 0 and conf.display_tracked_time and scheduler.local_index % conf.summary_freq == 0): print(timer.summary()) # reshuffle the data. if conf.reshuffle_per_epoch: print("\nReshuffle the dataset.") del data_loader gc.collect() data_loader = define_dataset(conf)
def distillation(self): # init the tracker. server_tracker = RuntimeTracker(metrics_to_track=["student_loss"], force_to_replace_metrics=True) # init the data iter. if self.distillation_data_loader is not None: data_iter = iter(self.distillation_data_loader) # get the client_weights from client's validation performance. client_weights = self._get_client_weights() # get the init server perf. init_perf_on_val = self.validate(model=self.server_student, data_loader=self.val_data_loader) # iterate over dataset n_pseudo_batches = 0 self.log_fn( f"Batch {n_pseudo_batches}/{self.total_n_server_pseudo_batches}: Student Validation Acc={init_perf_on_val}." ) while n_pseudo_batches < self.total_n_server_pseudo_batches: # get the inputs. if self.distillation_data_loader is not None: try: pseudo_data = next(data_iter)[0].to(device=self.device) except StopIteration: data_iter = iter(self.distillation_data_loader) pseudo_data = next(data_iter)[0].to(device=self.device) else: if self.conf.fl_aggregate["use_data_scheme"] == "random_data": pseudo_data = self._create_data_randomly() else: raise NotImplementedError("incorrect use_data_scheme.") # get the logits. with torch.no_grad(): teacher_logits = [ _teacher(pseudo_data) for _teacher in self.client_teachers ] # steps on the same pseudo data student_logits = self.server_student(pseudo_data) student_logits_activations = [ (student_logits, self.server_student.activations) ] * self.numb_teachers stud_avg_loss = self.update_student( student_logits_activations=student_logits_activations, base_solver=self.base_solver, _student=self.server_student, _teachers=self.client_teachers, _opt_student=self.swa_optimizer, teacher_logits=teacher_logits, update_student_scheme=self.update_student_scheme, weights=client_weights, ) # update the tracker after each batch. server_tracker.update_metrics([stud_avg_loss], n_samples=self.batch_size) if (n_pseudo_batches + 1) % self.eval_batches_freq == 0: validated_perf = self.validate( model=self.server_student, data_loader=self.val_data_loader) self.log_fn( f"Batch {n_pseudo_batches + 1}/{self.total_n_server_pseudo_batches}: Student Loss={server_tracker.stat['student_loss'].avg:02.5f}; Student Validation Acc={validated_perf}." ) server_tracker.reset() n_pseudo_batches += 1 # update the server model. self.swa_optimizer.swap_swa_sgd() self.server_student = self.server_student.cpu()
class Worker(object): def __init__(self, conf): self.conf = conf # some initializations. self.rank = conf.graph.rank conf.graph.worker_id = conf.graph.rank self.device = torch.device( "cuda" if self.conf.graph.on_cuda else "cpu") # define the timer for different operations. # if we choose the `train_fast` mode, then we will not track the time. self.timer = Timer( verbosity_level=1 if conf.track_time and not conf.train_fast else 0, log_fn=conf.logger.log_metric, ) # create dataset (as well as the potential data_partitioner) for training. dist.barrier() self.dataset = create_dataset.define_dataset(conf, data=conf.data) _, self.data_partitioner = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], localdata_id=0, # random id here. is_train=True, data_partitioner=None, ) conf.logger.log( f"Worker-{self.conf.graph.worker_id} initialized the local training data with Master." ) # define the criterion. self.criterion = nn.CrossEntropyLoss(reduction="mean") # define the model compression operators. if conf.local_model_compression is not None: if conf.local_model_compression == "quantization": self.model_compression_fn = compressor.ModelQuantization(conf) conf.logger.log( f"Worker-{conf.graph.worker_id} initialized dataset/criterion.\n") def run(self): while True: self._listen_to_master() # check if we need to terminate the training or not. if self._terminate_by_early_stopping(): return self._recv_model_from_master() self._train() self._send_model_to_master() # check if we need to terminate the training or not. if self._terminate_by_complete_training(): return def _listen_to_master(self): # listen to master, related to the function `_activate_selected_clients` in `master.py`. msg = torch.zeros((3, self.conf.n_participated)) dist.broadcast(tensor=msg, src=0) self.conf.graph.client_id, self.conf.graph.comm_round, self.n_local_epochs = ( msg[:, self.conf.graph.rank - 1].to(int).cpu().numpy().tolist()) # once we receive the signal, we init for the local training. self.arch, self.model = create_model.define_model( self.conf, to_consistent_model=False, client_id=self.conf.graph.client_id) self.model_state_dict = self.model.state_dict() self.model_tb = TensorBuffer(list(self.model_state_dict.values())) self.metrics = create_metrics.Metrics(self.model, task="classification") dist.barrier() def _recv_model_from_master(self): # related to the function `_send_model_to_selected_clients` in `master.py` old_buffer = copy.deepcopy(self.model_tb.buffer) dist.recv(self.model_tb.buffer, src=0) new_buffer = copy.deepcopy(self.model_tb.buffer) self.model_tb.unpack(self.model_state_dict.values()) self.model.load_state_dict(self.model_state_dict) random_reinit.random_reinit_model(self.conf, self.model) self.init_model = self._turn_off_grad( copy.deepcopy(self.model).to(self.device)) self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) received the model ({self.arch}) from Master. The model status {'is updated' if old_buffer.norm() != new_buffer.norm() else 'is not updated'}." ) dist.barrier() def _train(self): self.model.train() # init the model and dataloader. if self.conf.graph.on_cuda: self.model = self.model.cuda() self.train_loader, _ = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], # localdata_id start from 0 to the # of clients - 1. # client_id starts from 1 to the # of clients. localdata_id=self.conf.graph.client_id - 1, is_train=True, data_partitioner=self.data_partitioner, ) # define optimizer, scheduler and runtime tracker. self.optimizer = create_optimizer.define_optimizer( self.conf, model=self.model, optimizer_name=self.conf.optimizer) self.scheduler = create_scheduler.Scheduler(self.conf, optimizer=self.optimizer) self.tracker = RuntimeTracker( metrics_to_track=self.metrics.metric_names) self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) enters the local training phase (current communication rounds={self.conf.graph.comm_round})." ) # efficient local training. if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # entering local updates and will finish only after reaching the expected local_n_epochs. while True: for _input, _target in self.train_loader: # load data with self.timer("load_data", epoch=self.scheduler.epoch_): data_batch = create_dataset.load_data_batch( self.conf, _input, _target, is_training=True) # inference and get current performance. with self.timer("forward_pass", epoch=self.scheduler.epoch_): self.optimizer.zero_grad() loss, output = self._inference(data_batch) # in case we need self distillation to penalize the local training # (avoid catastrophic forgetting). self._local_training_with_self_distillation( loss, output, data_batch) with self.timer("backward_pass", epoch=self.scheduler.epoch_): loss.backward() self._add_grad_from_prox_regularized_loss() self.optimizer.step() self.scheduler.step() # efficient local training. with self.timer("compress_model", epoch=self.scheduler.epoch_): if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # display the logging info. display_training_stat(self.conf, self.scheduler, self.tracker) # display tracking time. if (self.conf.display_tracked_time and self.scheduler.local_index % self.conf.summary_freq == 0): self.conf.logger.log(self.timer.summary()) # check divergence. if self.tracker.stat["loss"].avg > 1e3 or np.isnan( self.tracker.stat["loss"].avg): self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) diverges!!!!!Early stop it." ) self._terminate_comm_round() return # check stopping condition. if self._is_finished_one_comm_round(): self._terminate_comm_round() return # refresh the logging cache at the end of each epoch. self.tracker.reset() if self.conf.logger.meet_cache_limit(): self.conf.logger.save_json() def _inference(self, data_batch): """Inference on the given model and get loss and accuracy.""" # do the forward pass and get the output. output = self.model(data_batch["input"]) # evaluate the output and get the loss, performance. if self.conf.use_mixup: loss = mixup.mixup_criterion( self.criterion, output, data_batch["target_a"], data_batch["target_b"], data_batch["mixup_lambda"], ) performance_a = self.metrics.evaluate(loss, output, data_batch["target_a"]) performance_b = self.metrics.evaluate(loss, output, data_batch["target_b"]) performance = [ data_batch["mixup_lambda"] * _a + (1 - data_batch["mixup_lambda"]) * _b for _a, _b in zip(performance_a, performance_b) ] else: loss = self.criterion(output, data_batch["target"]) performance = self.metrics.evaluate(loss, output, data_batch["target"]) # update tracker. if self.tracker is not None: self.tracker.update_metrics([loss.item()] + performance, n_samples=data_batch["input"].size(0)) return loss, output def _add_grad_from_prox_regularized_loss(self): assert self.conf.local_prox_term >= 0 if self.conf.local_prox_term != 0: assert self.conf.weight_decay == 0 assert self.conf.optimizer == "sgd" assert self.conf.momentum_factor == 0 for _param, _init_param in zip(self.model.parameters(), self.init_model.parameters()): if _param.grad is not None: _param.grad.data.add_((_param.data - _init_param.data) * self.conf.local_prox_term) def _local_training_with_self_distillation(self, loss, output, data_batch): if self.conf.self_distillation > 0: loss = loss * ( 1 - self.conf.self_distillation ) + self.conf.self_distillation * self._divergence( student_logits=output / self.conf.self_distillation_temperature, teacher_logits=self.init_model(data_batch["input"]) / self.conf.self_distillation_temperature, ) return loss def _divergence(self, student_logits, teacher_logits): divergence = F.kl_div( F.log_softmax(student_logits, dim=1), F.softmax(teacher_logits, dim=1), reduction="batchmean", ) # forward KL return divergence def _turn_off_grad(self, model): for param in model.parameters(): param.requires_grad = False return model def _send_model_to_master(self): dist.barrier() self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) sending the model ({self.arch}) back to Master." ) flatten_model = TensorBuffer(list(self.model.state_dict().values())) dist.send(tensor=flatten_model.buffer, dst=0) dist.barrier() def _terminate_comm_round(self): self.model = self.model.cpu() del self.init_model self.scheduler.clean() self.conf.logger.save_json() self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) finished one round of federated learning: (comm_round={self.conf.graph.comm_round})." ) def _terminate_by_early_stopping(self): if self.conf.graph.comm_round == -1: dist.barrier() self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} finished the federated learning by early-stopping." ) return True else: return False def _terminate_by_complete_training(self): if self.conf.graph.comm_round == self.conf.n_comm_rounds: dist.barrier() self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} finished the federated learning: (total comm_rounds={self.conf.graph.comm_round})." ) return True else: return False def _is_finished_one_comm_round(self): return True if self.conf.epoch_ >= self.conf.local_n_epochs else False
def _train(self): self.model.train() # init the model and dataloader. if self.conf.graph.on_cuda: self.model = self.model.cuda() self.train_loader, _ = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], # localdata_id start from 0 to the # of clients - 1. # client_id starts from 1 to the # of clients. localdata_id=self.conf.graph.client_id - 1, is_train=True, data_partitioner=self.data_partitioner, ) # define optimizer, scheduler and runtime tracker. self.optimizer = create_optimizer.define_optimizer( self.conf, model=self.model, optimizer_name=self.conf.optimizer) self.scheduler = create_scheduler.Scheduler(self.conf, optimizer=self.optimizer) self.tracker = RuntimeTracker( metrics_to_track=self.metrics.metric_names) self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) enters the local training phase (current communication rounds={self.conf.graph.comm_round})." ) # efficient local training. if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # entering local updates and will finish only after reaching the expected local_n_epochs. while True: for _input, _target in self.train_loader: # load data with self.timer("load_data", epoch=self.scheduler.epoch_): data_batch = create_dataset.load_data_batch( self.conf, _input, _target, is_training=True) # inference and get current performance. with self.timer("forward_pass", epoch=self.scheduler.epoch_): self.optimizer.zero_grad() loss, output = self._inference(data_batch) # in case we need self distillation to penalize the local training # (avoid catastrophic forgetting). self._local_training_with_self_distillation( loss, output, data_batch) with self.timer("backward_pass", epoch=self.scheduler.epoch_): loss.backward() self._add_grad_from_prox_regularized_loss() self.optimizer.step() self.scheduler.step() # efficient local training. with self.timer("compress_model", epoch=self.scheduler.epoch_): if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # display the logging info. display_training_stat(self.conf, self.scheduler, self.tracker) # display tracking time. if (self.conf.display_tracked_time and self.scheduler.local_index % self.conf.summary_freq == 0): self.conf.logger.log(self.timer.summary()) # check divergence. if self.tracker.stat["loss"].avg > 1e3 or np.isnan( self.tracker.stat["loss"].avg): self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) diverges!!!!!Early stop it." ) self._terminate_comm_round() return # check stopping condition. if self._is_finished_one_comm_round(): self._terminate_comm_round() return # refresh the logging cache at the end of each epoch. self.tracker.reset() if self.conf.logger.meet_cache_limit(): self.conf.logger.save_json()
def train_and_validate( conf, model, criterion, scheduler, optimizer, metrics, data_loader ): print("=>>>> start training and validation.\n") # define runtime stat tracker and start the training. tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names) # get the timer. timer = conf.timer # break until finish expected full epoch training. print("=>>>> enter the training.\n") while True: # init the hidden state. _hidden = ( model.module.init_hidden(conf.batch_size) if "DataParallel" == model.__class__.__name__ else model.init_hidden(conf.batch_size) ) # configure local step. for batch in data_loader["train_loader"]: model.train() # repackage the hidden. _hidden = ( model.module.repackage_hidden(_hidden) if "DataParallel" == model.__class__.__name__ else model.repackage_hidden(_hidden) ) # load data with timer("load_data", epoch=scheduler.epoch_): _input = batch.text[ :, conf.graph.rank * conf.batch_size : (conf.graph.rank + 1) * conf.batch_size, ] _target = batch.target[ :, conf.graph.rank * conf.batch_size : (conf.graph.rank + 1) * conf.batch_size, ] _input, _target = load_data_batch(conf, _input, _target) # inference and get current performance. with timer("forward_pass", epoch=scheduler.epoch_): optimizer.zero_grad() loss, _hidden = inference( conf, model, criterion, metrics, _input, _target, _hidden, tracker_tr, ) with timer("backward_pass", epoch=scheduler.epoch_): loss.backward() with timer("sync_complete", epoch=scheduler.epoch_): # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(model.parameters(), conf.rnn_clip) n_bits_to_transmit = optimizer.step(timer=timer) scheduler.step() # display the logging info. display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) # finish one epoch training and to decide if we want to val our model. if scheduler.epoch_ % 1 == 0: if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( tracker_tr.stat["loss"].avg ): print("\nThe process diverges!!!!!Early stop it.") error_handler.abort() # each worker finish one epoch training. do_validate( conf, model, optimizer, criterion, scheduler, metrics, data_loader ) # refresh the logging cache at the begining of each epoch. tracker_tr.reset() # determine if the training is finished. if scheduler.is_stop(): conf.logger.save_json() return # display tracking time. if ( conf.graph.rank == 0 and conf.display_tracked_time and scheduler.local_index % conf.summary_freq == 0 ): print(timer.summary())
def train_and_validate( conf, model, criterion, scheduler, optimizer, metrics, data_loader ): print("=>>>> start training and validation.") # define runtime stat tracker and start the training. tracker_tr = RuntimeTracker( metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda ) # get the timer. timer = conf.timer # break until finish expected full epoch training. print("=>>>> enter the training.\n") while True: dist.barrier() # configure local step. for _input, _target in data_loader["train_loader"]: model.train() # load data with timer("load_data", epoch=scheduler.epoch_): _input, _target = load_data_batch(conf, _input, _target) # inference and get current performance. with timer("forward_pass", epoch=scheduler.epoch_): optimizer.zero_grad() loss = inference(model, criterion, metrics, _input, _target, tracker_tr) with timer("backward_pass", epoch=scheduler.epoch_): loss.backward() with timer("sync_and_apply_grad", epoch=scheduler.epoch_): n_bits_to_transmit = optimizer.step(timer=timer, scheduler=scheduler) scheduler.step() # display the logging info. display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) # finish one epoch training and to decide if we want to val our model. if scheduler.epoch_ % 1 == 0: if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( tracker_tr.stat["loss"].avg ): print("\nThe process diverges!!!!!Early stop it.") error_handler.abort() # each worker finish one epoch training. do_validate( conf, model, optimizer, criterion, scheduler, metrics, data_loader ) # refresh the logging cache at the begining of each epoch. tracker_tr.reset() # determine if the training is finished. if scheduler.is_stop(): # save json. conf.logger.save_json() return # display tracking time. if ( conf.graph.rank == 0 and conf.display_tracked_time and scheduler.local_index % conf.summary_freq == 0 ): print(timer.summary()) # reshuffle the data. if conf.reshuffle_per_epoch: print("\nReshuffle the dataset.") del data_loader gc.collect() data_loader = define_dataset(conf)
def ensembled_validate( conf, coordinator, models, criterion, metrics, data_loader, label="test_loader", ensemble_scheme=None, ): """A function for model evaluation.""" if data_loader is None: return None # switch to evaluation mode. for model in models: model.eval() # place the model to the device. if conf.graph.on_cuda: model = model.cuda() # evaluate on test_loader. tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names) for _input, _target in data_loader: # load data and check performance. data_batch = create_dataset.load_data_batch( conf, _input, _target, is_training=False ) with torch.no_grad(): # ensemble. if ( ensemble_scheme is None or ensemble_scheme == "avg_losses" or ensemble_scheme == "avg_logits" ): outputs = [] for model in models: outputs.append(model(data_batch["input"])) output = sum(outputs) / len(outputs) elif ensemble_scheme == "avg_probs": outputs = [] for model in models: outputs.append(F.softmax(model(data_batch["input"]))) output = sum(outputs) / len(outputs) # eval the performance. loss = torch.FloatTensor([0]) performance = metrics.evaluate(loss, output, data_batch["target"]) # update the tracker. tracker_te.update_metrics( [loss.item()] + performance, n_samples=data_batch["input"].size(0) ) # place back model to the cpu. for model in models: if conf.graph.on_cuda: model = model.cpu() # display the test stat. if label is not None: display_test_stat(conf, coordinator, tracker_te, label) perf = tracker_te() conf.logger.log(f"The performance of the ensenmbled model: {perf}.") return perf
def training(conf, model, criterion, data_loaders, eps): # place the model on gpu if conf.graph.on_cuda: model = model.cuda() # then train the averaged model on the created virtual model. optimizer = create_optimizer(conf, model) # init the training setup. epoch_count = 0 final_model = copy.deepcopy(model) # init the recording status. if data_loaders["val_data_loader"] is not None: tracker_val = RuntimeTracker(metrics_to_track=[]) for _ind, (_input, _target) in enumerate(data_loaders["val_data_loader"]): # place model and data. if conf.graph.on_cuda: _input, _target = _input.cuda(), _target.cuda() # inference and evaluate. model.eval() loss = criterion(model(_input), _target) tracker_val.update_metrics([loss.item()], n_samples=_input.size(0)) tracking = { "tr_loss_last_epoch": float("inf"), "val_loss_last_epoch": tracker_val.stat["loss"].avg, } else: tracking = { "tr_loss_last_epoch": float("inf"), "val_loss_last_epoch": float("inf"), } conf.logger.log( f"finish {epoch_count} epoch on-server training: train={tracking['tr_loss_last_epoch']}, val={tracking['val_loss_last_epoch']}." ) # on server training and validation. while True: epoch_count += 1 tracker_tr = RuntimeTracker(metrics_to_track=[]) tracker_val = RuntimeTracker(metrics_to_track=[]) # train on the tr_data_loader. for _ind, (_input, _target) in enumerate(data_loaders["tr_data_loader"]): # place model and data. if conf.graph.on_cuda: _input, _target = _input.cuda(), _target.cuda() # inference and update alpha model.train() optimizer.zero_grad() loss = criterion(model(_input), _target, smooth_eps=eps) tracker_tr.update_metrics([loss.item()], n_samples=_input.size(0)) loss.backward() optimizer.step() # validate on the val_data_loader. if data_loaders["val_data_loader"] is not None: for _ind, (_input, _target) in enumerate(data_loaders["val_data_loader"]): # place model and data. if conf.graph.on_cuda: _input, _target = _input.cuda(), _target.cuda() # inference and evaluate. model.eval() loss = criterion(model(_input), _target) tracker_val.update_metrics([loss.item()], n_samples=_input.size(0)) # check the condition. if (tracker_tr.stat["loss"].avg < tracking["tr_loss_last_epoch"] and tracker_val.stat["loss"].avg < tracking["val_loss_last_epoch"]): conf.logger.log( f"finish {epoch_count} epoch on-server training: train={tracker_tr()}, val={tracker_val()}: will continue training." ) final_model = copy.deepcopy(model) else: conf.logger.log( f"finish {epoch_count} epoch on-server training: train={tracker_tr()}, val={tracker_val()}: will end training." ) if conf.graph.on_cuda: final_model = final_model.cpu() del model return final_model else: conf.logger.log( f"finish {epoch_count} epoch on-server training: {tracker_tr()}" ) assert conf.fl_aggregate["epochs"] == "plateau" assert "epochs_max" in conf.fl_aggregate if (tracking["tr_loss_last_epoch"] - tracker_tr.stat["loss"].avg <= conf.fl_aggregate["plateau_tol"] ) or epoch_count >= conf.fl_aggregate["epochs_max"]: if conf.graph.on_cuda: model = model.cpu() return model # update the tracking records. tracking = { "tr_loss_last_epoch": tracker_tr.stat["loss"].avg, "val_loss_last_epoch": tracker_val.stat["loss"].avg, }
def distillation(self): # init the tracker. server_tracker = RuntimeTracker(metrics_to_track=["student_loss"], force_to_replace_metrics=True) server_best_tracker = BestPerf(best_perf=None, larger_is_better=True) # update the server generator/student n_pseudo_batches = 0 best_models = [None] # init the data iter. if self.distillation_data_loader is not None: data_iter = iter(self.distillation_data_loader) # get the client_weights from client's validation performance. client_weights = self._get_client_weights() # get the init server perf. init_perf_on_val = self.validate(model=self.init_server_student, data_loader=self.val_data_loader) self.log_fn( f"Batch {n_pseudo_batches}/{self.total_n_server_pseudo_batches}: Student Validation Acc={init_perf_on_val}." ) # iterate over dataset while n_pseudo_batches < self.total_n_server_pseudo_batches: # get the inputs. if self.distillation_data_loader is not None: try: pseudo_data = next(data_iter)[0].to(device=self.device) except StopIteration: data_iter = iter(self.distillation_data_loader) pseudo_data = next(data_iter)[0].to(device=self.device) else: if self.conf.fl_aggregate["use_data_scheme"] == "random_data": pseudo_data = self._create_data_randomly() else: raise NotImplementedError("incorrect use_data_scheme.") # get the logits. with torch.no_grad(): teacher_logits = [ _teacher(pseudo_data) for _teacher in self.client_teachers ] # steps on the same pseudo data for _ in range(self.server_local_steps): student_logits = self.server_student(pseudo_data) student_logits_activations = [ (student_logits, self.server_student.activations) ] * self.numb_teachers stud_avg_loss = self.update_student( student_logits_activations=student_logits_activations, base_solver=self.base_solver, _student=self.server_student, _teachers=self.client_teachers, _opt_student=self.optimizer_server_student, teacher_logits=teacher_logits, update_student_scheme=self.update_student_scheme, weights=client_weights, ) # after each batch. if self.use_server_model_scheduler: self.scheduler_server_student.step() # update the tracker after each batch. server_tracker.update_metrics([stud_avg_loss], n_samples=self.batch_size) if (n_pseudo_batches + 1) % self.eval_batches_freq == 0: validated_perf = self.validate( model=self.server_student, data_loader=self.val_data_loader) self.log_fn( f"Batch {n_pseudo_batches + 1}/{self.total_n_server_pseudo_batches}: Student Loss={server_tracker.stat['student_loss'].avg:02.5f}; Student Validation Acc={validated_perf}." ) server_tracker.reset() # check early stopping. if self.base_solver.check_early_stopping( model=self.server_student, model_ind=0, best_tracker=server_best_tracker, validated_perf=validated_perf, validated_perfs=self.validated_perfs, perf_index=n_pseudo_batches + 1, early_stopping_batches=self. early_stopping_server_batches, best_models=best_models, ): break n_pseudo_batches += 1 # recover the best server model use_init_server_model = False if self.return_best_model_on_val: use_init_server_model = (True if init_perf_on_val["top1"] > server_best_tracker.best_perf else False) # get the server model. if use_init_server_model: self.log_fn("use init server model instead.") best_server_dict = self.init_server_student.state_dict() else: best_server_dict = best_models[0].state_dict() # update the server model. self.server_student.load_state_dict(best_server_dict) self.server_student = self.server_student.cpu()