def test_load_torch_file(model_dir): load_torch_file(model_dir / 'nuqe.ckpt') # There's no CUDA: # if not torch.cuda.is_available(): # with pytest.raises(RuntimeError, match='No CUDA GPUs are available'): # load_torch_file( # model_dir / 'nuqe.ckpt', # map_location=lambda storage, loc: storage.cuda(0), # ) # And this file does not exist: with pytest.raises(ValueError): load_torch_file(model_dir / 'nonexistent.torch')
def _load_encoder(self, path: Path): logger.info(f'Loading encoder from {path}') module_dict = load_torch_file(path) encoder_cls = MetaModule.retrieve_subclass( module_dict['encoder']['class_name']) self.data_encoders = WMTQEDataEncoder( config=self.config.data_processing, field_encoders=encoder_cls.input_data_encoders( self.config.model.encoder), ) input_vocabs = { const.SOURCE: module_dict[const.VOCAB][const.SOURCE], const.TARGET: module_dict[const.VOCAB][const.TARGET], } if const.PE in module_dict[const.VOCAB]: input_vocabs[const.PE] = module_dict[const.VOCAB][const.PE] self.data_encoders.vocabularies_from_dict(input_vocabs, overwrite=True) self.encoder = MetaModule.from_dict( module_dict['encoder'], vocabs=self.data_encoders.vocabularies, pre_load_model=False, )
def load_vocabularies(self, load_vocabs_from: Path = None, overwrite: bool = False): """Load serialized Vocabularies from disk into fields.""" logger.info(f'Loading vocabularies from: {load_vocabs_from}') vocabs_dict = load_torch_file(load_vocabs_from) if const.VOCAB not in vocabs_dict: raise KeyError(f'File {load_vocabs_from} has no {const.VOCAB}') return self.vocabularies_from_dict(vocabs_dict[const.VOCAB], overwrite)
def check_consistency(cls, v, values): if v is None and values.get('class_name') is None: raise ValueError('Must provide `class_name` or `load`') if v is not None and values['class_name'] is not None: model_dict = load_torch_file(v) if model_dict['class_name'] != values['class_name']: raise ValueError( f'`class_name` in configuration file ({values["class_name"]}) ' f'does not match class_name in the loaded model file ' f'({model_dict["class_name"]}); consider removing `class_name`' ) return v
def test_load_torch_file(model_dir): load_torch_file(model_dir / 'nuqe.ckpt') # There's no CUDA: with pytest.raises(AssertionError): load_torch_file(model_dir / 'nuqe.ckpt', map_location='cuda') # And this file does not exist: with pytest.raises(ValueError): load_torch_file(model_dir / 'nonexistent.torch')
def load(cls, path): model_dict = load_torch_file(path) return cls.from_dict(model_dict)
def load(cls, path: Path, map_location=None): logger.info(f'Loading system from {path}') module_dict = load_torch_file(path, map_location=map_location) system = TLMSystem.from_dict(module_dict=module_dict) return system