Exemplo n.º 1
0
    def send(self):
        """
        Send Dingding message
        """
        support_type = ['text', 'link', 'markdown', 'actionCard', 'feedCard']
        if self.message_type not in support_type:
            raise ValueError('DingdingWebhookHook only support {} '
                             'so far, but receive {}'.format(
                                 support_type, self.message_type))

        data = self._build_message()
        self.log.info('Sending Dingding type %s message %s', self.message_type,
                      data)
        resp = self.run(endpoint=self._get_endpoint(),
                        data=data,
                        headers={'Content-Type': 'application/json'})

        # Dingding success send message will with errcode equal to 0
        if int(resp.json().get('errcode')) != 0:
            raise AirflowException(
                'Send Dingding message failed, receive error '
                'message %s', resp.text)
        self.log.info('Success Send Dingding message')
Exemplo n.º 2
0
    def validate_socket_path_length(self) -> None:
        """
        Validates sockets path length.

        :return: None or rises AirflowException
        """
        if self.use_proxy and not self.sql_proxy_use_tcp:
            if self.database_type == 'postgres':
                suffix = "/.s.PGSQL.5432"
            else:
                suffix = ""
            expected_path = "{}/{}:{}:{}{}".format(
                self._generate_unique_path(), self.project_id, self.instance,
                self.database, suffix)
            if len(expected_path) > UNIX_PATH_MAX:
                self.log.info("Too long (%s) path: %s", len(expected_path),
                              expected_path)
                raise AirflowException(
                    "The UNIX socket path length cannot exceed {} characters "
                    "on Linux system. Either use shorter instance/database "
                    "name or switch to TCP connection. "
                    "The socket path for Cloud SQL proxy is now:"
                    "{}".format(UNIX_PATH_MAX, expected_path))
Exemplo n.º 3
0
    def execute(self, context):
        hook = BigtableHook(gcp_conn_id=self.gcp_conn_id)
        instance = hook.get_instance(project_id=self.project_id,
                                     instance_id=self.instance_id)
        if not instance:
            raise AirflowException(
                "Dependency: instance '{}' does not exist.".format(
                    self.instance_id))

        try:
            hook.delete_table(
                project_id=self.project_id,
                instance_id=self.instance_id,
                table_id=self.table_id,
            )
        except google.api_core.exceptions.NotFound:
            # It's OK if table doesn't exists.
            self.log.info(
                "The table '%s' no longer exists. Consider it as deleted",
                self.table_id)
        except google.api_core.exceptions.GoogleAPICallError as e:
            self.log.error('An error occurred. Exiting.')
            raise e
def s3_to_snowflake(logger, file_name, ds_nodash, **kwargs):

    ti = kwargs['ti']

    #table = dict[filename] for filename in dict.keys() if filename = file_name
    schema = "HUBSPOT"
    yyyymmdd = currently_date.strftime("%Y%m%d")

    bucket = 'roofstock-data-lake'
    key = f'Airflow/HUBSPOT/{yyyymmdd}'

    conn = snowflake_hook.get_cursor()
    conn.execute('USE DATABASE RAW_DATA')
    try:
        create_stage_query = f"""CREATE OR REPLACE STAGE RAW_DATA.{schema}.RAW_HUBSPOT_CONTACTS 
                        url = 's3://{bucket}/{key}'
                        credentials=(aws_key_id= '{aws_access_key_id}', aws_secret_key= '{aws_secret_access_key}')
                        file_format=(type=csv skip_header=1 NULL_IF='' FIELD_OPTIONALLY_ENCLOSED_BY='"')"""

        logger.info(create_stage_query)
        conn.execute(create_stage_query)
    except Exception as e:
        raise AirflowException(e)
Exemplo n.º 5
0
 def _extract_xcom(self, pod):
     resp = kubernetes_stream(
         self._client.connect_get_namespaced_pod_exec,
         pod.name,
         pod.namespace,
         container=self.kube_req_factory.SIDECAR_CONTAINER_NAME,
         command=['/bin/sh'],
         stdin=True,
         stdout=True,
         stderr=True,
         tty=False,
         _preload_content=False)
     try:
         result = self._exec_pod_command(
             resp, 'cat {}/return.json'.format(
                 self.kube_req_factory.XCOM_MOUNT_PATH))
         self._exec_pod_command(resp, 'kill -s SIGINT 1')
     finally:
         resp.close()
     if result is None:
         raise AirflowException(
             'Failed to extract xcom from pod: {}'.format(pod.name))
     return result
