예제 #1
0
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
        """
        Load all training state from a checkpoint file.
        rank = 0 will load the checkpoint, and then broadcast it to all
        other ranks.
        """
        extra_state, self._optim_history, last_optim_state = None, [], None

        bexists = PathManager.isfile(filename)
        if bexists:
            if (self.data_parallel_rank == 0
                    # TPUs don't support broadcast yet, so load checkpoints
                    # on every worker for now
                    or self.tpu):
                state = checkpoint_utils.load_checkpoint_to_cpu(filename)
                last_optim_state = state.get("last_optimizer_state", None)

                # If doing zero_sharding, do not broadcast global optimizer
                # state. Later we will broadcast sharded states to each rank
                # to avoid memory from exploding.
                if (self.cfg.distributed_training.zero_sharding == "os"
                        and "last_optimizer_state" in state
                        and self.data_parallel_world_size > 1):
                    state["last_optimizer_state"] = "SHARDED"
            else:
                last_optim_state = None
                state = None

            if (self.data_parallel_world_size > 1
                    # disable on TPUs until they support broadcast
                    and not self.tpu):
                state = distributed_utils.broadcast_object(
                    state,
                    src_rank=0,
                    group=self.data_parallel_process_group,
                    dist_device=self.device,
                )
                if self.data_parallel_rank > 0:
                    last_optim_state = state.get("last_optimizer_state", None)

            # load model parameters
            try:
                self.get_model().load_state_dict(state["model"],
                                                 strict=True,
                                                 model_cfg=self.cfg.model)
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state["criterion"],
                                                         strict=True)
            except Exception:
                raise Exception(
                    "Cannot load model parameters from checkpoint {}; "
                    "please ensure that the architectures match.".format(
                        filename))
            extra_state = state["extra_state"]
            self._optim_history = state["optimizer_history"]

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert (
                last_optim["criterion_name"] ==
                self.get_criterion().__class__.__name__
            ), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
            assert (
                last_optim["optimizer_name"] ==
                self.optimizer.__class__.__name__
            ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(
                    last_optim["lr_scheduler_state"])

            if self.data_parallel_world_size > 1:
                last_optim_state = self.optimizer.broadcast_global_state_dict(
                    last_optim_state)
            self.optimizer.load_state_dict(last_optim_state,
                                           optimizer_overrides)

            self.set_num_updates(last_optim["num_updates"])

        if extra_state is not None:
            epoch = extra_state["train_iterator"]["epoch"]
            logger.info("loaded checkpoint {} (epoch {} @ {} updates)".format(
                filename, epoch, self.get_num_updates()))

            if "previous_training_time" in extra_state:
                self._previous_training_time = extra_state[
                    "previous_training_time"]
                self._start_time = time.time()

            self.lr_step(epoch)

            if "metrics" in extra_state and not reset_meters:
                metrics.load_state_dict(extra_state["metrics"])

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in metrics.get_meters("default"):
                    if isinstance(meter, meters.TimeMeter):
                        meter.reset()
        else:
            logger.info("no existing checkpoint found {}".format(filename))

        return extra_state
예제 #2
0
    def binarize(
        filename,
        dict,
        consumer,
        tokenize=tokenize_line,
        append_eos=True,
        reverse_order=False,
        offset=0,
        end=-1,
        already_numberized=False,
        avoid_tokenize=False,
    ) -> Dict[str, int]:
        nseq, ntok = 0, 0
        replaced = Counter()

        def replaced_consumer(word, idx):
            if idx == dict.unk_index and word != dict.unk_word:
                replaced.update([word])

        def replaced_consumer_from_pretrained(word, idx):
            if idx == dict.convert_tokens_to_ids(
                    dict.unk_token) and word != dict.unk_token:
                replaced.update([word])

        with open(PathManager.get_local_path(filename), "r",
                  encoding="utf-8") as f:
            f.seek(offset)
            # next(f) breaks f.tell(), hence readline() must be used
            line = safe_readline(f)
            while line:
                # f.tell() does not always give the byte position in the file
                # sometimes it skips to a very large number
                # it is unlikely that through a normal read we go from
                # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely
                # that the procedure breaks by the undeterministic behavior of
                # f.tell()
                if end > 0 and f.tell() > end and f.tell() < end + 2**32:
                    break
                if already_numberized:
                    id_strings = line.strip().split()
                    id_list = [int(id_string) for id_string in id_strings]
                    if reverse_order:
                        id_list.reverse()
                    if append_eos:
                        id_list.append(dict.eos())
                    ids = torch.IntTensor(id_list)
                elif isinstance(dict, BertTokenizer) and not isinstance(
                        dict, ElectraTokenizer):
                    line = line.strip()
                    line = '{} {} {}'.format('[CLS]', line, '[SEP]')
                    if avoid_tokenize is False:
                        tokenizedline = dict.tokenize(line)
                    else:
                        tokenizedline = line.strip().split()
                    # max-len:1000000000000
                    # print('----------bert_max-len:' + str(dict.max_len) + '----------')
                    # if len(tokenizedline) > dict.max_len:
                    #     tokenizedline = tokenizedline[:dict.max_len - 1]
                    #     tokenizedline.append('[SEP]')
                    words = dict.convert_tokens_to_ids(tokenizedline)
                    nwords = len(words)
                    ids = torch.IntTensor(nwords)
                    for i, word in enumerate(words):
                        ids[i] = word
                        replaced_consumer_from_pretrained(
                            tokenizedline[i], word)
                elif isinstance(dict, BartTokenizer):
                    line = line.strip()
                    if avoid_tokenize is False:
                        # extra space at the end will cause weird outputs.
                        line = '{} {}{}'.format('<s>', line, '</s>')
                        tokenizedline = dict.tokenize(line)
                    else:
                        line = '{} {} {}'.format('<s>', line, '</s>')
                        tokenizedline = line.strip().split()
                    # tokenizedline = dict.tokenize(line)
                    words = dict.convert_tokens_to_ids(tokenizedline)
                    assert len(tokenizedline) == len(words)
                    nwords = len(words)
                    ids = torch.IntTensor(nwords)
                    for i, word in enumerate(words):
                        ids[i] = word
                        replaced_consumer_from_pretrained(
                            tokenizedline[i], word)
                elif isinstance(dict, ElectraTokenizer):
                    line = line.strip()
                    line = '{} {} {}'.format('[CLS]', line, '[SEP]')
                    if avoid_tokenize is False:
                        tokenizedline = dict.tokenize(line)
                    else:
                        tokenizedline = line.strip().split()
                    # max-len:1000000000000
                    # print('----------bert_max-len:' + str(dict.max_len) + '----------')
                    # if len(tokenizedline) > dict.max_len:
                    #     tokenizedline = tokenizedline[:dict.max_len - 1]
                    #     tokenizedline.append('[SEP]')
                    words = dict.convert_tokens_to_ids(tokenizedline)
                    #
                    # import pdb; pdb.set_trace()
                    nwords = len(words)
                    ids = torch.IntTensor(nwords)
                    for i, word in enumerate(words):
                        ids[i] = word
                        replaced_consumer_from_pretrained(
                            tokenizedline[i], word)
                elif dict is None:
                    line = line.strip()
                    words = line.split()
                    words = [int(item) for item in words]
                    nwords = len(words)
                    ids = torch.IntTensor(nwords)
                    for i, word in enumerate(words):
                        ids[i] = word
                else:
                    ids = dict.encode_line(
                        line=line,
                        line_tokenizer=tokenize,
                        add_if_not_exist=False,
                        consumer=replaced_consumer,
                        append_eos=append_eos,
                        reverse_order=reverse_order,
                    )
                nseq += 1
                ntok += len(ids)
                consumer(ids)
                line = f.readline()
        return {
            "nseq": nseq,
            "nunk": sum(replaced.values()),
            "ntok": ntok,
            "replaced": replaced,
        }
def save_state(
    filename,
    cfg: FairseqConfig,
    model_state_dict,
    criterion,
    optimizer,
    lr_scheduler,
    num_updates,
    optim_history=None,
    extra_state=None,
    task=None,
    **kwargs,
):
    from fairseq import utils

    if optim_history is None:
        optim_history = []
    if extra_state is None:
        extra_state = {}
    state_dict = {
        "cfg": cfg,
        "args": kwargs.get("args", None),
        "model": model_state_dict or {},
        "optimizer_history": optim_history
        + [
            {
                "criterion_name": criterion.__class__.__name__,
                "optimizer_name": optimizer.__class__.__name__,
                "lr_scheduler_state": lr_scheduler.state_dict(),
                "num_updates": num_updates,
            }
        ],
        "extra_state": extra_state,
        "task_state": task.state_dict() if task is not None else {}
    }
    if utils.has_parameters(criterion):
        state_dict["criterion"] = criterion.state_dict()

    if cfg is None:
        cfg = state_dict["args"]
        assert cfg is not None, "must provide cfg or args"

    if isinstance(cfg, DictConfig):
        no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state
    else:
        no_save_optimizer_state = cfg.no_save_optimizer_state
    if not no_save_optimizer_state:
        state_dict["last_optimizer_state"] = optimizer.state_dict()

    # keep everything on CPU
    state_dict = utils.move_to_cpu(state_dict)

    if PathManager.supports_rename(filename):
        # do atomic save
        with PathManager.open(filename + ".tmp", "wb") as f:
            torch_persistent_save(state_dict, f)
        PathManager.rename(filename + ".tmp", filename)
    else:
        # fallback to non-atomic save
        with PathManager.open(filename, "wb") as f:
            torch_persistent_save(state_dict, f)
예제 #4
0
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = None, [], None

        bexists = PathManager.isfile(filename)
        if bexists:
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(
                    state["model"], strict=True, args=self.args
                )
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(
                        state["criterion"], strict=True
                    )
            except Exception:
                raise Exception(
                    "Cannot load model parameters from checkpoint {}; "
                    "please ensure that the architectures match.".format(filename)
                )

            extra_state = state["extra_state"]
            self._optim_history = state["optimizer_history"]
            last_optim_state = state.get("last_optimizer_state", None)

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert (
                last_optim["criterion_name"] == self.get_criterion().__class__.__name__
            ), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
            assert (
                last_optim["optimizer_name"] == self.optimizer.__class__.__name__
            ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

            self.set_num_updates(last_optim["num_updates"])

        if extra_state is not None:
            epoch = extra_state["train_iterator"]["epoch"]
            logger.info(
                "loaded checkpoint {} (epoch {} @ {} updates)".format(
                    filename, epoch, self.get_num_updates()
                )
            )

            if "previous_training_time" in extra_state:
                self._previous_training_time = extra_state["previous_training_time"]
                self._start_time = time.time()

            self.lr_step(epoch)

            if "metrics" in extra_state and not reset_meters:
                metrics.load_state_dict(extra_state["metrics"])

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in metrics.get_meters("default"):
                    if isinstance(meter, meters.TimeMeter):
                        meter.reset()
        else:
            logger.info("no existing checkpoint found {}".format(filename))

        return extra_state
예제 #5
0
    def test_file_io(self):
        from fairseq.file_io import PathManager

        with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f:
            s = f.read()
        self.assertEqual(s, self._tmpfile_contents)
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
    from fairseq import meters

    # only one worker should attempt to create the required dir
    if cfg.distributed_rank == 0:
        os.makedirs(cfg.save_dir, exist_ok=True)

    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        best_function = max if cfg.maximize_best_checkpoint_metric else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if cfg.no_save:
        return

    trainer.consolidate_optimizer()

    if not trainer.is_data_parallel_master:
        return

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")

    def is_better(a, b):
        return a >= b if cfg.maximize_best_checkpoint_metric else a <= b

    suffix = cfg.checkpoint_suffix or ""
    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
        end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
    )
    checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
        not end_of_epoch
        and cfg.save_interval_updates > 0
        and updates % cfg.save_interval_updates == 0
    )
    checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
        not hasattr(save_checkpoint, "best")
        or is_better(val_loss, save_checkpoint.best)
    )
    if val_loss is not None and cfg.keep_best_checkpoints > 0:
        checkpoint_conds[
            "checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss)
        ] = not hasattr(save_checkpoint, "best") or is_better(
            val_loss, save_checkpoint.best
        )
    checkpoint_conds[
        "checkpoint_last{}.pt".format(suffix)
    ] = not cfg.no_last_checkpoints

    extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
    if hasattr(save_checkpoint, "best"):
        extra_state.update({"best": save_checkpoint.best})

    checkpoints = [
        os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            assert PathManager.copy(
                checkpoints[0], cp, overwrite=True
            ), f"Failed to copy {checkpoints[0]} to {cp}"

        write_timer.stop()
        logger.info(
            "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
                checkpoints[0], epoch, updates, val_loss, write_timer.sum
            )
        )

    if not end_of_epoch and cfg.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(
            cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
        )
        for old_chk in checkpoints[cfg.keep_interval_updates :]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if cfg.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt")
        for old_chk in checkpoints[cfg.keep_last_epochs :]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if cfg.keep_best_checkpoints > 0:
        # only keep the best N checkpoints according to validation metric
        checkpoints = checkpoint_paths(
            cfg.save_dir,
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
                cfg.best_checkpoint_metric
            ),
        )
        if not cfg.maximize_best_checkpoint_metric:
            checkpoints = checkpoints[::-1]
        for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
