예제 #1
0
    def delete_r_config():

        try:
            Variable.set('r_config', '{}')
        except KeyError:
            raise ConfigVariableNotFoundException(
                "Variable 'r_config' not found !")
예제 #2
0
    def create_r_config(self, ids, session):

        rows = session.query(FailedDagRun).filter(
            FailedDagRun.id.in_(ids)).all()

        r_obj = {}

        for d in rows:
            if r_obj.__contains__(d.dag_id):
                if not (r_obj[d.dag_id]).__contains__(d.execution_date):
                    r_obj[d.dag_id].append(str(d.execution_date))
            else:
                r_obj[d.dag_id] = [str(d.execution_date)[:19]]

        Variable.set(key='r_config', value=json.dumps(r_obj))

        for id in ids:
            execution_date = session.query(FailedDagRun).filter(
                FailedDagRun.id == id).one().execution_date
            dag_id = session.query(FailedDagRun).filter(
                FailedDagRun.id == id).one().dag_id

            session.query(FailedDagRun).filter(FailedDagRun.id == id).update(
                {'state': 'recovery_executed'}, synchronize_session='fetch')
            Variable.delete(
                key="{}${}".format(str(execution_date)[:19], dag_id))
예제 #3
0
 def test_variable_metastore_secrets_backend(self):
     Variable.set(key="hello", value="World")
     metastore_backend = MetastoreBackend()
     variable_value = metastore_backend.get_variable(key="hello")
     self.assertEqual("World", variable_value)
     self.assertIsNone(
         metastore_backend.get_variable(key="non_existent_key"))
예제 #4
0
 def test_variable_metastore_secrets_backend(self):
     Variable.set(key="hello", value="World")
     Variable.set(key="empty_str", value="")
     metastore_backend = MetastoreBackend()
     variable_value = metastore_backend.get_variable(key="hello")
     assert "World" == variable_value
     assert metastore_backend.get_variable(key="non_existent_key") is None
     assert '' == metastore_backend.get_variable(key="empty_str")
예제 #5
0
def test_variables_as_arguments_dag():
    override_command = 'value_from_variable'
    if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.10"):
        os.environ['AIRFLOW_VAR_VAR1'] = override_command
    else:
        Variable.set("var1",override_command)
    td = dagfactory.DagFactory(DAG_FACTORY_VARIABLES_AS_ARGUMENTS)
    td.generate_dags(globals())
    tasks = globals()['example_dag'].tasks
    for task in tasks:
        if task.task_id == "task_3":
            assert task.bash_command == override_command
예제 #6
0
def create_configuration_variables():

    # 'config' variable

    Variable.set(
        key='config',
        value=json.dumps({
            "tables": [],
            "start_date": "1da",
            "frequency": "hourly",
            "threshold": 10000,
            "export_format": "xml",
            "storage_type": "sftp",
            "email": ""
        }))

    # 'r_config' variable

    Variable.set(
        key='r_config',
        value='{}'
    )

    # 'dag_creation_dates' variable

    Variable.set(
        key='dag_creation_dates',
        value=json.dumps({})
    )
예제 #7
0
    def trigger_dag(self, ids, session=None):

        rows = session.query(FailedDagRun).filter(
            FailedDagRun.id.in_(ids)).all()

        try:
            r_config = Variable.get(key='r_config')
            r_obj = json.loads(r_config)

            for d in rows:
                if r_obj.__contains__(d.dag_id):
                    if not (r_obj[d.dag_id]).__contains__(
                            str(d.execution_date)[:19]):
                        r_obj[d.dag_id].append(str(d.execution_date)[:19])
                    else:
                        pass
                else:
                    r_obj[d.dag_id] = [str(d.execution_date)[:19]]

            Variable.set(key='r_config', value=json.dumps(r_obj))

            for id in ids:
                execution_date = session.query(FailedDagRun).filter(
                    FailedDagRun.id == id).one().execution_date
                dag_id = session.query(FailedDagRun).filter(
                    FailedDagRun.id == id).one().dag_id

                session.query(FailedDagRun).filter(
                    FailedDagRun.id == id).update(
                        {'state': 'recovery_executed'},
                        synchronize_session='fetch')
                Variable.delete(
                    key="{}${}".format(str(execution_date)[:19], dag_id))

        except KeyError as e:
            LoggingMixin().log.warn(e.__str__())
            Variable.set(key='r_config', value='{}')
            self.create_r_config(ids, session)
