Esempio n. 1
0
def write_td_table(database_name, table_name):
    import pandas as pd
    import random
    # TODO TD client, check for table's existence
    engine = td.create_engine(f"presto:{database_name}", con=con)
    df = pd.DataFrame({"c": [random.random() for _ in range(20)]})

    # Manipulating data in Treasure Data via Python.
    # Uses https://github.com/treasure-data/td-client-python

    tdc = tdclient.Client(apikey=os.environ['TD_API_KEY'],
                          endpoint=os.environ['TD_API_SERVER'])

    try:
        tdc.create_database(database_name)
    except tdclient.errors.AlreadyExistsError:
        pass

    try:
        tdc.create_log_table(database_name, table_name)
    except tdclient.errors.AlreadyExistsError:
        pass

    table_path = f"{database_name}.{table_name}"
    td.to_td(df, table_path, con, if_exists='replace', index=False)
Esempio n. 2
0
    def __init__(self,
                 apikey=None,
                 endpoint=None,
                 database="sample_datasets",
                 default_engine="presto",
                 header=True,
                 **kwargs):
        if isinstance(default_engine, QueryEngine):
            apikey = default_engine.apikey
            endpoint = default_engine.endpoint
            database = default_engine.database
        else:
            apikey = apikey or os.environ.get("TD_API_KEY")
            if apikey is None:
                raise ValueError(
                    "either argument 'apikey' or environment variable"
                    "'TD_API_KEY' should be set")
            if endpoint is None:
                endpoint = os.getenv("TD_API_SERVER",
                                     "https://api.treasuredata.com")
            default_engine = self._fetch_query_engine(default_engine, apikey,
                                                      endpoint, database,
                                                      header)

        self.apikey = apikey
        self.endpoint = endpoint
        self.database = database

        self.default_engine = default_engine
        self.query_executed = None

        self.api_client = tdclient.Client(apikey=apikey,
                                          endpoint=endpoint,
                                          user_agent=default_engine.user_agent,
                                          **kwargs)
def get_job_list(status, max_num):
    with tdclient.Client(apikey=TD_API_KEY, endpoint=TD_API_SERVER) as client:
        data = []
        for job in client.jobs(0, max_num, status):
            job_detail = client.job(job.job_id)
            data.append({
                'time':
                int(time.time()),
                'job_id':
                str(job.job_id),
                'type':
                str(job_detail._type),
                'query':
                str(job_detail._query),
                'status':
                str(job_detail._status),
                'created_at':
                -1 if job_detail._created_at is None else int(
                    job_detail._created_at.timestamp()),
                'start_at':
                -1 if job_detail._start_at is None else int(
                    job_detail._start_at.timestamp()),
                'org_name':
                str(job_detail.org_name),
                'database':
                str(job_detail._database),
                'user_name':
                str(job_detail._user_name)
            })
        return data
Esempio n. 4
0
    def __init__(self,
                 apikey=None,
                 endpoint=None,
                 database='sample_datasets',
                 engine='presto',
                 header=True,
                 **kwargs):
        if isinstance(engine, QueryEngine):
            apikey = engine.apikey
            endpoint = engine.endpoint
            database = engine.database
        else:
            apikey = apikey or os.environ.get('TD_API_KEY')
            if apikey is None:
                raise ValueError(
                    "either argument 'apikey' or environment variable 'TD_API_KEY' should be set"
                )
            endpoint = os.getenv('TD_API_SERVER',
                                 'https://api.treasuredata.com')
            engine = self._fetch_query_engine(engine, apikey, endpoint,
                                              database, header)

        self.apikey = apikey
        self.endpoint = endpoint
        self.database = database

        self.engine = engine

        self.api_client = tdclient.Client(apikey=apikey,
                                          endpoint=endpoint,
                                          user_agent=engine.user_agent,
                                          **kwargs)

        self.writer = None
Esempio n. 5
0
    def _create_vectorize_table(self, engine, dbname, table_name, source_table,
                                feature_query):
        import tdclient
        import pandas_td as td

        # Create feature vector table
        with tdclient.Client(apikey=self.apikey,
                             endpoint=self.endpoint) as client:
            db = client.database(dbname)
            try:
                db.table(table_name)
                db.table(table_name).delete()
            except tdclient.api.NotFoundError as e:
                pass

            db.create_log_table(table_name)

        hql = '''insert overwrite table {output_table}
        select
            rowid,
        {target_columns},
            medv as price
        from
            {source_table}
        '''.format_map({
            'output_table': table_name,
            'source_table': source_table,
            'target_columns': textwrap.indent(feature_query, '    ')
        })

        td.read_td(hql, engine)
