def _load_checkpoint(cls, model: DeviceAwareModule, checkpoint_path: Path,
                         key_in_state_dict: str, use_gpu: bool) -> int:
        """
        Loads a checkpoint of a model, may be the model or the mean teacher model. Assumes the model
        has already been created, and the checkpoint exists. This does not set checkpoint epoch.
        This method should not be called externally. Use instead try_load_checkpoint_for_model
        or try_load_checkpoint_for_mean_teacher_model
        :param model: model to load weights
        :param checkpoint_path: Path to checkpoint
        :param key_in_state_dict: the key for the model weights in the checkpoint state dict
        :param reader: Function which takes the path and returns a dict with model and optimizer states
        :return checkpoint epoch from the state dict
        """
        logging.info(f"Loading checkpoint {checkpoint_path}")
        checkpoint = ModelAndInfo.read_checkpoint(checkpoint_path, use_gpu)

        try:
            state_dict = checkpoint[key_in_state_dict]
        except KeyError:
            logging.error(f"Key {key_in_state_dict} not found in checkpoint")
            return False

        if isinstance(model, torch.nn.DataParallel):
            result = model.module.load_state_dict(state_dict, strict=False)
        else:
            result = model.load_state_dict(state_dict, strict=False)

        if result.missing_keys:
            logging.warning(f"Missing keys in model checkpoint: {result.missing_keys}")
        if result.unexpected_keys:
            logging.warning(f"Unexpected keys in model checkpoint: {result.unexpected_keys}")

        return checkpoint[ModelAndInfo.EPOCH_KEY]
    def _load_checkpoint(cls, model: DeviceAwareModule, checkpoint_path: Path,
                         key_in_state_dict: str, use_gpu: bool) -> int:
        """
        Loads a checkpoint of a model, may be the model or the mean teacher model. Assumes the model
        has already been created, and the checkpoint exists. This does not set checkpoint epoch.
        This method should not be called externally. Use instead try_load_checkpoint_for_model
        or try_load_checkpoint_for_mean_teacher_model
        :param model: model to load weights
        :param key_in_state_dict: the key for the model weights in the checkpoint state dict
        :return checkpoint epoch form the state dict
        """
        logging.info(f"Loading checkpoint {checkpoint_path}")
        # For model debugging, allow loading a GPU trained model onto the CPU. This will clearly only work
        # if the model is small.
        map_location = None if use_gpu else 'cpu'
        checkpoint = torch.load(str(checkpoint_path),
                                map_location=map_location)

        if isinstance(model, torch.nn.DataParallel):
            model.module.load_state_dict(checkpoint[key_in_state_dict])
        else:
            model.load_state_dict(checkpoint[key_in_state_dict])
        return checkpoint[ModelAndInfo.EPOCH_KEY]