def main(conf): try: init_distributed_world(conf, backend=conf.backend) conf.distributed = True and conf.n_mpi_process > 1 except AttributeError as e: print(f"failed to init the distributed world: {e}.") conf.distributed = False # init the config. init_config(conf) # define the timer for different operations. # if we choose the `train_fast` mode, then we will not track the time. conf.timer = Timer( verbosity_level=1 if conf.track_time and not conf.train_fast else 0, log_fn=conf.logger.log_metric, on_cuda=conf.on_cuda, ) # create dataset. data_loader = create_dataset.define_dataset(conf, force_shuffle=True) # create model model = create_model.define_model(conf, data_loader=data_loader) # define the optimizer. optimizer = create_optimizer.define_optimizer(conf, model) # define the lr scheduler. scheduler = create_scheduler.Scheduler(conf, optimizer) # add model with data-parallel wrapper. if conf.graph.on_cuda: if conf.n_sub_process > 1: model = torch.nn.DataParallel(model, device_ids=conf.graph.device) # (optional) reload checkpoint try: checkpoint.maybe_resume_from_checkpoint(conf, model, optimizer, scheduler) except RuntimeError as e: conf.logger.log(f"Resume Error: {e}") conf.resumed = False # train amd evaluate model. if "rnn_lm" in conf.arch: from pcode.distributed_running_nlp import train_and_validate # safety check. assert (conf.n_sub_process == 1 ), "our current data-parallel wrapper does not support RNN." # define the criterion and metrics. criterion = nn.CrossEntropyLoss(reduction="mean") criterion = criterion.cuda() if conf.graph.on_cuda else criterion metrics = create_metrics.Metrics( model.module if "DataParallel" == model.__class__.__name__ else model, task="language_modeling", ) # define the best_perf tracker, either empty or from the checkpoint. best_tracker = stat_tracker.BestPerf( best_perf=None if "best_perf" not in conf else conf.best_perf, larger_is_better=False, ) scheduler.set_best_tracker(best_tracker) # get train_and_validate_func train_and_validate_fn = train_and_validate else: from pcode.distributed_running_cv import train_and_validate # define the criterion and metrics. criterion = nn.CrossEntropyLoss(reduction="mean") criterion = criterion.cuda() if conf.graph.on_cuda else criterion metrics = create_metrics.Metrics( model.module if "DataParallel" == model.__class__.__name__ else model, task="classification", ) # define the best_perf tracker, either empty or from the checkpoint. best_tracker = stat_tracker.BestPerf( best_perf=None if "best_perf" not in conf else conf.best_perf, larger_is_better=True, ) scheduler.set_best_tracker(best_tracker) # get train_and_validate_func train_and_validate_fn = train_and_validate # save arguments to disk. checkpoint.save_arguments(conf) # start training. train_and_validate_fn( conf, model=model, criterion=criterion, scheduler=scheduler, optimizer=optimizer, metrics=metrics, data_loader=data_loader, )
def _train(self): self.model.train() # init the model and dataloader. if self.conf.graph.on_cuda: self.model = self.model.cuda() self.train_loader, _ = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], # localdata_id start from 0 to the # of clients - 1. # client_id starts from 1 to the # of clients. localdata_id=self.conf.graph.client_id - 1, is_train=True, data_partitioner=self.data_partitioner, ) # define optimizer, scheduler and runtime tracker. self.optimizer = create_optimizer.define_optimizer( self.conf, model=self.model, optimizer_name=self.conf.optimizer) self.scheduler = create_scheduler.Scheduler(self.conf, optimizer=self.optimizer) self.tracker = RuntimeTracker( metrics_to_track=self.metrics.metric_names) self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) enters the local training phase (current communication rounds={self.conf.graph.comm_round})." ) # efficient local training. if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # entering local updates and will finish only after reaching the expected local_n_epochs. while True: for _input, _target in self.train_loader: # load data with self.timer("load_data", epoch=self.scheduler.epoch_): data_batch = create_dataset.load_data_batch( self.conf, _input, _target, is_training=True) # inference and get current performance. with self.timer("forward_pass", epoch=self.scheduler.epoch_): self.optimizer.zero_grad() loss, output = self._inference(data_batch) # in case we need self distillation to penalize the local training # (avoid catastrophic forgetting). self._local_training_with_self_distillation( loss, output, data_batch) with self.timer("backward_pass", epoch=self.scheduler.epoch_): loss.backward() self._add_grad_from_prox_regularized_loss() self.optimizer.step() self.scheduler.step() # efficient local training. with self.timer("compress_model", epoch=self.scheduler.epoch_): if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # display the logging info. display_training_stat(self.conf, self.scheduler, self.tracker) # display tracking time. if (self.conf.display_tracked_time and self.scheduler.local_index % self.conf.summary_freq == 0): self.conf.logger.log(self.timer.summary()) # check divergence. if self.tracker.stat["loss"].avg > 1e3 or np.isnan( self.tracker.stat["loss"].avg): self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) diverges!!!!!Early stop it." ) self._terminate_comm_round() return # check stopping condition. if self._is_finished_one_comm_round(): self._terminate_comm_round() return # refresh the logging cache at the end of each epoch. self.tracker.reset() if self.conf.logger.meet_cache_limit(): self.conf.logger.save_json()