Example #1
0
    def init_fsdp_model_from_weights(
        cls,
        model: FullyShardedDataParallel,
        checkpoint: Dict[str, Any],
        weights_path: List[str],
    ):
        """
        Load the weights of the checkpoint to the FSDP model:
        Take into account the type of checkpoint to decide how
        to perform the load (sharded or consolidated load)
        """

        if checkpoint["type"] == CheckpointItemType.slice_list.name:
            SlicedCheckpointLoader.init_model_weights(model, checkpoint)
        elif checkpoint["type"] == CheckpointItemType.consolidated.name:
            weights = cls._extract_weights(checkpoint, weights_path)
            model.load_state_dict(weights)
        else:
            weights = cls._extract_weights(checkpoint, weights_path)
            model.load_local_state_dict(weights)
Example #2
0
 def init_fsdp_model_from_weights(
     cls,
     model: FullyShardedDataParallel,
     checkpoint: Dict[str, Any],
     weights_path: List[str],
     strict: bool = True,
     head_index: int = -1,
 ):
     """
     Load the weights of the checkpoint to the FSDP model:
     - Take into account the type of checkpoint to decide on how
       to perform the load (sharded or consolidated load)
     - Takes into account the head_index (-1 if trunk else >= 0)
       to find the appropriate weights for the head
     """
     if checkpoint["type"] == CheckpointItemType.slice_list.name:
         # Hack for checkpoints consolidated with the "layers" format
         # instead of the new "classy_state_dict" format: in that case
         # the slices are directly saved under "layers" and do not take
         # into account the 'weights_path' variable
         if "classy_state_dict" not in checkpoint:
             weights = checkpoint["layers"]
         else:
             weights = cls._extract_weights(checkpoint, weights_path,
                                            head_index)
         if weights is not None:
             SlicedCheckpointLoader.load_slice_state_dict(model,
                                                          weights,
                                                          strict=strict)
         else:
             raise ValueError(
                 f"Could not find weights path: {weights_path}")
     elif checkpoint["type"] == CheckpointItemType.consolidated.name:
         weights = cls._extract_weights(checkpoint, weights_path,
                                        head_index)
         if weights is not None:
             out = model.load_state_dict(weights, strict=False)
             cls._check_load_state_dict_out(out, strict=strict)
         elif strict:
             raise ValueError(
                 f"Could not find weights path: {weights_path}")
     else:
         weights = cls._extract_weights(checkpoint, weights_path,
                                        head_index)
         if weights is not None:
             out = model.load_local_state_dict(weights, strict=False)
             cls._check_load_state_dict_out(out, strict=strict)
         elif strict:
             raise ValueError(
                 f"Could not find weights path: {weights_path}")