def init_config(conf): # define the graph for the computation. conf.graph = topology.define_graph_topology( world=conf.world, world_conf=conf.world_conf, n_participated=conf.n_participated, on_cuda=conf.on_cuda, ) conf.graph.rank = dist.get_rank() # init related to randomness on cpu. if not conf.same_seed_process: conf.manual_seed = 1000 * conf.manual_seed + conf.graph.rank conf.random_state = np.random.RandomState(conf.manual_seed) torch.manual_seed(conf.manual_seed) # configure cuda related. if conf.graph.on_cuda: assert torch.cuda.is_available() torch.cuda.manual_seed(conf.manual_seed) torch.cuda.set_device(conf.graph.primary_device) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True if conf.train_fast else False # init the model arch info. conf.arch_info = (param_parser.dict_parser(conf.complex_arch) if conf.complex_arch is not None else { "master": conf.arch, "worker": conf.arch }) conf.arch_info["worker"] = conf.arch_info["worker"].split(":") # parse the fl_aggregate scheme. conf._fl_aggregate = conf.fl_aggregate conf.fl_aggregate = (param_parser.dict_parser(conf.fl_aggregate) if conf.fl_aggregate is not None else conf.fl_aggregate) [ setattr(conf, f"fl_aggregate_{k}", v) for k, v in conf.fl_aggregate.items() ] # define checkpoint for logging (for federated learning server). checkpoint.init_checkpoint(conf, rank=str(conf.graph.rank)) # configure logger. conf.logger = logging.Logger(conf.checkpoint_dir) # display the arguments' info. if conf.graph.rank == 0: logging.display_args(conf) # sync the processes. dist.barrier()
def maybe_resume_from_checkpoint(conf, model, optimizer, scheduler): if conf.resume: if conf.checkpoint_index is not None: # reload model from a specific checkpoint index. checkpoint_index = "_epoch_" + conf.checkpoint_index else: # reload model from the latest checkpoint. checkpoint_index = "" checkpoint_path = join( conf.resume, str(conf.graph.rank), "checkpoint{}.pth.tar".format(checkpoint_index), ) print("try to load previous model from the path:{}".format( checkpoint_path)) if isfile(checkpoint_path): print("=> loading checkpoint {} for {}".format( conf.resume, conf.graph.rank)) # get checkpoint. checkpoint = torch.load(checkpoint_path, map_location="cpu") # restore some run-time info. scheduler.update_from_checkpoint(checkpoint) # reset path for log. try: remove_folder(conf.checkpoint_root) except RuntimeError as e: print(f"ignore the error={e}") conf.checkpoint_root = conf.resume conf.checkpoint_dir = join(conf.resume, str(conf.graph.rank)) # restore model. model.load_state_dict(checkpoint["state_dict"]) # restore optimizer. optimizer.load_state_dict(checkpoint["optimizer"]) # logging. print("=> loaded model from path '{}' checkpointed at (epoch {})". format(conf.resume, checkpoint["current_epoch"])) # configure logger. conf.logger = logging.Logger(conf.checkpoint_dir) # try to solve memory issue. del checkpoint torch.cuda.empty_cache() gc.collect() return else: print("=> no checkpoint found at '{}'".format(conf.resume))
def init_config(conf): # define the graph for the computation. cur_rank = dist.get_rank() if conf.distributed else 0 conf.graph = topology.define_graph_topology( graph_topology=conf.graph_topology, world=conf.world, n_mpi_process=conf.n_mpi_process, # the # of total main processes. n_sub_process=conf. n_sub_process, # the # of subprocess for each main process. comm_device=conf.comm_device, on_cuda=conf.on_cuda, rank=cur_rank, ) conf.is_centralized = conf.graph_topology == "complete" # re-configure batch_size if sub_process > 1. if conf.n_sub_process > 1: conf.batch_size = conf.batch_size * conf.n_sub_process # configure cuda related. if conf.graph.on_cuda: assert torch.cuda.is_available() torch.manual_seed(conf.manual_seed) torch.cuda.manual_seed(conf.manual_seed) torch.cuda.set_device(conf.graph.device[0]) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True if conf.train_fast else False # define checkpoint for logging. checkpoint.init_checkpoint(conf) # configure logger. conf.logger = logging.Logger(conf.checkpoint_dir) # display the arguments' info. logging.display_args(conf)