コード例 #1
0
def load_dual_checkpoint_for_inference(args, arg_overrides=None, task=None):
    model = task.build_model(args)
    import pdb
    pdb.set_trace()

    f_checkpoint_path, b_checkpoint_path = args.path.split(':')

    f_state = torch.load(
        f_checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    model_f_dict = model.modelf.state_dict()
    base_dict = {
        k: v
        for k, v in f_state['model'].items() if k in model_f_dict
    }
    model_f_dict.update(base_dict)
    model.modelf.load_state_dict(model_f_dict)

    b_state = torch.load(
        b_checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    model_b_dict = model.modelb.state_dict()
    base_dict = {
        k: v
        for k, v in b_state['model'].items() if k in model_b_dict
    }
    model_b_dict.update(base_dict)
    model.modelb.load_state_dict(model_b_dict)
    print("load dual transformer finished.")

    return [model]
コード例 #2
0
def load_nmt_state(model, checkpoint):
    print('Load pretrained data augmentation checkpoint (Transformer)')
    if not PathManager.exists(checkpoint):
        raise IOError("Model file not found: {}".format(checkpoint))

    print('load nmt encoder...')
    from torch.serialization import default_restore_location
    state = torch.load(
        checkpoint,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))

    def upgrade_encoder(obj):
        if isinstance(obj, OrderedDict):
            oldkeys = list(obj.keys())
            for k in oldkeys:
                if k.startswith('encoder') and k != 'encoder':
                    newkey = k.split('.', 1)[1]
                else:
                    newkey = k
                obj[newkey] = upgrade_encoder(obj[k])
                if k.startswith('encoder') or k.startswith('decoder'):
                    del obj[k]
        else:
            return obj

    upgrade_encoder(state['model'])
    try:
        model.encoder.load_state_dict(state['model'], strict=True)
    except Exception:
        raise Exception(
            'Cannot load nmt encoder parameters from pretrained back translation checkpoint, '
            'please ensure that the architectures match')
    print('Load nmt decoder ...')
    state = torch.load(
        checkpoint,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))

    def upgrade_decoder(obj):
        if isinstance(obj, OrderedDict):
            oldkeys = list(obj.keys())
            for k in oldkeys:
                if k.startswith('decoder') and k != 'decoder':
                    newkey = k.split('.', 1)[1]
                else:
                    newkey = k
                obj[newkey] = upgrade_decoder(obj[k])
                if k.startswith('encoder') or k.startswith('decoder'):
                    del obj[k]
        else:
            return obj

    upgrade_decoder(state['model'])
    try:
        model.decoder.load_state_dict(state['model'], strict=True)
    except Exception:
        raise Exception(
            'Cannot load nmt decoder parameters from pretrained back translation checkpoint, '
            'please ensure that the architectures match')
    return True
