Esempio n. 1
0
    def model_training(self) -> TorchModelTraining:
        if not hasattr(self, "_model_training"):
            class_ = load_attr(self.model_task_class, Type[TorchModelTraining])
            self._model_training = load_torch_model_training_from_task_id(
                class_, self.model_task_id)

        return self._model_training
Esempio n. 2
0
 def project_config(self) -> ProjectConfig:
     if not hasattr(self, "_project_config"):
         self._project_config = deepcopy(
             load_attr(self.project, ProjectConfig))
         if (self.loss_function == "crm"
                 and self.project_config.propensity_score_column_name
                 not in self.project_config.auxiliar_output_columns):
             self._project_config.auxiliar_output_columns = [
                 *self._project_config.auxiliar_output_columns,
                 Column(self.project_config.propensity_score_column_name,
                        IOType.NUMBER),
             ]
     return self._project_config
Esempio n. 3
0
    def get_direct_estimator(self, extra_params: dict) -> TorchModelTraining:
        assert self.direct_estimator_class is not None
        estimator_class = load_attr(self.direct_estimator_class,
                                    Type[TorchModelTraining])

        attribute_names = get_attribute_names(estimator_class)

        params = {
            key: value
            for key, value in self.model_training.param_kwargs.items()
            if key in attribute_names
        }

        return estimator_class(**{**params, **extra_params})
Esempio n. 4
0
 def create_agent(self) -> BanditAgent:
     bandit_class = load_attr(self.bandit_policy_class, Type[BanditPolicy])
     bandit = bandit_class(reward_model=self.get_trained_module(),
                           **self.bandit_policy_params)
     return BanditAgent(bandit)
Esempio n. 5
0
 def module_class(self) -> Type[RecommenderModule]:
     if not hasattr(self, "_module_class"):
         self._module_class = load_attr(self.recommender_module_class,
                                        Type[RecommenderModule])
     return self._module_class