예제 #8
0
    def test_parse_bucket_key_from_jinja(self, mock_hook):
        mock_hook.return_value.check_for_key.return_value = False

        Variable.set("test_bucket_key", "s3://bucket/key")

        execution_date = datetime(2020, 1, 1)

        dag = DAG("test_s3_key", start_date=execution_date)
        op = S3KeySensor(
            task_id='s3_key_sensor',
            bucket_key='{{ var.value.test_bucket_key }}',
            bucket_name=None,
            dag=dag,
        )

        ti = TaskInstance(task=op, execution_date=execution_date)
        context = ti.get_template_context()
        ti.render_templates(context)

        op.poke(None)

        self.assertEqual(op.bucket_key, "key")
        self.assertEqual(op.bucket_name, "bucket")
예제 #9
0
def create_dags():

    global dag_creation_dates
    global new_dags
    global email_notify_required

    new_dags = []

    dag_creation_dates = json.loads(Variable.get(key='dag_creation_dates'))
    email_notify_required = is_email_notification_required()

    try:
        for table in config.get('tables'):
            with open(configuration.get_airflow_home() + '/dags/templates/main.py.jinja2') as file_:
                template = Template(file_.read())

            if dag_creation_dates.get(table) is not None:
                start_date = dag_creation_dates.get(table)
            else:
                start_date = get_start_date(config.get('start_date'))
                dag_creation_dates[table] = str(start_date)

            output = template.render(
                data={
                    'dag_id': table,
                    'frequency': config.get('frequency'),
                    'storage_type': storage_type,
                    'start_date': start_date,
                    'email_required': email_notify_required
                }
            )

            with open(configuration.get_airflow_home() + '/dags/generated/dag_'
                      + '{}'.format(table).replace(' ', '_') + '.py', 'w') as f:
                f.write(output)
                new_dags.append('dag_' + '{}'.format(table).replace(' ', '_') + '.py')

        if len(r_config) != 0:

            for table in r_config:
                for exec_date in r_config.get(table):
                    execution_date = str(exec_date).replace(' ', 'T')[0:19]
                    with open(configuration.get_airflow_home()
                              + '/dags/templates/recovery_template.py.jinja2') as file_:
                        template = Template(file_.read())
                        output = template.render(
                            data={'dag_id': table, 'frequency': config.get('frequency'), 'storage_type': storage_type,
                                  'execution_date': execution_date})
                    with open(configuration.get_airflow_home() + '/dags/generated/r_dag_' + '{}_{}'.format(
                            table, execution_date).replace(' ', '_') + '.py', 'w') as f:
                        f.write(output)
                        e = '{}'.format(execution_date).replace(' ', 'T')
                        new_dags.append('r_dag_' + '{}_{}'.format(table, e).replace(' ', '_') + '.py')

        md_dag_ids = settings.Session.query(Dags.dag_id, Dags.fileloc).all()

        for record in md_dag_ids:
            (d_id, loc) = record
            filename = loc[str(loc).rfind('/') + 1:]
            if filename == 'dag_generator.py' or filename == 'dag_cleanup.py':
                continue
            if filename not in new_dags:
                try:
                    if os.path.exists(str(loc)):
                        os.remove(str(loc))
                    else:
                        LoggingMixin().log.warning("{} file doesn't exists !".format(filename))

                    requests.delete(
                        url="http://{}:8080/api/experimental/dags/{}".format(
                            socket.gethostbyname(socket.gethostname()),
                            str(d_id)
                        ),
                        auth=(rest.login, rest.password)
                    )

                    dag_creation_dates.pop(d_id)

                except Exception as e:
                    LoggingMixin().log.error(str(e))

        Variable.set(key='dag_creation_dates', value=json.dumps(dag_creation_dates))

    except AirflowException:

        raise ConfigVariableNotFoundException()