예제 #1
0
def load_model_ensemble_and_task(filenames,
                                 arg_overrides=None,
                                 task=None,
                                 strict=True,
                                 suffix="",
                                 num_shards=1):
    from fairseq import tasks

    assert not (
        strict and num_shards > 1
    ), "Cannot load state dict with strict=True and checkpoint shards > 1"
    ensemble = []
    for filename in filenames:
        orig_filename = filename
        for shard_idx in range(num_shards):
            if num_shards == 1:
                filename = filename.replace(".pt", suffix + ".pt")
            else:
                filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
            if not PathManager.exists(filename):
                raise IOError("Model file not found: {}".format(filename))
            state = load_checkpoint_to_cpu(filename, arg_overrides)
            if shard_idx == 0:
                args = state["args"]
                if task is None:
                    task = tasks.setup_task(args)

                # build model for ensemble
                model = task.build_model(args)
            model.load_state_dict(state["model"], strict=strict, args=args)
        ensemble.append(model)
    return ensemble, args, task
def load_pretrained_component_from_model(component: Union[FairseqEncoder,
                                                          FairseqDecoder],
                                         checkpoint: str):
    """
    Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
    provided `component` object. If state_dict fails to load, there may be a
    mismatch in the architecture of the corresponding `component` found in the
    `checkpoint` file.
    """
    if not PathManager.exists(checkpoint):
        raise IOError("Model file not found: {}".format(checkpoint))
    state = load_checkpoint_to_cpu(checkpoint)
    if isinstance(component, FairseqEncoder):
        component_type = "encoder"
    elif isinstance(component, FairseqDecoder):
        component_type = "decoder"
    else:
        raise ValueError(
            "component to load must be either a FairseqEncoder or "
            "FairseqDecoder. Loading other component types are not supported.")
    component_state_dict = OrderedDict()
    for key in state["model"].keys():
        if key.startswith(component_type):
            # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
            component_subkey = key[len(component_type) + 1:]
            component_state_dict[component_subkey] = state["model"][key]
    component.load_state_dict(component_state_dict, strict=True)
    return component
예제 #3
0
def load_model_ensemble_and_task(filenames,
                                 arg_overrides=None,
                                 task=None,
                                 strict=True,
                                 suffix=''):
    from fairseq import tasks

    ensemble = []
    for filename in filenames:
        filename = filename.replace(".pt", suffix + ".pt")
        if not PathManager.exists(filename):
            raise IOError("Model file not found: {}".format(filename))
        state = load_checkpoint_to_cpu(filename, arg_overrides)

        args = state["args"]

        logger.info('[load_model_ensemble_and_task[data]:] {}'.format(
            args.data))

        if task is None:
            task = tasks.setup_task(args)

        # build model for ensemble
        model = task.build_model(args)
        model.load_state_dict(state["model"], strict=strict, args=args)
        ensemble.append(model)
    return ensemble, args, task
예제 #4
0
def load_model_ensemble_and_task(filenames,
                                 arg_overrides=None,
                                 task=None,
                                 strict=True,
                                 suffix=''):
    from fairseq import tasks

    ensemble = []
    for filename in filenames:
        filename = filename.replace(".pt", suffix + ".pt")
        if not PathManager.exists(filename):
            raise IOError("Model file not found: {}".format(filename))
        state = load_checkpoint_to_cpu(filename, arg_overrides)

        args = state["args"]
        if task is None:
            task = tasks.setup_task(args)

        # build model for ensemble
        model = task.build_model(args)
        states = state["model"]
        if hasattr(args, 'mixout') and args.mixout > 0:
            for k, v in list(states.items()):
                if '._params_learned' in k:
                    del states[k]
                    states[k.replace('._params_learned', '')] = v

        model.load_state_dict(states, strict=strict, args=args)
        ensemble.append(model)
    return ensemble, args, task