예제 #7
0
def load_checkpoint(args, trainer, **passthrough_args):
    """
    Load a checkpoint and restore the training iterator.

    *passthrough_args* will be passed through to
    ``trainer.get_train_iterator``.
    """
    reset_optimizer = args.reset_optimizer
    reset_lr_scheduler = args.reset_lr_scheduler
    optimizer_overrides = eval(args.optimizer_overrides)
    reset_meters = args.reset_meters
    reset_dataloader = args.reset_dataloader

    if getattr(args, 'finetune_from_model', None) is not None \
       and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader):
        raise ValueError(
            "--finetune-from-model can not be set together with either --reset-optimizer"
            " or reset_lr_scheduler or reset_meters or reset_dataloader")

    suffix = getattr(args, "checkpoint_suffix", "")
    if args.restore_file == "checkpoint_last.pt":  # default value of restore_file is 'checkpoint_last.pt'
        checkpoint_path = os.path.join(args.save_dir,
                                       "checkpoint_last{}.pt".format(suffix))
        first_launch = not PathManager.exists(checkpoint_path)
        if getattr(args, 'finetune_from_model',
                   None) is not None and first_launch:
            # if there is no last checkpoint to restore, start the finetune from pretrained model
            # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
            if PathManager.exists(args.finetune_from_model):
                checkpoint_path = args.finetune_from_model
                reset_optimizer = True
                reset_lr_scheduler = True
                reset_meters = True
                reset_dataloader = True
                logger.info(
                    f'loading pretrained model from {checkpoint_path}: '
                    'optimizer, lr scheduler, meters, dataloader will be reset'
                )
            else:
                raise ValueError(
                    f'--funetune-from-model {args.finetune_from_model} does not exist'
                )
    elif getattr(args, "model_parallel_size", 1) > 1:
        checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
    else:
        checkpoint_path = args.restore_file

    if args.restore_file != "checkpoint_last.pt" and getattr(
            args, 'finetune_from_model', None):
        raise ValueError(
            '--finetune-from-model and --restore-file (non-default value) '
            'can not be specified together: ' + str(args))

    extra_state = trainer.load_checkpoint(
        checkpoint_path,
        reset_optimizer,
        reset_lr_scheduler,
        optimizer_overrides,
        reset_meters=reset_meters,
    )

    if (extra_state is not None and "best" in extra_state
            and not reset_optimizer and not reset_meters):
        save_checkpoint.best = extra_state["best"]

    if extra_state is not None and not reset_dataloader:
        # restore iterator from checkpoint
        itr_state = extra_state["train_iterator"]
        epoch_itrs = trainer.get_train_iterator(epoch=itr_state["epoch"],
                                                load_dataset=True,
                                                **passthrough_args)
        epoch_itrs.load_state_dict(itr_state)
    else:
        epoch_itrs = trainer.get_train_iterator(epoch=1,
                                                load_dataset=True,
                                                **passthrough_args)
    if isinstance(epoch_itrs, list):
        trainer.lr_step(epoch_itrs[0].epoch)
    else:
        trainer.lr_step(epoch_itrs.epoch)

    return extra_state, epoch_itrs
