def _get_mock_classy_vision_model(self, trainable_params=True): mock_classy_vision_model = ClassyModel() if trainable_params: mock_classy_vision_model.get_optimizer_params = MagicMock( return_value=self._get_optimizer_params()) mock_classy_vision_model.parameters = MagicMock( return_value=self._get_optimizer_params()["regularized_params"] + self._get_optimizer_params()["unregularized_params"]) else: mock_classy_vision_model.get_optimizer_params = MagicMock( return_value={ "regularized_params": [], "unregularized_params": [] }) mock_classy_vision_model.parameters = MagicMock(return_value=[ param.detach() for param in self._get_optimizer_params()["regularized_params"] + self._get_optimizer_params()["unregularized_params"] ]) return mock_classy_vision_model
def _validate_and_get_optimizer_params( self, model: ClassyModel) -> Dict[str, Any]: """ Validate and return the optimizer params. The optimizer params are fetched from :fun:`models.ClassyModel.get_optimizer_params`. Args: model: The model to get the params from. Returns: A dict containing "regularized_params" and "unregularized_params". Weight decay will only be applied to "regularized_params". """ if isinstance(model, torch.nn.parallel.DistributedDataParallel): optimizer_params = model.module.get_optimizer_params() else: optimizer_params = model.get_optimizer_params() assert isinstance(optimizer_params, dict) and set( optimizer_params.keys() ) == { "regularized_params", "unregularized_params", }, "get_optimizer_params() of {0} should return dict with exact two keys\ 'regularized_params', 'unregularized_params'".format( type(model).__name__) trainable_params = [ params for params in model.parameters() if params.requires_grad ] assert len(trainable_params) == len( optimizer_params["regularized_params"] ) + len(optimizer_params["unregularized_params"]), ( "get_optimizer_params() of {0} should return params that cover all" "trainable params of model".format(type(model).__name__)) return optimizer_params