コード例 #3
0
def load_dual_checkpoint(args, trainer, **passthrough_args):
    """
    Load a checkpoint and restore the training iterator.

    *passthrough_args* will be passed through to
    ``trainer.get_train_iterator``.
    """
    # only one worker should attempt to create the required dir
    if args.distributed_rank == 0:
        os.makedirs(args.save_dir, exist_ok=True)

    # import pdb;pdb.set_trace()
    f_checkpoint_path = os.path.join(args.save_dir, "checkpoint_last_f.pt")
    b_checkpoint_path = os.path.join(args.save_dir, "checkpoint_last_b.pt")

    f_state = torch.load(
        f_checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    model_f_dict = trainer.model.modelf.state_dict()
    base_dict = {
        k: v
        for k, v in f_state['model'].items() if k in model_f_dict
    }
    model_f_dict.update(base_dict)
    trainer.model.modelf.load_state_dict(model_f_dict)

    for name, param in trainer.model.modelf.named_parameters():
        if 'suphead' in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

    b_state = torch.load(
        b_checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    model_b_dict = trainer.model.modelb.state_dict()
    base_dict = {
        k: v
        for k, v in b_state['model'].items() if k in model_b_dict
    }
    model_b_dict.update(base_dict)
    trainer.model.modelb.load_state_dict(model_b_dict)
    for name, param in trainer.model.modelf.named_parameters():
        if 'suphead' in name:
            param.requires_grad = True
            # print(name)
        else:
            param.requires_grad = False

    print("load dual transformer finished.")
    # import pdb;pdb.set_trace()

    epoch_itr = trainer.get_train_iterator(epoch=0,
                                           load_dataset=True,
                                           **passthrough_args)
    extra_state = None
    return extra_state, epoch_itr
コード例 #4
0
def load_model_state(filename, model, data_parallel=True):
    if not os.path.exists(filename):
        print("Starting training from scratch.")
        return 0

    print("Loading model from checkpoints", filename)
    state = torch.load(
        filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))

    from collections import OrderedDict
    new_state_dict = OrderedDict()
    # create new OrderedDict that does not contain `module.`
    if data_parallel:
        for k, v in state['model'].items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
    else:
        new_state_dict = state['model']
    # load model parameters
    try:
        model.load_state_dict(new_state_dict)
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')
    return state['num_updates']
コード例 #5
0
def load_model_state(filename):
    if not os.path.exists(filename):
        return None
    state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    state = _upgrade_state_dict(state)

    return state
コード例 #6
0
ファイル: utils.py プロジェクト: shamilcm/crosentgec-1
def load_model_state(filename, model):
    if not os.path.exists(filename):
        return None, [], None
    state = torch.load(
        filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    state = _upgrade_state_dict(state)
    model.upgrade_state_dict(state['model'])

    # load model parameters
    try:
        #model.load_state_dict(state['model'], strict=True)
        if (state['args'].arch == 'convlm'):  # fix parameter name mismatch
            for paramname in list(state['model'].keys()):
                state['model'][paramname.replace(
                    'layers', 'convolutions')] = state['model'].pop(paramname)
        model_state = model.state_dict()
        print('| mismatched parameters: {}'.format(
            set(model_state.keys()) ^ set(state['model'].keys())))
        model_state.update(state['model'])
        model.load_state_dict(model_state)
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')

    return state['extra_state'], state['optimizer_history'], state[
        'last_optimizer_state']
コード例 #7
0
def load_ensemble_for_inference(filenames, data_path):
    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        states.append(
            torch.load(
                filename,
                map_location=lambda s, l: default_restore_location(s, 'cpu')))

    # load dataset
    args = states[0]['args']
    if args.decoder_embed_path:
        args.decoder_embed_path = None
    if args.encoder_embed_path:
        args.encoder_embed_path = None
    dataset = data.load(data_path, args.source_lang, args.target_lang)
    # build models
    ensemble = []
    for state in states:
        model = build_model(args, dataset)
        model.load_state_dict(state['model'])
        ensemble.append(model)

    return ensemble, dataset
コード例 #8
0
def load_checkpoint(filename,
                    model,
                    optimizer,
                    lr_scheduler,
                    args=None,
                    cuda_device=None):
    if not os.path.exists(filename):
        return 1, 0
    if cuda_device is None:
        state = torch.load(filename)
    else:
        state = torch.load(filename,
                           map_location=lambda s, l: default_restore_location(
                               s, 'cuda:{}'.format(cuda_device)))
    '''
    print(set(state['model'].keys()) - set(model.state_dict().keys()))
    print('-----')
    print(model.state_dict().keys())
    '''

    model.load_state_dict(state['model'])
    optimizer.load_state_dict(state['optimizer'])
    if args and args.hardset_lr:
        optimizer.param_groups[0]['lr'] = args.lr
    lr_scheduler.best = state['best_loss']
    epoch = state['epoch'] + 1
    batch_offset = state['batch_offset']

    gpu_str = ' on GPU #{}'.format(
        cuda_device) if cuda_device is not None else ''
    print('| loaded checkpoint {} (epoch {}){}'.format(filename, epoch,
                                                       gpu_str))
    return epoch, batch_offset
コード例 #9
0
ファイル: utils.py プロジェクト: ahiroto/ParlAI
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
    """Load an ensemble of models for inference.

    The source and target dictionaries can be given explicitly, or loaded from
    the `data_dir` directory.
    """
    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        states.append(
            torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        )
    args = states[0]['args']
    args = _upgrade_args(args)

    if src_dict is None or dst_dict is None:
        assert data_dir is not None
        src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)

    # build ensemble
    ensemble = []
    for state in states:
        model = build_model(args, src_dict, dst_dict)
        model.load_state_dict(state['model'])
        ensemble.append(model)
    return ensemble, args
コード例 #10
0
def load_ensemble_for_inference(filenames):
    """Load an ensemble of models for inference.

    model_arg_overrides allows you to pass a dictionary model_arg_overrides --
    {'arg_name': arg} -- to override model args that were used during model
    training
    """
    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        state = torch.load(
            filename,
            map_location=lambda s, l: default_restore_location(s, 'cpu'))
        states.append(state)

    ensemble = []
    for state in states:
        args = state['args']

        # build model for ensemble
        model = TransformerModel.build_model(args)
        model.load_state_dict(state['model'], strict=True)
        ensemble.append(model)

    src_dict = states[0]['extra_state']['src_dict']
    tgt_dict = states[0]['extra_state']['tgt_dict']

    return ensemble, args, src_dict, tgt_dict
コード例 #11
0
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
    logger.info('Reading saved model from %s', model_file)
    state_dict = torch.load(
        model_file,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    logger.info('model_state_dict keys %s', state_dict.keys())
    return CheckpointState(**state_dict)
コード例 #12
0
ファイル: utils.py プロジェクト: fyabc/fairseq
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
    """Load an ensemble of models for inference.

    model_arg_overrides allows you to pass a dictionary model_arg_overrides --
    {'arg_name': arg} -- to override model args that were used during model
    training
    """
    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        state = _upgrade_state_dict(state)
        states.append(state)

    ensemble = []
    for state in states:
        args = state['args']

        if model_arg_overrides is not None:
            args = _override_model_args(args, model_arg_overrides)

        # build model for ensemble
        model = task.build_model(args)
        model.upgrade_state_dict(state['model'])
        model.load_state_dict(state['model'], strict=True)
        ensemble.append(model)

    return ensemble, args
コード例 #13
0
ファイル: model_utils.py プロジェクト: hnt4499/DPR
def load_states_from_checkpoint_ofa(model_file: str) -> CheckpointStateOFA:
    logger.info("Reading saved OFA model from %s", model_file)
    state_dict = torch.load(
        model_file, map_location=lambda s, l: default_restore_location(s, "cpu")
    )
    logger.info("model_state_dict keys %s", state_dict.keys())
    return CheckpointStateOFA(**state_dict)
コード例 #14
0
def load_model_state(filename, device, data_parallel=False):
    if not os.path.exists(filename):
        print("Starting training from scratch.")
        return 0

    def dict_to_sns(d):
        return SimpleNamespace(**d)

    basedir = os.path.dirname(filename)
    with open(os.path.join(basedir, 'config.json')) as f:
        args_dict = json.load(f, object_hook=dict_to_sns)

    model = build_model(args_dict, device)

    print("Loading model from checkpoints", filename)
    state = torch.load(
        filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))

    from collections import OrderedDict
    new_state_dict = OrderedDict()
    # create new OrderedDict that does not contain `module.`
    if data_parallel:
        for k, v in state['model'].items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
    else:
        new_state_dict = state['model']
    # load model parameters
    try:
        model.load_state_dict(new_state_dict)
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')
    return model, args_dict
コード例 #15
0
ファイル: utils.py プロジェクト: yuekai146/fairseq
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
    """Load an ensemble of models for inference.

    model_arg_overrides allows you to pass a dictionary model_arg_overrides --
    {'arg_name': arg} -- to override model args that were used during model
    training
    """
    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        state = _upgrade_state_dict(state)
        states.append(state)

    ensemble = []
    for state in states:
        args = state['args']

        if model_arg_overrides is not None:
            args = _override_model_args(args, model_arg_overrides)

        # build model for ensemble
        model = task.build_model(args)
        model.upgrade_state_dict(state['model'])
        model.load_state_dict(state['model'], strict=True)
        ensemble.append(model)

    return ensemble, args
コード例 #16
0
def load_checkpoint_to_cpu(path):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
    state = torch.load(
        path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
    )
    state = _upgrade_state_dict(state)
    return state
コード例 #17
0
def load_checkpoint(filename,
                    model,
                    optimizer,
                    lr_scheduler,
                    cuda_device=None):
    if not os.path.exists(filename):
        return 1, 0
    if cuda_device is None:
        state = torch.load(filename)
    else:
        state = torch.load(filename,
                           map_location=lambda s, l: default_restore_location(
                               s, 'cuda:{}'.format(cuda_device)))

    model.load_state_dict(state['model'])
    optimizer.load_state_dict(state['optimizer'])
    lr_scheduler.best = state['best_loss']
    epoch = state['epoch'] + 1
    batch_offset = state['batch_offset']

    gpu_str = ' on GPU #{}'.format(
        cuda_device) if cuda_device is not None else ''
    print('| loaded checkpoint {} (epoch {}){}'.format(filename, epoch,
                                                       gpu_str))
    return epoch, batch_offset
コード例 #18
0
    def __init__(self, config: Config, output_encoded_layers: bool,
                 **kwarg) -> None:
        super().__init__(config, output_encoded_layers=output_encoded_layers)
        # assert config.pretrained_encoder.load_path, "Load path cannot be empty."
        self.encoder = SentenceEncoder(transformer=Transformer(
            vocab_size=config.vocab_size,
            embedding_dim=config.embedding_dim,
            layers=[
                TransformerLayer(
                    embedding_dim=config.embedding_dim,
                    attention=MultiheadSelfAttention(
                        config.embedding_dim, config.num_attention_heads),
                ) for _ in range(config.num_encoder_layers)
            ],
        ))
        self.apply(init_params)
        if config.model_path:
            with PathManager.open(config.model_path, "rb") as f:
                roberta_state = torch.load(f,
                                           map_location=lambda s, l:
                                           default_restore_location(s, "cpu"))
            # In case the model has previously been loaded in PyText and finetuned,
            # then we dont need to do the special state dict translation. Load
            # it directly
            if not config.is_finetuned:
                self.encoder.load_roberta_state_dict(roberta_state["model"])
            else:
                self.load_state_dict(roberta_state)

        self.representation_dim = self._embedding().weight.size(-1)
        log_class_usage(__class__)
コード例 #19
0
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        num_attention_heads: int,
        num_encoder_layers: int,
        output_dropout: float,
        model_path: Optional[str] = None,
    ):
        super().__init__()
        self.transformer = Transformer(
            vocab_size=vocab_size,
            embedding_dim=embedding_dim,
            layers=[
                TransformerLayer(
                    embedding_dim=embedding_dim,
                    attention=MultiheadSelfAttention(
                        embedding_dim, num_attention_heads
                    ),
                )
                for _ in range(num_encoder_layers)
            ],
        )
        self.output_dropout = nn.Dropout(output_dropout)

        self.apply(init_params)
        if model_path:
            with PathManager.open(model_path, "rb") as f:
                roberta_state = torch.load(
                    f, map_location=lambda s, l: default_restore_location(s, "cpu")
                )
                if "model" in roberta_state:
                    roberta_state = translate_roberta_state_dict(roberta_state["model"])
                self.load_state_dict(roberta_state)
