예제 #1
0
    def __init__(self,
                 job_name='{{task.task_id}}_{{ds_nodash}}',
                 cluster_name="cluster-1",
                 dataproc_properties=None,
                 dataproc_jars=None,
                 gcp_conn_id='google_cloud_default',
                 delegate_to=None,
                 labels=None,
                 region='global',
                 job_error_states=None,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gcp_conn_id = gcp_conn_id
        self.delegate_to = delegate_to
        self.labels = labels
        self.job_name = job_name
        self.cluster_name = cluster_name
        self.dataproc_properties = dataproc_properties
        self.dataproc_jars = dataproc_jars
        self.region = region
        self.job_error_states = job_error_states if job_error_states is not None else {
            'ERROR'
        }

        self.hook = DataProcHook(gcp_conn_id=gcp_conn_id,
                                 delegate_to=delegate_to)
        self.job_template = None
        self.job = None
        self.dataproc_job_id = None
예제 #2
0
class DataprocOperationBaseOperator(BaseOperator):
    """
    The base class for operators that poll on a Dataproc Operation.
    """
    @apply_defaults
    def __init__(self,
                 project_id,
                 region='global',
                 gcp_conn_id='google_cloud_default',
                 delegate_to=None,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gcp_conn_id = gcp_conn_id
        self.delegate_to = delegate_to
        self.project_id = project_id
        self.region = region
        self.hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
                                 delegate_to=self.delegate_to,
                                 api_version='v1beta2')

    def execute(self, context):
        # pylint: disable=no-value-for-parameter
        self.hook.wait(self.start())

    def start(self, context):
        """
        You are expected to override the method.
        """
        raise AirflowException('Please submit an operation')
예제 #3
0
class TestDataProcHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
                        new=mock_init):
            self.dataproc_hook = DataProcHook()

    @mock.patch("airflow.gcp.hooks.dataproc.DataProcHook._authorize")
    @mock.patch("airflow.gcp.hooks.dataproc.build")
    def test_dataproc_client_creation(self, mock_build, mock_authorize):
        result = self.dataproc_hook.get_conn()
        mock_build.assert_called_once_with('dataproc',
                                           'v1beta2',
                                           http=mock_authorize.return_value,
                                           cache_discovery=False)
        self.assertEqual(mock_build.return_value, result)

    @mock.patch(DATAPROC_STRING.format('_DataProcJob'))
    def test_submit(self, job_mock):
        with mock.patch(
                DATAPROC_STRING.format('DataProcHook.get_conn',
                                       return_value=None)):
            self.dataproc_hook.submit(GCP_PROJECT_ID_HOOK_UNIT_TEST, JOB)
            job_mock.assert_called_once_with(mock.ANY,
                                             GCP_PROJECT_ID_HOOK_UNIT_TEST,
                                             JOB,
                                             GCP_REGION,
                                             job_error_states=mock.ANY,
                                             num_retries=mock.ANY)
예제 #4
0
 def __init__(self,
              project_id,
              region='global',
              gcp_conn_id='google_cloud_default',
              delegate_to=None,
              *args,
              **kwargs):
     super().__init__(*args, **kwargs)
     self.gcp_conn_id = gcp_conn_id
     self.delegate_to = delegate_to
     self.project_id = project_id
     self.region = region
     self.hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
                              delegate_to=self.delegate_to,
                              api_version='v1beta2')
예제 #5
0
class TestDataProcHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
                        new=mock_init):
            self.dataproc_hook = DataProcHook()

    @mock.patch(DATAPROC_STRING.format('_DataProcJob'))
    def test_submit(self, job_mock):
        with mock.patch(
                DATAPROC_STRING.format('DataProcHook.get_conn',
                                       return_value=None)):
            self.dataproc_hook.submit(GCP_PROJECT_ID_HOOK_UNIT_TEST, JOB)
            job_mock.assert_called_once_with(mock.ANY,
                                             GCP_PROJECT_ID_HOOK_UNIT_TEST,
                                             JOB,
                                             GCP_REGION,
                                             job_error_states=mock.ANY,
                                             num_retries=mock.ANY)
 def setUp(self):
     with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
                     new=mock_init):
         self.dataproc_hook = DataProcHook()
