def main(argv=None):
  """Create and train the CLV model on AutoML Tables."""
  argv = sys.argv if argv is None else argv
  args = create_parser().parse_args(args=argv[1:])

  # create and configure client
  keyfile_name = args.key_file
  client = AutoMlClient.from_service_account_file(keyfile_name)

  # create and deploy model
  model_name = create_automl_model(client,
                                   args.project_id,
                                   args.location,
                                   args.bq_dataset,
                                   args.bq_table,
                                   args.automl_dataset,
                                   args.automl_model,
                                   args.training_budget)

  # deploy model
  deploy_model(client, model_name)

  # get model evaluations
  model_evaluation = get_model_evaluation(client, model_name)

  # make predictions
  prediction_client = PredictionServiceClient.from_service_account_file(
      keyfile_name)
  do_batch_prediction(prediction_client,
                      model_name,
                      args.batch_gcs_input,
                      args.batch_gcs_output)
Esempio n. 2
0
    def test_list_column_specs(self, mock_list_column_specs):
        table_spec = "table_spec_id"
        filter_ = "filter"
        page_size = 42

        self.hook.list_column_specs(
            dataset_id=DATASET_ID,
            table_spec_id=table_spec,
            location=GCP_LOCATION,
            project_id=GCP_PROJECT_ID,
            field_mask=MASK,
            filter_=filter_,
            page_size=page_size,
        )

        parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION,
                                              DATASET_ID, table_spec)
        mock_list_column_specs.assert_called_once_with(
            parent=parent,
            field_mask=MASK,
            filter_=filter_,
            page_size=page_size,
            retry=None,
            timeout=None,
            metadata=None,
        )
Esempio n. 3
0
def export_dataset(project: str, dataset: str, bucket: str):
    if os.path.isfile(dataset_file_name):
        logging.info('Dataset already downloaded, no export done.')
        return dataset_file_name
    client = AutoMlClient()
    export_path = 'gs://{}/export/export_{}'.format(bucket, dataset)
    output_config = {"gcs_destination": {"output_uri_prefix": export_path}}
    dataset_name = client.dataset_path(project, compute_region, dataset)
    export_operation = client.export_data(dataset_name, output_config)
    logging.info('Waiting for the export to complete...')
    export_operation.result()
    logging.info('Downloading exported csv...')
    download_training_csv(bucket,
                          'export/export_{}/export.csv'.format(dataset),
                          dataset_file_name)
    return dataset_file_name
Esempio n. 4
0
    def get_conn(self) -> AutoMlClient:
        """
        Retrieves connection to AutoML.

        :return: Google Cloud AutoML client object.
        :rtype: google.cloud.automl_v1beta1.AutoMlClient
        """
        if self._client is None:
            self._client = AutoMlClient(credentials=self._get_credentials(), client_info=self.client_info)
        return self._client
CREDENTIALS = "test-creds"
TASK_ID = "test-automl-hook"
GCP_PROJECT_ID = "test-project"
GCP_LOCATION = "test-location"
MODEL_NAME = "test_model"
MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152"
DATASET_ID = "TBL123456789"
MODEL = {
    "display_name": MODEL_NAME,
    "dataset_id": DATASET_ID,
    "tables_model_metadata": {
        "train_budget_milli_node_hours": 1000
    },
}

LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION)
MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION,
                                                MODEL_ID)

INPUT_CONFIG = {"input": "value"}
OUTPUT_CONFIG = {"output": "value"}
PAYLOAD = {"test": "payload"}
DATASET = {"dataset_id": "data"}
MASK = {"field": "mask"}


class TestAutoMLTrainModelOperator(unittest.TestCase):
    @mock.patch(
        "airflow.gcp.operators.automl.AutoMLTrainModelOperator.xcom_push")
    @mock.patch("airflow.gcp.hooks.automl.CloudAutoMLHook.create_model")
    @mock.patch(
Esempio n. 6
0
CLIENT_INFO = "client-info"
TASK_ID = "test-automl-hook"
GCP_PROJECT_ID = "test-project"
GCP_LOCATION = "test-location"
MODEL_NAME = "test_model"
MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152"
DATASET_ID = "TBL123456789"
MODEL = {
    "display_name": MODEL_NAME,
    "dataset_id": DATASET_ID,
    "tables_model_metadata": {
        "train_budget_milli_node_hours": 1000
    },
}

LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION)
MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION,
                                                MODEL_ID)
DATASET_PATH = AutoMlClient.dataset_path(GCP_PROJECT_ID, GCP_LOCATION,
                                         DATASET_ID)

INPUT_CONFIG = {"input": "value"}
OUTPUT_CONFIG = {"output": "value"}
PAYLOAD = {"test": "payload"}
DATASET = {"dataset_id": "data"}
MASK = {"field": "mask"}


class TestAuoMLHook(unittest.TestCase):
    def setUp(self) -> None:
        with mock.patch(
AUTOML_MODEL = models.Variable.get('automl_model')
AUTOML_TRAINING_BUDGET = int(models.Variable.get('automl_training_budget'))

#[START dag_build_train_deploy]
default_dag_args = {
    'start_date': datetime.datetime(2050, 1, 1),
    'schedule_internal': None,
    'provide_context': True
}

dag = models.DAG('build_train_deploy', default_args=default_dag_args)
#[END dag_build_train_deploy]

# instantiate Google Cloud base hook to get credentials and create automl clients
gcp_hook = GoogleCloudBaseHook(conn_id='google_cloud_default')
automl_client = AutoMlClient(credentials=gcp_hook._get_credentials())

# Loads the database dump from Cloud Storage to BigQuery
t1 = gcs_to_bq.GoogleCloudStorageToBigQueryOperator(
    task_id="db_dump_to_bigquery",
    bucket=COMPOSER_BUCKET_NAME,
    source_objects=[DB_DUMP_FILENAME],
    schema_object="schema_source.json",
    source_format="CSV",
    skip_leading_rows=1,
    destination_project_dataset_table="{}.{}.{}".format(
        PROJECT, DATASET, 'data_source'),
    create_disposition="CREATE_IF_NEEDED",
    write_disposition="WRITE_TRUNCATE",
    dag=dag)
Esempio n. 8
0
from airflow.contrib.operators import mlengine_operator_utils
from airflow.contrib.operators import dataflow_operator
from airflow.contrib.operators import gcs_to_bq
# TODO Add when Composer on v2.0 and more Hook
# from airflow.contrib.operators import gcs_list_operator
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
from airflow.utils import trigger_rule

from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
from clv_automl import clv_automl

# instantiate Google Cloud base hook to get credentials and create automl clients
gcp_credentials = GoogleCloudBaseHook(
    conn_id='google_cloud_default')._get_credentials()
automl_client = AutoMlClient(credentials=gcp_credentials)
automl_predict_client = PredictionServiceClient(credentials=gcp_credentials)


def _get_project_id():
    """Get project ID from default GCP connection."""

    extras = BaseHook.get_connection('google_cloud_default').extra_dejson
    key = 'extra__google_cloud_platform__project'
    if key in extras:
        project_id = extras[key]
    else:
        raise ('Must configure project_id in google_cloud_default '
               'connection from Airflow Console')
    return project_id