Exemple #1
0
    def mongoclient(self):
        if not hasattr(self, "_mc"):
            if self.conn.ssh_conn_id:
                if self.conn.conn_style == "uri":
                    raise Exception(
                        "Cannot have SSH tunnel with uri connection type!")
                if not hasattr(self, "_ssh_hook"):
                    self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id(
                        conn_id=self.conn.ssh_conn_id)
                    self.local_bind_address = self._ssh_hook.start_tunnel(
                        self.conn.host, self.conn.port)
            else:
                self.local_bind_address = (self.conn.host, self.conn.port)

            conn_kwargs = {"tz_aware": True}
            if self.conn.conn_style == "uri":
                conn_kwargs["host"] = self.conn.uri
            else:
                conn_kwargs["host"] = self.local_bind_address[0]
                conn_kwargs["port"] = self.local_bind_address[1]
                if self.conn.username:
                    conn_kwargs["username"] = self.conn.username
                if self.conn.password:
                    conn_kwargs["password"] = self.conn.password

            with TemporaryDirectory() as tmp_dir:
                if self.conn.tls:
                    conn_kwargs["tls"] = True
                with NamedTemporaryFile(dir=tmp_dir) as ssl_cert:
                    with NamedTemporaryFile(dir=tmp_dir) as ssl_private:
                        if self.conn.ssl_cert:
                            ssl_cert.write(self.conn.ssl_cert.encode())
                            ssl_cert.seek(0)
                            conn_kwargs["ssl_certfile"] = os.path.abspath(
                                ssl_cert.name)
                        if self.conn.ssl_private:
                            ssl_private.write(self.conn.ssl_private.encode())
                            ssl_private.seek(0)
                            conn_kwargs["ssl_keyfile"] = os.path.abspath(
                                ssl_private.name)
                        if self.conn.ssl_password:
                            conn_kwargs[
                                "tlsCertificateKeyFilePassword"] = self.conn.ssl_password
                        if self.conn.tls_insecure:
                            conn_kwargs["tlsInsecure"] = True
                        if self.conn.auth_source:
                            conn_kwargs["authSource"] = self.conn.auth_source
                        if self.conn.auth_mechanism:
                            conn_kwargs[
                                "authMechanism"] = self.conn.auth_mechanism
                        self._mc = MongoClient(**conn_kwargs)

        return self._mc
Exemple #2
0
 def dbconn(self):
     if not hasattr(self, "_dbconn"):
         if hasattr(self.conn, "ssh_conn_id") and self.conn.ssh_conn_id:
             if not hasattr(self, "_ssh_hook"):
                 self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id(
                     conn_id=self.conn.ssh_conn_id)
                 self.local_bind_address = self._ssh_hook.start_tunnel(
                     self.conn.host, self.conn.port or self._DEFAULT_PORT)
         else:
             self.local_bind_address = self.conn.host, self.conn.port
         self._dbconn = self._get_db_conn()
     return self._dbconn
Exemple #3
0
 def kickoff_func(schema_full, schema_suffix, dwh_conn_id):
     # kickoff: create new dataset
     schema = schema_full + schema_suffix
     conn = EWAHBaseHook.get_hook_from_conn_id(dwh_conn_id).dbconn
     # delete dataset first if it already exists
     print("Deleting the dataset {0} if it already exists.".format(
         schema))
     conn.delete_dataset(schema,
                         delete_contents=True,
                         not_found_ok=True)
     print("Creating the dataset {0}.".format(schema))
     conn.create_dataset(schema)
     print("Done!")