コード例 #20
0
ファイル: utils.py プロジェクト: abhishek021/fairseq-py
def load_state(filename,
               model,
               criterion,
               optimizer,
               lr_scheduler,
               cuda_device=None):
    if not os.path.exists(filename):
        return None, []
    if cuda_device is None:
        state = torch.load(filename)
    else:
        state = torch.load(filename,
                           map_location=lambda s, l: default_restore_location(
                               s, 'cuda:{}'.format(cuda_device)))
    state = _upgrade_state_dict(state)

    # load model parameters
    model.load_state_dict(state['model'])

    # only load optimizer and lr_scheduler if they match with the checkpoint
    optim_history = state['optimizer_history']
    last_optim = optim_history[-1]
    if last_optim['criterion_name'] == criterion.__class__.__name__:
        optimizer.load_state_dict(state['last_optimizer_state'])
        lr_scheduler.best = last_optim['best_loss']

    return state['extra_state'], optim_history
コード例 #21
0
ファイル: mlp_decoder.py プロジェクト: wwjiang007/pytext
    def __init__(self, config: Config, in_dim: int, out_dim: int = 0) -> None:
        super().__init__(config)

        layers = []
        for dim in config.hidden_dims or []:
            layers.append(nn.Linear(in_dim, dim, config.bias))
            layers.append(get_activation(config.activation))
            if config.layer_norm:
                layers.append(nn.LayerNorm(dim))
            if config.dropout > 0:
                layers.append(nn.Dropout(config.dropout))
            in_dim = dim
        if config.out_dim is not None:
            out_dim = config.out_dim
        if out_dim > 0:
            layers.append(nn.Linear(in_dim, out_dim, config.bias))

        assert len(layers) > 0
        if config.spectral_normalization:
            layers[-1] = torch.nn.utils.spectral_norm(layers[-1])
        self.mlp = nn.Sequential(*layers)
        self.out_dim = out_dim if out_dim > 0 else config.hidden_dims[-1]
        self.temperature = config.temperature

        if config.load_model_path:
            with PathManager.open(config.load_model_path, "rb") as f:
                mlp_state = torch.load(f,
                                       map_location=lambda s, l:
                                       default_restore_location(s, "cpu"))
            print("loaded mlp state")
            self.load_state_dict(mlp_state, strict=config.load_strict)

        log_class_usage(__class__)