예제 #5
0
def load_bert_state(model, checkpoint):
    print('Load pretrained data augmentation checkpoint (BERT)')
    if not PathManager.exists(checkpoint):
        raise IOError("Model file not found: {}".format(checkpoint))

    from torch.serialization import default_restore_location
    state = torch.load(
        checkpoint,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))

    def upgrade(obj):
        if isinstance(obj, OrderedDict):
            oldkeys = list(obj.keys())
            for k in oldkeys:
                if k.startswith('encoder') and k != 'encoder':
                    newkey = k.split('.', 1)[1]
                else:
                    newkey = k
                obj[newkey] = upgrade(obj[k])
                if k.startswith('encoder'):
                    del obj[k]
        else:
            return obj

    upgrade(state['model'])
    try:
        model.load_state_dict(state['model'], strict=True)
    except Exception:
        raise Exception(
            'Cannot load model parameters from pretrained augmentation model checkpoint, '
            'please ensure that the architectures match')
    return True
예제 #6
0
def load_model_ensemble_and_task(
    filenames,
    arg_overrides: Optional[Dict[str, Any]] = None,
    task=None,
    strict=True,
    suffix="",
    num_shards=1,
    state=None,
):
    assert state is None or len(filenames) == 1

    from fairseq import tasks

    assert not (
        strict and num_shards > 1
    ), "Cannot load state dict with strict=True and checkpoint shards > 1"
    ensemble = []
    cfg = None
    for filename in filenames:
        orig_filename = filename
        assert num_shards > 0
        for shard_idx in range(num_shards):
            if num_shards == 1:
                filename = filename.replace(".pt", suffix + ".pt")
            else:
                filename = orig_filename[:-3] + f"_part{shard_idx}.pt"

            if not PathManager.exists(filename):
                raise IOError("Model file not found: {}".format(filename))
            if state is None:
                state = load_checkpoint_to_cpu(filename, arg_overrides)
            if "args" in state and state["args"] is not None:
                cfg = convert_namespace_to_omegaconf(state["args"])
            elif "cfg" in state and state["cfg"] is not None:
                cfg = state["cfg"]
            else:
                raise RuntimeError(
                    f"Neither args nor cfg exist in state keys = {state.keys()}"
                )

            if task is None:
                task = tasks.setup_task(cfg.task)

            if "task_state" in state:
                task.load_state_dict(state["task_state"])

            # build model for ensemble
            model = task.build_model(cfg.model)

            model.load_state_dict(state["model"],
                                  strict=strict,
                                  model_cfg=cfg.model)

            # reset state so it gets loaded for the next model in ensemble
            state = None

        ensemble.append(model)
    return ensemble, cfg, task
예제 #7
0
def load_xlmt_model_ensemble(filenames,
                             arg_overrides=None,
                             strict=True,
                             suffix="",
                             num_shards=1,
                             state=None,
                             src_dict=None,
                             tgt_dict=None):
    assert state is None or len(filenames) == 1
    from fairseq.models.xlmt_decoder_variant import XLMTDecoderVariantModel

    assert not (
        strict and num_shards > 1
    ), "Cannot load state dict with strict=True and checkpoint shards > 1"
    ensemble = []
    cfg = None
    for filename in filenames:
        orig_filename = filename
        assert num_shards > 0
        for shard_idx in range(num_shards):
            if num_shards == 1:
                filename = filename.replace(".pt", suffix + ".pt")
            else:
                filename = orig_filename[:-3] + f"_part{shard_idx}.pt"

            if not PathManager.exists(filename):
                raise IOError("Model file not found: {}".format(filename))
            if state is None:
                state = load_checkpoint_to_cpu(filename, arg_overrides)
            if "args" in state and state["args"] is not None:
                cfg = convert_namespace_to_omegaconf(state["args"])
            elif "cfg" in state and state["cfg"] is not None:
                cfg = state["cfg"]
            else:
                raise RuntimeError(
                    f"Neither args nor cfg exist in state keys = {state.keys()}"
                )

            # build model for ensemble
            model = XLMTDecoderVariantModel.build_model_without_task(
                cfg.model, src_dict, tgt_dict)

            state = expand_embedding_matrix(state, model)
            model.load_state_dict(state["model"],
                                  strict=strict,
                                  model_cfg=cfg.model)

            # reset state so it gets loaded for the next model in ensemble
            state = None

        ensemble.append(model)
    return ensemble, cfg
