コード例 #1
0
ファイル: test_postgres.py プロジェクト: yqian1991/airflow
    def test_bulk_load(self):
        hook = PostgresHook()
        input_data = ["foo", "bar", "baz"]

        with hook.get_conn() as conn:
            with conn.cursor() as cur:
                cur.execute("CREATE TABLE {} (c VARCHAR)".format(self.table))
                conn.commit()

                with NamedTemporaryFile() as f:
                    f.write("\n".join(input_data).encode("utf-8"))
                    f.flush()
                    hook.bulk_load(self.table, f.name)

                cur.execute("SELECT * FROM {}".format(self.table))
                results = [row[0] for row in cur.fetchall()]

        self.assertEqual(sorted(input_data), sorted(results))
コード例 #2
0
ファイル: operators.py プロジェクト: dsinaction/meteo
    def create_temporary_dbfile(self, request_id, buffer):
        pg_hook = PostgresHook(postgres_conn_id=self.postgres_conn_id,
                               schema=self.database)
        with closing(pg_hook.get_conn()) as pg_conn:
            with closing(pg_conn.cursor()) as pg_cursor:
                pg_cursor.execute(
                    """
                    INSERT INTO imgw.temporary_file (request_id, data) 
                    VALUES (%s, %s)
                    ON CONFLICT (request_id)
                    DO
                        UPDATE SET data = EXCLUDED.data
                    RETURNING id
                    """, (request_id, psycopg2.Binary(buffer)))
                temporary_file_id = pg_cursor.fetchone()

            pg_conn.commit()

        return temporary_file_id[0]
コード例 #3
0
    def execute(self, context):
        aws_hook = AwsBaseHook(self.aws_credentials_id, client_type="s3")
        credentials = aws_hook.get_credentials()
        redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id)

        self.log.info("Clearing data from destination Redshift table")
        redshift.run("DELETE FROM {}".format(self.table))

        self.log.info("Copying data from S3 to Redshift")
        rendered_key = self.s3_key.format(**context)
        s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key)
        formatted_sql = StageToRedshiftOperator.copy_sql.format(
            self.table,
            s3_path,
            credentials.access_key,
            credentials.secret_key,
            self.json_format,
        )
        redshift.run(formatted_sql)
コード例 #4
0
    def setUpClass(cls):
        postgres = PostgresHook()
        with postgres.get_conn() as conn:
            with conn.cursor() as cur:
                for table in TABLES:
                    cur.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
                    cur.execute(
                        f"CREATE TABLE {table}(some_str varchar, some_num integer);"
                    )

                cur.execute(
                    "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
                    ('mock_row_content_1', 42))
                cur.execute(
                    "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
                    ('mock_row_content_2', 43))
                cur.execute(
                    "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
                    ('mock_row_content_3', 44))
コード例 #5
0
ファイル: test_postgres.py プロジェクト: ysk24ok/airflow
    def test_bulk_dump(self):
        hook = PostgresHook()
        input_data = ["foo", "bar", "baz"]

        with hook.get_conn() as conn:
            with conn.cursor() as cur:
                cur.execute(f"CREATE TABLE {self.table} (c VARCHAR)")
                values = ",".join(f"('{data}')" for data in input_data)
                cur.execute(f"INSERT INTO {self.table} VALUES {values}")
                conn.commit()

                with NamedTemporaryFile() as f:
                    hook.bulk_dump(self.table, f.name)
                    f.seek(0)
                    results = [
                        line.rstrip().decode("utf-8")
                        for line in f.readlines()
                    ]

        assert sorted(input_data) == sorted(results)
コード例 #6
0
    def execute(self, context):
        aws_hook = AwsBaseHook(self.aws_credentials_id)
        aws_credentials = aws_hook.get_credentials()
        redshift_conn = PostgresHook(
            postgres_conn_id=self.redshift_conn_id,
            connect_args={
                'keepalives': 1,
                'keepalives_idle': 60,
                'keepalives_interval': 60
            })

        self.log.debug(f"Truncate Table: {self.table}")
        redshift_conn.run(f"TRUNCATE TABLE {self.table}")

        format = ''
        if self.data_format == 'csv' and self.ignore_header > 0:
            format += f"IGNOREHEADER {self.ignore_header}\n"

        if self.data_format == 'csv':
            format += f"DELIMITER '{self.delimiter}'\n"
        elif self.data_format == 'json':
            format += f"FORMAT AS JSON '{self.jsonpath}'\n"
        format += f"{self.copy_opts}"
        self.log.debug(f"format : {format}")

        formatted_key = self.s3_src_bucket_key.format(**context)
        self.log.info(f"Rendered S3 source file key : {formatted_key}")
        s3_url = f"s3://{self.s3_src_bucket_name}/{formatted_key}"
        self.log.debug(f"S3 URL : {s3_url}")
        formatted_sql = self._sql.format(**dict(
            table=self.table,
            source=s3_url,
            access_key=aws_credentials.access_key,
            secret_access_key=aws_credentials.secret_key,
            format=format
        ))
        self.log.debug(f"Base SQL: {self._sql}")

        self.log.info(f"Copying data from S3 to Redshift table {self.table}...")
        redshift_conn.run(formatted_sql)
        self.log.info(f"Finished copying data from S3 to Redshift table {self.table}")
コード例 #7
0
    def test_bulk_dump(self):
        hook = PostgresHook()
        input_data = ["foo", "bar", "baz"]

        with hook.get_conn() as conn:
            with conn.cursor() as cur:
                cur.execute("CREATE TABLE {} (c VARCHAR)".format(self.table))
                values = ",".join("('{}')".format(data) for data in input_data)
                cur.execute("INSERT INTO {} VALUES {}".format(
                    self.table, values))
                conn.commit()

                with NamedTemporaryFile() as f:
                    hook.bulk_dump(self.table, f.name)
                    f.seek(0)
                    results = [
                        line.rstrip().decode("utf-8")
                        for line in f.readlines()
                    ]

        self.assertEqual(sorted(input_data), sorted(results))
コード例 #8
0
 def capture_export_wrap(ds, **kwargs):
     from lib.utils import print_time
     db = PostgresHook(postgres_conn_id=postgresConnId)
     conn = db.get_conn()
     try:
         #   # generate last month
         #   # get the month value of the last month
         #   now = datetime.now()
         #   last_month = now.month - 1
         #   # get the year value of the last month
         #   last_year = now.year
         #   if last_month == 0:
         #       last_month = 12
         #       last_year = now.year - 1
         #   print ("last_month:", last_month)
         #   print ("last_year:", last_year)
         #   # get the last month
         #   year_month = str(last_year) + "-" + str(last_month)
         #   print ("year_month:", year_month)
         date = datetime.now().strftime("%Y-%m-%d")
         print("date:", date)
         CKAN_DOMAIN = Variable.get("CKAN_DOMAIN")
         # check if CKAN_DOMAIN exists
         assert CKAN_DOMAIN
         CKAN_DATASET_NAME = Variable.get("CKAN_DATASET_NAME")
         assert CKAN_DATASET_NAME
         CKAN_API_KEY = Variable.get("CKAN_API_KEY")
         assert CKAN_API_KEY
         ckan_config = {
             "CKAN_DOMAIN": CKAN_DOMAIN,
             "CKAN_DATASET_NAME": CKAN_DATASET_NAME,
             "CKAN_API_KEY": CKAN_API_KEY,
         }
         print("ckan_config:", ckan_config)
         capture_export(conn, date, 178, ckan_config)
         return 0
     except Exception as e:
         print("get error when exec SQL:", e)
         raise ValueError('Error executing query')
         return 1
コード例 #9
0
 def execute(self, context, testing=False):
     """Does data quality checks for each table in table list.
     Assert a list of tables against a business defined SQL metrics.
     """
     self.log.info('DataQualityCheckOperator Starting...')
     self.log.info("Initializing Postgres Master DB Connection...")
     psql_hook = PostgresHook(postgres_conn_id=self._postgres_conn_id)
     try:
         conn = psql_hook.get_conn()
         cursor = conn.cursor(cursor_factory=RealDictCursor)
         for table in self._tables:
             data_quality = dict()
             for name, query in self._queries.items():
                 self.log.info(f"Running query: {query}")
                 cursor.execute(query)
                 result = cursor.fetchone()
                 result = result.get('count')
                 if not result:
                     error = ("Data quality check FAILED. "
                              f"{table} returned no results "
                              f"for query: {name}")
                     self.log.error(error)
                     raise ValueError(error)
                 data_quality[name] = result
             self.log.info(
                 f"Data quality check on table '{table}' PASSED\n"
                 "Results Summary:\n"
                 f"{json.dumps(data_quality, indent=4, sort_keys=True)}")
     except (InterfaceError, OperationalError):
         self.log.error("DataQualityCheckOperator FAILED.")
         self.log.error(traceback.format_exc())
         raise Exception("DataQualityCheckOperator FAILED.")
     except Exception:
         self.log.error("DataQualityCheckOperator FAILED.")
         raise Exception("DataQualityCheckOperator FAILED.")
     finally:
         if not testing:
             conn.close()
     self.log.info('DataQualityCheckOperator SUCCESS!')
     return data_quality
コード例 #10
0
    def execute(self, context) -> None:
        postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        credentials = s3_hook.get_credentials()
        credentials_block = build_credentials_block(credentials)
        copy_options = '\n\t\t\t'.join(self.copy_options)
        destination = f'{self.schema}.{self.table}'
        copy_destination = f'#{self.table}' if self.method == 'UPSERT' else destination

        copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options)

        if self.method == 'REPLACE':
            sql = f"""
            BEGIN;
            DELETE FROM {destination};
            {copy_statement}
            COMMIT
            """
        elif self.method == 'UPSERT':
            keys = self.upsert_keys or postgres_hook.get_table_primary_key(self.table, self.schema)
            if not keys:
                raise AirflowException(
                    f"No primary key on {self.schema}.{self.table}. Please provide keys on 'upsert_keys'"
                )
            where_statement = ' AND '.join([f'{self.table}.{k} = {copy_destination}.{k}' for k in keys])
            sql = f"""
            CREATE TABLE {copy_destination} (LIKE {destination});
            {copy_statement}
            BEGIN;
            DELETE FROM {destination} USING {copy_destination} WHERE {where_statement};
            INSERT INTO {destination} SELECT * FROM {copy_destination};
            COMMIT
            """
        else:
            sql = copy_statement

        self.log.info('Executing COPY command...')
        postgres_hook.run(sql, self.autocommit)
        self.log.info("COPY command complete...")
コード例 #11
0
    def execute(self, context, testing=False):
        """
        Read all data from mongo db, process it
        and write to postgresql db.

        Uses UPSERT SQL query to write data.
        """
        self.log.info('LoadToMasterdbOperator Starting...')
        self.log.info("Initializing Mongo Staging DB Connection...")
        mongo_hook = MongoHook(conn_id=self._mongo_conn_id)
        ports_collection = mongo_hook.get_collection(self._mongo_collection)
        self.log.info("Initializing Postgres Master DB Connection...")
        psql_hook = PostgresHook(postgres_conn_id=self._postgres_conn_id)
        psql_conn = psql_hook.get_conn()
        psql_cursor = psql_conn.cursor()
        self.log.info("Loading Staging data to Master Database...")
        try:
            for idx, document in enumerate(ports_collection.find({})):
                document = self._processor.process_item(document)
                staging_id = document.get('_id').__str__()
                document['staging_id'] = staging_id
                document.pop('_id')
                psql_cursor.execute(self._sql_query, document)
            psql_conn.commit()
        except (OperationalError, UndefinedTable, OperationFailure):
            self.log.error("Writting to database FAILED.")
            self.log.error(traceback.format_exc())
            raise Exception("LoadToMasterdbOperator FAILED.")
        except Exception:
            self.log.error(traceback.format_exc())
            raise Exception("LoadToMasterdbOperator FAILED.")
        finally:
            if not testing:
                self.log.info('Closing database connections...')
                psql_conn.close()
                mongo_hook.close_conn()
        self.log.info(f'UPSERTED {idx+1} records into Postgres Database.')
        self.log.info('LoadToMasterdbOperator SUCCESS!')
コード例 #12
0
    def execute(self, context=None):
        """
        Format the sql statements with the params_sql statement.
        Execute one by one the different statements.
        Args:
            context:

        Returns:

        """
        if self.params_sql is not None:
            commands_formatted = [
                S.SQL(q).format(**self.params_sql)
                for q in self.commands_stripped
            ]
        else:
            commands_formatted = [S.SQL(q) for q in self.commands_stripped]
        hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        for qf in commands_formatted:
            self.log.info("Executing Query:{}".format(
                qf.as_string(hook.get_conn())))
            hook.run((qf, ))
            pass
コード例 #13
0
    def execute(self, context) -> None:
        postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        credentials = s3_hook.get_credentials()
        credentials_block = build_credentials_block(credentials)
        copy_options = '\n\t\t\t'.join(self.copy_options)

        copy_statement = self._build_copy_query(credentials_block, copy_options)

        if self.truncate_table:
            delete_statement = f'DELETE FROM {self.schema}.{self.table};'
            sql = f"""
            BEGIN;
            {delete_statement}
            {copy_statement}
            COMMIT
            """
        else:
            sql = copy_statement

        self.log.info('Executing COPY command...')
        postgres_hook.run(sql, self.autocommit)
        self.log.info("COPY command complete...")
コード例 #14
0
    def execute(self, context) -> None:
        postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        conn = S3Hook.get_connection(conn_id=self.aws_conn_id)

        credentials_block = None
        if conn.extra_dejson.get('role_arn', False):
            credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
        else:
            s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
            credentials = s3_hook.get_credentials()
            credentials_block = build_credentials_block(credentials)

        unload_options = '\n\t\t\t'.join(self.unload_options)

        unload_query = self._build_unload_query(credentials_block,
                                                self.select_query, self.s3_key,
                                                unload_options)

        self.log.info('Executing UNLOAD command...')
        postgres_hook.run(unload_query,
                          self.autocommit,
                          parameters=self.parameters)
        self.log.info("UNLOAD command complete...")
    def execute(self, context):
        postgres_hook = PostgresHook(postgres_conn_id=self._postgres_conn_id)
        s3_hook = S3Hook(aws_conn_id=self._s3_conn_id)

        with postgres_hook.get_cursor() as cursor:
            cursor.execute(self._query)
            results = cursor.fetchall()
            headers = [_[0] for _ in cursor.description]

        data_buffer = io.StringIO()
        csv_writer = csv.writer(data_buffer,
                                quoting=csv.QUOTE_ALL,
                                lineterminator=os.linesep)
        csv_writer.writerow(headers)
        csv_writer.writerows(results)
        data_buffer_binary = io.BytesIO(data_buffer.getvalue().encode())

        s3_hook.load_file_obj(
            file_obj=data_buffer_binary,
            bucket_name=self._s3_bucket,
            key=self._s3_key,
            replace=True,
        )
コード例 #16
0
    def execute(self, context):
        self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        credentials = self.s3.get_credentials()
        copy_options = '\n\t\t\t'.join(self.copy_options)

        copy_query = """
            COPY {schema}.{table}
            FROM 's3://{s3_bucket}/{s3_key}/{table}'
            with credentials
            'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
            {copy_options};
        """.format(schema=self.schema,
                   table=self.table,
                   s3_bucket=self.s3_bucket,
                   s3_key=self.s3_key,
                   access_key=credentials.access_key,
                   secret_key=credentials.secret_key,
                   copy_options=copy_options)

        self.log.info('Executing COPY command...')
        self.hook.run(copy_query, self.autocommit)
        self.log.info("COPY command complete...")
    def execute(self, context):
        """
        Description: This custom function fills a given fact table with a passed
                     SQL statement.

        Arguments:
            self: Instance of the class
            context: Context dictionary

        Returns:
            None
        """

        # Build connection
        postgres = PostgresHook(postgres_conn_id=self.postgres_conn_id)

        # Realize insert statement to fill dimension table
        formatted_sql = LoadFactOperator.insert_sql.format(
            self.table, self.insert_sql_query)
        postgres.run(formatted_sql)

        self.log.info(
            'LoadFactOperator for dimension table {} completed'.format(
                self.table))
コード例 #18
0
 def execute(self, context):
     """Establish connections to both MySQL & PostgreSQL databases, open
     cursor and begin processing query, loading chunks of rows into
     PostgreSQL. Repeat loading chunks until all rows processed for query.
     """
     source = MySqlHook(mysql_conn_id=self.mysql_conn_id)
     target = PostgresHook(postgres_conn_id=self.postgres_conn_id)
     with closing(source.get_conn()) as conn:
         with closing(conn.cursor()) as cursor:
             cursor.execute(self.sql, self.params)
             target_fields = [x[0] for x in cursor.description]
             row_count = 0
             rows = cursor.fetchmany(self.rows_chunk)
             while len(rows) > 0:
                 row_count += len(rows)
                 target.insert_rows(
                     self.postgres_table,
                     rows,
                     target_fields=target_fields,
                     commit_every=self.rows_chunk,
                 )
                 rows = cursor.fetchmany(self.rows_chunk)
             self.log.info(
                 f"{row_count} row(s) inserted into {self.postgres_table}.")
コード例 #19
0
    def tearDown(self):
        super().tearDown()

        with PostgresHook().get_conn() as conn:
            with conn.cursor() as cur:
                cur.execute("DROP TABLE IF EXISTS {}".format(self.table))
コード例 #20
0
 def tearDownClass(cls):
     postgres = PostgresHook()
     with postgres.get_conn() as conn:
         with conn.cursor() as cur:
             for table in TABLES:
                 cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table))
    def earnings_report(ds, **kwargs):
        db = PostgresHook(postgres_conn_id=postgresConnId)
        conn = db.get_conn()
        print("db:", conn)
        cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
        try:
           
            # these hard coded values are placeholders for the upcoming contracts system
            freetown_stakeholder_uuid = "2a34fa81-0683-4d25-94b9-24843ceec3c4"
            freetown_base_contract_uuid = "483a1f4e-0c52-4b53-b917-5ff4311ded26"
            freetown_base_contract_consolidation_uuid = "a2dc79ec-4556-4cc5-bff1-2dbb5fd35b51"

            cursor.execute("""
              SELECT COUNT(tree_id) capture_count,
              person_id,
              stakeholder_uuid,
              MIN(time_created) consolidation_start_date,
              MAX(time_created) consolidation_end_date,
              ARRAY_AGG(tree_id) tree_ids
              FROM (
                SELECT trees.id tree_id, person_id, time_created,
                stakeholder_uuid,
                rank() OVER (
                  PARTITION BY person_id
                  ORDER BY time_created ASC
                )
                FROM trees
                JOIN planter
                ON trees.planter_id = planter.id
                JOIN entity
                ON entity.id = planter.person_id
                AND earnings_id IS NULL
                AND planter.organization_id IN (
                  select entity_id from getEntityRelationshipChildren(178)
                )
                AND time_created > TO_TIMESTAMP(
                  '2021-09-01 00:00:00',
                  'YYYY-MM-DD HH24:MI:SS'
                )
                AND time_created <  TO_TIMESTAMP(
                  '2021-11-12 00:00:00',
                  'YYYY-MM-DD HH24:MI:SS'
                )
                AND trees.approved = true
                AND trees.active = true
              ) rank
              GROUP BY person_id, stakeholder_uuid
              ORDER BY person_id;
            """);
            print("SQL result:", cursor.query)
            for row in cursor:
                print(row)

                #calculate the earnings based on FCC logic
                multiplier = (row['capture_count'] - row['capture_count'] % 100) / 10 / 100
                if multiplier > 1: 
                  multiplier = 1
                print( "multiplier " + str(multiplier) )

                maxPayout = 1200000
                earningsCurrency = 'SLL'
                earnings = multiplier * maxPayout

                updateCursor = conn.cursor()
                updateCursor.execute("""
                  INSERT INTO earnings.earnings(
                    worker_id,
                    contract_id,
                    funder_id,
                    currency,
                    amount,
                    calculated_at,
                    consolidation_rule_id,
                    consolidation_period_start,
                    consolidation_period_end,
                    status
                    )
                  VALUES(
                    %s,
                    %s,
                    %s,
                    %s,
                    %s,
                    NOW(),
                    %s,
                    %s,
                    %s,
                    'calculated'
                  )
                  RETURNING *
              """, ( row['stakeholder_uuid'],
                     freetown_base_contract_uuid,
                     freetown_stakeholder_uuid,
                     earningsCurrency, 
                     earnings,
                     freetown_base_contract_consolidation_uuid,
                     row['consolidation_start_date'],
                     row['consolidation_end_date']))
                print("SQL result:", updateCursor.query)

                earningsId = updateCursor.fetchone()[0]
                print(earningsId)
                updateCursor.execute("""
                  UPDATE trees
                  SET earnings_id = %s
                  WHERE id = ANY(%s)
                """, 
                (earningsId, 
                row['tree_ids']))

            conn.commit()
            return 0
        except Exception as e:
            print("get error when exec SQL:", e)
            print("SQL result:", updateCursor.query)
            raise ValueError('Error executing query')
            return 1
コード例 #22
0
 def execute(self, context):
     self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, schema=self.database)
     self.hook.run(self.sql, self.autocommit, parameters=self.parameters)
     for output in self.hook.conn.notices:
         self.log.info(output)
コード例 #23
0
 def setUp(self):
     self.oltp_hook = PostgresHook('oltp')
     self.olap_hook = PostgresHook('olap')
コード例 #24
0
# Setting database name
db_name = "userdata"
# The api that we need to call
NY_API = "https://health.data.ny.gov/resource/xdss-u53e.json?"

# These args will get passed on to each operator
# You can override them on a per-task basis during operator initialization
default_args = {
    'owner': 'Anil',
    'dag_id': 'LOAD_NY_COVID_DLY',
    'start_date': datetime(2020, 3, 1, tzinfo=local_tz),
    'schedule_interval': '0 9 * * *'
}

# Using postgress Hook to get connection url and modifying it to have the right databasename
result = PostgresHook(postgres_conn_id='postgres_new').get_uri().split("/")
result[3] = db_name
dbURI = "/".join(result)

with DAG('LOAD_NY_COVID_DLY',
         default_args=default_args,
         catchup=False,
         template_searchpath='/opt/airflow/') as dag:

    @dag.task
    def getTodayDate():
        """
        gets the current context of Airflow task. This context will be used to get the execution date.

        """
        context = {"test_date": get_current_context()["ds"]}
コード例 #25
0
ファイル: base_download.py プロジェクト: VinnieJon/DE
import os
from airflow.operators.python_operator import PythonOperator
from airflow.providers.postgres.hooks.postgres import PostgresHook
import csv

PostgresConn = PostgresHook(postgres_conn_id='postgresql_conn')

def getconnection():
    PostgresConn.get_conn()
    print("connected")

def writerrecords_aisles():
    id = PostgresConn.get_records(sql='SELECT * FROM aisles')
    print(id)
    if not os.path.exists(os.path.join(os.getcwd(), 'base_data')):
        os.makedirs(os.path.join(os.getcwd(), 'base_data'))

    with open(os.path.join(os.getcwd(), 'base_data/aisles.csv'), 'w') as f:
        writer = csv.writer(f)
        writer.writerows(id)

def writerrecords_clients():
    id = PostgresConn.get_records(sql='SELECT * FROM clients')

    if not os.path.exists(os.path.join(os.getcwd(), 'base_data')):
        os.makedirs(os.path.join(os.getcwd(), 'base_data'))

    with open(os.path.join(os.getcwd(), 'base_data/clients.csv'), 'w') as f:
        writer = csv.writer(f)
        writer.writerows(id)
コード例 #26
0
 def get_data(self):
     pgHook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
     with closing(pgHook.get_conn()) as conn:
         df = pd.read_sql(self.postgres_sql, conn)
         return df
コード例 #27
0
 def tearDown(self):
     tables_to_drop = ['test_postgres_to_postgres', 'test_airflow']
     with PostgresHook().get_conn() as conn:
         with conn.cursor() as cur:
             for table in tables_to_drop:
                 cur.execute(f"DROP TABLE IF EXISTS {table}")
コード例 #28
0
    def create_tokens(ds, **kwargs):
        walletName = kwargs['dag_run'].conf.get('walletName')    
        entityId = kwargs['dag_run'].conf.get('entityId')    
        dryRun = kwargs['dag_run'].conf.get('dryRun')    
        # print them out
        print('walletName:', walletName)
        print('entityId:', entityId)
        print('dryRun:', dryRun)
        # check if wallet exists
        if walletName is None:
            print('walletName is None')
            return
        if entityId is None:
            print('entityId is None')
            return  
        if dryRun is None:
            print('dryRun is None')
            return

        result = 'pending'
        db = PostgresHook(postgres_conn_id='postgres_default')
        connection = db.get_conn()
        cursor = connection.cursor(cursor_factory=psycopg2.extras.DictCursor)
        try:
            # get first row from table 'wallet' 
            cursor.execute("SELECT * FROM wallet.wallet WHERE name = '{}'".format(walletName))
            wallet = cursor.fetchone()
            # check wallet exists
            if wallet is None:
                print('Wallet not found')
                return
            print('Wallet found', wallet)
            
            remaining = True
            
            for i in range(1, 100000):
                # if remaining is false, then we are done
                if not remaining:
                    break
                
                # fetch rows from table 'trees'
                cursor.execute("""
                    select id, uuid, token_id from trees
                    where
                        planter_id IN (
                            select id from planter 
                            where 
                                organization_id IN ( 
                                    select entity_id from getEntityRelationshipChildren({}) 
                                ) 
                            ) 
                            AND active = true 
                            AND approved = true 
                            AND token_id IS NULL 
                            LIMIT 3000
                    """.format(entityId))
                trees = cursor.fetchall()
                
                print('Trees found', len(trees))
                
                # check trees length < 3000
                if len(trees) < 3000:
                    print('Not more trees')
                    remaining = False
                
                # for each tree, create a token
                for capture in trees:
                    print('capture', capture)

                    tokenData = {
                        'tree_id': capture['id'],
                        'capture_id': capture['uuid'],
                        'wallet_id': wallet['id'],
                    }
                    
                    print('tokenData', tokenData)

                    # create token
                    cursor.execute("""
                        INSERT INTO wallet.token (
                            capture_id,
                            wallet_id
                        ) VALUES (
                            '{}',
                            '{}'
                        ) RETURNING id
                    """.format(tokenData['capture_id'], tokenData['wallet_id']))
                    
                    token = cursor.fetchone()
                    print('token', token)
                    print('token[id]', token['id'])
                    
                    # update tree with token id
                    cursor.execute("""
                        UPDATE trees SET token_id = '{}' WHERE id = {}
                    """.format(token['id'], capture['id']))
                    
                    print('Token created: {}'.format(token))

            # if dryRun is false, then commit
            if not dryRun:
                connection.commit()
                print('Commit')
                result = 'success'
            else:
                print('Dry run, not committing')
                result = 'dry run'
        except Exception as e:
            print(e)
            result = 'error'
        finally:
            cursor.close()
            connection.close()
            print('result', result)

        # check result value, if success, return true, else return false
        if result == 'success':
            return 0
        else:
            return 1
コード例 #29
0
 def drop_db():
     hook = PostgresHook()
     hook.run(DELETE_QUERY)
コード例 #30
0
 def get_hook(self):
     if self.conn_type == 'mysql':
         from airflow.providers.mysql.hooks.mysql import MySqlHook
         return MySqlHook(mysql_conn_id=self.conn_id)
     elif self.conn_type == 'google_cloud_platform':
         from airflow.gcp.hooks.bigquery import BigQueryHook
         return BigQueryHook(bigquery_conn_id=self.conn_id)
     elif self.conn_type == 'postgres':
         from airflow.providers.postgres.hooks.postgres import PostgresHook
         return PostgresHook(postgres_conn_id=self.conn_id)
     elif self.conn_type == 'pig_cli':
         from airflow.providers.apache.pig.hooks.pig import PigCliHook
         return PigCliHook(pig_cli_conn_id=self.conn_id)
     elif self.conn_type == 'hive_cli':
         from airflow.providers.apache.hive.hooks.hive import HiveCliHook
         return HiveCliHook(hive_cli_conn_id=self.conn_id)
     elif self.conn_type == 'presto':
         from airflow.providers.presto.hooks.presto import PrestoHook
         return PrestoHook(presto_conn_id=self.conn_id)
     elif self.conn_type == 'hiveserver2':
         from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
         return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
     elif self.conn_type == 'sqlite':
         from airflow.providers.sqlite.hooks.sqlite import SqliteHook
         return SqliteHook(sqlite_conn_id=self.conn_id)
     elif self.conn_type == 'jdbc':
         from airflow.providers.jdbc.hooks.jdbc import JdbcHook
         return JdbcHook(jdbc_conn_id=self.conn_id)
     elif self.conn_type == 'mssql':
         from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
         return MsSqlHook(mssql_conn_id=self.conn_id)
     elif self.conn_type == 'odbc':
         from airflow.providers.odbc.hooks.odbc import OdbcHook
         return OdbcHook(odbc_conn_id=self.conn_id)
     elif self.conn_type == 'oracle':
         from airflow.providers.oracle.hooks.oracle import OracleHook
         return OracleHook(oracle_conn_id=self.conn_id)
     elif self.conn_type == 'vertica':
         from airflow.providers.vertica.hooks.vertica import VerticaHook
         return VerticaHook(vertica_conn_id=self.conn_id)
     elif self.conn_type == 'cloudant':
         from airflow.providers.cloudant.hooks.cloudant import CloudantHook
         return CloudantHook(cloudant_conn_id=self.conn_id)
     elif self.conn_type == 'jira':
         from airflow.providers.jira.hooks.jira import JiraHook
         return JiraHook(jira_conn_id=self.conn_id)
     elif self.conn_type == 'redis':
         from airflow.providers.redis.hooks.redis import RedisHook
         return RedisHook(redis_conn_id=self.conn_id)
     elif self.conn_type == 'wasb':
         from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
         return WasbHook(wasb_conn_id=self.conn_id)
     elif self.conn_type == 'docker':
         from airflow.providers.docker.hooks.docker import DockerHook
         return DockerHook(docker_conn_id=self.conn_id)
     elif self.conn_type == 'azure_data_lake':
         from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook
         return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
     elif self.conn_type == 'azure_cosmos':
         from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook
         return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
     elif self.conn_type == 'cassandra':
         from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
         return CassandraHook(cassandra_conn_id=self.conn_id)
     elif self.conn_type == 'mongo':
         from airflow.providers.mongo.hooks.mongo import MongoHook
         return MongoHook(conn_id=self.conn_id)
     elif self.conn_type == 'gcpcloudsql':
         from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook
         return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
     elif self.conn_type == 'grpc':
         from airflow.providers.grpc.hooks.grpc import GrpcHook
         return GrpcHook(grpc_conn_id=self.conn_id)
     raise AirflowException("Unknown hook type {}".format(self.conn_type))