Exemplo n.º 6
0
    def execute(self, context, *args, **kwargs):
        upstream_tasks = self.get_flat_relatives(upstream=True)
        upstream_task_ids = [task.task_id for task in upstream_tasks]
        print(upstream_task_ids)
        url = f"https://api.telegram.org/bot{self.bot_token}/sendMessage"
        text = ("Пайплайн не завершился\n"
                f"DAG ID: {context['dag'].dag_id}\n"
                f"Task ID: {upstream_task_ids[0]}\n"
                f"Ошибка: {str(self.error_message)}")
        body = {
            "chat_id": self.chat_id,
            "text": text,
        }
        response = requests.post(url,
                                 headers=self.headers,
                                 json=body,
                                 timeout=10,
                                 verify=False)

        print(response.text)

        if not json.loads(response.text)["ok"]:
            raise AirflowException("Не удалось отправить сообщение")
Exemplo n.º 7
0
    def run_pod(self,
                pod: V1Pod,
                startup_timeout: int = 120,
                get_logs: bool = True) -> Tuple[State, Optional[str]]:
        """
        Launches the pod synchronously and waits for completion.

        :param pod:
        :param startup_timeout: Timeout for startup of the pod (if pod is pending for too long, fails task)
        :param get_logs:  whether to query k8s for logs
        :return:
        """
        resp = self.run_pod_async(pod)
        curr_time = dt.now()
        if resp.status.start_time is None:
            while self.pod_not_started(pod):
                delta = dt.now() - curr_time
                if delta.seconds >= startup_timeout:
                    raise AirflowException("Pod took too long to start")
                time.sleep(1)
            self.log.debug('Pod not yet started')

        return self._monitor_pod(pod, get_logs)
Exemplo n.º 8
0
 def _wait_for_export_metadata(self, hook: DataprocMetastoreHook):
     """
     Workaround to check that export was created successfully.
     We discovered a issue to parse result to MetadataExport inside the SDK
     """
     for time_to_wait in exponential_sleep_generator(initial=10, maximum=120):
         sleep(time_to_wait)
         service = hook.get_service(
             region=self.region,
             project_id=self.project_id,
             service_id=self.service_id,
             retry=self.retry,
             timeout=self.timeout,
             metadata=self.metadata,
         )
         activities: MetadataManagementActivity = service.metadata_management_activity
         metadata_export: MetadataExport = activities.metadata_exports[0]
         if metadata_export.state == MetadataExport.State.SUCCEEDED:
             return metadata_export
         if metadata_export.state == MetadataExport.State.FAILED:
             raise AirflowException(
                 f"Exporting metadata from Dataproc Metastore {metadata_export.name} FAILED"
             )
Exemplo n.º 9
0
    def assert_unique_values(self, cols: List[str], conn) -> None:
        with open(
                os.path.join(LOCAL_DIR, 'query_templates',
                             'count_duplicates.sql')) as f:
            template = Template(f.read())

        print('ASSERT UNQUE VALUES IN SUBSET: {}'.format(cols))
        query = template.render(schema=self.schema,
                                table=self.table,
                                cols=cols)
        print(query)

        with conn.cursor() as cur:
            cur.execute(query)
            result = cur.fetchall()[0][0]

        if result == 0:
            print('Success! All rows are unique across columns: {}'.format(
                ', '.join(cols)))
        else:
            raise AirflowException(
                'Uh-oh! Some rows are duplicated across columns {}'.format(
                    ', '.join(cols)))
