Exemple #1
0
    def _finish_task(self,
                     **kwargs) -> None:
        """ Writes datetime_finished to task table (based on repository_class) for ml_dag_id

        Args:
            **kwargs: Airflow context

        """
        self.log.debug(f'kwargs: {kwargs}')

        ml_dag_id = dag_utils.get_ml_dag_id(parent_dag_id=self._parent_dag_id, **kwargs)

        self._repository_class(engine=self._engine).finish_task(ml_dag_id=ml_dag_id)
Exemple #2
0
    def _initialize_task(self,
                         **kwargs) -> None:
        """ Inserts task with ml_dag_id into DB, if it doesn't already exists in DB

        Args:
            **kwargs: Airflow context

        """
        self.log.debug(f'kwargs: {kwargs}')

        ml_dag_id = dag_utils.get_ml_dag_id(parent_dag_id=self._parent_dag_id, **kwargs)

        try:
            self._repository_class(engine=self._engine).insert_task_with_ml_dag_id(ml_dag_id=ml_dag_id)
        except DBException:
            pass
Exemple #3
0
    def _execute_or_skip_task(self,
                              **kwargs) -> str:
        """ Conditional that chooses task that should be executed after branching based on presence of datetime_finished
        in repository for task (based on repository_class).

        Args:
            **kwargs: Airflow context

        Returns: Name of the task that should be executed after branching

        """
        self.log.debug(f'kwargs: {kwargs}')

        ml_dag_id = dag_utils.get_ml_dag_id(parent_dag_id=self._parent_dag_id, **kwargs)

        if self._repository_class(engine=self._engine).is_task_finished(ml_dag_id=ml_dag_id):
            return 'skip_{}'.format(self._child_dag_id)
        else:
            return 'start_task_in_db_{}'.format(self._child_dag_id)
Exemple #4
0
    def _parameters_provider(self, **kwargs) -> str:
        """ Callable that provides additional parameters for Parametrized Bash Operator related to this subdag

        Args:
            **kwargs: Airflow context

        Returns: additional parameters for Parametrized Bash Operator

        """
        self.log.info(f'kwargs: {kwargs}')
        ml_dag_id = dag_utils.get_ml_dag_id(self._parent_dag_id, **kwargs)

        self.log.info(f'ml_dag_id: {ml_dag_id}')

        parameters = [
            str(ml_dag_id), 'common_task_1_param_1', 'common_task_1_param_2'
        ]

        return ' '.join(parameters)