예제 #7
0
class DataProcJobBaseOperator(BaseOperator):
    """
    The base class for operators that launch job on DataProc.

    :param job_name: The job name used in the DataProc cluster. This name by default
        is the task_id appended with the execution data, but can be templated. The
        name will always be appended with a random number to avoid name clashes.
    :type job_name: str
    :param cluster_name: The name of the DataProc cluster.
    :type cluster_name: str
    :param dataproc_properties: Map for the Hive properties. Ideal to put in
        default arguments (templated)
    :type dataproc_properties: dict
    :param dataproc_jars: HCFS URIs of jar files to add to the CLASSPATH of the Hive server and Hadoop
        MapReduce (MR) tasks. Can contain Hive SerDes and UDFs. (templated)
    :type dataproc_jars: list
    :param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform.
    :type gcp_conn_id: str
    :param delegate_to: The account to impersonate, if any.
        For this to work, the service account making the request must have domain-wide
        delegation enabled.
    :type delegate_to: str
    :param labels: The labels to associate with this job. Label keys must contain 1 to 63 characters,
        and must conform to RFC 1035. Label values may be empty, but, if present, must contain 1 to 63
        characters, and must conform to RFC 1035. No more than 32 labels can be associated with a job.
    :type labels: dict
    :param region: The specified region where the dataproc cluster is created.
    :type region: str
    :param job_error_states: Job states that should be considered error states.
        Any states in this set will result in an error being raised and failure of the
        task. Eg, if the ``CANCELLED`` state should also be considered a task failure,
        pass in ``{'ERROR', 'CANCELLED'}``. Possible values are currently only
        ``'ERROR'`` and ``'CANCELLED'``, but could change in the future. Defaults to
        ``{'ERROR'}``.
    :type job_error_states: set
    :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
        This is useful for identifying or linking to the job in the Google Cloud Console
        Dataproc UI, as the actual "jobId" submitted to the Dataproc API is appended with
        an 8 character random string.
    :vartype dataproc_job_id: str
    """
    job_type = ""

    @apply_defaults
    def __init__(self,
                 job_name='{{task.task_id}}_{{ds_nodash}}',
                 cluster_name="cluster-1",
                 dataproc_properties=None,
                 dataproc_jars=None,
                 gcp_conn_id='google_cloud_default',
                 delegate_to=None,
                 labels=None,
                 region='global',
                 job_error_states=None,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.gcp_conn_id = gcp_conn_id
        self.delegate_to = delegate_to
        self.labels = labels
        self.job_name = job_name
        self.cluster_name = cluster_name
        self.dataproc_properties = dataproc_properties
        self.dataproc_jars = dataproc_jars
        self.region = region
        self.job_error_states = job_error_states if job_error_states is not None else {
            'ERROR'
        }

        self.hook = DataProcHook(gcp_conn_id=gcp_conn_id,
                                 delegate_to=delegate_to)
        self.job_template = None
        self.job = None
        self.dataproc_job_id = None

    def create_job_template(self):
        """
        Initialize `self.job_template` with default values
        """
        self.job_template = self.hook.create_job_template(
            self.task_id, self.cluster_name, self.job_type,
            self.dataproc_properties)
        self.job_template.set_job_name(self.job_name)
        self.job_template.add_jar_file_uris(self.dataproc_jars)
        self.job_template.add_labels(self.labels)

    def execute(self, context):
        if self.job_template:
            self.job = self.job_template.build()
            self.dataproc_job_id = self.job["job"]["reference"]["jobId"]
            self.hook.submit(self.hook.project_id, self.job, self.region,
                             self.job_error_states)
        else:
            raise AirflowException("Create a job template before")

    def on_kill(self):
        """
        Callback called when the operator is killed.
        Cancel any running job.
        """
        if self.dataproc_job_id:
            self.hook.cancel(self.hook.project_id, self.dataproc_job_id,
                             self.region)