Ejemplo n.º 1
0
 def test_fail(self, mock_hook):
     task = MLEngineModelOperator(task_id="task-id",
                                  project_id=TEST_PROJECT_ID,
                                  model=TEST_MODEL,
                                  operation="invalid",
                                  gcp_conn_id=TEST_GCP_CONN_ID,
                                  delegate_to=TEST_DELEGATE_TO)
     with self.assertRaises(ValueError):
         task.execute(None)
Ejemplo n.º 2
0
    def test_success_create_model(self, mock_hook):
        task = MLEngineModelOperator(task_id="task-id",
                                     project_id=TEST_PROJECT_ID,
                                     model=TEST_MODEL,
                                     operation="create",
                                     gcp_conn_id=TEST_GCP_CONN_ID,
                                     delegate_to=TEST_DELEGATE_TO)

        task.execute(None)

        mock_hook.assert_called_once_with(delegate_to=TEST_DELEGATE_TO,
                                          gcp_conn_id=TEST_GCP_CONN_ID)
        mock_hook.return_value.create_model.assert_called_once_with(
            project_id=TEST_PROJECT_ID, model=TEST_MODEL)
Ejemplo n.º 3
0
    def test_success_get_model(self, mock_hook):
        task = MLEngineModelOperator(task_id="task-id",
                                     project_id=TEST_PROJECT_ID,
                                     model=TEST_MODEL,
                                     operation="get",
                                     gcp_conn_id=TEST_GCP_CONN_ID,
                                     delegate_to=TEST_DELEGATE_TO)

        result = task.execute(None)

        mock_hook.assert_called_once_with(delegate_to=TEST_DELEGATE_TO,
                                          gcp_conn_id=TEST_GCP_CONN_ID)
        mock_hook.return_value.get_model.assert_called_once_with(
            project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME)
        self.assertEqual(mock_hook.return_value.get_model.return_value, result)
Ejemplo n.º 4
0
        task_id="training",
        project_id=PROJECT_ID,
        region="us-central1",
        job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}",
        runtime_version="1.14",
        python_version="3.5",
        job_dir=JOB_DIR,
        package_uris=[TRAINER_URI],
        training_python_module=TRAINER_PY_MODULE,
        training_args=[],
    )

    create_model = MLEngineModelOperator(
        task_id="create-model",
        project_id=PROJECT_ID,
        operation='create',
        model={
            "name": MODEL_NAME,
        },
    )

    get_model = MLEngineModelOperator(task_id="get-model",
                                      project_id=PROJECT_ID,
                                      operation="get",
                                      model={
                                          "name": MODEL_NAME,
                                      })

    get_model_result = BashOperator(
        bash_command="echo \"{{ task_instance.xcom_pull('get-model') }}\"",
        task_id="get-model-result",
    )