def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): """ Load all training state from a checkpoint file. rank = 0 will load the checkpoint, and then broadcast it to all other ranks. """ extra_state, self._optim_history, last_optim_state = None, [], None bexists = PathManager.isfile(filename) if bexists: if (self.data_parallel_rank == 0 # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu): state = checkpoint_utils.load_checkpoint_to_cpu(filename) last_optim_state = state.get("last_optimizer_state", None) # If doing zero_sharding, do not broadcast global optimizer # state. Later we will broadcast sharded states to each rank # to avoid memory from exploding. if (self.cfg.distributed_training.zero_sharding == "os" and "last_optimizer_state" in state and self.data_parallel_world_size > 1): state["last_optimizer_state"] = "SHARDED" else: last_optim_state = None state = None if (self.data_parallel_world_size > 1 # disable on TPUs until they support broadcast and not self.tpu): state = distributed_utils.broadcast_object( state, src_rank=0, group=self.data_parallel_process_group, dist_device=self.device, ) if self.data_parallel_rank > 0: last_optim_state = state.get("last_optimizer_state", None) # load model parameters try: self.get_model().load_state_dict(state["model"], strict=True, model_cfg=self.cfg.model) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict(state["criterion"], strict=True) except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " "please ensure that the architectures match.".format( filename)) extra_state = state["extra_state"] self._optim_history = state["optimizer_history"] if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." if not reset_lr_scheduler: self.lr_scheduler.load_state_dict( last_optim["lr_scheduler_state"]) if self.data_parallel_world_size > 1: last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim["num_updates"]) if extra_state is not None: epoch = extra_state["train_iterator"]["epoch"] logger.info("loaded checkpoint {} (epoch {} @ {} updates)".format( filename, epoch, self.get_num_updates())) if "previous_training_time" in extra_state: self._previous_training_time = extra_state[ "previous_training_time"] self._start_time = time.time() self.lr_step(epoch) if "metrics" in extra_state and not reset_meters: metrics.load_state_dict(extra_state["metrics"]) # reset TimeMeters, since their start times don't make sense anymore for meter in metrics.get_meters("default"): if isinstance(meter, meters.TimeMeter): meter.reset() else: logger.info("no existing checkpoint found {}".format(filename)) return extra_state
def _test_broadcast_object(ref_obj, rank, group): obj = dist_utils.broadcast_object(ref_obj if rank == 0 else None, src_rank=0, group=group) assert objects_are_equal(ref_obj, obj)