class TestGcpVisionHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(
            'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.__init__',
            new=mock_base_gcp_hook_default_project_id,
        ):
            self.hook = CloudVisionHook(gcp_conn_id='test')

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_productset_explicit_id(self, get_conn):
        # Given
        create_product_set_method = get_conn.return_value.create_product_set
        create_product_set_method.return_value = None
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
        product_set = ProductSet()
        # When
        result = self.hook.create_product_set(
            location=LOC_ID_TEST,
            product_set_id=PRODUCTSET_ID_TEST,
            product_set=product_set,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

        # Then
        # ProductSet ID was provided explicitly in the method call above, should be returned from the method
        self.assertEqual(result, PRODUCTSET_ID_TEST)
        create_product_set_method.assert_called_once_with(
            parent=parent,
            product_set=product_set,
            product_set_id=PRODUCTSET_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_productset_autogenerated_id(self, get_conn):
        # Given
        autogenerated_id = 'autogen-id'
        response_product_set = ProductSet(
            name=ProductSearchClient.product_set_path(PROJECT_ID_TEST, LOC_ID_TEST, autogenerated_id)
        )
        create_product_set_method = get_conn.return_value.create_product_set
        create_product_set_method.return_value = response_product_set
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
        product_set = ProductSet()
        # When
        result = self.hook.create_product_set(
            location=LOC_ID_TEST, product_set_id=None, product_set=product_set, project_id=PROJECT_ID_TEST
        )
        # Then
        # ProductSet ID was not provided in the method call above. Should be extracted from the API response
        # and returned.
        self.assertEqual(result, autogenerated_id)
        create_product_set_method.assert_called_once_with(
            parent=parent,
            product_set=product_set,
            product_set_id=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_productset_autogenerated_id_wrong_api_response(self, get_conn):
        # Given
        response_product_set = None
        create_product_set_method = get_conn.return_value.create_product_set
        create_product_set_method.return_value = response_product_set
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
        product_set = ProductSet()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.create_product_set(
                location=LOC_ID_TEST,
                product_set_id=None,
                product_set=product_set,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        # Then
        # API response was wrong (None) and thus ProductSet ID extraction should fail.
        err = cm.exception
        self.assertIn('Unable to get name from response...', str(err))
        create_product_set_method.assert_called_once_with(
            parent=parent,
            product_set=product_set,
            product_set_id=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_get_productset(self, get_conn):
        # Given
        name = ProductSearchClient.product_set_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST)
        response_product_set = ProductSet(name=name)
        get_product_set_method = get_conn.return_value.get_product_set
        get_product_set_method.return_value = response_product_set
        # When
        response = self.hook.get_product_set(
            location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, project_id=PROJECT_ID_TEST
        )
        # Then
        self.assertTrue(response)
        self.assertEqual(response, MessageToDict(response_product_set))
        get_product_set_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None)

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_productset_no_explicit_name(self, get_conn):
        # Given
        product_set = ProductSet()
        update_product_set_method = get_conn.return_value.update_product_set
        update_product_set_method.return_value = product_set
        productset_name = ProductSearchClient.product_set_path(
            PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST
        )
        # When
        result = self.hook.update_product_set(
            location=LOC_ID_TEST,
            product_set_id=PRODUCTSET_ID_TEST,
            product_set=product_set,
            update_mask=None,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )
        # Then
        self.assertEqual(result, MessageToDict(product_set))
        update_product_set_method.assert_called_once_with(
            product_set=ProductSet(name=productset_name),
            metadata=None,
            retry=None,
            timeout=None,
            update_mask=None,
        )

    @parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST), (LOC_ID_TEST, None)])
    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_productset_no_explicit_name_and_missing_params_for_constructed_name(
        self, location, product_set_id, get_conn
    ):
        # Given
        update_product_set_method = get_conn.return_value.update_product_set
        update_product_set_method.return_value = None
        product_set = ProductSet()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.update_product_set(
                location=location,
                product_set_id=product_set_id,
                product_set=product_set,
                update_mask=None,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        err = cm.exception
        self.assertTrue(err)
        self.assertIn(
            "Unable to determine the ProductSet name. Please either set the name directly in the "
            "ProductSet object or provide the `location` and `productset_id` parameters.",
            str(err),
        )
        update_product_set_method.assert_not_called()

    @parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST), (LOC_ID_TEST, None)])
    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_productset_explicit_name_missing_params_for_constructed_name(
        self, location, product_set_id, get_conn
    ):
        # Given
        explicit_ps_name = ProductSearchClient.product_set_path(
            PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCTSET_ID_TEST_2
        )
        product_set = ProductSet(name=explicit_ps_name)
        update_product_set_method = get_conn.return_value.update_product_set
        update_product_set_method.return_value = product_set
        # When
        result = self.hook.update_product_set(
            location=location,
            product_set_id=product_set_id,
            product_set=product_set,
            update_mask=None,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )
        # Then
        self.assertEqual(result, MessageToDict(product_set))
        update_product_set_method.assert_called_once_with(
            product_set=ProductSet(name=explicit_ps_name),
            metadata=None,
            retry=None,
            timeout=None,
            update_mask=None,
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_productset_explicit_name_different_from_constructed(self, get_conn):
        # Given
        update_product_set_method = get_conn.return_value.update_product_set
        update_product_set_method.return_value = None
        explicit_ps_name = ProductSearchClient.product_set_path(
            PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCTSET_ID_TEST_2
        )
        product_set = ProductSet(name=explicit_ps_name)
        template_ps_name = ProductSearchClient.product_set_path(
            PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST
        )
        # When
        # Location and product_set_id are passed in addition to a ProductSet with an explicit name,
        # but both names differ (constructed != explicit).
        # Should throw AirflowException in this case.
        with self.assertRaises(AirflowException) as cm:
            self.hook.update_product_set(
                location=LOC_ID_TEST,
                product_set_id=PRODUCTSET_ID_TEST,
                product_set=product_set,
                update_mask=None,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        err = cm.exception
        # self.assertIn("The required parameter 'project_id' is missing", str(err))
        self.assertTrue(err)
        self.assertIn(
            "The ProductSet name provided in the object ({}) is different than the name "
            "created from the input parameters ({}). Please either: 1) Remove the ProductSet "
            "name, 2) Remove the location and productset_id parameters, 3) Unify the "
            "ProductSet name and input parameters.".format(explicit_ps_name, template_ps_name),
            str(err),
        )
        update_product_set_method.assert_not_called()

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_delete_productset(self, get_conn):
        # Given
        delete_product_set_method = get_conn.return_value.delete_product_set
        delete_product_set_method.return_value = None
        name = ProductSearchClient.product_set_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST)
        # When
        response = self.hook.delete_product_set(
            location=LOC_ID_TEST, product_set_id=PRODUCTSET_ID_TEST, project_id=PROJECT_ID_TEST
        )
        # Then
        self.assertIsNone(response)
        delete_product_set_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn',
        **{'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST}
    )
    def test_create_reference_image_explicit_id(self, get_conn):
        # Given
        create_reference_image_method = get_conn.return_value.create_reference_image

        # When
        result = self.hook.create_reference_image(
            project_id=PROJECT_ID_TEST,
            location=LOC_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            reference_image=REFERENCE_IMAGE_WITHOUT_ID_NAME,
            reference_image_id=REFERENCE_IMAGE_ID_TEST,
        )
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        self.assertEqual(result, REFERENCE_IMAGE_ID_TEST)
        create_reference_image_method.assert_called_once_with(
            parent=PRODUCT_NAME,
            reference_image=REFERENCE_IMAGE_WITHOUT_ID_NAME,
            reference_image_id=REFERENCE_IMAGE_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn',
        **{'return_value.create_reference_image.return_value': REFERENCE_IMAGE_TEST}
    )
    def test_create_reference_image_autogenerated_id(self, get_conn):
        # Given
        create_reference_image_method = get_conn.return_value.create_reference_image

        # When
        result = self.hook.create_reference_image(
            project_id=PROJECT_ID_TEST,
            location=LOC_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            reference_image=REFERENCE_IMAGE_TEST,
            reference_image_id=REFERENCE_IMAGE_ID_TEST,
        )
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        self.assertEqual(result, REFERENCE_IMAGE_GEN_ID_TEST)
        create_reference_image_method.assert_called_once_with(
            parent=PRODUCT_NAME,
            reference_image=REFERENCE_IMAGE_TEST,
            reference_image_id=REFERENCE_IMAGE_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_add_product_to_product_set(self, get_conn):
        # Given
        add_product_to_product_set_method = get_conn.return_value.add_product_to_product_set

        # When
        self.hook.add_product_to_product_set(
            product_set_id=PRODUCTSET_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            location=LOC_ID_TEST,
            project_id=PROJECT_ID_TEST,
        )
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        add_product_to_product_set_method.assert_called_once_with(
            name=PRODUCTSET_NAME_TEST, product=PRODUCT_NAME_TEST, retry=None, timeout=None, metadata=None
        )

    # remove_product_from_product_set
    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_remove_product_from_product_set(self, get_conn):
        # Given
        remove_product_from_product_set_method = get_conn.return_value.remove_product_from_product_set

        # When
        self.hook.remove_product_from_product_set(
            product_set_id=PRODUCTSET_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            location=LOC_ID_TEST,
            project_id=PROJECT_ID_TEST,
        )
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        remove_product_from_product_set_method.assert_called_once_with(
            name=PRODUCTSET_NAME_TEST, product=PRODUCT_NAME_TEST, retry=None, timeout=None, metadata=None
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client')
    def test_annotate_image(self, annotator_client_mock):
        # Given
        annotate_image_method = annotator_client_mock.annotate_image

        # When
        self.hook.annotate_image(request=ANNOTATE_IMAGE_REQUEST)
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        annotate_image_method.assert_called_once_with(
            request=ANNOTATE_IMAGE_REQUEST, retry=None, timeout=None
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_product_explicit_id(self, get_conn):
        # Given
        create_product_method = get_conn.return_value.create_product
        create_product_method.return_value = None
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
        product = Product()
        # When
        result = self.hook.create_product(
            location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, product=product, project_id=PROJECT_ID_TEST
        )
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        self.assertEqual(result, PRODUCT_ID_TEST)
        create_product_method.assert_called_once_with(
            parent=parent,
            product=product,
            product_id=PRODUCT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_product_autogenerated_id(self, get_conn):
        # Given
        autogenerated_id = 'autogen-p-id'
        response_product = Product(
            name=ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST, autogenerated_id)
        )
        create_product_method = get_conn.return_value.create_product
        create_product_method.return_value = response_product
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
        product = Product()
        # When
        result = self.hook.create_product(
            location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST
        )
        # Then
        # Product ID was not provided in the method call above. Should be extracted from the API response
        # and returned.
        self.assertEqual(result, autogenerated_id)
        create_product_method.assert_called_once_with(
            parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_product_autogenerated_id_wrong_name_in_response(self, get_conn):
        # Given
        wrong_name = 'wrong_name_not_a_correct_path'
        response_product = Product(name=wrong_name)
        create_product_method = get_conn.return_value.create_product
        create_product_method.return_value = response_product
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
        product = Product()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.create_product(
                location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST
            )
        # Then
        # API response was wrong (wrong name format) and thus ProductSet ID extraction should fail.
        err = cm.exception
        self.assertIn('Unable to get id from name', str(err))
        create_product_method.assert_called_once_with(
            parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_product_autogenerated_id_wrong_api_response(self, get_conn):
        # Given
        response_product = None
        create_product_method = get_conn.return_value.create_product
        create_product_method.return_value = response_product
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST, LOC_ID_TEST)
        product = Product()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.create_product(
                location=LOC_ID_TEST, product_id=None, product=product, project_id=PROJECT_ID_TEST
            )
        # Then
        # API response was wrong (None) and thus ProductSet ID extraction should fail.
        err = cm.exception
        self.assertIn('Unable to get name from response...', str(err))
        create_product_method.assert_called_once_with(
            parent=parent, product=product, product_id=None, retry=None, timeout=None, metadata=None
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_product_no_explicit_name(self, get_conn):
        # Given
        product = Product()
        update_product_method = get_conn.return_value.update_product
        update_product_method.return_value = product
        product_name = ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST)
        # When
        result = self.hook.update_product(
            location=LOC_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            product=product,
            update_mask=None,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )
        # Then
        self.assertEqual(result, MessageToDict(product))
        update_product_method.assert_called_once_with(
            product=Product(name=product_name), metadata=None, retry=None, timeout=None, update_mask=None
        )

    @parameterized.expand([(None, None), (None, PRODUCT_ID_TEST), (LOC_ID_TEST, None)])
    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_product_no_explicit_name_and_missing_params_for_constructed_name(
        self, location, product_id, get_conn
    ):
        # Given
        update_product_method = get_conn.return_value.update_product
        update_product_method.return_value = None
        product = Product()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.update_product(
                location=location,
                product_id=product_id,
                product=product,
                update_mask=None,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        err = cm.exception
        self.assertTrue(err)
        self.assertIn(
            "Unable to determine the Product name. Please either set the name directly in the "
            "Product object or provide the `location` and `product_id` parameters.",
            str(err),
        )
        update_product_method.assert_not_called()

    @parameterized.expand([(None, None), (None, PRODUCT_ID_TEST), (LOC_ID_TEST, None)])
    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_product_explicit_name_missing_params_for_constructed_name(
        self, location, product_id, get_conn
    ):
        # Given
        explicit_p_name = ProductSearchClient.product_path(
            PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCT_ID_TEST_2
        )
        product = Product(name=explicit_p_name)
        update_product_method = get_conn.return_value.update_product
        update_product_method.return_value = product
        # When
        result = self.hook.update_product(
            location=location,
            product_id=product_id,
            product=product,
            update_mask=None,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )
        # Then
        self.assertEqual(result, MessageToDict(product))
        update_product_method.assert_called_once_with(
            product=Product(name=explicit_p_name), metadata=None, retry=None, timeout=None, update_mask=None
        )

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_product_explicit_name_different_from_constructed(self, get_conn):
        # Given
        update_product_method = get_conn.return_value.update_product
        update_product_method.return_value = None
        explicit_p_name = ProductSearchClient.product_path(
            PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCT_ID_TEST_2
        )
        product = Product(name=explicit_p_name)
        template_p_name = ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST)
        # When
        # Location and product_id are passed in addition to a Product with an explicit name,
        # but both names differ (constructed != explicit).
        # Should throw AirflowException in this case.
        with self.assertRaises(AirflowException) as cm:
            self.hook.update_product(
                location=LOC_ID_TEST,
                product_id=PRODUCT_ID_TEST,
                product=product,
                update_mask=None,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        err = cm.exception
        self.assertTrue(err)
        self.assertIn(
            "The Product name provided in the object ({}) is different than the name created from the input "
            "parameters ({}). Please either: 1) Remove the Product name, 2) Remove the location and product_"
            "id parameters, 3) Unify the Product name and input parameters.".format(
                explicit_p_name, template_p_name
            ),
            str(err),
        )
        update_product_method.assert_not_called()

    @mock.patch('airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_delete_product(self, get_conn):
        # Given
        delete_product_method = get_conn.return_value.delete_product
        delete_product_method.return_value = None
        name = ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST)
        # When
        response = self.hook.delete_product(
            location=LOC_ID_TEST, product_id=PRODUCT_ID_TEST, project_id=PROJECT_ID_TEST
        )
        # Then
        self.assertIsNone(response)
        delete_product_method.assert_called_once_with(name=name, retry=None, timeout=None, metadata=None)
Example #2
0
class CloudVisionProductUpdateOperator(BaseOperator):
    """
    Makes changes to a Product resource. Only the display_name, description, and labels fields can be
    updated right now.

    If labels are updated, the change will not be reflected in queries until the next index time.

    .. note:: To locate the `Product` resource, its `name` in the form
        `projects/PROJECT_ID/locations/LOC_ID/products/PRODUCT_ID` is necessary.

    You can provide the `name` directly as an attribute of the `product` object. However, you can leave it
    blank and provide `location` and `product_id` instead (and optionally `project_id` - if not present,
    the connection default will be used) and the `name` will be created by the operator itself.

    This mechanism exists for your convenience, to allow leaving the `project_id` empty and having Airflow
    use the connection default `project_id`.

    Possible errors related to the provided `Product`:

    - Returns NOT_FOUND if the Product does not exist.
    - Returns INVALID_ARGUMENT if display_name is present in update_mask but is missing from the request or
        longer than 4096 characters.
    - Returns INVALID_ARGUMENT if description is present in update_mask but is longer than 4096 characters.
    - Returns INVALID_ARGUMENT if product_category is present in update_mask.

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudVisionProductUpdateOperator`

    :param product: (Required) The Product resource which replaces the one on the server. product.name is
        immutable. If a dict is provided, it must be of the same form as the protobuf message `Product`.
    :type product: dict or google.cloud.vision_v1.types.ProductSet
    :param location: (Optional) The region where the Product is located. Valid regions (as of 2019-02-05) are:
        us-east1, us-west1, europe-west1, asia-east1
    :type location: str
    :param product_id: (Optional) The resource id of this Product.
    :type product_id: str
    :param project_id: (Optional) The project in which the Product is located. If set to None or
        missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param update_mask: (Optional) The `FieldMask` that specifies which fields to update. If update_mask
        isn’t specified, all mutable fields are to be updated. Valid mask paths include product_labels,
        display_name, and description. If a dict is provided, it must be of the same form as the protobuf
        message `FieldMask`.
    :type update_mask: dict or google.cloud.vision_v1.types.FieldMask
    :param retry: (Optional) A retry object used to retry requests. If `None` is
        specified, requests will not be retried.
    :type retry: google.api_core.retry.Retry
    :param timeout: (Optional) The amount of time, in seconds, to wait for the request to
        complete. Note that if retry is specified, the timeout applies to each individual
        attempt.
    :type timeout: float
    :param metadata: (Optional) Additional metadata that is provided to the method.
    :type metadata: Sequence[Tuple[str, str]]
    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """

    # [START vision_product_update_template_fields]
    template_fields = ('location', 'project_id', 'product_id', 'gcp_conn_id')
    # [END vision_product_update_template_fields]

    @apply_defaults
    def __init__(
        self,
        product,
        location=None,
        product_id=None,
        project_id=None,
        update_mask=None,
        retry=None,
        timeout=None,
        metadata=None,
        gcp_conn_id='google_cloud_default',
        *args,
        **kwargs
    ):
        super(CloudVisionProductUpdateOperator, self).__init__(*args, **kwargs)
        self.product = product
        self.location = location
        self.product_id = product_id
        self.project_id = project_id
        self.update_mask = update_mask
        self.retry = retry
        self.timeout = timeout
        self.metadata = metadata
        self.gcp_conn_id = gcp_conn_id
        self._hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id)

    def execute(self, context):
        return self._hook.update_product(
            product=self.product,
            location=self.location,
            product_id=self.product_id,
            project_id=self.project_id,
            update_mask=self.update_mask,
            retry=self.retry,
            timeout=self.timeout,
            metadata=self.metadata,
        )
Example #3
0
class TestGcpVisionHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(
                'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.__init__',
                new=mock_base_gcp_hook_default_project_id,
        ):
            self.hook = CloudVisionHook(gcp_conn_id='test')

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_productset_explicit_id(self, get_conn):
        # Given
        create_product_set_method = get_conn.return_value.create_product_set
        create_product_set_method.return_value = None
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST,
                                                   LOC_ID_TEST)
        product_set = ProductSet()
        # When
        result = self.hook.create_product_set(
            location=LOC_ID_TEST,
            product_set_id=PRODUCTSET_ID_TEST,
            product_set=product_set,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

        # Then
        # ProductSet ID was provided explicitly in the method call above, should be returned from the method
        self.assertEqual(result, PRODUCTSET_ID_TEST)
        create_product_set_method.assert_called_once_with(
            parent=parent,
            product_set=product_set,
            product_set_id=PRODUCTSET_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_productset_autogenerated_id(self, get_conn):
        # Given
        autogenerated_id = 'autogen-id'
        response_product_set = ProductSet(
            name=ProductSearchClient.product_set_path(
                PROJECT_ID_TEST, LOC_ID_TEST, autogenerated_id))
        create_product_set_method = get_conn.return_value.create_product_set
        create_product_set_method.return_value = response_product_set
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST,
                                                   LOC_ID_TEST)
        product_set = ProductSet()
        # When
        result = self.hook.create_product_set(location=LOC_ID_TEST,
                                              product_set_id=None,
                                              product_set=product_set,
                                              project_id=PROJECT_ID_TEST)
        # Then
        # ProductSet ID was not provided in the method call above. Should be extracted from the API response
        # and returned.
        self.assertEqual(result, autogenerated_id)
        create_product_set_method.assert_called_once_with(
            parent=parent,
            product_set=product_set,
            product_set_id=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_productset_autogenerated_id_wrong_api_response(
            self, get_conn):
        # Given
        response_product_set = None
        create_product_set_method = get_conn.return_value.create_product_set
        create_product_set_method.return_value = response_product_set
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST,
                                                   LOC_ID_TEST)
        product_set = ProductSet()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.create_product_set(
                location=LOC_ID_TEST,
                product_set_id=None,
                product_set=product_set,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        # Then
        # API response was wrong (None) and thus ProductSet ID extraction should fail.
        err = cm.exception
        self.assertIn('Unable to get name from response...', str(err))
        create_product_set_method.assert_called_once_with(
            parent=parent,
            product_set=product_set,
            product_set_id=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_get_productset(self, get_conn):
        # Given
        name = ProductSearchClient.product_set_path(PROJECT_ID_TEST,
                                                    LOC_ID_TEST,
                                                    PRODUCTSET_ID_TEST)
        response_product_set = ProductSet(name=name)
        get_product_set_method = get_conn.return_value.get_product_set
        get_product_set_method.return_value = response_product_set
        # When
        response = self.hook.get_product_set(location=LOC_ID_TEST,
                                             product_set_id=PRODUCTSET_ID_TEST,
                                             project_id=PROJECT_ID_TEST)
        # Then
        self.assertTrue(response)
        self.assertEqual(response, MessageToDict(response_product_set))
        get_product_set_method.assert_called_once_with(name=name,
                                                       retry=None,
                                                       timeout=None,
                                                       metadata=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_productset_no_explicit_name(self, get_conn):
        # Given
        product_set = ProductSet()
        update_product_set_method = get_conn.return_value.update_product_set
        update_product_set_method.return_value = product_set
        productset_name = ProductSearchClient.product_set_path(
            PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST)
        # When
        result = self.hook.update_product_set(
            location=LOC_ID_TEST,
            product_set_id=PRODUCTSET_ID_TEST,
            product_set=product_set,
            update_mask=None,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )
        # Then
        self.assertEqual(result, MessageToDict(product_set))
        update_product_set_method.assert_called_once_with(
            product_set=ProductSet(name=productset_name),
            metadata=None,
            retry=None,
            timeout=None,
            update_mask=None,
        )

    @parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST),
                           (LOC_ID_TEST, None)])
    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_productset_no_explicit_name_and_missing_params_for_constructed_name(
            self, location, product_set_id, get_conn):
        # Given
        update_product_set_method = get_conn.return_value.update_product_set
        update_product_set_method.return_value = None
        product_set = ProductSet()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.update_product_set(
                location=location,
                product_set_id=product_set_id,
                product_set=product_set,
                update_mask=None,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        err = cm.exception
        self.assertTrue(err)
        self.assertIn(
            ERR_UNABLE_TO_CREATE.format(label='ProductSet',
                                        id_label='productset_id'), str(err))
        update_product_set_method.assert_not_called()

    @parameterized.expand([(None, None), (None, PRODUCTSET_ID_TEST),
                           (LOC_ID_TEST, None)])
    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_productset_explicit_name_missing_params_for_constructed_name(
            self, location, product_set_id, get_conn):
        # Given
        explicit_ps_name = ProductSearchClient.product_set_path(
            PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCTSET_ID_TEST_2)
        product_set = ProductSet(name=explicit_ps_name)
        update_product_set_method = get_conn.return_value.update_product_set
        update_product_set_method.return_value = product_set
        # When
        result = self.hook.update_product_set(
            location=location,
            product_set_id=product_set_id,
            product_set=product_set,
            update_mask=None,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )
        # Then
        self.assertEqual(result, MessageToDict(product_set))
        update_product_set_method.assert_called_once_with(
            product_set=ProductSet(name=explicit_ps_name),
            metadata=None,
            retry=None,
            timeout=None,
            update_mask=None,
        )

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_productset_explicit_name_different_from_constructed(
            self, get_conn):
        # Given
        update_product_set_method = get_conn.return_value.update_product_set
        update_product_set_method.return_value = None
        explicit_ps_name = ProductSearchClient.product_set_path(
            PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCTSET_ID_TEST_2)
        product_set = ProductSet(name=explicit_ps_name)
        template_ps_name = ProductSearchClient.product_set_path(
            PROJECT_ID_TEST, LOC_ID_TEST, PRODUCTSET_ID_TEST)
        # When
        # Location and product_set_id are passed in addition to a ProductSet with an explicit name,
        # but both names differ (constructed != explicit).
        # Should throw AirflowException in this case.
        with self.assertRaises(AirflowException) as cm:
            self.hook.update_product_set(
                location=LOC_ID_TEST,
                product_set_id=PRODUCTSET_ID_TEST,
                product_set=product_set,
                update_mask=None,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        err = cm.exception
        # self.assertIn("The required parameter 'project_id' is missing", str(err))
        self.assertTrue(err)
        self.assertIn(
            ERR_DIFF_NAMES.format(explicit_name=explicit_ps_name,
                                  constructed_name=template_ps_name,
                                  label="ProductSet",
                                  id_label="productset_id"), str(err))
        update_product_set_method.assert_not_called()

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_delete_productset(self, get_conn):
        # Given
        delete_product_set_method = get_conn.return_value.delete_product_set
        delete_product_set_method.return_value = None
        name = ProductSearchClient.product_set_path(PROJECT_ID_TEST,
                                                    LOC_ID_TEST,
                                                    PRODUCTSET_ID_TEST)
        # When
        response = self.hook.delete_product_set(
            location=LOC_ID_TEST,
            product_set_id=PRODUCTSET_ID_TEST,
            project_id=PROJECT_ID_TEST)
        # Then
        self.assertIsNone(response)
        delete_product_set_method.assert_called_once_with(name=name,
                                                          retry=None,
                                                          timeout=None,
                                                          metadata=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn', **{
            'return_value.create_reference_image.return_value':
            REFERENCE_IMAGE_TEST
        })
    def test_create_reference_image_explicit_id(self, get_conn):
        # Given
        create_reference_image_method = get_conn.return_value.create_reference_image

        # When
        result = self.hook.create_reference_image(
            project_id=PROJECT_ID_TEST,
            location=LOC_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            reference_image=REFERENCE_IMAGE_WITHOUT_ID_NAME,
            reference_image_id=REFERENCE_IMAGE_ID_TEST,
        )
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        self.assertEqual(result, REFERENCE_IMAGE_ID_TEST)
        create_reference_image_method.assert_called_once_with(
            parent=PRODUCT_NAME,
            reference_image=REFERENCE_IMAGE_WITHOUT_ID_NAME,
            reference_image_id=REFERENCE_IMAGE_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn', **{
            'return_value.create_reference_image.return_value':
            REFERENCE_IMAGE_TEST
        })
    def test_create_reference_image_autogenerated_id(self, get_conn):
        # Given
        create_reference_image_method = get_conn.return_value.create_reference_image

        # When
        result = self.hook.create_reference_image(
            project_id=PROJECT_ID_TEST,
            location=LOC_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            reference_image=REFERENCE_IMAGE_TEST,
            reference_image_id=REFERENCE_IMAGE_ID_TEST,
        )
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        self.assertEqual(result, REFERENCE_IMAGE_GEN_ID_TEST)
        create_reference_image_method.assert_called_once_with(
            parent=PRODUCT_NAME,
            reference_image=REFERENCE_IMAGE_TEST,
            reference_image_id=REFERENCE_IMAGE_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_add_product_to_product_set(self, get_conn):
        # Given
        add_product_to_product_set_method = get_conn.return_value.add_product_to_product_set

        # When
        self.hook.add_product_to_product_set(
            product_set_id=PRODUCTSET_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            location=LOC_ID_TEST,
            project_id=PROJECT_ID_TEST,
        )
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        add_product_to_product_set_method.assert_called_once_with(
            name=PRODUCTSET_NAME_TEST,
            product=PRODUCT_NAME_TEST,
            retry=None,
            timeout=None,
            metadata=None)

    # remove_product_from_product_set
    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_remove_product_from_product_set(self, get_conn):
        # Given
        remove_product_from_product_set_method = get_conn.return_value.remove_product_from_product_set

        # When
        self.hook.remove_product_from_product_set(
            product_set_id=PRODUCTSET_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            location=LOC_ID_TEST,
            project_id=PROJECT_ID_TEST,
        )
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        remove_product_from_product_set_method.assert_called_once_with(
            name=PRODUCTSET_NAME_TEST,
            product=PRODUCT_NAME_TEST,
            retry=None,
            timeout=None,
            metadata=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client'
    )
    def test_annotate_image(self, annotator_client_mock):
        # Given
        annotate_image_method = annotator_client_mock.annotate_image

        # When
        self.hook.annotate_image(request=ANNOTATE_IMAGE_REQUEST)
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        annotate_image_method.assert_called_once_with(
            request=ANNOTATE_IMAGE_REQUEST, retry=None, timeout=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client'
    )
    def test_batch_annotate_images(self, annotator_client_mock):
        # Given
        batch_annotate_images_method = annotator_client_mock.batch_annotate_images

        # When
        self.hook.batch_annotate_images(requests=BATCH_ANNOTATE_IMAGE_REQUEST)
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        batch_annotate_images_method.assert_called_once_with(
            requests=BATCH_ANNOTATE_IMAGE_REQUEST, retry=None, timeout=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_product_explicit_id(self, get_conn):
        # Given
        create_product_method = get_conn.return_value.create_product
        create_product_method.return_value = None
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST,
                                                   LOC_ID_TEST)
        product = Product()
        # When
        result = self.hook.create_product(location=LOC_ID_TEST,
                                          product_id=PRODUCT_ID_TEST,
                                          product=product,
                                          project_id=PROJECT_ID_TEST)
        # Then
        # Product ID was provided explicitly in the method call above, should be returned from the method
        self.assertEqual(result, PRODUCT_ID_TEST)
        create_product_method.assert_called_once_with(
            parent=parent,
            product=product,
            product_id=PRODUCT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_product_autogenerated_id(self, get_conn):
        # Given
        autogenerated_id = 'autogen-p-id'
        response_product = Product(name=ProductSearchClient.product_path(
            PROJECT_ID_TEST, LOC_ID_TEST, autogenerated_id))
        create_product_method = get_conn.return_value.create_product
        create_product_method.return_value = response_product
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST,
                                                   LOC_ID_TEST)
        product = Product()
        # When
        result = self.hook.create_product(location=LOC_ID_TEST,
                                          product_id=None,
                                          product=product,
                                          project_id=PROJECT_ID_TEST)
        # Then
        # Product ID was not provided in the method call above. Should be extracted from the API response
        # and returned.
        self.assertEqual(result, autogenerated_id)
        create_product_method.assert_called_once_with(parent=parent,
                                                      product=product,
                                                      product_id=None,
                                                      retry=None,
                                                      timeout=None,
                                                      metadata=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_product_autogenerated_id_wrong_name_in_response(
            self, get_conn):
        # Given
        wrong_name = 'wrong_name_not_a_correct_path'
        response_product = Product(name=wrong_name)
        create_product_method = get_conn.return_value.create_product
        create_product_method.return_value = response_product
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST,
                                                   LOC_ID_TEST)
        product = Product()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.create_product(location=LOC_ID_TEST,
                                     product_id=None,
                                     product=product,
                                     project_id=PROJECT_ID_TEST)
        # Then
        # API response was wrong (wrong name format) and thus ProductSet ID extraction should fail.
        err = cm.exception
        self.assertIn('Unable to get id from name', str(err))
        create_product_method.assert_called_once_with(parent=parent,
                                                      product=product,
                                                      product_id=None,
                                                      retry=None,
                                                      timeout=None,
                                                      metadata=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_create_product_autogenerated_id_wrong_api_response(
            self, get_conn):
        # Given
        response_product = None
        create_product_method = get_conn.return_value.create_product
        create_product_method.return_value = response_product
        parent = ProductSearchClient.location_path(PROJECT_ID_TEST,
                                                   LOC_ID_TEST)
        product = Product()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.create_product(location=LOC_ID_TEST,
                                     product_id=None,
                                     product=product,
                                     project_id=PROJECT_ID_TEST)
        # Then
        # API response was wrong (None) and thus ProductSet ID extraction should fail.
        err = cm.exception
        self.assertIn('Unable to get name from response...', str(err))
        create_product_method.assert_called_once_with(parent=parent,
                                                      product=product,
                                                      product_id=None,
                                                      retry=None,
                                                      timeout=None,
                                                      metadata=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_product_no_explicit_name(self, get_conn):
        # Given
        product = Product()
        update_product_method = get_conn.return_value.update_product
        update_product_method.return_value = product
        product_name = ProductSearchClient.product_path(
            PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST)
        # When
        result = self.hook.update_product(
            location=LOC_ID_TEST,
            product_id=PRODUCT_ID_TEST,
            product=product,
            update_mask=None,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )
        # Then
        self.assertEqual(result, MessageToDict(product))
        update_product_method.assert_called_once_with(
            product=Product(name=product_name),
            metadata=None,
            retry=None,
            timeout=None,
            update_mask=None)

    @parameterized.expand([(None, None), (None, PRODUCT_ID_TEST),
                           (LOC_ID_TEST, None)])
    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_product_no_explicit_name_and_missing_params_for_constructed_name(
            self, location, product_id, get_conn):
        # Given
        update_product_method = get_conn.return_value.update_product
        update_product_method.return_value = None
        product = Product()
        # When
        with self.assertRaises(AirflowException) as cm:
            self.hook.update_product(
                location=location,
                product_id=product_id,
                product=product,
                update_mask=None,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        err = cm.exception
        self.assertTrue(err)
        self.assertIn(
            ERR_UNABLE_TO_CREATE.format(label='Product',
                                        id_label='product_id'),
            str(err),
        )
        update_product_method.assert_not_called()

    @parameterized.expand([(None, None), (None, PRODUCT_ID_TEST),
                           (LOC_ID_TEST, None)])
    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_product_explicit_name_missing_params_for_constructed_name(
            self, location, product_id, get_conn):
        # Given
        explicit_p_name = ProductSearchClient.product_path(
            PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCT_ID_TEST_2)
        product = Product(name=explicit_p_name)
        update_product_method = get_conn.return_value.update_product
        update_product_method.return_value = product
        # When
        result = self.hook.update_product(
            location=location,
            product_id=product_id,
            product=product,
            update_mask=None,
            project_id=PROJECT_ID_TEST,
            retry=None,
            timeout=None,
            metadata=None,
        )
        # Then
        self.assertEqual(result, MessageToDict(product))
        update_product_method.assert_called_once_with(
            product=Product(name=explicit_p_name),
            metadata=None,
            retry=None,
            timeout=None,
            update_mask=None)

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_update_product_explicit_name_different_from_constructed(
            self, get_conn):
        # Given
        update_product_method = get_conn.return_value.update_product
        update_product_method.return_value = None
        explicit_p_name = ProductSearchClient.product_path(
            PROJECT_ID_TEST_2, LOC_ID_TEST_2, PRODUCT_ID_TEST_2)
        product = Product(name=explicit_p_name)
        template_p_name = ProductSearchClient.product_path(
            PROJECT_ID_TEST, LOC_ID_TEST, PRODUCT_ID_TEST)
        # When
        # Location and product_id are passed in addition to a Product with an explicit name,
        # but both names differ (constructed != explicit).
        # Should throw AirflowException in this case.
        with self.assertRaises(AirflowException) as cm:
            self.hook.update_product(
                location=LOC_ID_TEST,
                product_id=PRODUCT_ID_TEST,
                product=product,
                update_mask=None,
                project_id=PROJECT_ID_TEST,
                retry=None,
                timeout=None,
                metadata=None,
            )
        err = cm.exception
        self.assertTrue(err)
        self.assertIn(
            ERR_DIFF_NAMES.format(explicit_name=explicit_p_name,
                                  constructed_name=template_p_name,
                                  label="Product",
                                  id_label="product_id"),
            str(err),
        )
        update_product_method.assert_not_called()

    @mock.patch(
        'airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.get_conn')
    def test_delete_product(self, get_conn):
        # Given
        delete_product_method = get_conn.return_value.delete_product
        delete_product_method.return_value = None
        name = ProductSearchClient.product_path(PROJECT_ID_TEST, LOC_ID_TEST,
                                                PRODUCT_ID_TEST)
        # When
        response = self.hook.delete_product(location=LOC_ID_TEST,
                                            product_id=PRODUCT_ID_TEST,
                                            project_id=PROJECT_ID_TEST)
        # Then
        self.assertIsNone(response)
        delete_product_method.assert_called_once_with(name=name,
                                                      retry=None,
                                                      timeout=None,
                                                      metadata=None)

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_detect_text(self, annotator_client_mock):
        # Given
        detect_text_method = annotator_client_mock.text_detection
        detect_text_method.return_value = AnnotateImageResponse(
            text_annotations=[EntityAnnotation(description="test", score=0.5)])

        # When
        self.hook.text_detection(image=DETECT_TEST_IMAGE)

        # Then
        detect_text_method.assert_called_once_with(image=DETECT_TEST_IMAGE,
                                                   max_results=None,
                                                   retry=None,
                                                   timeout=None)

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_detect_text_with_additional_properties(self,
                                                    annotator_client_mock):
        # Given
        detect_text_method = annotator_client_mock.text_detection
        detect_text_method.return_value = AnnotateImageResponse(
            text_annotations=[EntityAnnotation(description="test", score=0.5)])

        # When
        self.hook.text_detection(image=DETECT_TEST_IMAGE,
                                 additional_properties={
                                     "prop1": "test1",
                                     "prop2": "test2"
                                 })

        # Then
        detect_text_method.assert_called_once_with(image=DETECT_TEST_IMAGE,
                                                   max_results=None,
                                                   retry=None,
                                                   timeout=None,
                                                   prop1="test1",
                                                   prop2="test2")

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_detect_text_with_error_response(self, annotator_client_mock):
        # Given
        detect_text_method = annotator_client_mock.text_detection
        detect_text_method.return_value = AnnotateImageResponse(
            error={
                "code": 3,
                "message": "test error message"
            })

        # When
        with self.assertRaises(AirflowException) as msg:
            self.hook.text_detection(image=DETECT_TEST_IMAGE)

        err = msg.exception
        self.assertIn("test error message", str(err))

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_document_text_detection(self, annotator_client_mock):
        # Given
        document_text_detection_method = annotator_client_mock.document_text_detection
        document_text_detection_method.return_value = AnnotateImageResponse(
            text_annotations=[EntityAnnotation(description="test", score=0.5)])

        # When
        self.hook.document_text_detection(image=DETECT_TEST_IMAGE)

        # Then
        document_text_detection_method.assert_called_once_with(
            image=DETECT_TEST_IMAGE,
            max_results=None,
            retry=None,
            timeout=None)

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_document_text_detection_with_additional_properties(
            self, annotator_client_mock):
        # Given
        document_text_detection_method = annotator_client_mock.document_text_detection
        document_text_detection_method.return_value = AnnotateImageResponse(
            text_annotations=[EntityAnnotation(description="test", score=0.5)])

        # When
        self.hook.document_text_detection(image=DETECT_TEST_IMAGE,
                                          additional_properties={
                                              "prop1": "test1",
                                              "prop2": "test2"
                                          })

        # Then
        document_text_detection_method.assert_called_once_with(
            image=DETECT_TEST_IMAGE,
            max_results=None,
            retry=None,
            timeout=None,
            prop1="test1",
            prop2="test2")

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_detect_document_text_with_error_response(self,
                                                      annotator_client_mock):
        # Given
        detect_text_method = annotator_client_mock.document_text_detection
        detect_text_method.return_value = AnnotateImageResponse(
            error={
                "code": 3,
                "message": "test error message"
            })

        # When
        with self.assertRaises(AirflowException) as msg:
            self.hook.document_text_detection(image=DETECT_TEST_IMAGE)

        err = msg.exception
        self.assertIn("test error message", str(err))

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_label_detection(self, annotator_client_mock):
        # Given
        label_detection_method = annotator_client_mock.label_detection
        label_detection_method.return_value = AnnotateImageResponse(
            label_annotations=[
                EntityAnnotation(description="test", score=0.5)
            ])

        # When
        self.hook.label_detection(image=DETECT_TEST_IMAGE)

        # Then
        label_detection_method.assert_called_once_with(image=DETECT_TEST_IMAGE,
                                                       max_results=None,
                                                       retry=None,
                                                       timeout=None)

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_label_detection_with_additional_properties(
            self, annotator_client_mock):
        # Given
        label_detection_method = annotator_client_mock.label_detection
        label_detection_method.return_value = AnnotateImageResponse(
            label_annotations=[
                EntityAnnotation(description="test", score=0.5)
            ])

        # When
        self.hook.label_detection(image=DETECT_TEST_IMAGE,
                                  additional_properties={
                                      "prop1": "test1",
                                      "prop2": "test2"
                                  })

        # Then
        label_detection_method.assert_called_once_with(image=DETECT_TEST_IMAGE,
                                                       max_results=None,
                                                       retry=None,
                                                       timeout=None,
                                                       prop1="test1",
                                                       prop2="test2")

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_label_detection_with_error_response(self, annotator_client_mock):
        # Given
        detect_text_method = annotator_client_mock.label_detection
        detect_text_method.return_value = AnnotateImageResponse(
            error={
                "code": 3,
                "message": "test error message"
            })

        # When
        with self.assertRaises(AirflowException) as msg:
            self.hook.label_detection(image=DETECT_TEST_IMAGE)

        err = msg.exception
        self.assertIn("test error message", str(err))

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_safe_search_detection(self, annotator_client_mock):
        # Given
        safe_search_detection_method = annotator_client_mock.safe_search_detection
        safe_search_detection_method.return_value = AnnotateImageResponse(
            safe_search_annotation=SafeSearchAnnotation(
                adult="VERY_UNLIKELY",
                spoof="VERY_UNLIKELY",
                medical="VERY_UNLIKELY",
                violence="VERY_UNLIKELY",
                racy="VERY_UNLIKELY",
            ))

        # When
        self.hook.safe_search_detection(image=DETECT_TEST_IMAGE)

        # Then
        safe_search_detection_method.assert_called_once_with(
            image=DETECT_TEST_IMAGE,
            max_results=None,
            retry=None,
            timeout=None)

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_safe_search_detection_with_additional_properties(
            self, annotator_client_mock):
        # Given
        safe_search_detection_method = annotator_client_mock.safe_search_detection
        safe_search_detection_method.return_value = AnnotateImageResponse(
            safe_search_annotation=SafeSearchAnnotation(
                adult="VERY_UNLIKELY",
                spoof="VERY_UNLIKELY",
                medical="VERY_UNLIKELY",
                violence="VERY_UNLIKELY",
                racy="VERY_UNLIKELY",
            ))

        # When
        self.hook.safe_search_detection(image=DETECT_TEST_IMAGE,
                                        additional_properties={
                                            "prop1": "test1",
                                            "prop2": "test2"
                                        })

        # Then
        safe_search_detection_method.assert_called_once_with(
            image=DETECT_TEST_IMAGE,
            max_results=None,
            retry=None,
            timeout=None,
            prop1="test1",
            prop2="test2")

    @mock.patch(
        "airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook.annotator_client"
    )
    def test_safe_search_detection_with_error_response(self,
                                                       annotator_client_mock):
        # Given
        detect_text_method = annotator_client_mock.safe_search_detection
        detect_text_method.return_value = AnnotateImageResponse(
            error={
                "code": 3,
                "message": "test error message"
            })

        # When
        with self.assertRaises(AirflowException) as msg:
            self.hook.safe_search_detection(image=DETECT_TEST_IMAGE)

        err = msg.exception
        self.assertIn("test error message", str(err))
Example #4
0
class CloudVisionProductUpdateOperator(BaseOperator):
    """
    Makes changes to a Product resource. Only the display_name, description, and labels fields can be
    updated right now.

    If labels are updated, the change will not be reflected in queries until the next index time.

    .. note:: To locate the `Product` resource, its `name` in the form
        `projects/PROJECT_ID/locations/LOC_ID/products/PRODUCT_ID` is necessary.

    You can provide the `name` directly as an attribute of the `product` object. However, you can leave it
    blank and provide `location` and `product_id` instead (and optionally `project_id` - if not present,
    the connection default will be used) and the `name` will be created by the operator itself.

    This mechanism exists for your convenience, to allow leaving the `project_id` empty and having Airflow
    use the connection default `project_id`.

    Possible errors related to the provided `Product`:

    - Returns `NOT_FOUND` if the Product does not exist.
    - Returns `INVALID_ARGUMENT` if `display_name` is present in update_mask but is missing from the request
        or longer than 4096 characters.
    - Returns `INVALID_ARGUMENT` if `description` is present in update_mask but is longer than 4096
        characters.
    - Returns `INVALID_ARGUMENT` if `product_category` is present in update_mask.

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:CloudVisionProductUpdateOperator`

    :param product: (Required) The Product resource which replaces the one on the server. product.name is
        immutable. If a dict is provided, it must be of the same form as the protobuf message `Product`.
    :type product: dict or google.cloud.vision_v1.types.ProductSet
    :param location: (Optional) The region where the Product is located. Valid regions (as of 2019-02-05) are:
        us-east1, us-west1, europe-west1, asia-east1
    :type location: str
    :param product_id: (Optional) The resource id of this Product.
    :type product_id: str
    :param project_id: (Optional) The project in which the Product is located. If set to None or
        missing, the default project_id from the GCP connection is used.
    :type project_id: str
    :param update_mask: (Optional) The `FieldMask` that specifies which fields to update. If update_mask
        isn’t specified, all mutable fields are to be updated. Valid mask paths include product_labels,
        display_name, and description. If a dict is provided, it must be of the same form as the protobuf
        message `FieldMask`.
    :type update_mask: dict or google.cloud.vision_v1.types.FieldMask
    :param retry: (Optional) A retry object used to retry requests. If `None` is
        specified, requests will not be retried.
    :type retry: google.api_core.retry.Retry
    :param timeout: (Optional) The amount of time, in seconds, to wait for the request to
        complete. Note that if retry is specified, the timeout applies to each individual
        attempt.
    :type timeout: float
    :param metadata: (Optional) Additional metadata that is provided to the method.
    :type metadata: sequence[tuple[str, str]]
    :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    """
    # [START vision_product_update_template_fields]
    template_fields = ('location', 'project_id', 'product_id', 'gcp_conn_id')
    # [END vision_product_update_template_fields]

    @apply_defaults
    def __init__(
        self,
        product,
        location=None,
        product_id=None,
        project_id=None,
        update_mask=None,
        retry=None,
        timeout=None,
        metadata=None,
        gcp_conn_id='google_cloud_default',
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.product = product
        self.location = location
        self.product_id = product_id
        self.project_id = project_id
        self.update_mask = update_mask
        self.retry = retry
        self.timeout = timeout
        self.metadata = metadata
        self.gcp_conn_id = gcp_conn_id
        self._hook = CloudVisionHook(gcp_conn_id=self.gcp_conn_id)

    def execute(self, context):
        return self._hook.update_product(
            product=self.product,
            location=self.location,
            product_id=self.product_id,
            project_id=self.project_id,
            update_mask=self.update_mask,
            retry=self.retry,
            timeout=self.timeout,
            metadata=self.metadata,
        )