def test_success_create_training_job(self, mock_hook):
        success_response = self.TRAINING_INPUT.copy()
        success_response['state'] = 'SUCCEEDED'
        hook_instance = mock_hook.return_value
        hook_instance.create_job.return_value = success_response

        training_op = MLEngineStartTrainingJobOperator(
            **self.TRAINING_DEFAULT_ARGS)
        training_op.execute(None)

        mock_hook.assert_called_once_with(
            gcp_conn_id='google_cloud_default', delegate_to=None)
        # Make sure only 'create_job' is invoked on hook instance
        self.assertEqual(len(hook_instance.mock_calls), 1)
        hook_instance.create_job.assert_called_once_with(
            project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY)
    def test_failed_job_error(self, mock_hook):
        failure_response = self.TRAINING_INPUT.copy()
        failure_response['state'] = 'FAILED'
        failure_response['errorMessage'] = 'A failure message'
        hook_instance = mock_hook.return_value
        hook_instance.create_job.return_value = failure_response

        with self.assertRaises(RuntimeError) as context:
            training_op = MLEngineStartTrainingJobOperator(
                **self.TRAINING_DEFAULT_ARGS)
            training_op.execute(None)

        mock_hook.assert_called_once_with(
            gcp_conn_id='google_cloud_default', delegate_to=None)
        # Make sure only 'create_job' is invoked on hook instance
        self.assertEqual(len(hook_instance.mock_calls), 1)
        hook_instance.create_job.assert_called_once_with(
            project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY)
        self.assertEqual('A failure message', str(context.exception))
    def test_http_error(self, mock_hook):
        http_error_code = 403
        hook_instance = mock_hook.return_value
        hook_instance.create_job.side_effect = HttpError(
            resp=httplib2.Response({
                'status': http_error_code
            }),
            content=b'Forbidden')

        with self.assertRaises(HttpError) as context:
            training_op = MLEngineStartTrainingJobOperator(
                **self.TRAINING_DEFAULT_ARGS)
            training_op.execute(None)

        mock_hook.assert_called_once_with(
            gcp_conn_id='google_cloud_default', delegate_to=None)
        # Make sure only 'create_job' is invoked on hook instance
        self.assertEqual(len(hook_instance.mock_calls), 1)
        hook_instance.create_job.assert_called_once_with(
            project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY)
        self.assertEqual(http_error_code, context.exception.resp.status)
    def test_success_create_training_job_with_optional_args(self, mock_hook):
        training_input = copy.deepcopy(self.TRAINING_INPUT)
        training_input['trainingInput']['runtimeVersion'] = '1.6'
        training_input['trainingInput']['pythonVersion'] = '3.5'
        training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training'

        success_response = self.TRAINING_INPUT.copy()
        success_response['state'] = 'SUCCEEDED'
        hook_instance = mock_hook.return_value
        hook_instance.create_job.return_value = success_response

        training_op = MLEngineStartTrainingJobOperator(
            runtime_version='1.6',
            python_version='3.5',
            job_dir='gs://some-bucket/jobs/test_training',
            **self.TRAINING_DEFAULT_ARGS)
        training_op.execute(None)

        mock_hook.assert_called_once_with(gcp_conn_id='google_cloud_default', delegate_to=None)
        # Make sure only 'create_job' is invoked on hook instance
        self.assertEqual(len(hook_instance.mock_calls), 1)
        hook_instance.create_job.assert_called_once_with(
            project_id='test-project', job=training_input, use_existing_job_fn=ANY)
Ejemplo n.º 5
0
    "params": {
        "model_name": MODEL_NAME
    }
}

with models.DAG(
        "example_gcp_mlengine",
        default_args=default_args,
        schedule_interval=None  # Override to match your needs
) as dag:
    training = MLEngineStartTrainingJobOperator(
        task_id="training",
        project_id=PROJECT_ID,
        region="us-central1",
        job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}",
        runtime_version="1.14",
        python_version="3.5",
        job_dir=JOB_DIR,
        package_uris=[TRAINER_URI],
        training_python_module=TRAINER_PY_MODULE,
        training_args=[],
    )

    create_model = MLEngineManageModelOperator(
        task_id="create-model",
        project_id=PROJECT_ID,
        operation='create',
        model={
            "name": MODEL_NAME,
        },
    )