コード例 #22
0
ファイル: utils.py プロジェクト: nguyenlab/fairseq-py
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
    """Load an ensemble of models for inference.

    The source and target dictionaries can be given explicitly, or loaded from
    the `data_dir` directory.
    """
    from fairseq import data, models

    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        states.append(
            torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        )
    args = states[0]['args']
    args = _upgrade_args(args)

    if src_dict is None or dst_dict is None:
        assert data_dir is not None
        src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)

    # build ensemble
    ensemble = []
    for state in states:
        model = models.build_model(args, src_dict, dst_dict)
        model.load_state_dict(state['model'])
        ensemble.append(model)
    return ensemble, args
コード例 #23
0
    def load_pretrained_model(self, model, state_file_name):
        from torch.serialization import default_restore_location
        state = torch.load(
            state_file_name,
            map_location=lambda s, l: default_restore_location(s, 'cpu'))
        params = state['model']

        # non_encoder_param_names = [k for k in params.keys() if not k.startswith('encoder')]
        # for nk in non_encoder_param_names:
        #     del params[nk]

        # non_decoder_param_names = [k for k in params.keys() if not k.startswith('decoder')]
        # for nk in non_decoder_param_names:
        #     del params[nk]

        enc_cnt = 0
        non_enc_cnt = 0
        for k in params.keys():
            if not k.startswith('encoder'):
                print(k)
                non_enc_cnt += 1
            else:
                enc_cnt += 1
        print('enc_cnt = %d, non_enc_cnt = %d' % (enc_cnt, non_enc_cnt))
        model.load_state_dict(params, strict=False)
        print(
            '*** *** load pretrained doc encoder from {} done! *** ***'.format(
                state_file_name))
