Esempio n. 1
0
    def test_create_dataset_nontabular(self, create_dataset_mock, sync):
        aiplatform.init(project=_TEST_PROJECT)

        my_dataset = datasets._Dataset.create(
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        if not sync:
            my_dataset.wait()

        expected_dataset = gca_dataset.Dataset(
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR,
            metadata=_TEST_NONTABULAR_DATASET_METADATA,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        create_dataset_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            dataset=expected_dataset,
            metadata=_TEST_REQUEST_METADATA,
        )
Esempio n. 2
0
    def test_create_dataset_with_default_encryption_key(
        self, create_dataset_mock, sync
    ):
        aiplatform.init(
            project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
        )

        my_dataset = datasets.TabularDataset.create(
            display_name=_TEST_DISPLAY_NAME, bq_source=_TEST_SOURCE_URI_BQ, sync=sync,
        )

        if not sync:
            my_dataset.wait()

        expected_dataset = gca_dataset.Dataset(
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR,
            metadata=_TEST_METADATA_TABULAR_BQ,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        create_dataset_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            dataset=expected_dataset,
            metadata=_TEST_REQUEST_METADATA,
        )
Esempio n. 3
0
def get_dataset_without_name_mock():
    with patch.object(dataset_service_client.DatasetServiceClient,
                      "get_dataset") as get_dataset_mock:
        get_dataset_mock.return_value = gca_dataset.Dataset(
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )
        yield get_dataset_mock
def get_dataset_tabular_missing_datasource_mock():
    with patch.object(dataset_service_client.DatasetServiceClient,
                      "get_dataset") as get_dataset_mock:
        get_dataset_mock.return_value = gca_dataset.Dataset(
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR,
            metadata={"inputConfig": {}},
            name=_TEST_NAME,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )
        yield get_dataset_mock
Esempio n. 5
0
def create_dataset_mock():
    with patch.object(dataset_service_client.DatasetServiceClient,
                      "create_dataset") as create_dataset_mock:
        create_dataset_lro_mock = mock.Mock(operation.Operation)
        create_dataset_lro_mock.result.return_value = gca_dataset.Dataset(
            name=_TEST_NAME,
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )
        create_dataset_mock.return_value = create_dataset_lro_mock
        yield create_dataset_mock
def mock_dataset_nontimeseries():
    ds = mock.MagicMock(datasets.ImageDataset)
    ds.name = _TEST_DATASET_NAME
    ds._latest_future = None
    ds._exception = None
    ds._gca_resource = gca_dataset.Dataset(
        display_name=_TEST_DATASET_DISPLAY_NAME,
        metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTIMESERIES,
        labels={},
        name=_TEST_DATASET_NAME,
        metadata={},
    )
    return ds
Esempio n. 7
0
def mock_dataset_text():
    ds = mock.MagicMock(datasets.TextDataset)
    ds.name = _TEST_DATASET_NAME
    ds._latest_future = None
    ds._exception = None
    ds._gca_resource = gca_dataset.Dataset(
        display_name=_TEST_DATASET_DISPLAY_NAME,
        metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT,
        labels={},
        name=_TEST_DATASET_NAME,
        metadata={},
    )
    return ds
Esempio n. 8
0
    def test_create_then_import(
        self, create_dataset_mock, import_data_mock, get_dataset_mock, sync
    ):

        aiplatform.init(project=_TEST_PROJECT)

        my_dataset = datasets._Dataset.create(
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        my_dataset.import_data(
            gcs_source=_TEST_SOURCE_URI_GCS,
            import_schema_uri=_TEST_IMPORT_SCHEMA_URI,
            data_item_labels=_TEST_DATA_LABEL_ITEMS,
            sync=sync,
        )

        if not sync:
            my_dataset.wait()

        expected_dataset = gca_dataset.Dataset(
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR,
            metadata=_TEST_NONTABULAR_DATASET_METADATA,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        expected_import_config = gca_dataset.ImportDataConfig(
            gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),
            import_schema_uri=_TEST_IMPORT_SCHEMA_URI,
            data_item_labels=_TEST_DATA_LABEL_ITEMS,
        )

        create_dataset_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            dataset=expected_dataset,
            metadata=_TEST_REQUEST_METADATA,
        )

        get_dataset_mock.assert_called_once_with(name=_TEST_NAME)

        import_data_mock.assert_called_once_with(
            name=_TEST_NAME, import_configs=[expected_import_config]
        )

        expected_dataset.name = _TEST_NAME
        assert my_dataset._gca_resource == expected_dataset
Esempio n. 9
0
def mock_dataset_tabular():
    ds = mock.MagicMock(datasets.TabularDataset)
    ds.name = _TEST_DATASET_NAME
    ds._latest_future = None
    ds._exception = None
    ds._gca_resource = gca_dataset.Dataset(
        display_name=_TEST_DATASET_DISPLAY_NAME,
        metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR,
        labels={},
        name=_TEST_DATASET_NAME,
        metadata={},
    )
    ds.column_names = _TEST_TRAINING_COLUMN_NAMES

    yield ds
Esempio n. 10
0
_TEST_METADATA_TABULAR_BQ = {
    "inputConfig": {"bigquerySource": {"uri": _TEST_SOURCE_URI_BQ}}
}

# CMEK encryption
_TEST_ENCRYPTION_KEY_NAME = "key_1234"
_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec(
    kms_key_name=_TEST_ENCRYPTION_KEY_NAME
)

# misc
_TEST_OUTPUT_DIR = "gs://my-output-bucket"

_TEST_DATASET_LIST = [
    gca_dataset.Dataset(
        display_name="a", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR
    ),
    gca_dataset.Dataset(
        display_name="d", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR
    ),
    gca_dataset.Dataset(
        display_name="b", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR
    ),
    gca_dataset.Dataset(
        display_name="e", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT
    ),
    gca_dataset.Dataset(
        display_name="c", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR
    ),
]
Esempio n. 11
0
    def test_dataset_create_to_model_predict(
        self,
        create_dataset_mock,  # noqa: F811
        import_data_mock,  # noqa: F811
        predict_client_predict_mock,  # noqa: F811
        mock_python_package_to_gcs,  # noqa: F811
        mock_pipeline_service_create,  # noqa: F811
        mock_model_service_get,  # noqa: F811
        mock_pipeline_service_get,  # noqa: F811
        sync,
    ):

        aiplatform.init(
            project=test_datasets._TEST_PROJECT,
            staging_bucket=test_training_jobs._TEST_BUCKET_NAME,
            credentials=test_training_jobs._TEST_CREDENTIALS,
        )

        my_dataset = aiplatform.ImageDataset.create(
            display_name=test_datasets._TEST_DISPLAY_NAME,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        my_dataset.import_data(
            gcs_source=test_datasets._TEST_SOURCE_URI_GCS,
            import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI,
            data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS,
            sync=sync,
        )

        job = aiplatform.CustomTrainingJob(
            display_name=test_training_jobs._TEST_DISPLAY_NAME,
            script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME,
            container_uri=test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE,
            model_serving_container_image_uri=test_training_jobs.
            _TEST_SERVING_CONTAINER_IMAGE,
            model_serving_container_predict_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
            model_serving_container_health_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_HEALTH_ROUTE,
        )

        model_from_job = job.run(
            dataset=my_dataset,
            base_output_dir=test_training_jobs._TEST_BASE_OUTPUT_DIR,
            args=test_training_jobs._TEST_RUN_ARGS,
            replica_count=1,
            machine_type=test_training_jobs._TEST_MACHINE_TYPE,
            accelerator_type=test_training_jobs._TEST_ACCELERATOR_TYPE,
            accelerator_count=test_training_jobs._TEST_ACCELERATOR_COUNT,
            model_display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME,
            training_fraction_split=test_training_jobs.
            _TEST_TRAINING_FRACTION_SPLIT,
            validation_fraction_split=test_training_jobs.
            _TEST_VALIDATION_FRACTION_SPLIT,
            test_fraction_split=test_training_jobs._TEST_TEST_FRACTION_SPLIT,
            sync=sync,
        )

        created_endpoint = models.Endpoint.create(
            display_name=test_endpoints._TEST_DISPLAY_NAME,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        my_endpoint = model_from_job.deploy(
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync)

        endpoint_deploy_return = created_endpoint.deploy(model_from_job,
                                                         sync=sync)

        assert endpoint_deploy_return is None

        if not sync:
            my_endpoint.wait()
            created_endpoint.wait()

        test_prediction = created_endpoint.predict(instances=[[1.0, 2.0, 3.0],
                                                              [1.0, 3.0, 4.0]],
                                                   parameters={"param": 3.0})

        true_prediction = models.Prediction(
            predictions=test_endpoints._TEST_PREDICTION,
            deployed_model_id=test_endpoints._TEST_ID,
        )

        assert true_prediction == test_prediction
        predict_client_predict_mock.assert_called_once_with(
            endpoint=test_endpoints._TEST_ENDPOINT_NAME,
            instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]],
            parameters={"param": 3.0},
        )

        expected_dataset = gca_dataset.Dataset(
            display_name=test_datasets._TEST_DISPLAY_NAME,
            metadata_schema_uri=test_datasets.
            _TEST_METADATA_SCHEMA_URI_NONTABULAR,
            metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        expected_import_config = gca_dataset.ImportDataConfig(
            gcs_source=gca_io.GcsSource(
                uris=[test_datasets._TEST_SOURCE_URI_GCS]),
            import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI,
            data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS,
        )

        create_dataset_mock.assert_called_once_with(
            parent=test_datasets._TEST_PARENT,
            dataset=expected_dataset,
            metadata=test_datasets._TEST_REQUEST_METADATA,
        )

        import_data_mock.assert_called_once_with(
            name=test_datasets._TEST_NAME,
            import_configs=[expected_import_config])

        expected_dataset.name = test_datasets._TEST_NAME
        assert my_dataset._gca_resource == expected_dataset

        mock_python_package_to_gcs.assert_called_once_with(
            gcs_staging_dir=test_training_jobs._TEST_BUCKET_NAME,
            project=test_training_jobs._TEST_PROJECT,
            credentials=initializer.global_config.credentials,
        )

        true_args = test_training_jobs._TEST_RUN_ARGS

        true_worker_pool_spec = {
            "replica_count": test_training_jobs._TEST_REPLICA_COUNT,
            "machine_spec": {
                "machine_type": test_training_jobs._TEST_MACHINE_TYPE,
                "accelerator_type": test_training_jobs._TEST_ACCELERATOR_TYPE,
                "accelerator_count":
                test_training_jobs._TEST_ACCELERATOR_COUNT,
            },
            "python_package_spec": {
                "executor_image_uri":
                test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE,
                "python_module":
                source_utils._TrainingScriptPythonPackager.module_name,
                "package_uris":
                [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH],
                "args":
                true_args,
            },
        }

        true_fraction_split = gca_training_pipeline.FractionSplit(
            training_fraction=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT,
            validation_fraction=test_training_jobs.
            _TEST_VALIDATION_FRACTION_SPLIT,
            test_fraction=test_training_jobs._TEST_TEST_FRACTION_SPLIT,
        )

        true_container_spec = gca_model.ModelContainerSpec(
            image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE,
            predict_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
            health_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_HEALTH_ROUTE,
        )

        true_managed_model = gca_model.Model(
            display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME,
            container_spec=true_container_spec,
        )

        true_input_data_config = gca_training_pipeline.InputDataConfig(
            fraction_split=true_fraction_split,
            dataset_id=my_dataset.name,
            gcs_destination=gca_io.GcsDestination(
                output_uri_prefix=test_training_jobs._TEST_BASE_OUTPUT_DIR),
        )

        true_training_pipeline = gca_training_pipeline.TrainingPipeline(
            display_name=test_training_jobs._TEST_DISPLAY_NAME,
            training_task_definition=schema.training_job.definition.
            custom_task,
            training_task_inputs=json_format.ParseDict(
                {
                    "worker_pool_specs": [true_worker_pool_spec],
                    "base_output_directory": {
                        "output_uri_prefix":
                        test_training_jobs._TEST_BASE_OUTPUT_DIR
                    },
                },
                struct_pb2.Value(),
            ),
            model_to_upload=true_managed_model,
            input_data_config=true_input_data_config,
        )

        mock_pipeline_service_create.assert_called_once_with(
            parent=initializer.global_config.common_location_path(),
            training_pipeline=true_training_pipeline,
        )

        assert job._gca_resource is mock_pipeline_service_get.return_value

        mock_model_service_get.assert_called_once_with(
            name=test_training_jobs._TEST_MODEL_NAME)

        assert model_from_job._gca_resource is mock_model_service_get.return_value

        assert job.get_model(
        )._gca_resource is mock_model_service_get.return_value

        assert not job.has_failed

        assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED