def main(argv=None):
  """Create and train the CLV model on AutoML Tables."""
  argv = sys.argv if argv is None else argv
  args = create_parser().parse_args(args=argv[1:])

  # create and configure client
  keyfile_name = args.key_file
  client = AutoMlClient.from_service_account_file(keyfile_name)

  # create and deploy model
  model_name = create_automl_model(client,
                                   args.project_id,
                                   args.location,
                                   args.bq_dataset,
                                   args.bq_table,
                                   args.automl_dataset,
                                   args.automl_model,
                                   args.training_budget)

  # deploy model
  deploy_model(client, model_name)

  # get model evaluations
  model_evaluation = get_model_evaluation(client, model_name)

  # make predictions
  prediction_client = PredictionServiceClient.from_service_account_file(
      keyfile_name)
  do_batch_prediction(prediction_client,
                      model_name,
                      args.batch_gcs_input,
                      args.batch_gcs_output)
示例#2
0
    def prediction_client(self) -> PredictionServiceClient:
        """
        Creates PredictionServiceClient.

        :return: Google Cloud AutoML PredictionServiceClient client object.
        :rtype: google.cloud.automl_v1beta1.PredictionServiceClient
        """
        return PredictionServiceClient(credentials=self._get_credentials(), client_info=self.client_info)
示例#3
0
文件: ai_web.py 项目: sunny0531/AI
def get_prediction(content, project_id, model_id):
    prediction_client = PredictionServiceClient.from_service_account_file(
        "AI_KEY.json")
    name = 'projects/{}/locations/us-central1/models/{}'.format(
        project_id, model_id)
    payload = {'image': {'image_bytes': content}}
    params = {}
    request = prediction_client.predict(name, payload, params)
    return request
def get_prediction(content, project_id, model_id):
  #storage_client = storage.Client.from_service_account_json(r"C:\Users\dell laptop\AppData\Roaming\gcloud\chicago-hospitals-social-media-7892696dfd9c.json")
  #credentials = service_account.Credentials.from_service_account_file("C:\\Users\\dell laptop\\AppData\\Roaming\\gcloud\\chicago-hospitals-social-media-7892696dfd9c.json")
  #prediction_client = automl_v1beta1.PredictionServiceClient(credentials=credentials)
  prediction_client = PredictionServiceClient.from_service_account_file("C:\\Users\\dell laptop\\AppData\\Roaming\\gcloud\\chicago-hospitals-social-media-7892696dfd9c.json")
  name = 'projects/{}/locations/us-central1/models/{}'.format(project_id, model_id)
  payload = {'text_snippet': {'content': content, 'mime_type': 'text/plain' }}
  params = {}
  request = prediction_client.predict(name, payload, params)
  return request  # waits till request is returned
def main(context):
    # get inputs
    service_account_path = context.get_input_path('service_account')
    training_result_path = context.get_input_path('training_result')
    input_path = context.get_input_path('input')

    training_result = json.load(open(training_result_path))

    # get config
    gcp_project = training_result['gcp_project']
    aml_location = training_result['aml_location']
    aml_dataset = training_result['aml_dataset']
    aml_model = training_result['aml_model']
    img_frame_selection = context.config.get('img_frame_selection')
    img_slice_selection = context.config.get('img_slice_selection')
    score_threshold = context.config.get('score_threshold')

    # create predict client
    predict_client = PredictionServiceClient.from_service_account_json(
        service_account_path)

    # extract the slice and upload for prediction
    log.info('Extracting prediction slice')
    stage_dir = tempfile.mkdtemp()
    image_path = extract_prediction_image(input_path, stage_dir,
                                          img_frame_selection,
                                          img_slice_selection)
    image = open(stage_dir + '/' + image_path, 'rb').read()

    log.info('Running AutoML Vision prediction')
    payload = {'image': {'image_bytes': image}}
    classification = predict_client.predict(aml_model, payload).payload

    # TODO handle multi-label
    score = classification[0].classification.score
    label_key = classification[0].display_name
    label = training_result['label_map'][label_key]
    log.info('Prediction: (confidence={}):\n{}'.format(score,
                                                       pprint.pformat(label)))

    if score < score_threshold:
        log.error('Confidence score is below threshold')
        sys.exit(1)

    file_meta = {'name': os.path.basename(input_path)}
    for key, value in label.items():
        node = file_meta
        key_parts = key.replace('file.', '').split('.')
        for part in key_parts[:-1]:
            node[part] = {}
            node = node[part]
        node[key_parts[-1]] = value
    with open('output/.metadata.json', 'wt') as f:
        # TODO enable for non-acqisition files?
        json.dump({'acquisition': {'files': [file_meta]}}, f)
