コード例 #1
0
ファイル: trainer.py プロジェクト: lianqing01/transformer-cbn
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = None, [], None

        try:
            from fairseq.fb_pathmgr import fb_pathmgr
            bexists = fb_pathmgr.isfile(filename)
        except Exception:
            bexists = os.path.exists(filename)

        if bexists:
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(state['model'],
                                                 strict=True,
                                                 args=self.args)
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state['criterion'],
                                                         strict=True)
            except Exception:
                raise Exception(
                    'Cannot load model parameters from checkpoint {}; '
                    'please ensure that the architectures match.'.format(
                        filename))

            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
            last_optim_state = state.get('last_optimizer_state', None)

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(
                    last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state,
                                           optimizer_overrides)

            self.set_num_updates(last_optim['num_updates'])

        if extra_state is not None:
            epoch = extra_state['train_iterator']['epoch']
            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                filename, epoch, self.get_num_updates()))

            self.lr_step(epoch)

            if 'train_meters' in extra_state and not reset_meters:
                self.meters.update(extra_state['train_meters'])
                del extra_state['train_meters']

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in self.meters.values():
                    if isinstance(meter, TimeMeter):
                        meter.reset()
        else:
            print('| no existing checkpoint found {}'.format(filename))

        return extra_state
コード例 #2
0
    def load_checkpoint(self,
                        filename,
                        reset_optimizer=False,
                        reset_lr_scheduler=False,
                        optimizer_overrides=None,
                        reset_meters=False):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = None, [], None

        try:
            from fairseq.fb_pathmgr import fb_pathmgr

            bexists = fb_pathmgr.isfile(filename)
        except (ModuleNotFoundError, ImportError):
            bexists = os.path.exists(filename)

        if bexists:
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(state["model"],
                                                 strict=True,
                                                 args=self.args)
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state["criterion"],
                                                         strict=True)
            except Exception:
                raise Exception(
                    "Cannot load model parameters from checkpoint {}; "
                    "please ensure that the architectures match.".format(
                        filename))

            extra_state = state["extra_state"]
            self._optim_history = state["optimizer_history"]
            last_optim_state = state.get("last_optimizer_state", None)

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert (
                last_optim["criterion_name"] ==
                self.get_criterion().__class__.__name__
            ), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
            assert (
                last_optim["optimizer_name"] ==
                self.optimizer.__class__.__name__
            ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(
                    last_optim["lr_scheduler_state"])
            self.optimizer.load_state_dict(last_optim_state,
                                           optimizer_overrides)

            self.set_num_updates(last_optim["num_updates"])

        if extra_state is not None:
            epoch = extra_state["train_iterator"]["epoch"]
            print("| loaded checkpoint {} (epoch {} @ {} updates)".format(
                filename, epoch, self.get_num_updates()))

            self.lr_step(epoch)

            if "train_meters" in extra_state and not reset_meters:
                self.meters.update(extra_state["train_meters"])
                del extra_state["train_meters"]

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in self.meters.values():
                    if isinstance(meter, TimeMeter):
                        meter.reset()
        else:
            print("| no existing checkpoint found {}".format(filename))

        return extra_state
コード例 #3
0
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
        """Load all training state from a checkpoint file."""
        def rename_state_dict_keys(source,old_name,new_name,all=True):
            new_state_dict = OrderedDict()

            for key, value in source.items():
                if all:
                    if "encoder" in key and "decoder" not in key:
                        key = key.replace(old_name,new_name)
                    new_state_dict[key] = value
                else:
                    if "encoder" in key and "decoder" not in key:
                        key = key.replace(old_name,new_name)
                        new_state_dict[key] = value
            return new_state_dict

        extra_state, self._optim_history, last_optim_state = None, [], None

        try:
            from fairseq.fb_pathmgr import fb_pathmgr
            bexists = fb_pathmgr.isfile(filename)
        except Exception:
            bexists = os.path.exists(filename)

        if bexists:
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                # if self.args.task == "audio_translation" and self.args.audio_pt is not None:
                #     # print("**************")
                #     # state['model']=rename_state_dict_keys(state['model'],"encoder","text_encoder")
                #     self.get_model().load_state_dict(state['model'], strict=False, args=self.args)
                # else:
                self.get_model().load_state_dict(state['model'], strict=True, args=self.args)

                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state['criterion'], strict=True)
            except Exception:
                raise Exception(
                    'Cannot load model parameters from checkpoint {}; '
                    'please ensure that the architectures match.'.format(filename)
                )
            extra_state = state['extra_state']
            self._optim_history = state['optimizer_history']
            last_optim_state = state.get('last_optimizer_state', None)


        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
                'Criterion does not match; please reset the optimizer (--reset-optimizer).'
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
                'Optimizer does not match; please reset the optimizer (--reset-optimizer).'

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

            self.set_num_updates(last_optim['num_updates'])

        if extra_state is not None:
            epoch = extra_state['train_iterator']['epoch']
            print('| loaded mt checkpoint {} (epoch {} @ {} updates)'.format(
                filename, epoch, self.get_num_updates()))

            self.lr_step(epoch)

            if 'train_meters' in extra_state and not reset_meters:
                self.meters.update(extra_state['train_meters'])
                del extra_state['train_meters']

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in self.meters.values():
                    if isinstance(meter, TimeMeter):
                        meter.reset()
        else:
            print('| no existing checkpoint found {}'.format(filename))
        if self.args.task == "audio_translation" and self.args.audio_pt is not None:
            print('| loaded audio checkpoint {} '.format(self.args.audio_pt))
            model_dict = torch.load(self.args.audio_pt)
            model_dict["model"] = rename_state_dict_keys(model_dict["model"], "encoder", "audio_encoder", all=False)
            self.model.load_state_dict(model_dict["model"], strict=False)
        if self.args.task == "audio_translation" and self.args.mt_pt is not None:
            print('| loaded mt checkpoint {} '.format(self.args.mt_pt))
            model_dict = torch.load(self.args.mt_pt)
            # model_dict["model"] = rename_state_dict_keys(model_dict["model"], "encoder", "audio_encoder", all=False)
            self.model.load_state_dict(model_dict["model"], strict=False)
        return extra_state