def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock):
        creds = auth_credentials.AnonymousCredentials()
        my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME,
                                             credentials=creds)

        my_dataset.column_names

        assert bq_client_mock.call_args_list[0] == mock.call(
            project=_TEST_PROJECT, credentials=creds)
示例#2
0
    def test_delete_dataset(self, delete_dataset_mock, sync):
        aiplatform.init(project=_TEST_PROJECT)

        my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)
        my_dataset.delete(sync=sync)

        if not sync:
            my_dataset.wait()

        delete_dataset_mock.assert_called_once_with(name=my_dataset.resource_name)
    def test_init_dataset_with_alt_location(self,
                                            get_dataset_tabular_gcs_mock):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_ALT_LOCATION)

        ds = datasets.TabularDataset(dataset_name=_TEST_NAME)

        assert (ds.api_client._clients[
            compat.DEFAULT_VERSION]._client_options.api_endpoint ==
                f"{_TEST_LOCATION}-{aiplatform.constants.API_BASE_PATH}")

        assert _TEST_ALT_LOCATION != _TEST_LOCATION

        get_dataset_tabular_gcs_mock.assert_called_once_with(name=_TEST_NAME)
    def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_mock):
        creds = auth_credentials.AnonymousCredentials()
        my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME,
                                             credentials=creds)

        # we are just testing creds passing
        # this exception if from the mock not returning
        # the csv data which is tested above
        try:
            my_dataset.column_names
        except StopIteration:
            pass

        gcs_client_mock.assert_called_once_with(project=_TEST_PROJECT,
                                                credentials=creds)
示例#5
0
    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_tabular_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.TabularDataset(dataset_name=self.dataset_id),
            target_column=self.target_column,
            optimization_prediction_type=self.optimization_prediction_type,
            optimization_objective=self.optimization_objective,
            column_specs=self.column_specs,
            column_transformations=self.column_transformations,
            optimization_objective_recall_value=self.
            optimization_objective_recall_value,
            optimization_objective_precision_value=self.
            optimization_objective_precision_value,
            labels=self.labels,
            training_encryption_spec_key_name=self.
            training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            validation_fraction_split=self.validation_fraction_split,
            test_fraction_split=self.test_fraction_split,
            predefined_split_column_name=self.predefined_split_column_name,
            timestamp_split_column_name=self.timestamp_split_column_name,
            weight_column=self.weight_column,
            budget_milli_node_hours=self.budget_milli_node_hours,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            disable_early_stopping=self.disable_early_stopping,
            export_evaluated_data_items=self.export_evaluated_data_items,
            export_evaluated_data_items_bigquery_destination_uri=(
                self.export_evaluated_data_items_bigquery_destination_uri),
            export_evaluated_data_items_override_destination=(
                self.export_evaluated_data_items_override_destination),
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context,
                                  task_instance=self,
                                  model_id=model_id)
        return result
示例#6
0
    def test_tabular_dataset_column_name_bigquery(self):
        my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

        assert my_dataset.column_names == ["column_1", "column_2"]
示例#7
0
    def test_tabular_dataset_column_name_missing_datasource(self):
        my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

        with pytest.raises(RuntimeError):
            my_dataset.column_names
示例#8
0
    def test_no_import_data_method(self):

        my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

        with pytest.raises(NotImplementedError):
            my_dataset.import_data()
示例#9
0
    def test_init_dataset_non_tabular(self):

        with pytest.raises(ValueError):
            datasets.TabularDataset(dataset_name=_TEST_NAME)
示例#10
0
    def test_init_dataset_tabular(self, get_dataset_tabular_bq_mock):

        datasets.TabularDataset(dataset_name=_TEST_NAME)
        get_dataset_tabular_bq_mock.assert_called_once_with(name=_TEST_NAME)