예제 #8
0
def get_maybe_sharded_checkpoint_filename(
    filename: str, suffix: str, shard_idx: int, num_shards: int
) -> str:
    orig_filename = filename
    filename = filename.replace(".pt", suffix + ".pt")
    fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
    model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
    if PathManager.exists(fsdp_filename):
        return fsdp_filename
    elif num_shards > 1:
        return model_parallel_filename
    else:
        return filename
예제 #9
0
            def load_feature_extractor(component, checkpoint):
                if not PathManager.exists(checkpoint):
                    raise IOError(
                        "Model file not found: {}".format(checkpoint))
                state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
                component_state_dict = OrderedDict()

                component_prefix = "feature_extractor"
                for key in state["model"].keys():
                    if key.startswith(component_prefix):
                        component_subkey = key[len(component_prefix) + 1:]
                        component_state_dict[component_subkey] = state[
                            "model"][key]
                component.load_state_dict(component_state_dict, strict=True)
                return component
예제 #10
0
 def load_pretrained_speech_text_components(cls, checkpoint,
                                            component_pairs):
     if not PathManager.exists(checkpoint):
         raise IOError("Model file not found: {}".format(checkpoint))
     state = load_checkpoint_to_cpu(checkpoint)
     for component_type, component in component_pairs:
         if isinstance(component, nn.parameter.Parameter):
             component.data.copy_(state["model"][component_type])
         else:
             component_state_dict = OrderedDict()
             for key in state["model"].keys():
                 if key.startswith(component_type):
                     component_subkey = key[len(component_type) + 1:]
                     component_state_dict[component_subkey] = state[
                         "model"][key]
             component.load_state_dict(component_state_dict, strict=True)
     return state
예제 #11
0
def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None):
    from fairseq import tasks

    ensemble = []
    for filename in filenames:
        if not PathManager.exists(filename):
            raise IOError("Model file not found: {}".format(filename))
        state = load_checkpoint_to_cpu(filename, arg_overrides)

        args = state["args"]
        if task is None:
            task = tasks.setup_task(args)

        # build model for ensemble
        model = task.build_model(args)
        model.load_state_dict(state["model"], strict=True, args=args)
        ensemble.append(model)
    return ensemble, args, task
 def exists(prefix_path):
     return (
         PathManager.exists(indexed_dataset.index_file_path(prefix_path))
         and PathManager.exists(indexed_dataset.data_file_path(prefix_path))
         and PathManager.exists(vocab_file_path(prefix_path)))
예제 #13
0
 def exists(path):
     return PathManager.exists(
         index_file_path(path)) and PathManager.exists(data_file_path(path))
예제 #14
0
 def exists(path):
     return PathManager.exists(path)
