class TestGcpVisionHook(unittest.TestCase): def setUp(self): with mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.__init__', new=mock_base_gcp_hook_default_project_id, ): self.hook = CloudVisionHook(gcp_conn_id='test') @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_productset_explicit_id(self, get_conn): # Given create_product_set_method = get_conn.return_value.create_product_set create_product_set_method.return_value = None parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product_set = ProductSet() # When result = self.hook.create_product_set( location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, product_set=product_set, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then # ProductSet ID was provided explicitly in the method call above, should be returned from the method self.assertEqual(result, PRODUCTSET_ID_TEST) create_product_set_method.assert_called_once_with( parent=parent, product_set=product_set, product_set_id=PRODUCTSET_ID_TEST, retry=None, timeout=None, metadata=None, ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_productset_autogenerated_id(self, get_conn): # Given autogenerated_id = 'autogen-id' response_product_set = ProductSet( name=ProductSearchClient.product_set_path(PROJECT_ID_TEST, LOC_ID_TEST, autogenerated_id) ) create_product_set_method = get_conn.return_value.create_product_set create_product_set_method.return_value = response_product_set parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product_set = ProductSet() # When result = self.hook.create_product_set( location=LOC_ID_TEST, product_set_id=None, product_set=product_set, project_id=PROJECT_ID_TEST ) # Then # ProductSet ID was not provided in the method call above. Should be extracted from the API response # and returned. self.assertEqual(result, autogenerated_id) create_product_set_method.assert_called_once_with( parent=parent, product_set=product_set, product_set_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_productset_autogenerated_id_wrong_api_response(self, get_conn): # Given response_product_set = None create_product_set_method = get_conn.return_value.create_product_set create_product_set_method.return_value = response_product_set parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product_set = ProductSet() # When with self.assertRaises(AirflowException) as cm: self.hook.create_product_set( location=LOC_ID_TEST, product_set_id=None, product_set=product_set, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then # API response was wrong (None) and thus ProductSet ID extraction should fail. err = cm.exception self.assertIn('Unable to get name from response...', str(err)) create_product_set_method.assert_called_once_with( parent=parent, product_set=product_set, product_set_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_get_productset(self, get_conn): # Given name = ProductSearchClient.product_set_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST) response_product_set = ProductSet(name=name) get_product_set_method = get_conn.return_value.get_product_set get_product_set_method.return_value = response_product_set # When response = self.hook.get_product_set( location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, project_id=PROJECT_ID_TEST ) # Then self.assertTrue(response) self.assertEqual(response, MessageToDict(response_product_set)) get_product_set_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_productset_no_explicit_name(self, get_conn): # Given product_set = ProductSet() update_product_set_method = get_conn.return_value.update_product_set update_product_set_method.return_value = product_set productset_name = ProductSearchClient.product_set_path( PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST ) # When result = self.hook.update_product_set( location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, product_set=product_set, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then self.assertEqual(result, MessageToDict(product_set)) update_product_set_method.assert_called_once_with( product_set=ProductSet(name=productset_name), metadata=None, retry=None, timeout=None, update_mask=None, ) @parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST), (LOC_ID_TEST, None)]) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_productset_no_explicit_name_and_missing_params_for_constructed_name( self, location, product_set_id, get_conn ): # Given update_product_set_method = get_conn.return_value.update_product_set update_product_set_method.return_value = None product_set = ProductSet() # When with self.assertRaises(AirflowException) as cm: self.hook.update_product_set( location=location, product_set_id=product_set_id, product_set=product_set, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) err = cm.exception self.assertTrue(err) self.assertIn( "Unable to determine the ProductSet name. Please either set the name directly in the " "ProductSet object or provide the `location` and `productset_id` parameters.", str(err), ) update_product_set_method.assert_not_called() @parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST), (LOC_ID_TEST, None)]) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_productset_explicit_name_missing_params_for_constructed_name( self, location, product_set_id, get_conn ): # Given explicit_ps_name = ProductSearchClient.product_set_path( PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCTSET_ID_TEST_2 ) product_set = ProductSet(name=explicit_ps_name) update_product_set_method = get_conn.return_value.update_product_set update_product_set_method.return_value = product_set # When result = self.hook.update_product_set( location=location, product_set_id=product_set_id, product_set=product_set, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then self.assertEqual(result, MessageToDict(product_set)) update_product_set_method.assert_called_once_with( product_set=ProductSet(name=explicit_ps_name), metadata=None, retry=None, timeout=None, update_mask=None, ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_productset_explicit_name_different_from_constructed(self, get_conn): # Given update_product_set_method = get_conn.return_value.update_product_set update_product_set_method.return_value = None explicit_ps_name = ProductSearchClient.product_set_path( PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCTSET_ID_TEST_2 ) product_set = ProductSet(name=explicit_ps_name) template_ps_name = ProductSearchClient.product_set_path( PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST ) # When # Location and product_set_id are passed in addition to a ProductSet with an explicit name, # but both names differ (constructed != explicit). # Should throw AirflowException in this case. with self.assertRaises(AirflowException) as cm: self.hook.update_product_set( location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, product_set=product_set, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) err = cm.exception # self.assertIn("The required parameter 'project_id' is missing", str(err)) self.assertTrue(err) self.assertIn( "The ProductSet name provided in the object ({}) is different than the name " "created from the input parameters ({}). Please either: 1) Remove the ProductSet " "name, 2) Remove the location and productset_id parameters, 3) Unify the " "ProductSet name and input parameters.".format(explicit_ps_name, template_ps_name), str(err), ) update_product_set_method.assert_not_called() @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_delete_productset(self, get_conn): # Given delete_product_set_method = get_conn.return_value.delete_product_set delete_product_set_method.return_value = None name = ProductSearchClient.product_set_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST) # When response = self.hook.delete_product_set( location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, project_id=PROJECT_ID_TEST ) # Then self.assertIsNone(response) delete_product_set_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn', **{'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST} ) def test_create_reference_image_explicit_id(self, get_conn): # Given create_reference_image_method = get_conn.return_value.create_reference_image # When result = self.hook.create_reference_image( project_id=PROJECT_ID_TEST, location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, reference_image=REFERENCE_IMAGE_WITHOUT_ID_NAME, reference_image_id=REFERENCE_IMAGE_ID_TEST, ) # Then # Product ID was provided explicitly in the method call above, should be returned from the method self.assertEqual(result, REFERENCE_IMAGE_ID_TEST) create_reference_image_method.assert_called_once_with( parent=PRODUCT_NAME, reference_image=REFERENCE_IMAGE_WITHOUT_ID_NAME, reference_image_id=REFERENCE_IMAGE_ID_TEST, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn', **{'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST} ) def test_create_reference_image_autogenerated_id(self, get_conn): # Given create_reference_image_method = get_conn.return_value.create_reference_image # When result = self.hook.create_reference_image( project_id=PROJECT_ID_TEST, location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, reference_image=REFERENCE_IMAGE_TEST, reference_image_id=REFERENCE_IMAGE_ID_TEST, ) # Then # Product ID was provided explicitly in the method call above, should be returned from the method self.assertEqual(result, REFERENCE_IMAGE_GEN_ID_TEST) create_reference_image_method.assert_called_once_with( parent=PRODUCT_NAME, reference_image=REFERENCE_IMAGE_TEST, reference_image_id=REFERENCE_IMAGE_ID_TEST, retry=None, timeout=None, metadata=None, ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_add_product_to_product_set(self, get_conn): # Given add_product_to_product_set_method = get_conn.return_value.add_product_to_product_set # When self.hook.add_product_to_product_set( product_set_id=PRODUCTSET_ID_TEST, product_id=PRODUCT_ID_TEST, location=LOC_ID_TEST, project_id=PROJECT_ID_TEST, ) # Then # Product ID was provided explicitly in the method call above, should be returned from the method add_product_to_product_set_method.assert_called_once_with( name=PRODUCTSET_NAME_TEST, product=PRODUCT_NAME_TEST, retry=None, timeout=None, metadata=None ) # remove_product_from_product_set @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_remove_product_from_product_set(self, get_conn): # Given remove_product_from_product_set_method = get_conn.return_value.remove_product_from_product_set # When self.hook.remove_product_from_product_set( product_set_id=PRODUCTSET_ID_TEST, product_id=PRODUCT_ID_TEST, location=LOC_ID_TEST, project_id=PROJECT_ID_TEST, ) # Then # Product ID was provided explicitly in the method call above, should be returned from the method remove_product_from_product_set_method.assert_called_once_with( name=PRODUCTSET_NAME_TEST, product=PRODUCT_NAME_TEST, retry=None, timeout=None, metadata=None ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client') def test_annotate_image(self, annotator_client_mock): # Given annotate_image_method = annotator_client_mock.annotate_image # When self.hook.annotate_image(request=ANNOTATE_IMAGE_REQUEST) # Then # Product ID was provided explicitly in the method call above, should be returned from the method annotate_image_method.assert_called_once_with( request=ANNOTATE_IMAGE_REQUEST, retry=None, timeout=None ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_product_explicit_id(self, get_conn): # Given create_product_method = get_conn.return_value.create_product create_product_method.return_value = None parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product = Product() # When result = self.hook.create_product( location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, product=product, project_id=PROJECT_ID_TEST ) # Then # Product ID was provided explicitly in the method call above, should be returned from the method self.assertEqual(result, PRODUCT_ID_TEST) create_product_method.assert_called_once_with( parent=parent, product=product, product_id=PRODUCT_ID_TEST, retry=None, timeout=None, metadata=None, ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_product_autogenerated_id(self, get_conn): # Given autogenerated_id = 'autogen-p-id' response_product = Product( name=ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST, autogenerated_id) ) create_product_method = get_conn.return_value.create_product create_product_method.return_value = response_product parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product = Product() # When result = self.hook.create_product( location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST ) # Then # Product ID was not provided in the method call above. Should be extracted from the API response # and returned. self.assertEqual(result, autogenerated_id) create_product_method.assert_called_once_with( parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_product_autogenerated_id_wrong_name_in_response(self, get_conn): # Given wrong_name = 'wrong_name_not_a_correct_path' response_product = Product(name=wrong_name) create_product_method = get_conn.return_value.create_product create_product_method.return_value = response_product parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product = Product() # When with self.assertRaises(AirflowException) as cm: self.hook.create_product( location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST ) # Then # API response was wrong (wrong name format) and thus ProductSet ID extraction should fail. err = cm.exception self.assertIn('Unable to get id from name', str(err)) create_product_method.assert_called_once_with( parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_product_autogenerated_id_wrong_api_response(self, get_conn): # Given response_product = None create_product_method = get_conn.return_value.create_product create_product_method.return_value = response_product parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product = Product() # When with self.assertRaises(AirflowException) as cm: self.hook.create_product( location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST ) # Then # API response was wrong (None) and thus ProductSet ID extraction should fail. err = cm.exception self.assertIn('Unable to get name from response...', str(err)) create_product_method.assert_called_once_with( parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_product_no_explicit_name(self, get_conn): # Given product = Product() update_product_method = get_conn.return_value.update_product update_product_method.return_value = product product_name = ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST) # When result = self.hook.update_product( location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, product=product, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then self.assertEqual(result, MessageToDict(product)) update_product_method.assert_called_once_with( product=Product(name=product_name), metadata=None, retry=None, timeout=None, update_mask=None ) @parameterized.expand([(None, None), (None, PRODUCT_ID_TEST), (LOC_ID_TEST, None)]) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_product_no_explicit_name_and_missing_params_for_constructed_name( self, location, product_id, get_conn ): # Given update_product_method = get_conn.return_value.update_product update_product_method.return_value = None product = Product() # When with self.assertRaises(AirflowException) as cm: self.hook.update_product( location=location, product_id=product_id, product=product, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) err = cm.exception self.assertTrue(err) self.assertIn( "Unable to determine the Product name. Please either set the name directly in the " "Product object or provide the `location` and `product_id` parameters.", str(err), ) update_product_method.assert_not_called() @parameterized.expand([(None, None), (None, PRODUCT_ID_TEST), (LOC_ID_TEST, None)]) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_product_explicit_name_missing_params_for_constructed_name( self, location, product_id, get_conn ): # Given explicit_p_name = ProductSearchClient.product_path( PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCT_ID_TEST_2 ) product = Product(name=explicit_p_name) update_product_method = get_conn.return_value.update_product update_product_method.return_value = product # When result = self.hook.update_product( location=location, product_id=product_id, product=product, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then self.assertEqual(result, MessageToDict(product)) update_product_method.assert_called_once_with( product=Product(name=explicit_p_name), metadata=None, retry=None, timeout=None, update_mask=None ) @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_product_explicit_name_different_from_constructed(self, get_conn): # Given update_product_method = get_conn.return_value.update_product update_product_method.return_value = None explicit_p_name = ProductSearchClient.product_path( PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCT_ID_TEST_2 ) product = Product(name=explicit_p_name) template_p_name = ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST) # When # Location and product_id are passed in addition to a Product with an explicit name, # but both names differ (constructed != explicit). # Should throw AirflowException in this case. with self.assertRaises(AirflowException) as cm: self.hook.update_product( location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, product=product, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) err = cm.exception self.assertTrue(err) self.assertIn( "The Product name provided in the object ({}) is different than the name created from the input " "parameters ({}). Please either: 1) Remove the Product name, 2) Remove the location and product_" "id parameters, 3) Unify the Product name and input parameters.".format( explicit_p_name, template_p_name ), str(err), ) update_product_method.assert_not_called() @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_delete_product(self, get_conn): # Given delete_product_method = get_conn.return_value.delete_product delete_product_method.return_value = None name = ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST) # When response = self.hook.delete_product( location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, project_id=PROJECT_ID_TEST ) # Then self.assertIsNone(response) delete_product_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None)
class CloudVisionProductUpdateOperator(BaseOperator): """ Makes changes to a Product resource. Only the display_name, description, and labels fields can be updated right now. If labels are updated, the change will not be reflected in queries until the next index time. .. note:: To locate the `Product` resource, its `name` in the form `projects/PROJECT_ID/locations/LOC_ID/products/PRODUCT_ID` is necessary. You can provide the `name` directly as an attribute of the `product` object. However, you can leave it blank and provide `location` and `product_id` instead (and optionally `project_id` - if not present, the connection default will be used) and the `name` will be created by the operator itself. This mechanism exists for your convenience, to allow leaving the `project_id` empty and having Airflow use the connection default `project_id`. Possible errors related to the provided `Product`: - Returns NOT_FOUND if the Product does not exist. - Returns INVALID_ARGUMENT if display_name is present in update_mask but is missing from the request or longer than 4096 characters. - Returns INVALID_ARGUMENT if description is present in update_mask but is longer than 4096 characters. - Returns INVALID_ARGUMENT if product_category is present in update_mask. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:CloudVisionProductUpdateOperator` :param product: (Required) The Product resource which replaces the one on the server. product.name is immutable. If a dict is provided, it must be of the same form as the protobuf message `Product`. :type product: dict or google.cloud.vision_v1.types.ProductSet :param location: (Optional) The region where the Product is located. Valid regions (as of 2019-02-05) are: us-east1, us-west1, europe-west1, asia-east1 :type location: str :param product_id: (Optional) The resource id of this Product. :type product_id: str :param project_id: (Optional) The project in which the Product is located. If set to None or missing, the default project_id from the GCP connection is used. :type project_id: str :param update_mask: (Optional) The `FieldMask` that specifies which fields to update. If update_mask isn’t specified, all mutable fields are to be updated. Valid mask paths include product_labels, display_name, and description. If a dict is provided, it must be of the same form as the protobuf message `FieldMask`. :type update_mask: dict or google.cloud.vision_v1.types.FieldMask :param retry: (Optional) A retry object used to retry requests. If `None` is specified, requests will not be retried. :type retry: google.api_core.retry.Retry :param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :type timeout: float :param metadata: (Optional) Additional metadata that is provided to the method. :type metadata: Sequence[Tuple[str, str]] :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. :type gcp_conn_id: str """ # [START vision_product_update_template_fields] template_fields = ('location', 'project_id', 'product_id', 'gcp_conn_id') # [END vision_product_update_template_fields] @apply_defaults def __init__( self, product, location=None, product_id=None, project_id=None, update_mask=None, retry=None, timeout=None, metadata=None, gcp_conn_id='google_cloud_default', *args, **kwargs ): super(CloudVisionProductUpdateOperator, self).__init__(*args, **kwargs) self.product = product self.location = location self.product_id = product_id self.project_id = project_id self.update_mask = update_mask self.retry = retry self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id self._hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id) def execute(self, context): return self._hook.update_product( product=self.product, location=self.location, product_id=self.product_id, project_id=self.project_id, update_mask=self.update_mask, retry=self.retry, timeout=self.timeout, metadata=self.metadata, )
class TestGcpVisionHook(unittest.TestCase): def setUp(self): with mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.__init__', new=mock_base_gcp_hook_default_project_id, ): self.hook = CloudVisionHook(gcp_conn_id='test') @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_productset_explicit_id(self, get_conn): # Given create_product_set_method = get_conn.return_value.create_product_set create_product_set_method.return_value = None parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product_set = ProductSet() # When result = self.hook.create_product_set( location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, product_set=product_set, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then # ProductSet ID was provided explicitly in the method call above, should be returned from the method self.assertEqual(result, PRODUCTSET_ID_TEST) create_product_set_method.assert_called_once_with( parent=parent, product_set=product_set, product_set_id=PRODUCTSET_ID_TEST, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_productset_autogenerated_id(self, get_conn): # Given autogenerated_id = 'autogen-id' response_product_set = ProductSet( name=ProductSearchClient.product_set_path( PROJECT_ID_TEST, LOC_ID_TEST, autogenerated_id)) create_product_set_method = get_conn.return_value.create_product_set create_product_set_method.return_value = response_product_set parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product_set = ProductSet() # When result = self.hook.create_product_set(location=LOC_ID_TEST, product_set_id=None, product_set=product_set, project_id=PROJECT_ID_TEST) # Then # ProductSet ID was not provided in the method call above. Should be extracted from the API response # and returned. self.assertEqual(result, autogenerated_id) create_product_set_method.assert_called_once_with( parent=parent, product_set=product_set, product_set_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_productset_autogenerated_id_wrong_api_response( self, get_conn): # Given response_product_set = None create_product_set_method = get_conn.return_value.create_product_set create_product_set_method.return_value = response_product_set parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product_set = ProductSet() # When with self.assertRaises(AirflowException) as cm: self.hook.create_product_set( location=LOC_ID_TEST, product_set_id=None, product_set=product_set, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then # API response was wrong (None) and thus ProductSet ID extraction should fail. err = cm.exception self.assertIn('Unable to get name from response...', str(err)) create_product_set_method.assert_called_once_with( parent=parent, product_set=product_set, product_set_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_get_productset(self, get_conn): # Given name = ProductSearchClient.product_set_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST) response_product_set = ProductSet(name=name) get_product_set_method = get_conn.return_value.get_product_set get_product_set_method.return_value = response_product_set # When response = self.hook.get_product_set(location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, project_id=PROJECT_ID_TEST) # Then self.assertTrue(response) self.assertEqual(response, MessageToDict(response_product_set)) get_product_set_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_productset_no_explicit_name(self, get_conn): # Given product_set = ProductSet() update_product_set_method = get_conn.return_value.update_product_set update_product_set_method.return_value = product_set productset_name = ProductSearchClient.product_set_path( PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST) # When result = self.hook.update_product_set( location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, product_set=product_set, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then self.assertEqual(result, MessageToDict(product_set)) update_product_set_method.assert_called_once_with( product_set=ProductSet(name=productset_name), metadata=None, retry=None, timeout=None, update_mask=None, ) @parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST), (LOC_ID_TEST, None)]) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_productset_no_explicit_name_and_missing_params_for_constructed_name( self, location, product_set_id, get_conn): # Given update_product_set_method = get_conn.return_value.update_product_set update_product_set_method.return_value = None product_set = ProductSet() # When with self.assertRaises(AirflowException) as cm: self.hook.update_product_set( location=location, product_set_id=product_set_id, product_set=product_set, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) err = cm.exception self.assertTrue(err) self.assertIn( ERR_UNABLE_TO_CREATE.format(label='ProductSet', id_label='productset_id'), str(err)) update_product_set_method.assert_not_called() @parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST), (LOC_ID_TEST, None)]) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_productset_explicit_name_missing_params_for_constructed_name( self, location, product_set_id, get_conn): # Given explicit_ps_name = ProductSearchClient.product_set_path( PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCTSET_ID_TEST_2) product_set = ProductSet(name=explicit_ps_name) update_product_set_method = get_conn.return_value.update_product_set update_product_set_method.return_value = product_set # When result = self.hook.update_product_set( location=location, product_set_id=product_set_id, product_set=product_set, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then self.assertEqual(result, MessageToDict(product_set)) update_product_set_method.assert_called_once_with( product_set=ProductSet(name=explicit_ps_name), metadata=None, retry=None, timeout=None, update_mask=None, ) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_productset_explicit_name_different_from_constructed( self, get_conn): # Given update_product_set_method = get_conn.return_value.update_product_set update_product_set_method.return_value = None explicit_ps_name = ProductSearchClient.product_set_path( PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCTSET_ID_TEST_2) product_set = ProductSet(name=explicit_ps_name) template_ps_name = ProductSearchClient.product_set_path( PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST) # When # Location and product_set_id are passed in addition to a ProductSet with an explicit name, # but both names differ (constructed != explicit). # Should throw AirflowException in this case. with self.assertRaises(AirflowException) as cm: self.hook.update_product_set( location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, product_set=product_set, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) err = cm.exception # self.assertIn("The required parameter 'project_id' is missing", str(err)) self.assertTrue(err) self.assertIn( ERR_DIFF_NAMES.format(explicit_name=explicit_ps_name, constructed_name=template_ps_name, label="ProductSet", id_label="productset_id"), str(err)) update_product_set_method.assert_not_called() @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_delete_productset(self, get_conn): # Given delete_product_set_method = get_conn.return_value.delete_product_set delete_product_set_method.return_value = None name = ProductSearchClient.product_set_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST) # When response = self.hook.delete_product_set( location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, project_id=PROJECT_ID_TEST) # Then self.assertIsNone(response) delete_product_set_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn', **{ 'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST }) def test_create_reference_image_explicit_id(self, get_conn): # Given create_reference_image_method = get_conn.return_value.create_reference_image # When result = self.hook.create_reference_image( project_id=PROJECT_ID_TEST, location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, reference_image=REFERENCE_IMAGE_WITHOUT_ID_NAME, reference_image_id=REFERENCE_IMAGE_ID_TEST, ) # Then # Product ID was provided explicitly in the method call above, should be returned from the method self.assertEqual(result, REFERENCE_IMAGE_ID_TEST) create_reference_image_method.assert_called_once_with( parent=PRODUCT_NAME, reference_image=REFERENCE_IMAGE_WITHOUT_ID_NAME, reference_image_id=REFERENCE_IMAGE_ID_TEST, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn', **{ 'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST }) def test_create_reference_image_autogenerated_id(self, get_conn): # Given create_reference_image_method = get_conn.return_value.create_reference_image # When result = self.hook.create_reference_image( project_id=PROJECT_ID_TEST, location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, reference_image=REFERENCE_IMAGE_TEST, reference_image_id=REFERENCE_IMAGE_ID_TEST, ) # Then # Product ID was provided explicitly in the method call above, should be returned from the method self.assertEqual(result, REFERENCE_IMAGE_GEN_ID_TEST) create_reference_image_method.assert_called_once_with( parent=PRODUCT_NAME, reference_image=REFERENCE_IMAGE_TEST, reference_image_id=REFERENCE_IMAGE_ID_TEST, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_add_product_to_product_set(self, get_conn): # Given add_product_to_product_set_method = get_conn.return_value.add_product_to_product_set # When self.hook.add_product_to_product_set( product_set_id=PRODUCTSET_ID_TEST, product_id=PRODUCT_ID_TEST, location=LOC_ID_TEST, project_id=PROJECT_ID_TEST, ) # Then # Product ID was provided explicitly in the method call above, should be returned from the method add_product_to_product_set_method.assert_called_once_with( name=PRODUCTSET_NAME_TEST, product=PRODUCT_NAME_TEST, retry=None, timeout=None, metadata=None) # remove_product_from_product_set @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_remove_product_from_product_set(self, get_conn): # Given remove_product_from_product_set_method = get_conn.return_value.remove_product_from_product_set # When self.hook.remove_product_from_product_set( product_set_id=PRODUCTSET_ID_TEST, product_id=PRODUCT_ID_TEST, location=LOC_ID_TEST, project_id=PROJECT_ID_TEST, ) # Then # Product ID was provided explicitly in the method call above, should be returned from the method remove_product_from_product_set_method.assert_called_once_with( name=PRODUCTSET_NAME_TEST, product=PRODUCT_NAME_TEST, retry=None, timeout=None, metadata=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client' ) def test_annotate_image(self, annotator_client_mock): # Given annotate_image_method = annotator_client_mock.annotate_image # When self.hook.annotate_image(request=ANNOTATE_IMAGE_REQUEST) # Then # Product ID was provided explicitly in the method call above, should be returned from the method annotate_image_method.assert_called_once_with( request=ANNOTATE_IMAGE_REQUEST, retry=None, timeout=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client' ) def test_batch_annotate_images(self, annotator_client_mock): # Given batch_annotate_images_method = annotator_client_mock.batch_annotate_images # When self.hook.batch_annotate_images(requests=BATCH_ANNOTATE_IMAGE_REQUEST) # Then # Product ID was provided explicitly in the method call above, should be returned from the method batch_annotate_images_method.assert_called_once_with( requests=BATCH_ANNOTATE_IMAGE_REQUEST, retry=None, timeout=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_product_explicit_id(self, get_conn): # Given create_product_method = get_conn.return_value.create_product create_product_method.return_value = None parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product = Product() # When result = self.hook.create_product(location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, product=product, project_id=PROJECT_ID_TEST) # Then # Product ID was provided explicitly in the method call above, should be returned from the method self.assertEqual(result, PRODUCT_ID_TEST) create_product_method.assert_called_once_with( parent=parent, product=product, product_id=PRODUCT_ID_TEST, retry=None, timeout=None, metadata=None, ) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_product_autogenerated_id(self, get_conn): # Given autogenerated_id = 'autogen-p-id' response_product = Product(name=ProductSearchClient.product_path( PROJECT_ID_TEST, LOC_ID_TEST, autogenerated_id)) create_product_method = get_conn.return_value.create_product create_product_method.return_value = response_product parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product = Product() # When result = self.hook.create_product(location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST) # Then # Product ID was not provided in the method call above. Should be extracted from the API response # and returned. self.assertEqual(result, autogenerated_id) create_product_method.assert_called_once_with(parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_product_autogenerated_id_wrong_name_in_response( self, get_conn): # Given wrong_name = 'wrong_name_not_a_correct_path' response_product = Product(name=wrong_name) create_product_method = get_conn.return_value.create_product create_product_method.return_value = response_product parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product = Product() # When with self.assertRaises(AirflowException) as cm: self.hook.create_product(location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST) # Then # API response was wrong (wrong name format) and thus ProductSet ID extraction should fail. err = cm.exception self.assertIn('Unable to get id from name', str(err)) create_product_method.assert_called_once_with(parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_create_product_autogenerated_id_wrong_api_response( self, get_conn): # Given response_product = None create_product_method = get_conn.return_value.create_product create_product_method.return_value = response_product parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST) product = Product() # When with self.assertRaises(AirflowException) as cm: self.hook.create_product(location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST) # Then # API response was wrong (None) and thus ProductSet ID extraction should fail. err = cm.exception self.assertIn('Unable to get name from response...', str(err)) create_product_method.assert_called_once_with(parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_product_no_explicit_name(self, get_conn): # Given product = Product() update_product_method = get_conn.return_value.update_product update_product_method.return_value = product product_name = ProductSearchClient.product_path( PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST) # When result = self.hook.update_product( location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, product=product, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then self.assertEqual(result, MessageToDict(product)) update_product_method.assert_called_once_with( product=Product(name=product_name), metadata=None, retry=None, timeout=None, update_mask=None) @parameterized.expand([(None, None), (None, PRODUCT_ID_TEST), (LOC_ID_TEST, None)]) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_product_no_explicit_name_and_missing_params_for_constructed_name( self, location, product_id, get_conn): # Given update_product_method = get_conn.return_value.update_product update_product_method.return_value = None product = Product() # When with self.assertRaises(AirflowException) as cm: self.hook.update_product( location=location, product_id=product_id, product=product, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) err = cm.exception self.assertTrue(err) self.assertIn( ERR_UNABLE_TO_CREATE.format(label='Product', id_label='product_id'), str(err), ) update_product_method.assert_not_called() @parameterized.expand([(None, None), (None, PRODUCT_ID_TEST), (LOC_ID_TEST, None)]) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_product_explicit_name_missing_params_for_constructed_name( self, location, product_id, get_conn): # Given explicit_p_name = ProductSearchClient.product_path( PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCT_ID_TEST_2) product = Product(name=explicit_p_name) update_product_method = get_conn.return_value.update_product update_product_method.return_value = product # When result = self.hook.update_product( location=location, product_id=product_id, product=product, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) # Then self.assertEqual(result, MessageToDict(product)) update_product_method.assert_called_once_with( product=Product(name=explicit_p_name), metadata=None, retry=None, timeout=None, update_mask=None) @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_update_product_explicit_name_different_from_constructed( self, get_conn): # Given update_product_method = get_conn.return_value.update_product update_product_method.return_value = None explicit_p_name = ProductSearchClient.product_path( PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCT_ID_TEST_2) product = Product(name=explicit_p_name) template_p_name = ProductSearchClient.product_path( PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST) # When # Location and product_id are passed in addition to a Product with an explicit name, # but both names differ (constructed != explicit). # Should throw AirflowException in this case. with self.assertRaises(AirflowException) as cm: self.hook.update_product( location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, product=product, update_mask=None, project_id=PROJECT_ID_TEST, retry=None, timeout=None, metadata=None, ) err = cm.exception self.assertTrue(err) self.assertIn( ERR_DIFF_NAMES.format(explicit_name=explicit_p_name, constructed_name=template_p_name, label="Product", id_label="product_id"), str(err), ) update_product_method.assert_not_called() @mock.patch( 'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn') def test_delete_product(self, get_conn): # Given delete_product_method = get_conn.return_value.delete_product delete_product_method.return_value = None name = ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST) # When response = self.hook.delete_product(location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, project_id=PROJECT_ID_TEST) # Then self.assertIsNone(response) delete_product_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None) @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_detect_text(self, annotator_client_mock): # Given detect_text_method = annotator_client_mock.text_detection detect_text_method.return_value = AnnotateImageResponse( text_annotations=[EntityAnnotation(description="test", score=0.5)]) # When self.hook.text_detection(image=DETECT_TEST_IMAGE) # Then detect_text_method.assert_called_once_with(image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None) @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_detect_text_with_additional_properties(self, annotator_client_mock): # Given detect_text_method = annotator_client_mock.text_detection detect_text_method.return_value = AnnotateImageResponse( text_annotations=[EntityAnnotation(description="test", score=0.5)]) # When self.hook.text_detection(image=DETECT_TEST_IMAGE, additional_properties={ "prop1": "test1", "prop2": "test2" }) # Then detect_text_method.assert_called_once_with(image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None, prop1="test1", prop2="test2") @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_detect_text_with_error_response(self, annotator_client_mock): # Given detect_text_method = annotator_client_mock.text_detection detect_text_method.return_value = AnnotateImageResponse( error={ "code": 3, "message": "test error message" }) # When with self.assertRaises(AirflowException) as msg: self.hook.text_detection(image=DETECT_TEST_IMAGE) err = msg.exception self.assertIn("test error message", str(err)) @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_document_text_detection(self, annotator_client_mock): # Given document_text_detection_method = annotator_client_mock.document_text_detection document_text_detection_method.return_value = AnnotateImageResponse( text_annotations=[EntityAnnotation(description="test", score=0.5)]) # When self.hook.document_text_detection(image=DETECT_TEST_IMAGE) # Then document_text_detection_method.assert_called_once_with( image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None) @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_document_text_detection_with_additional_properties( self, annotator_client_mock): # Given document_text_detection_method = annotator_client_mock.document_text_detection document_text_detection_method.return_value = AnnotateImageResponse( text_annotations=[EntityAnnotation(description="test", score=0.5)]) # When self.hook.document_text_detection(image=DETECT_TEST_IMAGE, additional_properties={ "prop1": "test1", "prop2": "test2" }) # Then document_text_detection_method.assert_called_once_with( image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None, prop1="test1", prop2="test2") @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_detect_document_text_with_error_response(self, annotator_client_mock): # Given detect_text_method = annotator_client_mock.document_text_detection detect_text_method.return_value = AnnotateImageResponse( error={ "code": 3, "message": "test error message" }) # When with self.assertRaises(AirflowException) as msg: self.hook.document_text_detection(image=DETECT_TEST_IMAGE) err = msg.exception self.assertIn("test error message", str(err)) @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_label_detection(self, annotator_client_mock): # Given label_detection_method = annotator_client_mock.label_detection label_detection_method.return_value = AnnotateImageResponse( label_annotations=[ EntityAnnotation(description="test", score=0.5) ]) # When self.hook.label_detection(image=DETECT_TEST_IMAGE) # Then label_detection_method.assert_called_once_with(image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None) @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_label_detection_with_additional_properties( self, annotator_client_mock): # Given label_detection_method = annotator_client_mock.label_detection label_detection_method.return_value = AnnotateImageResponse( label_annotations=[ EntityAnnotation(description="test", score=0.5) ]) # When self.hook.label_detection(image=DETECT_TEST_IMAGE, additional_properties={ "prop1": "test1", "prop2": "test2" }) # Then label_detection_method.assert_called_once_with(image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None, prop1="test1", prop2="test2") @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_label_detection_with_error_response(self, annotator_client_mock): # Given detect_text_method = annotator_client_mock.label_detection detect_text_method.return_value = AnnotateImageResponse( error={ "code": 3, "message": "test error message" }) # When with self.assertRaises(AirflowException) as msg: self.hook.label_detection(image=DETECT_TEST_IMAGE) err = msg.exception self.assertIn("test error message", str(err)) @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_safe_search_detection(self, annotator_client_mock): # Given safe_search_detection_method = annotator_client_mock.safe_search_detection safe_search_detection_method.return_value = AnnotateImageResponse( safe_search_annotation=SafeSearchAnnotation( adult="VERY_UNLIKELY", spoof="VERY_UNLIKELY", medical="VERY_UNLIKELY", violence="VERY_UNLIKELY", racy="VERY_UNLIKELY", )) # When self.hook.safe_search_detection(image=DETECT_TEST_IMAGE) # Then safe_search_detection_method.assert_called_once_with( image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None) @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_safe_search_detection_with_additional_properties( self, annotator_client_mock): # Given safe_search_detection_method = annotator_client_mock.safe_search_detection safe_search_detection_method.return_value = AnnotateImageResponse( safe_search_annotation=SafeSearchAnnotation( adult="VERY_UNLIKELY", spoof="VERY_UNLIKELY", medical="VERY_UNLIKELY", violence="VERY_UNLIKELY", racy="VERY_UNLIKELY", )) # When self.hook.safe_search_detection(image=DETECT_TEST_IMAGE, additional_properties={ "prop1": "test1", "prop2": "test2" }) # Then safe_search_detection_method.assert_called_once_with( image=DETECT_TEST_IMAGE, max_results=None, retry=None, timeout=None, prop1="test1", prop2="test2") @mock.patch( "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client" ) def test_safe_search_detection_with_error_response(self, annotator_client_mock): # Given detect_text_method = annotator_client_mock.safe_search_detection detect_text_method.return_value = AnnotateImageResponse( error={ "code": 3, "message": "test error message" }) # When with self.assertRaises(AirflowException) as msg: self.hook.safe_search_detection(image=DETECT_TEST_IMAGE) err = msg.exception self.assertIn("test error message", str(err))
class CloudVisionProductUpdateOperator(BaseOperator): """ Makes changes to a Product resource. Only the display_name, description, and labels fields can be updated right now. If labels are updated, the change will not be reflected in queries until the next index time. .. note:: To locate the `Product` resource, its `name` in the form `projects/PROJECT_ID/locations/LOC_ID/products/PRODUCT_ID` is necessary. You can provide the `name` directly as an attribute of the `product` object. However, you can leave it blank and provide `location` and `product_id` instead (and optionally `project_id` - if not present, the connection default will be used) and the `name` will be created by the operator itself. This mechanism exists for your convenience, to allow leaving the `project_id` empty and having Airflow use the connection default `project_id`. Possible errors related to the provided `Product`: - Returns `NOT_FOUND` if the Product does not exist. - Returns `INVALID_ARGUMENT` if `display_name` is present in update_mask but is missing from the request or longer than 4096 characters. - Returns `INVALID_ARGUMENT` if `description` is present in update_mask but is longer than 4096 characters. - Returns `INVALID_ARGUMENT` if `product_category` is present in update_mask. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:CloudVisionProductUpdateOperator` :param product: (Required) The Product resource which replaces the one on the server. product.name is immutable. If a dict is provided, it must be of the same form as the protobuf message `Product`. :type product: dict or google.cloud.vision_v1.types.ProductSet :param location: (Optional) The region where the Product is located. Valid regions (as of 2019-02-05) are: us-east1, us-west1, europe-west1, asia-east1 :type location: str :param product_id: (Optional) The resource id of this Product. :type product_id: str :param project_id: (Optional) The project in which the Product is located. If set to None or missing, the default project_id from the GCP connection is used. :type project_id: str :param update_mask: (Optional) The `FieldMask` that specifies which fields to update. If update_mask isn’t specified, all mutable fields are to be updated. Valid mask paths include product_labels, display_name, and description. If a dict is provided, it must be of the same form as the protobuf message `FieldMask`. :type update_mask: dict or google.cloud.vision_v1.types.FieldMask :param retry: (Optional) A retry object used to retry requests. If `None` is specified, requests will not be retried. :type retry: google.api_core.retry.Retry :param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete. Note that if retry is specified, the timeout applies to each individual attempt. :type timeout: float :param metadata: (Optional) Additional metadata that is provided to the method. :type metadata: sequence[tuple[str, str]] :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform. :type gcp_conn_id: str """ # [START vision_product_update_template_fields] template_fields = ('location', 'project_id', 'product_id', 'gcp_conn_id') # [END vision_product_update_template_fields] @apply_defaults def __init__( self, product, location=None, product_id=None, project_id=None, update_mask=None, retry=None, timeout=None, metadata=None, gcp_conn_id='google_cloud_default', *args, **kwargs ): super().__init__(*args, **kwargs) self.product = product self.location = location self.product_id = product_id self.project_id = project_id self.update_mask = update_mask self.retry = retry self.timeout = timeout self.metadata = metadata self.gcp_conn_id = gcp_conn_id self._hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id) def execute(self, context): return self._hook.update_product( product=self.product, location=self.location, product_id=self.product_id, project_id=self.project_id, update_mask=self.update_mask, retry=self.retry, timeout=self.timeout, metadata=self.metadata, )