Пример #1
0
    def execute(self, context: 'Context'):
        hook = EndpointServiceHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )

        self.log.info("Deploying model")
        operation = hook.deploy_model(
            project_id=self.project_id,
            region=self.region,
            endpoint=self.endpoint_id,
            deployed_model=self.deployed_model,
            traffic_split=self.traffic_split,
            retry=self.retry,
            timeout=self.timeout,
            metadata=self.metadata,
        )
        result = hook.wait_for_operation(timeout=self.timeout,
                                         operation=operation)

        deploy_model = endpoint_service.DeployModelResponse.to_dict(result)
        deployed_model_id = hook.extract_deployed_model_id(deploy_model)
        self.log.info("Model was deployed. Deployed Model ID: %s",
                      deployed_model_id)

        self.xcom_push(context,
                       key="deployed_model_id",
                       value=deployed_model_id)
        VertexAIModelLink.persist(context=context,
                                  task_instance=self,
                                  model_id=deployed_model_id)
        return deploy_model
Пример #2
0
    def execute(self, context: "Context"):
        hook = ModelServiceHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        self.log.info("Upload model")
        operation = hook.upload_model(
            project_id=self.project_id,
            region=self.region,
            model=self.model,
            retry=self.retry,
            timeout=self.timeout,
            metadata=self.metadata,
        )
        result = hook.wait_for_operation(timeout=self.timeout,
                                         operation=operation)

        model_resp = model_service.UploadModelResponse.to_dict(result)
        model_id = hook.extract_model_id(model_resp)
        self.log.info("Model was uploaded. Model ID: %s", model_id)

        self.xcom_push(context, key="model_id", value=model_id)
        VertexAIModelLink.persist(context=context,
                                  task_instance=self,
                                  model_id=model_id)
        return model_resp
Пример #3
0
    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_video_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.VideoDataset(dataset_name=self.dataset_id),
            prediction_type=self.prediction_type,
            model_type=self.model_type,
            labels=self.labels,
            training_encryption_spec_key_name=self.training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            test_fraction_split=self.test_fraction_split,
            training_filter_split=self.training_filter_split,
            test_filter_split=self.test_filter_split,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
        return result
Пример #4
0
    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_image_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.ImageDataset(dataset_name=self.dataset_id),
            prediction_type=self.prediction_type,
            multi_label=self.multi_label,
            model_type=self.model_type,
            base_model=self.base_model,
            labels=self.labels,
            training_encryption_spec_key_name=self.training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            validation_fraction_split=self.validation_fraction_split,
            test_fraction_split=self.test_fraction_split,
            training_filter_split=self.training_filter_split,
            validation_filter_split=self.validation_filter_split,
            test_filter_split=self.test_filter_split,
            budget_milli_node_hours=self.budget_milli_node_hours,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            disable_early_stopping=self.disable_early_stopping,
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
        return result
Пример #5
0
    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_forecasting_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.TimeSeriesDataset(dataset_name=self.dataset_id),
            target_column=self.target_column,
            time_column=self.time_column,
            time_series_identifier_column=self.time_series_identifier_column,
            unavailable_at_forecast_columns=self.
            unavailable_at_forecast_columns,
            available_at_forecast_columns=self.available_at_forecast_columns,
            forecast_horizon=self.forecast_horizon,
            data_granularity_unit=self.data_granularity_unit,
            data_granularity_count=self.data_granularity_count,
            optimization_objective=self.optimization_objective,
            column_specs=self.column_specs,
            column_transformations=self.column_transformations,
            labels=self.labels,
            training_encryption_spec_key_name=self.
            training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            validation_fraction_split=self.validation_fraction_split,
            test_fraction_split=self.test_fraction_split,
            predefined_split_column_name=self.predefined_split_column_name,
            weight_column=self.weight_column,
            time_series_attribute_columns=self.time_series_attribute_columns,
            context_window=self.context_window,
            export_evaluated_data_items=self.export_evaluated_data_items,
            export_evaluated_data_items_bigquery_destination_uri=(
                self.export_evaluated_data_items_bigquery_destination_uri),
            export_evaluated_data_items_override_destination=(
                self.export_evaluated_data_items_override_destination),
            quantiles=self.quantiles,
            validation_options=self.validation_options,
            budget_milli_node_hours=self.budget_milli_node_hours,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context,
                                  task_instance=self,
                                  model_id=model_id)
        return result
Пример #6
0
    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_tabular_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.TabularDataset(dataset_name=self.dataset_id),
            target_column=self.target_column,
            optimization_prediction_type=self.optimization_prediction_type,
            optimization_objective=self.optimization_objective,
            column_specs=self.column_specs,
            column_transformations=self.column_transformations,
            optimization_objective_recall_value=self.
            optimization_objective_recall_value,
            optimization_objective_precision_value=self.
            optimization_objective_precision_value,
            labels=self.labels,
            training_encryption_spec_key_name=self.
            training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            validation_fraction_split=self.validation_fraction_split,
            test_fraction_split=self.test_fraction_split,
            predefined_split_column_name=self.predefined_split_column_name,
            timestamp_split_column_name=self.timestamp_split_column_name,
            weight_column=self.weight_column,
            budget_milli_node_hours=self.budget_milli_node_hours,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            disable_early_stopping=self.disable_early_stopping,
            export_evaluated_data_items=self.export_evaluated_data_items,
            export_evaluated_data_items_bigquery_destination_uri=(
                self.export_evaluated_data_items_bigquery_destination_uri),
            export_evaluated_data_items_override_destination=(
                self.export_evaluated_data_items_override_destination),
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context,
                                  task_instance=self,
                                  model_id=model_id)
        return result
Пример #7
0
class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
    """Create Auto ML Video Training job"""

    template_fields = [
        'region',
        'impersonation_chain',
    ]
    operator_extra_links = (VertexAIModelLink(), )

    def __init__(
        self,
        *,
        dataset_id: str,
        prediction_type: str = "classification",
        model_type: str = "CLOUD",
        training_filter_split: Optional[str] = None,
        test_filter_split: Optional[str] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.dataset_id = dataset_id
        self.prediction_type = prediction_type
        self.model_type = model_type
        self.training_filter_split = training_filter_split
        self.test_filter_split = test_filter_split

    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_video_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.VideoDataset(dataset_name=self.dataset_id),
            prediction_type=self.prediction_type,
            model_type=self.model_type,
            labels=self.labels,
            training_encryption_spec_key_name=self.
            training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            test_fraction_split=self.test_fraction_split,
            training_filter_split=self.training_filter_split,
            test_filter_split=self.test_filter_split,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context,
                                  task_instance=self,
                                  model_id=model_id)
        return result
Пример #8
0
class DeployModelOperator(BaseOperator):
    """
    Deploys a Model into this Endpoint, creating a DeployedModel within it.

    :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
    :param region: Required. The ID of the Google Cloud region that the service belongs to.
    :param endpoint_id:  Required. The name of the Endpoint resource into which to deploy a Model. Format:
        ``projects/{project}/locations/{location}/endpoints/{endpoint}``
    :param deployed_model:  Required. The DeployedModel to be created within the Endpoint. Note that
        [Endpoint.traffic_split][google.cloud.aiplatform.v1.Endpoint.traffic_split] must be updated for
        the DeployedModel to start receiving traffic, either as part of this call, or via
        [EndpointService.UpdateEndpoint][google.cloud.aiplatform.v1.EndpointService.UpdateEndpoint].
    :param traffic_split:  A map from a DeployedModel's ID to the percentage of this Endpoint's traffic
        that should be forwarded to that DeployedModel.

        If this field is non-empty, then the Endpoint's
        [traffic_split][google.cloud.aiplatform.v1.Endpoint.traffic_split] will be overwritten with it. To
        refer to the ID of the just being deployed Model, a "0" should be used, and the actual ID of the
        new DeployedModel will be filled in its place by this method. The traffic percentage values must
        add up to 100.

        If this field is empty, then the Endpoint's
        [traffic_split][google.cloud.aiplatform.v1.Endpoint.traffic_split] is not updated.
    :param retry: Designation of what errors, if any, should be retried.
    :param timeout: The timeout for this request.
    :param metadata: Strings which should be sent along with the request as metadata.
    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
    :param delegate_to: The account to impersonate using domain-wide delegation of authority,
        if any. For this to work, the service account making the request must have
        domain-wide delegation enabled.
    :param impersonation_chain: Optional service account to impersonate using short-term
        credentials, or chained list of accounts required to get the access_token
        of the last account in the list, which will be impersonated in the request.
        If set as a string, the account must grant the originating account
        the Service Account Token Creator IAM role.
        If set as a sequence, the identities from the list must grant
        Service Account Token Creator IAM role to the directly preceding identity, with first
        account from the list granting this role to the originating account (templated).
    """

    template_fields = ("region", "endpoint_id", "project_id",
                       "impersonation_chain")
    operator_extra_links = (VertexAIModelLink(), )

    def __init__(
        self,
        *,
        region: str,
        project_id: str,
        endpoint_id: str,
        deployed_model: Union[DeployedModel, Dict],
        traffic_split: Optional[Union[Sequence, Dict]] = None,
        retry: Union[Retry, _MethodDefault] = DEFAULT,
        timeout: Optional[float] = None,
        metadata: Sequence[Tuple[str, str]] = (),
        gcp_conn_id: str = "google_cloud_default",
        delegate_to: Optional[str] = None,
        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.region = region
        self.project_id = project_id
        self.endpoint_id = endpoint_id
        self.deployed_model = deployed_model
        self.traffic_split = traffic_split
        self.retry = retry
        self.timeout = timeout
        self.metadata = metadata
        self.gcp_conn_id = gcp_conn_id
        self.delegate_to = delegate_to
        self.impersonation_chain = impersonation_chain

    def execute(self, context: 'Context'):
        hook = EndpointServiceHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )

        self.log.info("Deploying model")
        operation = hook.deploy_model(
            project_id=self.project_id,
            region=self.region,
            endpoint=self.endpoint_id,
            deployed_model=self.deployed_model,
            traffic_split=self.traffic_split,
            retry=self.retry,
            timeout=self.timeout,
            metadata=self.metadata,
        )
        result = hook.wait_for_operation(timeout=self.timeout,
                                         operation=operation)

        deploy_model = endpoint_service.DeployModelResponse.to_dict(result)
        deployed_model_id = hook.extract_deployed_model_id(deploy_model)
        self.log.info("Model was deployed. Deployed Model ID: %s",
                      deployed_model_id)

        self.xcom_push(context,
                       key="deployed_model_id",
                       value=deployed_model_id)
        VertexAIModelLink.persist(context=context,
                                  task_instance=self,
                                  model_id=deployed_model_id)
        return deploy_model