Esempio n. 6
0
def run_dynamic_query(parameters):
    #0. Initialize our connection to Treasure Data
    apikey = os.environ['MASTER_TD_API_KEY']
    endpoint = 'https://api.treasuredata.com'
    con = td.connect(apikey, endpoint)
    #1. Connect to the query engine
    con_engine = con.query_engine(database=parameters['db_name'],
                                  type=parameters['query_engine'])

    #2. Setup query limit string
    if parameters['limit'] != '0':
        limit_str = "LIMIT " + str(parameters['limit']) + ";"
    else:
        limit_str = ";"

#3. Compose Query String
    if not 'min_time' in parameters.keys():
        parameters['min_time'] = 'NULL'

    if not 'max_time' in parameters.keys():
        parameters['max_time'] = 'NULL'

    if parameters['min_time'] == 'NULL' and parameters['max_time'] == 'NULL':
        compose_query = "SELECT " + parameters['col_list']   + " " + \
                        "FROM "   + parameters['table_name'] + " " + limit_str

    else:
        compose_query = "SELECT " + parameters['col_list']   + " " + \
                        "FROM "   + parameters['table_name'] + " " + \
                        "WHERE "  + "td_time_range(time,"    + parameters['min_time'] + "," + parameters['max_time'] + ") " + \
                        limit_str

    print("Executing..." + compose_query)
    #4. Run query as a job and wait for job to finish
    #Assign result set to a data frame

    with tdclient.Client(apikey) as client:
        job = client.query(parameters['db_name'],
                           compose_query,
                           type=parameters['query_engine'])
        job.wait()
        try:
            #Assign result set to a data frame
            df = td.read_td_job(job.job_id, con_engine)
        except RuntimeError:
            print("Please review the column names and delimited by commas: " +
                  parameters['col_list'])
            return

#5. Write the results to a csv or tabular format file
    if parameters['format'] == 'csv':
        print("Downloading results to " + job.job_id + ".csv" + " file")
        df.to_csv(job.job_id + ".csv")
    else:
        #Write data into tabular grid format
        print("Downloading results to " + job.job_id + ".txt" + " file")
        filename = job.job_id + ".txt"
        outfile = open(filename, "a")
        outfile.write(tabulate(df, tablefmt="grid"))
        outfile.close()
Esempio n. 7
0
    def execute(self, quals, columns):
        if self.query:
            statement = self.query
        else:
            cond = self.create_cond(quals)
            statement = "SELECT %s FROM %s" % (",".join(
                self.columns.keys()), self.table)
            if cond != '':
                statement += ' WHERE %s' % (cond)

        log_to_postgres('TreasureData query: ' + str(statement), DEBUG)

        try:
            with tdclient.Client(apikey=self.apikey,
                                 endpoint=self.endpoint) as td:
                job = td.query(self.database,
                               statement,
                               type=self.query_engine)
                job.wait()
                for row in job.result():
                    i = 0
                    record = {}
                    for column_name in self.columns:
                        record[column_name] = row[i]
                        i += 1
                    yield record
        except Exception as e:
            log_to_postgres(str(e), ERROR)
Esempio n. 8
0
def run():
    os.system(f"{sys.executable} -m pip install --user tdclient boto3")
    os.system(f"{sys.executable} -m pip install --user tensorflow")

    import boto3
    import tensorflow as tf
    import numpy as np
    import tdclient

    database = 'sentiment'
    td = tdclient.Client(apikey=os.environ['TD_API_KEY'],
                         endpoint=os.environ['TD_API_SERVER'])
    job = td.query(database,
                   """
            select
                rowid, sentence
            from
                movie_review_test_shuffled
        """,
                   type="presto")
    job.wait()

    examples = []
    row_ids = []
    for row in job.result():
        rowid, sentence = row
        row_ids.append(rowid)

        feature = {
            'sentence':
            tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[sentence.encode('utf-8')]))
        }
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        examples.append(example.SerializeToString())

    # Download the TensorFlow model to S3
    # boto3 assuming environment variables "AWS_ACCESS_KEY_ID" and "AWS_SECRET_ACCESS_KEY":
    # http://boto3.readthedocs.io/en/latest/guide/configuration.html#environment-variables
    s3 = boto3.resource('s3')
    s3.Bucket(os.environ['S3_BUCKET']).download_file('tfmodel.tar.gz',
                                                     'tfmodel.tar.gz')

    with tarfile.open('tfmodel.tar.gz') as tar:
        tar.extractall()

    with tf.Session(graph=tf.Graph()) as sess:
        export_dir = get_export_dir()
        predict_fn = tf.contrib.predictor.from_saved_model(export_dir)
        predictions = predict_fn({'inputs': examples})

    predicted_polarities = np.argmax(predictions['scores'], axis=1)
    scores = np.max(predictions['scores'], axis=1)

    table = 'test_predicted_polarities'
    upload_prediction_result(td, row_ids, predicted_polarities, scores,
                             database, table,
                             ['rowid', 'predicted_polarity', 'score'])
Esempio n. 9
0
 def get_schema(self):
     schema = {}
     try:
         with tdclient.Client(self.configuration.get('apikey')) as client:
             for table in client.tables(self.configuration.get('db')):
                 table_name = '{}.{}'.format(self.configuration.get('db'), table.name)
                 for table_schema in table.schema:
                     schema[table_name] = {'name': table_name, 'columns': table.schema}
     except Exception, ex:
         raise Exception("Failed getting schema")
Esempio n. 10
0
    def run(self):
        # load cluster definitions
        with open('resources/cluster_definitions.json') as f:
            self.cluster_definitions = json.loads(
                f.read(), object_pairs_hook=OrderedDict)

        database = 'takuti'

        td = tdclient.Client(apikey=os.environ['TD_API_KEY'],
                             endpoint=os.environ['TD_API_SERVER'])

        # read original title data
        job = td.query(database,
                       'select title, words from title',
                       type='presto')
        job.wait()

        titles = {}
        for row in job.result():
            title, words = row
            titles[title] = words

        # categorize & write to a mapping file
        with open('resources/title_mapping.csv', 'w', newline='') as f:
            writer = csv.DictWriter(
                f, fieldnames=['time', 'title', 'role', 'job'])
            writer.writeheader()
            t = int(time.time())
            for raw_title, words in titles.items():
                row = self.__categorize(raw_title, words)
                row['time'] = t
                writer.writerow(row)

        table = 'title_mapping'
        try:
            td.table(database, table)
        except tdclient.errors.NotFoundError:
            pass
        else:
            td.delete_table(database, table)
        td.create_log_table(database, table)
        td.import_file(database, table, 'csv', 'resources/title_mapping.csv')

        os.remove('resources/title_mapping.csv')

        # Wait for a while until imported records are fully available on TD
        # console.
        while True:
            job = td.query(database,
                           'select count(title) from ' + table,
                           type='presto')
            job.wait()
            if not job.error():
                break
            time.sleep(10)
Esempio n. 11
0
def main():

    with tdclient.Client(apikey) as client:
        print("""
        Now processing this SQL query: """ + q_td)
        job = client.query(db_name="sample_datasets", q=q_td)
        # sleep until job's finish
        job.wait()
        f = open(filename, "w")
        for row in job.result():
            row = str(row)
            f.write(row)
Esempio n. 12
0
def new_schema_create_new_table(filename, table_name, database_name = "braze"):
    reader = DataFileReader(open(filename, "rb"), DatumReader())
    schema = json.loads(reader.meta['avro.schema'])
    create_table = "CREATE TABLE IF NOT EXISTS " + table_name
    all_field_string = ''
    for field in  schema['fields']:
        comma = ', '
        if(all_field_string == ""):
            comma = ' '
        all_field_string = all_field_string + comma + convert_schema_to_Presto(field)
    create_table = create_table + ' ( ' + all_field_string +  ' ); '
    td = tdclient.Client(os.environ['td_apikey'])
    job = td.query(database_name, create_table, type = "presto")
    job.wait()
Esempio n. 13
0
 def get_schema(self, get_stats=False):
     schema = {}
     if self.configuration.get('get_schema', False):
         try:
             with tdclient.Client(self.configuration.get('apikey')) as client:
                 for table in client.tables(self.configuration.get('db')):
                     table_name = '{}.{}'.format(self.configuration.get('db'), table.name)
                     for table_schema in table.schema:
                         schema[table_name] = {
                             'name': table_name,
                             'columns': [column[0] for column in table.schema],
                         }
         except Exception as ex:
             raise Exception("Failed getting schema")
     return schema.values()
