コード例 #1
1
    def execute(self, context):
        # Specifying a service account file allows the user to using non default
        # authentication for creating a Kubernetes Pod. This is done by setting the
        # environment variable `GOOGLE_APPLICATION_CREDENTIALS` that gcloud looks at.
        key_file = None

        # If gcp_conn_id is not specified gcloud will use the default
        # service account credentials.
        if self.gcp_conn_id:
            from airflow.hooks.base_hook import BaseHook
            # extras is a deserialized json object
            extras = BaseHook.get_connection(self.gcp_conn_id).extra_dejson
            # key_file only gets set if a json file is created from a JSON string in
            # the web ui, else none
            key_file = self._set_env_from_extras(extras=extras)

        # Write config to a temp file and set the environment variable to point to it.
        # This is to avoid race conditions of reading/writing a single file
        with tempfile.NamedTemporaryFile() as conf_file:
            os.environ[KUBE_CONFIG_ENV_VAR] = conf_file.name
            # Attempt to get/update credentials
            # We call gcloud directly instead of using google-cloud-python api
            # because there is no way to write kubernetes config to a file, which is
            # required by KubernetesPodOperator.
            # The gcloud command looks at the env variable `KUBECONFIG` for where to save
            # the kubernetes config file.
            subprocess.check_call(
                ["gcloud", "container", "clusters", "get-credentials",
                 self.cluster_name,
                 "--zone", self.location,
                 "--project", self.project_id])

            # Since the key file is of type mkstemp() closing the file will delete it from
            # the file system so it cannot be accessed after we don't need it anymore
            if key_file:
                key_file.close()

            # Tell `KubernetesPodOperator` where the config file is located
            self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
            return super().execute(context)
コード例 #2
0
    def execute(self, context):
        source_hook = BaseHook.get_hook(self.source_conn_id)

        self.log.info("Extracting data from %s", self.source_conn_id)
        self.log.info("Executing: \n %s", self.sql)
        results = source_hook.get_records(self.sql)

        destination_hook = BaseHook.get_hook(self.destination_conn_id)
        if self.preoperator:
            self.log.info("Running preoperator")
            self.log.info(self.preoperator)
            destination_hook.run(self.preoperator)

        self.log.info("Inserting rows into %s", self.destination_conn_id)
        destination_hook.insert_rows(table=self.destination_table, rows=results)
コード例 #3
0
ファイル: generic_transfer.py プロジェクト: 16522855/airflow
    def execute(self, context):
        source_hook = BaseHook.get_hook(self.source_conn_id)

        logging.info("Extracting data from {}".format(self.source_conn_id))
        logging.info("Executing: \n" + self.sql)
        results = source_hook.get_records(self.sql)

        destination_hook = BaseHook.get_hook(self.destination_conn_id)
        if self.preoperator:
            logging.info("Running preoperator")
            logging.info(self.preoperator)
            destination_hook.run(self.preoperator)

        logging.info("Inserting rows into {}".format(self.destination_conn_id))
        destination_hook.insert_rows(table=self.destination_table, rows=results)
コード例 #4
0
    def execute(self, context):
        # If gcp_conn_id is not specified gcloud will use the default
        # service account credentials.
        if self.gcp_conn_id:
            from airflow.hooks.base_hook import BaseHook
            # extras is a deserialized json object
            extras = BaseHook.get_connection(self.gcp_conn_id).extra_dejson
            self._set_env_from_extras(extras=extras)

        # Write config to a temp file and set the environment variable to point to it.
        # This is to avoid race conditions of reading/writing a single file
        with tempfile.NamedTemporaryFile() as conf_file:
            os.environ[KUBE_CONFIG_ENV_VAR] = conf_file.name
            # Attempt to get/update credentials
            # We call gcloud directly instead of using google-cloud-python api
            # because there is no way to write kubernetes config to a file, which is
            # required by KubernetesPodOperator.
            # The gcloud command looks at the env variable `KUBECONFIG` for where to save
            # the kubernetes config file.
            subprocess.check_call(
                ["gcloud", "container", "clusters", "get-credentials",
                 self.cluster_name,
                 "--zone", self.location,
                 "--project", self.project_id])

            # Tell `KubernetesPodOperator` where the config file is located
            self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
            super(GKEPodOperator, self).execute(context)
コード例 #5
0
    def poke(self, context):
        hook = BaseHook.get_connection(self.conn_id).get_hook()

        self.log.info('Poking: %s', self.sql)
        records = hook.get_records(self.sql)
        if not records:
            return False
        return str(records[0][0]) not in ('0', '')
コード例 #6
0
def _get_project_id():
  """Get project ID from default GCP connection."""

  extras = BaseHook.get_connection('google_cloud_default').extra_dejson
  key = 'extra__google_cloud_platform__project'
  if key in extras:
    project_id = extras[key]
  else:
    raise ('Must configure project_id in google_cloud_default '
           'connection from Airflow Console')
  return project_id
コード例 #7
0
ファイル: sensors.py プロジェクト: Zen-Slug/incubator-airflow
    def poke(self, context):
        hook = BaseHook.get_connection(self.conn_id).get_hook()

        logging.info('Poking: ' + self.sql)
        records = hook.get_records(self.sql)
        if not records:
            return False
        else:
            if str(records[0][0]) in ('0', '',):
                return False
            else:
                return True
            print(records[0][0])
コード例 #8
0
ファイル: qubole_sensor.py プロジェクト: caseybrown89/airflow
    def poke(self, context):
        conn = BaseHook.get_connection(self.qubole_conn_id)
        Qubole.configure(api_token=conn.password, api_url=conn.host)

        this.log.info('Poking: %s', self.data)

        status = False
        try:
            status = self.sensor_class.check(self.data)
        except Exception as e:
            logging.exception(e)
            status = False

        this.log.info('Status of this Poke: %s', status)

        return status