예제 #8
0
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
    from fairseq import meters

    # only one worker should attempt to create the required dir
    if trainer.data_parallel_rank == 0:
        os.makedirs(cfg.save_dir, exist_ok=True)

    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        best_function = max if cfg.maximize_best_checkpoint_metric else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if cfg.no_save:
        return

    trainer.consolidate_optimizer(
    )  # TODO(SS): do we need this if no_save_optimizer_state

    if not trainer.should_save_checkpoint_on_current_rank:
        if trainer.always_call_state_dict_during_save_checkpoint:
            trainer.state_dict()
        return

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    logger.info(
        f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")

    def is_better(a, b):
        return a >= b if cfg.maximize_best_checkpoint_metric else a <= b

    suffix = trainer.checkpoint_suffix
    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}{}.pt".format(
        epoch, suffix)] = (end_of_epoch and not cfg.no_epoch_checkpoints
                           and epoch % cfg.save_interval == 0)
    checkpoint_conds["checkpoint_{}_{}{}.pt".format(
        epoch, updates,
        suffix)] = (not end_of_epoch and cfg.save_interval_updates > 0
                    and updates % cfg.save_interval_updates == 0)
    checkpoint_conds["checkpoint_best{}.pt".format(
        suffix)] = val_loss is not None and (
            not hasattr(save_checkpoint, "best")
            or is_better(val_loss, save_checkpoint.best))
    if val_loss is not None and cfg.keep_best_checkpoints > 0:
        worst_best = getattr(save_checkpoint, "best", None)
        chkpts = checkpoint_paths(
            cfg.save_dir,
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
                cfg.best_checkpoint_metric),
        )
        if len(chkpts) > 0:
            p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
            worst_best = float(p.rsplit("_")[-1].replace(".pt", ""))
        # add random digits to resolve ties
        rand_sfx = randint(0, cfg.keep_best_checkpoints)
        checkpoint_conds["checkpoint.best_{}_{:.3f}{}.pt".format(
            cfg.best_checkpoint_metric, val_loss,
            rand_sfx)] = worst_best is None or is_better(val_loss, worst_best)
    checkpoint_conds["checkpoint_last{}.pt".format(
        suffix)] = not cfg.no_last_checkpoints

    extra_state = {
        "train_iterator": epoch_itr.state_dict(),
        "val_loss": val_loss
    }
    if hasattr(save_checkpoint, "best"):
        extra_state.update({"best": save_checkpoint.best})

    checkpoints = [
        os.path.join(cfg.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            if cfg.write_checkpoints_asynchronously:
                # TODO[ioPath]: Need to implement a delayed asynchronous
                # file copying/moving feature.
                logger.warning(
                    f"ioPath is not copying {checkpoints[0]} to {cp} "
                    "since async write mode is on.")
            else:
                assert PathManager.copy(
                    checkpoints[0], cp,
                    overwrite=True), f"Failed to copy {checkpoints[0]} to {cp}"

        write_timer.stop()
        logger.info(
            "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)"
            .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum))

    if not end_of_epoch and cfg.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        if cfg.keep_interval_updates_pattern == -1:
            checkpoints = checkpoint_paths(
                cfg.save_dir,
                pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix))
        else:
            checkpoints = checkpoint_paths(
                cfg.save_dir,
                pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
                keep_match=True,
            )
            checkpoints = [
                x[0] for x in checkpoints
                if x[1] % cfg.keep_interval_updates_pattern != 0
            ]

        for old_chk in checkpoints[cfg.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
            elif PathManager.exists(old_chk):
                PathManager.rm(old_chk)

    if cfg.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(
            cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix))
        for old_chk in checkpoints[cfg.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if cfg.keep_best_checkpoints > 0:
        # only keep the best N checkpoints according to validation metric
        checkpoints = checkpoint_paths(
            cfg.save_dir,
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
                cfg.best_checkpoint_metric, suffix),
        )
        if not cfg.maximize_best_checkpoint_metric:
            checkpoints = checkpoints[::-1]
        for old_chk in checkpoints[cfg.keep_best_checkpoints:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    from fairseq import distributed_utils, meters

    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        best_function = max if args.maximize_best_checkpoint_metric else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if args.no_save or not distributed_utils.is_master(args):
        return

    def is_better(a, b):
        return a >= b if args.maximize_best_checkpoint_metric else a <= b

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
        end_of_epoch and not args.no_epoch_checkpoints
        and epoch % args.save_interval == 0)
    checkpoint_conds["checkpoint_{}_{}.pt".format(
        epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0
                            and updates % args.save_interval_updates == 0)
    checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
        not hasattr(save_checkpoint, "best")
        or is_better(val_loss, save_checkpoint.best))
    if val_loss is not None and args.keep_best_checkpoints > 0:
        checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
            args.best_checkpoint_metric,
            val_loss)] = (not hasattr(save_checkpoint, "best")
                          or is_better(val_loss, save_checkpoint.best))
    checkpoint_conds["checkpoint_last.pt"] = not args.no_last_checkpoints

    extra_state = {
        "train_iterator": epoch_itr.state_dict(),
        "val_loss": val_loss
    }
    if hasattr(save_checkpoint, "best"):
        extra_state.update({"best": save_checkpoint.best})

    checkpoints = [
        os.path.join(args.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            PathManager.copy(checkpoints[0], cp, overwrite=True)

        write_timer.stop()
        logger.info(
            "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)"
            .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum))

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(args.save_dir,
                                       pattern=r"checkpoint_\d+_(\d+)\.pt")
        for old_chk in checkpoints[args.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(args.save_dir,
                                       pattern=r"checkpoint(\d+)\.pt")
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_best_checkpoints > 0:
        # only keep the best N checkpoints according to validation metric
        checkpoints = checkpoint_paths(
            args.save_dir,
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
                args.best_checkpoint_metric))
        if not args.maximize_best_checkpoint_metric:
            checkpoints = checkpoints[::-1]
        for old_chk in checkpoints[args.keep_best_checkpoints:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
예제 #10
0
def add_file_to_dictionary(filename, dict, tokenize):
    with PathManager.open(filename, "r", encoding="utf-8") as f:
        for line in f:
            for word in tokenize(line):
                dict.add_symbol(word)
            dict.add_symbol(dict.eos_word)
예제 #11
0
def main(cfg: FairseqConfig) -> None:
    if isinstance(cfg, argparse.Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    if distributed_utils.is_master(
            cfg.distributed_training) and "job_logging_cfg" in cfg:
        # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
        logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"
    metrics.reset()

    if cfg.common.log_file is not None:
        handler = logging.FileHandler(filename=cfg.common.log_file)
        logger.addHandler(handler)

    np.random.seed(cfg.common.seed)
    utils.set_torch_seed(cfg.common.seed)

    if distributed_utils.is_master(cfg.distributed_training):
        checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)

    # Print args
    logger.info(cfg)

    if cfg.checkpoint.write_checkpoints_asynchronously:
        try:
            import iopath  # noqa: F401
        except ImportError:
            logging.exception(
                "Asynchronous checkpoint writing is specified but iopath is "
                "not installed: `pip install iopath`")
            return

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(cfg.task)

    assert cfg.criterion, "Please specify criterion to train a model"

    # Build model and criterion
    if cfg.distributed_training.ddp_backend == "fully_sharded":
        with fsdp_enable_wrap(cfg.distributed_training):
            model = fsdp_wrap(task.build_model(cfg.model))
    else:
        model = task.build_model(cfg.model)
    criterion = task.build_criterion(cfg.criterion)
    logger.info(model)
    logger.info("task: {}".format(task.__class__.__name__))
    logger.info("model: {}".format(model.__class__.__name__))
    logger.info("criterion: {}".format(criterion.__class__.__name__))
    logger.info("num. shared model params: {:,} (num. trained: {:,})".format(
        sum(p.numel() for p in model.parameters()
            if not getattr(p, "expert", False)),
        sum(p.numel() for p in model.parameters()
            if not getattr(p, "expert", False) and p.requires_grad)))

    logger.info("num. expert model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()
            if getattr(p, "expert", False)),
        sum(p.numel() for p in model.parameters()
            if getattr(p, "expert", False) and p.requires_grad),
    ))

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    # We load the valid dataset AFTER building the model
    data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
    if cfg.dataset.combine_valid_subsets:
        task.load_dataset("valid", combine=True, epoch=1)
    else:
        for valid_sub_split in cfg.dataset.valid_subset.split(","):
            task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # (optionally) Configure quantization
    if cfg.common.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=cfg.common.quantization_config_path,
            max_epoch=cfg.optimization.max_epoch,
            max_update=cfg.optimization.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if cfg.common.model_parallel_size == 1:
        trainer = Trainer(cfg, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(cfg, task, model, criterion)
    logger.info("training on {} devices (GPUs/TPUs)".format(
        cfg.distributed_training.distributed_world_size))
    logger.info(
        "max tokens per device = {} and max sentences per device = {}".format(
            cfg.dataset.max_tokens,
            cfg.dataset.batch_size,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        cfg.checkpoint,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )
    if cfg.common.tpu:
        import torch_xla.core.xla_model as xm
        xm.rendezvous("load_checkpoint")  # wait for all workers

    max_epoch = cfg.optimization.max_epoch or math.inf
    lr = trainer.get_lr()

    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while epoch_itr.next_epoch_idx <= max_epoch:
        if lr <= cfg.optimization.stop_min_lr:
            logger.info(
                f"stopping training because current learning rate ({lr}) is smaller "
                "than or equal to minimum learning rate "
                f"(--stop-min-lr={cfg.optimization.stop_min_lr})")
            break

        # train for one epoch
        valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))

    # ioPath implementation to wait for all asynchronous file writes to complete.
    if cfg.checkpoint.write_checkpoints_asynchronously:
        logger.info(
            "ioPath PathManager waiting for all asynchronous checkpoint "
            "writes to finish.")
        PathManager.async_close()
        logger.info("ioPath PathManager finished waiting.")