class AutoMLAPIWrapper:

    def __init__(self, project, location, automl_model_name):
        # The default client behaviour excepts a global model, set the specific EU endpoint for EU models
        self.client = PredictionServiceClient(client_options={"api_endpoint": "eu-automl.googleapis.com:443"})
        self.path = PredictionServiceClient.model_path(project, location, automl_model_name)

    def predict(self, text):
        payload = {"text_snippet": {"content": text, "mime_type": "text/plain"}}
        predictions = self.client.predict(self.path, payload)
        print(predictions)
        return predictions
示例#7
0
def get_prediction(phrase):
  # Create the client from the AutoML service account
  client = PredictionServiceClient.from_service_account_file("./hackutd-1550944892104-1ba910d00c8f.json")
  path = client.model_path("hackutd-1550944892104", "us-central1", "TCN6183341381162616853")

  # Retrieve the prediction data
  payload = {"text_snippet": {"content": phrase, "mime_type": "text/plain"}}
  request = client.predict(path, payload)

  # Extract the highest prediction result
  result = request.payload[0]
  data = (result.display_name, result.classification.score)

  return data
TASK_ID = "test-automl-hook"
GCP_PROJECT_ID = "test-project"
GCP_LOCATION = "test-location"
MODEL_NAME = "test_model"
MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152"
DATASET_ID = "TBL123456789"
MODEL = {
    "display_name": MODEL_NAME,
    "dataset_id": DATASET_ID,
    "tables_model_metadata": {
        "train_budget_milli_node_hours": 1000
    },
}

LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION)
MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION,
                                                MODEL_ID)

INPUT_CONFIG = {"input": "value"}
OUTPUT_CONFIG = {"output": "value"}
PAYLOAD = {"test": "payload"}
DATASET = {"dataset_id": "data"}
MASK = {"field": "mask"}


class TestAutoMLTrainModelOperator(unittest.TestCase):
    @mock.patch(
        "airflow.gcp.operators.automl.AutoMLTrainModelOperator.xcom_push")
    @mock.patch("airflow.gcp.hooks.automl.CloudAutoMLHook.create_model")
    @mock.patch(
        "airflow.gcp.hooks.automl.CloudAutoMLHook.extract_object_id",
        return_value=MODEL_ID,
 def __init__(self, project, location, automl_model_name):
     # The default client behaviour excepts a global model, set the specific EU endpoint for EU models
     self.client = PredictionServiceClient(client_options={"api_endpoint": "eu-automl.googleapis.com:443"})
     self.path = PredictionServiceClient.model_path(project, location, automl_model_name)
示例#10
0
from airflow.contrib.operators import dataflow_operator
from airflow.contrib.operators import gcs_to_bq
# TODO Add when Composer on v2.0 and more Hook
# from airflow.contrib.operators import gcs_list_operator
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
from airflow.utils import trigger_rule

from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
from clv_automl import clv_automl

# instantiate Google Cloud base hook to get credentials and create automl clients
gcp_credentials = GoogleCloudBaseHook(
    conn_id='google_cloud_default')._get_credentials()
automl_client = AutoMlClient(credentials=gcp_credentials)
automl_predict_client = PredictionServiceClient(credentials=gcp_credentials)


def _get_project_id():
    """Get project ID from default GCP connection."""

    extras = BaseHook.get_connection('google_cloud_default').extra_dejson
    key = 'extra__google_cloud_platform__project'
    if key in extras:
        project_id = extras[key]
    else:
        raise ('Must configure project_id in google_cloud_default '
               'connection from Airflow Console')
    return project_id