def submit_pyspark_job(self, main_python_file_uri: str, python_file_uris: List[str]): print(f'submit pyspark job started.') job_details = { 'placement': { 'cluster_name': self.cluster_name }, 'pyspark_job': { 'main_python_file_uri': main_python_file_uri, 'python_file_uris': python_file_uris } } job_transport: JobControllerGrpcTransport = JobControllerGrpcTransport( address='{}-dataproc.googleapis.com:443'.format(self.region), credentials=self.dataproc_credentials) dataproc_job_client = JobControllerClient(job_transport) result = dataproc_job_client.submit_job(project_id=self.project_id, region=self.region, job=job_details) job_id = result.reference.job_id print(f'job {job_id} is submitted.') print(f'waiting for job {job_id} to finish...') while True: time.sleep(1) job = dataproc_job_client.get_job(self.project_id, self.region, job_id) if job.status.State.Name(job.status.state) == 'ERROR': raise Exception(job.status.details) elif job.status.State.Name(job.status.state) == 'DONE': print(f'job {job_id} is finished.') break
class DataprocClusterLauncher(JobLauncher): """ Submits jobs to an existing Dataproc cluster. Depends on google-cloud-dataproc and google-cloud-storage, which are optional dependencies that the user has to installed in addition to the Feast SDK. """ EXTERNAL_JARS = ["gs://spark-lib/bigquery/spark-bigquery-latest_2.12.jar"] JOB_TYPE_LABEL_KEY = "feast_job_type" JOB_HASH_LABEL_KEY = "feast_job_hash" def __init__( self, cluster_name: str, staging_location: str, region: str, project_id: str, executor_instances: str, executor_cores: str, executor_memory: str, ): """ Initialize a dataproc job controller client, used internally for job submission and result retrieval. Args: cluster_name (str): Dataproc cluster name. staging_location (str): GCS directory for the storage of files generated by the launcher, such as the pyspark scripts. region (str): Dataproc cluster region. project_id (str): GCP project id for the dataproc cluster. executor_instances (str): Number of executor instances for dataproc job. executor_cores (str): Number of cores for dataproc job. executor_memory (str): Amount of memory for dataproc job. """ self.cluster_name = cluster_name scheme, self.staging_bucket, self.remote_path, _, _, _ = urlparse( staging_location) if scheme != "gs": raise ValueError( "Only GCS staging location is supported for DataprocLauncher.") self.project_id = project_id self.region = region self.job_client = JobControllerClient( client_options={ "api_endpoint": f"{region}-dataproc.googleapis.com:443" }) self.executor_instances = executor_instances self.executor_cores = executor_cores self.executor_memory = executor_memory def _stage_file(self, file_path: str, job_id: str) -> str: if not os.path.isfile(file_path): return file_path staging_client = get_staging_client("gs") blob_path = os.path.join( self.remote_path, job_id, os.path.basename(file_path), ).lstrip("/") blob_uri_str = f"gs://{self.staging_bucket}/{blob_path}" with open(file_path, "rb") as f: staging_client.upload_fileobj(f, file_path, remote_uri=urlparse(blob_uri_str)) return blob_uri_str def dataproc_submit( self, job_params: SparkJobParameters, extra_properties: Dict[str, str] ) -> Tuple[Job, Callable[[], Job], Callable[[], None]]: local_job_id = str(uuid.uuid4()) main_file_uri = self._stage_file(job_params.get_main_file_path(), local_job_id) job_config: Dict[str, Any] = { "reference": { "job_id": local_job_id }, "placement": { "cluster_name": self.cluster_name }, "labels": { self.JOB_TYPE_LABEL_KEY: job_params.get_job_type().name.lower() }, } # Add job hash to labels only for the stream ingestion job if isinstance(job_params, StreamIngestionJobParameters): job_config["labels"][ self.JOB_HASH_LABEL_KEY] = job_params.get_job_hash() if job_params.get_class_name(): properties = { "spark.yarn.user.classpath.first": "true", "spark.executor.instances": self.executor_instances, "spark.executor.cores": self.executor_cores, "spark.executor.memory": self.executor_memory, } properties.update(extra_properties) job_config.update({ "spark_job": { "jar_file_uris": [main_file_uri] + self.EXTERNAL_JARS, "main_class": job_params.get_class_name(), "args": job_params.get_arguments(), "properties": properties, } }) else: job_config.update({ "pyspark_job": { "main_python_file_uri": main_file_uri, "jar_file_uris": self.EXTERNAL_JARS, "args": job_params.get_arguments(), "properties": extra_properties if extra_properties else {}, } }) job = self.job_client.submit_job( request={ "project_id": self.project_id, "region": self.region, "job": job_config, }) refresh_fn = partial( self.job_client.get_job, project_id=self.project_id, region=self.region, job_id=job.reference.job_id, ) cancel_fn = partial(self.dataproc_cancel, job.reference.job_id) return job, refresh_fn, cancel_fn def dataproc_cancel(self, job_id): self.job_client.cancel_job(project_id=self.project_id, region=self.region, job_id=job_id) def historical_feature_retrieval( self, job_params: RetrievalJobParameters) -> RetrievalJob: job, refresh_fn, cancel_fn = self.dataproc_submit( job_params, {"dev.feast.outputuri": job_params.get_destination_path()}) return DataprocRetrievalJob(job, refresh_fn, cancel_fn, job_params.get_destination_path()) def offline_to_online_ingestion( self, ingestion_job_params: BatchIngestionJobParameters ) -> BatchIngestionJob: job, refresh_fn, cancel_fn = self.dataproc_submit( ingestion_job_params, {}) return DataprocBatchIngestionJob(job, refresh_fn, cancel_fn) def start_stream_to_online_ingestion( self, ingestion_job_params: StreamIngestionJobParameters ) -> StreamIngestionJob: job, refresh_fn, cancel_fn = self.dataproc_submit( ingestion_job_params, {}) job_hash = ingestion_job_params.get_job_hash() return DataprocStreamingIngestionJob(job, refresh_fn, cancel_fn, job_hash) def get_job_by_id(self, job_id: str) -> SparkJob: job = self.job_client.get_job(project_id=self.project_id, region=self.region, job_id=job_id) return self._dataproc_job_to_spark_job(job) def _dataproc_job_to_spark_job(self, job: Job) -> SparkJob: job_type = job.labels[self.JOB_TYPE_LABEL_KEY] job_id = job.reference.job_id refresh_fn = partial( self.job_client.get_job, project_id=self.project_id, region=self.region, job_id=job_id, ) cancel_fn = partial(self.dataproc_cancel, job_id) if job_type == SparkJobType.HISTORICAL_RETRIEVAL.name.lower(): output_path = job.pyspark_job.properties.get("dev.feast.outputuri") return DataprocRetrievalJob(job, refresh_fn, cancel_fn, output_path) if job_type == SparkJobType.BATCH_INGESTION.name.lower(): return DataprocBatchIngestionJob(job, refresh_fn, cancel_fn) if job_type == SparkJobType.STREAM_INGESTION.name.lower(): job_hash = job.labels[self.JOB_HASH_LABEL_KEY] return DataprocStreamingIngestionJob(job, refresh_fn, cancel_fn, job_hash) raise ValueError(f"Unrecognized job type: {job_type}") def list_jobs(self, include_terminated: bool) -> List[SparkJob]: job_filter = f"labels.{self.JOB_TYPE_LABEL_KEY} = * AND clusterName = {self.cluster_name}" if not include_terminated: job_filter = job_filter + "AND status.state = ACTIVE" return [ self._dataproc_job_to_spark_job(job) for job in self.job_client.list_jobs(project_id=self.project_id, region=self.region, filter=job_filter) ]