Ejemplo n.º 1
0
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
        """
        Load all training state from a checkpoint file.
        rank = 0 will load the checkpoint, and then broadcast it to all
        other ranks.
        """
        extra_state, self._optim_history, last_optim_state = None, [], None

        bexists = PathManager.isfile(filename)
        if bexists:
            if (self.data_parallel_rank == 0
                    # TPUs don't support broadcast yet, so load checkpoints
                    # on every worker for now
                    or self.tpu):
                state = checkpoint_utils.load_checkpoint_to_cpu(filename)
                last_optim_state = state.get("last_optimizer_state", None)

                # If doing zero_sharding, do not broadcast global optimizer
                # state. Later we will broadcast sharded states to each rank
                # to avoid memory from exploding.
                if (self.cfg.distributed_training.zero_sharding == "os"
                        and "last_optimizer_state" in state
                        and self.data_parallel_world_size > 1):
                    state["last_optimizer_state"] = "SHARDED"
            else:
                last_optim_state = None
                state = None

            if (self.data_parallel_world_size > 1
                    # disable on TPUs until they support broadcast
                    and not self.tpu):
                state = distributed_utils.broadcast_object(
                    state,
                    src_rank=0,
                    group=self.data_parallel_process_group,
                    dist_device=self.device,
                )
                if self.data_parallel_rank > 0:
                    last_optim_state = state.get("last_optimizer_state", None)

            # load model parameters
            try:
                self.get_model().load_state_dict(state["model"],
                                                 strict=True,
                                                 model_cfg=self.cfg.model)
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state["criterion"],
                                                         strict=True)
            except Exception:
                raise Exception(
                    "Cannot load model parameters from checkpoint {}; "
                    "please ensure that the architectures match.".format(
                        filename))
            extra_state = state["extra_state"]
            self._optim_history = state["optimizer_history"]

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert (
                last_optim["criterion_name"] ==
                self.get_criterion().__class__.__name__
            ), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
            assert (
                last_optim["optimizer_name"] ==
                self.optimizer.__class__.__name__
            ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(
                    last_optim["lr_scheduler_state"])

            if self.data_parallel_world_size > 1:
                last_optim_state = self.optimizer.broadcast_global_state_dict(
                    last_optim_state)
            self.optimizer.load_state_dict(last_optim_state,
                                           optimizer_overrides)

            self.set_num_updates(last_optim["num_updates"])

        if extra_state is not None:
            epoch = extra_state["train_iterator"]["epoch"]
            logger.info("loaded checkpoint {} (epoch {} @ {} updates)".format(
                filename, epoch, self.get_num_updates()))

            if "previous_training_time" in extra_state:
                self._previous_training_time = extra_state[
                    "previous_training_time"]
                self._start_time = time.time()

            self.lr_step(epoch)

            if "metrics" in extra_state and not reset_meters:
                metrics.load_state_dict(extra_state["metrics"])

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in metrics.get_meters("default"):
                    if isinstance(meter, meters.TimeMeter):
                        meter.reset()
        else:
            logger.info("no existing checkpoint found {}".format(filename))

        return extra_state
Ejemplo n.º 2
0
 def _test_broadcast_object(ref_obj, rank, group):
     obj = dist_utils.broadcast_object(ref_obj if rank == 0 else None,
                                       src_rank=0,
                                       group=group)
     assert objects_are_equal(ref_obj, obj)