예제 #12
0
파일: trainer.py 프로젝트: Fei-WL/CCMT
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
        """
        Load all training state from a checkpoint file.
        rank = 0 will load the checkpoint, and then broadcast it to all
        other ranks.
        """
        extra_state, self._optim_history, last_optim_state = None, [], None

        logger.info(f"Preparing to load checkpoint {filename}")
        is_distributed = self.data_parallel_world_size > 1
        bexists = PathManager.isfile(filename)
        if bexists:
            load_on_all_ranks = (
                self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks
                # TPUs don't support broadcast yet, so load checkpoints
                # on every worker for now
                or self.tpu)

            if load_on_all_ranks or self.data_parallel_rank == 0:
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    filename, load_on_all_ranks=load_on_all_ranks)
                last_optim_state = state.get("last_optimizer_state", None)

                # If doing zero_sharding, do not broadcast global optimizer
                # state. Later we will broadcast sharded states to each rank
                # to avoid memory from exploding.
                if (not load_on_all_ranks
                        and self.cfg.distributed_training.zero_sharding == "os"
                        and "last_optimizer_state" in state
                        and is_distributed):
                    state["last_optimizer_state"] = "SHARDED"
            else:
                last_optim_state = None
                state = None

            if is_distributed and not load_on_all_ranks:
                state = distributed_utils.broadcast_object(
                    state,
                    src_rank=0,
                    group=self.data_parallel_process_group,
                    dist_device=self.device,
                )
                if self.data_parallel_rank > 0:
                    last_optim_state = state.get("last_optimizer_state", None)

            # load model parameters
            try:
                # model_dict = self.model.state_dict()
                # 1. filter out unnecessary keys
                # pretrained_dict = {k: v for k, v in state["model"].items() if k in model_dict
                #                    and v.size() == model_dict[k].size()}
                # 2. overwrite entries in the existing state dict
                # model_dict.update(pretrained_dict)
                # 3. load the new state dict
                # self.model.load_state_dict(
                #     model_dict, strict=False, model_cfg=self.cfg.model
                # )
                self.model.load_state_dict(state["model"],
                                           strict=True,
                                           model_cfg=self.cfg.model)
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state["criterion"],
                                                         strict=True)
            except Exception:
                raise Exception(
                    "Cannot load model parameters from checkpoint {}; "
                    "please ensure that the architectures match.".format(
                        filename))
            extra_state = state["extra_state"]
            self._optim_history = state["optimizer_history"]

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert (
                last_optim["criterion_name"] ==
                self.get_criterion().__class__.__name__
            ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}"
            assert (
                last_optim["optimizer_name"] ==
                self.optimizer.__class__.__name__
            ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}"

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(
                    last_optim["lr_scheduler_state"])

            if not load_on_all_ranks and is_distributed:
                last_optim_state = self.optimizer.broadcast_global_state_dict(
                    last_optim_state)
            self.optimizer.load_state_dict(last_optim_state,
                                           optimizer_overrides)

            self.set_num_updates(last_optim["num_updates"])

        if extra_state is not None:
            itr_state = extra_state["train_iterator"]
            epoch = itr_state["epoch"]

            if "previous_training_time" in extra_state:
                self._previous_training_time = extra_state[
                    "previous_training_time"]
                self._start_time = time.time()

            self.lr_step(epoch)

            if itr_state.get("version",
                             1) >= 2 and itr_state["iterations_in_epoch"] == 0:
                # reset meters at start of epoch
                reset_meters = True

            if "metrics" in extra_state and not reset_meters:
                metrics.load_state_dict(extra_state["metrics"])

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in metrics.get_meters("default"):
                    if isinstance(meter, meters.TimeMeter):
                        meter.reset()

            logger.info("Loaded checkpoint {} (epoch {} @ {} updates)".format(
                filename, epoch, self.get_num_updates()))

        else:
            logger.info("No existing checkpoint found {}".format(filename))

        return extra_state
