def main(argv=None): """Create and train the CLV model on AutoML Tables.""" argv = sys.argv if argv is None else argv args = create_parser().parse_args(args=argv[1:]) # create and configure client keyfile_name = args.key_file client = AutoMlClient.from_service_account_file(keyfile_name) # create and deploy model model_name = create_automl_model(client, args.project_id, args.location, args.bq_dataset, args.bq_table, args.automl_dataset, args.automl_model, args.training_budget) # deploy model deploy_model(client, model_name) # get model evaluations model_evaluation = get_model_evaluation(client, model_name) # make predictions prediction_client = PredictionServiceClient.from_service_account_file( keyfile_name) do_batch_prediction(prediction_client, model_name, args.batch_gcs_input, args.batch_gcs_output)
def test_list_column_specs(self, mock_list_column_specs): table_spec = "table_spec_id" filter_ = "filter" page_size = 42 self.hook.list_column_specs( dataset_id=DATASET_ID, table_spec_id=table_spec, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, field_mask=MASK, filter_=filter_, page_size=page_size, ) parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec) mock_list_column_specs.assert_called_once_with( parent=parent, field_mask=MASK, filter_=filter_, page_size=page_size, retry=None, timeout=None, metadata=None, )
def export_dataset(project: str, dataset: str, bucket: str): if os.path.isfile(dataset_file_name): logging.info('Dataset already downloaded, no export done.') return dataset_file_name client = AutoMlClient() export_path = 'gs://{}/export/export_{}'.format(bucket, dataset) output_config = {"gcs_destination": {"output_uri_prefix": export_path}} dataset_name = client.dataset_path(project, compute_region, dataset) export_operation = client.export_data(dataset_name, output_config) logging.info('Waiting for the export to complete...') export_operation.result() logging.info('Downloading exported csv...') download_training_csv(bucket, 'export/export_{}/export.csv'.format(dataset), dataset_file_name) return dataset_file_name
def get_conn(self) -> AutoMlClient: """ Retrieves connection to AutoML. :return: Google Cloud AutoML client object. :rtype: google.cloud.automl_v1beta1.AutoMlClient """ if self._client is None: self._client = AutoMlClient(credentials=self._get_credentials(), client_info=self.client_info) return self._client
CREDENTIALS = "test-creds" TASK_ID = "test-automl-hook" GCP_PROJECT_ID = "test-project" GCP_LOCATION = "test-location" MODEL_NAME = "test_model" MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152" DATASET_ID = "TBL123456789" MODEL = { "display_name": MODEL_NAME, "dataset_id": DATASET_ID, "tables_model_metadata": { "train_budget_milli_node_hours": 1000 }, } LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION) MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID) INPUT_CONFIG = {"input": "value"} OUTPUT_CONFIG = {"output": "value"} PAYLOAD = {"test": "payload"} DATASET = {"dataset_id": "data"} MASK = {"field": "mask"} class TestAutoMLTrainModelOperator(unittest.TestCase): @mock.patch( "airflow.gcp.operators.automl.AutoMLTrainModelOperator.xcom_push") @mock.patch("airflow.gcp.hooks.automl.CloudAutoMLHook.create_model") @mock.patch(
CLIENT_INFO = "client-info" TASK_ID = "test-automl-hook" GCP_PROJECT_ID = "test-project" GCP_LOCATION = "test-location" MODEL_NAME = "test_model" MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152" DATASET_ID = "TBL123456789" MODEL = { "display_name": MODEL_NAME, "dataset_id": DATASET_ID, "tables_model_metadata": { "train_budget_milli_node_hours": 1000 }, } LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION) MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID) DATASET_PATH = AutoMlClient.dataset_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID) INPUT_CONFIG = {"input": "value"} OUTPUT_CONFIG = {"output": "value"} PAYLOAD = {"test": "payload"} DATASET = {"dataset_id": "data"} MASK = {"field": "mask"} class TestAuoMLHook(unittest.TestCase): def setUp(self) -> None: with mock.patch(
AUTOML_MODEL = models.Variable.get('automl_model') AUTOML_TRAINING_BUDGET = int(models.Variable.get('automl_training_budget')) #[START dag_build_train_deploy] default_dag_args = { 'start_date': datetime.datetime(2050, 1, 1), 'schedule_internal': None, 'provide_context': True } dag = models.DAG('build_train_deploy', default_args=default_dag_args) #[END dag_build_train_deploy] # instantiate Google Cloud base hook to get credentials and create automl clients gcp_hook = GoogleCloudBaseHook(conn_id='google_cloud_default') automl_client = AutoMlClient(credentials=gcp_hook._get_credentials()) # Loads the database dump from Cloud Storage to BigQuery t1 = gcs_to_bq.GoogleCloudStorageToBigQueryOperator( task_id="db_dump_to_bigquery", bucket=COMPOSER_BUCKET_NAME, source_objects=[DB_DUMP_FILENAME], schema_object="schema_source.json", source_format="CSV", skip_leading_rows=1, destination_project_dataset_table="{}.{}.{}".format( PROJECT, DATASET, 'data_source'), create_disposition="CREATE_IF_NEEDED", write_disposition="WRITE_TRUNCATE", dag=dag)
from airflow.contrib.operators import mlengine_operator_utils from airflow.contrib.operators import dataflow_operator from airflow.contrib.operators import gcs_to_bq # TODO Add when Composer on v2.0 and more Hook # from airflow.contrib.operators import gcs_list_operator from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook from airflow.utils import trigger_rule from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient from clv_automl import clv_automl # instantiate Google Cloud base hook to get credentials and create automl clients gcp_credentials = GoogleCloudBaseHook( conn_id='google_cloud_default')._get_credentials() automl_client = AutoMlClient(credentials=gcp_credentials) automl_predict_client = PredictionServiceClient(credentials=gcp_credentials) def _get_project_id(): """Get project ID from default GCP connection.""" extras = BaseHook.get_connection('google_cloud_default').extra_dejson key = 'extra__google_cloud_platform__project' if key in extras: project_id = extras[key] else: raise ('Must configure project_id in google_cloud_default ' 'connection from Airflow Console') return project_id