Esempio n. 14
0
def clean_Reload(email_links, input_tables, database_name = "braze"):
    for key in email_links.keys():
        email_links[key]['date'] = 'date=2010-04-17-20/'
        email_links[key]['time'] = '00:00:00'
    td = tdclient.Client(os.environ['td_apikey'])

    for key in email_links:
        if(input_tables == [] or (key in input_tables)):
            table = key.split("/")[0]
            table = table.split(".")
            table_name = table[0] + "_" + table[1] + "_" + table[2] + "_" + table[3]


            drop_table = "DROP TABLE IF EXISTS " + table_name
            job = td.query(database_name, drop_table, type = "presto")
            job.wait()
Esempio n. 15
0
def execute_query(query):
    headers = None
    data = []
    try:
        with tdclient.Client(apikey) as client:
            job = client.query(args.db_name, query)
            # sleep until job's finish
            job.wait()
            print job.result_schema
            headers = extract_headers_from_result(job.result_schema)
            for row in job.result():
                data.append(row)
        return True, headers, data
    except:
        print "Error occured while querying - ", sys.exc_info()[0]
        return False, None, None
Esempio n. 16
0
 def __init__(self, apikey=None, endpoint=None, **kwargs):
     if apikey is not None:
         kwargs['apikey'] = apikey
     if endpoint is not None:
         if not endpoint.endswith('/'):
             endpoint = endpoint + '/'
         kwargs['endpoint'] = endpoint
     if 'user_agent' not in kwargs:
         versions = [
             "pandas/{0}".format(pd.__version__),
             "tdclient/{0}".format(tdclient.version.__version__),
             "Python/{0}.{1}.{2}.{3}.{4}".format(*list(sys.version_info)),
         ]
         kwargs['user_agent'] = "pandas-td/{0} ({1})".format(
             __version__, ' '.join(versions))
     self.client = tdclient.Client(**kwargs)
Esempio n. 17
0
 def get_client(self, apikey=None, endpoint=None):
     kwargs = {}
     if apikey is not None:
         kwargs['apikey'] = apikey
     if endpoint is not None:
         if not endpoint.endswith('/'):
             endpoint = endpoint + '/'
         kwargs['endpoint'] = endpoint
     if 'user_agent' not in kwargs:
         versions = [
             "tdclient/{0}".format(tdclient.version.__version__),
             "Python/{0}.{1}.{2}.{3}.{4}".format(*list(sys.version_info)),
         ]
         kwargs['user_agent'] = "pytd/{0} ({1})".format(
             __version__, ' '.join(versions))
     return tdclient.Client(**kwargs)
Esempio n. 18
0
 def get_schema(self, get_stats=False):
     schema = {}
     if self.configuration.get("get_schema", False):
         try:
             with tdclient.Client(self.configuration.get("apikey"),endpoint=self.configuration.get("endpoint")) as client:
                 for table in client.tables(self.configuration.get("db")):
                     table_name = "{}.{}".format(
                         self.configuration.get("db"), table.name
                     )
                     for table_schema in table.schema:
                         schema[table_name] = {
                             "name": table_name,
                             "columns": [column[0] for column in table.schema],
                         }
         except Exception as ex:
             raise Exception("Failed getting schema")
     return list(schema.values())
Esempio n. 19
0
def get_kansei():
    n = datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
    n = n - datetime.timedelta(minutes=1)
    t = n.strftime("%Y-%m-%d %H:%M:%S %Z")

    with tdclient.Client(apikey) as client:
        job = client.query("eeg_datasets",
              "SELECT AVG(interest) AS interest" +
              "      ,AVG(concentration) AS concentration" +
              "      ,AVG(drowsiness) AS drowsiness" +
              "      ,AVG(stress) AS stress" +
              "      ,AVG(like) AS like" +
              "  FROM kansei_sample" +
              " WHERE TD_TIME_RANGE(time, '" + t + "'," +
              "                     TD_TIME_ADD('" + t + "', '1m'))"
              )
        while not job.finished():
            time.sleep(1)
        for row in job.result():
            return row
Esempio n. 20
0
 def __init__(self, apikey):
     self.tdclient = tdclient.Client(apikey=apikey)
Esempio n. 21
0
import os
import tdclient

my_db = "evaluation_thomasluckenbach"

apikey = os.getenv("TD_API_KEY")