Exemple #4
0
    def execute(self, context):

        # env to be used in processes later
        env = os.environ.copy()

        # create a new temp folder, all action happens in here
        with TemporaryDirectory(prefix="__ewah_dbt_operator_") as tmp_dir:
            # clone repo into temp directory
            repo_dir = tmp_dir + os.path.sep + "repo"
            if self.repo_type == "git":
                git_hook = EWAHBaseHook.get_hook_from_conn_id(
                    conn_id=self.git_conn_id)
                git_hook.clone_repo(repo_dir, env)
            else:
                raise Exception("Not Implemented!")

            # create a virual environment in temp folder
            venv_folder = tmp_dir + os.path.sep + "venv"
            self.log.info(
                "creating a new virtual environment in {0}...".format(
                    venv_folder, ))
            venv.create(venv_folder, with_pip=True)

            # install dbt into created venv
            self.log.info("installing dbt=={0}".format(self.dbt_version))
            cmd = []
            cmd.append("source {0}/bin/activate".format(venv_folder))
            cmd.append("pip install --quiet --upgrade dbt=={0}".format(
                self.dbt_version))
            cmd.append("dbt --version")
            cmd.append("deactivate")
            assert run_cmd(cmd, env, self.log.info) == 0

            dbt_dir = repo_dir
            if self.subfolder:
                if not self.subfolder[:1] == os.path.sep:
                    self.subfolder = os.path.sep + self.subfolder
                dbt_dir += self.subfolder

            dwh_conn = EWAHBaseHook.get_connection(self.dwh_conn_id)

            # read profile name & create temporary profiles.yml
            project_yml_file = dbt_dir
            if not project_yml_file[-1:] == os.path.sep:
                project_yml_file += os.path.sep
            project_yml_file += "dbt_project.yml"
            project_yml = yaml.load(open(project_yml_file, "r"), Loader=Loader)
            profile_name = project_yml["profile"]
            self.log.info('Creating temp profile "{0}"'.format(profile_name))
            profiles_yml = {
                "config": {
                    "send_anonymous_usage_stats": False,
                    "use_colors": False,  # colors won't be useful in logs
                },
            }
            if self.dwh_engine == EC.DWH_ENGINE_POSTGRES:
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {  # for postgres
                            "type": "postgres",
                            "host": dwh_conn.host,
                            "port": dwh_conn.port or "5432",
                            "user": dwh_conn.login,
                            "pass": dwh_conn.password,
                            "dbname": dwh_conn.schema,
                            "schema": self.schema_name,
                            "threads": self.threads,
                            "keepalives_idle": self.keepalives_idle,
                        },
                    },
                }
            elif self.dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {  # for snowflake
                            "type": "snowflake",
                            "account": dwh_conn.account,
                            "user": dwh_conn.user,
                            "password": dwh_conn.password,
                            "role": dwh_conn.role,
                            "database": self.database_name or dwh_conn.database,
                            "warehouse": dwh_conn.warehouse,
                            "schema": self.schema_name or dwh_conn.schema,
                            "threads": self.threads,
                            "keepalives_idle": self.keepalives_idle,
                        },
                    },
                }
            else:
                raise Exception("DWH Engine not implemented!")

            # run commands with correct profile in the venv in the temp folder
            profiles_yml_name = tmp_dir + os.path.sep + "profiles.yml"
            env["DBT_PROFILES_DIR"] = os.path.abspath(tmp_dir)
            with open(profiles_yml_name, "w") as profiles_file:
                # write profile into profiles.yml file
                yaml.dump(profiles_yml,
                          profiles_file,
                          default_flow_style=False)

                # run dbt commands
                self.log.info("Now running commands dbt!")
                cmd = []
                cmd.append("cd {0}".format(dbt_dir))
                cmd.append("source {0}/bin/activate".format(venv_folder))
                cmd.append("dbt deps")
                [cmd.append("dbt {0}".format(dc)) for dc in self.dbt_commands]
                cmd.append("deactivate")
                assert run_cmd(cmd, env, self.log.info) == 0

            # if applicable: close SSH tunnel
            if hasattr(self, "ssh_tunnel_forwarder"):
                self.log.info("Stopping!")
                self.ssh_tunnel_forwarder.stop()
                del self.ssh_tunnel_forwarder