예제 #15
0
def load_model_ensemble_and_task(
    filenames,
    arg_overrides: Optional[Dict[str, Any]] = None,
    task=None,
    strict=True,
    suffix="",
    num_shards=1,
    state=None,
):
    assert state is None or len(filenames) == 1

    from fairseq import tasks

    assert not (
        strict and num_shards > 1
    ), "Cannot load state dict with strict=True and checkpoint shards > 1"
    ensemble = []
    cfg = None
    for filename in filenames:
        orig_filename = filename
        model_shard_state = {"shard_weights": [], "shard_metadata": []}
        assert num_shards > 0
        st = time.time()
        for shard_idx in range(num_shards):
            filename = get_maybe_sharded_checkpoint_filename(
                orig_filename, suffix, shard_idx, num_shards
            )

            if not PathManager.exists(filename):
                raise IOError("Model file not found: {}".format(filename))
            if state is None:
                state = load_checkpoint_to_cpu(filename, arg_overrides)
            if "args" in state and state["args"] is not None:
                cfg = convert_namespace_to_omegaconf(state["args"])
            elif "cfg" in state and state["cfg"] is not None:
                cfg = state["cfg"]
            else:
                raise RuntimeError(
                    f"Neither args nor cfg exist in state keys = {state.keys()}"
                )

            if task is None:
                task = tasks.setup_task(cfg.task)

            if "task_state" in state:
                task.load_state_dict(state["task_state"])

            if "fsdp_metadata" in state and num_shards > 1:
                model_shard_state["shard_weights"].append(state["model"])
                model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
                # check FSDP import before the code goes too far
                if not has_FSDP:
                    raise ImportError(
                        "Cannot find FullyShardedDataParallel. "
                        "Please install fairscale with: pip install fairscale"
                    )
                if shard_idx == num_shards - 1:
                    consolidated_model_state = FSDP.consolidate_shard_weights(
                        shard_weights=model_shard_state["shard_weights"],
                        shard_metadata=model_shard_state["shard_metadata"],
                    )
                    model = task.build_model(cfg.model)
                    if (
                        "optimizer_history" in state
                        and len(state["optimizer_history"]) > 0
                        and "num_updates" in state["optimizer_history"][-1]
                    ):
                        model.set_num_updates(
                            state["optimizer_history"][-1]["num_updates"]
                        )
                    model.load_state_dict(
                        consolidated_model_state, strict=strict, model_cfg=cfg.model
                    )
            else:
                # model parallel checkpoint or unsharded checkpoint
                # support old external tasks

                argspec = inspect.getfullargspec(task.build_model)
                if "from_checkpoint" in argspec.args:
                    model = task.build_model(cfg.model, from_checkpoint=True)
                else:
                    model = task.build_model(cfg.model)
                if (
                    "optimizer_history" in state
                    and len(state["optimizer_history"]) > 0
                    and "num_updates" in state["optimizer_history"][-1]
                ):
                    model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
                model.load_state_dict(
                    state["model"], strict=strict, model_cfg=cfg.model
                )

            # reset state so it gets loaded for the next model in ensemble
            state = None
            if shard_idx % 10 == 0 and shard_idx > 0:
                elapsed = time.time() - st
                logger.info(
                    f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
                )

        # build model for ensemble
        ensemble.append(model)
    return ensemble, cfg, task