with tdclient.Client(apikey) as td:
    job = td.query(
        my_db,
        "SELECT count(amount), sum(amount) FROM sales_data where prod_id='13'")
    job.wait()
    print("Total Count and Sum amount for prod_id=13")
    for row in job.result():
        print(row)

    job = td.query(
        my_db, "SELECT TD_TIME_FORMAT(time, 'yyyy-MM-dd', 'JST') AS day,\
        COUNT(1) AS pv, sum(amount) FROM sales_data\
        WHERE  prod_id='13'\
        GROUP BY TD_TIME_FORMAT(time, 'yyyy-MM-dd', 'JST')\
        ORDER BY day ASC")
    job.wait()
    print("Orders and Revenue by day for prod_id=13")
    for row in job.result():
        print(row)
Esempio n. 22
0
def import_from_aws(directory_avro = "/tmp/directory_avro",  clean_start = False, input_tables = {}, test = True):
    '''
        Location of the relevant files in the bucket
    '''
    link2 = 'currents/dataexport.prod-03.S3.integration.5a9ee171a12f74534a9a4e70/'
    link3 = 'event_type='

    #boto3 client environmental parameters are set.
    client = boto3.client(
        's3',
        aws_access_key_id=os.environ['aws_access_key_id'],
        aws_secret_access_key=os.environ['aws_secret_access_key'],
    )
    bucket = 'fivestars-kprod-braze-events'

    email_links = json.load(open("config.json", "rb")) #loaded as a dictionary
    if(clean_start == True and test == False):
        clean_Reload(email_links, input_tables)
    else:
        for table in input_tables:
            email_links[table]['date'] = input_tables[table]

    '''
        Each event is a folder in the s3 bucket corresponded to the event type.
        events:
            users.messages.email.Bounce/
            users.messages.email.Click/
            users.messages.email.Delivery/
            ...
            That way one event type is processed at a time.
    '''
    for event in email_links.keys():
        if(event not in input_tables):
            continue

        '''
            This is the last transferred date and time for this event event
            Necessary information to avoid duplicates.
        '''
        last_transferred_date, last_transferred_time = email_links[event]['date'], email_links[event]['time']

        #lists all files in a folder with the given prefix.
        result = client.list_objects(Bucket=bucket, Prefix=link2 + link3 + event, Delimiter='/').get('CommonPrefixes')

        if(len(result) > 999):
            print ("Severe problem, result is longer then 999")


        '''
            The maximal dates that would be transfered are recorded
        '''
        all_dates = get_boto_Dates(result) #the date is actually part of the file name
        if(last_transferred_date <= all_dates[0]):
            equal_files = client.list_objects(Bucket=bucket, Prefix=link2 + link3 + event + all_dates[0]).get('Contents')
            all_time = []
            for file in equal_files:
                time = file['LastModified'].time()
                all_time.append(str(time.hour) + ":" + str(time.minute) + ":" + str(time.second))

            '''
                The term 'import_time' differes from 'time' because the files have not yet
                been uploaded to Treasure data, so this is treated as temporary time until final
                upload is made
            '''
            email_links[event]['import_time'] = max(all_time)
            email_links[event]['import_date'] = all_dates[0]
        else:
            continue

        '''
            The list of all files with dates greater then the last_transferred date is compiled.
        '''

        json_output = []
        for date in all_dates:
            if(last_transferred_date > date):
                break

            output = client.list_objects(Bucket=bucket, Prefix=link2 + link3 + event + date).get('Contents')
            files_to_download = []
            for filename in output:
                location = filename['Key']
                output_location = location.split("/")
                if(last_transferred_date == date):
                    time = filename['LastModified'].time()
                    time = str(time.hour) + ":" + str(time.minute) + ":" + str(time.second)
                    if(time > last_transferred_time):
                        files_to_download.append(location)
                else:
                    files_to_download.append(location)
            '''
                Up to this point no files were actually downloaded, but instead a list of files to download
                was compiled. The next step is below.
            '''

            '''
                The main work happens here,
                -files are downloaded,
                -stored briefly in /tmp/temp.avro
                -converted to JSON
                -combined to a single large array
                -checked for duplicates.
            '''

            temp_json_output = []
            for file in files_to_download:
                filename = "/tmp/temp.avro"
                client.download_file(Bucket = bucket, Key = file, Filename = filename)
                try:
                    reader = DataFileReader(open(filename, "rb"), DatumReader())
                except Exception as e:
                    pass #this needs to be expended
                '''
                '''
                for user in reader:
                    if user not in temp_json_output:
                        temp_json_output.append(user)

            for item in temp_json_output:
                json_output.append(item)

        if(clean_start == True):
            table_name = event.split("/")[0].split(".")
            table_name = table_name[0] + "_" + table_name[1] + "_" + table_name[2] + "_" + table_name[3]
            new_schema_create_new_table(filename = filename, table_name = table_name, database_name = "braze")


        #files are moved to a single file /tmp/temp.json
        json_file_name = "/tmp/temp.json"
        file_to_treasure = open("/tmp/temp.json", "w")
        for user in json_output:
            file_to_treasure.write(json.dumps(user) + '\n')
        file_to_treasure.close()

        #this single file is uploaded to Treasure Data.
        td =  tdclient.Client(os.environ['td_apikey'])
        try:
            if(test == True):
                print('This is a test on my computer')
                print('table_name:' + email_links[event]['table_name'])
                print()
            else:
                result = td.import_file(db_name = "braze", table_name = email_links[event]['table_name'], format = "json", file = json_file_name)
        except Exception as e:
            print(e)
    return "success"