Exemple #5
0
    def execute(self, context):

        # env to be used in processes later
        env = os.environ.copy()
        env["PIP_USER"] = "******"

        # create a new temp folder, all action happens in here
        with TemporaryDirectory(prefix="__ewah_dbt_operator_") as tmp_dir:
            # clone repo into temp directory
            repo_dir = tmp_dir + os.path.sep + "repo"
            if self.repo_type == "git":
                # Clone repo into temp folder
                git_hook = EWAHBaseHook.get_hook_from_conn_id(
                    conn_id=self.git_conn_id)
                git_hook.clone_repo(repo_dir, env)
            elif self.repo_type == "local":
                # Copy local version of the repository into temp folder
                copy_tree(self.local_path, repo_dir)
            else:
                raise Exception("Not Implemented!")

            # create a virual environment in temp folder
            venv_folder = tmp_dir + os.path.sep + "venv"
            self.log.info(
                "creating a new virtual environment in {0}...".format(
                    venv_folder, ))
            venv.create(venv_folder, with_pip=True)
            dbt_dir = repo_dir
            if self.subfolder:
                if not self.subfolder[:1] == os.path.sep:
                    self.subfolder = os.path.sep + self.subfolder
                dbt_dir += self.subfolder

            dwh_hook = EWAHBaseHook.get_hook_from_conn_id(self.dwh_conn_id)
            # in case of SSH: execute a query to create the connection and tunnel
            dwh_hook.execute("SELECT 1 AS a -- Testing the connection")
            dwh_conn = dwh_hook.conn

            # read profile name and dbt version & create temporary profiles.yml
            project_yml_file = dbt_dir
            if not project_yml_file[-1:] == os.path.sep:
                project_yml_file += os.path.sep
            project_yml_file += "dbt_project.yml"
            project_yml = yaml.load(open(project_yml_file, "r"), Loader=Loader)
            profile_name = project_yml["profile"]
            dbt_version = self.dbt_version or project_yml.get(
                "require-dbt-version")
            del self.dbt_version  # Make sure it can't accidentally be used below
            assert dbt_version, "Must supply dbt_version or set require-dbt-version!"
            if isinstance(dbt_version, str):
                if not dbt_version.startswith(("=", "<", ">")):
                    dbt_version = "==" + dbt_version
            elif isinstance(dbt_version, list):
                dbt_version = ",".join(dbt_version)
            else:
                raise Exception(
                    "dbt_version must be a string or a list of strings!")
            self.log.info('Creating temp profile "{0}"'.format(profile_name))
            profiles_yml = {
                "config": {
                    "send_anonymous_usage_stats": False,
                    "use_colors": False,  # colors won't be useful in logs
                },
            }
            if self.dwh_engine == EC.DWH_ENGINE_POSTGRES:
                mb_database = dwh_conn.schema
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {  # for postgres
                            "type": "postgres",
                            "host": dwh_hook.local_bind_address[0],
                            "port": dwh_hook.local_bind_address[1],
                            "user": dwh_conn.login,
                            "pass": dwh_conn.password,
                            "dbname": dwh_conn.schema,
                            "schema": self.schema_name,
                            "threads": self.threads,
                            "keepalives_idle": self.keepalives_idle,
                        },
                    },
                }
            elif self.dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
                mb_database = self.database_name or dwh_conn.database
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {  # for snowflake
                            "type": "snowflake",
                            "account": dwh_conn.account,
                            "user": dwh_conn.user,
                            "password": dwh_conn.password,
                            "role": dwh_conn.role,
                            "database": self.database_name or dwh_conn.database,
                            "warehouse": dwh_conn.warehouse,
                            "schema": self.schema_name or dwh_conn.schema,
                            "threads": self.threads,
                            "keepalives_idle": self.keepalives_idle,
                        },
                    },
                }
            elif self.dwh_engine == EC.DWH_ENGINE_BIGQUERY:
                mb_database = self.database_name
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {
                            "type":
                            "bigquery",
                            "method":
                            "service-account-json",
                            "project":
                            self.database_name,
                            "dataset":
                            self.schema_name,
                            "threads":
                            self.threads,
                            "timeout_seconds":
                            self.keepalives_idle or 300,
                            "priority":
                            "interactive",
                            "keyfile_json":
                            json.loads(dwh_conn.service_account_json),
                        },
                    },
                }
                if dwh_conn.location:
                    profiles_yml[profile_name]["outputs"]["prod"][
                        "location"] = dwh_conn.location
            else:
                raise Exception("DWH Engine not implemented!")

            # install dbt into created venv
            cmd = []
            cmd.append("source {0}/bin/activate".format(venv_folder))
            cmd.append("pip install --quiet --upgrade pip setuptools")
            if re.search("[^0-9\.]0(\.[0-9]+)?(\.[0-9]+)?$", dbt_version):
                # regex checks whether the (last) version start with 0
                # if true, version <1.0.0 required
                cmd.append(
                    'pip install --quiet --upgrade "MarkupSafe<=2.0.1" "dbt{0}"'
                    .format(dbt_version))
            else:
                # Different pip behavior since dbt 1.0.0
                cmd.append(
                    'pip install --quiet --upgrade "MarkupSafe<=2.0.1" "dbt-{0}{1}"'
                    .format(
                        {
                            EC.DWH_ENGINE_POSTGRES: "postgres",
                            EC.DWH_ENGINE_SNOWFLAKE: "snowflake",
                            EC.DWH_ENGINE_BIGQUERY: "bigquery",
                        }[self.dwh_engine],
                        dbt_version,
                    ))

            cmd.append("dbt --version")
            cmd.append("deactivate")
            assert run_cmd(cmd, env, self.log.info) == 0

            # run commands with correct profile in the venv in the temp folder
            profiles_yml_name = tmp_dir + os.path.sep + "profiles.yml"
            env["DBT_PROFILES_DIR"] = os.path.abspath(tmp_dir)
            with open(profiles_yml_name, "w") as profiles_file:
                # write profile into profiles.yml file
                yaml.dump(profiles_yml,
                          profiles_file,
                          default_flow_style=False)

                # run dbt commands
                self.log.info("Now running commands dbt!")
                cmd = []
                cmd.append("cd {0}".format(dbt_dir))
                cmd.append("source {0}/bin/activate".format(venv_folder))
                if self.repo_type == "local":
                    cmd.append("dbt clean")
                cmd.append("dbt deps")
                [cmd.append("dbt {0}".format(dc)) for dc in self.dbt_commands]
                cmd.append("deactivate")
                assert run_cmd(cmd, env, self.log.info) == 0

                if self.metabase_conn_id:
                    # Push docs to Metabase at the end of the run!
                    metabase_hook = EWAHBaseHook.get_hook_from_conn_id(
                        conn_id=self.metabase_conn_id)
                    metabase_hook.push_dbt_docs_to_metabase(
                        dbt_project_path=dbt_dir,
                        dbt_database_name=mb_database,
                    )