Exemplo n.º 10
0
def dag_list_dag_runs(args, dag=None):
    """Lists dag runs for a given DAG"""
    if dag:
        args.dag_id = dag.dag_id

    dagbag = DagBag()

    if args.dag_id not in dagbag.dags:
        error_message = "Dag id {} not found".format(args.dag_id)
        raise AirflowException(error_message)

    state = args.state.lower() if args.state else None
    dag_runs = DagRun.find(dag_id=args.dag_id,
                           state=state,
                           no_backfills=args.no_backfill)

    if not dag_runs:
        print('No dag runs for {dag_id}'.format(dag_id=args.dag_id))
        return

    dag_runs.sort(key=lambda x: x.execution_date, reverse=True)
    table = _tabulate_dag_runs(dag_runs, tablefmt=args.output)
    print(table)
Exemplo n.º 11
0
    def parse_job_description(job_id: str, response: Dict) -> Dict:
        """
        Parse job description to extract description for job_id

        :param job_id: a batch job ID
        :type job_id: str

        :param response: an API response for describe jobs
        :type response: Dict

        :return: an API response to describe job_id
        :rtype: Dict

        :raises: AirflowException
        """
        jobs = response.get("jobs", [])
        matching_jobs = [job for job in jobs if job.get("jobId") == job_id]
        if len(matching_jobs) != 1:
            raise AirflowException(
                "AWS Batch job ({}) description error: response: {}".format(
                    job_id, response))

        return matching_jobs[0]
Exemplo n.º 12
0
    def set_machine_type(
        self,
        zone: str,
        resource_id: str,
        body: Dict,
        project_id: Optional[str] = None
    ) -> None:
        """
        Sets machine type of an instance defined by project_id, zone and resource_id.
        Must be called with keyword arguments rather than positional.

        :param zone: Google Cloud Platform zone where the instance exists.
        :type zone: str
        :param resource_id: Name of the Compute Engine instance resource
        :type resource_id: str
        :param body: Body required by the Compute Engine setMachineType API,
            as described in
            https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType
        :type body: dict
        :param project_id: Optional, Google Cloud Platform project ID where the
            Compute Engine Instance exists. If set to None or missing,
            the default project_id from the GCP connection is used.
        :type project_id: str
        :return: None
        """
        if not project_id:
            raise ValueError("The project_id should be set")
        response = self._execute_set_machine_type(zone, resource_id, body, project_id)
        try:
            operation_name = response["name"]
        except KeyError:
            raise AirflowException(
                "Wrong response '{}' returned - it should contain "
                "'name' field".format(response))
        self._wait_for_operation_to_complete(project_id=project_id,
                                             operation_name=operation_name,
                                             zone=zone)
