class TestDataFlowHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
                        new=mock_init):
            self.dataflow_hook = DataFlowHook(gcp_conn_id='test')

    @mock.patch("airflow.gcp.hooks.dataflow.DataFlowHook._authorize")
    @mock.patch("airflow.gcp.hooks.dataflow.build")
    def test_dataflow_client_creation(self, mock_build, mock_authorize):
        result = self.dataflow_hook.get_conn()
        mock_build.assert_called_once_with('dataflow',
                                           'v1b3',
                                           http=mock_authorize.return_value,
                                           cache_discovery=False)
        self.assertEqual(mock_build.return_value, result)

    @parameterized.expand([('default_to_python2', "python2"),
                           ('major_version_2', 'python2'),
                           ('major_version_3', 'python3'),
                           ('minor_version', 'python3.6')])
    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
    @mock.patch(DATAFLOW_STRING.format('_Dataflow'))
    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
    def test_start_python_dataflow(self, name, py, mock_conn, mock_dataflow,
                                   mock_dataflowjob, mock_uuid):
        del name  # unused variable
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_python_dataflow(job_name=JOB_NAME,
                                                 variables=DATAFLOW_OPTIONS_PY,
                                                 dataflow=PY_FILE,
                                                 py_options=PY_OPTIONS,
                                                 py_interpreter=py)
        expected_interpreter = py if py else DEFAULT_PY_INTERPRETER
        expected_cmd = [
            expected_interpreter, '-m', PY_FILE, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test', '--labels=foo=bar',
            '--staging_location=gs://test/staging',
            '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
    @mock.patch(DATAFLOW_STRING.format('_Dataflow'))
    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
    def test_start_java_dataflow(self, mock_conn, mock_dataflow,
                                 mock_dataflowjob, mock_uuid):
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_java_dataflow(job_name=JOB_NAME,
                                               variables=DATAFLOW_OPTIONS_JAVA,
                                               jar=JAR_FILE)
        expected_cmd = [
            'java', '-jar', JAR_FILE, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test',
            '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}',
            '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
    @mock.patch(DATAFLOW_STRING.format('_Dataflow'))
    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
    def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow,
                                                mock_dataflowjob, mock_uuid):
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_java_dataflow(job_name=JOB_NAME,
                                               variables=DATAFLOW_OPTIONS_JAVA,
                                               jar=JAR_FILE,
                                               job_class=JOB_CLASS)
        expected_cmd = [
            'java', '-cp', JAR_FILE, JOB_CLASS, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test',
            '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}',
            '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @mock.patch('airflow.gcp.hooks.dataflow._Dataflow.log')
    @mock.patch('subprocess.Popen')
    @mock.patch('select.select')
    def test_dataflow_wait_for_done_logging(self, mock_select, mock_popen,
                                            mock_logging):
        mock_logging.info = MagicMock()
        mock_logging.warning = MagicMock()
        mock_proc = MagicMock()
        mock_proc.stderr = MagicMock()
        mock_proc.stderr.readlines = MagicMock(
            return_value=['test\n', 'error\n'])
        mock_stderr_fd = MagicMock()
        mock_proc.stderr.fileno = MagicMock(return_value=mock_stderr_fd)
        mock_proc_poll = MagicMock()
        mock_select.return_value = [[mock_stderr_fd]]

        def poll_resp_error():
            mock_proc.return_code = 1
            return True

        mock_proc_poll.side_effect = [None, poll_resp_error]
        mock_proc.poll = mock_proc_poll
        mock_popen.return_value = mock_proc
        dataflow = _Dataflow(['test', 'cmd'])
        mock_logging.info.assert_called_once_with('Running command: %s',
                                                  'test cmd')
        self.assertRaises(Exception, dataflow.wait_for_done)

    def test_valid_dataflow_job_name(self):
        job_name = self.dataflow_hook._build_dataflow_job_name(
            job_name=JOB_NAME, append_job_name=False)

        self.assertEqual(job_name, JOB_NAME)

    def test_fix_underscore_in_job_name(self):
        job_name_with_underscore = 'test_example'
        fixed_job_name = job_name_with_underscore.replace('_', '-')
        job_name = self.dataflow_hook._build_dataflow_job_name(
            job_name=job_name_with_underscore, append_job_name=False)

        self.assertEqual(job_name, fixed_job_name)

    def test_invalid_dataflow_job_name(self):
        invalid_job_name = '9test_invalid_name'
        fixed_name = invalid_job_name.replace('_', '-')

        with self.assertRaises(ValueError) as e:
            self.dataflow_hook._build_dataflow_job_name(
                job_name=invalid_job_name, append_job_name=False)
        #   Test whether the job_name is present in the Error msg
        self.assertIn('Invalid job_name ({})'.format(fixed_name),
                      str(e.exception))

    def test_dataflow_job_regex_check(self):
        self.assertEqual(
            self.dataflow_hook._build_dataflow_job_name(job_name='df-job-1',
                                                        append_job_name=False),
            'df-job-1')

        self.assertEqual(
            self.dataflow_hook._build_dataflow_job_name(job_name='df-job',
                                                        append_job_name=False),
            'df-job')

        self.assertEqual(
            self.dataflow_hook._build_dataflow_job_name(job_name='dfjob',
                                                        append_job_name=False),
            'dfjob')

        self.assertEqual(
            self.dataflow_hook._build_dataflow_job_name(job_name='dfjob1',
                                                        append_job_name=False),
            'dfjob1')

        self.assertRaises(ValueError,
                          self.dataflow_hook._build_dataflow_job_name,
                          job_name='1dfjob',
                          append_job_name=False)

        self.assertRaises(ValueError,
                          self.dataflow_hook._build_dataflow_job_name,
                          job_name='dfjob@',
                          append_job_name=False)

        self.assertRaises(ValueError,
                          self.dataflow_hook._build_dataflow_job_name,
                          job_name='df^jo',
                          append_job_name=False)
Example #2
0
class TestDataFlowHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format('CloudBaseHook.__init__'),
                        new=mock_init):
            self.dataflow_hook = DataFlowHook(gcp_conn_id='test')

    @mock.patch("airflow.gcp.hooks.dataflow.DataFlowHook._authorize")
    @mock.patch("airflow.gcp.hooks.dataflow.build")
    def test_dataflow_client_creation(self, mock_build, mock_authorize):
        result = self.dataflow_hook.get_conn()
        mock_build.assert_called_once_with('dataflow',
                                           'v1b3',
                                           http=mock_authorize.return_value,
                                           cache_discovery=False)
        self.assertEqual(mock_build.return_value, result)

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
    def test_start_python_dataflow(self, mock_conn, mock_dataflow,
                                   mock_dataflowjob, mock_uuid):
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_python_dataflow(job_name=JOB_NAME,
                                                 variables=DATAFLOW_OPTIONS_PY,
                                                 dataflow=PY_FILE,
                                                 py_options=PY_OPTIONS)
        expected_cmd = [
            "python2", '-m', PY_FILE, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test', '--labels=foo=bar',
            '--staging_location=gs://test/staging',
            '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @parameterized.expand([('default_to_python2', "python2"),
                           ('major_version_2', 'python2'),
                           ('major_version_3', 'python3'),
                           ('minor_version', 'python3.6')])
    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
    def test_start_python_dataflow_with_custom_interpreter(
            self, name, py_interpreter, mock_conn, mock_dataflow,
            mock_dataflowjob, mock_uuid):
        del name  # unused variable
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_python_dataflow(job_name=JOB_NAME,
                                                 variables=DATAFLOW_OPTIONS_PY,
                                                 dataflow=PY_FILE,
                                                 py_options=PY_OPTIONS,
                                                 py_interpreter=py_interpreter)
        expected_cmd = [
            py_interpreter, '-m', PY_FILE, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test', '--labels=foo=bar',
            '--staging_location=gs://test/staging',
            '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
    def test_start_java_dataflow(self, mock_conn, mock_dataflow,
                                 mock_dataflowjob, mock_uuid):
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_java_dataflow(job_name=JOB_NAME,
                                               variables=DATAFLOW_OPTIONS_JAVA,
                                               jar=JAR_FILE)
        expected_cmd = [
            'java', '-jar', JAR_FILE, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test',
            '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}',
            '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
    @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
    def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow,
                                                mock_dataflowjob, mock_uuid):
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_java_dataflow(job_name=JOB_NAME,
                                               variables=DATAFLOW_OPTIONS_JAVA,
                                               jar=JAR_FILE,
                                               job_class=JOB_CLASS)
        expected_cmd = [
            'java', '-cp', JAR_FILE, JOB_CLASS, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test',
            '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}',
            '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @parameterized.expand([
        (JOB_NAME, JOB_NAME, False),
        ('test-example', 'test_example', False),
        ('test-dataflow-pipeline-12345678', JOB_NAME, True),
        ('test-example-12345678', 'test_example', True),
        ('df-job-1', 'df-job-1', False),
        ('df-job', 'df-job', False),
        ('dfjob', 'dfjob', False),
        ('dfjob1', 'dfjob1', False),
    ])
    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
    def test_valid_dataflow_job_name(self, expected_result, job_name,
                                     append_job_name, mock_uuid4):
        job_name = self.dataflow_hook._build_dataflow_job_name(
            job_name=job_name, append_job_name=append_job_name)

        self.assertEqual(expected_result, job_name)

    @parameterized.expand([("1dfjob@", ), ("dfjob@", ), ("df^jo", )])
    def test_build_dataflow_job_name_with_invalid_value(self, job_name):
        self.assertRaises(ValueError,
                          self.dataflow_hook._build_dataflow_job_name,
                          job_name=job_name,
                          append_job_name=False)