Exemple #6
0
        def final_func(schema_name, schema_suffix, dwh_conn_id):
            # final: move new data into the final dataset
            conn = EWAHBaseHook.get_hook_from_conn_id(dwh_conn_id).dbconn
            # get dataset objects
            try:  # create final dataset if not exists
                ds_final = conn.get_dataset(schema_name)
            except:
                print("Creating dataset {0}".format(schema_name))
                ds_final = conn.create_dataset(schema_name)
            ds_temp = conn.get_dataset(schema_name + schema_suffix)

            # copy all tables from temp dataset to final dataset
            new_tables = conn.list_tables(ds_temp)
            new_table_ids = [
                table.table_id for table in conn.list_tables(ds_temp)
            ]
            old_table_ids = [
                table.table_id for table in conn.list_tables(ds_final)
            ]
            copy_jobs = []
            for table in new_tables:
                print("Copying table {0} from temp to final dataset".format(
                    table.table_id))
                try:
                    old_table = conn.get_table(table=TableReference(
                        dataset_ref=ds_final, table_id=table.table_id))
                    conn.delete_table(old_table)
                except:
                    # ignore failure, fails if old table does not exist to begin with
                    pass
                finally:
                    final_table = ds_final.table(table.table_id)
                    copy_jobs.append(conn.copy_table(table, final_table))

            # delete tables that don't exist in temp dataset from final dataset
            for table_id in old_table_ids:
                if not table_id in new_table_ids:
                    print("Deleting table {0}".format(table_id))
                    conn.delete_table(
                        conn.get_table(
                            TableReference(dataset_ref=ds_final,
                                           table_id=table_id)))

            # make sure all copy jobs succeeded
            while copy_jobs:
                sleep(0.1)
                job = copy_jobs.pop(0)
                job.result()
                assert job.state in ("RUNNING", "DONE")
                if job.state == "RUNNING":
                    copy_jobs.append(job)
                else:
                    print("Successfully copied {0}".format(
                        job.__dict__["_properties"]["configuration"]["copy"]
                        ["destinationTable"]["tableId"]))

            # delete temp dataset
            print("Deleting temp dataset.")
            conn.delete_dataset(ds_temp,
                                delete_contents=True,
                                not_found_ok=False)

            print("Done.")