예제 #16
0
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
    from fairseq import meters

    # only one worker should attempt to create the required dir
    if trainer.data_parallel_rank == 0:
        os.makedirs(cfg.save_dir, exist_ok=True)

    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        best_function = max if cfg.maximize_best_checkpoint_metric else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if cfg.no_save:
        return

    trainer.consolidate_optimizer()  # TODO(SS): do we need this if no_save_optimizer_state

    if not trainer.should_save_checkpoint_on_current_rank:
        if trainer.always_call_state_dict_during_save_checkpoint:
            trainer.state_dict()
        return

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")

    def is_better(a, b):
        return a >= b if cfg.maximize_best_checkpoint_metric else a <= b

    suffix = trainer.checkpoint_suffix
    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
        end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
    )
    checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
        not end_of_epoch
        and cfg.save_interval_updates > 0
        and updates % cfg.save_interval_updates == 0
    )
    checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
        not hasattr(save_checkpoint, "best")
        or is_better(val_loss, save_checkpoint.best)
    )
    if val_loss is not None and cfg.keep_best_checkpoints > 0:
        worst_best = getattr(save_checkpoint, "best", None)
        chkpts = checkpoint_paths(
            cfg.save_dir,
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
                cfg.best_checkpoint_metric, suffix
            ),
        )
        if len(chkpts) > 0:
            p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
            worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
        # add random digits to resolve ties
        with data_utils.numpy_seed(epoch, updates, val_loss):
            rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)

        checkpoint_conds[
            "checkpoint.best_{}_{:.3f}{}{}.pt".format(
                cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
            )
        ] = worst_best is None or is_better(val_loss, worst_best)
    checkpoint_conds[
        "checkpoint_last{}.pt".format(suffix)
    ] = not cfg.no_last_checkpoints

    extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
    if hasattr(save_checkpoint, "best"):
        extra_state.update({"best": save_checkpoint.best})

    checkpoints = [
        os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            if cfg.write_checkpoints_asynchronously:
                # TODO[ioPath]: Need to implement a delayed asynchronous
                # file copying/moving feature.
                logger.warning(
                    f"ioPath is not copying {checkpoints[0]} to {cp} "
                    "since async write mode is on."
                )
            else:
                assert PathManager.copy(
                    checkpoints[0], cp, overwrite=True
                ), f"Failed to copy {checkpoints[0]} to {cp}"

        write_timer.stop()
        logger.info(
            "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
                checkpoints[0], epoch, updates, val_loss, write_timer.sum
            )
        )

    if not end_of_epoch and cfg.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        if cfg.keep_interval_updates_pattern == -1:
            checkpoints = checkpoint_paths(
                cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
            )
        else:
            checkpoints = checkpoint_paths(
                cfg.save_dir,
                pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
                keep_match=True,
            )
            checkpoints = [
                x[0]
                for x in checkpoints
                if x[1] % cfg.keep_interval_updates_pattern != 0
            ]

        for old_chk in checkpoints[cfg.keep_interval_updates :]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
            elif PathManager.exists(old_chk):
                PathManager.rm(old_chk)

    if cfg.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(
            cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
        )
        for old_chk in checkpoints[cfg.keep_last_epochs :]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
            elif PathManager.exists(old_chk):
                PathManager.rm(old_chk)

    if cfg.keep_best_checkpoints > 0:
        # only keep the best N checkpoints according to validation metric
        checkpoints = checkpoint_paths(
            cfg.save_dir,
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
                cfg.best_checkpoint_metric, suffix
            ),
        )
        if not cfg.maximize_best_checkpoint_metric:
            checkpoints = checkpoints[::-1]
        for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
            elif PathManager.exists(old_chk):
                PathManager.rm(old_chk)
예제 #17
0
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
    """
    Load a checkpoint and restore the training iterator.

    *passthrough_args* will be passed through to
    ``trainer.get_train_iterator``.
    """

    reset_optimizer = cfg.reset_optimizer
    reset_lr_scheduler = cfg.reset_lr_scheduler
    optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
    reset_meters = cfg.reset_meters
    reset_dataloader = cfg.reset_dataloader

    if cfg.finetune_from_model is not None and (
        reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
    ):
        raise ValueError(
            "--finetune-from-model can not be set together with either --reset-optimizer"
            " or reset_lr_scheduler or reset_meters or reset_dataloader"
        )

    suffix = trainer.checkpoint_suffix
    if (
        cfg.restore_file == "checkpoint_last.pt"
    ):  # default value of restore_file is 'checkpoint_last.pt'
        checkpoint_path = os.path.join(
            cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
        )
        first_launch = not PathManager.exists(checkpoint_path)
        if first_launch and getattr(cfg, "continue_once", None) is not None:
            checkpoint_path = cfg.continue_once
        elif cfg.finetune_from_model is not None and first_launch:
            # if there is no last checkpoint to restore, start the finetune from pretrained model
            # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
            if PathManager.exists(cfg.finetune_from_model):
                checkpoint_path = cfg.finetune_from_model
                reset_optimizer = True
                reset_lr_scheduler = True
                reset_meters = True
                reset_dataloader = True
                logger.info(
                    f"loading pretrained model from {checkpoint_path}: "
                    "optimizer, lr scheduler, meters, dataloader will be reset"
                )
            else:
                raise ValueError(
                    f"--finetune-from-model {cfg.finetune_from_model} does not exist"
                )
    elif suffix is not None:
        checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
    else:
        checkpoint_path = cfg.restore_file

    if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
        raise ValueError(
            "--finetune-from-model and --restore-file (non-default value) "
            "can not be specified together: " + str(cfg)
        )

    extra_state = trainer.load_checkpoint(
        checkpoint_path,
        reset_optimizer,
        reset_lr_scheduler,
        optimizer_overrides,
        reset_meters=reset_meters,
    )

    if (
        extra_state is not None
        and "best" in extra_state
        and not reset_optimizer
        and not reset_meters
    ):
        save_checkpoint.best = extra_state["best"]

    if extra_state is not None and not reset_dataloader:
        # restore iterator from checkpoint
        itr_state = extra_state["train_iterator"]
        epoch_itr = trainer.get_train_iterator(
            epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
        )
        epoch_itr.load_state_dict(itr_state)
    else:
        epoch_itr = trainer.get_train_iterator(
            epoch=1, load_dataset=True, **passthrough_args
        )

    trainer.lr_step(epoch_itr.epoch)

    return extra_state, epoch_itr