コード例 #9
0
ファイル: sql_sensor.py プロジェクト: Fokko/incubator-airflow
    def poke(self, context):
        conn = BaseHook.get_connection(self.conn_id)

        allowed_conn_type = {'google_cloud_platform', 'jdbc', 'mssql',
                             'mysql', 'oracle', 'postgres',
                             'presto', 'sqlite', 'vertica'}
        if conn.conn_type not in allowed_conn_type:
            raise AirflowException("The connection type is not supported by SqlSensor. " +
                                   "Supported connection types: {}".format(list(allowed_conn_type)))
        hook = conn.get_hook()

        self.log.info('Poking: %s (with parameters %s)', self.sql, self.parameters)
        records = hook.get_records(self.sql, self.parameters)
        if not records:
            return False
        return str(records[0][0]) not in ('0', '')
コード例 #10
0
    def get_extra_links(self, operator, dttm):
        """
        Get link to qubole command result page.

        :param operator: operator
        :param dttm: datetime
        :return: url link
        """
        conn = BaseHook.get_connection(operator.kwargs['qubole_conn_id'])
        if conn and conn.host:
            host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host)
        else:
            host = 'https://api.qubole.com/v2/analyze?command_id='

        ti = TaskInstance(task=operator, execution_date=dttm)
        qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id')
        url = host + str(qds_command_id) if qds_command_id else ''
        return url
コード例 #11
0
    def execute(self, context):
        source_hook = BaseHook.get_hook(self.source_conn_id)

        logging.info("Extracting data from {}".format(self.source_conn_id))
        logging.info("Executing: \n" + self.sql)
        results = source_hook.get_records(self.sql)

        destination_hook = TeradataHook(teradata_conn_id=self.destination_conn_id)
        if self.preoperator:
            logging.info("Running preoperator")
            logging.info(self.preoperator)
            destination_hook.run(self.preoperator)

        if self.batch:
            logging.info("Inserting {} rows into {} with a batch size of {} rows".format(len(results), self.destination_conn_id, self.batch_size))
            destination_hook.bulk_insert_rows(table=self.destination_table, rows=iter(results), commit_every=self.batch_size,  unicode_source=self.unicode_source)
        else:
            logging.info("Inserting {} rows into {}".format(len(results), self.destination_conn_id))
            destination_hook.insert_rows(table=self.destination_table, rows=iter(results), commit_every=1000, unicode_source=self.unicode_source )
コード例 #12
0
 def __init__(self,
              sql,
              autocommit=False,
              parameters=None,
              gcp_conn_id='google_cloud_default',
              gcp_cloudsql_conn_id='google_cloud_sql_default',
              *args, **kwargs):
     super(CloudSqlQueryOperator, self).__init__(*args, **kwargs)
     self.sql = sql
     self.gcp_conn_id = gcp_conn_id
     self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id
     self.autocommit = autocommit
     self.parameters = parameters
     self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
     self.cloudsql_db_hook = CloudSqlDatabaseHook(
         gcp_cloudsql_conn_id=gcp_cloudsql_conn_id,
         default_gcp_project_id=self.gcp_connection.extra_dejson.get(
             'extra__google_cloud_platform__project'))
     self.cloud_sql_proxy_runner = None
     self.database_hook = None
Airflow_snowflake_connection_name = Variable.get('Airflow_snowflake_connection_name')
orchestration_country = Variable.get('orchestration_country')
max_task_time = int(Variable.get('set_task_max_time_minutes')) #set the max runtime for a task
max_task_retries_on_error = int(Variable.get('max_task_retries_on_error'))

database_include_patterns = ['trans*', 'gateway'] #only inlcude the staging, transaction, and gateway databases, for multiple format as a list seperated by commas




##################################################################
#Collecting Connection attributes from Airflow connections repo
##################################################################

sf_con_parm = BaseHook.get_connection(Airflow_snowflake_connection_name)
snowflake_username = sf_con_parm.login 
snowflake_password = sf_con_parm.password 
snowflake_account = sf_con_parm.host 
snowflake_schema = 'A_UTILITY' 
snowflake_warehouse = "MYSQL_TO_RAW_MIGRATION_XSMALL_1" 
if orchestration_country.lower() in ['us', 'usa','united states','u.s.','u.s.a']:
    snowflake_database = "US_RAW"
if orchestration_country.lower() in ['ca', 'canada','c.a.']:
    snowflake_database = "CA_RAW"
if orchestration_country.lower() in ['uk', 'u.k.','united kingdom']:
    snowflake_database = "UK_RAW"

########################################################################
#Defining Utility functions
########################################################################
コード例 #14
0
    'depends_on_past': False,
    'start_date': datetime(2019, 6, 1),
    'email': ['*****@*****.**'],
    'email_on_failure': False,
    'email_on_retry': False,
    'retries': 0,
    'retry_delay': timedelta(minutes=0)
}

dag = DAG(
    'neo4j-con-1',
    default_args=default_args,
    description='testing generic cypher',
    # schedule_interval=timedelta(days=1)
    schedule_interval='@hourly',
    catchup=False)

t1 = BashOperator(task_id='print_date', bash_command='date', dag=dag)

connection = BaseHook.get_connection("neo4j_default")
uri = connection.host
pw = connection.password

cypher_1 = Neo4jOperator(task_id='node_count',
                         cql="MATCH (n) RETURN count(n)",
                         uri=uri,
                         pw=pw,
                         dag=dag)

t1.set_upstream(cypher_1)
コード例 #15
0
    :return:
    """

    ti = kwargs['ti']

    emr_dns = ti.xcom_pull(task_ids='iac_create_emr_cluster')

    ssh_conn = BaseHook.get_connection('ssh_default')
    ssh_conn.host = emr_dns

    session = settings.Session()  # get the session
    session.add(ssh_conn)
    session.commit()


ssh_emr_host = BaseHook.get_connection('ssh_default').host
ssh_emr_key = BaseHook.get_connection('ssh_default').extra_dejson.get(
    'key_file')
ssh_emr_user = BaseHook.get_connection('ssh_default').login
files_to_upload = '{constants.py,create_integration_layer.py,create_landing_zone.py,create_presentation_layer.py,' \
                  'helper_functions,load_integration_layer,load_landing_zone,load_presentation_layer,sql_queries,quality_checks}'

spark_master = 'yarn'

default_args = {
    'owner': 'flights_dl',
    'depends_on_past': False,
    'retries': 0,
    'catchup': False,
    'email_on_retry': False,
    'concurrency': 3
コード例 #16
0
ファイル: facebook_operator.py プロジェクト: sdaltmann/ewah
    def __init__(
        self,
        account_ids,
        insight_fields,
        level,
        data_from=None,
        data_until=None,
        time_increment=1,
        breakdowns=None,
        execution_waittime_seconds=15, # wait for a while before execution
        #   between account_ids to avoid hitting rate limits during backfill
        pagination_limit=1000,
        async_job_read_frequency_seconds=5,
        reload_data_from=None,
    *args, **kwargs):

        if kwargs.get('update_on_columns'):
            raise Exception('update_on_columns is set by operator!')

        if not account_ids.__iter__:
            raise Exception('account_ids must be an iterable, such as a list,' \
                + ' of strings or integers!')

        if level == self.levels.ad:
            kwargs['update_on_columns'] = [
                'ad_id',
                'date_start',
                'date_stop',
            ] + (breakdowns or [])
            insight_fields += ['ad_id', 'ad_name']
            insight_fields = list(set(insight_fields))
        else:
            raise Exception('Specified level not supported!')

        if not (
            (
                type(time_increment) == str
                and time_increment in ['monthly', 'all_days']
            )
            or
            (
                type(time_increment) == int
                and time_increment >= 1
                and time_increment <= 1
            )
        ):
            raise Exception('time_increment must either be an integer ' \
                + 'between 1 and 90, or a string of either "monthly" '\
                + 'or "all_days". Recommended and default is the integer 1.')

        allowed_insight_fields = [
            _attr[1] for _attr in [
                member for member in inspect.getmembers(
                    AdsInsights.Field,
                    lambda a:not (inspect.isroutine(a)),
                )
                if not (member[0].startswith('__') and member[0].endswith('__'))
            ]
        ]
        for i_f in insight_fields:
            if not i_f in allowed_insight_fields:
                raise Exception((
                    'Field {0} is not an accepted value for insight_fields! ' \
                    + 'Accepted field values:\n\t{1}\n'
                ).format(
                    i_f,
                    '\n\t'.join(allowed_insight_fields)
                ))

        self.data_from = data_from
        self.data_until = data_until

        super().__init__(*args, **kwargs)

        credentials = BaseHook.get_connection(self.source_conn_id)
        extra = credentials.extra_dejson

        # Note: app_secret is not always required!
        if not extra.get('app_id'):
            raise Exception('Connection extra must contain an "app_id"!')
        if not extra.get('access_token', credentials.password):
            raise Exception('Connection extra must contain an "access_token" ' \
                + 'if it is not saved as the connection password!')

        self.credentials = {
            'app_id': extra.get('app_id'),
            'app_secret': extra.get('app_secret'),
            'access_token': extra.get('access_token', credentials.password),
        }

        self.account_ids = account_ids
        self.insight_fields = insight_fields
        self.level = level
        self.time_increment = time_increment
        self.breakdowns = breakdowns
        self.execution_waittime_seconds = execution_waittime_seconds
        self.pagination_limit = pagination_limit
        self.async_job_read_frequency_seconds = async_job_read_frequency_seconds
        self.reload_data_from = reload_data_from
コード例 #17
0
import smtplib
from datetime import timedelta
from email import encoders
from email.mime.base import MIMEBase
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText

import pandas as pd
import plotly.graph_objects as go
from airflow.hooks.base_hook import BaseHook
from pendulum import Pendulum
from plotly.subplots import make_subplots

from utils import db

email_connection = BaseHook.get_connection('sender_email')


def send_email(**kwargs):
    execution_date: Pendulum = kwargs['execution_date']

    if execution_date.weekday() == 6:
        table = analyse_prices(execution_date=execution_date)
        report_text = create_html_report(table=table)
        __send_email(execution_date=execution_date, report_text=report_text)


def create_html_report(table: pd.DataFrame) -> str:
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    fig.add_trace(go.Scatter(x=table.index,
                             y=table.average_price,
コード例 #18
0
 def get_db_hook(self):
     return BaseHook.get_hook(conn_id=self.conn_id)
コード例 #19
0
        # hive create table
        hive_hook = HiveCliHook()
        sql = ODS_CREATE_TABLE_SQL.format(db_name=HIVE_DB,
                                          table_name=hive_table_name,
                                          columns=",\n".join(rows),
                                          ufile_path=UFILE_PATH %
                                          (db_name, table_name))
        logging.info('Executing: %s', sql)
        hive_hook.run_cli(sql)
    return


conn_conf_dict = {}
for db_name, table_name, conn_id, prefix_name, priority_weight_nm in table_list:
    if conn_id not in conn_conf_dict:
        conn_conf_dict[conn_id] = BaseHook.get_connection(conn_id)

    hive_table_name = HIVE_TABLE % (prefix_name, table_name)
    if table_name in ['data_opay_transaction']:
        m = 1
    else:
        m = 12
    # sqoop import
    import_table = BashOperator(
        task_id='import_table_{}'.format(hive_table_name),
        priority_weight=priority_weight_nm,
        bash_command='''
            #!/usr/bin/env bash
            sqoop import "-Dorg.apache.sqoop.splitter.allow_text_splitter=true" \
            -D mapred.job.queue.name=root.collects \
            --connect "jdbc:mysql://{host}:{port}/{schema}?tinyInt1isBit=false&useUnicode=true&characterEncoding=utf8" \
database_include_patterns = [
    'prefix*'
]  #only inlcude the staging, transaction, and gateway databases, for multiple format as a list seperated by commas

excluded_tables = ['table1', 'table2']  #list of tables we dont want to migrate

max_task_time = int(
    Variable.get('set_task_max_time_minutes'))  #set the max runtime for a task
max_task_retries_on_error = int(Variable.get('max_task_retries_on_error'))

##################################################################
#Collection Connection attributes from Airflow connections repo
##################################################################

sf_con_parm = BaseHook.get_connection(
    'snowflake_1')  #Airflow_snowflake_connection_name
snowflake_username = sf_con_parm.login
snowflake_password = sf_con_parm.password
snowflake_account = sf_con_parm.host
snowflake_stage_schema = 'A_UTILITY'
#snowflake_warehouse = "XSMALL"
snowflake_database = "US_RAW"

mysql_con = BaseHook.get_connection(
    'mysql_celltrak_1')  #Airflow_mysql_connection_name
mysql_username = mysql_con.login
mysql_password = mysql_con.password
mysql_hostname = mysql_con.host
mysql_port = mysql_con.port

########################################################################
コード例 #21
0
 def get_db_hook(self):
     return BaseHook.get_hook(conn_id=self.conn_id)
コード例 #22
0
ファイル: google_ads_operator.py プロジェクト: sdaltmann/ewah
    def ewah_execute(self, context):
        # Task execution happens here
        def get_data_from_ads_output(fields_dict, values, prefix=None):
            if prefix is None:
                prefix = ''
            elif not prefix[-1] == '_':
                prefix += '_'
                # e.g. 2b prefix = 'ad_group_criterion_'
            data = {}
            for key, value in fields_dict.items():
                # e.g. 1 key = 'metrics', value = ['impressions', 'clicks']
                # e.g. 2a key = 'ad_group_criterion', value = [{'keyword': ['text', 'match_type']}]
                # e.g. 2b key = 'keyword', value = ['text', 'match_type']
                node = getattr(values, key)
                # e.g. 1 node = row.metrics
                # e.g. 2a node = row.ad_group_criterion
                # e.g. 2b node = row.ad_group_criterion.keyword
                for item in value:
                    # e.g. 1 item = 'clicks'
                    # e.g. 2a item = {'keyword': ['text', 'match_type']}
                    # e.g. 2b item = 'text'
                    if type(item) == dict:
                        data.update(
                            get_data_from_ads_output(
                                fields_dict=item,
                                values=node,
                                prefix=prefix +
                                key,  # e.g. 2a '' + 'ad_group_criterion'
                            ))
                    else:
                        # e.g. 1: {'' + 'metrics' + '_' + 'clicks': row.metrics.clicks.value}
                        # e.g. 2b: {'ad_group_criterion_' + 'keyeword' + '_' + 'text': row.ad_group_criterion.keyword.text.value}
                        if hasattr(getattr(node, item), 'value'):
                            data.update({
                                prefix + key + '_' + item: \
                                    getattr(node, item).value
                            })
                        else:
                            # some node ends don't respond to .value but are
                            #   already the value
                            data.update({
                                prefix + key + '_' + item:
                                getattr(node, item)
                            })
            return data

        self.data_until = airflow_datetime_adjustments(self.data_until)
        self.data_until = self.data_until or context['next_execution_date']
        if isinstance(self.data_from, timedelta):
            self.data_from = self.data_until - self.data_from
        else:
            self.data_from = airflow_datetime_adjustments(self.data_from)
            self.data_from = self.data_from or context['execution_date']

        conn = BaseHook.get_connection(self.source_conn_id).extra_dejson
        credentials = {}
        for key in self._REQUIRED_KEYS:
            if not key in conn.keys():
                raise Exception(
                    '{0} must be in connection extra json!'.format(key))
            credentials[key] = conn[key]

        # build the query
        query = 'SELECT {0} FROM {1} WHERE segments.date {2} {3}'.format(
            ', '.join(self.fields_list),
            self.resource,
            "BETWEEN '{0}' AND '{1}'".format(
                self.data_from.strftime('%Y-%m-%d'),
                self.data_until.strftime('%Y-%m-%d'),
            ),
            ('AND' + ' AND '.join(self.conditions)) if self.conditions else '',
        )

        self.log.info('executing this google ads query:\n{0}'.format(query))
        cli = GoogleAdsClient.load_from_dict(credentials)
        service = cli.get_service("GoogleAdsService", version="v3")
        search = service.search(
            self.client_id.replace('-', ''),
            query=query,
        )
        data = [row for row in search]

        # get into uploadable format
        upload_data = []
        while data:
            datum = data.pop(0)
            upload_data += [
                get_data_from_ads_output(
                    deepcopy(self.fields_dict),
                    datum,
                )
            ]

        self.upload_data(upload_data)
コード例 #23
0
def create_connection():
    c= BaseHook.get_connection('mssql_pi') 
    return c
コード例 #24
0
def hdfs_conn(conn_id='hdfs'):
    conn = BaseHook.get_connection(conn_id)
    return f"hdfs://{conn.host}:{conn.port}"
コード例 #25
0
 def test_dbapi_get_sqlalchemy_engine(self):
     conn = BaseHook.get_connection(conn_id='test_uri')
     hook = conn.get_hook()
     engine = hook.get_sqlalchemy_engine()
     self.assertIsInstance(engine, sqlalchemy.engine.Engine)
     self.assertEqual('postgres://*****:*****@ec2.compute.com:5432/the_database', str(engine.url))
コード例 #26
0
 def from_conn_id(cls, conn_id: str) -> 'CustomBaseHook':
     conn_params = BaseHook.get_connection(conn_id)
     return cls(conn_params)
コード例 #27
0
user = logging.getLogger(__name__)

default_args = {
    'owner': 'airflow',
    'retries': 1,
    'retry_delay': timedelta(minutes=1)
}

# Snowflake information
# Information must be stored in connections
# It can be done with Airflow UI -
# Admin -> Connections -> Create
database_name = 'TEST_DB'
table_name = 'customer'
schema_name = 'public'
snowflake_username = BaseHook.get_connection('snowflake').login
snowflake_password = BaseHook.get_connection(
    'snowflake').password  #Snowflake conn id=snowflake
snowflake_account = BaseHook.get_connection('snowflake').host

dag = DAG(dag_id="finally_done",
          default_args=default_args,
          start_date=datetime(2020, 3, 31),
          schedule_interval='*/12 * * * *',
          catchup=False)

#Loading Data


def load_data(**context):
    con = snowflake.connector.connect(user = snowflake_username, \
コード例 #28
0
def send_slack_message(text):
    connection = BaseHook.get_connection("slack")
    headers = {"Content-type": "application/json"}
    requests.post(connection.host,
                  data=json.dumps({"text": text}),
                  headers=headers)
コード例 #29
0
import datetime as dt

# Third Party
from airflow import DAG
from airflow.utils.helpers import chain
from airflow.hooks.base_hook import BaseHook
from airflow.operators.python_operator import PythonOperator
from airflow.operators.postgres_operator import PostgresOperator

# Custom
sys.path.append("/usr/local/airflow/dags/efs")
import redb.scripts.transfer_to_s3 as toS3
import redb.scripts.mdb_to_postgres as mdbToREDB

# Credentials for S3 Bucket
BUCKET_CONN = BaseHook.get_connection('redb-workbucket')
BUCKET_NAME = BUCKET_CONN.conn_id
AWS_ACCESS_KEY_ID = BUCKET_CONN.login
AWS_SECRET_ACCESS_KEY = BUCKET_CONN.password

# Credentials for Database
DATABASE_CONN = BaseHook.get_connection('redb_postgres')
DATABASE_NAME = DATABASE_CONN.schema
DATABASE_HOST = DATABASE_CONN.host
DATABASE_USER = DATABASE_CONN.login
DATABASE_PORT = DATABASE_CONN.port
DATABASE_PASSWORD = DATABASE_CONN.password

default_args = {
    'owner': 'redb',
    'start_date': dt.datetime(2020, 7, 23),
コード例 #30
0
    BranchPythonOperator

default_args = {
    "owner": "airflow",
    "start_date": datetime(2020, 11, 1),
    "depends_on_past": False,
    "email_on_failure": False,
    "email_on_retry": False,
    "email": "*****@*****.**",
    "retries": 1,
    "retry_delay": timedelta(minutes=5)
}

data_path = f'{json.loads(BaseHook.get_connection("data_path").get_extra()).get("path")}/data.csv'
transformed_path = f'{os.path.splitext(data_path)[0]}-transformed.csv'
slack_token = BaseHook.get_connection("slack_conn").password


def transform_data(*args, **kwargs):
    invoices_data = pd.read_csv(filepath_or_buffer=data_path,
                                sep=',',
                                header=0,
                                usecols=[
                                    'StockCode', 'Quantity', 'InvoiceDate',
                                    'UnitPrice', 'CustomerID', 'Country'
                                ],
                                parse_dates=['InvoiceDate'],
                                index_col=0)
    invoices_data.to_csv(path_or_buf=transformed_path)

コード例 #31
0
from airflow.utils.dates import days_ago
from datetime import datetime
import sqlalchemy
import pymysql
import papermill as pm
import airflow.hooks.S3_hook
from airflow.hooks.base_hook import BaseHook

bucket_name = "spotify-billboard-airflow-project"
reports_storage_path = "pass"
data_storage_path = "/Users/jkocher/Documents/airflow_home/data/"
jupyter_notebook_storage = "pass"
billboard_location = "data/raw_data/billboard_pickle"
audio_features_location = "data/raw_data/audio_feature_pickle"
hook = airflow.hooks.S3_hook.S3Hook('my_conn_S3')
c = BaseHook.get_connection('postgres_conn')
engine = sqlalchemy.create_engine('postgresql+psycopg2://' + str(c.login) +
                                  ':' + str(c.password) + '@' + str(c.host) +
                                  ':5432/music_db')
#engine = sqlalchemy.create_engine('postgresql+psycopg2://airflow:[email protected]:5432/music_db')

default_args = {
    'owner': 'James Kocher',
    'depends_on_past': False,
    'start_date': datetime.now(),
    'retries': 0
}

dag = DAG(
    "data_world_music_pipeline",
    default_args=default_args,
コード例 #32
0
def load_data(**context):
    postgres_hook = PostgresHook('admin_postgres')
    tickers = postgres_hook.get_records(
        'select yf_code from tickers where fetch_from_yahoo_finance')
    logging.info('Loaded %d tickers from db.' % len(tickers))
    tickers = [x[0] for x in tickers]
    frequency = '1d'
    start_dt = parse_execution_date(
        context['yesterday_ds']) - timedelta(days=7)
    data = yf.download(tickers=tickers,
                       start=start_dt,
                       end=context['tomorrow_ds'],
                       interval=frequency,
                       auto_adjust=True,
                       group_by='ticker',
                       progress=False,
                       threads=True)

    columns_mapping = {
        'Date': 'ts',
        'Open': 'open',
        'High': 'high',
        'Low': 'low',
        'Close': 'close',
        'Volume': 'volume',
        'Adj Close': 'adj_close'
    }
    ch_columns = ['ticker', 'frequency', 'source', 'type'] + list(
        columns_mapping.values())
    df = None
    for ticker in tickers:
        try:
            _df = data[ticker].copy()
        except KeyError:
            logging.error('Ticker %s not found in data' % ticker)
            continue
        _df = _df.reset_index()
        _df['ticker'] = ticker
        _df['frequency'] = frequency
        _df['source'] = 'yfinance'
        _df['type'] = 'history'
        _df = _df.rename(columns=columns_mapping)
        if 'adj_close' not in _df.columns:
            _df['adj_close'] = np.nan
        _df = _df[ch_columns]
        _df = _df[~_df.close.isna()]

        if df is None:
            df = _df
        else:
            df = pd.concat([df, _df])

    logging.info('Prepared df with shape (%s, %s)' % df.shape)
    ch_hook = BaseHook(None)
    ch_conn = ch_hook.get_connection('rocket_clickhouse')
    data_json_each = ''
    df.reset_index(drop=True, inplace=True)
    for i in df.index:
        json_str = df.loc[i].to_json(date_format='iso')
        data_json_each += json_str + '\n'

    result = requests.post(
        url=ch_conn.host,
        data=data_json_each,
        params=dict(
            query='insert into rocket.events format JSONEachRow',
            user=ch_conn.login,
            password=ch_conn.password,
            date_time_input_format='best_effort',
        ))
    if result.ok:
        logging.info('Insert ok.')
    else:
        raise requests.HTTPError('Request response code: %d. Message: %s' %
                                 (result.status_code, result.text))
コード例 #33
0
# custom operators
from operators.s3toredshift_operator import S3ToRedshiftOperator
from operators.dimension_operator import DimensionOperator

from operators.data_quality_count_operator import DataQualityCountOperator
from operators.data_quality_dimension_operator import DataQualityDimensionOperator

# helpers
from helpers import sql_queries_staging, sql_queries_presentation

# OS variables
start_time = datetime.now()
start_time_str = start_time.strftime("%d/%m/%Y %H:%M:%S")

# Configuration variables
aws_connection = BaseHook.get_connection("aws_credentials")
aws_username: str = aws_connection.login
aws_password: str = aws_connection.password
'''

    MANIFOLD DAG CONFIGURATION

'''

########################
# Airflow Dag Configs  #
########################
DEFAULT_ARGS = {
    'owner': 'Guilherme Banhudo',
    'depends_on_past': False,
    'email': ['*****@*****.**'],
コード例 #34
0
default_args = {
    "owner": "airflow",
    "depends_on_past": False,
    "start_date": datetime(2018, 10, 14),
    "email": ["*****@*****.**"],
    "email_on_failure": False,
    "email_on_retry": False,
    "retries": 1,
    "retry_delay": timedelta(minutes=5),
    # 'queue': 'bash_queue',
    # 'pool': 'backfill',
    # 'priority_weight': 10,
    # 'end_date': datetime(2016, 1, 1),
}

SIXTHMAN_PROD = BaseHook.get_connection("sixthman_prod")
SIXTHMAN_CONN_PASSWORD = SIXTHMAN_PROD.password

dag = DAG("nba_box_scores",
          default_args=default_args,
          schedule_interval=timedelta(minutes=20),
          catchup=False)

t1 = BashOperator(
    task_id="nba_box_scores_task",
    pool="nba_box_scores",
    bash_command=
    f"DATABASE_API_CONNECTION=postgres://sixthman:{SIXTHMAN_CONN_PASSWORD}@sixthman-prod.cbdmxavtswxu.us-west-1.rds.amazonaws.com:5432/sixthman node /usr/local/airflow/build/ingestJobs/scrapeNbaBoxscore.js",
    retries=3,
    execution_timeout=timedelta(minutes=3),
    dag=dag)
コード例 #35
0
import gc


##################################################################
#Setting variable definitions and connections
##################################################################

#database = Variable.get('Create_customer_database_tables_var__database_name')
Airflow_snowflake_connection_name = Variable.get('Airflow_snowflake_connection_name')
Airflow_mysql_connection_name = Variable.get('Airflow_mysql_connection_name')

database_list = ['database']

parent_dag_name = 'Collect_Mysql_Table_Counts_Load_to_Snowflake_Muliple_dbs'

sf_con = BaseHook.get_connection(Airflow_snowflake_connection_name)
snowflake_username = sf_con.login 
snowflake_password = sf_con.password 
snowflake_account = sf_con.host 
snowflake_warehouse = "XSMALL" 
snowflake_database = "sf_db"

mysql_con = BaseHook.get_connection(Airflow_mysql_connection_name)
mysql_username = mysql_con.login 
mysql_password = mysql_con.password 
mysql_hostname = mysql_con.host
mysql_port = mysql_con.port

########################################################################
#Defining Utility functions
########################################################################
コード例 #36
0
input of single zipcode

Contains two major codes: thread and unnessary data copy: bug > pd.DataFrame is not thread safe, flow: pd.append used in loop
"""

from airflow import DAG
from airflow.models import Variable
from airflow.hooks.base_hook import BaseHook
from airflow.operators.python_operator import PythonOperator
from airflow.utils.dates import datetime
import sqlalchemy
import pandas as pd
from pipelines.boligax import BoligaRecent


CONNECTION_URI = BaseHook.get_connection("bolig_db").get_uri()
TABLE_NAME = f'recent_bolig_{Variable.get("postal",2650)}'

args = {
    "owner": "Prayson",
    "catchup_by_default": False,
}


def get_bolig(
    postal: int, engine: sqlalchemy.types.TypeEngine = None, **kwargs
) -> None:
    """get bolig[estate] from a given postal code

    Arguments:
        postal {int} -- Danish postal code: e.g. 2560
コード例 #37
0
default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'start_date': START_DATE,
    'email': ['*****@*****.**'],
    'email_on_failure': True,
    'email_on_retry': False,
    'retries': 23,
    'retry_delay': timedelta(minutes=20),
}

dag = DAG('DAG_NAME', default_args=default_args, schedule_interval="@once")

dag_config = Variable.get('VARIABLES_NAME', deserialize_json=True)
aws_conn = BaseHook.get_connection("aws_conn")
s3_bucket = dag_config['s3_bucket']
datasource_type = dag_config['datasource_type']
date_preset = dag_config['date_preset']
account_id = dag_config['account_id']
access_token = dag_config['access_token']
api_version = dag_config['api_version']
insight_fields = dag_config['insight_fields']
action_attribution_windows = dag_config['action_attribution_windows']
file_path = dag_config['file_path']
time_increment = dag_config['time_increment']
backlogdays = dag_config['backlog_days']
days = 29 + int(dag_config['backlog_days'])
# today = datetime.today().strftime('%Y-%m-%d')
today = date.today().isoformat()
start_date = (date.today() - timedelta(days=29)).isoformat()
コード例 #38
0
def subdag(parent_dag_name, child_dag_name, args, json_gs):
    dag_subdag = DAG(
        dag_id=f'{parent_dag_name}.{child_dag_name}',
        default_args=args,
        start_date=datetime.datetime(2021, 8, 5, 20, 0),
        schedule_interval='0 13,14,15,16,17,18,19,20,21,22,23,0,1 * * *',
    )

    connection_airflow_yas_sa_sii_de = BaseHook.get_connection(
        'google_cloud_yas_sa_sii_de')
    service_account_yas_sa_sii_de = ast.literal_eval(
        connection_airflow_yas_sa_sii_de.
        extra_dejson["extra__google_cloud_platform__keyfile_dict"])

    with gcsfs.GCSFileSystem(
            project='yas-dev-sii-pid',
            token=service_account_yas_sa_sii_de).open(json_gs) as f:
        jd = json.load(f)

    # Variables para ejecucion desde JSON
    url_trn = jd['url_trn']

    # Datos de TRN
    job_name_hom = jd['job_name_hom']
    url_hom = jd['url_hom']
    file_name_hom = jd['file_name_hom']
    template_location_hom = jd['template_location_hom']

    # Datos Generales para la ejecucion
    temp_location = jd['temp_location']
    project = jd['project']
    region = jd['region']
    subnetwork = jd['subnetwork']
    service_account_email = jd['service_account_email']
    machine_type = jd['machine_type']
    max_num_workers = jd['max_num_workers']
    num_workers = jd['num_workers']

    folders = gcsfs.GCSFileSystem(
        project='yas-dev-sii-pid',
        token=service_account_yas_sa_sii_de).ls(url_trn)

    if len(folders) > 0:
        for folder in folders:
            date_folder = folder.split('/')[3]

            if len(date_folder) >= 10:
                url_source = 'gs://' + folder
                url_dest = url_hom + date_folder + '/' + file_name_hom

                parent_dag_name_for_id = parent_dag_name.lower()

                print('url_source: ' + url_source)
                print('url_dest: ' + url_dest)

                DataflowTemplateOperator(
                    template=template_location_hom,
                    job_name=
                    f'{parent_dag_name_for_id}-{child_dag_name}-{date_folder}',
                    task_id=
                    f'{parent_dag_name_for_id}-{child_dag_name}-{date_folder}',
                    location=region,
                    parameters={
                        'url_trn': url_source,
                        'url_hom': url_dest,
                    },
                    default_args=args,
                    dataflow_default_options={
                        'project': project,
                        'zone': 'us-east1-c',
                        'tempLocation': temp_location,
                        'machineType': machine_type,
                        'serviceAccountEmail': service_account_email,
                        'subnetwork': subnetwork,
                    },
                    gcp_conn_id='google_cloud_yas_sa_sii_de',
                    dag=dag_subdag,
                )
    return dag_subdag
コード例 #39
0
def check_if_tweet_is_avalaible(twitter_account_id=None,
                                since_id=None,
                                find_param=None,
                                **kwargs):
    """
    This method tweepy api via TwitterHook to check if a tweet from a specific twitter_account
    containing a specific search_string or not
    :param: twitter_account_id : for which tweets are to be fetched
    :param: since_id : airflow execution date of the dag
    :return: tweet_id
    """
    log = LoggingMixin().log
    try:
        # Load Configuration Data
        config = json.loads(Variable.get("config"))
        log.info("Config found")

    except AirflowException as e:
        log.error("Config missing")
        raise ConfigVariableNotFoundException()

    try:
        twitter_account_id = config['twitter_account_id']
    except KeyError as e:
        raise AirflowException('Missing Twitter Account Id in config variable')

    try:
        since_id = config['since_id']
    except KeyError as e:
        log.warn("Since id missing")

    try:
        find_param = config['find_param'].lower()
    except KeyError as e:
        raise AirflowException('Missing Find Param in config variable')

    try:
        twitter_credentials = BaseHook.get_connection("twitter_default")
        twitter_credentials = json.loads(twitter_credentials.extra)
        consumer_key = twitter_credentials['consumer_key']
        consumer_secret = twitter_credentials['consumer_secret']
        access_token = twitter_credentials['access_token']
        access_token_secret = twitter_credentials['access_token_secret']

    except AirflowException as e:
        raise TwitterConnectionNotFoundException()

    twitter_hook = TwitterHook(consumer_key=consumer_key,
                               consumer_secret=consumer_secret,
                               access_token=access_token,
                               access_token_secret=access_token_secret)

    tweepy_api = twitter_hook.get_tweepy_api()
    today = date.today()
    curr_date = today.strftime("%d-%m-%Y")
    # try to get tweet related to covid media bulliten from @diprjk handle

    tweets = tweepy_api.user_timeline(id=twitter_account_id,
                                      since_id=since_id,
                                      count=1000,
                                      exclude_replies=True,
                                      include_rts=False,
                                      tweet_mode="extended")
    if len(tweets) > 0:
        # find_param = "Media Bulletin on Novel".lower()
        log.info("Found : {}  tweets".format(len(tweets) + 1))
        # loop over all extracted tweets and
        # if tweet.full_text contains string "Media Bulletin On Novel"
        # then we got our concerned tweet and save its tweet_id
        image_urls = []
        for tweet in tweets:
            tweet_date = tweet.created_at
            tweet_date = tweet_date.strftime("%d-%m-%Y")
            text = tweet.full_text.lower()
            if find_param in text and tweet_date == curr_date:
                bulletin_tweet_id = tweet.id
                print('Tweet found')
                # save bulliten tweet id as environ variable or on file and then use in next run
                log.info("Tweet ID: {}  TEXT : {} ".format(
                    bulletin_tweet_id, tweet.full_text))
                if 'media' in tweet.entities:
                    for media in tweet.extended_entities['media']:
                        image_urls.append(media['media_url'])
                    detail_image_url = image_urls[2]
                    log.info("Tweet Image Url: {} ".format(detail_image_url))
                else:
                    log.info("No media found")
                    #skip the processing and end dag
                    return False
                data = {
                    "tweet_id": bulletin_tweet_id,
                    "tweet_date": tweet_date,
                    "media_url": detail_image_url
                }
                Variable.set("bulliten_tweet", json.dumps(data))
                return True
            else:
                pass
        else:
            log.info("No tweets related to {} found".format(find_param))
            return False

    else:
        log.info("No tweets found!")
        return False
コード例 #40
0
ファイル: useless_hook.py プロジェクト: biellls/airflow-plus
 def from_conn_id(conn_id: str):
     return UselessHookImplicitProtocol(BaseHook.get_connection(conn_id))
コード例 #41
0
import logging
from datetime import datetime, timedelta

import pandas as pd
import pendulum
import sqlalchemy
from airflow import DAG
from airflow.hooks.base_hook import BaseHook
from airflow.operators.python_operator import PythonOperator
from airflow.utils.email import send_email

pg_conn = BaseHook.get_connection("postgres_default")

local_tz = pendulum.timezone("America/Los_Angeles")
default_args = {
    "owner": "airflow",
    "depends_on_past": False,
    "start_date": datetime(2018, 10, 30, tzinfo=local_tz),
    "email": [
        "*****@*****.**",
        "*****@*****.**",
        "*****@*****.**",
    ],
    "email_on_failure": True,
    "email_on_retry": False,
    "retries": 1,
    "retry_delay": timedelta(minutes=15)
    # 'queue': 'bash_queue',
    # 'pool': 'backfill',
    # 'priority_weight': 10,
    # 'end_date': datetime(2016, 1, 1),
コード例 #42
0
ファイル: sql.py プロジェクト: beingbisht/airflow-tests
 def get_db_hook(self):
     """
     Returns DB hook
     """
     return BaseHook.get_hook(conn_id=self.conn_id)
コード例 #43
0
from airflow.hooks.base_hook import BaseHook
from sqlalchemy import create_engine
from sqlalchemy_utils import create_database, database_exists
import sqlalchemy

datalake_conn_string = BaseHook.get_connection('postgres_datalake').get_uri()

engine = create_engine(datalake_conn_string)

# create database
if not database_exists(engine.url):
    create_database(engine.url)
    engine.execute("GRANT ALL PRIVILEGES ON DATABASE {db} TO {user};".format(user = engine.url.username, db = engine.url.database))

# create schema, give permissions
if not engine.dialect.has_schema(engine, 'views'):
    engine.execute(sqlalchemy.schema.CreateSchema('views'))
    engine.execute("GRANT ALL PRIVILEGES ON SCHEMA views TO {user};".format(user = engine.url.username))
    engine.execute("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA views TO {user};".format(user = engine.url.username))
    engine.execute("ALTER DEFAULT PRIVILEGES IN SCHEMA views GRANT ALL PRIVILEGES ON TABLES TO {user};".format(user = engine.url.username))