예제 #13
0
def main(args):
    state = checkpoint_utils.load_checkpoint_to_cpu(args.checkpoint)
    ns = state["args"]
    model = state["model"]
    ns.arch = "transformer_modular"

    if (args.encoder_attention_heads_active is None
            and args.decoder_attention_heads_active is None):
        raise ValueError(
            'Either --encoder-attention-heads-active or '
            '--decoder-attention-heads-active option must be set.')
    if args.encoder_attention_heads_active is None:
        args.encoder_attention_heads_active = args.decoder_attention_heads_active

    if args.encoder_modular_layer_indices is not None:
        ns.encoder_modular_layer_indices = "({})".format(
            args.encoder_modular_layer_indices)
        model = convert_model(model, ns, coder="encoder", att_type="self_attn")
    if args.decoder_modular_layer_indices is not None:
        ns.decoder_modular_layer_indices = "({})".format(
            args.decoder_modular_layer_indices)
        model = convert_model(model, ns, coder="decoder", att_type="self_attn")
        model = convert_model(model,
                              ns,
                              coder="decoder",
                              att_type="encoder_attn")

    ctrl_enc = ModularCtrl(ns.encoder_embed_dim,
                           ns.encoder_attention_heads,
                           args.encoder_attention_heads_active,
                           hidden_depth=args.ctrl_hidden_depth,
                           hidden_dim=args.ctrl_hidden_dim,
                           ctrl_type=args.ctrl_type)
    ns.module_ctrl_hidden_depth = args.ctrl_hidden_depth
    ns.module_ctrl_hidden_dim = args.ctrl_hidden_dim
    ns.module_ctrl_type = args.ctrl_type

    for k, v in ctrl_enc.state_dict().items():
        model["encoder.module_ctrl.{}".format(k)] = v

    if not args.share_encoder_ctrl:
        if args.decoder_attention_heads_active is None:
            raise ValueError("Missing ``decoder-attention-heads-active'' "
                             "when ``share-encoder-ctrl'' is disabled.")
        ns.share_encoder_ctrl = False
        ctrl_dec = ModularCtrl(ns.decoder_embed_dim,
                               ns.decoder_attention_heads,
                               args.decoder_attention_heads_active,
                               hidden_depth=args.ctrl_hidden_depth,
                               hidden_dim=args.ctrl_hidden_dim,
                               ctrl_type=args.ctrl_type)
        for k, v in ctrl_dec.state_dict().items():
            model["decoder.module_ctrl.{}".format(k)] = v
    else:
        ns.share_encoder_ctrl = True

    ns.arch = "transformer_modular"
    ns.criterion = "label_smoothed_cross_entropy_modular"
    ns.task = "translation_modular"
    ns.encoder_attention_heads_active = args.encoder_attention_heads_active

    state["args"] = ns
    state["model"] = model

    for i, _ in enumerate(state["optimizer_history"]):
        state["optimizer_history"][i][
            "criterion_name"] = 'LabelSmoothedCrossEntropyModularCriterion'

    state = utils.move_to_cpu(state)

    with PathManager.open(args.save_as, "wb") as f:
        checkpoint_utils.torch_persistent_save(state, f)
