예제 #1
0
def do_predict_mle(**kwargs):
    """ Runs a batch prediction on new data and saving the results as CSV into
    output_path.
    """
    job_id = 'clv-{}'.format(datetime.datetime.now().strftime('%Y%m%d%H%M'))
    gcs_prediction_input = 'gs://{}/predictions/to_predict.csv'.format(
        COMPOSER_BUCKET_NAME)
    gcs_prediction_output = 'gs://{}/predictions/output'.format(
        COMPOSER_BUCKET_NAME)
    model_name = kwargs['dag_run'].conf.get('model_name')
    model_version = kwargs['dag_run'].conf.get('model_version')

    logging.info("Running prediction using {}:{}...".format(
        model_name, model_version))

    mlengine_operator.MLEngineBatchPredictionOperator(
        task_id='predict_dnn',
        project_id=PROJECT,
        job_id=job_id,
        region=REGION,
        data_format='TEXT',
        input_paths=gcs_prediction_input,
        output_path=gcs_prediction_output,
        model_name=model_name,
        version_name=model_version,
        #uri=gs://WHERE_MODEL_IS_IF_NOT_ML_ENGINE
        #runtime_version=TF_VERSION,
        dag=dag).execute(kwargs)
예제 #2
0
def create_dag(env_variables):
    """Creates the Airflow directed acyclic graph.

  Args:
    env_variables: Dictionary of Airflow environment variables.

  Returns:
    driblet_dag: An instance of models.DAG.
  """
    driblet_dag = initialize_dag()

    # Clients setup.
    project_id = env_variables['project_id']
    bq_client = bigquery.Client(project=project_id)
    gcs_client = storage.Client(project=project_id)

    # TASK 1: Convert BigQuery CSV to TFRECORD.
    dag_dir = configuration.get('core', 'dags_folder')
    transformer_py = os.path.join(dag_dir, 'tasks/preprocess',
                                  'transformer.py')
    bq_to_tfrecord = dataflow_operator.DataFlowPythonOperator(
        task_id='bq-to-tfrecord',
        py_file=transformer_py,
        options={
            'project':
            project_id,
            'predict-data':
            '{}.{}.{}_{}'.format(project_id, env_variables['bq_dataset'],
                                 env_variables['bq_input_table'],
                                 datetime.datetime.now().strftime('%Y%m%d')),
            'data-source':
            'bigquery',
            'transform-dir':
            'gs://%s/transformer' % env_variables['bucket_name'],
            'output-dir':
            'gs://%s/input' % env_variables['bucket_name'],
            'mode':
            'predict'
        },
        dataflow_default_options={'project': project_id},
        dag=driblet_dag)

    # TASK 2: Make prediction from CSV in GCS.
    make_predictions = mlengine_operator.MLEngineBatchPredictionOperator(
        task_id='make-predictions',
        project_id=project_id,
        job_id='driblet_run_{}'.format(
            datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')),
        data_format='TF_RECORD',
        input_paths=['gs://%s/input/predict-*' % env_variables['bucket_name']],
        output_path='gs://%s/output' % env_variables['bucket_name'],
        region=env_variables['region'],
        model_name=env_variables['model_name'],
        version_name=env_variables['model_version'],
        gcp_conn_id='google_cloud_default',
        dag=driblet_dag)

    # TASK 3: Export predicted CSV from Cloud Storage to BigQuery.
    job_config = bigquery.LoadJobConfig()
    job_config.autodetect = True
    job_config.source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON
    job_config.time_partitioning = bigquery.TimePartitioning(
        type_=bigquery.TimePartitioningType.
        DAY,  # Sets daily partitioned table.
        expiration_ms=env_variables['dataset_expiration'])
    gcs_to_bigquery = GCStoBQOperator(
        task_id='gcs-to-bigquery',
        bq_client=bq_client,
        gcs_client=gcs_client,
        job_config=job_config,
        dataset_id=env_variables['bq_dataset'],
        table_id=env_variables['bq_output_table'],
        gcs_bucket=env_variables['bucket_name'],
        gcs_location=env_variables['location'],
        exclude_prefix='errors_stats',  # Exclude files starting with name.
        dir_prefix='output',
        dag=driblet_dag)

    # TASK 4: Delete files in Cloud Storage bucket.
    gcs_delete_blob = GCSDeleteBlobOperator(
        task_id='gcs-delete-blob',
        client=gcs_client,
        gcs_bucket=env_variables['bucket_name'],
        prefixes=['input', 'output'],
        dag=driblet_dag)

    make_predictions.set_upstream(bq_to_tfrecord)
    make_predictions.set_downstream(gcs_to_bigquery)
    gcs_delete_blob.set_upstream(gcs_to_bigquery)

    return driblet_dag