コード例 #1
0
ファイル: test_variable.py プロジェクト: alrolorojas/airflow
    def test_var_with_encryption_rotate_fernet_key(self, mock_get):
        """
        Tests rotating encrypted variables.
        """
        key1 = Fernet.generate_key()
        key2 = Fernet.generate_key()

        mock_get.return_value = key1.decode()
        Variable.set('key', 'value')
        session = settings.Session()
        test_var = session.query(Variable).filter(Variable.key == 'key').one()
        self.assertTrue(test_var.is_encrypted)
        self.assertEqual(test_var.val, 'value')
        self.assertEqual(Fernet(key1).decrypt(test_var._val.encode()), b'value')

        # Test decrypt of old value with new key
        mock_get.return_value = ','.join([key2.decode(), key1.decode()])
        crypto._fernet = None
        self.assertEqual(test_var.val, 'value')

        # Test decrypt of new value with new key
        test_var.rotate_fernet_key()
        self.assertTrue(test_var.is_encrypted)
        self.assertEqual(test_var.val, 'value')
        self.assertEqual(Fernet(key2).decrypt(test_var._val.encode()), b'value')
コード例 #2
0
 def __call__(self):
     # read the configuration out of S3
     with open(local_cfg_file_location) as data_file:    
         config = json.load(data_file)
     # setup the environment variables
     for var in config['variables']:
         Variable.set(var['name'],var['value'])
コード例 #3
0
ファイル: cli.py プロジェクト: cjquinon/incubator-airflow
def variables(args):

    if args.get:
        try:
            var = Variable.get(args.get,
                               deserialize_json=args.json,
                               default_var=args.default)
            print(var)
        except ValueError as e:
            print(e)
    if args.delete:
        session = settings.Session()
        session.query(Variable).filter_by(key=args.delete).delete()
        session.commit()
        session.close()
    if args.set:
        Variable.set(args.set[0], args.set[1])
    # Work around 'import' as a reserved keyword
    imp = getattr(args, 'import')
    if imp:
        if os.path.exists(imp):
            import_helper(imp)
        else:
            print("Missing variables file.")
    if args.export:
        export_helper(args.export)
    if not (args.set or args.get or imp or args.export or args.delete):
        # list all variables
        session = settings.Session()
        vars = session.query(Variable)
        msg = "\n".join(var.key for var in vars)
        print(msg)
コード例 #4
0
ファイル: test_variable.py プロジェクト: alrolorojas/airflow
 def test_variable_no_encryption(self, mock_get):
     """
     Test variables without encryption
     """
     mock_get.return_value = ''
     Variable.set('key', 'value')
     session = settings.Session()
     test_var = session.query(Variable).filter(Variable.key == 'key').one()
     self.assertFalse(test_var.is_encrypted)
     self.assertEqual(test_var.val, 'value')
コード例 #5
0
ファイル: test_variable.py プロジェクト: alrolorojas/airflow
 def test_variable_with_encryption(self, mock_get):
     """
     Test variables with encryption
     """
     mock_get.return_value = Fernet.generate_key().decode()
     Variable.set('key', 'value')
     session = settings.Session()
     test_var = session.query(Variable).filter(Variable.key == 'key').one()
     self.assertTrue(test_var.is_encrypted)
     self.assertEqual(test_var.val, 'value')
コード例 #6
0
ファイル: cli.py プロジェクト: yogesh2021/airflow
def variables(args):
    if args.get:
        try:
            var = Variable.get(args.get, deserialize_json=args.json, default_var=args.default)
            print(var)
        except ValueError as e:
            print(e)
    if args.set:
        Variable.set(args.set[0], args.set[1])
    if not args.set and not args.get:
        # list all variables
        session = settings.Session()
        vars = session.query(Variable)
        msg = "\n".join(var.key for var in vars)
        print(msg)
コード例 #7
0
  def MonthlyGenerateTestArgs(**kwargs):

    """Loads the configuration that will be used for this Iteration."""
    conf = kwargs['dag_run'].conf
    if conf is None:
      conf = dict()

    # If version is overriden then we should use it otherwise we use it's
    # default or monthly value.
    version = conf.get('VERSION') or istio_common_dag.GetVariableOrDefault('monthly-version', None)
    if not version or version == 'INVALID':
      raise ValueError('version needs to be provided')
    Variable.set('monthly-version', 'INVALID')

    #GCS_MONTHLY_STAGE_PATH is of the form ='prerelease/{version}'
    gcs_path = 'prerelease/%s' % (version)

    branch = conf.get('BRANCH') or istio_common_dag.GetVariableOrDefault('monthly-branch', None)
    if not branch or branch == 'INVALID':
      raise ValueError('branch needs to be provided')
    Variable.set('monthly-branch', 'INVALID')
    mfest_commit = conf.get('MFEST_COMMIT') or branch

    default_conf = environment_config.GetDefaultAirflowConfig(
        branch=branch,
        gcs_path=gcs_path,
        mfest_commit=mfest_commit,
        pipeline_type='monthly',
        verify_consistency='true',
        version=version)

    config_settings = dict()
    for name in default_conf.iterkeys():
      config_settings[name] = conf.get(name) or default_conf[name]

    # These are the extra params that are passed to the dags for monthly release
    monthly_conf = dict()
    monthly_conf['DOCKER_HUB'              ] = 'istio'
    monthly_conf['GCR_RELEASE_DEST'        ] = 'istio-io'
    monthly_conf['GCS_GITHUB_PATH'         ] = 'istio-secrets/github.txt.enc'
    monthly_conf['RELEASE_PROJECT_ID'      ] = 'istio-io'
    # GCS_MONTHLY_RELEASE_PATH is of the form  'istio-release/releases/{version}'
    monthly_conf['GCS_MONTHLY_RELEASE_PATH'] = 'istio-release/releases/%s' % (version)
    for name in monthly_conf.iterkeys():
      config_settings[name] = conf.get(name) or monthly_conf[name]

    testMonthlyConfigSettings(config_settings)
    return config_settings
コード例 #8
0
ファイル: tuto2.py プロジェクト: VViles/airflow_test
def set_sms(*args, **context):
    group = Variable.get('group')
    if group == 'night_shift':
        context['task_instance'].xcom_push('recipient', '0011223344')
        context['task_instance'].xcom_push('message', 'night airflow message')
    else:
        context['task_instance'].xcom_push('recipient', '0011223344')
        context['task_instance'].xcom_push('message', 'day airflow message')
コード例 #9
0
ファイル: cli.py プロジェクト: cjquinon/incubator-airflow
def import_helper(filepath):
    with open(filepath, 'r') as varfile:
        var = varfile.read()

    try:
        d = json.loads(var)
    except Exception:
        print("Invalid variables file.")
    else:
        try:
            n = 0
            for k, v in d.items():
                if isinstance(v, dict):
                    Variable.set(k, v, serialize_json=True)
                else:
                    Variable.set(k, v)
                n += 1
        except Exception:
            pass
        finally:
            print("{} of {} variables successfully updated.".format(n, len(d)))
コード例 #10
0
    def failed(self, context):
        self.conf = context["conf"]
        self.task = context["task"]
        self.execution_date = context["execution_date"]
        self.dag = context["dag"]
        self.errors = SlackAPIPostOperator(
            task_id='task_failed',
            token=Variable.get('slack_token'),
            channel='C1SRU2R33',
            text="Your DAG has encountered an error, please follow the link to view the log details:  " + "http://localhost:8080/admin/airflow/log?" + "task_id=" + task.task_id + "&" +\
            "execution_date=" + execution_date.isoformat() + "&" + "dag_id=" + dag.dag_id,
            dag=pipeline
        )

        errors.execute()
コード例 #11
0
def ReportDailySuccessful(task_instance, **kwargs):
  date = kwargs['execution_date']
  latest_run = float(Variable.get('latest_daily_timestamp'))

  timestamp = time.mktime(date.timetuple())
  logging.info('Current run\'s timestamp: %s \n'
               'latest_daily\'s timestamp: %s', timestamp, latest_run)
  if timestamp >= latest_run:
    Variable.set('latest_daily_timestamp', timestamp)
    run_sha = task_instance.xcom_pull(task_ids='get_git_commit')
    latest_version = GetSettingPython(task_instance, 'VERSION')
    logging.info('setting latest green daily to: %s', run_sha)
    Variable.set('latest_sha', run_sha)
    Variable.set('latest_daily', latest_version)
    logging.info('latest_sha test to %s', run_sha)
