def train_and_validate(conf, model, criterion, scheduler, optimizer, metrics, data_loader): print("=>>>> start training and validation.\n") # define runtime stat tracker and start the training. tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda) # get the timer. timer = conf.timer # break until finish expected full epoch training. print("=>>>> enter the training.\n") while True: dist.barrier() # configure local step. for _input, _target in data_loader["train_loader"]: model.train() scheduler.step(optimizer) # load data with timer("load_data", epoch=scheduler.epoch_): _input, _target = load_data_batch(conf, _input, _target) # inference and get current performance. with timer("forward_pass", epoch=scheduler.epoch_): optimizer.zero_grad() loss = inference(model, criterion, metrics, _input, _target, tracker_tr) with timer("backward_pass", epoch=scheduler.epoch_): loss.backward() with timer("sync_complete", epoch=scheduler.epoch_): n_bits_to_transmit = optimizer.step(timer=timer, scheduler=scheduler) # display the logging info. display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) # finish one epoch training and to decide if we want to val our model. if scheduler.epoch_ % 1 == 0: if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( tracker_tr.stat["loss"].avg): print("\nThe process diverges!!!!!Early stop it.") error_handler.abort() # each worker finish one epoch training. do_validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader) # refresh the logging cache at the begining of each epoch. tracker_tr.reset() # evaluate (and only inference) on the whole training loader. if (conf.evaluate_consensus or scheduler.is_stop()) and not conf.train_fast: # prepare the dataloader for the consensus evaluation. _data_loader = { "val_loader": _define_cv_dataset( conf, partition_type=None, dataset_type="train", force_shuffle=True, ) } # evaluate on the local model. conf.logger.log( "eval the local model on full training data.") validate( conf, model, optimizer, criterion, scheduler, metrics, data_loader=_data_loader, label="eval_local_model_on_full_training_data", force_evaluate_on_averaged_model=False, ) # evaluate on the averaged model. conf.logger.log( "eval the averaged model on full training data.") copied_model = copy.deepcopy( model.module if "DataParallel" == model.__class__.__name__ else model) optimizer.world_aggregator.agg_model(copied_model, op="avg") validate( conf, copied_model, optimizer, criterion, scheduler, metrics, data_loader=_data_loader, label="eval_averaged_model_on_full_training_data", force_evaluate_on_averaged_model=False, ) # determine if the training is finished. if scheduler.is_stop(): # save json. conf.logger.save_json() # temporarily hack the exit parallelchoco if optimizer.__class__.__name__ == "ParallelCHOCO": error_handler.abort() return # display tracking time. if (conf.graph.rank == 0 and conf.display_tracked_time and scheduler.local_index % conf.summary_freq == 0): print(timer.summary()) # reshuffle the data. if conf.reshuffle_per_epoch: print("\nReshuffle the dataset.") del data_loader gc.collect() data_loader = define_dataset(conf)
def train_and_validate( conf, model, criterion, scheduler, optimizer, metrics, data_loader ): print("=>>>> start training and validation.\n") # define runtime stat tracker and start the training. tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names) # get the timer. timer = conf.timer # break until finish expected full epoch training. print("=>>>> enter the training.\n") while True: # init the hidden state. _hidden = ( model.module.init_hidden(conf.batch_size) if "DataParallel" == model.__class__.__name__ else model.init_hidden(conf.batch_size) ) # configure local step. for batch in data_loader["train_loader"]: model.train() # repackage the hidden. _hidden = ( model.module.repackage_hidden(_hidden) if "DataParallel" == model.__class__.__name__ else model.repackage_hidden(_hidden) ) # load data with timer("load_data", epoch=scheduler.epoch_): _input = batch.text[ :, conf.graph.rank * conf.batch_size : (conf.graph.rank + 1) * conf.batch_size, ] _target = batch.target[ :, conf.graph.rank * conf.batch_size : (conf.graph.rank + 1) * conf.batch_size, ] _input, _target = load_data_batch(conf, _input, _target) # inference and get current performance. with timer("forward_pass", epoch=scheduler.epoch_): optimizer.zero_grad() loss, _hidden = inference( conf, model, criterion, metrics, _input, _target, _hidden, tracker_tr, ) with timer("backward_pass", epoch=scheduler.epoch_): loss.backward() with timer("sync_complete", epoch=scheduler.epoch_): # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(model.parameters(), conf.rnn_clip) n_bits_to_transmit = optimizer.step(timer=timer) scheduler.step() # display the logging info. display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) # finish one epoch training and to decide if we want to val our model. if scheduler.epoch_ % 1 == 0: if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( tracker_tr.stat["loss"].avg ): print("\nThe process diverges!!!!!Early stop it.") error_handler.abort() # each worker finish one epoch training. do_validate( conf, model, optimizer, criterion, scheduler, metrics, data_loader ) # refresh the logging cache at the begining of each epoch. tracker_tr.reset() # determine if the training is finished. if scheduler.is_stop(): conf.logger.save_json() return # display tracking time. if ( conf.graph.rank == 0 and conf.display_tracked_time and scheduler.local_index % conf.summary_freq == 0 ): print(timer.summary())
def _train(self): self.model.train() # init the model and dataloader. if self.conf.graph.on_cuda: self.model = self.model.cuda() self.train_loader, _ = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], # localdata_id start from 0 to the # of clients - 1. # client_id starts from 1 to the # of clients. localdata_id=self.conf.graph.client_id - 1, is_train=True, data_partitioner=self.data_partitioner, ) # define optimizer, scheduler and runtime tracker. self.optimizer = create_optimizer.define_optimizer( self.conf, model=self.model, optimizer_name=self.conf.optimizer) self.scheduler = create_scheduler.Scheduler(self.conf, optimizer=self.optimizer) self.tracker = RuntimeTracker( metrics_to_track=self.metrics.metric_names) self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) enters the local training phase (current communication rounds={self.conf.graph.comm_round})." ) # efficient local training. if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # entering local updates and will finish only after reaching the expected local_n_epochs. while True: for _input, _target in self.train_loader: # load data with self.timer("load_data", epoch=self.scheduler.epoch_): data_batch = create_dataset.load_data_batch( self.conf, _input, _target, is_training=True) # inference and get current performance. with self.timer("forward_pass", epoch=self.scheduler.epoch_): self.optimizer.zero_grad() loss, output = self._inference(data_batch) # in case we need self distillation to penalize the local training # (avoid catastrophic forgetting). self._local_training_with_self_distillation( loss, output, data_batch) with self.timer("backward_pass", epoch=self.scheduler.epoch_): loss.backward() self._add_grad_from_prox_regularized_loss() self.optimizer.step() self.scheduler.step() # efficient local training. with self.timer("compress_model", epoch=self.scheduler.epoch_): if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # display the logging info. display_training_stat(self.conf, self.scheduler, self.tracker) # display tracking time. if (self.conf.display_tracked_time and self.scheduler.local_index % self.conf.summary_freq == 0): self.conf.logger.log(self.timer.summary()) # check divergence. if self.tracker.stat["loss"].avg > 1e3 or np.isnan( self.tracker.stat["loss"].avg): self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) diverges!!!!!Early stop it." ) self._terminate_comm_round() return # check stopping condition. if self._is_finished_one_comm_round(): self._terminate_comm_round() return # refresh the logging cache at the end of each epoch. self.tracker.reset() if self.conf.logger.meet_cache_limit(): self.conf.logger.save_json()
def train_and_validate( conf, model, criterion, scheduler, optimizer, metrics, data_loader ): print("=>>>> start training and validation.") # define runtime stat tracker and start the training. tracker_tr = RuntimeTracker( metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda ) # get the timer. timer = conf.timer # break until finish expected full epoch training. print("=>>>> enter the training.\n") while True: dist.barrier() # configure local step. for _input, _target in data_loader["train_loader"]: model.train() # load data with timer("load_data", epoch=scheduler.epoch_): _input, _target = load_data_batch(conf, _input, _target) # inference and get current performance. with timer("forward_pass", epoch=scheduler.epoch_): optimizer.zero_grad() loss = inference(model, criterion, metrics, _input, _target, tracker_tr) with timer("backward_pass", epoch=scheduler.epoch_): loss.backward() with timer("sync_and_apply_grad", epoch=scheduler.epoch_): n_bits_to_transmit = optimizer.step(timer=timer, scheduler=scheduler) scheduler.step() # display the logging info. display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) # finish one epoch training and to decide if we want to val our model. if scheduler.epoch_ % 1 == 0: if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( tracker_tr.stat["loss"].avg ): print("\nThe process diverges!!!!!Early stop it.") error_handler.abort() # each worker finish one epoch training. do_validate( conf, model, optimizer, criterion, scheduler, metrics, data_loader ) # refresh the logging cache at the begining of each epoch. tracker_tr.reset() # determine if the training is finished. if scheduler.is_stop(): # save json. conf.logger.save_json() return # display tracking time. if ( conf.graph.rank == 0 and conf.display_tracked_time and scheduler.local_index % conf.summary_freq == 0 ): print(timer.summary()) # reshuffle the data. if conf.reshuffle_per_epoch: print("\nReshuffle the dataset.") del data_loader gc.collect() data_loader = define_dataset(conf)