예제 #18
0
def load_checkpoint(args, trainer, **passthrough_args):
    """
    Load a checkpoint and restore the training iterator.

    *passthrough_args* will be passed through to
    ``trainer.get_train_iterator``.
    """
    reset_optimizer = args.reset_optimizer
    reset_lr_scheduler = args.reset_lr_scheduler
    optimizer_overrides = eval(args.optimizer_overrides)
    reset_meters = args.reset_meters
    reset_dataloader = args.reset_dataloader

    if getattr(args, 'finetune_from_model', None) is not None \
       and (reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader):
        raise ValueError(
            "--finetune-from-model can not be set together with either --reset-optimizer"
            " or reset_lr_scheduler or reset_meters or reset_dataloader")

    suffix = getattr(args, "checkpoint_suffix", "")
    if args.restore_file == "checkpoint_last.pt":  # default value of restore_file is 'checkpoint_last.pt'
        checkpoint_path = os.path.join(args.save_dir,
                                       "checkpoint_last{}.pt".format(suffix))
        first_launch = not PathManager.exists(checkpoint_path)
        if getattr(args, 'finetune_from_model',
                   None) is not None and first_launch:
            # if there is no last checkpoint to restore, start the finetune from pretrained model
            # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
            if PathManager.exists(args.finetune_from_model):
                checkpoint_path = args.finetune_from_model
                reset_optimizer = True
                reset_lr_scheduler = True
                reset_meters = True
                reset_dataloader = True
                logger.info(
                    f'loading pretrained model from {checkpoint_path}: '
                    'optimizer, lr scheduler, meters, dataloader will be reset'
                )
            else:
                raise ValueError(
                    f'--funetune-from-model {args.finetune_from_model} does not exist'
                )
    elif getattr(args, "model_parallel_size", 1) > 1:
        checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
    else:
        checkpoint_path = args.restore_file

    if args.restore_file != "checkpoint_last.pt" and getattr(
            args, 'finetune_from_model', None):
        raise ValueError(
            '--finetune-from-model and --restore-file (non-default value) '
            'can not be specified together: ' + str(args))

    extra_state = trainer.load_checkpoint(
        checkpoint_path,
        reset_optimizer,
        reset_lr_scheduler,
        optimizer_overrides,
        reset_meters=reset_meters,
    )

    if (extra_state is not None and "best" in extra_state
            and not reset_optimizer and not reset_meters):
        save_checkpoint.best = extra_state["best"]

    if extra_state is not None and not reset_dataloader:
        # restore iterator from checkpoint
        itr_state = extra_state["train_iterator"]
        epoch_itrs = trainer.get_train_iterator(epoch=itr_state["epoch"],
                                                load_dataset=True,
                                                **passthrough_args)
        epoch_itrs.load_state_dict(itr_state)
    else:
        epoch_itrs = trainer.get_train_iterator(epoch=1,
                                                load_dataset=True,
                                                **passthrough_args)
    if isinstance(epoch_itrs, list):
        trainer.lr_step(epoch_itrs[0].epoch)
    else:
        trainer.lr_step(epoch_itrs.epoch)

    return extra_state, epoch_itrs
예제 #19
0
def load_model_ensemble_and_task(
    filenames,
    arg_overrides: Optional[Dict[str, Any]] = None,
    task=None,
    strict=True,
    suffix="",
    num_shards=1,
    state=None,
):
    assert state is None or len(filenames) == 1

    from fairseq import tasks

    assert not (
        strict and num_shards > 1
    ), "Cannot load state dict with strict=True and checkpoint shards > 1"
    ensemble = []
    cfg = None
    for filename in filenames:
        orig_filename = filename
        model_shard_state = {"shard_weights": [], "shard_metadata": []}
        assert num_shards > 0
        st = time.time()
        for shard_idx in range(num_shards):
            filename = get_maybe_sharded_checkpoint_filename(
                orig_filename, suffix, shard_idx, num_shards
            )

            if not PathManager.exists(filename):
                raise IOError("Model file not found: {}".format(filename))
            if state is None:
                state = load_checkpoint_to_cpu(filename, arg_overrides)
            if "args" in state and state["args"] is not None:
                cfg = convert_namespace_to_omegaconf(state["args"])
            elif "cfg" in state and state["cfg"] is not None:
                cfg = state["cfg"]
            else:
                raise RuntimeError(
                    f"Neither args nor cfg exist in state keys = {state.keys()}"
                )

            if task is None:
                task = tasks.setup_task(cfg.task)

            if "task_state" in state:
                task.load_state_dict(state["task_state"])

            if "fsdp_metadata" in state and num_shards > 1:
                model_shard_state["shard_weights"].append(state["model"])
                model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
                # check FSDP import before the code goes too far
                if not has_FSDP:
                    raise ImportError(
                        "Cannot find FullyShardedDataParallel. "
                        "Please install fairscale with: pip install fairscale"
                    )
                if shard_idx == num_shards - 1:
                    consolidated_model_state = FSDP.consolidate_shard_weights(
                        shard_weights=model_shard_state["shard_weights"],
                        shard_metadata=model_shard_state["shard_metadata"],
                    )
                    model = task.build_model(cfg.model)
                    model.load_state_dict(
                        consolidated_model_state, strict=strict, model_cfg=cfg.model
                    )
            else:
                # model parallel checkpoint or unsharded checkpoint
                model = task.build_model(cfg.model)
                new_state_model = state["model"]

                '''=====The following if-else statement is a work-around =====
                # the current metadata loading/saving of pytorch.
                # In Pytorch, if state["model"]["_metadata"] exists as dictionary, then model.load_state_dict(strict=True)
                # will throw an error for unexpected "_metadata" key. To avoid this error, we need the state_dict to be
                # in orderedDict format, which has new_state_model._metadata attribute but not as key.
                # TODO yuansg@ This issue should be fixed in pytorch ideally.
                '''
                if new_state_model.get("_metadata", None) is not None:
                    new_metadata = new_state_model.get("_metadata", None)
                    del state["model"]["_metadata"]
                else:
                    new_metadata = None
                # Construct state dict content.
                contents = OrderedDict(new_state_model)
                # We explicitly set _metadata for the state models. The _metadata is implicitly stored for pytorch models.
                # calling state["model"] in fairseq will not invoke metadata storage.
                if new_metadata is None:
                    logger.warning("===Jit: state[\"model\"] does not contain key \"_metadata\"=====")
                    logger.warning("===Jit: we will be filling in with current model's meta-data instead.")
                    # For models trained before this diff, we do the following to be backward compatible.
                    contents.__setattr__("_metadata", model.state_dict()._metadata)
                else:
                    contents.__setattr__("_metadata", new_metadata)
                '''====End of work-around logic====='''

                model.load_state_dict(
                    contents, strict=strict, model_cfg=cfg.model
                )

            # reset state so it gets loaded for the next model in ensemble
            state = None
            if shard_idx % 10 == 0 and shard_idx > 0:
                elapsed = time.time() - st
                logger.info(
                    f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
                )

        # build model for ensemble
        ensemble.append(model)
    return ensemble, cfg, task