コード例 #12
0
def ReportMonthlySuccessful(task_instance, **kwargs):
  del kwargs
  version = istio_common_dag.GetSettingPython(task_instance, 'VERSION')
  try:
    match = re.match(r'([0-9])\.([0-9])\.([0-9]).*', version)
    major, minor, patch = match.group(1), match.group(2), match.group(3)
    Variable.set('major_version', major)
    Variable.set('released_version_minor', minor)
    Variable.set('released_version_patch', patch)
  except (IndexError, AttributeError):
    logging.error('Could not extract released version infomation. \n'
                  'Please set airflow version Variables manually.'
                  'After you are done hit Mark Success.')
コード例 #13
0
        def wrapped(context):
            """ping error in slack on failure and provide link to the log"""
            conf = context["conf"]
            task = context["task"]
            execution_date = context["execution_date"]
            dag = context["dag"]
            base_url = conf.get('webserver', 'base_url')

            # Get the ID of the target slack channel
            slack_token = Variable.get(slack_token_variable)
            sc = SlackClient(slack_token)

            response = sc.api_call('channels.list')
            for channel in response['channels']:
                if channel['name'].lower() == channel_name.lower():
                    break
            else:
                raise AirflowException('No channel named {} found.'.format(channel_name))

            # Construct a slack operator to send the message off.
            notifier = cls(
                task_id='task_failed',
                token=slack_token,
                channel=channel['id'],
                text=(
                    "Your DAG has encountered an error, please follow the link "
                    "to view the log details:  "
                    "{}/admin/airflow/log?"
                        "task_id={}&"
                        "dag_id={}&"
                        "execution_date={}"
                    ).format(base_url, task.task_id, dag.dag_id,
                             execution_date.isoformat()),
                dag=dag,
            )
            notifier.execute()
コード例 #14
0
from airflow.hooks.postgres_hook import PostgresHook
from airflow.exceptions import AirflowException
from airflow.models import Variable
from psycopg2 import ProgrammingError

from mongoengine import connect
from sqlalchemy.exc import OperationalError

Logger = logging.getLogger(__name__)

DEFAULT_HEADER = {"Content-Type": "application/json"}

KERNEL_HOOK_BASE = HttpHook(http_conn_id="kernel_conn", method="GET")

try:
    HTTP_HOOK_RUN_RETRIES = int(Variable.get("HTTP_HOOK_RUN_RETRIES", 5))
except (OperationalError, ValueError):
    HTTP_HOOK_RUN_RETRIES = 5


@retry(
    wait=wait_exponential(),
    stop=stop_after_attempt(HTTP_HOOK_RUN_RETRIES),
    retry=retry_if_exception_type(
        (requests.ConnectionError, requests.Timeout)),
)
def http_hook_run(api_hook,
                  method,
                  endpoint,
                  data=None,
                  headers=DEFAULT_HEADER,
# This DAG is configured to print the date and sleep for 5 seconds.
# However, it is configured to fail (see the expect_failure bash_command)
# and send an e-mail to your specified email on task failure.

from airflow import DAG
from airflow.models import Variable
from airflow.operators.bash_operator import BashOperator
from datetime import datetime, timedelta

YESTERDAY = datetime.combine(
    datetime.today() - timedelta(days=1), datetime.min.time())

default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'start_date': YESTERDAY,
    'email': [ Variable.get('email') ],
    'email_on_failure': True,
    'email_on_retry': False,
    'retries': 0,
}

with DAG('hello_world_email_bonus', default_args=default_args) as dag:
  t1 = BashOperator(task_id='print_date', bash_command='date', dag=dag)
  t2 = BashOperator(task_id='expect_failure', bash_command='exit 1', dag=dag)
  t1 >> t2
コード例 #16
0
    'start_date': START_DATE,
    'email': ['*****@*****.**'],
    'email_on_failure': True,
    'email_on_retry': False,
    'retries': 3,
    'wait_for_downstream': True,
    'retry_delay': timedelta(minutes=3),
}

dag_name = "DAG_NAME"
dag = DAG(dag_name,
          catchup=False,
          default_args=default_args,
          schedule_interval="SCHEDULE_INTERVAL")

dag_config = Variable.get('VARIABLES_NAME', deserialize_json=True)
klarna_username = dag_config['klarna_username']
klarna_password = dag_config['klarna_password']
aws_conn = BaseHook.get_connection("aws_conn")
s3_bucket = dag_config['s3_bucket']
datasource_type = dag_config['datasource_type']
account_name = dag_config['account_name']
file_path = dag_config['file_path']
s3_folder_path = "/{datasource_type}/{transaction_type}/{account_name}/{year}/{month}/{day}/"

file_key_regex = 'klarna_transaction_{0}/'.format(account_name)

print("Starting job rn ")

timestamp = int(time.time() * 1000.0)
コード例 #17
0
ファイル: tuto2.py プロジェクト: VViles/airflow_test
def set_call(*args, **context):
    group = Variable.get('group')
    if group == 'night_shift':
        context['task_instance'].xcom_push(key='recipient', value='0011223344')
    else:
        context['task_instance'].xcom_push(key='recipient', value='0011223344')
コード例 #18
0
try:
    from airflow.utils import timezone  # airflow.utils.timezone is available from v1.10 onwards
    now = timezone.utcnow
except ImportError:
    now = datetime.utcnow

DAG_ID = os.path.basename(__file__).replace(".pyc", "").replace(
    ".py", "")  # airflow-db-cleanup
START_DATE = airflow.utils.dates.days_ago(1)
SCHEDULE_INTERVAL = "@daily"  # How often to Run. @daily - Once a day at Midnight (UTC)
DAG_OWNER_NAME = "operations"  # Who is listed as the owner of this DAG in the Airflow Web Server
ALERT_EMAIL_ADDRESSES = [
]  # List of email address to send email alerts to if this job fails
DEFAULT_MAX_DB_ENTRY_AGE_IN_DAYS = int(
    Variable.get("airflow_db_cleanup__max_db_entry_age_in_days", 30)
)  # Length to retain the log files if not already provided in the conf. If this is set to 30, the job will remove those files that are 30 days old or older.
ENABLE_DELETE = True  # Whether the job should delete the db entries or not. Included if you want to temporarily avoid deleting the db entries.
DATABASE_OBJECTS = [  # List of all the objects that will be deleted. Comment out the DB objects you want to skip.
    {
        "airflow_db_model": DagRun,
        "age_check_column": DagRun.execution_date,
        "keep_last": True,
        "keep_last_filters": [DagRun.external_trigger == 0],
        "keep_last_group_by": DagRun.dag_id
    },
    {
        "airflow_db_model": TaskInstance,
        "age_check_column": TaskInstance.execution_date,
        "keep_last": False,
        "keep_last_filters": None,
コード例 #19
0
"""
#
# Pools setup either by UI or backend
#
"""

from airflow.models import Variable

# Task Pools
WAGL_TASK_POOL = Variable.get("wagl_task_pool", "wagl_processing_pool")
コード例 #20
0
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from airflow.contrib.operators.file_to_gcs import FileToGoogleCloudStorageOperator
from airflow.models import Variable
from datetime import datetime, timedelta

import logging

home = expanduser("~")

SAVE_PATH = '{0}/gcs/data/powerschool/'.format(home)
BASE_URL = 'https://kippchicago.powerschool.com'
MAXPAGESIZE = 1000
STATE_FILEPATH = '{0}/gcs/data/'.format(home) + 'state.json'

client_id = Variable.get("ps_client_id")
client_secret = Variable.get("ps_client_secret")
credentials_concat = '{0}:{1}'.format(client_id, client_secret)
CREDENTIALS_ENCODED = base64.b64encode(credentials_concat.encode('utf-8'))

endpoints = [{
    "table_name":
    "attendance",
    "query_expression":
    "yearid==29",
    "projection":
    "dcid,id,attendance_codeid,calendar_dayid,schoolid,yearid,studentid,ccid,periodid,parent_attendanceid,att_mode_code,att_comment,att_interval,prog_crse_type,lock_teacher_yn,lock_reporting_yn,transaction_type,total_minutes,att_date,ada_value_code,ada_value_time,adm_value,programid,att_flags,whomodifiedid,whomodifiedtype,ip_address"
}]

#################################
# Airflow specific DAG set up ##
コード例 #21
0
    {"id": "CAPACITY_TYPE", "type": "text"},
    {"id": "CAPACITY_ACTUAL_BED", "type": "text"},
    {"id": "CAPACITY_FUNDING_BED", "type": "text"},
    {"id": "OCCUPIED_BEDS", "type": "text"},
    {"id": "UNOCCUPIED_BEDS", "type": "text"},
    {"id": "UNAVAILABLE_BEDS", "type": "text"},
    {"id": "CAPACITY_ACTUAL_ROOM", "type": "text"},
    {"id": "CAPACITY_FUNDING_ROOM", "type": "text"},
    {"id": "OCCUPIED_ROOMS", "type": "text"},
    {"id": "UNOCCUPIED_ROOMS", "type": "text"},
    {"id": "UNAVAILABLE_ROOMS", "type": "text"},
    {"id": "OCCUPANCY_RATE_BEDS", "type": "text"},
    {"id": "OCCUPANCY_RATE_ROOMS", "type": "text"},
]

TMP_DIR = Variable.get("tmp_dir")

ckan_creds = Variable.get("ckan_credentials_secret", deserialize_json=True)
active_env = Variable.get("active_env")
ckan_address = ckan_creds[active_env]["address"]
ckan_apikey = ckan_creds[active_env]["apikey"]

ckan = ckanapi.RemoteCKAN(apikey=ckan_apikey, address=ckan_address)

with DAG(
    PACKAGE_NAME,
    default_args=airflow_utils.get_default_args(
        {
            "on_failure_callback": task_failure_slack_alert,
            "start_date": datetime(2020, 11, 24, 13, 35, 0),
            "retries": 0,
コード例 #22
0
    'email_on_retry': False,
}

dag = airflow.DAG('dim_oride_passenger_whitelist_base',
                  schedule_interval="20 03 * * *",
                  default_args=args,
                  )

##----------------------------------------- 变量 ---------------------------------------##

db_name="oride_dw"
table_name = "dim_oride_passenger_whitelist_base"

##----------------------------------------- 依赖 ---------------------------------------##
#获取变量
code_map=eval(Variable.get("sys_flag"))

#判断ufile(cdh环境)
if code_map["id"].lower()=="ufile":
    # 依赖前一天分区
    ods_sqoop_base_data_user_whitelist_df_task = UFileSensor(
        task_id='ods_sqoop_base_data_user_whitelist_df_task',
        filepath='{hdfs_path_str}/dt={pt}/_SUCCESS'.format(
            hdfs_path_str="oride_dw_sqoop/oride_data/data_user_whitelist",
            pt='{{ds}}'
        ),
        bucket_name='opay-datalake',
        poke_interval=60,  # 依赖不满足时,一分钟检查一次依赖状态
        dag=dag
    )
    #路径
コード例 #23
0
from psaw import PushshiftAPI
'''
-----------------------------
Setup
=============================
'''

load_dotenv()

logger = logging.getLogger(__name__)

# load variables
# first try to get airflow variables and then default to os variables
try:
    from airflow.models import Variable
    reddit_client_id = Variable.get(
        'REDDIT_CLIENT_ID', default_var=os.environ.get('REDDIT_CLIENT_ID'))
    reddit_client_secret = Variable.get(
        'REDDIT_CLIENT_SECRET',
        default_var=os.environ.get('REDDIT_CLIENT_SECRET'))
    reddit_user_agent = Variable.get(
        'REDDIT_USER_AGENT', default_var=os.environ.get('REDDIT_USER_AGENT'))
    google_storage_bucket_name = Variable.get(
        'GOOGLE_STORAGE_BUCKET_NAME',
        default_var=os.environ.get('GOOGLE_STORAGE_BUCKET_NAME'))
except:
    reddit_client_id = os.environ.get('REDDIT_CLIENT_ID')
    reddit_client_secret = os.environ.get('REDDIT_CLIENT_SECRET')
    reddit_user_agent = os.environ.get('REDDIT_USER_AGENT')
    google_storage_bucket_name = os.environ.get('GOOGLE_STORAGE_BUCKET_NAME')
'''
-----------------------------
コード例 #24
0
ファイル: weather_dag.py プロジェクト: michalpytlos/DE_p6
# DAG
dag = DAG('weather_dag',
          default_args=default_args,
          description='ETL for weather data',
          schedule_interval='0 0 1 1 *'
          )

# Tasks
start_operator = DummyOperator(task_id='Begin_execution',  dag=dag)

stage_weather_stations_to_redshift = CopyFixedWidthRedshiftOperator(
    task_id='Stage_stations_table',
    dag=dag,
    redshift_conn_id='redshift',
    table='staging_weather_stations',
    s3_bucket=Variable.get('s3_weather_bucket'),
    s3_key=Variable.get('s3_weather_stations_key'),
    arn=Variable.get('iam_role_arn'),
    fixedwidth_spec=Variable.get('s3_weather_stations_spec'),
    maxerror=Variable.get('s3_weather_stations_maxerror'),
    load_delete=True
)

load_weather_stations_table = LoadDimensionOperator(
    task_id='Load_weather_stations_table',
    dag=dag,
    redshift_conn_id='redshift',
    table='weather_stations',
    insert_query=SqlQueries.weather_stations_table_insert,
    delete_load=True
)
コード例 #25
0
from airflow import DAG, macros
import airflow
from airflow.hooks.postgres_hook import PostgresHook
from airflow.operators.python_operator import PythonOperator
from airflow.operators.python_operator import BranchPythonOperator
from operators.dwh_operators import PostgresOperatorWithTemplatedParams
from airflow.operators.dummy_operator import DummyOperator
from datetime import datetime, timedelta
import os.path, shutil
from os import path
import gzip, json, csv, psycopg2, glob, logging
from airflow.models import Variable



landing_zone = Variable.get("landing_zone")
archive_dir = Variable.get("archive_dir")
tmpl_search_path = Variable.get("sql_path")
output_dir = Variable.get("output_dir")
config_dir = Variable.get("config_dir")
pattern = r".json.gz"
output_filenames = [os.path.join(output_dir, 'metadata.csv'), os.path.join(output_dir, 'reviews.csv')]
db_config_file = os.path.join(config_dir, 'db_config.json')

logger = logging.getLogger()
# Change format of handler for the logger
logger.handlers[0].setFormatter(logging.Formatter('%(message)s'))

# read dwh details into dictionary
with open(db_config_file, 'r') as f:
    db_config = json.load(f)
コード例 #26
0
import json

default_args = {
    'owner': 'datagap'
}

volume = k8s.V1Volume(
    name='data-volume',
    persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(claim_name='shared-data-volume')
)

volume_mount = k8s.V1VolumeMount(
    name='data-volume', mount_path='/shared-data', sub_path=None, read_only=False
)

login_url = Variable.get("ntreis_login_url")
rets_type = Variable.get("ntreis_rets_type")
search_limit = Variable.get("ntreis_search_limit")
password = Variable.get("ntreis_password")
user_agent = Variable.get("ntreis_user_agent")
working_dir = Variable.get("ntreis_working_dir")
server_version = Variable.get("ntreis_server_version")
username = Variable.get("ntreis_username")

activeTemplateUrl = Variable.get("ntreis_prop_active_index_url")
activePropDataSource = Variable.get("ntreis_prop_active_datasource")
activeQuery1 = Variable.get("ntreis_prop_active_query_1")
activeQuery2 = Variable.get("ntreis_prop_active_query_2")

def downloadTemplate(templateUrl):
  request = urllib.request.urlopen(templateUrl)
コード例 #27
0
def pull_pricing_data(**kwargs):
    Variable.set('prices_dag_state', 'COMPLETED')
コード例 #28
0
            'TABLE_STOCK_PRICES': config['App']['TABLE_STOCK_PRICES'],
            'URL': "http://app.quotemedia.com/quotetools/getHistoryDownload.csv?&webmasterId=501&startDay={sd}&startMonth={sm}&startYear={sy}&endDay={ed}&endMonth={em}&endYear={ey}&isRanged=true&symbol={sym}",
        }
    },
    dag=dag
)

quality_check_task = PythonOperator(
    task_id='Quality_check',
    python_callable=submit_spark_job_from_file,
    op_kwargs={
        'commonpath': '{}/dags/etl/common.py'.format(airflow_dir),
        'helperspath': '{}/dags/etl/helpers.py'.format(airflow_dir),
        'filepath': '{}/dags/etl/pull_prices_quality.py'.format(airflow_dir), 
        'args': {
            'AWS_ACCESS_KEY_ID': config['AWS']['AWS_ACCESS_KEY_ID'],
            'AWS_SECRET_ACCESS_KEY': config['AWS']['AWS_SECRET_ACCESS_KEY'],
            'YESTERDAY_DATE': '{{yesterday_ds}}',
            'STOCKS': STOCKS,
            'DB_HOST': config['App']['DB_HOST'],
            'TABLE_STOCK_INFO_NASDAQ': config['App']['TABLE_STOCK_INFO_NASDAQ'],
            'TABLE_STOCK_INFO_NYSE': config['App']['TABLE_STOCK_INFO_NYSE'],
            'TABLE_STOCK_PRICES': config['App']['TABLE_STOCK_PRICES'],
        },
        'on_complete': lambda *args: Variable.set('prices_dag_state', 'COMPLETED')
    },
    dag=dag
)

wait_for_fresh_run_task >> wait_for_short_interests_task >> \
pull_stock_symbols_task >> pull_pricing_data_task >> quality_check_task
コード例 #29
0
"""Airflow DAG to run data process for capstone project"""

import datetime
from airflow import DAG
from airflow.models import Variable
from airflow.operators.dummy_operator import DummyOperator
from airflow.contrib.operators.gcp_function_operator import GcfFunctionDeployOperator
from airflow.providers.google.cloud.operators.mlengine import MLEngineCreateVersionOperator, \
    MLEngineSetDefaultVersionOperator
from airflow.utils.dates import days_ago
from datetime import timedelta

SPARK_CLUSTER = Variable.get('SPARK_CLUSTER')
MASTER_MACHINE_TYPE = Variable.get('MASTER_MACHINE_TYPE')
WORKER_MACHINE_TYPE = Variable.get('WORKER_MACHINE_TYPE')
NUMBER_OF_WORKERS = Variable.get('NUMBER_OF_WORKERS')
PROJECT = Variable.get('PROJECT')
ZONE = Variable.get('ZONE')
REGION = Variable.get('REGION')
START_DATE = datetime.datetime(2020, 1, 1)
RAW_DATA = Variable.get('RAW_DATA')
TOKENIZED_DATA_DIR = Variable.get('TOKENIZED_DATA_DIR')
THRESHOLD = Variable.get('THRESHOLD')
MODEL_NAME = Variable.get('MODEL_NAME')
MODEL_DIR = Variable.get('MODEL_DIR')
VERSION_NAME = Variable.get('VERSION_NAME')
DOMAIN_LOOKUP_PATH = Variable.get('DOMAIN_LOOKUP_PATH')

INTERVAL = None

default_args = {
コード例 #30
0
from airflow.models import Variable
import pandas as pd
import sqlalchemy as db
import configparser
import logging

# variables
SOURCE_MYSQL_HOST = Variable.get('SOURCE_MYSQL_HOST')
SOURCE_MYSQL_PORT = Variable.get('SOURCE_MYSQL_PORT')
SOURCE_MYSQL_USER = Variable.get('SOURCE_MYSQL_USER')
SOURCE_MYSQL_PASSWORD = Variable.get('SOURCE_MYSQL_PASSWORD')
SOURCE_MYSQL_ROOT_PASSWORD = Variable.get('SOURCE_MYSQL_ROOT_PASSWORD')
SOURCE_MYSQL_DATABASE = Variable.get('SOURCE_MYSQL_DATABASE')

DW_MYSQL_HOST = Variable.get('DW_MYSQL_HOST')
DW_MYSQL_PORT = Variable.get('DW_MYSQL_PORT')
DW_MYSQL_USER = Variable.get('DW_MYSQL_USER')
DW_MYSQL_PASSWORD = Variable.get('DW_MYSQL_PASSWORD')
DW_MYSQL_ROOT_PASSWORD = Variable.get('DW_MYSQL_ROOT_PASSWORD')
DW_MYSQL_DATABASE = Variable.get('DW_MYSQL_DATABASE')

# Database connection URI
db_conn_url = "mysql+pymysql://{}:{}@{}:{}/{}".format(SOURCE_MYSQL_USER,
                                                      SOURCE_MYSQL_PASSWORD,
                                                      SOURCE_MYSQL_HOST,
                                                      SOURCE_MYSQL_PORT,
                                                      SOURCE_MYSQL_DATABASE)
db_engine = db.create_engine(db_conn_url)

# Data warehouse connection URI
dw_conn_url = "mysql+pymysql://{}:{}@{}:{}/{}".format(DW_MYSQL_USER,
コード例 #31
0
# fetch AWS_KEY and AWS_SECRET
config = configparser.ConfigParser()
config.read('/Users/sathishkaliamoorthy/.aws/credentials')
AWS_KEY = config['default']['aws_access_key_id']
AWS_SECRET = config['default']['aws_secret_access_key']
#AWS_KEY = os.environ.get('AWS_KEY')
#AWS_SECRET = os.environ.get('AWS_SECRET')

# inserting new connection object programmatically
aws_conn = Connection(conn_id='aws_credentials',
                      conn_type='Amazon Web Services',
                      login=AWS_KEY,
                      password=AWS_SECRET)

redshift_conn = Connection(
    conn_id='aws_redshift',
    conn_type='Postgres',
    host='my-sparkify-dwh.cdc1pzfmi32k.us-west-2.redshift.amazonaws.com',
    port=5439,
    schema='sparkify_dwh',
    login='******',
    password='******')
session = settings.Session()
session.add(aws_conn)
session.add(redshift_conn)
session.commit()

# inserting variables programmatically
Variable.set("s3_bucket", "udacity-dend")
コード例 #32
0
    'retry_delay': timedelta(minutes=2),
    'email': ['*****@*****.**'],
    'email_on_failure': True,
    'email_on_retry': False,
}

dag = airflow.DAG('dwd_ocredit_phones_contract_di',
                  schedule_interval="30 00 * * *",
                  default_args=args,
                  )

##----------------------------------------- 变量 ---------------------------------------##
db_name = "ocredit_phones_dw"
table_name = "dwd_ocredit_phones_contract_di"
hdfs_path = "oss://opay-datalake/opay/ocredit_phones_dw/" + table_name
config = eval(Variable.get("ocredit_time_zone_config"))
time_zone = config['NG']['time_zone']
##----------------------------------------- 依赖 ---------------------------------------##

### 检查当前小时的分区依赖
ods_binlog_base_t_contract_all_hi_check_task = OssSensor(
    task_id='ods_binlog_base_t_contract_all_hi_check_task',
    bucket_key='{hdfs_path_str}/dt={pt}/hour=23/_SUCCESS'.format(
        hdfs_path_str="ocredit_phones_all_hi/ods_binlog_base_t_contract_all_hi",
        pt='{{ds}}',
        hour='{{ execution_date.strftime("%H") }}'
    ),
    bucket_name='opay-datalake',
    poke_interval=60,  # 依赖不满足时,一分钟检查一次依赖状态
    dag=dag
)
コード例 #33
0
class EmrClusterController:

    connection = boto3.client(
        'emr',
        region_name='us-east-2',
        aws_access_key_id=Variable.get("aws_access_key_id"),
        aws_secret_access_key=Variable.get("aws_secret_access_key"),
    )

    @staticmethod
    def get_vpc_id():
        ec2 = boto3.client(
            'ec2',
            region_name="us-east-2",
            aws_access_key_id=Variable.get("aws_access_key_id"),
            aws_secret_access_key=Variable.get("aws_secret_access_key"))
        response = ec2.describe_vpcs()
        vpc_id = response.get('Vpcs', [{}])[0].get('VpcId', '')
        return vpc_id

    @staticmethod
    def get_security_group_id(group_name):
        vpc_id = EmrClusterController.get_vpc_id()
        print(f"VPC FOUND: {vpc_id}")
        ec2 = boto3.client(
            'ec2',
            region_name="us-east-2",
            aws_access_key_id=Variable.get("aws_access_key_id"),
            aws_secret_access_key=Variable.get("aws_secret_access_key"))
        response = ec2.describe_security_groups()
        result = ''
        for group in response['SecurityGroups']:
            print(f"Security Group Found: \n\t\t {group}")
            if group["GroupName"] == group_name:
                result = group["GroupId"]
                break
        return result

    @staticmethod
    def get_subnet_id():
        ec2 = boto3.client(
            "ec2",
            region_name="us-east-2",
            aws_access_key_id=Variable.get("aws_access_key_id"),
            aws_secret_access_key=Variable.get("aws_secret_access_key"))
        response = ec2.describe_subnets()
        print(response)
        result = ''
        for subnet in response["Subnets"]:
            print(subnet)
            if subnet["AvailabilityZone"] == 'us-east-2a':
                result = subnet['SubnetId']
        return result

    @staticmethod
    def create_cluster_job_execution(name,
                                     release,
                                     master_node_type="m4.large",
                                     slave_node_type="m4.large",
                                     master_instance_count=1,
                                     slave_instance_count=1):
        emr_master_security_group_id = EmrClusterController.get_security_group_id(
            'security-group-master')
        emr_slave_security_group_id = EmrClusterController.get_security_group_id(
            'security-group-slave')
        public_subnet = EmrClusterController.get_subnet_id()
        response = EmrClusterController.connection.run_job_flow(
            Name=name,
            ReleaseLabel=release,
            LogUri='s3://art-emr-data-mesh-logging-bucket',
            Applications=[{
                'Name': 'hadoop'
            }, {
                'Name': 'spark'
            }, {
                'Name': 'hive'
            }, {
                'Name': 'livy'
            }, {
                'Name': 'zeppelin'
            }],
            Instances={
                'InstanceGroups': [{
                    'Name':
                    "Master nodes",
                    'Market':
                    'ON_DEMAND',
                    'InstanceRole':
                    'MASTER',
                    'InstanceType':
                    master_node_type,
                    'InstanceCount':
                    master_instance_count,
                    'Configurations': [{
                        "Classification": "livy-conf",
                        "Properties": {
                            "livy.server.session.timeout-check": "true",
                            "livy.server.session.timeout": "2h",
                            "livy.server.yarn.app-lookup-timeout": "120s"
                        }
                    }]
                }, {
                    'Name': "Slave nodes",
                    'Market': 'ON_DEMAND',
                    'InstanceRole': 'CORE',
                    'InstanceType': slave_node_type,
                    'InstanceCount': slave_instance_count,
                }],
                'Ec2KeyName':
                'EMR-key-pair',
                'EmrManagedMasterSecurityGroup':
                emr_master_security_group_id,
                'EmrManagedSlaveSecurityGroup':
                emr_slave_security_group_id,
                'KeepJobFlowAliveWhenNoSteps':
                True,
                'TerminationProtected':
                False,
                'Ec2SubnetId':
                public_subnet,
            },
            VisibleToAllUsers=True,
            ServiceRole='iam_emr_service_role',
            JobFlowRole='emr-instance-profile',
        )

        print('cluster created with the step...', response['JobFlowId'])
        return response["JobFlowId"]

    @staticmethod
    def add_job_step(cluster_id,
                     name,
                     jar,
                     args,
                     main_class="",
                     action="CONTINUE"):
        response = EmrClusterController.connection.add_job_flow_steps(
            JobFlowId=cluster_id,
            Steps=[
                {
                    'Name': name,
                    'ActionOnFailure': action,
                    'HadoopJarStep': {
                        'Jar': jar,
                        'MainClass': main_class,
                        'Args': args
                    }
                },
            ])
        print(f"Add Job Response: {response}")
        return response['StepIds'][0]

    @staticmethod
    def list_job_steps(cluster_id):
        response = EmrClusterController.connection.list_steps(
            ClusterId=cluster_id,
            StepStates=[
                'PENDING', 'CANCEL_PENDING', 'RUNNING', 'COMPLETED',
                'CANCELLED', 'FAILED', 'INTERRUPTED'
            ])
        for cluster in response['Clusters']:
            print(cluster['Name'])
            print(cluster['Id'])

    @staticmethod
    def get_step_status(cluster_id, step_id):
        response = EmrClusterController.connection.describe_step(
            ClusterId=cluster_id, StepId=step_id)
        return response['Step']['Status']

    @staticmethod
    def get_cluster_dns(cluster_id):
        response = EmrClusterController.connection.describe_cluster(
            ClusterId=cluster_id)
        return response['Cluster']['MasterPublicDnsName']

    @staticmethod
    def get_public_ip(cluster_id):
        instances = EmrClusterController.connection.list_instances(
            ClusterId=cluster_id, InstanceGroupTypes=['MASTER'])
        return instances['Instances'][0]['PublicIpAddress']

    @staticmethod
    def wait_for_cluster_creation(cluster_id):
        EmrClusterController.connection.get_waiter('cluster_running').wait(
            ClusterId=cluster_id)

    @staticmethod
    def wait_for_step_completion(cluster_id, step_id):
        EmrClusterController.connection.get_waiter('step_complete').wait(
            ClusterId=cluster_id, StepId=step_id)

    @staticmethod
    def terminate_cluster(cluster_id):
        EmrClusterController.connection.terminate_job_flows(
            JobFlowIds=[cluster_id])

    @staticmethod
    def create_spark_session(master_dns, kind='spark'):
        host = "http://" + master_dns + ":8998"
        conf = {
            "hive.metastore.client.factory.class":
            "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory"
        }
        data = {"kind": kind, "conf": conf}
        headers = {"Content-Type": "application/json"}
        response = requests.post(host + "/sessions",
                                 data=json.dumps(data),
                                 headers=headers)
        print(
            f"\n\nCREATE SPARK SESSION RESPONSE STATUS CODE: {response.status_code}"
        )
        logging.info(response.json())
        print("\n\nCREATED LIVY SPARK SESSION SUCCESSFULLY")
        return response.headers

    @staticmethod
    def wait_for_idle_session(master_dns, response_headers):
        status = ""
        host = "http://" + master_dns + ":8998"
        session_url = host + response_headers['location']
        print(f"\n\nWAIT FOR IDLE SESSION: Session URL: {session_url}")
        while status != "idle":
            time.sleep(3)
            status_response = requests.get(
                session_url, headers={"Content-Type": "application/json"})
            status = status_response.json()['state']
            logging.info('Session status: ' + status)
        print("\n\nLIVY SPARK SESSION IS IDLE")
        return session_url

    @staticmethod
    def submit_statement(session_url, statement_path):
        statements_url = session_url + "/statements"
        with open(statement_path, 'r') as f:
            code = f.read()
        data = {"code": code}
        response = requests.post(statements_url,
                                 data=json.dumps(data),
                                 headers={"Content-Type": "application/json"})
        logging.info(response.json())
        print("\n\nSUBMITTED LIVY STATEMENT SUCCESSFULLY")
        return response

    @staticmethod
    def track_statement_progress(master_dns, response_headers):
        statement_status = ""
        host = "http://" + master_dns + ":8998"
        session_url = host + response_headers['location'].split(
            '/statements', 1)[0]
        # Poll the status of the submitted scala code
        while statement_status != "available":
            statement_url = host + response_headers['location']
            statement_response = requests.get(
                statement_url, headers={"Content-Type": "application/json"})
            statement_status = statement_response.json()['state']
            logging.info('Statement status: ' + statement_status)
            lines = requests.get(session_url + '/log',
                                 headers={
                                     'Content-Type': 'application/json'
                                 }).json()['log']
            for line in lines:
                logging.info(line)

            if 'progress' in statement_response.json():
                logging.info('Progress: ' +
                             str(statement_response.json()['progress']))
            time.sleep(10)
        final_statement_status = statement_response.json()['output']['status']
        if final_statement_status == 'error':
            logging.info('Statement exception: ' +
                         statement_response.json()['output']['evalue'])
            for trace in statement_response.json()['output']['traceback']:
                logging.info(trace)
            raise ValueError('Final Statement Status: ' +
                             final_statement_status)
        print(statement_response.json())
        logging.info('Final Statement Status: ' + final_statement_status)

    @staticmethod
    def kill_spark_session(session_url):
        requests.delete(session_url,
                        headers={"Content-Type": "application/json"})
        print("\n\nLIVY SESSION WAS DELETED SUCCESSFULLY")

    @staticmethod
    def create_livy_batch(master_dns, path, class_name):
        data = {"file": path, "className": class_name}
        host = "http://" + master_dns + ":8998"
        headers = {"Content-Type": "application/json"}
        response = requests.post(host + "/batches",
                                 data=json.dumps(data),
                                 headers=headers)
        print(
            f"\n\nCREATE SPARK SESSION RESPONSE STATUS CODE: {response.status_code}"
        )
        logging.info(response.json())
        print("\n\nCREATED LIVY SPARK SESSION SUCCESSFULLY")
        return response.json()["id"]

    @staticmethod
    def track_livy_batch_job(master_dns, batch_id):
        statement_status = ""
        host = "http://" + master_dns + ":8998"
        session_url = host + "/batches/" + str(batch_id)
        # Poll the status of the submitted scala code
        while statement_status != "available":
            statement_url = host + "/state"
            statement_response = requests.get(
                statement_url, headers={"Content-Type": "application/json"})
            statement_status = statement_response.json()['state']
            logging.info('Statement status: ' + statement_status)
            lines = requests.get(session_url + '/log',
                                 headers={
                                     'Content-Type': 'application/json'
                                 }).json()['log']
            for line in lines:
                logging.info(line)

            if 'progress' in statement_response.json():
                logging.info('Progress: ' +
                             str(statement_response.json()['progress']))
            time.sleep(10)
        final_statement_status = statement_response.json()['output']['status']
        if final_statement_status == 'error':
            logging.info('Statement exception: ' +
                         statement_response.json()['output']['evalue'])
            for trace in statement_response.json()['output']['traceback']:
                logging.info(trace)
            raise ValueError('Final Statement Status: ' +
                             final_statement_status)
        print(statement_response.json())
        logging.info('Final Statement Status: ' + final_statement_status)

    @staticmethod
    def terminate_batch_job(master_dns, batch_id):
        requests.delete(f"http://{master_dns}:8998/batches/{batch_id}",
                        headers={"Content-Type": "application/json"})
        print("\n\nLIVY SESSION WAS DELETED SUCCESSFULLY")
コード例 #34
0
HOST_LABELBOX_OUTPUT_FOLDER = HOST_LABELBOX_FOLDER + "/output/"

BASE_AIRFLOW_FOLDER = "/usr/local/airflow/"
AIRFLOW_DATA_FOLDER = os.path.join(BASE_AIRFLOW_FOLDER, "data")

AIRFLOW_CURRENT_DAG_FOLDER = os.path.dirname(os.path.realpath(__file__))
AIRFLOW_LABELBOX_FOLDER = os.path.join(AIRFLOW_DATA_FOLDER, "labelbox")
AIRFLOW_LABELBOX_OUTPUT_FOLDER = os.path.join(AIRFLOW_LABELBOX_FOLDER,
                                              "output")
AIRFLOW_TF_RECORD_FOLDER = os.path.join(AIRFLOW_DATA_FOLDER, "tfrecord")

labelbox_api_url = BaseHook.get_connection("labelbox").host
labelbox_api_key = BaseHook.get_connection("labelbox").password
slack_webhook_token = BaseHook.get_connection("slack").password

ontology_front = json.loads(Variable.get("ontology_front"))
ontology_bottom = json.loads(Variable.get("ontology_bottom"))

# TODO: Document this since it could be an issues
export_project_name = Variable.get("labelbox_export_project_list").split(",")

front_cam_object_list = [tool["name"] for tool in ontology_front["tools"]]

bottom_cam_object_list = [tool["name"] for tool in ontology_bottom["tools"]]

default_args = {
    "owner": "airflow",
    "depends_on_past": False,
    "start_date": datetime(2019, 1, 24),
    "email": ["*****@*****.**"],
    "email_on_failure": False,
コード例 #35
0
from airflow import DAG
from airflow.models import Variable
from airflow.operators.python import PythonOperator
from sustainment.update_data_quality_scores import dqs_logic  # noqa: E402
from utils import airflow_utils, ckan_utils

job_settings = {
    "description": "Calculates DQ scores across the catalogue",
    "schedule": "0 0 * * 1,4",
    "start_date": datetime(2020, 11, 10, 5, 0, 0),
}

JOB_FILE = Path(os.path.abspath(__file__))
JOB_NAME = JOB_FILE.name[:-3]

ACTIVE_ENV = Variable.get("active_env")
CKAN_CREDS = Variable.get("ckan_credentials_secret", deserialize_json=True)
CKAN = ckanapi.RemoteCKAN(**CKAN_CREDS[ACTIVE_ENV])

METADATA_FIELDS = ["collection_method", "limitations", "topics", "owner_email"]

TIME_MAP = {
    "daily": 1,
    "weekly": 7,
    "monthly": 30,
    "quarterly": 52 * 7 / 4,
    "semi-annually": 52 * 7 / 2,
    "annually": 365,
}

RESOURCE_MODEL = "scoring-models"
コード例 #36
0
import airflow
from airflow.models import Variable
from airflow.hooks.base_hook import BaseHook
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator
from airflow.contrib.operators.s3_to_sftp_operator import S3ToSFTPOperator
from airflow.contrib.operators.ssh_operator import SSHOperator
from cob_datapipeline.task_slack_posts import notes_slackpostonsuccess
"""
INIT SYSTEMWIDE VARIABLES

check for existence of systemwide variables shared across tasks that can be
initialized here if not found (i.e. if this is a new installation) & defaults exist
"""

AIRFLOW_HOME = Variable.get("AIRFLOW_HOME")

# cob_index Indexer Library Variables
GIT_BRANCH = Variable.get("CATALOG_QA_BRANCH")
LATEST_RELEASE = Variable.get("CATALOG_QA_LATEST_RELEASE")

# Get S3 data bucket variables
AIRFLOW_S3 = BaseHook.get_connection("AIRFLOW_S3")
AIRFLOW_DATA_BUCKET = Variable.get("AIRFLOW_DATA_BUCKET")

# CREATE DAG
DEFAULT_ARGS = {
    "owner": "cob",
    "depends_on_past": False,
    "email_on_failure": False,
    "email_on_retry": False,
コード例 #37
0
ファイル: project-workflow.py プロジェクト: dkyos/dev-samples
from airflow import DAG
from airflow.operators.bash_operator import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
from datetime import datetime,timedelta
from airflow.models import Variable

SRC=Variable.get("SRC")
#SRC='./'
COUNTRY=Variable.get("COUNTRY")
#COUNTRY='PL'

dag = DAG('project-workflow',description='Project Workflow DAG',
        schedule_interval = '*/5 0 * * *',
        start_date=datetime(2017,7,1),
        catchup=False)

xlsx_to_csv_task = BashOperator(
        task_id='xlsx_to_csv',
        bash_command='"$src"/test.sh "$country" 2nd_param_xlsx',
        env={'src': SRC, 'country': COUNTRY},
        dag=dag)

merge_command = SRC + '/test.sh ' + COUNTRY + ' 2nd_param_merge'
merge_task = BashOperator(
        task_id='merge',
        bash_command=merge_command ,
        dag=dag)

my_templated_command = """
{{ params.src }}/test.sh {{ params.country}} 2nd_param_cleansing
コード例 #38
0
def setEndTime(**kwargs):
    Variable.set('last_execution_date_succes', kwargs['execution_date'])
コード例 #39
0
ファイル: tuto2.py プロジェクト: VViles/airflow_test
def set_group(*args, **kwargs):
    if datetime.now().hour > 18 or datetime.now().hour < 8:
        Variable.set('group', 'night_shift')
    else:
        Variable.set('group', 'day_shift')
コード例 #40
0
ファイル: istio_common_dag.py プロジェクト: veggiemonk/istio
 def AirflowGetVariableOrBaseCase(var, base):
   try:
     return Variable.get(var)
   except KeyError:
     return base
コード例 #41
0
ファイル: tuto2.py プロジェクト: VViles/airflow_test
def set_mail(*args, **context):
    group = Variable.get('group')
    if group == 'night_shift':
        context['task_instance'].xcom_push(key='recipient', value='*****@*****.**')
    else:
        context['task_instance'].xcom_push(key='recipient', value='*****@*****.**')
コード例 #42
0
        if cluster is None or len(
                cluster) == 0 or 'clusterName' not in cluster:
            return 'create_cluster'
        else:
            return 'run_job'

    start = BranchPythonOperator(
        task_id='start',
        provide_context=True,
        python_callable=ensure_cluster_exists,
    )

    create_cluster = DataprocClusterCreateOperator(
        task_id='create_cluster',
        cluster_name=CLUSTER_NAME,
        project_id=Variable.get('project'),
        num_workers=2,
        master_disk_size=50,
        worker_disk_size=50,
        image_version='preview',
        internal_ip_only=True,
        tags=['dataproc'],
        labels={'dataproc-cluster': CLUSTER_NAME},
        zone=Variable.get('zone'),
        subnetwork_uri='projects/{}/region/{}/subnetworks/{}'.format(
            Variable.get('project'), Variable.get('region'),
            Variable.get('subnet')),
        service_account=Variable.get('serviceAccount'),
    )

    class PatchedDataProcSparkOperator(DataProcSparkOperator):
コード例 #43
0
def GetVariableOrDefault(var, default):
  try:
    return Variable.get(var)
  except KeyError:
    return default
コード例 #44
0
ファイル: core.py プロジェクト: moritzpein/airflow
 def test_variable_set_get_round_trip(self):
     Variable.set("tested_var_set_id", "Monday morning breakfast")
     assert "Monday morning breakfast" == Variable.get("tested_var_set_id")
コード例 #45
0
def process_host_mk():
    path = Variable.get("hosts_mk_path")
    hosts = {}
    site_mapping = {}
    all_site_mapping = []
    all_list = []
    device_dict = {}
    start = 0
    tech_wise_device_site_mapping = {}
    try:
        text_file = open(path, "r")

    except IOError:
        logging.error("File Name not correct")
        return "notify"
    except Exception:
        logging.error(
            "Please check the HostMK file exists on the path provided ")
        return "notify"

    lines = text_file.readlines()
    host_ip_mapping = get_host_ip_mapping()
    for line in lines:
        if "all_hosts" in line:
            start = 1

        if start == 1:
            hosts["hostname"] = line.split("|")[0]
            hosts["device_type"] = line.split("|")[1]
            site_mapping["hostname"] = line.split("|")[0].strip().strip("'")

            site_mapping['site'] = line.split("site:")[1].split("|")[0].strip()

            site_mapping['device_type'] = line.split("|")[1].strip()

            all_list.append(hosts.copy())
            all_site_mapping.append(site_mapping.copy())
            if ']\n' in line:
                start = 0
                all_list[0]['hostname'] = all_list[0].get("hostname").strip(
                    'all_hosts += [\'')
                all_site_mapping[0]['hostname'] = all_site_mapping[0].get(
                    "hostname").strip('all_hosts += [\'')
                break
    print "LEN of ALL LIST is %s" % (len(all_list))
    if len(all_list) > 1:
        for device in all_list:
            device_dict[device.get("hostname").strip().strip(
                "'")] = device.get("device_type").strip()
        Variable.set("hostmk.dict", str(device_dict))

        for site_mapping in all_site_mapping:
            if site_mapping.get(
                    'device_type') not in tech_wise_device_site_mapping.keys():
                tech_wise_device_site_mapping[site_mapping.get(
                    'device_type')] = {
                        site_mapping.get('site'): [{
                            "hostname":
                            site_mapping.get('hostname'),
                            "ip_address":
                            host_ip_mapping.get(site_mapping.get('hostname'))
                        }]
                    }
            else:
                if site_mapping.get(
                        'site') not in tech_wise_device_site_mapping.get(
                            site_mapping.get('device_type')).keys():
                    tech_wise_device_site_mapping.get(
                        site_mapping.get('device_type'))[site_mapping.get(
                            'site')] = [{
                                "hostname":
                                site_mapping.get('hostname'),
                                "ip_address":
                                host_ip_mapping.get(
                                    site_mapping.get('hostname'))
                            }]
                else:
                    tech_wise_device_site_mapping.get(
                        site_mapping.get('device_type')).get(
                            site_mapping.get('site')).append({
                                "hostname":
                                site_mapping.get('hostname'),
                                "ip_address":
                                host_ip_mapping.get(
                                    site_mapping.get('hostname'))
                            })

        Variable.set("hostmk.dict.site_mapping",
                     str(tech_wise_device_site_mapping))
        count = 0
        for x in tech_wise_device_site_mapping:
            for y in tech_wise_device_site_mapping.get(x):
                count = count+len(tech_wise_device_site_mapping.get(x).get(y))\

        print "COUNT : %s" % (count)
        return 0
    else:
        return -4
コード例 #46
0
ファイル: core.py プロジェクト: moritzpein/airflow
 def test_get_non_existing_var_should_return_default(self):
     default_value = "some default val"
     assert default_value == Variable.get("thisIdDoesNotExist",
                                          default_var=default_value)
コード例 #47
0
# -*- coding: utf-8 -*-
import airflow
from airflow.models import Variable

from etl.operators import MsSqlOperator, MsSqlDataImportOperator, SqlcmdFilesOperator

args = {
    'owner': 'airflow',
    'start_date': airflow.utils.dates.days_ago(7),
    'provide_context': True
}

dag = airflow.DAG('otfn_with_link_server',
                  schedule_interval='@daily',
                  default_args=args,
                  template_searchpath=Variable.get('sql_path'),
                  max_active_runs=1)

t0 = MsSqlOperator(task_id='clear_timesheet_data',
                   sql='DELETE FROM timesheet',
                   mssql_conn_id='mssql_datalake',
                   dag=dag)

t1 = MsSqlDataImportOperator(task_id='import_timesheet_data',
                             table_name='timesheet',
                             data_file=Variable.get('data_file_path') +
                             '/timesheet.csv',
                             mssql_conn_id='mssql_datalake',
                             dag=dag)

t2 = SqlcmdFilesOperator(task_id='otfn_with_link_servers',
コード例 #48
0
)

from importscripts.convert_bedrijveninvesteringszones_data import convert_biz_data

from sql.bedrijveninvesteringszones import CREATE_TABLE, UPDATE_TABLE

from postgres_check_operator import (
    PostgresMultiCheckOperator,
    COUNT_CHECK,
    GEO_CHECK,
)

from postgres_permissions_operator import PostgresPermissionsOperator

dag_id = "bedrijveninvesteringszones"
variables = Variable.get(dag_id, deserialize_json=True)
tmp_dir = f"{SHARED_DIR}/{dag_id}"
files_to_download = variables["files_to_download"]
total_checks = []
count_checks = []
geo_checks = []
check_name = {}


# needed to put quotes on elements in geotypes for SQL_CHECK_GEO
def quote(instr):
    return f"'{instr}'"


with DAG(
        dag_id,
コード例 #49
0
ファイル: istio_common_dag.py プロジェクト: veggiemonk/istio
  def GenerateTestArgs(**kwargs):
    """Loads the configuration that will be used for this Iteration."""
    conf = kwargs['dag_run'].conf
    if conf is None:
      conf = dict()

    """ Airflow gives the execution date when the job is supposed to be run,
        however we dont backfill and only need to run one build therefore use
        the current date instead of the date that is passed in """
#    date = kwargs['execution_date']
    date = datetime.datetime.now()

    timestamp = time.mktime(date.timetuple())

    # Monthly releases started in Nov 2017 with 0.3.0, so minor is # of months
    # from Aug 2017.
    minor_version = (date.year - 2017) * 12 + (date.month - 1) - 7
    major_version = AirflowGetVariableOrBaseCase('major_version', 0)
    # This code gets information about the latest released version so we know
    # What version number to use for this round.
    r_minor = int(AirflowGetVariableOrBaseCase('released_version_minor', 0))
    r_patch = int(AirflowGetVariableOrBaseCase('released_version_patch', 0))
    # If  we have already released a monthy for this mounth then bump
    # The patch number for the remander of the month.
    if r_minor == minor_version:
      patch = r_patch + 1
    else:
      patch = 0
    # If version is overriden then we should use it otherwise we use it's
    # default or monthly value.
    version = conf.get('VERSION')
    if monthly and not version:
      version = '{}.{}.{}'.format(major_version, minor_version, patch)

    default_conf = environment_config.get_airflow_config(
        version,
        timestamp,
        major=major_version,
        minor=minor_version,
        patch=patch,
        date=date.strftime('%Y%m%d'),
        rc=date.strftime('%H-%M'))
    config_settings = dict(VERSION=default_conf['VERSION'])
    config_settings_name = [
        'PROJECT_ID',
        'MFEST_URL',
        'MFEST_FILE',
        'GCS_STAGING_BUCKET',
        'SVC_ACCT',
        'GITHUB_ORG',
        'GITHUB_REPO',
        'GCS_GITHUB_PATH',
        'TOKEN_FILE',
        'GCR_STAGING_DEST',
        'GCR_RELEASE_DEST',
        'GCS_MONTHLY_RELEASE_PATH',
        'DOCKER_HUB',
        'GCS_BUILD_BUCKET',
        'RELEASE_PROJECT_ID',
    ]

    for name in config_settings_name:
      config_settings[name] = conf.get(name) or default_conf[name]

    if monthly:
      config_settings['MFEST_COMMIT'] = conf.get(
          'MFEST_COMMIT') or Variable.get('latest_sha')
      gcs_path = conf.get('GCS_MONTHLY_STAGE_PATH')
      if not gcs_path:
        gcs_path = default_conf['GCS_MONTHLY_STAGE_PATH']
    else:
      config_settings['MFEST_COMMIT'] = conf.get(
          'MFEST_COMMIT') or default_conf['MFEST_COMMIT']
      gcs_path = conf.get('GCS_DAILY_PATH') or default_conf['GCS_DAILY_PATH']

    config_settings['GCS_STAGING_PATH'] = gcs_path
    config_settings['GCS_BUILD_PATH'] = '{}/{}'.format(
        config_settings['GCS_BUILD_BUCKET'], gcs_path)
    config_settings['GCS_RELEASE_TOOLS_PATH'] = '{}/release-tools/{}'.format(
        config_settings['GCS_BUILD_BUCKET'], gcs_path)
    config_settings['GCS_FULL_STAGING_PATH'] = '{}/{}'.format(
        config_settings['GCS_STAGING_BUCKET'], gcs_path)
    config_settings['ISTIO_REPO'] = 'https://github.com/{}/{}.git'.format(
        config_settings['GITHUB_ORG'], config_settings['GITHUB_REPO'])

    return config_settings
コード例 #50
0
from airflow import DAG
from airflow.operators.docker_operator import DockerOperator
from airflow.operators.python_operator import PythonOperator
from datetime import timedelta
from airflow.utils.dates import days_ago
import json

from airflow.models import Variable

brewed_from = Variable.get("brewed_from", "10-2011")
brewed_until = Variable.get("brewed_until", "09-2013")

default_args = {
    'owner': 'Muhammad Faizan Khan',
    'description': 'Use of the DockerOperator',
    'depend_on_past': False,
    'start_date': days_ago(2),
    'email_on_failure': False,
    'email_on_retry': False,
    'retries': 5,
    'retry_delay': timedelta(minutes=5)
}

dag = DAG(
    'BeerByInterval',
    default_args=default_args,
    description='A simple tutorial DAG',
    schedule_interval='@daily',
)

t1 = DockerOperator(task_id='DockerOperator',
コード例 #51
0
import os
from datetime import timedelta

from airflow import DAG
from airflow.models import Variable
from airflow.operators.dummy import DummyOperator
from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator
from airflow.utils.dates import days_ago

# ************** AIRFLOW VARIABLES **************
sns_topic = Variable.get("sns_topic")
# ***********************************************

DAG_ID = os.path.basename(__file__).replace(".py", "")

DEFAULT_ARGS = {
    "owner": "airflow",
    "depends_on_past": False,
    "email": ["*****@*****.**"],
    "email_on_failure": False,
    "email_on_retry": False,
}
"""
 # Sends a test message from Amazon SNS from Amazon MWAA.
"""

with DAG(
        dag_id=DAG_ID,
        description="Send a test message to an SNS Topic",
        default_args=DEFAULT_ARGS,
        dagrun_timeout=timedelta(hours=2),
コード例 #52
0
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator, PythonOperator

from create_project_into_labelbox import create_project_into_labelbox
from utils import file_ops, slack

BASE_AIRFLOW_FOLDER = "/usr/local/airflow/"
AIRFLOW_DATA_FOLDER = os.path.join(BASE_AIRFLOW_FOLDER, "data")
AIRFLOW_IMAGE_FOLDER = os.path.join(AIRFLOW_DATA_FOLDER, "images")
AIRFLOW_JSON_FOLDER = os.path.join(AIRFLOW_DATA_FOLDER, "json")

labelbox_api_url = BaseHook.get_connection("labelbox").host
labelbox_api_key = BaseHook.get_connection("labelbox").password
slack_webhook_token = BaseHook.get_connection("slack").password

bucket_name = Variable.get("bucket_name")
ontology_front = Variable.get("ontology_front")
ontology_bottom = Variable.get("ontology_bottom")

json_files = file_ops.get_files_in_directory(AIRFLOW_JSON_FOLDER, "*.json")

default_args = {
    "owner": "airflow",
    "depends_on_past": False,
    "start_date": datetime(2019, 1, 24),
    "email": ["*****@*****.**"],
    "email_on_failure": False,
    "email_on_retry": False,
    "on_failure_callback": slack.task_fail_slack_alert,
    "retries": 0,
}
コード例 #53
0
ファイル: core.py プロジェクト: moritzpein/airflow
 def test_variable_set_get_round_trip_json(self):
     value = {"a": 17, "b": 47}
     Variable.set("tested_var_set_id", value, serialize_json=True)
     assert value == Variable.get("tested_var_set_id", deserialize_json=True)
コード例 #54
0
from _slack_operators import *

default_args = {
    'owner': 'airflow',
    'description': 'Gathers MDS data and uploads to Knack',
    'depend_on_past': False,
    'start_date': datetime(2018, 1, 1),
    'email_on_failure': False,
    'email_on_retry': False,
    'retries': 1,
    'retry_delay': timedelta(minutes=5),
    'on_failure_callback': task_fail_slack_alert,
}

environment_vars = Variable.get("atd_mds_monthly_report_production",
                                deserialize_json=True)

with DAG(
        "atd_mds",
        default_args=default_args,
        schedule_interval="40 7 3 * *",
        catchup=False,
        tags=["production", "mds"],
) as dag:
    #
    # Task: run_python
    # Description: Gathers data from Hasura to generate report data and uploads to Knack.
    #
    run_python = BashOperator(
        task_id="run_python_script",
        bash_command="python3 ~/dags/python_scripts/atd_mds_monthly_report.py",
コード例 #55
0
ファイル: core.py プロジェクト: moritzpein/airflow
 def test_get_non_existing_var_should_not_deserialize_json_default(self):
     default_value = "}{ this is a non JSON default }{"
     assert default_value == Variable.get("thisIdDoesNotExist",
                                          default_var=default_value,
                                          deserialize_json=True)
コード例 #56
0
ファイル: varialbes.py プロジェクト: Aleks-Ya/yaal_examples
"""
Working with Variables.
Doc: https://airflow.apache.org/concepts.html?highlight=variable#variables
"""

from airflow import DAG
from airflow.operators.bash_operator import BashOperator
from datetime import datetime
from airflow.models import Variable

text_variable = Variable.get("user")

# Getting a JSON var doesn't work ()
#json_variable = Variable.get("json_var", deserialize_json = True)

default_args = {
    'start_date': datetime.now()
}

dag = DAG('varialbes', default_args=default_args)

text_message = f"echo 'The user variable is {text_variable}'"
#json_message = f"echo 'The json_var={json_variable}'"
t1 = BashOperator(
    task_id='text_variable',
    bash_command=text_message,
    dag=dag)