def execute(self, context): job_id = _normalize_mlengine_job_id(self._job_id) training_request = { 'jobId': job_id, 'trainingInput': { 'scaleTier': self._scale_tier, 'packageUris': self._package_uris, 'pythonModule': self._training_python_module, 'region': self._region, 'args': self._training_args, } } if self._runtime_version: training_request['trainingInput']['runtimeVersion'] = self._runtime_version if self._python_version: training_request['trainingInput']['pythonVersion'] = self._python_version if self._job_dir: training_request['trainingInput']['jobDir'] = self._job_dir if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM": training_request['trainingInput']['masterType'] = self._master_type if self._mode == 'DRY_RUN': self.log.info('In dry_run mode.') self.log.info('MLEngine Training job request is: %s', training_request) return hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) # Helper method to check if the existing job's training input is the # same as the request we get here. def check_existing_job(existing_job): existing_training_input = existing_job.get('trainingInput', None) requested_training_input = training_request['trainingInput'] if 'scaleTier' not in existing_training_input: existing_training_input['scaleTier'] = None existing_training_input['args'] = existing_training_input.get('args', None) requested_training_input["args"] = requested_training_input['args'] \ if requested_training_input["args"] else None return existing_training_input == requested_training_input finished_training_job = hook.create_job( project_id=self._project_id, job=training_request, use_existing_job_fn=check_existing_job ) if finished_training_job['state'] != 'SUCCEEDED': self.log.error('MLEngine training job failed: %s', str(finished_training_job)) raise RuntimeError(finished_training_job['errorMessage']) gcp_metadata = { "job_id": job_id, "project_id": self._project_id, } context['task_instance'].xcom_push("gcp_metadata", gcp_metadata)
def execute(self, context): job_id = _normalize_mlengine_job_id(self._job_id) prediction_request = { 'jobId': job_id, 'predictionInput': { 'dataFormat': self._data_format, 'inputPaths': self._input_paths, 'outputPath': self._output_path, 'region': self._region } } if self._labels: prediction_request['labels'] = self._labels if self._uri: prediction_request['predictionInput']['uri'] = self._uri elif self._model_name: origin_name = 'projects/{}/models/{}'.format( self._project_id, self._model_name) if not self._version_name: prediction_request['predictionInput'][ 'modelName'] = origin_name else: prediction_request['predictionInput']['versionName'] = \ origin_name + '/versions/{}'.format(self._version_name) if self._max_worker_count: prediction_request['predictionInput'][ 'maxWorkerCount'] = self._max_worker_count if self._runtime_version: prediction_request['predictionInput'][ 'runtimeVersion'] = self._runtime_version if self._signature_name: prediction_request['predictionInput'][ 'signatureName'] = self._signature_name hook = MLEngineHook(self._gcp_conn_id, self._delegate_to) # Helper method to check if the existing job's prediction input is the # same as the request we get here. def check_existing_job(existing_job): return existing_job.get('predictionInput', None) == \ prediction_request['predictionInput'] finished_prediction_job = hook.create_job( project_id=self._project_id, job=prediction_request, use_existing_job_fn=check_existing_job) if finished_prediction_job['state'] != 'SUCCEEDED': self.log.error('MLEngine batch prediction job failed: %s', str(finished_prediction_job)) raise RuntimeError(finished_prediction_job['errorMessage']) return finished_prediction_job['predictionOutput']