Exemplo n.º 1
0
    def __restore_from(self, path, state):
        if not os.path.isdir(path):
            if self._force_load:
                raise ValueError("force_load was set to True for checkpoint callback but a checkpoint was not found.")
            logging.warning(f"Checkpoint folder {path} not found!")
        else:
            logging.info(f"Found checkpoint folder {path}. Will attempt to restore checkpoints from it.")
            modules_to_restore = []
            modules_to_restore_name = []
            for module in AppState().modules:
                if module.num_weights > 0:
                    modules_to_restore.append(module)
                    modules_to_restore_name.append(str(module))
            step_check = None
            try:
                module_checkpoints, steps = get_checkpoint_from_dir(modules_to_restore_name, path, return_steps=True)

                # If the steps are different, print a warning message
                for step in steps:
                    if step_check is None:
                        step_check = step
                    elif step != step_check:
                        logging.warning("Restoring from modules checkpoints where the training step does not match")
                        break

                for mod, checkpoint in zip(modules_to_restore, module_checkpoints):
                    mod.restore_from(checkpoint, state["local_rank"])
            except (ValueError) as e:
                if self._force_load:
                    raise ValueError(
                        "force_load was set to True for checkpoint callback but a checkpoint was not found."
                    )
                logging.warning(e)
                logging.warning(
                    f"Checkpoint folder {path} was present but nothing was restored. Continuing training from random "
                    "initialization."
                )
                return

            try:
                trainer_checkpoints, steps = get_checkpoint_from_dir(["trainer"], path, return_steps=True)
                if step_check is not None and step_check != steps[0]:
                    logging.error(
                        "The step we are restoring from the trainer checkpoint does not match one or more steps that "
                        "are being restored from modules."
                    )
                state.restore_state_from(trainer_checkpoints[0])
            except (ValueError) as e:
                logging.warning(e)
                logging.warning(
                    "Trainer state such as optimizer state and current step/epoch was not restored. Pretrained weights"
                    " have still been restore and fine-tuning should continue fine."
                )
                return
Exemplo n.º 2
0
    def __restore_from(self, path):
        if not os.path.isdir(path):
            if self._force_load:
                raise ValueError(
                    "force_load was set to True for checkpoint callback but a checkpoint was not found."
                )
            logging.warning(f"Checkpoint folder {path} not found!")
        else:
            logging.info(
                f"Found checkpoint folder {path}. Will attempt to restore checkpoints from it."
            )
            modules_to_restore = []
            modules_to_restore_name = []
            for module in self.action.modules:
                if module.num_weights > 0:
                    modules_to_restore.append(module)
                    modules_to_restore_name.append(str(module))
            try:
                module_checkpoints = get_checkpoint_from_dir(
                    modules_to_restore_name, path)

                for mod, checkpoint in zip(modules_to_restore,
                                           module_checkpoints):
                    mod.restore_from(checkpoint, self.local_rank)
            except (BaseException, ValueError) as e:
                if self._force_load:
                    raise ValueError(
                        "force_load was set to True for checkpoint callback but a checkpoint was not found."
                    )
                logging.warning(e)
                logging.warning(
                    f"Checkpoint folder {path} was present but nothing was restored. Continuing training from random "
                    "initialization.")
                return

            try:
                trainer_checkpoints = get_checkpoint_from_dir(["trainer"],
                                                              path)
                for tr, checkpoint in zip([self.action], trainer_checkpoints):
                    tr.restore_state_from(checkpoint)
            except (BaseException, ValueError) as e:
                logging.warning(e)
                logging.warning(
                    "Trainer state such as optimizer state and current step/epoch was not restored. Pretrained weights"
                    " have still been restore and fine-tuning should continue fine."
                )
                return
Exemplo n.º 3
0
    def __restore_from(self, path):
        if not os.path.isdir(path):
            if self._force_load:
                raise ValueError(
                    "force_load was set to True for checkpoint callback but a checkpoint was not found."
                )
            logging.warning(f"Checkpoint folder {path} not found!")
        else:
            logging.info(f"Restoring checkpoint from folder {path} ...")
            modules_to_restore = []
            modules_to_restore_name = []
            for module in self.action.modules:
                if module.num_weights > 0:
                    modules_to_restore.append(module)
                    modules_to_restore_name.append(str(module))
            try:
                module_checkpoints = get_checkpoint_from_dir(
                    modules_to_restore_name, path)

                for mod, checkpoint in zip(modules_to_restore,
                                           module_checkpoints):
                    mod.restore_from(checkpoint, self.local_rank)
            except (BaseException, ValueError) as e:
                if self._force_load:
                    raise ValueError(
                        "force_load was set to True for checkpoint callback but a checkpoint was not found."
                    )
                logging.warning(e)
                logging.warning(
                    f"Checkpoint folder {path} present but did not restore")
                return

            try:
                trainer_checkpoints = get_checkpoint_from_dir(["trainer"],
                                                              path)
                for tr, checkpoint in zip([self.action], trainer_checkpoints):
                    tr.restore_state_from(checkpoint)
            except (BaseException, ValueError) as e:
                logging.warning(e)
                logging.warning("Trainer state wasn't restored")
                return