def execute(self, context: "Context"): self.hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) model = self.hook.create_auto_ml_video_training_job( project_id=self.project_id, region=self.region, display_name=self.display_name, dataset=datasets.VideoDataset(dataset_name=self.dataset_id), prediction_type=self.prediction_type, model_type=self.model_type, labels=self.labels, training_encryption_spec_key_name=self.training_encryption_spec_key_name, model_encryption_spec_key_name=self.model_encryption_spec_key_name, training_fraction_split=self.training_fraction_split, test_fraction_split=self.test_fraction_split, training_filter_split=self.training_filter_split, test_filter_split=self.test_filter_split, model_display_name=self.model_display_name, model_labels=self.model_labels, sync=self.sync, ) result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) return result
def test_import_data(self, import_data_mock, sync): aiplatform.init(project=_TEST_PROJECT) my_dataset = datasets.VideoDataset(dataset_name=_TEST_NAME) my_dataset.import_data( gcs_source=[_TEST_SOURCE_URI_GCS], import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, sync=sync, ) if not sync: my_dataset.wait() expected_import_config = gca_dataset.ImportDataConfig( gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, ) import_data_mock.assert_called_once_with( name=_TEST_NAME, import_configs=[expected_import_config])
def test_init_dataset_non_video(self): aiplatform.init(project=_TEST_PROJECT) with pytest.raises(ValueError): datasets.VideoDataset(dataset_name=_TEST_NAME)
def test_init_dataset_video(self, get_dataset_video_mock): aiplatform.init(project=_TEST_PROJECT) datasets.VideoDataset(dataset_name=_TEST_NAME) get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME)