Ejemplo n.º 1
0
    def __init__(self, optimizer, device_placement=True, scaler=None):
        self.optimizer = optimizer
        self.scaler = scaler
        self.state = AcceleratorState()

        # Handle device placement
        if device_placement:
            state_dict = self.optimizer.state_dict()
            if self.state.distributed_type == DistributedType.TPU:
                xm.send_cpu_data_to_device(state_dict, self.state.device)
            else:
                state_dict = move_to_device(state_dict, self.state.device)
            self.optimizer.load_state_dict(state_dict)
Ejemplo n.º 2
0
 def _worker(self, dqueue):
     device = torch.device(dqueue.device)
     while True:
         batch = self._get_batch(dqueue)
         if not batch:
             break
         batch = xm.send_cpu_data_to_device(batch, device)
         for data in batch:
             dqueue.queue.put(data)
     dqueue.queue.close_write()
Ejemplo n.º 3
0
    def __init__(self, args, task, model, criterion, quantizer=None):
        self.args = args
        self.task = task

        # catalog shared parameters
        shared_params = _catalog_shared_params(model)

        self.tpu = getattr(args, 'tpu', False)
        self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
        if self.cuda:
            self.device = torch.device('cuda')
        elif self.tpu:
            self.device = utils.get_tpu_device(args)
        else:
            self.device = torch.device('cpu')

        # copy model and criterion to current device/dtype
        self._criterion = criterion
        self._model = model
        if self.tpu:
            import torch_xla.core.xla_model as xm
            self._model = xm.send_cpu_data_to_device(self._model, self.device)
        if args.fp16:
            self._criterion = self._criterion.half()
            self._model = self._model.half()
        elif args.bf16:
            self._criterion = self._criterion.to(dtype=torch.bfloat16)
            self._model = self._model.to(dtype=torch.bfloat16)
        self._criterion = self._criterion.to(device=self.device)
        self._model = self._model.to(device=self.device)

        # check that shared parameters are preserved after device transfer
        for shared_param in shared_params:
            ref = _get_module_by_path(self._model, shared_param[0])
            for path in shared_param[1:]:
                logger.info(
                    'detected shared parameter: {} <- {}'.format(shared_param[0], path)
                )
                _set_module_by_path(self._model, path, ref)

        self._dummy_batch = "DUMMY"  # indicates we don't have a dummy batch at first
        self._lr_scheduler = None
        self._num_updates = 0
        self._num_xla_compiles = 0  # for TPUs
        self._optim_history = None
        self._optimizer = None
        self._warn_once = set()
        self._wrapped_criterion = None
        self._wrapped_model = None

        # TODO(myleott): support tpu
        if self.cuda and self.data_parallel_world_size > 1:
            self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
        else:
            self._grad_norm_buf = None

        self.quantizer = quantizer
        if self.quantizer is not None:
            self.quantizer.set_trainer(self)

        # get detailed cuda environment
        if self.cuda:
            self.cuda_env = utils.CudaEnvironment()
            if self.data_parallel_world_size > 1:
                self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env)
            else:
                self.cuda_env_arr = [self.cuda_env]
            if self.data_parallel_rank == 0:
                utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
        else:
            self.cuda_env = None
            self.cuda_env_arr = None

        metrics.log_start_time("wall", priority=790, round=0)

        self._start_time = time.time()
        self._previous_training_time = 0
        self._cumulative_training_time = None
Ejemplo n.º 4
0
 def test_send_to_device_grad(self):
   xla_device = xm.xla_device()
   t = _gen_tensor(2, 2, requires_grad=True)
   dt = xm.send_cpu_data_to_device([t], xla_device)
   self.assertTrue(dt[0].requires_grad)
Ejemplo n.º 5
0
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
                        **kwargs):
        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", False)
        use_cdn = kwargs.pop("use_cdn", True)

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                TF2_WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False"
                        .format(
                            [
                                WEIGHTS_NAME, TF2_WEIGHTS_NAME,
                                TF_WEIGHTS_NAME + ".index"
                            ],
                            pretrained_model_name_or_path,
                        ))
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                assert (from_tf), (
                    "We found a TensorFlow checkpoint at {}, please set from_tf to True"
                    " to load from this checkpoint"
                ).format(pretrained_model_name_or_path + ".index")
                archive_file = pretrained_model_name_or_path + ".index"
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
                    use_cdn=use_cdn,
                )

            try:
                # Load from URL or cache if already cached
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                )
                if resolved_archive_file is None:
                    raise EnvironmentError
            except EnvironmentError:
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make "
                    f"sure that:\n\n- '{pretrained_model_name_or_path}' is a correct "
                    f"model identifier listed on 'https://huggingface.co/models'\n\n- "
                    f"or '{pretrained_model_name_or_path}' is the correct path to a "
                    f"directory containing a file named one of {WEIGHTS_NAME}, "
                    f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n")
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
            else:
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
        else:
            resolved_archive_file = None

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if state_dict is None and not from_tf:
            try:
                state_dict = torch.load(resolved_archive_file,
                                        map_location="cpu")
            except Exception:
                raise OSError(
                    "Unable to load weights from pytorch checkpoint file. "
                    "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
                )

        missing_keys = []
        unexpected_keys = []
        error_msgs = []

        if from_tf:
            if resolved_archive_file.endswith(".index"):
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(
                    model, config,
                    resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
                    from transformers import load_tf2_checkpoint_in_pytorch_model

                    model = load_tf2_checkpoint_in_pytorch_model(
                        model, resolved_archive_file, allow_missing_keys=True)
                except ImportError:
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
                    raise
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            has_all_sub_modules = all(
                any(s.startswith(_prefix) for s in state_dict.keys())
                for _prefix in model.base_model_prefixs)
            has_prefix_module = any(
                s.startswith(model.base_model_prefix)
                for s in state_dict.keys())

            old_keys = list(state_dict.keys())
            for key in old_keys:
                new_key = key

                if "gamma" in key:
                    new_key = new_key.replace("gamma", "weight")
                if "beta" in key:
                    new_key = new_key.replace("beta", "bias")

                _state = state_dict.pop(key)
                if has_all_sub_modules:
                    state_dict[new_key] = _state
                elif not has_prefix_module:
                    for _prefix in model.base_model_prefixs:
                        _key = _prefix + "." + new_key
                        state_dict[_key] = _state
                else:
                    if new_key.startswith(model.base_model_prefix):
                        for _prefix in model.base_model_prefixs:
                            _key = _prefix + new_key[len(model.
                                                         base_model_prefix):]
                            state_dict[_key] = _state
                    else:
                        state_dict[new_key] = _state
            if hasattr(model, "hack_pretrained_state_dict"):
                state_dict = model.hack_pretrained_state_dict(state_dict)

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, "_metadata", None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
            # so we need to apply the function recursively.
            def load(module, prefix=""):
                local_metadata = {} if metadata is None else metadata.get(
                    prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict,
                    prefix,
                    local_metadata,
                    True,
                    missing_keys,
                    unexpected_keys,
                    error_msgs,
                )
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + ".")

            # Make sure we are able to load base models as well as derived models (with heads)
            load(model)
            load = None

            if len(missing_keys) > 0:
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".
                    format(model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys))
            if len(error_msgs) > 0:
                raise RuntimeError(
                    "Error(s) in loading state_dict for {}:\n\t{}".format(
                        model.__class__.__name__, "\n\t".join(error_msgs)))
        model.tie_weights(
        )  # make sure token embedding weights are still tied if needed

        # Set model in evaluation mode to deactivate DropOut modules by default
        model.eval()

        if output_loading_info:
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs,
            }
            return model, loading_info

        if hasattr(config, "xla_device") and config.xla_device:
            import torch_xla.core.xla_model as xm

            model = xm.send_cpu_data_to_device(model, xm.xla_device())
            model.to(xm.xla_device())

        return model
Ejemplo n.º 6
0
    def __init__(self, cfg: DictConfig, task, model, criterion, quantizer=None):

        if isinstance(cfg, Namespace):
            logger.warning(
                "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
            )
            cfg = convert_namespace_to_omegaconf(cfg)

        self.cfg = cfg
        self.task = task

        # catalog shared parameters
        shared_params = _catalog_shared_params(model)
        self.tpu = cfg.common.tpu
        self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
        if self.cuda:
            self.device = torch.device("cuda")
        elif self.tpu:
            self.device = utils.get_tpu_device()
        else:
            self.device = torch.device("cpu")

        # copy model and criterion to current device/dtype
        self._criterion = criterion
        self._model = model
        if self.tpu:
            import torch_xla.core.xla_model as xm

            self._model = xm.send_cpu_data_to_device(self._model, self.device)
        if cfg.common.fp16:
            self._criterion = self._criterion.half()
            self._model = self._model.half()
        elif cfg.common.bf16:
            self._criterion = self._criterion.to(dtype=torch.bfloat16)
            self._model = self._model.to(dtype=torch.bfloat16)
        if not cfg.distributed_training.pipeline_model_parallel:
            self._criterion = self._criterion.to(device=self.device)
            self._model = self._model.to(device=self.device)
        self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
        self.last_device = None
        if self.cuda and self.pipeline_model_parallel:
            self.last_device = torch.device(
                cfg.distributed_training.pipeline_devices[-1]
            )

        # check that shared parameters are preserved after device transfer
        for shared_param in shared_params:
            ref = _get_module_by_path(self._model, shared_param[0])
            for path in shared_param[1:]:
                logger.info(
                    "detected shared parameter: {} <- {}".format(shared_param[0], path)
                )
                _set_module_by_path(self._model, path, ref)

        self._dummy_batch = None  # indicates we don't have a dummy batch at first
        self._lr_scheduler = None
        self._num_updates = 0
        self._num_xla_compiles = 0  # for TPUs
        self._optim_history = None
        self._optimizer = None
        self._warn_once = set()
        self._wrapped_criterion = None
        self._wrapped_model = None

        # TODO(myleott): support tpu
        if self.cuda and self.data_parallel_world_size > 1:
            self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
        else:
            self._grad_norm_buf = None

        self.quantizer = quantizer
        if self.quantizer is not None:
            self.quantizer.set_trainer(self)

        # get detailed cuda environment
        if self.cuda:
            self.cuda_env = utils.CudaEnvironment()
            if self.data_parallel_world_size > 1:
                self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env)
            else:
                self.cuda_env_arr = [self.cuda_env]
            if self.data_parallel_rank == 0:
                utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
        else:
            self.cuda_env = None
            self.cuda_env_arr = None

        metrics.log_start_time("wall", priority=790, round=0)

        self._start_time = time.time()
        self._previous_training_time = 0
        self._cumulative_training_time = None
Ejemplo n.º 7
0
  def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
    config = kwargs.pop("config", None)
    state_dict = kwargs.pop("state_dict", None)
    cache_dir = kwargs.pop("cache_dir", None)
    force_download = kwargs.pop("force_download", False)
    resume_download = kwargs.pop("resume_download", False)
    proxies = kwargs.pop("proxies", None)
    output_loading_info = kwargs.pop("output_loading_info", False)
    local_files_only = kwargs.pop("local_files_only", False)
    use_auth_token = kwargs.pop("use_auth_token", None)
    revision = kwargs.pop("revision", None)
    mirror = kwargs.pop("mirror", None)

    # Load config if we don't provide a configuration
    if not isinstance(config, PretrainedConfig):
      config_path = config if config is not None else pretrained_model_name_or_path
      config, model_kwargs = cls.config_class.from_pretrained(
        config_path,
        *model_args,
        cache_dir=cache_dir,
        return_unused_kwargs=True,
        force_download=force_download,
        resume_download=resume_download,
        proxies=proxies,
        local_files_only=local_files_only,
        use_auth_token=use_auth_token,
        revision=revision,
        **kwargs,
      )
    else:
      model_kwargs = kwargs

    # Load model
    if pretrained_model_name_or_path is not None:
      pretrained_model_name_or_path = str(pretrained_model_name_or_path)
      if os.path.isdir(pretrained_model_name_or_path):
        # Load from a PyTorch checkpoint
        archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
      elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
        archive_file = pretrained_model_name_or_path
      else:
        archive_file = hf_bucket_url(
          pretrained_model_name_or_path,
          filename=WEIGHTS_NAME,
          revision=revision,
          mirror=mirror,
        )
      try:
        # Load from URL or cache if already cached
        resolved_archive_file = cached_path(
          archive_file,
          cache_dir=cache_dir,
          force_download=force_download,
          proxies=proxies,
          resume_download=resume_download,
          local_files_only=local_files_only,
          use_auth_token=use_auth_token,
        )
      except EnvironmentError as err:
        #logger.error(err)
        msg = (
          f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
          f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
          f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}.\n\n"
        )
        raise EnvironmentError(msg)
    else:
      resolved_archive_file = None

    config.name_or_path = pretrained_model_name_or_path

    # Instantiate model.
    model = cls(config, *model_args, **model_kwargs)

    if state_dict is None:
      try:
        state_dict = torch.load(resolved_archive_file, map_location="cpu")
      except Exception:
        raise OSError(
          f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
          f"at '{resolved_archive_file}'"
        )

    missing_keys = []
    unexpected_keys = []
    error_msgs = []

    # Convert old format to new format if needed from a PyTorch state_dict
    old_keys = []
    new_keys = []
    m = {'embeddings.word_embeddings': 'word_embedding',
         'embeddings.position_embeddings': 'pos_embedding',
         'embeddings.token_type_embeddings': 'tk_type_embedding',
         'embeddings.LayerNorm': 'embed_layer_norm',
         'embeddings.dropout': 'embed_dropout',
         'encoder.layer': 'bert_layers',
         'pooler.dense': 'pooler_dense',
         'pooler.activation': 'pooler_af',
         'attention.self': "self_attention",
         'attention.output.dense': 'attention_dense',
         'attention.output.LayerNorm': 'attention_layer_norm',
         'attention.output.dropout': 'attention_dropout',
         'intermediate.dense': 'interm_dense',
         'intermediate.intermediate_act_fn': 'interm_af',
         'output.dense': 'out_dense',
         'output.LayerNorm': 'out_layer_norm',
         'output.dropout': 'out_dropout'}

    for key in state_dict.keys():
      new_key = None
      if "gamma" in key:
        new_key = key.replace("gamma", "weight")
      if "beta" in key:
        new_key = key.replace("beta", "bias")
      for x, y in m.items():
        if new_key is not None:
          _key = new_key
        else:
          _key = key
        if x in key:
          new_key = _key.replace(x, y)
      if new_key:
        old_keys.append(key)
        new_keys.append(new_key)

    for old_key, new_key in zip(old_keys, new_keys):
      # print(old_key, new_key)
      state_dict[new_key] = state_dict.pop(old_key)

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
      state_dict._metadata = metadata

    your_bert_params = [f"bert.{x[0]}" for x in model.named_parameters()]
    for k in state_dict:
      if k not in your_bert_params and not k.startswith("cls."):
        possible_rename = [x for x in k.split(".")[1:-1] if x in m.values()]
        raise ValueError(f"{k} cannot be reload to your model, one/some of {possible_rename} we provided have been renamed")

    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
    # so we need to apply the function recursively.
    def load(module: nn.Module, prefix=""):
      local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
      module._load_from_state_dict(
        state_dict,
        prefix,
        local_metadata,
        True,
        missing_keys,
        unexpected_keys,
        error_msgs,
      )
      for name, child in module._modules.items():
        if child is not None:
          load(child, prefix + name + ".")

    # Make sure we are able to load base models as well as derived models (with heads)
    start_prefix = ""
    model_to_load = model
    has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
    if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
      start_prefix = cls.base_model_prefix + "."
    if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
      model_to_load = getattr(model, cls.base_model_prefix)
    load(model_to_load, prefix=start_prefix)

    if model.__class__.__name__ != model_to_load.__class__.__name__:
      base_model_state_dict = model_to_load.state_dict().keys()
      head_model_state_dict_without_base_prefix = [
        key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
      ]
      missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)

    # Some models may have keys that are not in the state by design, removing them before needlessly warning
    # the user.
    if cls._keys_to_ignore_on_load_missing is not None:
      for pat in cls._keys_to_ignore_on_load_missing:
        missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

    if cls._keys_to_ignore_on_load_unexpected is not None:
      for pat in cls._keys_to_ignore_on_load_unexpected:
        unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

    if len(error_msgs) > 0:
      raise RuntimeError(
        "Error(s) in loading state_dict for {}:\n\t{}".format(
          model.__class__.__name__, "\n\t".join(error_msgs)
        )
      )

    # Set model in evaluation mode to deactivate DropOut modules by default
    model.eval()

    if output_loading_info:
      loading_info = {
        "missing_keys": missing_keys,
        "unexpected_keys": unexpected_keys,
        "error_msgs": error_msgs,
      }
      return model, loading_info

    if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
      import torch_xla.core.xla_model as xm

      model = xm.send_cpu_data_to_device(model, xm.xla_device())
      model.to(xm.xla_device())

    return model
