def _finishing(self, _comm_round=None): self.conf.logger.save_json() self.conf.logger.log(f"Master finished the federated learning.") self.conf.is_finished = True self.conf.finished_comm = _comm_round checkpoint.save_arguments(self.conf) os.system(f"echo {self.conf.checkpoint_root} >> {self.conf.job_id}")
def __init__(self, conf): self.conf = conf # some initializations. self.client_ids = list(range(1, 1 + conf.n_clients)) self.world_ids = list(range(1, 1 + conf.n_participated)) # create model as well as their corresponding state_dicts. _, self.master_model = create_model.define_model( conf, to_consistent_model=False) self.used_client_archs = set([ create_model.determine_arch(conf, client_id, use_complex_arch=True) for client_id in range(1, 1 + conf.n_clients) ]) self.conf.used_client_archs = self.used_client_archs conf.logger.log(f"The client will use archs={self.used_client_archs}.") conf.logger.log("Master created model templates for client models.") self.client_models = dict( create_model.define_model( conf, to_consistent_model=False, arch=arch) for arch in self.used_client_archs) self.clientid2arch = dict(( client_id, create_model.determine_arch( conf, client_id=client_id, use_complex_arch=True), ) for client_id in range(1, 1 + conf.n_clients)) self.conf.clientid2arch = self.clientid2arch conf.logger.log( f"Master initialize the clientid2arch mapping relations: {self.clientid2arch}." ) # create dataset (as well as the potential data_partitioner) for training. dist.barrier() self.dataset = create_dataset.define_dataset(conf, data=conf.data) _, self.data_partitioner = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], localdata_id=0, # random id here. is_train=True, data_partitioner=None, ) conf.logger.log( f"Master initialized the local training data with workers.") # create val loader. # right now we just ignore the case of partitioned_by_user. if self.dataset["val"] is not None: assert not conf.partitioned_by_user self.val_loader, _ = create_dataset.define_data_loader( conf, self.dataset["val"], is_train=False) conf.logger.log(f"Master initialized val data.") else: self.val_loader = None # create test loaders. # localdata_id start from 0 to the # of clients - 1. client_id starts from 1 to the # of clients. if conf.partitioned_by_user: self.test_loaders = [] for localdata_id in self.client_ids: test_loader, _ = create_dataset.define_data_loader( conf, self.dataset["test"], localdata_id=localdata_id - 1, is_train=False, shuffle=False, ) self.test_loaders.append(copy.deepcopy(test_loader)) else: test_loader, _ = create_dataset.define_data_loader( conf, self.dataset["test"], is_train=False) self.test_loaders = [test_loader] # define the criterion and metrics. self.criterion = cross_entropy.CrossEntropyLoss(reduction="mean") self.metrics = create_metrics.Metrics(self.master_model, task="classification") conf.logger.log(f"Master initialized model/dataset/criterion/metrics.") # define the aggregators. self.aggregator = create_aggregator.Aggregator( conf, model=self.master_model, criterion=self.criterion, metrics=self.metrics, dataset=self.dataset, test_loaders=self.test_loaders, clientid2arch=self.clientid2arch, ) self.coordinator = create_coordinator.Coordinator(conf, self.metrics) conf.logger.log(f"Master initialized the aggregator/coordinator.\n") # define early_stopping_tracker. self.early_stopping_tracker = EarlyStoppingTracker( patience=conf.early_stopping_rounds) # save arguments to disk. conf.is_finished = False checkpoint.save_arguments(conf)
def main(conf): try: init_distributed_world(conf, backend=conf.backend) conf.distributed = True and conf.n_mpi_process > 1 except AttributeError as e: print(f"failed to init the distributed world: {e}.") conf.distributed = False # init the config. init_config(conf) # define the timer for different operations. # if we choose the `train_fast` mode, then we will not track the time. conf.timer = Timer( verbosity_level=1 if conf.track_time and not conf.train_fast else 0, log_fn=conf.logger.log_metric, on_cuda=conf.on_cuda, ) # create dataset. data_loader = create_dataset.define_dataset(conf, force_shuffle=True) # create model model = create_model.define_model(conf, data_loader=data_loader) # define the optimizer. optimizer = create_optimizer.define_optimizer(conf, model) # define the lr scheduler. scheduler = create_scheduler.Scheduler(conf, optimizer) # add model with data-parallel wrapper. if conf.graph.on_cuda: if conf.n_sub_process > 1: model = torch.nn.DataParallel(model, device_ids=conf.graph.device) # (optional) reload checkpoint try: checkpoint.maybe_resume_from_checkpoint(conf, model, optimizer, scheduler) except RuntimeError as e: conf.logger.log(f"Resume Error: {e}") conf.resumed = False # train amd evaluate model. if "rnn_lm" in conf.arch: from pcode.distributed_running_nlp import train_and_validate # safety check. assert (conf.n_sub_process == 1 ), "our current data-parallel wrapper does not support RNN." # define the criterion and metrics. criterion = nn.CrossEntropyLoss(reduction="mean") criterion = criterion.cuda() if conf.graph.on_cuda else criterion metrics = create_metrics.Metrics( model.module if "DataParallel" == model.__class__.__name__ else model, task="language_modeling", ) # define the best_perf tracker, either empty or from the checkpoint. best_tracker = stat_tracker.BestPerf( best_perf=None if "best_perf" not in conf else conf.best_perf, larger_is_better=False, ) scheduler.set_best_tracker(best_tracker) # get train_and_validate_func train_and_validate_fn = train_and_validate else: from pcode.distributed_running_cv import train_and_validate # define the criterion and metrics. criterion = nn.CrossEntropyLoss(reduction="mean") criterion = criterion.cuda() if conf.graph.on_cuda else criterion metrics = create_metrics.Metrics( model.module if "DataParallel" == model.__class__.__name__ else model, task="classification", ) # define the best_perf tracker, either empty or from the checkpoint. best_tracker = stat_tracker.BestPerf( best_perf=None if "best_perf" not in conf else conf.best_perf, larger_is_better=True, ) scheduler.set_best_tracker(best_tracker) # get train_and_validate_func train_and_validate_fn = train_and_validate # save arguments to disk. checkpoint.save_arguments(conf) # start training. train_and_validate_fn( conf, model=model, criterion=criterion, scheduler=scheduler, optimizer=optimizer, metrics=metrics, data_loader=data_loader, )