コード例 #24
0
def load_checkpoint(args, model=None, optimizer=None, scheduler=None):
    if args.restore_file is not None and os.path.isfile(args.restore_file):
        state_dict = torch.load(
            args.restore_file,
            map_location=lambda s, l: default_restore_location(s, "cpu"))

        model = [
            model
        ] if model is not None and not isinstance(model, list) else model
        optimizer = [optimizer] if optimizer is not None and not isinstance(
            optimizer, list) else optimizer
        scheduler = [scheduler] if scheduler is not None and not isinstance(
            scheduler, list) else scheduler

        if "best_score" in state_dict:
            save_checkpoint.best_score = state_dict["best_score"]
        if "last_step" in state_dict:
            save_checkpoint.last_step = state_dict["last_step"]
        if model is not None and state_dict.get("model", None) is not None:
            for m, state in zip(model, state_dict["model"]):
                m.load_state_dict(state)
        if optimizer is not None and state_dict.get("optimizer",
                                                    None) is not None:
            for o, state in zip(optimizer, state_dict["optimizer"]):
                o.load_state_dict(state)
        if scheduler is not None and state_dict.get("scheduler",
                                                    None) is not None:
            for s, state in zip(scheduler, state_dict["scheduler"]):
                s.load_state_dict(state)

        logging.info("Loaded checkpoint {}".format(args.restore_file))
        return state_dict