Esempio n. 23
0
 def get_client(self):
     return tdclient.Client(self.apikey, endpoint=self.endpoint)
Esempio n. 24
0
import os
import sys
import tdclient
import writegss

query = """
SELECT {0} FROM your_table
""".strip()

key = '/path/to/service-account-key.json'
worksheet = writegss.WorksheetUtil(
    '1_R1jkvv4WW7jomwMCyYJA-UCNVVFJIzdscFlF9xGXB4', sheet_index=0, keyfile=key)

header = [
    'id', 'first_name', 'middle_name', 'last_name', 'birthday', 'sex',
    'address', 'zipcode'
]
style = [{'bold': True}] * len(header)

sys.stderr.write('>>> module = {0}\n'.format(tdclient.__name__))
sys.stderr.write('{0} THROWING QUERY BELOW {1}\n{2}\n{3}\n'.format(
    '#' * 10, '#' * 20, query.format(', '.join(header)), '#' * 52))

with tdclient.Client() as td:
    job = td.query('your_database',
                   query.format(', '.join(header)),
                   type='presto')
    job.wait()
    sys.stderr.write('Result Records: {0}\n'.format(job.num_records))
    worksheet.write_records_with_prepare(job, headers=header, fg=style)
Esempio n. 25
0
    )
    logging.info("exiting script")
    raise SystemExit

#Check that repo is downloaded
if not os.path.isdir(mbedos_repo_projectname):
    logging.info("Repo not found, please run setup.py")
    raise SystemExit

################################
# Step 2.1 - Check all config values are filled out
################################

logging.info("\r\n\r\n***** Step 2.1 *******\r\n")
logging.info("Grabing data from Treasure Data")
with tdclient.Client(td_apikey) as client:
    job = client.query(td_database, td_query)
    logging.info("Runnig TD query... Please wait for it to finish (<30sec)...")
    # sleep until job's finish
    job.wait()
    logging.info("Result is")
    for row in job.result():
        logging.info(row)

###############################
# Step 2.2 - Compute custom algorithm
###############################