예제 #14
0
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
    """
    Load a checkpoint and restore the training iterator.

    *passthrough_args* will be passed through to
    ``trainer.get_train_iterator``.
    """

    reset_optimizer = cfg.reset_optimizer
    reset_lr_scheduler = cfg.reset_lr_scheduler
    optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
    reset_meters = cfg.reset_meters
    reset_dataloader = cfg.reset_dataloader

    if cfg.finetune_from_model is not None and (reset_optimizer
                                                or reset_lr_scheduler
                                                or reset_meters
                                                or reset_dataloader):
        raise ValueError(
            "--finetune-from-model can not be set together with either --reset-optimizer"
            " or reset_lr_scheduler or reset_meters or reset_dataloader")

    suffix = trainer.checkpoint_suffix
    if (cfg.restore_file == "checkpoint_last.pt"
        ):  # default value of restore_file is 'checkpoint_last.pt'
        checkpoint_path = os.path.join(cfg.save_dir,
                                       "checkpoint_last{}.pt".format(suffix))
        first_launch = not PathManager.exists(checkpoint_path)
        if cfg.finetune_from_model is not None and first_launch:
            # if there is no last checkpoint to restore, start the finetune from pretrained model
            # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
            if PathManager.exists(cfg.finetune_from_model):
                checkpoint_path = cfg.finetune_from_model
                reset_optimizer = True
                reset_lr_scheduler = True
                reset_meters = True
                reset_dataloader = True
                logger.info(
                    f"loading pretrained model from {checkpoint_path}: "
                    "optimizer, lr scheduler, meters, dataloader will be reset"
                )
            else:
                raise ValueError(
                    f"--funetune-from-model {cfg.finetune_from_model} does not exist"
                )
    elif suffix is not None:
        checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
    else:
        checkpoint_path = cfg.restore_file

    if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
        raise ValueError(
            "--finetune-from-model and --restore-file (non-default value) "
            "can not be specified together: " + str(cfg))

    extra_state = trainer.load_checkpoint(
        checkpoint_path,
        reset_optimizer,
        reset_lr_scheduler,
        optimizer_overrides,
        reset_meters=reset_meters,
    )

    if (extra_state is not None and "best" in extra_state
            and not reset_optimizer and not reset_meters):
        save_checkpoint.best = extra_state["best"]

    if extra_state is not None and not reset_dataloader:
        # restore iterator from checkpoint
        itr_state = extra_state["train_iterator"]
        epoch_itr = trainer.get_train_iterator(epoch=itr_state["epoch"],
                                               load_dataset=True,
                                               **passthrough_args)
        epoch_itr.load_state_dict(itr_state)
    else:
        epoch_itr = trainer.get_train_iterator(epoch=1,
                                               load_dataset=True,
                                               **passthrough_args)

    trainer.lr_step(epoch_itr.epoch)

    return extra_state, epoch_itr
