示例#1
0
文件: auto_ml.py 项目: vipadm/airflow
    def execute(self, context: "Context"):
        self.hook = AutoMLHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )
        model = self.hook.create_auto_ml_image_training_job(
            project_id=self.project_id,
            region=self.region,
            display_name=self.display_name,
            dataset=datasets.ImageDataset(dataset_name=self.dataset_id),
            prediction_type=self.prediction_type,
            multi_label=self.multi_label,
            model_type=self.model_type,
            base_model=self.base_model,
            labels=self.labels,
            training_encryption_spec_key_name=self.training_encryption_spec_key_name,
            model_encryption_spec_key_name=self.model_encryption_spec_key_name,
            training_fraction_split=self.training_fraction_split,
            validation_fraction_split=self.validation_fraction_split,
            test_fraction_split=self.test_fraction_split,
            training_filter_split=self.training_filter_split,
            validation_filter_split=self.validation_filter_split,
            test_filter_split=self.test_filter_split,
            budget_milli_node_hours=self.budget_milli_node_hours,
            model_display_name=self.model_display_name,
            model_labels=self.model_labels,
            disable_early_stopping=self.disable_early_stopping,
            sync=self.sync,
        )

        result = Model.to_dict(model)
        model_id = self.hook.extract_model_id(result)
        VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
        return result
示例#2
0
    def test_import_data(self, import_data_mock, sync):
        aiplatform.init(project=_TEST_PROJECT)

        my_dataset = datasets.ImageDataset(dataset_name=_TEST_NAME)

        my_dataset.import_data(
            gcs_source=[_TEST_SOURCE_URI_GCS],
            import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE,
            sync=sync,
        )

        if not sync:
            my_dataset.wait()

        expected_import_config = gca_dataset.ImportDataConfig(
            gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),
            import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE,
        )

        import_data_mock.assert_called_once_with(
            name=_TEST_NAME, import_configs=[expected_import_config])
示例#3
0
 def test_init_dataset_non_image(self):
     aiplatform.init(project=_TEST_PROJECT)
     with pytest.raises(ValueError):
         datasets.ImageDataset(dataset_name=_TEST_NAME)
示例#4
0
 def test_init_dataset_image(self, get_dataset_image_mock):
     aiplatform.init(project=_TEST_PROJECT)
     datasets.ImageDataset(dataset_name=_TEST_NAME)
     get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME)