def mongoclient(self): if not hasattr(self, "_mc"): if self.conn.ssh_conn_id: if self.conn.conn_style == "uri": raise Exception( "Cannot have SSH tunnel with uri connection type!") if not hasattr(self, "_ssh_hook"): self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.conn.ssh_conn_id) self.local_bind_address = self._ssh_hook.start_tunnel( self.conn.host, self.conn.port) else: self.local_bind_address = (self.conn.host, self.conn.port) conn_kwargs = {"tz_aware": True} if self.conn.conn_style == "uri": conn_kwargs["host"] = self.conn.uri else: conn_kwargs["host"] = self.local_bind_address[0] conn_kwargs["port"] = self.local_bind_address[1] if self.conn.username: conn_kwargs["username"] = self.conn.username if self.conn.password: conn_kwargs["password"] = self.conn.password with TemporaryDirectory() as tmp_dir: if self.conn.tls: conn_kwargs["tls"] = True with NamedTemporaryFile(dir=tmp_dir) as ssl_cert: with NamedTemporaryFile(dir=tmp_dir) as ssl_private: if self.conn.ssl_cert: ssl_cert.write(self.conn.ssl_cert.encode()) ssl_cert.seek(0) conn_kwargs["ssl_certfile"] = os.path.abspath( ssl_cert.name) if self.conn.ssl_private: ssl_private.write(self.conn.ssl_private.encode()) ssl_private.seek(0) conn_kwargs["ssl_keyfile"] = os.path.abspath( ssl_private.name) if self.conn.ssl_password: conn_kwargs[ "tlsCertificateKeyFilePassword"] = self.conn.ssl_password if self.conn.tls_insecure: conn_kwargs["tlsInsecure"] = True if self.conn.auth_source: conn_kwargs["authSource"] = self.conn.auth_source if self.conn.auth_mechanism: conn_kwargs[ "authMechanism"] = self.conn.auth_mechanism self._mc = MongoClient(**conn_kwargs) return self._mc
def dbconn(self): if not hasattr(self, "_dbconn"): if hasattr(self.conn, "ssh_conn_id") and self.conn.ssh_conn_id: if not hasattr(self, "_ssh_hook"): self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.conn.ssh_conn_id) self.local_bind_address = self._ssh_hook.start_tunnel( self.conn.host, self.conn.port or self._DEFAULT_PORT) else: self.local_bind_address = self.conn.host, self.conn.port self._dbconn = self._get_db_conn() return self._dbconn
def kickoff_func(schema_full, schema_suffix, dwh_conn_id): # kickoff: create new dataset schema = schema_full + schema_suffix conn = EWAHBaseHook.get_hook_from_conn_id(dwh_conn_id).dbconn # delete dataset first if it already exists print("Deleting the dataset {0} if it already exists.".format( schema)) conn.delete_dataset(schema, delete_contents=True, not_found_ok=True) print("Creating the dataset {0}.".format(schema)) conn.create_dataset(schema) print("Done!")
def execute(self, context): # env to be used in processes later env = os.environ.copy() # create a new temp folder, all action happens in here with TemporaryDirectory(prefix="__ewah_dbt_operator_") as tmp_dir: # clone repo into temp directory repo_dir = tmp_dir + os.path.sep + "repo" if self.repo_type == "git": git_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.git_conn_id) git_hook.clone_repo(repo_dir, env) else: raise Exception("Not Implemented!") # create a virual environment in temp folder venv_folder = tmp_dir + os.path.sep + "venv" self.log.info( "creating a new virtual environment in {0}...".format( venv_folder, )) venv.create(venv_folder, with_pip=True) # install dbt into created venv self.log.info("installing dbt=={0}".format(self.dbt_version)) cmd = [] cmd.append("source {0}/bin/activate".format(venv_folder)) cmd.append("pip install --quiet --upgrade dbt=={0}".format( self.dbt_version)) cmd.append("dbt --version") cmd.append("deactivate") assert run_cmd(cmd, env, self.log.info) == 0 dbt_dir = repo_dir if self.subfolder: if not self.subfolder[:1] == os.path.sep: self.subfolder = os.path.sep + self.subfolder dbt_dir += self.subfolder dwh_conn = EWAHBaseHook.get_connection(self.dwh_conn_id) # read profile name & create temporary profiles.yml project_yml_file = dbt_dir if not project_yml_file[-1:] == os.path.sep: project_yml_file += os.path.sep project_yml_file += "dbt_project.yml" project_yml = yaml.load(open(project_yml_file, "r"), Loader=Loader) profile_name = project_yml["profile"] self.log.info('Creating temp profile "{0}"'.format(profile_name)) profiles_yml = { "config": { "send_anonymous_usage_stats": False, "use_colors": False, # colors won't be useful in logs }, } if self.dwh_engine == EC.DWH_ENGINE_POSTGRES: profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { # for postgres "type": "postgres", "host": dwh_conn.host, "port": dwh_conn.port or "5432", "user": dwh_conn.login, "pass": dwh_conn.password, "dbname": dwh_conn.schema, "schema": self.schema_name, "threads": self.threads, "keepalives_idle": self.keepalives_idle, }, }, } elif self.dwh_engine == EC.DWH_ENGINE_SNOWFLAKE: profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { # for snowflake "type": "snowflake", "account": dwh_conn.account, "user": dwh_conn.user, "password": dwh_conn.password, "role": dwh_conn.role, "database": self.database_name or dwh_conn.database, "warehouse": dwh_conn.warehouse, "schema": self.schema_name or dwh_conn.schema, "threads": self.threads, "keepalives_idle": self.keepalives_idle, }, }, } else: raise Exception("DWH Engine not implemented!") # run commands with correct profile in the venv in the temp folder profiles_yml_name = tmp_dir + os.path.sep + "profiles.yml" env["DBT_PROFILES_DIR"] = os.path.abspath(tmp_dir) with open(profiles_yml_name, "w") as profiles_file: # write profile into profiles.yml file yaml.dump(profiles_yml, profiles_file, default_flow_style=False) # run dbt commands self.log.info("Now running commands dbt!") cmd = [] cmd.append("cd {0}".format(dbt_dir)) cmd.append("source {0}/bin/activate".format(venv_folder)) cmd.append("dbt deps") [cmd.append("dbt {0}".format(dc)) for dc in self.dbt_commands] cmd.append("deactivate") assert run_cmd(cmd, env, self.log.info) == 0 # if applicable: close SSH tunnel if hasattr(self, "ssh_tunnel_forwarder"): self.log.info("Stopping!") self.ssh_tunnel_forwarder.stop() del self.ssh_tunnel_forwarder
def execute(self, context): # env to be used in processes later env = os.environ.copy() env["PIP_USER"] = "******" # create a new temp folder, all action happens in here with TemporaryDirectory(prefix="__ewah_dbt_operator_") as tmp_dir: # clone repo into temp directory repo_dir = tmp_dir + os.path.sep + "repo" if self.repo_type == "git": # Clone repo into temp folder git_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.git_conn_id) git_hook.clone_repo(repo_dir, env) elif self.repo_type == "local": # Copy local version of the repository into temp folder copy_tree(self.local_path, repo_dir) else: raise Exception("Not Implemented!") # create a virual environment in temp folder venv_folder = tmp_dir + os.path.sep + "venv" self.log.info( "creating a new virtual environment in {0}...".format( venv_folder, )) venv.create(venv_folder, with_pip=True) dbt_dir = repo_dir if self.subfolder: if not self.subfolder[:1] == os.path.sep: self.subfolder = os.path.sep + self.subfolder dbt_dir += self.subfolder dwh_hook = EWAHBaseHook.get_hook_from_conn_id(self.dwh_conn_id) # in case of SSH: execute a query to create the connection and tunnel dwh_hook.execute("SELECT 1 AS a -- Testing the connection") dwh_conn = dwh_hook.conn # read profile name and dbt version & create temporary profiles.yml project_yml_file = dbt_dir if not project_yml_file[-1:] == os.path.sep: project_yml_file += os.path.sep project_yml_file += "dbt_project.yml" project_yml = yaml.load(open(project_yml_file, "r"), Loader=Loader) profile_name = project_yml["profile"] dbt_version = self.dbt_version or project_yml.get( "require-dbt-version") del self.dbt_version # Make sure it can't accidentally be used below assert dbt_version, "Must supply dbt_version or set require-dbt-version!" if isinstance(dbt_version, str): if not dbt_version.startswith(("=", "<", ">")): dbt_version = "==" + dbt_version elif isinstance(dbt_version, list): dbt_version = ",".join(dbt_version) else: raise Exception( "dbt_version must be a string or a list of strings!") self.log.info('Creating temp profile "{0}"'.format(profile_name)) profiles_yml = { "config": { "send_anonymous_usage_stats": False, "use_colors": False, # colors won't be useful in logs }, } if self.dwh_engine == EC.DWH_ENGINE_POSTGRES: mb_database = dwh_conn.schema profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { # for postgres "type": "postgres", "host": dwh_hook.local_bind_address[0], "port": dwh_hook.local_bind_address[1], "user": dwh_conn.login, "pass": dwh_conn.password, "dbname": dwh_conn.schema, "schema": self.schema_name, "threads": self.threads, "keepalives_idle": self.keepalives_idle, }, }, } elif self.dwh_engine == EC.DWH_ENGINE_SNOWFLAKE: mb_database = self.database_name or dwh_conn.database profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { # for snowflake "type": "snowflake", "account": dwh_conn.account, "user": dwh_conn.user, "password": dwh_conn.password, "role": dwh_conn.role, "database": self.database_name or dwh_conn.database, "warehouse": dwh_conn.warehouse, "schema": self.schema_name or dwh_conn.schema, "threads": self.threads, "keepalives_idle": self.keepalives_idle, }, }, } elif self.dwh_engine == EC.DWH_ENGINE_BIGQUERY: mb_database = self.database_name profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { "type": "bigquery", "method": "service-account-json", "project": self.database_name, "dataset": self.schema_name, "threads": self.threads, "timeout_seconds": self.keepalives_idle or 300, "priority": "interactive", "keyfile_json": json.loads(dwh_conn.service_account_json), }, }, } if dwh_conn.location: profiles_yml[profile_name]["outputs"]["prod"][ "location"] = dwh_conn.location else: raise Exception("DWH Engine not implemented!") # install dbt into created venv cmd = [] cmd.append("source {0}/bin/activate".format(venv_folder)) cmd.append("pip install --quiet --upgrade pip setuptools") if re.search("[^0-9\.]0(\.[0-9]+)?(\.[0-9]+)?$", dbt_version): # regex checks whether the (last) version start with 0 # if true, version <1.0.0 required cmd.append( 'pip install --quiet --upgrade "MarkupSafe<=2.0.1" "dbt{0}"' .format(dbt_version)) else: # Different pip behavior since dbt 1.0.0 cmd.append( 'pip install --quiet --upgrade "MarkupSafe<=2.0.1" "dbt-{0}{1}"' .format( { EC.DWH_ENGINE_POSTGRES: "postgres", EC.DWH_ENGINE_SNOWFLAKE: "snowflake", EC.DWH_ENGINE_BIGQUERY: "bigquery", }[self.dwh_engine], dbt_version, )) cmd.append("dbt --version") cmd.append("deactivate") assert run_cmd(cmd, env, self.log.info) == 0 # run commands with correct profile in the venv in the temp folder profiles_yml_name = tmp_dir + os.path.sep + "profiles.yml" env["DBT_PROFILES_DIR"] = os.path.abspath(tmp_dir) with open(profiles_yml_name, "w") as profiles_file: # write profile into profiles.yml file yaml.dump(profiles_yml, profiles_file, default_flow_style=False) # run dbt commands self.log.info("Now running commands dbt!") cmd = [] cmd.append("cd {0}".format(dbt_dir)) cmd.append("source {0}/bin/activate".format(venv_folder)) if self.repo_type == "local": cmd.append("dbt clean") cmd.append("dbt deps") [cmd.append("dbt {0}".format(dc)) for dc in self.dbt_commands] cmd.append("deactivate") assert run_cmd(cmd, env, self.log.info) == 0 if self.metabase_conn_id: # Push docs to Metabase at the end of the run! metabase_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.metabase_conn_id) metabase_hook.push_dbt_docs_to_metabase( dbt_project_path=dbt_dir, dbt_database_name=mb_database, )
def final_func(schema_name, schema_suffix, dwh_conn_id): # final: move new data into the final dataset conn = EWAHBaseHook.get_hook_from_conn_id(dwh_conn_id).dbconn # get dataset objects try: # create final dataset if not exists ds_final = conn.get_dataset(schema_name) except: print("Creating dataset {0}".format(schema_name)) ds_final = conn.create_dataset(schema_name) ds_temp = conn.get_dataset(schema_name + schema_suffix) # copy all tables from temp dataset to final dataset new_tables = conn.list_tables(ds_temp) new_table_ids = [ table.table_id for table in conn.list_tables(ds_temp) ] old_table_ids = [ table.table_id for table in conn.list_tables(ds_final) ] copy_jobs = [] for table in new_tables: print("Copying table {0} from temp to final dataset".format( table.table_id)) try: old_table = conn.get_table(table=TableReference( dataset_ref=ds_final, table_id=table.table_id)) conn.delete_table(old_table) except: # ignore failure, fails if old table does not exist to begin with pass finally: final_table = ds_final.table(table.table_id) copy_jobs.append(conn.copy_table(table, final_table)) # delete tables that don't exist in temp dataset from final dataset for table_id in old_table_ids: if not table_id in new_table_ids: print("Deleting table {0}".format(table_id)) conn.delete_table( conn.get_table( TableReference(dataset_ref=ds_final, table_id=table_id))) # make sure all copy jobs succeeded while copy_jobs: sleep(0.1) job = copy_jobs.pop(0) job.result() assert job.state in ("RUNNING", "DONE") if job.state == "RUNNING": copy_jobs.append(job) else: print("Successfully copied {0}".format( job.__dict__["_properties"]["configuration"]["copy"] ["destinationTable"]["tableId"])) # delete temp dataset print("Deleting temp dataset.") conn.delete_dataset(ds_temp, delete_contents=True, not_found_ok=False) print("Done.")
def start_tunnel(self, remote_host: str, remote_port: int, tunnel_timeout: Optional[int] = 30) -> Tuple[str, int]: """Starts the SSH tunnel with port forwarding. Returns the host and port tuple that can be used to connect to the remote. :param remote_host: Host of the remote that the tunnel should port forward to. Can be "localhost" e.g. if tunneling into a server that hosts a database. :param remote_port: Port that goes along with remote_host. :param tunnel_timeout: Optional timeout setting. Supply a higher number if the default (30s) is too low. :returns: Local bind address aka tuple of local_bind_host and local_bind_port. Calls to local_bind_host:local_bind_port will be forwarded to remote_host:remote_port via the SSH tunnel. """ if not hasattr(self, "_ssh_tunnel_forwarder"): # Tunnel is not started yet - start it now! # Set a specific tunnel timeout if applicable if tunnel_timeout: old_timeout = sshtunnel.TUNNEL_TIMEOUT sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout try: # Build kwargs dict for SSH Tunnel Forwarder if self.conn.ssh_proxy_server: # Use the proxy SSH server as target self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.conn.ssh_proxy_server, ) kwargs = { "ssh_address_or_host": self._ssh_hook.start_tunnel(self.conn.host, self.conn.port or 22), "remote_bind_address": (remote_host, remote_port), } else: kwargs = { "ssh_address_or_host": (self.conn.host, self.conn.port or 22), "remote_bind_address": (remote_host, remote_port), } if self.conn.username: kwargs["ssh_username"] = self.conn.username if self.conn.password: kwargs["ssh_password"] = self.conn.password # Save private key in a temporary file, if applicable with NamedTemporaryFile() as keyfile: if self.conn.private_key: keyfile.write(self.conn.private_key.encode()) keyfile.flush() kwargs["ssh_pkey"] = os.path.abspath(keyfile.name) self.log.info("Opening SSH Tunnel to {0}:{1}...".format( *kwargs["ssh_address_or_host"])) self._ssh_tunnel_forwarder = sshtunnel.SSHTunnelForwarder( **kwargs) self._ssh_tunnel_forwarder.start() except: # Set package constant back to original setting, if applicable if tunnel_timeout: sshtunnel.TUNNEL_TIMEOUT = old_timeout raise return ("localhost", self._ssh_tunnel_forwarder.local_bind_port)
def execute(self, context): hook = EWAHBaseHook.get_hook_from_conn_id(self.postgres_conn_id) hook.execute(sql=self.sql, params=self.parameters, commit=True) hook.close() # SSH tunnel does not close if hook is not closed first
def execute_snowflake(sql, conn_id, **kwargs): hook = EWAHBaseHook.get_hook_from_conn_id(conn_id) hook.execute(sql) hook.close()
def get_data_in_batches(self, endpoint, page_size=100, batch_size=10000, data_from=None): endpoint = self._ENDPOINTS.get(endpoint, endpoint) params = {} if data_from: assert endpoint == self._ENDPOINT_DAGRUNS # get DagRuns that ended since data_from params["end_date_gte"] = data_from.isoformat() params["order_by"] = "end_date" auth = requests.auth.HTTPBasicAuth(self.conn.login, self.conn.password) if self.conn.ssh_conn_id: ssh_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.conn.ssh_conn_id) ssh_host = self.conn.host or "localhost" ssh_host = ssh_host.replace("https://", "").replace("http://", "") ssh_port = self.conn.port if not ssh_port: if self.conn.protocol == "http": ssh_port = 80 else: ssh_port = 443 else: ssh_port = int(ssh_port) local_bind_address = ssh_hook.start_tunnel(ssh_host, ssh_port) host = "{2}://{0}:{1}".format( local_bind_address[0], str(local_bind_address[1]), self.conn.protocol or "http", ) else: host = self.conn.host if not host.startswith("http"): host = self.conn.protocol + "://" + host url = self._BASE_URL.format(host, endpoint) params["limit"] = page_size params["offset"] = 0 data = [] i = 0 while True: i += 1 self.log.info("Making request {0} to {1}...".format(i, url)) request = requests.get(url, params=params, auth=auth) assert request.status_code == 200, request.text response = request.json() keys = list(response.keys()) if "total_entries" in keys: # Most endpoint use pagination + give "total_entries" for requests # The key to get the data from the response may differ from endpoint if keys[0] == "total_entries": data_key = keys[1] else: data_key = keys[0] if not response[data_key]: # You know that you fetched all items if an empty list is returned # (Note: total_entries is not reliable) yield data data = [] break data += response[data_key] if len(data) >= batch_size: yield data data = [] else: # Rare endpoint that does not paginate (usually singletons) yield [response] break params["offset"] = params["offset"] + params["limit"] if self.conn.ssh_conn_id: ssh_hook.stop_tunnel() del ssh_hook