def model_training(self) -> TorchModelTraining: if not hasattr(self, "_model_training"): class_ = load_attr(self.model_task_class, Type[TorchModelTraining]) self._model_training = load_torch_model_training_from_task_id( class_, self.model_task_id) return self._model_training
def project_config(self) -> ProjectConfig: if not hasattr(self, "_project_config"): self._project_config = deepcopy( load_attr(self.project, ProjectConfig)) if (self.loss_function == "crm" and self.project_config.propensity_score_column_name not in self.project_config.auxiliar_output_columns): self._project_config.auxiliar_output_columns = [ *self._project_config.auxiliar_output_columns, Column(self.project_config.propensity_score_column_name, IOType.NUMBER), ] return self._project_config
def get_direct_estimator(self, extra_params: dict) -> TorchModelTraining: assert self.direct_estimator_class is not None estimator_class = load_attr(self.direct_estimator_class, Type[TorchModelTraining]) attribute_names = get_attribute_names(estimator_class) params = { key: value for key, value in self.model_training.param_kwargs.items() if key in attribute_names } return estimator_class(**{**params, **extra_params})
def create_agent(self) -> BanditAgent: bandit_class = load_attr(self.bandit_policy_class, Type[BanditPolicy]) bandit = bandit_class(reward_model=self.get_trained_module(), **self.bandit_policy_params) return BanditAgent(bandit)
def module_class(self) -> Type[RecommenderModule]: if not hasattr(self, "_module_class"): self._module_class = load_attr(self.recommender_module_class, Type[RecommenderModule]) return self._module_class