예제 #15
0
 def exists(path):
     return PathManager.exists(path)
예제 #16
0
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility).

    If doing single-GPU training or if the checkpoint is only being loaded by at
    most one process on each node (current default behavior is for only rank 0
    to read the checkpoint from disk), load_on_all_ranks should be False to
    avoid errors from torch.distributed not having been initialized or
    torch.distributed.barrier() hanging.

    If all processes on each node may be loading the checkpoint
    simultaneously, load_on_all_ranks should be set to True to avoid I/O
    conflicts.

    There's currently no support for > 1 but < all processes loading the
    checkpoint on each node.
    """
    local_path = PathManager.get_local_path(path)
    # The locally cached file returned by get_local_path() may be stale for
    # remote files that are periodically updated/overwritten (ex:
    # checkpoint_last.pt) - so we remove the local copy, sync across processes
    # (if needed), and then download a fresh copy.
    if local_path != path and PathManager.path_requires_pathmanager(path):
        try:
            os.remove(local_path)
        except FileNotFoundError:
            # With potentially multiple processes removing the same file, the
            # file being missing is benign (missing_ok isn't available until
            # Python 3.8).
            pass
        if load_on_all_ranks:
            torch.distributed.barrier()
        local_path = PathManager.get_local_path(path)

    with open(local_path, "rb") as f:
        state = torch.load(f, map_location=torch.device("cpu"))

    if "args" in state and state[
            "args"] is not None and arg_overrides is not None:
        args = state["args"]
        for arg_name, arg_val in arg_overrides.items():
            setattr(args, arg_name, arg_val)

    if "cfg" in state and state["cfg"] is not None:

        # hack to be able to set Namespace in dict config. this should be removed when we update to newer
        # omegaconf version that supports object flags, or when we migrate all existing models
        from omegaconf import _utils

        old_primitive = _utils.is_primitive_type
        _utils.is_primitive_type = lambda _: True

        state["cfg"] = OmegaConf.create(state["cfg"])

        _utils.is_primitive_type = old_primitive
        OmegaConf.set_struct(state["cfg"], True)

        if arg_overrides is not None:
            overwrite_args_by_name(state["cfg"], arg_overrides)

    state = _upgrade_state_dict(state)
    return state
예제 #17
0
 def exists(path):
     return PathManager.exists(index_file_path(path)) and PathManager.exists(
         data_file_path(path)
     )
예제 #18
0
def load_model_ensemble_and_task(
    filenames,
    arg_overrides: Optional[Dict[str, Any]] = None,
    task=None,
    strict=True,
    suffix="",
    num_shards=1,
    state=None,
):
    assert state is None or len(filenames) == 1

    from fairseq import tasks

    assert not (
        strict and num_shards > 1
    ), "Cannot load state dict with strict=True and checkpoint shards > 1"
    ensemble = []
    cfg = None
    for filename in filenames:
        orig_filename = filename
        model_shard_state = {"shard_weights": [], "shard_metadata": []}
        assert num_shards > 0
        st = time.time()
        for shard_idx in range(num_shards):
            filename = get_maybe_sharded_checkpoint_filename(
                orig_filename, suffix, shard_idx, num_shards)

            if not PathManager.exists(filename):
                raise IOError("Model file not found: {}".format(filename))
            if state is None:
                state = load_checkpoint_to_cpu(filename, arg_overrides)
            if "args" in state and state["args"] is not None:
                cfg = convert_namespace_to_omegaconf(state["args"])
            elif "cfg" in state and state["cfg"] is not None:
                cfg = state["cfg"]
            else:
                raise RuntimeError(
                    f"Neither args nor cfg exist in state keys = {state.keys()}"
                )

            if task is None:
                task = tasks.setup_task(cfg.task)

            if "task_state" in state:
                task.load_state_dict(state["task_state"])

            if "fsdp_metadata" in state and num_shards > 1:
                model_shard_state["shard_weights"].append(state["model"])
                model_shard_state["shard_metadata"].append(
                    state["fsdp_metadata"])
                # check FSDP import before the code goes too far
                if not has_FSDP:
                    raise ImportError(
                        "Cannot find FullyShardedDataParallel. "
                        "Please install fairscale with: pip install fairscale")
                if shard_idx == num_shards - 1:
                    consolidated_model_state = FSDP.consolidate_shard_weights(
                        shard_weights=model_shard_state["shard_weights"],
                        shard_metadata=model_shard_state["shard_metadata"],
                    )
                    model = task.build_model(cfg.model)
                    model.load_state_dict(consolidated_model_state,
                                          strict=strict,
                                          model_cfg=cfg.model)
            else:
                # model parallel checkpoint or unsharded checkpoint
                model = task.build_model(cfg.model)
                model.load_state_dict(state["model"],
                                      strict=strict,
                                      model_cfg=cfg.model)

            # reset state so it gets loaded for the next model in ensemble
            state = None
            if shard_idx % 10 == 0 and shard_idx > 0:
                elapsed = time.time() - st
                logger.info(
                    f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
                )

        # build model for ensemble
        ensemble.append(model)
    return ensemble, cfg, task
예제 #19
0
def upgrade_state_dict_with_infoxlm_weights(
        state_dict: Dict[str, Any],
        pretrained_infoxlm_checkpoint: str,
        num_layers: int,
        shared_cross_attn: bool = False) -> Dict[str, Any]:
    """
    Load XLM weights into a Transformer encoder or decoder model.

    Args:
        state_dict: state dict for either TransformerEncoder or
            TransformerDecoder
        pretrained_infoxlm_checkpoint: checkpoint to load XLM weights from

    Raises:
        AssertionError: If architecture (num layers, attention heads, etc.)
            does not match between the current Transformer encoder or
            decoder and the pretrained_xlm_checkpoint
    """
    if not os.path.exists(pretrained_infoxlm_checkpoint):
        raise IOError(
            "Model file not found: {}".format(pretrained_infoxlm_checkpoint))

    # state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_infoxlm_checkpoint)
    with open(PathManager.get_local_path(pretrained_infoxlm_checkpoint),
              "rb") as f:
        state = torch.load(f, map_location=torch.device("cpu"))
    infoxlm_state_dict = state["model"]
    # print(state_dict.keys())

    for key in infoxlm_state_dict.keys():
        if 'layers' in key and int(key.split('.')[3]) > num_layers - 1:
            continue
        if not key.startswith('decoder.'):
            continue
        if 'lm_head' not in key:
            if 'in_proj_weight' in key:
                q, k, v = infoxlm_state_dict[key].chunk(3, dim=0)
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_weight',
                                                   'q_proj.weight')] = q
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_weight',
                                                   'k_proj.weight')] = k
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_weight',
                                                   'v_proj.weight')] = v
                if shared_cross_attn:
                    state_dict[key.replace(
                        'decoder.sentence_encoder.',
                        '').replace('in_proj_weight', 'q_proj.weight').replace(
                            'self_attn', 'encoder_attn')] = q
                    state_dict[key.replace(
                        'decoder.sentence_encoder.',
                        '').replace('in_proj_weight', 'k_proj.weight').replace(
                            'self_attn', 'encoder_attn')] = k
                    state_dict[key.replace(
                        'decoder.sentence_encoder.',
                        '').replace('in_proj_weight', 'v_proj.weight').replace(
                            'self_attn', 'encoder_attn')] = v
            elif 'in_proj_bias' in key:
                q, k, v = infoxlm_state_dict[key].chunk(3, dim=0)
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_bias',
                                                   'q_proj.bias')] = q
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_bias',
                                                   'k_proj.bias')] = k
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_bias',
                                                   'v_proj.bias')] = v
                if shared_cross_attn:
                    state_dict[key.replace('decoder.sentence_encoder.',
                                           '').replace('in_proj_bias',
                                                       'q_proj.bias').replace(
                                                           'self_attn',
                                                           'encoder_attn')] = q
                    state_dict[key.replace('decoder.sentence_encoder.',
                                           '').replace('in_proj_bias',
                                                       'k_proj.bias').replace(
                                                           'self_attn',
                                                           'encoder_attn')] = k
                    state_dict[key.replace('decoder.sentence_encoder.',
                                           '').replace('in_proj_bias',
                                                       'v_proj.bias').replace(
                                                           'self_attn',
                                                           'encoder_attn')] = v
            elif 'emb_layer_norm' in key:
                state_dict[key.replace(
                    'decoder.sentence_encoder.emb_layer_norm',
                    'layernorm_embedding')] = infoxlm_state_dict[key]
            elif 'embed_positions' in key:
                state_dict[key.replace(
                    'decoder.sentence_encoder.',
                    '')] = infoxlm_state_dict[key][:state_dict[key.replace(
                        'decoder.sentence_encoder.', '')].size(0)]
            elif 'embed_tokens' in key:
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '')][:infoxlm_state_dict[key].
                                            size(0)] = infoxlm_state_dict[key]
            else:
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '')] = infoxlm_state_dict[key]

    return state_dict