예제 #1
0
def test_load_torch_file(model_dir):
    load_torch_file(model_dir / 'nuqe.ckpt')
    # There's no CUDA:
    # if not torch.cuda.is_available():
    #     with pytest.raises(RuntimeError, match='No CUDA GPUs are available'):
    #         load_torch_file(
    #             model_dir / 'nuqe.ckpt',
    #             map_location=lambda storage, loc: storage.cuda(0),
    #         )
    # And this file does not exist:
    with pytest.raises(ValueError):
        load_torch_file(model_dir / 'nonexistent.torch')
예제 #2
0
    def _load_encoder(self, path: Path):
        logger.info(f'Loading encoder from {path}')
        module_dict = load_torch_file(path)

        encoder_cls = MetaModule.retrieve_subclass(
            module_dict['encoder']['class_name'])
        self.data_encoders = WMTQEDataEncoder(
            config=self.config.data_processing,
            field_encoders=encoder_cls.input_data_encoders(
                self.config.model.encoder),
        )

        input_vocabs = {
            const.SOURCE: module_dict[const.VOCAB][const.SOURCE],
            const.TARGET: module_dict[const.VOCAB][const.TARGET],
        }
        if const.PE in module_dict[const.VOCAB]:
            input_vocabs[const.PE] = module_dict[const.VOCAB][const.PE]
        self.data_encoders.vocabularies_from_dict(input_vocabs, overwrite=True)

        self.encoder = MetaModule.from_dict(
            module_dict['encoder'],
            vocabs=self.data_encoders.vocabularies,
            pre_load_model=False,
        )
예제 #3
0
    def load_vocabularies(self, load_vocabs_from: Path = None, overwrite: bool = False):
        """Load serialized Vocabularies from disk into fields."""
        logger.info(f'Loading vocabularies from: {load_vocabs_from}')
        vocabs_dict = load_torch_file(load_vocabs_from)
        if const.VOCAB not in vocabs_dict:
            raise KeyError(f'File {load_vocabs_from} has no {const.VOCAB}')

        return self.vocabularies_from_dict(vocabs_dict[const.VOCAB], overwrite)
예제 #4
0
 def check_consistency(cls, v, values):
     if v is None and values.get('class_name') is None:
         raise ValueError('Must provide `class_name` or `load`')
     if v is not None and values['class_name'] is not None:
         model_dict = load_torch_file(v)
         if model_dict['class_name'] != values['class_name']:
             raise ValueError(
                 f'`class_name` in configuration file ({values["class_name"]}) '
                 f'does not match class_name in the loaded model file '
                 f'({model_dict["class_name"]}); consider removing `class_name`'
             )
     return v
예제 #5
0
def test_load_torch_file(model_dir):
    load_torch_file(model_dir / 'nuqe.ckpt')
    # There's no CUDA:
    with pytest.raises(AssertionError):
        load_torch_file(model_dir / 'nuqe.ckpt', map_location='cuda')
    # And this file does not exist:
    with pytest.raises(ValueError):
        load_torch_file(model_dir / 'nonexistent.torch')
예제 #6
0
 def load(cls, path):
     model_dict = load_torch_file(path)
     return cls.from_dict(model_dict)
예제 #7
0
 def load(cls, path: Path, map_location=None):
     logger.info(f'Loading system from {path}')
     module_dict = load_torch_file(path, map_location=map_location)
     system = TLMSystem.from_dict(module_dict=module_dict)
     return system