Пример #9
0
class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
    """Create AutoML Forecasting Training job"""

    template_fields = [
        'region',
        'impersonation_chain',
    ]
    operator_extra_links = (VertexAIModelLink(),)

    def __init__(
        self,
        *,
        dataset_id: str,
        target_column: str,
        time_column: str,
        time_series_identifier_column: str,
        unavailable_at_forecast_columns: List[str],
        available_at_forecast_columns: List[str],
        forecast_horizon: int,
        data_granularity_unit: str,
        data_granularity_count: int,
        optimization_objective: Optional[str] = None,
        column_specs: Optional[Dict[str, str]] = None,
        column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
        validation_fraction_split: Optional[float] = None,
        predefined_split_column_name: Optional[str] = None,
        weight_column: Optional[str] = None,
        time_series_attribute_columns: Optional[List[str]] = None,
        context_window: Optional[int] = None,
        export_evaluated_data_items: bool = False,
        export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
        export_evaluated_data_items_override_destination: bool = False,
        quantiles: Optional[List[float]] = None,
        validation_options: Optional[str] = None,
        budget_milli_node_hours: int = 1000,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.dataset_id = dataset_id
        self.target_column = target_column
        self.time_column = time_column
        self.time_series_identifier_column = time_series_identifier_column
        self.unavailable_at_forecast_columns = unavailable_at_forecast_columns
        self.available_at_forecast_columns = available_at_forecast_columns
        self.forecast_horizon = forecast_horizon
        self.data_granularity_unit = data_granularity_unit
        self.data_granularity_count = data_granularity_count
        self.optimization_objective = optimization_objective
        self.column_specs = column_specs
        self.column_transformations = column_transformations
        self.validation_fraction_split = validation_fraction_split
        self.predefined_split_column_name = predefined_split_column_name
        self.weight_column = weight_column
        self.time_series_attribute_columns = time_series_attribute_columns
        self.context_window = context_window
        self.export_evaluated_data_items = export_evaluated_data_items
        self.export_evaluated_data_items_bigquery_destination_uri = (
            export_evaluated_data_items_bigquery_destination_uri
        )
        self.export_evaluated_data_items_override_destination = (
            export_evaluated_data_items_override_destination
        )
        self.quantiles = quantiles
        self.validation_options = validation_options
        self.budget_milli_node_hours = budget_milli_node_hours

    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_forecasting_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.TimeSeriesDataset(dataset_name=self.dataset_id),
            target_column=self.target_column,
            time_column=self.time_column,
            time_series_identifier_column=self.time_series_identifier_column,
            unavailable_at_forecast_columns=self.unavailable_at_forecast_columns,
            available_at_forecast_columns=self.available_at_forecast_columns,
            forecast_horizon=self.forecast_horizon,
            data_granularity_unit=self.data_granularity_unit,
            data_granularity_count=self.data_granularity_count,
            optimization_objective=self.optimization_objective,
            column_specs=self.column_specs,
            column_transformations=self.column_transformations,
            labels=self.labels,
            training_encryption_spec_key_name=self.training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            validation_fraction_split=self.validation_fraction_split,
            test_fraction_split=self.test_fraction_split,
            predefined_split_column_name=self.predefined_split_column_name,
            weight_column=self.weight_column,
            time_series_attribute_columns=self.time_series_attribute_columns,
            context_window=self.context_window,
            export_evaluated_data_items=self.export_evaluated_data_items,
            export_evaluated_data_items_bigquery_destination_uri=(
                self.export_evaluated_data_items_bigquery_destination_uri
            ),
            export_evaluated_data_items_override_destination=(
                self.export_evaluated_data_items_override_destination
            ),
            quantiles=self.quantiles,
            validation_options=self.validation_options,
            budget_milli_node_hours=self.budget_milli_node_hours,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
        return result
Пример #10
0
class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
    """Create Auto ML Tabular Training job"""

    template_fields = [
        'region',
        'impersonation_chain',
    ]
    operator_extra_links = (VertexAIModelLink(),)

    def __init__(
        self,
        *,
        dataset_id: str,
        target_column: str,
        optimization_prediction_type: str,
        optimization_objective: Optional[str] = None,
        column_specs: Optional[Dict[str, str]] = None,
        column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
        optimization_objective_recall_value: Optional[float] = None,
        optimization_objective_precision_value: Optional[float] = None,
        validation_fraction_split: Optional[float] = None,
        predefined_split_column_name: Optional[str] = None,
        timestamp_split_column_name: Optional[str] = None,
        weight_column: Optional[str] = None,
        budget_milli_node_hours: int = 1000,
        disable_early_stopping: bool = False,
        export_evaluated_data_items: bool = False,
        export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
        export_evaluated_data_items_override_destination: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.dataset_id = dataset_id
        self.target_column = target_column
        self.optimization_prediction_type = optimization_prediction_type
        self.optimization_objective = optimization_objective
        self.column_specs = column_specs
        self.column_transformations = column_transformations
        self.optimization_objective_recall_value = optimization_objective_recall_value
        self.optimization_objective_precision_value = optimization_objective_precision_value
        self.validation_fraction_split = validation_fraction_split
        self.predefined_split_column_name = predefined_split_column_name
        self.timestamp_split_column_name = timestamp_split_column_name
        self.weight_column = weight_column
        self.budget_milli_node_hours = budget_milli_node_hours
        self.disable_early_stopping = disable_early_stopping
        self.export_evaluated_data_items = export_evaluated_data_items
        self.export_evaluated_data_items_bigquery_destination_uri = (
            export_evaluated_data_items_bigquery_destination_uri
        )
        self.export_evaluated_data_items_override_destination = (
            export_evaluated_data_items_override_destination
        )

    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_tabular_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.TabularDataset(dataset_name=self.dataset_id),
            target_column=self.target_column,
            optimization_prediction_type=self.optimization_prediction_type,
            optimization_objective=self.optimization_objective,
            column_specs=self.column_specs,
            column_transformations=self.column_transformations,
            optimization_objective_recall_value=self.optimization_objective_recall_value,
            optimization_objective_precision_value=self.optimization_objective_precision_value,
            labels=self.labels,
            training_encryption_spec_key_name=self.training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            validation_fraction_split=self.validation_fraction_split,
            test_fraction_split=self.test_fraction_split,
            predefined_split_column_name=self.predefined_split_column_name,
            timestamp_split_column_name=self.timestamp_split_column_name,
            weight_column=self.weight_column,
            budget_milli_node_hours=self.budget_milli_node_hours,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            disable_early_stopping=self.disable_early_stopping,
            export_evaluated_data_items=self.export_evaluated_data_items,
            export_evaluated_data_items_bigquery_destination_uri=(
                self.export_evaluated_data_items_bigquery_destination_uri
            ),
            export_evaluated_data_items_override_destination=(
                self.export_evaluated_data_items_override_destination
            ),
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
        return result
Пример #11
0
class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
    """Create Auto ML Image Training job"""

    template_fields = [
        'region',
        'impersonation_chain',
    ]
    operator_extra_links = (VertexAIModelLink(),)

    def __init__(
        self,
        *,
        dataset_id: str,
        prediction_type: str = "classification",
        multi_label: bool = False,
        model_type: str = "CLOUD",
        base_model: Optional[Model] = None,
        validation_fraction_split: Optional[float] = None,
        training_filter_split: Optional[str] = None,
        validation_filter_split: Optional[str] = None,
        test_filter_split: Optional[str] = None,
        budget_milli_node_hours: Optional[int] = None,
        disable_early_stopping: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.dataset_id = dataset_id
        self.prediction_type = prediction_type
        self.multi_label = multi_label
        self.model_type = model_type
        self.base_model = base_model
        self.validation_fraction_split = validation_fraction_split
        self.training_filter_split = training_filter_split
        self.validation_filter_split = validation_filter_split
        self.test_filter_split = test_filter_split
        self.budget_milli_node_hours = budget_milli_node_hours
        self.disable_early_stopping = disable_early_stopping

    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_image_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.ImageDataset(dataset_name=self.dataset_id),
            prediction_type=self.prediction_type,
            multi_label=self.multi_label,
            model_type=self.model_type,
            base_model=self.base_model,
            labels=self.labels,
            training_encryption_spec_key_name=self.training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            validation_fraction_split=self.validation_fraction_split,
            test_fraction_split=self.test_fraction_split,
            training_filter_split=self.training_filter_split,
            validation_filter_split=self.validation_filter_split,
            test_filter_split=self.test_filter_split,
            budget_milli_node_hours=self.budget_milli_node_hours,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            disable_early_stopping=self.disable_early_stopping,
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
        return result
Пример #12
0
class UploadModelOperator(BaseOperator):
    """
    Uploads a Model artifact into Vertex AI.

    :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
    :param region: Required. The ID of the Google Cloud region that the service belongs to.
    :param model:  Required. The Model to create.
    :param retry: Designation of what errors, if any, should be retried.
    :param timeout: The timeout for this request.
    :param metadata: Strings which should be sent along with the request as metadata.
    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
    :param delegate_to: The account to impersonate using domain-wide delegation of authority,
        if any. For this to work, the service account making the request must have
        domain-wide delegation enabled.
    :param impersonation_chain: Optional service account to impersonate using short-term
        credentials, or chained list of accounts required to get the access_token
        of the last account in the list, which will be impersonated in the request.
        If set as a string, the account must grant the originating account
        the Service Account Token Creator IAM role.
        If set as a sequence, the identities from the list must grant
        Service Account Token Creator IAM role to the directly preceding identity, with first
        account from the list granting this role to the originating account (templated).
    """

    template_fields = ("region", "project_id", "impersonation_chain")
    operator_extra_links = (VertexAIModelLink(), )

    def __init__(
        self,
        *,
        project_id: str,
        region: str,
        model: Union[Model, Dict],
        retry: Optional[Retry] = None,
        timeout: Optional[float] = None,
        metadata: Sequence[Tuple[str, str]] = (),
        gcp_conn_id: str = "google_cloud_default",
        delegate_to: Optional[str] = None,
        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.project_id = project_id
        self.region = region
        self.model = model
        self.retry = retry
        self.timeout = timeout
        self.metadata = metadata
        self.gcp_conn_id = gcp_conn_id
        self.delegate_to = delegate_to
        self.impersonation_chain = impersonation_chain

    def execute(self, context: "Context"):
        hook = ModelServiceHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        self.log.info("Upload model")
        operation = hook.upload_model(
            project_id=self.project_id,
            region=self.region,
            model=self.model,
            retry=self.retry,
            timeout=self.timeout,
            metadata=self.metadata,
        )
        result = hook.wait_for_operation(timeout=self.timeout,
                                         operation=operation)

        model_resp = model_service.UploadModelResponse.to_dict(result)
        model_id = hook.extract_model_id(model_resp)
        self.log.info("Model was uploaded. Model ID: %s", model_id)

        self.xcom_push(context, key="model_id", value=model_id)
        VertexAIModelLink.persist(context=context,
                                  task_instance=self,
                                  model_id=model_id)
        return model_resp