Ejemplo n.º 8
0
 def load_state_dict(self, state_dict):
     if self.state.distributed_type == DistributedType.TPU and self.device_placement:
         xm.send_cpu_data_to_device(state_dict, self.state.device)
     self.optimizer.load_state_dict(state_dict)
Ejemplo n.º 9
0
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
        tag=None,
    ):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = None, [], None

        try:
            from fairseq.fb_pathmgr import fb_pathmgr
            bexists = fb_pathmgr.isfile(filename)
        except Exception:
            bexists = False
            if tag is not None:
                tagger_filename = os.path.join(filename,
                                               self.checkpoint_tagger_filename)
                if self.bexists(tagger_filename):
                    self.checkpoint_tagger = CheckpointTagger.load_from_json(
                        gcsfs.generic_read(tagger_filename))
                    filename = self.checkpoint_tagger.tags[tag]
                    bexists = self.bexists(filename)

        if bexists:
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(state['model'],
                                                 strict=True,
                                                 args=self.args)
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state['criterion'],
                                                         strict=True)
            except Exception:
                raise Exception(
                    'Cannot load model parameters from checkpoint {}; '
                    'please ensure that the architectures match.'.format(
                        filename))

            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
            last_optim_state = state.get('last_optimizer_state', None)

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'

            if self.xla:
                # tpu-comment: send states to device before loading
                last_optim_state = xm.send_cpu_data_to_device(
                    last_optim_state, self.xla_device)
                last_optim['lr_scheduler_state'] = xm.send_cpu_data_to_device(
                    last_optim['lr_scheduler_state'], self.xla_device)
            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(
                    last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state,
                                           optimizer_overrides)

            self.set_num_updates(last_optim['num_updates'])

        if extra_state is not None:
            epoch = extra_state['train_iterator']['epoch']
            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                filename, epoch, self.get_num_updates()))

            self.lr_step(epoch)

            if 'train_meters' in extra_state and not reset_meters:
                self.meters.update(extra_state['train_meters'])
                del extra_state['train_meters']

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in self.meters.values():
                    if isinstance(meter, TimeMeter):
                        meter.reset()
        else:
            print('| no existing checkpoint found {}'.format(filename))

        return extra_state