コード例 #1
0
def preparation(**kwargs):
    # Without this global setting, this DAG on EC2 server got the following error:
    #     UnboundLocalError: local variable 'VPC_ID' referenced before assignment
    global VPC_ID, SUBNET_ID, CLUSTER_NAME
    Variable.delete('cluster_id')
    Variable.delete('keypair_name')
    Variable.delete('master_sg_id')
    Variable.delete('slave_sg_id')
    Variable.delete('short_interests_dag_state')

    ec2, emr, iam = emrs.get_boto_clients(config['AWS']['REGION_NAME'],
                                          config=config)

    if VPC_ID == '':
        VPC_ID = emrs.get_first_available_vpc(ec2)

    if SUBNET_ID == '':
        SUBNET_ID = emrs.get_first_available_subnet(ec2, VPC_ID)

    master_sg_id = emrs.create_security_group(
        ec2, '{}SG'.format(CLUSTER_NAME),
        'Master SG for {}'.format(CLUSTER_NAME), VPC_ID)
    slave_sg_id = emrs.create_security_group(
        ec2, '{}SlaveSG'.format(CLUSTER_NAME),
        'Slave SG for {}'.format(CLUSTER_NAME), VPC_ID)

    Variable.set('master_sg_id', master_sg_id)
    Variable.set('slave_sg_id', slave_sg_id)

    keypair = emrs.create_key_pair(ec2, '{}_pem'.format(CLUSTER_NAME))
    Variable.set('keypair_name', keypair['KeyName'])

    emrs.create_default_roles(iam)
コード例 #2
0
    def test_weekly_schedule_conversion(self):
        """
        Test that weekly schedule is converted properly into cron expression.
        """
        report_saver = ReportFormSaver(self.report_form_sample_weekly)
        report_saver.extract_report_data_into_airflow(report_exists=False)
        report_airflow_variable = Variable.get(
            "rb_status_" + self.report_form_sample_weekly.report_title.data,
            deserialize_json=True,
        )

        time = self.report_form_sample_weekly.schedule_time.data
        week_day = int(self.report_form_sample_weekly.schedule_week_day.data)
        tz = self.report_form_sample_weekly.schedule_timezone.data

        before_dt = (
            pendulum.now().in_tz(tz).next(int(week_day)).at(time.hour, time.minute, 0)
        )

        after_dt = before_dt.in_tz("UTC")

        Variable.delete("rb_status_" + self.report_form_sample_weekly.report_title.data)

        self.assertEqual(
            f"{after_dt.minute} {after_dt.hour} * * {after_dt.day_of_week}",
            report_airflow_variable["schedule"],
        )
コード例 #3
0
    def test_conversion_to_default_timezone(self):
        """
        Tests that the schedule time is converted to airflow default
        timezone in the backend (i.e. America/Chicago -> UTC)
        """
        report_saver = ReportFormSaver(self.report_form_sample_timezone_daily)
        report_saver.extract_report_data_into_airflow(report_exists=False)
        report_airflow_variable = Variable.get(
            "rb_status_" + self.report_form_sample_timezone_daily.report_title.data,
            deserialize_json=True,
        )

        time = self.report_form_sample_timezone_daily.schedule_time.data
        tz = self.report_form_sample_timezone_daily.schedule_timezone.data

        before_dt = pendulum.now().in_tz(tz).at(time.hour, time.minute, 0)

        after_dt = before_dt.in_tz("UTC")

        Variable.delete(
            "rb_status_" + self.report_form_sample_timezone_daily.report_title.data
        )
        self.assertEqual(
            report_airflow_variable["schedule_time"], after_dt.strftime("%H:%M")
        )
コード例 #4
0
 def teardown_class(cls):
     """
     Delete the airflow variable.
     """
     print("Removing airflow variable...")
     Variable.delete("rb_status_" +
                     cls.report_form_sample.report_title.data)
コード例 #5
0
def cleanup(**kwargs):
    ec2, emr, iam = emrs.get_boto_clients(config['AWS']['REGION_NAME'],
                                          config=config)
    ec2.delete_key_pair(KeyName=Variable.get('keypair_name'))
    emrs.delete_security_group(ec2, Variable.get('master_sg_id'))
    time.sleep(2)
    emrs.delete_security_group(ec2, Variable.get('slave_sg_id'))
    Variable.delete('cluster_id')
    Variable.delete('keypair_name')
    Variable.delete('master_sg_id')
    Variable.delete('slave_sg_id')
    Variable.delete('short_interests_dag_state')