Exemplo n.º 13
0
    def __init__(
        self,
        *,
        share_name: str,
        dest_gcs: str,
        directory_name: Optional[str] = None,
        prefix: str = '',
        wasb_conn_id: str = 'wasb_default',
        gcp_conn_id: str = 'google_cloud_default',
        delegate_to: Optional[str] = None,
        replace: bool = False,
        gzip: bool = False,
        google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.share_name = share_name
        self.directory_name = directory_name
        self.prefix = prefix
        self.wasb_conn_id = wasb_conn_id
        self.gcp_conn_id = gcp_conn_id
        self.dest_gcs = dest_gcs
        self.delegate_to = delegate_to
        self.replace = replace
        self.gzip = gzip
        self.google_impersonation_chain = google_impersonation_chain

        if dest_gcs and not gcs_object_is_directory(self.dest_gcs):
            self.log.info(
                'Destination Google Cloud Storage path is not a valid '
                '"directory", define a path that ends with a slash "/" or '
                'leave it empty for the root of the bucket.'
            )
            raise AirflowException(
                'The destination Google Cloud Storage path must end with a slash "/" or be empty.'
            )
Exemplo n.º 14
0
    def create_database(self,
                        instance_id,
                        database_id,
                        ddl_statements,
                        project_id=None):
        """
        Creates a new database in Cloud Spanner.

        :type project_id: str
        :param instance_id: The ID of the Cloud Spanner instance.
        :type instance_id: str
        :param database_id: The ID of the database to create in Cloud Spanner.
        :type database_id: str
        :param ddl_statements: The string list containing DDL for the new database.
        :type ddl_statements: list[str]
        :param project_id: Optional, the ID of the  GCP project that owns the Cloud Spanner
            database. If set to None or missing, the default project_id from the GCP connection is used.
        :return: None
        """

        instance = self._get_client(project_id=project_id).instance(
            instance_id=instance_id)
        if not instance.exists():
            raise AirflowException(
                "The instance {} does not exist in project {} !".format(
                    instance_id, project_id))
        database = instance.database(database_id=database_id,
                                     ddl_statements=ddl_statements)
        try:
            operation = database.create()  # type: Operation
        except GoogleAPICallError as e:
            self.log.error('An error occurred: %s. Exiting.', e.message)
            raise e

        if operation:
            result = operation.result()
            self.log.info(result)
Exemplo n.º 15
0
    def get_conn(self) -> Connection:
        """Returns a connection object"""
        db = self.get_connection(self.trino_conn_id)  # type: ignore[attr-defined]
        extra = db.extra_dejson
        auth = None
        if db.password and extra.get('auth') == 'kerberos':
            raise AirflowException("Kerberos authorization doesn't support password.")
        elif db.password:
            auth = trino.auth.BasicAuthentication(db.login, db.password)
        elif extra.get('auth') == 'kerberos':
            auth = trino.auth.KerberosAuthentication(
                config=extra.get('kerberos__config', os.environ.get('KRB5_CONFIG')),
                service_name=extra.get('kerberos__service_name'),
                mutual_authentication=_boolify(extra.get('kerberos__mutual_authentication', False)),
                force_preemptive=_boolify(extra.get('kerberos__force_preemptive', False)),
                hostname_override=extra.get('kerberos__hostname_override'),
                sanitize_mutual_error_response=_boolify(
                    extra.get('kerberos__sanitize_mutual_error_response', True)
                ),
                principal=extra.get('kerberos__principal', conf.get('kerberos', 'principal')),
                delegate=_boolify(extra.get('kerberos__delegate', False)),
                ca_bundle=extra.get('kerberos__ca_bundle'),
            )
        trino_conn = trino.dbapi.connect(
            host=db.host,
            port=db.port,
            user=db.login,
            source=extra.get('source', 'airflow'),
            http_scheme=extra.get('protocol', 'http'),
            catalog=extra.get('catalog', 'hive'),
            schema=db.schema,
            auth=auth,
            isolation_level=self.get_isolation_level(),  # type: ignore[func-returns-value]
            verify=_boolify(extra.get('verify', True)),
        )

        return trino_conn
Exemplo n.º 16
0
    def add_only_new(self, upstream_or_downstream_list, task_id):
        """
         This method invoke if operator call ">> | << | set_downstream | set_upstream" to MultiPointGroupConnector.

        :param upstream_or_downstream_list: downstream_task_ids or upstream_task_ids list
        :param task_id: task_id
        :return:
        """

        # find task by task_id
        if self.dag.has_task(task_id):
            task = self.dag.get_task(task_id)

            # if upstream:
            # add task_id to first_task.upstream_task_ids
            # add first_task.task_id to task.downstream_task_ids
            if upstream_or_downstream_list is self.upstream_task_ids:
                for first_task in self._first_tasks:
                    super(MultiPointGroupConnector,
                          self).add_only_new(first_task.upstream_task_ids,
                                             task_id)
                    task.add_only_new(task.downstream_task_ids,
                                      first_task.task_id)

            # if downstream:
            # add task_id to last_task.downstream_task_ids
            # add last_task.task_id to task.upstream_task_ids
            elif upstream_or_downstream_list is self.downstream_task_ids:
                for last_task in self._last_tasks:
                    super(MultiPointGroupConnector,
                          self).add_only_new(last_task.downstream_task_ids,
                                             task_id)
                    task.add_only_new(task.upstream_task_ids,
                                      last_task.task_id)
        else:
            raise AirflowException('The {} dag should contain {} task_id'
                                   ''.format(self.dag.dag_id, task_id))
Exemplo n.º 17
0
    def execute(self, context):
        gcs_hook = GCSHook(
            gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to
        )

        sftp_hook = SFTPHook(self.sftp_conn_id)

        if WILDCARD in self.source_object:
            total_wildcards = self.source_object.count(WILDCARD)
            if total_wildcards > 1:
                raise AirflowException(
                    "Only one wildcard '*' is allowed in source_object parameter. "
                    "Found {} in {}.".format(total_wildcards, self.source_object)
                )

            prefix, delimiter = self.source_object.split(WILDCARD, 1)
            objects = gcs_hook.list(
                self.source_bucket, prefix=prefix, delimiter=delimiter
            )

            for source_object in objects:
                destination_path = os.path.join(self.destination_path, source_object)
                self._copy_single_object(
                    gcs_hook, sftp_hook, source_object, destination_path
                )

            self.log.info(
                "Done. Uploaded '%d' files to %s", len(objects), self.destination_path
            )
        else:
            destination_path = os.path.join(self.destination_path, self.source_object)
            self._copy_single_object(
                gcs_hook, sftp_hook, self.source_object, destination_path
            )
            self.log.info(
                "Done. Uploaded '%s' file to %s", self.source_object, destination_path
            )
Exemplo n.º 18
0
    def _get_credential_parameters(self, session):
        connection = session.query(Connection). \
            filter(Connection.conn_id == self.gcp_conn_id).first()
        session.expunge_all()
        if GCP_CREDENTIALS_KEY_PATH in connection.extra_dejson:
            credential_params = [
                '-credential_file',
                connection.extra_dejson[GCP_CREDENTIALS_KEY_PATH]
            ]
        elif GCP_CREDENTIALS_KEYFILE_DICT in connection.extra_dejson:
            credential_file_content = json.loads(
                connection.extra_dejson[GCP_CREDENTIALS_KEYFILE_DICT])
            self.log.info("Saving credentials to {}".format(
                self.credentials_path))
            with open(self.credentials_path, "w") as f:
                json.dump(credential_file_content, f)
            credential_params = ['-credential_file', self.credentials_path]
        else:
            self.log.info(
                "The credentials are not supplied by neither key_path nor "
                "keyfile_dict of the gcp connection {}. Falling back to "
                "default activated account".format(self.gcp_conn_id))
            credential_params = []

        if not self.instance_specification:
            project_id = connection.extra_dejson.get(
                'extra__google_cloud_platform__project')
            if self.project_id:
                project_id = self.project_id
            if not project_id:
                raise AirflowException(
                    "For forwarding all instances, the project id "
                    "for GCP should be provided either "
                    "by project_id extra in the GCP connection or by "
                    "project_id provided in the operator.")
            credential_params.extend(['-projects', project_id])
        return credential_params
Exemplo n.º 19
0
    def delete_database(self, instance_id, database_id, project_id=None):
        """
        Drops a database in Cloud Spanner.

        :type project_id: str
        :param instance_id: The ID of the Cloud Spanner instance.
        :type instance_id: str
        :param database_id: The ID of the database in Cloud Spanner.
        :type database_id: str
        :param project_id: Optional, the ID of the  GCP project that owns the Cloud Spanner
            database. If set to None or missing, the default project_id from the GCP connection is used.
        :return: True if everything succeeded
        :rtype: bool
        """

        instance = self._get_client(project_id=project_id).\
            instance(instance_id=instance_id)
        if not instance.exists():
            raise AirflowException(
                "The instance {} does not exist in project {} !".format(
                    instance_id, project_id))
        database = instance.database(database_id=database_id)
        if not database.exists():
            self.log.info(
                "The database {} is already deleted from instance {}. "
                "Exiting.".format(database_id, instance_id))
            return
        try:
            operation = database.drop()  # type: Operation
        except GoogleAPICallError as e:
            self.log.error('An error occurred: %s. Exiting.', e.message)
            raise e

        if operation:
            result = operation.result()
            self.log.info(result)
        return
Exemplo n.º 20
0
    def index(self, index, doc_type, body={}):
        """Create new document."""
        session = self.get_conn({})

        query = json.dumps(body, default=dt_to_json)
        url = '{}/{}/{}'.format(self.base_url, index, doc_type)

        print('query: ' + query)
        # 由于requests里使用的json.dumps方法没法设置default参数, 序列化datetime会有问题
        req = requests.Request('POST',
                               url,
                               data=query,
                               headers={'Content-Type': 'application/json'})
        prep_req = session.prepare_request(req)

        resp = session.send(prep_req)

        try:
            resp.raise_for_status()
        except requests.exceptions.HTTPError:
            logging.error("HTTP error: " + resp.reason)
            raise AirflowException(str(resp.status_code) + ":" + resp.reason)

        print('resp[{}]: {}'.format(resp.status_code, resp.text))
Exemplo n.º 21
0
    def __init__(
            self,
            gen_callable=None,  # Callable
            worker_id=0,  # int
            num_workers=0,  # int
            data=None,  # dict (templated)
            conn_id=None,  # string for connection
            endpoint="",  # string for endpoint
            headers=None,  # dict with http headers
            cache="",  # directory location for storing results
            log_response=False,
            num_http_tries=1,  # int
            xcom_push=False,
            num_threads=8,  # default threading for low-compute jobs
            validate_output=None,  # callable with response as parameter
            try_number=1,
            *args,
            **kwargs):
        super(CloudRunBatchOperator, self).__init__(*args, **kwargs)
        if not callable(gen_callable):
            raise AirflowException('`gen_callable` param must be callable')

        self.gen_callable = gen_callable
        self.worker_id = worker_id
        self.num_workers = num_workers
        self.data = data
        self.conn_id = conn_id
        self.endpoint = endpoint
        self.headers = headers or {}
        self.log_response = log_response
        self.xcom_push_flag = xcom_push
        self.num_http_tries = num_http_tries
        self.num_threads = num_threads
        self.validate_output = validate_output
        self.cache = cache
        self.try_number = try_number
Exemplo n.º 22
0
    def list_pipelines(
        self,
        instance_url: str,
        artifact_name: Optional[str] = None,
        artifact_version: Optional[str] = None,
        namespace: str = "default",
    ) -> Dict[Any, Any]:
        """
        Lists Cloud Data Fusion pipelines.

        :param artifact_version: Artifact version to filter instances
        :type artifact_version: Optional[str]
        :param artifact_name: Artifact name to filter instances
        :type artifact_name: Optional[str]
        :param instance_url: Endpoint on which the REST APIs is accessible for the instance.
        :type instance_url: str
        :param namespace: f your pipeline belongs to a Basic edition instance, the namespace ID
            is always default. If your pipeline belongs to an Enterprise edition instance, you
            can create a namespace.
        :type namespace: str
        """
        url = os.path.join(instance_url, "v3", "namespaces", namespace, "apps")
        query: Dict[str, str] = {}
        if artifact_name:
            query = {"artifactName": artifact_name}
        if artifact_version:
            query = {"artifactVersion": artifact_version}
        if query:
            url = os.path.join(url, urlencode(query))

        response = self._cdap_request(url=url, method="GET", body=None)
        if response.status != 200:
            raise AirflowException(
                f"Listing pipelines failed with code {response.status}"
            )
        return json.loads(response.data)
Exemplo n.º 23
0
    def client(self) -> hvac.Client:
        """
        Return an authenticated Hashicorp Vault client
        """

        _client = hvac.Client(url=self.url, **self.kwargs)
        if self.auth_type == "token":
            if not self.token:
                raise VaultError("token cannot be None for auth_type='token'")
            _client.token = self.token
        elif self.auth_type == "ldap":
            _client.auth.ldap.login(username=self.username,
                                    password=self.password)
        elif self.auth_type == "userpass":
            _client.auth_userpass(username=self.username,
                                  password=self.password)
        elif self.auth_type == "approle":
            _client.auth_approle(role_id=self.role_id,
                                 secret_id=self.secret_id)
        elif self.auth_type == "github":
            _client.auth.github.login(token=self.token)
        elif self.auth_type == "gcp":
            from airflow.providers.google.cloud.utils.credentials_provider import (
                get_credentials_and_project_id, _get_scopes)
            scopes = _get_scopes(self.gcp_scopes)
            credentials, _ = get_credentials_and_project_id(
                key_path=self.gcp_key_path, scopes=scopes)
            _client.auth.gcp.configure(credentials=credentials)
        else:
            raise AirflowException(
                f"Authentication type '{self.auth_type}' not supported")

        if _client.is_authenticated():
            return _client
        else:
            raise VaultError("Vault Authentication Error!")
Exemplo n.º 24
0
    def _wait_for_operation_to_complete(self,
                                        project_id,
                                        operation_name,
                                        zone=None):
        """
        Waits for the named operation to complete - checks status of the async call.

        :param operation_name: name of the operation
        :type operation_name: str
        :param zone: optional region of the request (might be None for global operations)
        :type zone: str
        :return: None
        """
        service = self.get_conn()
        while True:
            if zone is None:
                # noinspection PyTypeChecker
                operation_response = self._check_global_operation_status(
                    service, operation_name, project_id)
            else:
                # noinspection PyTypeChecker
                operation_response = self._check_zone_operation_status(
                    service, operation_name, project_id, zone,
                    self.num_retries)
            if operation_response.get("status") == GceOperationStatus.DONE:
                error = operation_response.get("error")
                if error:
                    code = operation_response.get("httpErrorStatusCode")
                    msg = operation_response.get("httpErrorMessage")
                    # Extracting the errors list as string and trimming square braces
                    error_msg = str(error.get("errors"))[1:-1]
                    raise AirflowException("{} {}: ".format(code, msg) +
                                           error_msg)
                # No meaningful info to return from the response in case of success
                return
            time.sleep(TIME_TO_SLEEP_IN_SECONDS)
Exemplo n.º 25
0
    def stop_proxy(self):
        """
        Stops running proxy.

        You should stop the proxy after you stop using it.
        """
        if not self.sql_proxy_process:
            raise AirflowException("The sql proxy is not started yet")
        else:
            self.log.info("Stopping the cloud_sql_proxy pid: {}".format(
                self.sql_proxy_process.pid))
            self.sql_proxy_process.kill()
            self.sql_proxy_process = None
        # Cleanup!
        self.log.info("Removing the socket directory: {}".format(
            self.cloud_sql_proxy_socket_directory))
        shutil.rmtree(self.cloud_sql_proxy_socket_directory,
                      ignore_errors=True)
        if self.sql_proxy_was_downloaded:
            self.log.info("Removing downloaded proxy: {}".format(
                self.sql_proxy_path))
            # Silently ignore if the file has already been removed (concurrency)
            try:
                os.remove(self.sql_proxy_path)
            except OSError as e:
                if not e.errno == errno.ENOENT:
                    raise
        else:
            self.log.info(
                "Skipped removing proxy - it was not downloaded: {}".format(
                    self.sql_proxy_path))
        if isfile(self.credentials_path):
            self.log.info("Removing generated credentials file {}".format(
                self.credentials_path))
            # Here file cannot be delete by concurrent task (each task has its own copy)
            os.remove(self.credentials_path)
Exemplo n.º 26
0
 def __init__(self,
              path_prefix: str,
              instance_specification: str,
              gcp_conn_id: str = 'google_cloud_default',
              project_id: Optional[str] = None,
              sql_proxy_version: Optional[str] = None,
              sql_proxy_binary_path: Optional[str] = None) -> None:
     super().__init__()
     self.path_prefix = path_prefix
     if not self.path_prefix:
         raise AirflowException("The path_prefix must not be empty!")
     self.sql_proxy_was_downloaded = False
     self.sql_proxy_version = sql_proxy_version
     self.download_sql_proxy_dir = None
     self.sql_proxy_process = None  # type: Optional[Popen]
     self.instance_specification = instance_specification
     self.project_id = project_id
     self.gcp_conn_id = gcp_conn_id
     self.command_line_parameters = []  # type:  List[str]
     self.cloud_sql_proxy_socket_directory = self.path_prefix
     self.sql_proxy_path = sql_proxy_binary_path if sql_proxy_binary_path \
         else self.path_prefix + "_cloud_sql_proxy"
     self.credentials_path = self.path_prefix + "_credentials.json"
     self._build_command_line_parameters()
Exemplo n.º 27
0
def _draw_nodes(
    node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id: Dict[str, str]
) -> None:
    """Draw the node and its children on the given parent_graph recursively."""
    if isinstance(node, BaseOperator):
        _draw_task(node, parent_graph, states_by_task_id)
    else:
        if not isinstance(node, TaskGroup):
            raise AirflowException(f"The node {node} should be TaskGroup and is not")
        # Draw TaskGroup
        if node.is_root:
            # No need to draw background for root TaskGroup.
            _draw_task_group(node, parent_graph, states_by_task_id)
        else:
            with parent_graph.subgraph(name=f"cluster_{node.group_id}") as sub:
                sub.attr(
                    shape="rectangle",
                    style="filled",
                    color=_refine_color(node.ui_fgcolor),
                    # Partially transparent CornflowerBlue
                    fillcolor="#6495ed7f",
                    label=node.label,
                )
                _draw_task_group(node, sub, states_by_task_id)
Exemplo n.º 28
0
    def execute(self, context):
        self._test_queries = self._load_test_queries(context)

        # actual execution
        self._actual_execution(context)

        # health checking the results
        for test_name, test_query in self._test_queries.items():
            resp = self.dw_hook.get_first(test_query)
            assertion = self._assert_resp(test_name, resp)
            self._format_assertion(assertion, test_name)

        slack_result_attachments, successful_tests_count, total_tests_count = (
            self._aggregate_test_results())
        suite_status = successful_tests_count == total_tests_count

        if not context["test_mode"]:
            # not in test mode, report
            self._report_assertions(
                attachments=slack_result_attachments,
                successful_test_count=successful_tests_count,
                total_tests_count=total_tests_count,
                context=context,
            )
        else:
            self.log.warning(
                f"Reporting to Slack is disabled during test / backfill mode. "
                f"Showing failure result attachments below: ")

            self.log.warning(slack_result_attachments)

        if self.block_on_failure and not suite_status:
            raise AirflowException(
                "Blocking this task due to healthcheck failure. ")

        self.log.warning("Task completed.")
Exemplo n.º 29
0
    def _wait_for_operation_to_complete(self, operation_name):
        """
        Waits for the named operation to complete - checks status of the
        asynchronous call.

        :param operation_name: The name of the operation.
        :type operation_name: str
        :return: The response returned by the operation.
        :rtype: dict
        :exception: AirflowException in case error is returned.
        """
        service = self.get_conn()
        while True:
            operation_response = service.operations().get(
                name=operation_name, ).execute(num_retries=self.num_retries)
            if operation_response.get("done"):
                response = operation_response.get("response")
                error = operation_response.get("error")
                # Note, according to documentation always either response or error is
                # set when "done" == True
                if error:
                    raise AirflowException(str(error))
                return response
            time.sleep(TIME_TO_SLEEP_IN_SECONDS)
Exemplo n.º 30
0
    def execute(self, context: Dict):
        hook = GCSHook(gcp_conn_id=self.gcp_conn_id)

        with NamedTemporaryFile() as source_file, NamedTemporaryFile(
        ) as destination_file:
            self.log.info("Downloading file from %s", self.source_bucket)
            hook.download(bucket_name=self.source_bucket,
                          object_name=self.source_object,
                          filename=source_file.name)

            self.log.info("Starting the transformation")
            cmd = [self.transform_script] if isinstance(
                self.transform_script, str) else self.transform_script
            cmd += [source_file.name, destination_file.name]
            process = subprocess.Popen(args=cmd,
                                       stdout=subprocess.PIPE,
                                       stderr=subprocess.STDOUT,
                                       close_fds=True)
            self.log.info("Process output:")
            for line in iter(process.stdout.readline, b''):
                self.log.info(line.decode(self.output_encoding).rstrip())

            process.wait()
            if process.returncode > 0:
                raise AirflowException("Transform script failed: {0}".format(
                    process.returncode))

            self.log.info(
                "Transformation succeeded. Output temporarily located at %s",
                destination_file.name)

            self.log.info("Uploading file to %s as %s",
                          self.destination_bucket, self.destination_object)
            hook.upload(bucket_name=self.destination_bucket,
                        object_name=self.destination_object,
                        filename=destination_file.name)