def setup_teardown(
        shared_state: dict
) -> Tuple[storage.bucket.Bucket, aip.JobServiceClient]:
    storage_client = storage.Client()
    bucket = storage_client.create_bucket(STAGING_BUCKET, location=REGION)
    bucket.blob(f"{INPUT_DIR}/{TRAIN_DATA}").upload_from_filename(TRAIN_DATA,
                                                                  timeout=600)

    with tarfile.open(TRAINER_TAR, mode="x:gz") as tar:
        tar.add(f"{TRAINER_DIR}/")

    bucket.blob(TRAINER_TAR).upload_from_filename(TRAINER_TAR)

    aip_job_client = aip.JobServiceClient(
        client_options={"api_endpoint": f"{REGION}-aiplatform.googleapis.com"})

    yield bucket, aip_job_client

    try:
        bucket.delete(force=True)
    except NotFound:
        print("Bucket not found.")

    os.remove(TRAINER_TAR)

    aip_job_client.delete_custom_job(name=shared_state["model_name"]).result()
Ejemplo n.º 2
0
    def create_client(self) -> None:
        """Creates the Gapic job client.

    Can also be used for recreating the job client (e.g. in the case of
    communication failure).

    Multiple job requests can be done in parallel if needed, by creating an
    instance of the class for each job. Note that one class instance should
    only be used for one job, as each instance stores variables (e.g. job_id)
    specific to each job.
    """
        self._client = gapic.JobServiceClient(client_options=dict(
            api_endpoint=self._region + _UCAIP_ENDPOINT_SUFFIX))
Ejemplo n.º 3
0
 def _get_vertexai_job_client(self, location):
   api_endpoint = f'{location}-aiplatform.googleapis.com'
   client_options = {'api_endpoint': api_endpoint}
   return aip.JobServiceClient(client_options=client_options)