def __init__(self, conf): self.conf = conf # some initializations. self.rank = conf.graph.rank conf.graph.worker_id = conf.graph.rank self.device = torch.device( "cuda" if self.conf.graph.on_cuda else "cpu") # define the timer for different operations. # if we choose the `train_fast` mode, then we will not track the time. self.timer = Timer( verbosity_level=1 if conf.track_time and not conf.train_fast else 0, log_fn=conf.logger.log_metric, ) # create dataset (as well as the potential data_partitioner) for training. dist.barrier() self.dataset = create_dataset.define_dataset(conf, data=conf.data) _, self.data_partitioner = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], localdata_id=0, # random id here. is_train=True, data_partitioner=None, ) conf.logger.log( f"Worker-{self.conf.graph.worker_id} initialized the local training data with Master." ) # define the criterion. self.criterion = nn.CrossEntropyLoss(reduction="mean") # define the model compression operators. if conf.local_model_compression is not None: if conf.local_model_compression == "quantization": self.model_compression_fn = compressor.ModelQuantization(conf) conf.logger.log( f"Worker-{conf.graph.worker_id} initialized dataset/criterion.\n")
def __init__(self, conf): self.conf = conf # some initializations. self.client_ids = list(range(1, 1 + conf.n_clients)) self.world_ids = list(range(1, 1 + conf.n_participated)) # create model as well as their corresponding state_dicts. _, self.master_model = create_model.define_model( conf, to_consistent_model=False) self.used_client_archs = set([ create_model.determine_arch(conf, client_id, use_complex_arch=True) for client_id in range(1, 1 + conf.n_clients) ]) self.conf.used_client_archs = self.used_client_archs conf.logger.log(f"The client will use archs={self.used_client_archs}.") conf.logger.log("Master created model templates for client models.") self.client_models = dict( create_model.define_model( conf, to_consistent_model=False, arch=arch) for arch in self.used_client_archs) self.clientid2arch = dict(( client_id, create_model.determine_arch( conf, client_id=client_id, use_complex_arch=True), ) for client_id in range(1, 1 + conf.n_clients)) self.conf.clientid2arch = self.clientid2arch conf.logger.log( f"Master initialize the clientid2arch mapping relations: {self.clientid2arch}." ) # create dataset (as well as the potential data_partitioner) for training. dist.barrier() self.dataset = create_dataset.define_dataset(conf, data=conf.data) _, self.data_partitioner = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], localdata_id=0, # random id here. is_train=True, data_partitioner=None, ) conf.logger.log( f"Master initialized the local training data with workers.") # create val loader. # right now we just ignore the case of partitioned_by_user. if self.dataset["val"] is not None: assert not conf.partitioned_by_user self.val_loader, _ = create_dataset.define_data_loader( conf, self.dataset["val"], is_train=False) conf.logger.log(f"Master initialized val data.") else: self.val_loader = None # create test loaders. # localdata_id start from 0 to the # of clients - 1. client_id starts from 1 to the # of clients. if conf.partitioned_by_user: self.test_loaders = [] for localdata_id in self.client_ids: test_loader, _ = create_dataset.define_data_loader( conf, self.dataset["test"], localdata_id=localdata_id - 1, is_train=False, shuffle=False, ) self.test_loaders.append(copy.deepcopy(test_loader)) else: test_loader, _ = create_dataset.define_data_loader( conf, self.dataset["test"], is_train=False) self.test_loaders = [test_loader] # define the criterion and metrics. self.criterion = cross_entropy.CrossEntropyLoss(reduction="mean") self.metrics = create_metrics.Metrics(self.master_model, task="classification") conf.logger.log(f"Master initialized model/dataset/criterion/metrics.") # define the aggregators. self.aggregator = create_aggregator.Aggregator( conf, model=self.master_model, criterion=self.criterion, metrics=self.metrics, dataset=self.dataset, test_loaders=self.test_loaders, clientid2arch=self.clientid2arch, ) self.coordinator = create_coordinator.Coordinator(conf, self.metrics) conf.logger.log(f"Master initialized the aggregator/coordinator.\n") # define early_stopping_tracker. self.early_stopping_tracker = EarlyStoppingTracker( patience=conf.early_stopping_rounds) # save arguments to disk. conf.is_finished = False checkpoint.save_arguments(conf)
def train_and_validate(conf, model, criterion, scheduler, optimizer, metrics, data_loader): print("=>>>> start training and validation.\n") # define runtime stat tracker and start the training. tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda) # get the timer. timer = conf.timer # break until finish expected full epoch training. print("=>>>> enter the training.\n") while True: dist.barrier() # configure local step. for _input, _target in data_loader["train_loader"]: model.train() scheduler.step(optimizer) # load data with timer("load_data", epoch=scheduler.epoch_): _input, _target = load_data_batch(conf, _input, _target) # inference and get current performance. with timer("forward_pass", epoch=scheduler.epoch_): optimizer.zero_grad() loss = inference(model, criterion, metrics, _input, _target, tracker_tr) with timer("backward_pass", epoch=scheduler.epoch_): loss.backward() with timer("sync_complete", epoch=scheduler.epoch_): n_bits_to_transmit = optimizer.step(timer=timer, scheduler=scheduler) # display the logging info. display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) # finish one epoch training and to decide if we want to val our model. if scheduler.epoch_ % 1 == 0: if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( tracker_tr.stat["loss"].avg): print("\nThe process diverges!!!!!Early stop it.") error_handler.abort() # each worker finish one epoch training. do_validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader) # refresh the logging cache at the begining of each epoch. tracker_tr.reset() # evaluate (and only inference) on the whole training loader. if (conf.evaluate_consensus or scheduler.is_stop()) and not conf.train_fast: # prepare the dataloader for the consensus evaluation. _data_loader = { "val_loader": _define_cv_dataset( conf, partition_type=None, dataset_type="train", force_shuffle=True, ) } # evaluate on the local model. conf.logger.log( "eval the local model on full training data.") validate( conf, model, optimizer, criterion, scheduler, metrics, data_loader=_data_loader, label="eval_local_model_on_full_training_data", force_evaluate_on_averaged_model=False, ) # evaluate on the averaged model. conf.logger.log( "eval the averaged model on full training data.") copied_model = copy.deepcopy( model.module if "DataParallel" == model.__class__.__name__ else model) optimizer.world_aggregator.agg_model(copied_model, op="avg") validate( conf, copied_model, optimizer, criterion, scheduler, metrics, data_loader=_data_loader, label="eval_averaged_model_on_full_training_data", force_evaluate_on_averaged_model=False, ) # determine if the training is finished. if scheduler.is_stop(): # save json. conf.logger.save_json() # temporarily hack the exit parallelchoco if optimizer.__class__.__name__ == "ParallelCHOCO": error_handler.abort() return # display tracking time. if (conf.graph.rank == 0 and conf.display_tracked_time and scheduler.local_index % conf.summary_freq == 0): print(timer.summary()) # reshuffle the data. if conf.reshuffle_per_epoch: print("\nReshuffle the dataset.") del data_loader gc.collect() data_loader = define_dataset(conf)
def main(conf): try: init_distributed_world(conf, backend=conf.backend) conf.distributed = True and conf.n_mpi_process > 1 except AttributeError as e: print(f"failed to init the distributed world: {e}.") conf.distributed = False # init the config. init_config(conf) # define the timer for different operations. # if we choose the `train_fast` mode, then we will not track the time. conf.timer = Timer( verbosity_level=1 if conf.track_time and not conf.train_fast else 0, log_fn=conf.logger.log_metric, on_cuda=conf.on_cuda, ) # create dataset. data_loader = create_dataset.define_dataset(conf, force_shuffle=True) # create model model = create_model.define_model(conf, data_loader=data_loader) # define the optimizer. optimizer = create_optimizer.define_optimizer(conf, model) # define the lr scheduler. scheduler = create_scheduler.Scheduler(conf, optimizer) # add model with data-parallel wrapper. if conf.graph.on_cuda: if conf.n_sub_process > 1: model = torch.nn.DataParallel(model, device_ids=conf.graph.device) # (optional) reload checkpoint try: checkpoint.maybe_resume_from_checkpoint(conf, model, optimizer, scheduler) except RuntimeError as e: conf.logger.log(f"Resume Error: {e}") conf.resumed = False # train amd evaluate model. if "rnn_lm" in conf.arch: from pcode.distributed_running_nlp import train_and_validate # safety check. assert (conf.n_sub_process == 1 ), "our current data-parallel wrapper does not support RNN." # define the criterion and metrics. criterion = nn.CrossEntropyLoss(reduction="mean") criterion = criterion.cuda() if conf.graph.on_cuda else criterion metrics = create_metrics.Metrics( model.module if "DataParallel" == model.__class__.__name__ else model, task="language_modeling", ) # define the best_perf tracker, either empty or from the checkpoint. best_tracker = stat_tracker.BestPerf( best_perf=None if "best_perf" not in conf else conf.best_perf, larger_is_better=False, ) scheduler.set_best_tracker(best_tracker) # get train_and_validate_func train_and_validate_fn = train_and_validate else: from pcode.distributed_running_cv import train_and_validate # define the criterion and metrics. criterion = nn.CrossEntropyLoss(reduction="mean") criterion = criterion.cuda() if conf.graph.on_cuda else criterion metrics = create_metrics.Metrics( model.module if "DataParallel" == model.__class__.__name__ else model, task="classification", ) # define the best_perf tracker, either empty or from the checkpoint. best_tracker = stat_tracker.BestPerf( best_perf=None if "best_perf" not in conf else conf.best_perf, larger_is_better=True, ) scheduler.set_best_tracker(best_tracker) # get train_and_validate_func train_and_validate_fn = train_and_validate # save arguments to disk. checkpoint.save_arguments(conf) # start training. train_and_validate_fn( conf, model=model, criterion=criterion, scheduler=scheduler, optimizer=optimizer, metrics=metrics, data_loader=data_loader, )
def train_and_validate( conf, model, criterion, scheduler, optimizer, metrics, data_loader ): print("=>>>> start training and validation.") # define runtime stat tracker and start the training. tracker_tr = RuntimeTracker( metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda ) # get the timer. timer = conf.timer # break until finish expected full epoch training. print("=>>>> enter the training.\n") while True: dist.barrier() # configure local step. for _input, _target in data_loader["train_loader"]: model.train() # load data with timer("load_data", epoch=scheduler.epoch_): _input, _target = load_data_batch(conf, _input, _target) # inference and get current performance. with timer("forward_pass", epoch=scheduler.epoch_): optimizer.zero_grad() loss = inference(model, criterion, metrics, _input, _target, tracker_tr) with timer("backward_pass", epoch=scheduler.epoch_): loss.backward() with timer("sync_and_apply_grad", epoch=scheduler.epoch_): n_bits_to_transmit = optimizer.step(timer=timer, scheduler=scheduler) scheduler.step() # display the logging info. display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) # finish one epoch training and to decide if we want to val our model. if scheduler.epoch_ % 1 == 0: if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( tracker_tr.stat["loss"].avg ): print("\nThe process diverges!!!!!Early stop it.") error_handler.abort() # each worker finish one epoch training. do_validate( conf, model, optimizer, criterion, scheduler, metrics, data_loader ) # refresh the logging cache at the begining of each epoch. tracker_tr.reset() # determine if the training is finished. if scheduler.is_stop(): # save json. conf.logger.save_json() return # display tracking time. if ( conf.graph.rank == 0 and conf.display_tracked_time and scheduler.local_index % conf.summary_freq == 0 ): print(timer.summary()) # reshuffle the data. if conf.reshuffle_per_epoch: print("\nReshuffle the dataset.") del data_loader gc.collect() data_loader = define_dataset(conf)