def test_hparams(self): r"""Tests the HParams class. """ default_hparams = { "str": "str", "list": ['item1', 'item2'], "dict": { "key1": "value1", "key2": "value2" }, "nested_dict": { "dict_l2": { "key1_l2": "value1_l2" } }, "type": "type", "kwargs": { "arg1": "argv1" }, } # Test HParams.items() function hparams_ = HParams(None, default_hparams) names = [] for name, _ in hparams_.items(): names.append(name) self.assertEqual(set(names), set(default_hparams.keys())) hparams = {"dict": {"key1": "new_value"}, "kwargs": {"arg2": "argv2"}} hparams_ = HParams(hparams, default_hparams) # Test HParams construction self.assertEqual(hparams_.str, default_hparams["str"]) self.assertEqual(hparams_.list, default_hparams["list"]) self.assertEqual(hparams_.dict.key1, hparams["dict"]["key1"]) self.assertEqual(hparams_.kwargs.arg2, hparams["kwargs"]["arg2"]) self.assertEqual(hparams_.nested_dict.dict_l2.key1_l2, default_hparams["nested_dict"]["dict_l2"]["key1_l2"]) self.assertEqual(len(hparams_), len(default_hparams)) new_hparams = copy.deepcopy(default_hparams) new_hparams["dict"]["key1"] = hparams["dict"]["key1"] new_hparams["kwargs"].update(hparams["kwargs"]) self.assertEqual(hparams_.todict(), new_hparams) self.assertTrue("dict" in hparams_) self.assertIsNone(hparams_.get('not_existed_name', None)) self.assertEqual(hparams_.get('str'), default_hparams['str']) # Test HParams update related operations hparams_.str = "new_str" hparams_.dict = {"key3": "value3"} self.assertEqual(hparams_.str, "new_str") self.assertEqual(hparams_.dict.key3, "value3") hparams_.add_hparam("added_str", "added_str") hparams_.add_hparam("added_dict", {"key4": "value4"}) hparams_.kwargs.add_hparam("added_arg", "added_argv") self.assertEqual(hparams_.added_str, "added_str") self.assertEqual(hparams_.added_dict.todict(), {"key4": "value4"}) self.assertEqual(hparams_.kwargs.added_arg, "added_argv") # Test HParams I/O hparams_file = tempfile.NamedTemporaryFile() pickle.dump(hparams_, hparams_file) with open(hparams_file.name, 'rb') as hparams_file: hparams_loaded = pickle.load(hparams_file) self.assertEqual(hparams_loaded.todict(), hparams_.todict())
class PretrainedMixin(ModuleBase, ABC): r"""A mixin class for all pre-trained classes to inherit. """ _MODEL_NAME: str _MODEL2URL: Dict[str, MaybeList[str]] pretrained_model_dir: Optional[str] @classmethod def available_checkpoints(cls) -> List[str]: return list(cls._MODEL2URL.keys()) def _name_to_variable(self, name: str) -> nn.Parameter: r"""Find the corresponding variable given the specified name. """ pointer = self for m_name in name.split("."): if m_name.isdigit(): num = int(m_name) pointer = pointer[num] # type: ignore else: pointer = getattr(pointer, m_name) return pointer # type: ignore def load_pretrained_config(self, pretrained_model_name: Optional[str] = None, cache_dir: Optional[str] = None, hparams=None): r"""Load paths and configurations of the pre-trained model. Args: pretrained_model_name (optional): A str with the name of a pre-trained model to load. If `None`, will use the model name in :attr:`hparams`. cache_dir (optional): The path to a folder in which the pre-trained models will be cached. If `None` (default), a default directory will be used. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparameter will be set to default values. See :meth:`default_hparams` for the hyperparameter structure and default values. """ if not hasattr(self, "_hparams"): self._hparams = HParams(hparams, self.default_hparams()) else: # Probably already parsed by subclasses. We rely on subclass # implementations to get this right. # As a sanity check, we require `hparams` to be `None` in this case. if hparams is not None: raise ValueError( "`self._hparams` is already assigned, but `hparams` " "argument is not None.") self.pretrained_model_dir = None self.pretrained_model_name = pretrained_model_name if self.pretrained_model_name is None: self.pretrained_model_name = self._hparams.pretrained_model_name if self.pretrained_model_name is not None: self.pretrained_model_dir = self.download_checkpoint( self.pretrained_model_name, cache_dir) pretrained_model_hparams = self._transform_config( self.pretrained_model_name, self.pretrained_model_dir) self._hparams = HParams(pretrained_model_hparams, self._hparams.todict()) def init_pretrained_weights(self, *args, **kwargs): if self.pretrained_model_dir: self._init_from_checkpoint(self.pretrained_model_name, self.pretrained_model_dir, *args, **kwargs) else: self.reset_parameters() def reset_parameters(self): r"""Initialize parameters of the pre-trained model. This method is only called if pre-trained checkpoints are not loaded. """ pass @staticmethod def default_hparams(): r"""Returns a dictionary of hyperparameters with default values. .. code-block:: python { "pretrained_model_name": None, "name": "pretrained_base" } """ return { 'pretrained_model_name': None, 'name': "pretrained_base", '@no_typecheck': ['pretrained_model_name'] } @classmethod def download_checkpoint(cls, pretrained_model_name: str, cache_dir: Optional[str] = None) -> str: r"""Download the specified pre-trained checkpoint, and return the directory in which the checkpoint is cached. Args: pretrained_model_name (str): Name of the model checkpoint. cache_dir (str, optional): Path to the cache directory. If `None`, uses the default directory (user's home directory). Returns: Path to the cache directory. """ if pretrained_model_name in cls._MODEL2URL: download_path = cls._MODEL2URL[pretrained_model_name] else: raise ValueError( f"Pre-trained model not found: {pretrained_model_name}") if cache_dir is None: cache_path = default_download_dir(cls._MODEL_NAME) else: cache_path = Path(cache_dir) cache_path = cache_path / pretrained_model_name if not cache_path.exists(): if isinstance(download_path, str): filename = download_path.split('/')[-1] maybe_download(download_path, cache_path, extract=True) folder = None for file in cache_path.iterdir(): if file.is_dir(): folder = file assert folder is not None (cache_path / filename).unlink() for file in folder.iterdir(): file.rename(file.parents[1] / file.name) folder.rmdir() else: for path in download_path: maybe_download(path, cache_path) print(f"Pre-trained {cls._MODEL_NAME} checkpoint " f"{pretrained_model_name} cached to {cache_path}") else: print(f"Using cached pre-trained {cls._MODEL_NAME} checkpoint " f"from {cache_path}.") return str(cache_path) @classmethod @abstractmethod def _transform_config(cls, pretrained_model_name: str, cache_dir: str) -> Dict[str, Any]: r"""Load the official configuration file and transform it into Texar-style hyperparameters. Args: pretrained_model_name (str): Name of the pre-trained model. cache_dir (str): Path to the cache directory. Returns: dict: Texar module hyperparameters. """ raise NotImplementedError @abstractmethod def _init_from_checkpoint(self, pretrained_model_name: str, cache_dir: str, **kwargs): r"""Initialize model parameters from weights stored in the pre-trained checkpoint. Args: pretrained_model_name (str): Name of the pre-trained model. cache_dir (str): Path to the cache directory. **kwargs: Additional arguments for specific models. """ raise NotImplementedError