def setUp(self) -> None: with mock.patch( "airflow.gcp.hooks.automl.GoogleCloudBaseHook.__init__", new=mock_base_gcp_hook_no_default_project_id, ): self.hook = CloudAutoMLHook() self.hook._get_credentials = mock.MagicMock( # type: ignore return_value=CREDENTIALS)
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) result = hook.get_model( model_id=self.model_id, location=self.location, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) return MessageToDict(result)
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) self.log.info("Updating AutoML dataset %s.", self.dataset["name"]) result = hook.update_dataset( dataset=self.dataset, update_mask=self.update_mask, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) self.log.info("Dataset updated.") return MessageToDict(result)
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) dataset_id_list = self._parse_dataset_id(self.dataset_id) for dataset_id in dataset_id_list: self.log.info("Deleting dataset %s", dataset_id) hook.delete_dataset( dataset_id=dataset_id, location=self.location, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) self.log.info("Dataset deleted.")
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) self.log.info("Importing dataset") operation = hook.import_data( dataset_id=self.dataset_id, input_config=self.input_config, location=self.location, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) result = MessageToDict(operation.result()) self.log.info("Import completed") return result
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) self.log.info("Deploying model_id %s", self.model_id) operation = hook.deploy_model( model_id=self.model_id, location=self.location, project_id=self.project_id, image_detection_metadata=self.image_detection_metadata, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) result = MessageToDict(operation.result()) self.log.info("Model deployed.") return result
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) self.log.info("Requesting table specs for %s.", self.dataset_id) page_iterator = hook.list_table_specs( dataset_id=self.dataset_id, filter_=self.filter_, page_size=self.page_size, location=self.location, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) result = [MessageToDict(spec) for spec in page_iterator] self.log.info(result) self.log.info("Table specs obtained.") return result
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) self.log.info("Creating model.") operation = hook.create_model( model=self.model, location=self.location, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) result = MessageToDict(operation.result()) model_id = hook.extract_object_id(result) self.log.info("Model created: %s", model_id) self.xcom_push(context, key="model_id", value=model_id) return result
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) self.log.info("Creating dataset") result = hook.create_dataset( dataset=self.dataset, location=self.location, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) result = MessageToDict(result) dataset_id = hook.extract_object_id(result) self.log.info("Creating completed. Dataset id: %s", dataset_id) self.xcom_push(context, key="dataset_id", value=dataset_id) return result
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) self.log.info("Fetch batch prediction.") operation = hook.batch_predict( model_id=self.model_id, input_config=self.input_config, output_config=self.output_config, project_id=self.project_id, location=self.location, params=self.params, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) result = MessageToDict(operation.result()) self.log.info("Batch prediction ready.") return result
def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) self.log.info("Requesting datasets") page_iterator = hook.list_datasets( location=self.location, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) result = [MessageToDict(dataset) for dataset in page_iterator] self.log.info("Datasets obtained.") self.xcom_push( context, key="dataset_id_list", value=[hook.extract_object_id(d) for d in result], ) return result
class TestAuoMLHook(unittest.TestCase): def setUp(self) -> None: with mock.patch( "airflow.gcp.hooks.automl.GoogleCloudBaseHook.__init__", new=mock_base_gcp_hook_no_default_project_id, ): self.hook = CloudAutoMLHook() self.hook._get_credentials = mock.MagicMock( # type: ignore return_value=CREDENTIALS) @mock.patch( "airflow.gcp.hooks.automl.GoogleCloudBaseHook.client_info", new_callable=lambda: CLIENT_INFO, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient") def test_get_conn(self, mock_automl_client, mock_client_info): self.hook.get_conn() mock_automl_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) @mock.patch( "airflow.gcp.hooks.automl.GoogleCloudBaseHook.client_info", new_callable=lambda: CLIENT_INFO, ) @mock.patch("airflow.gcp.hooks.automl.PredictionServiceClient") def test_prediction_client(self, mock_prediction_client, mock_client_info): client = self.hook.prediction_client # pylint: disable=unused-variable # noqa mock_prediction_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.create_model") def test_create_model(self, mock_create_model): self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_create_model.assert_called_once_with(parent=LOCATION_PATH, model=MODEL, retry=None, timeout=None, metadata=None) @mock.patch( "airflow.gcp.hooks.automl.PredictionServiceClient.batch_predict") def test_batch_predict(self, mock_batch_predict): self.hook.batch_predict( model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, ) mock_batch_predict.assert_called_once_with( name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.PredictionServiceClient.predict") def test_predict(self, mock_predict): self.hook.predict( model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, payload=PAYLOAD, ) mock_predict.assert_called_once_with( name=MODEL_PATH, payload=PAYLOAD, params=None, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.create_dataset") def test_create_dataset(self, mock_create_dataset): self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_create_dataset.assert_called_once_with( parent=LOCATION_PATH, dataset=DATASET, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.import_data") def test_import_dataset(self, mock_import_data): self.hook.import_data( dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, input_config=INPUT_CONFIG, ) mock_import_data.assert_called_once_with( name=DATASET_PATH, input_config=INPUT_CONFIG, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.list_column_specs") def test_list_column_specs(self, mock_list_column_specs): table_spec = "table_spec_id" filter_ = "filter" page_size = 42 self.hook.list_column_specs( dataset_id=DATASET_ID, table_spec_id=table_spec, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, field_mask=MASK, filter_=filter_, page_size=page_size, ) parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec) mock_list_column_specs.assert_called_once_with( parent=parent, field_mask=MASK, filter_=filter_, page_size=page_size, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.get_model") def test_get_model(self, mock_get_model): self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_get_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.delete_model") def test_delete_model(self, mock_delete_model): self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_delete_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.update_dataset") def test_update_dataset(self, mock_update_dataset): self.hook.update_dataset(dataset=DATASET, update_mask=MASK, project_id=GCP_PROJECT_ID) mock_update_dataset.assert_called_once_with(dataset=DATASET, update_mask=MASK, retry=None, timeout=None, metadata=None) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.deploy_model") def test_deploy_model(self, mock_deploy_model): image_detection_metadata = {} self.hook.deploy_model( model_id=MODEL_ID, image_detection_metadata=image_detection_metadata, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, ) mock_deploy_model.assert_called_once_with( name=MODEL_PATH, retry=None, timeout=None, metadata=None, image_object_detection_model_deployment_metadata= image_detection_metadata, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.list_table_specs") def test_list_table_specs(self, mock_list_table_specs): filter_ = "filter" page_size = 42 self.hook.list_table_specs( dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, filter_=filter_, page_size=page_size, ) mock_list_table_specs.assert_called_once_with( parent=DATASET_PATH, filter_=filter_, page_size=page_size, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.list_datasets") def test_list_datasets(self, mock_list_datasets): self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_list_datasets.assert_called_once_with(parent=LOCATION_PATH, retry=None, timeout=None, metadata=None) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.delete_dataset") def test_delete_dataset(self, mock_delete_dataset): self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_delete_dataset.assert_called_once_with(name=DATASET_PATH, retry=None, timeout=None, metadata=None)