def _get_optimizer_from_params( self, model: RunnerModel, stage: str, params: "DictConfig" ) -> Optimizer: # @TODO 1: refactor; this method is too long params = deepcopy(params) # learning rate linear scaling lr_scaling_params = params.pop("lr_linear_scaling", None) if lr_scaling_params: loaders_params = self._config.stages[stage].loaders lr, lr_scaling = do_lr_linear_scaling( lr_scaling_params=lr_scaling_params, batch_size=loaders_params.get("batch_size", 1), per_gpu_scaling=loaders_params.get("per_gpu_scaling", False), ) params["lr"] = lr else: lr_scaling = 1.0 # getting layer-wise parameters layerwise_params = params.pop("layerwise_params", OrderedDict()) no_bias_weight_decay = params.pop("no_bias_weight_decay", True) # getting model parameters model_key = params.pop("_model", None) model_params = get_model_parameters( models=model, models_keys=model_key, layerwise_params=layerwise_params, no_bias_weight_decay=no_bias_weight_decay, lr_scaling=lr_scaling, ) # instantiate optimizer optimizer: Optimizer = hydra.utils.instantiate(params, params=model_params) return optimizer
def _get_optimizer_from_params(self, model: RunnerModel, stage: str, **params) -> RunnerOptimizer: # @TODO 1: refactor; this method is too long # learning rate linear scaling lr_scaling_params = params.pop("lr_linear_scaling", None) if lr_scaling_params: loaders_params = dict(self._stage_config[stage]["loaders"]) lr, lr_scaling = do_lr_linear_scaling( lr_scaling_params=lr_scaling_params, batch_size=loaders_params.get("batch_size", 1), per_gpu_scaling=loaders_params.get("per_gpu_scaling", False), ) params["lr"] = lr else: lr_scaling = 1.0 # getting layer-wise parameters layerwise_params = params.pop("layerwise_params", OrderedDict()) no_bias_weight_decay = params.pop("no_bias_weight_decay", True) # getting model parameters model_key = params.pop("_model", None) model_params = get_model_parameters( models=model, models_keys=model_key, layerwise_params=layerwise_params, no_bias_weight_decay=no_bias_weight_decay, lr_scaling=lr_scaling, ) # instantiate optimizer optimizer = REGISTRY.get_from_params(**params, params=model_params) return optimizer