コード例 #25
0
 def load_pretrained_model(self, model, state_file_name):
     from torch.serialization import default_restore_location
     state = torch.load(
         state_file_name,
         map_location=lambda s, l: default_restore_location(s, 'cpu'))
     params = state['model']
     model.load_state_dict(params, strict=False)
     print('*** *** load pretrained model done! *** ***')
コード例 #26
0
ファイル: dist_trainer.py プロジェクト: victorustc/fastNLP
 def load_check_point(self, name):
     path = os.path.join(self.cp_save_path, name)
     self.logger.info('reload best model from %s', path)
     model_load = torch.load(
         path, map_location=lambda s, l: default_restore_location(s, "cpu"))
     if not isinstance(model_load, dict):
         model_load = model_load.state_dict()
     self.model.load_state_dict(model_load)
コード例 #27
0
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
    print("Reading saved model from %s", model_file)
    state_dict = torch.load(
        model_file,
        map_location=lambda s, l: default_restore_location(s, "cpu"))
    print(len(state_dict))
    print(len(state_dict.keys()), state_dict.keys())
    return CheckpointState(model_dict=state_dict)
コード例 #28
0
ファイル: train.py プロジェクト: zhoutao1996/SCA
def load_lm_state(args, model):
    os.makedirs(args.save_dir, exist_ok=True)
    checkpoint_path = os.path.join(args.save_dir, args.load_srclm_file)
    if os.path.isfile(checkpoint_path):
        print('load language model...')
        from torch.serialization import default_restore_location
        state = torch.load(checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        def upgrade(obj):
            if isinstance(obj, OrderedDict):
                oldkeys = list(obj.keys())
                for k in oldkeys:
                    if k.startswith('decoder') and k != 'decoder':
                        newkey = k.split('.', 1)[1]
                    else:
                        newkey = k
                    obj[newkey] = upgrade(obj[k])
                    if k.startswith('decoder'):
                        del obj[k]
            else:
                return obj
        # state = _upgrade_state_dict(state)
        # model.upgrade_state_dict(state['model'])
        upgrade(state['model'])
        # upgrade(state['model']._metadata)
    # load model parameters
        try:
            model.srclmdecoder.load_state_dict(state['model'], strict=True)
        except Exception:
            raise Exception('Cannot load model parameters from source language model checkpoint, '
                            'please ensure that the architectures match')
        checkpoint_path = os.path.join(args.save_dir, args.load_tgtlm_file)
        assert os.path.exists(checkpoint_path)
        state = torch.load(checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        # state = _upgrade_state_dict(state)
        # model.upgrade_state_dict(state['model'])
        upgrade(state['model'])
        # upgrade(state['model']._metadata)
        # load model parameters
        try:
            model.tgtlmdecoder.load_state_dict(state['model'], strict=True)
        except Exception:
            raise Exception('Cannot load model parameters from target language model checkpoint, '
                            'please ensure that the architectures match')
        return True
    return False
コード例 #29
0
ファイル: roberta.py プロジェクト: freegliboracle/pytext
    def __init__(self, config: Config, output_encoded_layers: bool,
                 **kwarg) -> None:
        super().__init__(config, output_encoded_layers=output_encoded_layers)

        # map to the real model_path
        config.model_path = (resources.roberta.RESOURCE_MAP[config.model_path]
                             if config.model_path
                             in resources.roberta.RESOURCE_MAP else
                             config.model_path)
        # assert config.pretrained_encoder.load_path, "Load path cannot be empty."
        # sharing compression across each layers

        # create compress layer if use linear multihead attention
        if config.use_linformer_encoder:
            compress_layer = nn.Linear(
                config.max_seq_len - 2,
                (config.max_seq_len - 2) // config.linformer_compressed_ratio,
            )

        self.encoder = SentenceEncoder(transformer=Transformer(
            vocab_size=config.vocab_size,
            embedding_dim=config.embedding_dim,
            layers=[
                TransformerLayer(
                    embedding_dim=config.embedding_dim,
                    attention=MultiheadLinearAttention(
                        embed_dim=config.embedding_dim,
                        num_heads=config.num_attention_heads,
                        compress_layer=compress_layer,
                    ) if config.
                    use_linformer_encoder else MultiheadSelfAttention(
                        embed_dim=config.embedding_dim,
                        num_heads=config.num_attention_heads,
                    ),
                ) for _ in range(config.num_encoder_layers)
            ],
            max_seq_len=config.max_seq_len,
        ))
        self.apply(init_params)
        if config.model_path:
            with PathManager.open(config.model_path, "rb") as f:
                roberta_state = torch.load(f,
                                           map_location=lambda s, l:
                                           default_restore_location(s, "cpu"))
            # In case the model has previously been loaded in PyText and finetuned,
            # then we dont need to do the special state dict translation. Load
            # it directly
            if not config.is_finetuned:
                self.encoder.load_roberta_state_dict(roberta_state["model"])
            else:
                self.load_state_dict(roberta_state)

        self.representation_dim = self._embedding().weight.size(-1)
        self.export_encoder = config.export_encoder
        self.variable_size_embedding = config.variable_size_embedding
        log_class_usage(__class__)
コード例 #30
0
def load_model(checkpoint_path):
    state_dict = torch.load(
        checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, "cpu"))
    args = argparse.Namespace(**{**vars(state_dict["args"]), "no_log": True})

    model = models.build_model(args).to(device)
    model.load_state_dict(state_dict["model"][0])
    model.eval()
    return model
コード例 #31
0
ファイル: utils.py プロジェクト: bcm628/nlu_cw2-1
def load_checkpoint(args, model, optimizer):
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    if os.path.isfile(checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        model.load_state_dict(state_dict['model'])
        optimizer.load_state_dict(state_dict['optimizer'])
        save_checkpoint.best_loss = state_dict['best_loss']
        save_checkpoint.last_epoch = state_dict['last_epoch']
        logging.info('Loaded checkpoint {}'.format(checkpoint_path))
        return state_dict
コード例 #32
0
def load_checkpoint_to_cpu(path, arg_overrides=None):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
    try:
        from fairseq.fb_pathmgr import fb_pathmgr
        with fb_pathmgr.open(path, "rb") as f:
            state = torch.load(
                f, map_location=lambda s, l: default_restore_location(s, 'cpu'),
            )
    except (ModuleNotFoundError, ImportError):
        # if path manager not found, continue with local file.
        state = torch.load(
            path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
        )
    args = state['args']
    if arg_overrides is not None:
        for arg_name, arg_val in arg_overrides.items():
            setattr(args, arg_name, arg_val)
    state = _upgrade_state_dict(state)
    return state
コード例 #33
0
ファイル: utils.py プロジェクト: fyabc/fairseq
def load_model_state(filename, model):
    if not os.path.exists(filename):
        return None, [], None
    state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    state = _upgrade_state_dict(state)
    model.upgrade_state_dict(state['model'])

    # load model parameters
    try:
        model.load_state_dict(state['model'], strict=True)
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')

    return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
コード例 #34
0
ファイル: utils.py プロジェクト: ahiroto/ParlAI
def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
    if not os.path.exists(filename):
        return None, []
    if cuda_device is None:
        state = torch.load(filename)
    else:
        state = torch.load(
            filename,
            map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
        )
    state = _upgrade_state_dict(state)

    # load model parameters
    model.load_state_dict(state['model'])

    # only load optimizer and lr_scheduler if they match with the checkpoint
    optim_history = state['optimizer_history']
    last_optim = optim_history[-1]
    if last_optim['criterion_name'] == criterion.__class__.__name__:
        optimizer.load_state_dict(state['last_optimizer_state'])
        lr_scheduler.best = last_optim['best_loss']

    return state['extra_state'], optim_history