# get the avg value
logging.info("\r\n\r\n***** Step 2.2 *******\r\n")
x = 0
Esempio n. 26
0
def run():
    database = 'takuti'
    model_filename = 'churn_prediction_model.pkl'

    # boto3 internally checks "AWS_ACCESS_KEY_ID" and "AWS_SECRET_ACCESS_KEY":
    # http://boto3.readthedocs.io/en/latest/guide/configuration.html#environment-variables
    # Create AWS session with specific IAM role:
    # https://dev.classmethod.jp/cloud/aws/aws-sdk-for-python-boto3-assumerole/
    """
    # Option #1:
    # If S3_BUCKET is accessible from Docker container, create AWS session with
    # your IAM role ARN and download the model dump.
    client = boto3.client('sts')
    response = client.assume_role(
        RoleArn=os.environ['AWS_IAM_ROLE_ARN'],
        RoleSessionName='ml-prediction')
    session = boto3.Session(
        aws_access_key_id=response['Credentials']['AccessKeyId'],
        aws_secret_access_key=response['Credentials']['SecretAccessKey'],
        aws_session_token=response['Credentials']['SessionToken'])
    s3 = session.resource('s3')

    with open(model_filename, 'w+b') as f:
        s3.Bucket(os.environ['S3_BUCKET']).download_fileobj(model_filename, f)
    """

    # Option #2:
    # For S3 bucket that is not accessible from Docker container, upload model
    # dump with public ACL in `train.py`, and simply download it from its
    # public URL.
    url = 'https://s3.amazonaws.com/' + os.environ['S3_BUCKET'] + '/' + model_filename
    urllib.request.urlretrieve(url, model_filename)

    with open(model_filename, 'rb') as f:
        obj = pickle.load(f)
        clf, vectorizer = obj['classifier'], obj['vectorizer']

    os.remove(model_filename)

    td = tdclient.Client(apikey=os.environ['TD_API_KEY'], endpoint=os.environ['TD_API_SERVER'])

    job = td.query(database, 'select * from churn', type='presto')
    job.wait()

    keys, rows_dict = [], []
    for row in job.result():
        key, row_dict, _ = process_row(row)
        keys.append(key)
        rows_dict.append(row_dict)

    y = clf.predict(vectorizer.transform(rows_dict))

    with open('churn_prediction_result.csv', 'w') as f:
        f.write('time,key,predict\n')
        t = int(time.time())
        for key, yi in zip(keys, y):
            f.write('%d,%s,%f\n' % (t, key, yi))

    table = 'churn_predict'
    try:
        td.table(database, table)
    except tdclient.errors.NotFoundError:
        pass
    else:
        td.delete_table(database, table)
    td.create_log_table(database, table)
    td.import_file(database, table, 'csv', 'churn_prediction_result.csv')

    os.remove('churn_prediction_result.csv')

    # Wait for a while until imported records are fully available on TD
    # console.
    while True:
        job = td.query(database, 'select count(key) from ' + table, type='presto')
        job.wait()
        if not job.error():
            break
        time.sleep(30)
Esempio n. 27
0
 def get_client(self):
     return tdclient.Client(**self.kwargs)
def lambda_handler(event, context):
    link2 = 'currents/dataexport.prod-03.S3.integration.5a9ee171a12f74534a9a4e70/'

    #boto3 client environmental parameters are set.
    client = boto3.client(
        's3',
        aws_access_key_id=os.environ['aws_access_key_id'],
        aws_secret_access_key=os.environ['aws_secret_access_key'],
    )

    bucket = 'fivestars-kprod-braze-events'

    all_keys = {}
    error_keys = {}
    for record in event['Records']:

        key = record['s3']['object']['key']
        '''
            Important Gotcha - unquote_plus removes strange extra characters that appear in 'key'
        '''
        key = unquote_plus(key)
        if allowed_keys(key) == False:
            continue

        table = getTable(key)
        if table not in all_keys:
            all_keys[table] = []
        all_keys[table].append(key)
    '''
        Create all the error logs for all the all_keys
    '''
    for table in all_keys:
        if table not in error_keys:
            error_keys[table] = empty_logs()
    '''
        Here all files are read and uploaded to Treasure Data
    '''

    td = tdclient.Client(os.environ['td_apikey'])
    for table in all_keys:
        ''' Files for a given table are read and transfered to a dictionary'''
        json_output = recursive_read_aws.read_then_to_json(
            client=client,
            file_names=all_keys[table],
            bucket=bucket,
            error_keys_table=error_keys[table])
        json_file_name = "/tmp/temp.json"
        file_to_treasure = open("/tmp/temp.json", "w")
        for user in json_output:
            file_to_treasure.write(json.dumps(user) + '\n')
        file_to_treasure.close()
        if "test_number" in record:
            print(record['test_number'])
            print(table)
        else:
            if (json_output != []):
                try:
                    result = td.import_file(db_name="braze",
                                            table_name=table,
                                            format="json",
                                            file=json_file_name)
                except Exception as e:
                    print('Transfer failed for filenames: ' +
                          str(all_keys[table]))
                    '''
                        In the event of exception and a failed transfer, all the names of the failed avro
                        files are written to the error_keys and eventually to the logs.
                    '''
                    for file in all_keys[table]:
                        error_keys[table]['td']['files'].append(file)
    ''' Errors are written to a log file '''
    if (context != "test_number"):
        pass
    log_errors(error_keys=error_keys,
               client=client,
               bucket=bucket,
               link2=link2)
    return 'success'