Exemple #7
0
    def start_tunnel(self,
                     remote_host: str,
                     remote_port: int,
                     tunnel_timeout: Optional[int] = 30) -> Tuple[str, int]:
        """Starts the SSH tunnel with port forwarding. Returns the host and port tuple
        that can be used to connect to the remote.

        :param remote_host: Host of the remote that the tunnel should port forward to.
            Can be "localhost" e.g. if tunneling into a server that hosts a database.
        :param remote_port: Port that goes along with remote_host.
        :param tunnel_timeout: Optional timeout setting. Supply a higher number if the
            default (30s) is too low.
        :returns: Local bind address aka tuple of local_bind_host and local_bind_port.
            Calls to local_bind_host:local_bind_port will be forwarded to
            remote_host:remote_port via the SSH tunnel.
        """

        if not hasattr(self, "_ssh_tunnel_forwarder"):
            # Tunnel is not started yet - start it now!

            # Set a specific tunnel timeout if applicable
            if tunnel_timeout:
                old_timeout = sshtunnel.TUNNEL_TIMEOUT
                sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout

            try:
                # Build kwargs dict for SSH Tunnel Forwarder
                if self.conn.ssh_proxy_server:
                    # Use the proxy SSH server as target
                    self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id(
                        conn_id=self.conn.ssh_proxy_server, )
                    kwargs = {
                        "ssh_address_or_host":
                        self._ssh_hook.start_tunnel(self.conn.host,
                                                    self.conn.port or 22),
                        "remote_bind_address": (remote_host, remote_port),
                    }
                else:
                    kwargs = {
                        "ssh_address_or_host": (self.conn.host, self.conn.port
                                                or 22),
                        "remote_bind_address": (remote_host, remote_port),
                    }

                if self.conn.username:
                    kwargs["ssh_username"] = self.conn.username
                if self.conn.password:
                    kwargs["ssh_password"] = self.conn.password

                # Save private key in a temporary file, if applicable
                with NamedTemporaryFile() as keyfile:
                    if self.conn.private_key:
                        keyfile.write(self.conn.private_key.encode())
                        keyfile.flush()
                        kwargs["ssh_pkey"] = os.path.abspath(keyfile.name)
                    self.log.info("Opening SSH Tunnel to {0}:{1}...".format(
                        *kwargs["ssh_address_or_host"]))
                    self._ssh_tunnel_forwarder = sshtunnel.SSHTunnelForwarder(
                        **kwargs)
                    self._ssh_tunnel_forwarder.start()
            except:
                # Set package constant back to original setting, if applicable
                if tunnel_timeout:
                    sshtunnel.TUNNEL_TIMEOUT = old_timeout
                raise

        return ("localhost", self._ssh_tunnel_forwarder.local_bind_port)
Exemple #8
0
 def execute(self, context):
     hook = EWAHBaseHook.get_hook_from_conn_id(self.postgres_conn_id)
     hook.execute(sql=self.sql, params=self.parameters, commit=True)
     hook.close()  # SSH tunnel does not close if hook is not closed first
Exemple #9
0
 def execute_snowflake(sql, conn_id, **kwargs):
     hook = EWAHBaseHook.get_hook_from_conn_id(conn_id)
     hook.execute(sql)
     hook.close()
Exemple #10
0
    def get_data_in_batches(self,
                            endpoint,
                            page_size=100,
                            batch_size=10000,
                            data_from=None):
        endpoint = self._ENDPOINTS.get(endpoint, endpoint)
        params = {}
        if data_from:
            assert endpoint == self._ENDPOINT_DAGRUNS
            # get DagRuns that ended since data_from
            params["end_date_gte"] = data_from.isoformat()
            params["order_by"] = "end_date"
        auth = requests.auth.HTTPBasicAuth(self.conn.login, self.conn.password)
        if self.conn.ssh_conn_id:
            ssh_hook = EWAHBaseHook.get_hook_from_conn_id(
                conn_id=self.conn.ssh_conn_id)
            ssh_host = self.conn.host or "localhost"
            ssh_host = ssh_host.replace("https://", "").replace("http://", "")
            ssh_port = self.conn.port
            if not ssh_port:
                if self.conn.protocol == "http":
                    ssh_port = 80
                else:
                    ssh_port = 443
            else:
                ssh_port = int(ssh_port)
            local_bind_address = ssh_hook.start_tunnel(ssh_host, ssh_port)
            host = "{2}://{0}:{1}".format(
                local_bind_address[0],
                str(local_bind_address[1]),
                self.conn.protocol or "http",
            )
        else:
            host = self.conn.host
            if not host.startswith("http"):
                host = self.conn.protocol + "://" + host
        url = self._BASE_URL.format(host, endpoint)
        params["limit"] = page_size
        params["offset"] = 0
        data = []
        i = 0
        while True:
            i += 1
            self.log.info("Making request {0} to {1}...".format(i, url))
            request = requests.get(url, params=params, auth=auth)
            assert request.status_code == 200, request.text
            response = request.json()
            keys = list(response.keys())
            if "total_entries" in keys:
                # Most endpoint use pagination + give "total_entries" for requests
                # The key to get the data from the response may differ from endpoint
                if keys[0] == "total_entries":
                    data_key = keys[1]
                else:
                    data_key = keys[0]
                if not response[data_key]:
                    # You know that you fetched all items if an empty list is returned
                    # (Note: total_entries is not reliable)
                    yield data
                    data = []
                    break
                data += response[data_key]
                if len(data) >= batch_size:
                    yield data
                    data = []
            else:
                # Rare endpoint that does not paginate (usually singletons)
                yield [response]
                break
            params["offset"] = params["offset"] + params["limit"]

        if self.conn.ssh_conn_id:
            ssh_hook.stop_tunnel()
            del ssh_hook