コード例 #6
0
    def test_write(self):
        """
        Test records can be written and overwritten
        """
        Variable.set(key="test_key", value="test_val")

        session = settings.Session()
        result = session.query(RTIF).all()
        assert [] == result

        with DAG("test_write", start_date=START_DATE):
            task = BashOperator(task_id="test",
                                bash_command="echo {{ var.value.test_key }}")

        rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE))
        rtif.write()
        result = (session.query(RTIF.dag_id, RTIF.task_id,
                                RTIF.rendered_fields).filter(
                                    RTIF.dag_id == rtif.dag_id,
                                    RTIF.task_id == rtif.task_id,
                                    RTIF.execution_date == rtif.execution_date,
                                ).first())
        assert ('test_write', 'test', {
            'bash_command': 'echo test_val',
            'env': None
        }) == result

        # Test that overwrite saves new values to the DB
        Variable.delete("test_key")
        Variable.set(key="test_key", value="test_val_updated")

        with DAG("test_write", start_date=START_DATE):
            updated_task = BashOperator(
                task_id="test", bash_command="echo {{ var.value.test_key }}")

        rtif_updated = RTIF(
            TI(task=updated_task, execution_date=EXECUTION_DATE))
        rtif_updated.write()

        result_updated = (session.query(
            RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter(
                RTIF.dag_id == rtif_updated.dag_id,
                RTIF.task_id == rtif_updated.task_id,
                RTIF.execution_date == rtif_updated.execution_date,
            ).first())
        assert (
            'test_write',
            'test',
            {
                'bash_command': 'echo test_val_updated',
                'env': None
            },
        ) == result_updated
コード例 #7
0
def update_schemas(**kwargs):
    schemas = get_all_schemas()
    # we update all schemas that we found:
    for key, value in schemas.items():
        Variable.set(key=key, value=value, serialize_json=True)
    # now we clean the variables that do not exist anymore:
    with create_session() as session:
        current_vars = set(var.key for var in session.query(Variable))
        apps_to_delete = current_vars - schemas.keys()
        print("About to delete old apps: {}".format(apps_to_delete))
        for _var in apps_to_delete:
            Variable.delete(_var, session)
コード例 #8
0
def resend_reports():
    with create_session() as session:
        for var in session.query(Variable):
            if any(prefix in var.key for prefix in ["post_progress__", "post_results__"]):
                logging.debug(f"Retreive {var.key} from Variables")
                value = Variable.get(key=var.key, deserialize_json=True)
                try:
                    http_hook.run(
                        endpoint=value["endpoint"],
                        json=value["message"],
                        extra_options={"timeout": 30}  # need to have timeout otherwise may get stuck forever
                    )
                    Variable.delete(key=var.key)
                    logging.debug(f"Value from {var.key} variable has been successfully sent")
                except Exception as err:
                    logging.debug(f"Failed to POST value from {var.key} variable. Will retry in the next run \n {err}")
コード例 #9
0
ファイル: _variables.py プロジェクト: fossabot/docker-airflow
def main(session=None):
    print("Variables ********************")
    created_variables = set()
    for variable in variables:
        key = variable["key"]
        print(f"Create: {key}")
        Variable.set(**variable)
        created_variables.add(key)
    for variable in session.query(Variable).all():
        key = variable.key
        if key in created_variables:
            continue
        if key not in created_variables:
            print(f"Delete: {key}")
            Variable.delete(key)
    return
コード例 #10
0
def delete_variable(variable_key: str) -> Response:
    """
    Delete variable
    """
    if Variable.delete(variable_key) == 0:
        raise NotFound("Variable not found")
    return Response(status=204)
コード例 #11
0
    def test_variable_delete(self):
        key = "tested_var_delete"
        value = "to be deleted"

        # No-op if the variable doesn't exist
        Variable.delete(key)
        with pytest.raises(KeyError):
            Variable.get(key)

        # Set the variable
        Variable.set(key, value)
        assert value == Variable.get(key)

        # Delete the variable
        Variable.delete(key)
        with pytest.raises(KeyError):
            Variable.get(key)
コード例 #12
0
    def test_variable_delete(self):
        key = "tested_var_delete"
        value = "to be deleted"

        # No-op if the variable doesn't exist
        Variable.delete(key)
        with self.assertRaises(KeyError):
            Variable.get(key)

        # Set the variable
        Variable.set(key, value)
        self.assertEqual(value, Variable.get(key))

        # Delete the variable
        Variable.delete(key)
        with self.assertRaises(KeyError):
            Variable.get(key)
コード例 #13
0
ファイル: resend_results.py プロジェクト: NSAPH/cwl-airflow
def resend_results():
    with create_session() as session:
        for var in session.query(Variable):
            if any(prefix in var.key
                   for prefix in ["post_progress", "post_results"]):
                logging.info(f"Retreive {var.key} from Variables")
                value = Variable.get(key=var.key, deserialize_json=True)
                try:
                    http_hook.run(endpoint=value["endpoint"],
                                  json=value["message"])
                    Variable.delete(key=var.key)
                    logging.info(
                        f"Value from {var.key} variable has been successfully sent"
                    )
                except Exception as err:
                    logging.info(
                        f"Failed to POST value from {var.key} variable. Will retry in the next run \n {err}"
                    )
コード例 #14
0
    def test_daily_schedule_conversion(self):
        """
        Test that daily schedule is converted properly into cron expression.
        """
        Variable.delete("rb_status_" +
                        self.report_form_sample_daily.report_title.data)
        report_saver = ReportFormSaver(self.report_form_sample_daily)
        report_saver.extract_report_data_into_airflow(report_exists=False)
        report_airflow_variable = Variable.get(
            "rb_status_" + self.report_form_sample_daily.report_title.data,
            deserialize_json=True,
        )

        time = self.report_form_sample_daily.schedule_time.data
        tz = self.report_form_sample_daily.schedule_timezone.data

        before_dt = pendulum.now().in_tz(tz).at(time.hour, time.minute, 0)
        after_dt = before_dt.in_tz("UTC")

        Variable.delete("rb_status_" +
                        self.report_form_sample_daily.report_title.data)
        assert (f"{after_dt.minute} {after_dt.hour} * * *" ==
                report_airflow_variable["schedule"])
コード例 #15
0
ファイル: servicenow.py プロジェクト: shahbaz-ali/miniBRS
def clean_up(dag_id, execution_date, session=None):

    # check for empty
    if is_empty(dag_id) or is_empty(execution_date) or is_empty(session):
        raise InvalidArguments(
            "dag_id, task_id ,execution_date and session can't be empty")

    # check for none
    if dag_id is None or execution_date is None:
        raise InvalidArguments("dag_id, task_id can't be None")

    try:

        search = pendulum.strptime(execution_date, "%Y-%m-%dT%H:%M:%S")

        execution_date = execution_date.replace('T', ' ')

        r_config = json.loads(Variable.get("r_config"))
        if dag_id in r_config:
            exec_dates = r_config[dag_id]
            if execution_date in exec_dates:
                exec_dates.remove(execution_date)
                r_config[dag_id] = exec_dates

            if len(r_config[dag_id]) == 0:
                del r_config[dag_id]
        if len(r_config) != 0:
            Variable.set(key="r_config", value=json.dumps(r_config))
        else:
            Variable.delete('r_config')

        # update airflow meta-database
        session.query(FailedDagRun).filter(FailedDagRun.dag_id == dag_id, FailedDagRun.execution_date.like(search))\
            .update({'state': 'recovery_executed'}, synchronize_session='fetch')

    except Exception as e:
        LoggingMixin().log.error(e)
コード例 #16
0
def execute_insert_into(target_db,
                        table_name=None,
                        run_fetch_task_id=None,
                        field_mapping=None,
                        task_instance=None,
                        **kwargs):
    """Inserts each paginated response data into target database table.
    Polls to find variable hasn't been processed, generates regarding sql statement to
    insert data in, incrementally waits for new variables.
    Success depends on fetcher task completion.
    """

    insert_into_sql = """
        INSERT INTO {{ table_name }} (
        {% for _, tt_field_name, _ in field_mapping %}
            {{ tt_field_name }}{{ "," if not loop.last }}
        {% endfor %}
        )
        VALUES
        {% for record in record_subset %}
        (
            {% for st_field_name, _, _ in field_mapping %}
                {% if not record[st_field_name] or record[st_field_name] == 'None' %}
                    NULL
                {% else %}
                    {{ record[st_field_name] }}
                {% endif %}
                {{ "," if not loop.last }}
            {% endfor %}
        {{ ")," if not loop.last }}
        {% endfor %}
        );
    """
    # Give some initial time to fetch task to get a page and save it into variable
    time.sleep(3)
    # Used for providing incremental wait
    sleep_time = 5
    number_of_run = 1
    redis_client = get_redis_client()
    table_exists = task_instance.xcom_pull(
        task_ids='check-if-table-exists')[0][0]
    if table_exists and table_exists != 'None':
        table_name = f'{table_name}_copy'
    try:
        target_db_conn = PostgresHook(postgres_conn_id=target_db).get_conn()
        target_db_cursor = target_db_conn.cursor()

        while True:
            var_name = get_available_page_var(redis_client, run_fetch_task_id)
            if var_name:
                logging.info(f'Got the unprocessed var_name {var_name}')
                sleep_time = 5
                var_name = var_name.decode('utf-8')
                try:
                    record_subset = json.loads(Variable.get(var_name))
                except KeyError:
                    logging.info(
                        f'Var {var_name} no more exist! It is processed by another worker. Moving on.'
                    )
                    continue

                escaped_record_subset = []
                for record in record_subset:
                    escaped_record = {}
                    for key, value in record.items():
                        if value and value != 'None':
                            escaped_record[key] = sql.Literal(value).as_string(
                                target_db_conn)
                        else:
                            escaped_record[key] = sql.Literal(None).as_string(
                                target_db_conn)
                    escaped_record_subset.append(escaped_record)

                exec_sql = Template(insert_into_sql).render(
                    table_name=sql.Identifier(table_name).as_string(
                        target_db_conn),
                    field_mapping=field_mapping,
                    record_subset=escaped_record_subset,
                )
                target_db_cursor.execute(exec_sql)
                logging.info(f'Deleting the var_name {var_name}')
                Variable.delete(var_name)
            else:
                # Check if fetch task completed successfully, if it's, break out of loop and commit
                # the transaction because there is no more page to process. If it's failed raise Exception so that
                # transaction will be rollbacked
                state = task_instance.xcom_pull(key='state',
                                                task_ids=run_fetch_task_id)
                logging.info(f'Checking the state of fetcher task {state}')
                if state is False:
                    raise Exception('Fetcher task failed!')
                elif state is True:
                    logging.info(
                        'Fetcher task successfully completed and there is no more variable to process.'
                    )
                    break
                else:
                    logging.info(
                        f'Sleeping for {sleep_time} fetcher task to catchup')
                    sleep_time = sleep_time * number_of_run
                    time.sleep(sleep_time)
                    number_of_run += 1

        target_db_conn.commit()
        task_instance.xcom_push(key='state', value=True)

    # TODO: Gotta Catch'm all
    except Exception as e:
        logging.error(f'Exception: {e}')
        target_db_conn.rollback()
        task_instance.xcom_push(key='state', value=False)
        raise

    finally:
        if target_db_conn:
            target_db_cursor.close()
            target_db_conn.close()
コード例 #17
0
    def execute(self, context):  # NoQA
        # When doing 'airflow test' there is a context['params']
        # For full dag runs, there is dag_run["conf"]
        dag_run = context["dag_run"]
        if dag_run is None:
            params = context["params"]
        else:
            params = dag_run.conf or {}
        self.log.debug("PARAMS: %s", params)
        max_records = params.get("max_records", self.max_records)
        cursor_pos = params.get("cursor_pos", Variable.get(f"{self.db_table_name}.cursor_pos", 0))
        batch_size = params.get("batch_size", self.batch_size)
        with TemporaryDirectory() as temp_dir:
            tmp_file = Path(temp_dir) / "out.ndjson"
            http = HttpParamsHook(http_conn_id=self.http_conn_id, method="POST")

            self.log.info("Calling GOB graphql endpoint")

            # we know the schema, can be an input param (schema_def_from_url function)
            # We use the ndjson importer from schematools, give it a tmp tablename
            pg_hook = PostgresHook()
            schema_def = schema_def_from_url(SCHEMA_URL, self.dataset)
            importer = NDJSONImporter(schema_def, pg_hook.get_sqlalchemy_engine(), logger=self.log)

            importer.generate_db_objects(
                table_name=self.schema,
                db_table_name=f"{self.db_table_name}_new",
                ind_tables=True,
                ind_extra_index=False,
            )
            # For GOB content, cursor value is exactly the same as
            # the record index. If this were not true, the cursor needed
            # to be obtained from the last content record
            records_loaded = 0

            with self.graphql_query_path.open() as gql_file:
                query = gql_file.read()

            # Sometime GOB-API fail with 500 error, caught by Airflow
            # We retry several times
            while True:

                force_refresh_token = False
                for i in range(3):
                    try:
                        request_start_time = time.time()
                        headers = self._fetch_headers(force_refresh=force_refresh_token)
                        response = http.run(
                            self.endpoint,
                            self._fetch_params(),
                            json.dumps(
                                dict(
                                    query=self.add_batch_params_to_query(
                                        query, cursor_pos, batch_size
                                    )
                                )
                            ),
                            headers=headers,
                            extra_options={"stream": True},
                        )
                    except AirflowException:
                        self.log.exception("Cannot reach %s", self.endpoint)
                        force_refresh_token = True
                        time.sleep(1)
                    else:
                        break
                else:
                    # Save cursor_pos in a variable
                    Variable.set(f"{self.db_table_name}.cursor_pos", cursor_pos)
                    raise AirflowException("All retries on GOB-API have failed.")

                records_loaded += batch_size
                # No records returns one newline and a Content-Length header
                # If records are available, there is no Content-Length header
                if int(response.headers.get("Content-Length", "2")) < 2:
                    break
                # When content is encoded (gzip etc.) we need this:
                # response.raw.read = functools.partial(response.raw.read, decode_content=True)
                try:
                    with tmp_file.open("wb") as wf:
                        shutil.copyfileobj(response.raw, wf, self.copy_bufsize)

                    request_end_time = time.time()
                    self.log.info(
                        "GOB-API request took %s seconds, cursor: %s",
                        request_end_time - request_start_time,
                        cursor_pos,
                    )
                    last_record = importer.load_file(tmp_file)
                except (SQLAlchemyError, ProtocolError, UnicodeDecodeError) as e:
                    # Save last imported file for further inspection
                    shutil.copy(
                        tmp_file,
                        f"/tmp/{self.db_table_name}-{datetime.now().isoformat()}.ndjson",
                    )
                    Variable.set(f"{self.db_table_name}.cursor_pos", cursor_pos)
                    raise AirflowException("A database error has occurred.") from e

                self.log.info(
                    "Loading db took %s seconds",
                    time.time() - request_end_time,
                )
                if last_record is None or (
                    max_records is not None and records_loaded >= max_records
                ):
                    break
                cursor_pos = last_record["cursor"]

        # On successfull completion, remove cursor_pos variable
        Variable.delete(f"{self.db_table_name}.cursor_pos")
コード例 #18
0
def delete_variable(*, variable_key: str) -> Response:
    """Delete variable"""
    if Variable.delete(variable_key) == 0:
        raise NotFound("Variable not found")
    return Response(status=HTTPStatus.NO_CONTENT)
コード例 #19
0
def variables_delete(args):
    """Deletes variable by a given name"""
    Variable.delete(args.key)
コード例 #20
0
def variables_delete(args):
    """Deletes variable by a given name"""
    Variable.delete(args.key)
    print(f"Variable {args.key} deleted")
コード例 #21
0
 def tearDownClass(self):
     """
     Delete the airflow variable.
     """
     print("Removing airflow variable...")
     Variable.delete("rb_status_" + self.report_form_sample.report_title.data)
コード例 #22
0
 def rollback_variables(index):
     for i in range(index):
         key = f'{run_fetch_task_id}{i}'
         Variable.delete(key)
         redis_client.delete(key)