Example #1
0
    def append_data(self, descriptor, table_name=None):
        table_name = table_name or self.carto_table_name
        client = CopySQLClient(self.carto_auth_client)

        query = "COPY {table_name} ({columns}) FROM stdin WITH (FORMAT csv, HEADER true)".format(
            table_name=table_name, columns=",".join(self.carto_field_names))
        client.copyfrom_file_object(query, descriptor)
    def __init__(self, credentials):
        self.credentials = credentials or get_default_credentials()
        check_credentials(self.credentials)

        self.auth_client = _create_auth_client(self.credentials)
        self.sql_client = SQLClient(self.auth_client)
        self.copy_client = CopySQLClient(self.auth_client)
        self.batch_sql_client = BatchSQLClient(self.auth_client)
Example #3
0
 def copy_from(self, data, filepath, to_table):
     if self._copy_client is None:
         from carto.sql import CopySQLClient
         self._copy_client = CopySQLClient(self._auth_client)
     headers = data.readline().decode('utf-8')
     data.seek(0)
     from_query = 'COPY %s (%s) FROM stdin WITH (FORMAT csv, HEADER true)' % (
         to_table, headers)
     return self._copy_client.copyfrom_file_object(from_query, data)
Example #4
0
def copy(tablename, rows, delimiter=",", quote='"', headers=None):
    copy_client = CopySQLClient(_get_auth_client())
    rows = iter(rows)
    if headers is None:
        headers = delimiter.join(next(rows))
    else:
        headers = delimiter.join(headers)

    from_query = f"""COPY {tablename} ({headers}) FROM stdin
        (FORMAT CSV, DELIMITER '{delimiter}', HEADER false)"""
    try:
        return copy_client.copyfrom(from_query,
                                    rows_generator(rows, delimiter, quote))
    except CartoException as e:
        log.error(f"Error importing \n\n {rows}")
        log.exception(e)
Example #5
0
    def initialize(self):
        if not self.api_url and self.user_name:
            self.api_url = "https://{}.carto.com/api/".format(self.user_name)
        elif not self.api_url and not self.user_name:
            raise Exception(
                'Not enough data provided to initialize the client')

        if self.org_name:
            self.client = APIKeyAuthClient(self.api_url, self.api_key,
                                           self.org_name)
        else:
            self.client = APIKeyAuthClient(self.api_url, self.api_key)

        self.sql_client = SQLClient(self.client)
        self.batch_client = BatchSQLClient(self.client)
        self.copy_client = CopySQLClient(self.client)
Example #6
0
    logger.error('You need to provide valid credentials, run with '
                 '-h parameter for details')
    sys.exit(1)

# Create and cartodbfy a table
sqlClient = SQLClient(auth_client)
sqlClient.send("""
    CREATE TABLE IF NOT EXISTS copy_example (
      the_geom geometry(Geometry,4326),
      name text,
      age integer
    )
    """)
sqlClient.send("SELECT CDB_CartodbfyTable(current_schema, 'copy_example')")

copyClient = CopySQLClient(auth_client)

# COPY FROM example
logger.info("COPY'ing FROM file...")
query = ('COPY copy_example (the_geom, name, age) '
         'FROM stdin WITH (FORMAT csv, HEADER true)')
result = copyClient.copyfrom_file_path(query, 'files/copy_from.csv')
logger.info('result = %s' % result)

# COPY TO example
query = 'COPY copy_example TO stdout WITH (FORMAT csv, HEADER true)'
output_file = 'files/copy_export.csv'
copyClient.copyto_file_path(query, output_file)
logger.info('Table copied to %s' % output_file)

# Truncate the table to make this example repeatable
class ContextManager:
    def __init__(self, credentials):
        self.credentials = credentials or get_default_credentials()
        check_credentials(self.credentials)

        self.auth_client = _create_auth_client(self.credentials)
        self.sql_client = SQLClient(self.auth_client)
        self.copy_client = CopySQLClient(self.auth_client)
        self.batch_sql_client = BatchSQLClient(self.auth_client)

    def execute_query(self,
                      query,
                      parse_json=True,
                      do_post=True,
                      format=None,
                      **request_args):
        return self.sql_client.send(query.strip(), parse_json, do_post, format,
                                    **request_args)

    def execute_long_running_query(self, query):
        return self.batch_sql_client.create_and_wait_for_completion(
            query.strip())

    def copy_to(self,
                source,
                schema,
                limit=None,
                retry_times=DEFAULT_RETRY_TIMES):
        query = self.compute_query(source, schema)
        columns = self._get_query_columns_info(query)
        copy_query = self._get_copy_query(query, columns, limit)
        return self._copy_to(copy_query, columns, retry_times)

    def copy_from(self, gdf, table_name, if_exists='fail', cartodbfy=True):
        schema = self.get_schema()
        table_name = self.normalize_table_name(table_name)
        columns = get_dataframe_columns_info(gdf)

        if if_exists == 'replace' or not self.has_table(table_name, schema):
            log.debug('Creating table "{}"'.format(table_name))
            self._create_table_from_columns(table_name, columns, schema,
                                            cartodbfy)
        elif if_exists == 'fail':
            raise Exception(
                'Table "{schema}.{table_name}" already exists in your CARTO account. '
                'Please choose a different `table_name` or use '
                'if_exists="replace" to overwrite it.'.format(
                    table_name=table_name, schema=schema))
        else:  # 'append'
            pass

        self._copy_from(gdf, table_name, columns)
        return table_name

    def create_table_from_query(self,
                                query,
                                table_name,
                                if_exists,
                                cartodbfy=True):
        schema = self.get_schema()
        table_name = self.normalize_table_name(table_name)

        if if_exists == 'replace' or not self.has_table(table_name, schema):
            log.debug('Creating table "{}"'.format(table_name))
            self._create_table_from_query(query, table_name, schema, cartodbfy)
        elif if_exists == 'fail':
            raise Exception(
                'Table "{schema}.{table_name}" already exists in your CARTO account. '
                'Please choose a different `table_name` or use '
                'if_exists="replace" to overwrite it.'.format(
                    table_name=table_name, schema=schema))
        else:  # 'append'
            pass

        return table_name

    def has_table(self, table_name, schema=None):
        query = self.compute_query(table_name, schema)
        return self._check_exists(query)

    def delete_table(self, table_name):
        query = _drop_table_query(table_name)
        output = self.execute_query(query)
        return not ('notices' in output
                    and 'does not exist' in output['notices'][0])

    def rename_table(self, table_name, new_table_name, if_exists='fail'):
        new_table_name = self.normalize_table_name(new_table_name)

        if table_name == new_table_name:
            raise ValueError(
                'Table names are equal. Please choose a different table name.')

        if not self.has_table(table_name):
            raise Exception(
                'Table "{table_name}" does not exist in your CARTO account.'.
                format(table_name=table_name))

        if self.has_table(new_table_name):
            if if_exists == 'replace':
                log.debug('Removing table "{}"'.format(new_table_name))
                self.delete_table(new_table_name)
            elif if_exists == 'fail':
                raise Exception(
                    'Table "{new_table_name}" already exists in your CARTO account. '
                    'Please choose a different `new_table_name` or use '
                    'if_exists="replace" to overwrite it.'.format(
                        new_table_name=new_table_name))

        self._rename_table(table_name, new_table_name)
        return new_table_name

    def update_privacy_table(self, table_name, privacy=None):
        DatasetInfo(self.auth_client, table_name).update_privacy(privacy)

    def get_privacy(self, table_name):
        return DatasetInfo(self.auth_client, table_name).privacy

    def get_schema(self):
        """Get user schema from current credentials"""
        query = 'SELECT current_schema()'
        result = self.execute_query(query, do_post=False)
        return result['rows'][0]['current_schema']

    def get_geom_type(self, query):
        """Fetch geom type of a remote table or query"""
        distict_query = '''
            SELECT distinct ST_GeometryType(the_geom) AS geom_type
            FROM ({}) q
            LIMIT 5
        '''.format(query)
        response = self.execute_query(distict_query, do_post=False)
        if response and response.get('rows') and len(response.get('rows')) > 0:
            st_geom_type = response.get('rows')[0].get('geom_type')
            if st_geom_type:
                return map_geom_type(st_geom_type[3:])
        return None

    def get_num_rows(self, query):
        """Get the number of rows in the query"""
        result = self.execute_query(
            "SELECT COUNT(*) FROM ({query}) _query".format(query=query))
        return result.get('rows')[0].get('count')

    def get_bounds(self, query):
        extent_query = '''
            SELECT ARRAY[
                ARRAY[st_xmin(geom_env), st_ymin(geom_env)],
                ARRAY[st_xmax(geom_env), st_ymax(geom_env)]
            ] bounds FROM (
                SELECT ST_Extent(the_geom) geom_env
                FROM ({}) q
            ) q;
        '''.format(query)
        response = self.execute_query(extent_query, do_post=False)
        if response and response.get('rows') and len(response.get('rows')) > 0:
            return response.get('rows')[0].get('bounds')
        return None

    def get_column_names(self, source, schema=None, exclude=None):
        query = self.compute_query(source, schema)
        columns = [c.name for c in self._get_query_columns_info(query)]

        if exclude and isinstance(exclude, list):
            columns = list(set(columns) - set(exclude))

        return columns

    def is_public(self, query):
        # Used to detect public tables in queries in the publication,
        # because privacy only works for tables.
        public_auth_client = _create_auth_client(self.credentials, public=True)
        public_sql_client = SQLClient(public_auth_client)
        exists_query = 'EXPLAIN {}'.format(query)
        try:
            public_sql_client.send(exists_query, do_post=False)
            return True
        except CartoException:
            return False

    def get_table_names(self, query):
        # Used to detect tables in queries in the publication.
        query = 'SELECT CDB_QueryTablesText(\'{}\') as tables'.format(query)
        result = self.execute_query(query)
        tables = []
        if result['total_rows'] > 0 and result['rows'][0]['tables']:
            # Dataset_info only works with tables without schema
            tables = [
                table.split('.')[1] if '.' in table else table
                for table in result['rows'][0]['tables']
            ]
        return tables

    def _create_table_from_query(self,
                                 query,
                                 table_name,
                                 schema,
                                 cartodbfy=True):
        query = 'BEGIN; {drop}; {create}; {cartodbfy}; COMMIT;'.format(
            drop=_drop_table_query(table_name),
            create=_create_table_from_query_query(table_name, query),
            cartodbfy=_cartodbfy_query(table_name, schema)
            if cartodbfy else '')
        self.execute_long_running_query(query)

    def _create_table_from_columns(self,
                                   table_name,
                                   columns,
                                   schema,
                                   cartodbfy=True):
        query = 'BEGIN; {drop}; {create}; {cartodbfy}; COMMIT;'.format(
            drop=_drop_table_query(table_name),
            create=_create_table_from_columns_query(table_name, columns),
            cartodbfy=_cartodbfy_query(table_name, schema)
            if cartodbfy else '')
        self.execute_long_running_query(query)

    def compute_query(self, source, schema=None):
        if is_sql_query(source):
            return source
        schema = schema or self.get_schema()
        return self._compute_query_from_table(source, schema)

    def _compute_query_from_table(self, table_name, schema):
        return 'SELECT * FROM "{schema}"."{table_name}"'.format(
            schema=schema or 'public', table_name=table_name)

    def _check_exists(self, query):
        exists_query = 'EXPLAIN {}'.format(query)
        try:
            self.execute_query(exists_query, do_post=False)
            return True
        except CartoException:
            return False

    def _get_query_columns_info(self, query):
        query = 'SELECT * FROM ({}) _q LIMIT 0'.format(query)
        table_info = self.execute_query(query)
        return Column.from_sql_api_fields(table_info['fields'])

    def _get_copy_query(self, query, columns, limit):
        query_columns = [
            column.name for column in columns
            if (column.name != 'the_geom_webmercator')
        ]

        query = 'SELECT {columns} FROM ({query}) _q'.format(
            query=query, columns=','.join(query_columns))

        if limit is not None:
            if isinstance(limit, int) and (limit >= 0):
                query += ' LIMIT {limit}'.format(limit=limit)
            else:
                raise ValueError("`limit` parameter must an integer >= 0")

        return query

    def _copy_to(self, query, columns, retry_times):
        copy_query = 'COPY ({0}) TO stdout WITH (FORMAT csv, HEADER true, NULL \'{1}\')'.format(
            query, PG_NULL)

        try:
            raw_result = self.copy_client.copyto_stream(copy_query)
        except CartoRateLimitException as err:
            if retry_times > 0:
                retry_times -= 1
                warn('Read call rate limited. Waiting {s} seconds'.format(
                    s=err.retry_after))
                time.sleep(err.retry_after)
                warn('Retrying...')
                return self._copy_to(query, columns, retry_times)
            else:
                warn((
                    'Read call was rate-limited. '
                    'This usually happens when there are multiple queries being read at the same time.'
                ))
                raise err

        converters = obtain_converters(columns)
        parse_dates = date_columns_names(columns)

        df = read_csv(raw_result,
                      converters=converters,
                      parse_dates=parse_dates)

        return df

    def _copy_from(self, dataframe, table_name, columns):
        query = """
            COPY {table_name}({columns}) FROM stdin WITH (FORMAT csv, DELIMITER '|', NULL '{null}');
        """.format(table_name=table_name,
                   null=PG_NULL,
                   columns=','.join(column.dbname
                                    for column in columns)).strip()
        data = _compute_copy_data(dataframe, columns)
        self.copy_client.copyfrom(query, data)

    def _rename_table(self, table_name, new_table_name):
        query = _rename_table_query(table_name, new_table_name)
        self.execute_query(query)

    def normalize_table_name(self, table_name):
        norm_table_name = normalize_name(table_name)
        if norm_table_name != table_name:
            log.debug('Table name normalized: "{}"'.format(norm_table_name))
        return norm_table_name
Example #8
0
class CartoDataSource(DataSource):
    SUBDOMAIN_URL_PATTERN = "https://%s.carto.com"
    ON_PREMISES_URL_PATTERN = "https://%s/user/%s"
    DEFAULT_API_VERSION = 'v2'

    def __init__(self, user, api_key, options={}):
        super().__init__(options)

        self.do_post = options.get('do_post', False)
        self.parse_json = options.get('parse_json', True)
        self.format = options.get('format', 'json')
        self.base_url_option = options.get('base_url', '')
        self.api_version = options.get('api_version', self.DEFAULT_API_VERSION)
        self.batch = options.get('batch', False)

        self.user = user
        self.api_key = api_key
        self.base_url = self._generate_base_url(user, self.base_url_option)

        # Carto Context for DataFrame handling
        self._carto_context = None

        # Carto client for COPYs
        self._copy_client = None

        self._auth_client = APIKeyAuthClient(api_key=api_key,
                                             base_url=self.base_url)
        self._sql_client = SQLClient(self._auth_client,
                                     api_version=self.api_version)

        self._batch_client = None
        if self.batch:
            self._batch_client = BatchSQLClient(self._auth_client)

    @property
    def cc(self):
        """
        Creates and returns a CartoContext object to work with Panda Dataframes
        :return:
        """
        # TODO: The CartoContext documentaton says that SSL must be disabled sometimes if an on
        #  premise host is used.
        #  We are not taking this into account. It would need to create a requests.Session()
        #  object, set its SSL to false and pass it to the CartoContext init.
        if self._carto_context is None:
            self._carto_context = cartoframes.CartoContext(
                base_url=self.base_url, api_key=self.api_key)
        return self._carto_context

    def _generate_base_url(self, user, base_url_option):
        if base_url_option:
            base_url = self.ON_PREMISES_URL_PATTERN % (base_url_option, user)
        else:
            base_url = self.SUBDOMAIN_URL_PATTERN % user
        return base_url

    def execute_query(self, query_template, params, query_config, **opts):
        # TODO: Here we are parsing the parameters and taking responsability for it. We do not make
        #  any safe parsing as this will be used in a backend-to-backend context and we build our
        #  own queries.
        #  ---
        #  This is also problematic as quoting is not done and relies in the query template
        #  ---
        #  Can we use the .mogrify method in psycopg2 to render a query as it is going to be
        #  executed ? -> NO
        #   ->  .mogrify is a cursor method but in CARTO connections we lack a cursor.
        #  ---
        #  There is an open issue in CARTO about having separated parameters and binding them in
        #  the server:
        #   https://github.com/CartoDB/Geographica-Product-Coordination/issues/57
        params = {k: "'" + v + "'" for k, v in params.items()}
        formatted_query = query_template % params

        try:
            return self._sql_client.send(formatted_query,
                                         parse_json=self.parse_json,
                                         do_post=self.do_post,
                                         format=self.format)

        except CartoException as e:
            raise LongitudeQueryCannotBeExecutedException(str(e))

    def parse_response(self, response):
        return LongitudeQueryResponse(rows=response['rows'],
                                      fields=response['fields'],
                                      meta={
                                          'response_time':
                                          response.get('time'),
                                          'total_rows':
                                          response.get('total_rows')
                                      })

    def copy_from(self, data, filepath, to_table):
        if self._copy_client is None:
            from carto.sql import CopySQLClient
            self._copy_client = CopySQLClient(self._auth_client)
        headers = data.readline().decode('utf-8')
        data.seek(0)
        from_query = 'COPY %s (%s) FROM stdin WITH (FORMAT csv, HEADER true)' % (
            to_table, headers)
        return self._copy_client.copyfrom_file_object(from_query, data)

    def read_dataframe(self, table_name='', *args, **kwargs):
        return self.cc.read(table_name=table_name, *args, **kwargs)

    def query_dataframe(self, query='', *args, **kwargs):
        return self.cc.query(query=query, *args, **kwargs)

    def write_dataframe(self, df, table_name='', *args, **kwargs):
        return self.cc.write(df=df, table_name=table_name, *args, **kwargs)
Example #9
0
                          ' https://username.carto.com/'
                          ' (defaults to env variable CARTO_API_URL)'))

parser.add_argument('--api_key', dest='CARTO_API_KEY',
                    default=os.environ.get('CARTO_API_KEY', ''),
                    help=('Api key of the account'
                          ' (defaults to env variable CARTO_API_KEY)'))

args = parser.parse_args()

if not args.CARTO_BASE_URL or not args.CARTO_API_KEY:
    sys.exit(parser.print_usage())

auth_client = APIKeyAuthClient(args.CARTO_BASE_URL, args.CARTO_API_KEY)
sql_client = SQLClient(auth_client)
copy_client = CopySQLClient(auth_client)

# Create a table suitable to receive the data
logger.info('Creating table nexrad_copy_example...')
sql_client.send("""CREATE TABLE IF NOT EXISTS nexrad_copy_example (
  the_geom geometry(Geometry,4326),
  reflectivity numeric
)""")
sql_client.send(
    "SELECT CDB_CartodbfyTable(current_schema, 'nexrad_copy_example')")
logger.info('Done')

logger.info('Trying to connect to the THREDDS radar query service')
rs = RadarServer(
    'http://thredds.ucar.edu/thredds/radarServer/nexrad/level2/IDD/')
Example #10
0
def copy_client(api_key_auth_client_usr):
    return CopySQLClient(api_key_auth_client_usr)
    logger.error('You need to provide valid credentials, run with '
                 '-h parameter for details')
    sys.exit(1)

# Create and cartodbfy a table
sqlClient = SQLClient(auth_client)
sqlClient.send("""
    CREATE TABLE IF NOT EXISTS copy_example (
      the_geom geometry(Geometry,4326),
      name text,
      age integer
    )
    """)
sqlClient.send("SELECT CDB_CartodbfyTable(current_schema, 'copy_example')")

copyClient = CopySQLClient(auth_client)

# COPY FROM example
logger.info("COPY'ing FROM file...")
query = ('COPY copy_example (the_geom, name, age) '
         'FROM stdin WITH (FORMAT csv, HEADER true)')
result = copyClient.copyfrom_file_path(query, 'files/copy_from.csv')
logger.info('result = %s' % result)

# COPY TO example with pandas DataFrame
logger.info("COPY'ing TO pandas DataFrame...")
query = 'COPY copy_example TO stdout WITH (FORMAT csv, HEADER true)'
result = copyClient.copyto_stream(query)
df = pd.read_csv(result)
logger.info(df.head())
Example #12
0
# This is a bit of a trick: we omit the sequences to avoid
# dependencies on other objects Normally this just affects the
# cartodb_id and can optionally be fixed by cartodbfy'ing
create_table_no_seqs = re.sub(r'DEFAULT nextval\([^\)]+\)', ' ', create_table)
logger.info(create_table_no_seqs)

# Create the table in the destination account
logger.info('Creating the table in the destination account...')
res = sql_dst_client.send(create_table_no_seqs)
logger.info('Response: {}'.format(res))

# Cartodbfy the table (this is optional)
logger.info("Cartodbfy'ing the destination table...")
res = sql_dst_client.send(
    "SELECT CDB_CartodbfyTable(current_schema, '%s')" % TABLE_NAME
)
logger.info('Response: {}'.format(res))

# Create COPY clients
copy_src_client = CopySQLClient(auth_src_client)
copy_dst_client = CopySQLClient(auth_dst_client)

# COPY (streaming) the data from the source to the dest table. We use
# here all the COPY defaults. Note that we take the `response` from
# the `copyto`, which can be iterated, and we pipe it directly into
# the `copyfrom`.
logger.info("Streaming the data from source to destination...")
response = copy_src_client.copyto('COPY %s TO STDOUT' % TABLE_NAME)
result = copy_dst_client.copyfrom('COPY %s FROM STDIN' % TABLE_NAME, response)
logger.info('Result: {}'.format(result))
Example #13
0
class CARTOUser(object):
    def __init__(self,
                 user_name=None,
                 org_name=None,
                 api_url=None,
                 api_key=None,
                 check_ssl=True):
        self.user_name = user_name
        self.org_name = org_name
        self.api_url = api_url
        self.api_key = api_key

        if not check_ssl:
            old_request = requests.Session.request
            requests.Session.request = partialmethod(old_request, verify=False)
            warnings.filterwarnings('ignore', 'Unverified HTTPS request')

    def initialize(self):
        if not self.api_url and self.user_name:
            self.api_url = "https://{}.carto.com/api/".format(self.user_name)
        elif not self.api_url and not self.user_name:
            raise Exception(
                'Not enough data provided to initialize the client')

        if self.org_name:
            self.client = APIKeyAuthClient(self.api_url, self.api_key,
                                           self.org_name)
        else:
            self.client = APIKeyAuthClient(self.api_url, self.api_key)

        self.sql_client = SQLClient(self.client)
        self.batch_client = BatchSQLClient(self.client)
        self.copy_client = CopySQLClient(self.client)

    def execute_sql(self, query, parse_json=True, format=None, do_post=False):
        try:
            try:
                self.client
            except AttributeError:
                self.initialize()
            return self.sql_client.send(query,
                                        parse_json=parse_json,
                                        format=format,
                                        do_post=do_post)
        except CartoException as e:
            raise Exception(e.args[0].args[0][0])

    def batch_check(self, job_id):
        try:
            self.batch_client
        except AttributeError:
            self.initialize()
        return self.batch_client.read(job_id)

    def batch_create(self, query):
        try:
            self.batch_client
        except AttributeError:
            self.initialize()
        return self.batch_client.create(query)

    def batch_cancel(self, job_id):
        try:
            self.batch_client
        except AttributeError:
            self.initialize()
        return self.batch_client.cancel(job_id)

    def get_dataset_manager(self):
        try:
            self.sql_client
        except AttributeError:
            self.initialize()
        return DatasetManager(self.client)

    def get_sync_manager(self):
        try:
            self.sql_client
        except AttributeError:
            self.initialize()
        return SyncTableJobManager(self.client)

    def upload(self, uri, sync_time=None):
        try:
            self.sql_client
        except AttributeError:
            self.initialize()

        dataset_manager = DatasetManager(self.client)

        if sync_time:
            return dataset_manager.create(uri, sync_time)
        else:
            return dataset_manager.create(uri)

    def copy_from(self, path, query, tablename=None, delimiter=','):
        try:
            self.copy_client
        except AttributeError:
            self.initialize()

        if tablename is None:
            tablename = Path(path).stem

        if query is None:
            with open(path, 'rb') as myfile:
                headers = next(myfile).strip().decode('utf8')
                query = f"""COPY {tablename} ({headers}) FROM stdin
                (FORMAT CSV, DELIMITER '{delimiter}', HEADER false, QUOTE '"')"""
                return self.copy_client.copyfrom_file_object(query, myfile)
        return self.copy_client.copyfrom_file_path(query, path)

    def copy_to(self, query, output, delimiter=','):
        try:
            self.copy_client
        except AttributeError:
            self.initialize()

        copy_query = f"""COPY ({query}) TO stdout WITH
        (FORMAT CSV, DELIMITER '{delimiter}', HEADER true, QUOTE '"')"""

        return self.copy_client.copyto_file_path(copy_query, output)
Example #14
0
# This is a bit of a trick: we omit the sequences to avoid
# dependencies on other objects Normally this just affects the
# cartodb_id and can optionally be fixed by cartodbfy'ing
create_table_no_seqs = re.sub(r'DEFAULT nextval\([^\)]+\)', ' ', create_table)
logger.info(create_table_no_seqs)

# Create the table in the destination account
logger.info('Creating the table in the destination account...')
res = sql_dst_client.send(create_table_no_seqs)
logger.info('Response: {}'.format(res))

# Cartodbfy the table (this is optional)
logger.info("Cartodbfy'ing the destination table...")
res = sql_dst_client.send("SELECT CDB_CartodbfyTable(current_schema, '%s')" %
                          TABLE_NAME)
logger.info('Response: {}'.format(res))

# Create COPY clients
copy_src_client = CopySQLClient(auth_src_client)
copy_dst_client = CopySQLClient(auth_dst_client)

# COPY (streaming) the data from the source to the dest table. We use
# here all the COPY defaults. Note that we take the `response` from
# the `copyto`, which can be iterated, and we pipe it directly into
# the `copyfrom`.
logger.info("Streaming the data from source to destination...")
response = copy_src_client.copyto('COPY %s TO STDOUT' % TABLE_NAME)
result = copy_dst_client.copyfrom('COPY %s FROM STDIN' % TABLE_NAME, response)
logger.info('Result: {}'.format(result))
Example #15
0
    logger.error('You need to provide valid credentials, run with '
                 '-h parameter for details')
    sys.exit(1)

# Create and cartodbfy a table
sqlClient = SQLClient(auth_client)
sqlClient.send("""
    CREATE TABLE IF NOT EXISTS copy_example (
      the_geom geometry(Geometry,4326),
      name text,
      age integer
    )
    """)
sqlClient.send("SELECT CDB_CartodbfyTable(current_schema, 'copy_example')")

copyClient = CopySQLClient(auth_client)

# COPY FROM example
logger.info("COPY'ing FROM file...")
query = ('COPY copy_example (the_geom, name, age) '
         'FROM stdin WITH (FORMAT csv, HEADER true)')
result = copyClient.copyfrom_file_path(query, 'files/copy_from.csv')
logger.info('result = %s' % result)

# COPY TO example
query = 'COPY copy_example TO stdout WITH (FORMAT csv, HEADER true)'
output_file = 'files/copy_export.csv'
copyClient.copyto_file_path(query, output_file)
logger.info('Table copied to %s' % output_file)

# Truncate the table to make this example repeatable
    logger.error('You need to provide valid credentials, run with '
                 '-h parameter for details')
    sys.exit(1)

# Create and cartodbfy a table
sqlClient = SQLClient(auth_client)
sqlClient.send("""
    CREATE TABLE IF NOT EXISTS copy_example (
      the_geom geometry(Geometry,4326),
      name text,
      age integer
    )
    """)
sqlClient.send("SELECT CDB_CartodbfyTable(current_schema, 'copy_example')")

copyClient = CopySQLClient(auth_client)

# COPY FROM example
logger.info("COPY'ing FROM file...")
query = ('COPY copy_example (the_geom, name, age) '
         'FROM stdin WITH (FORMAT csv, HEADER true)')
result = copyClient.copyfrom_file_path(query, 'files/copy_from.csv')
logger.info('result = %s' % result)

# COPY TO example with pandas DataFrame
logger.info("COPY'ing TO pandas DataFrame...")
query = 'COPY copy_example TO stdout WITH (FORMAT csv, HEADER true)'
result = copyClient.copyto_stream(query)
df = pd.read_csv(result)
logger.info(df.head())
Example #17
0
class ContextManager:
    def __init__(self, credentials):
        self.credentials = credentials or get_default_credentials()
        check_credentials(self.credentials)

        self.auth_client = _create_auth_client(self.credentials)
        self.sql_client = SQLClient(self.auth_client)
        self.copy_client = CopySQLClient(self.auth_client)
        self.batch_sql_client = BatchSQLClient(self.auth_client)

    @not_found
    def execute_query(self,
                      query,
                      parse_json=True,
                      do_post=True,
                      format=None,
                      **request_args):
        return self.sql_client.send(query.strip(), parse_json, do_post, format,
                                    **request_args)

    @not_found
    def execute_long_running_query(self, query):
        return self.batch_sql_client.create_and_wait_for_completion(
            query.strip())

    def copy_to(self,
                source,
                schema=None,
                limit=None,
                retry_times=DEFAULT_RETRY_TIMES):
        query = self.compute_query(source, schema)
        columns = self._get_query_columns_info(query)
        copy_query = self._get_copy_query(query, columns, limit)
        return self._copy_to(copy_query, columns, retry_times)

    def copy_from(self,
                  gdf,
                  table_name,
                  if_exists='fail',
                  cartodbfy=True,
                  retry_times=DEFAULT_RETRY_TIMES):
        schema = self.get_schema()
        table_name = self.normalize_table_name(table_name)
        df_columns = get_dataframe_columns_info(gdf)

        if self.has_table(table_name, schema):
            if if_exists == 'replace':
                table_query = self._compute_query_from_table(
                    table_name, schema)
                table_columns = self._get_query_columns_info(table_query)

                if self._compare_columns(df_columns, table_columns):
                    # Equal columns: truncate table
                    self._truncate_table(table_name, schema, cartodbfy)
                else:
                    # Diff columns: truncate table and drop + add columns
                    self._truncate_and_drop_add_columns(
                        table_name, schema, df_columns, table_columns,
                        cartodbfy)

            elif if_exists == 'fail':
                raise Exception(
                    'Table "{schema}.{table_name}" already exists in your CARTO account. '
                    'Please choose a different `table_name` or use '
                    'if_exists="replace" to overwrite it.'.format(
                        table_name=table_name, schema=schema))
            else:  # 'append'
                pass
        else:
            self._create_table_from_columns(table_name, schema, df_columns,
                                            cartodbfy)

        self._copy_from(gdf, table_name, df_columns, retry_times)
        return table_name

    def create_table_from_query(self,
                                query,
                                table_name,
                                if_exists,
                                cartodbfy=True):
        schema = self.get_schema()
        table_name = self.normalize_table_name(table_name)

        if self.has_table(table_name, schema):
            if if_exists == 'replace':
                # TODO: review logic copy_from
                self._drop_create_table_from_query(table_name, schema, query,
                                                   cartodbfy)
            elif if_exists == 'fail':
                raise Exception(
                    'Table "{schema}.{table_name}" already exists in your CARTO account. '
                    'Please choose a different `table_name` or use '
                    'if_exists="replace" to overwrite it.'.format(
                        table_name=table_name, schema=schema))
            else:  # 'append'
                pass
        else:
            self._drop_create_table_from_query(table_name, schema, query,
                                               cartodbfy)

        return table_name

    def list_tables(self, schema=None):
        datasets = DatasetManager(self.auth_client).filter(
            show_table_size_and_row_count='false',
            show_table='false',
            show_stats='false',
            show_likes='false',
            show_liked='false',
            show_permission='false',
            show_uses_builder_features='false',
            show_synchronization='false',
            load_totals='false')
        datasets.sort(key=lambda x: x.updated_at, reverse=True)
        return pd.DataFrame([dataset.name for dataset in datasets],
                            columns=['tables'])

    def has_table(self, table_name, schema=None):
        query = self.compute_query(table_name, schema)
        return self._check_exists(query)

    def delete_table(self, table_name):
        query = _drop_table_query(table_name)
        output = self.execute_query(query)
        return not ('notices' in output
                    and 'does not exist' in output['notices'][0])

    def rename_table(self, table_name, new_table_name, if_exists='fail'):
        new_table_name = self.normalize_table_name(new_table_name)

        if table_name == new_table_name:
            raise ValueError(
                'Table names are equal. Please choose a different table name.')

        if not self.has_table(table_name):
            raise Exception(
                'Table "{table_name}" does not exist in your CARTO account.'.
                format(table_name=table_name))

        if self.has_table(new_table_name):
            if if_exists == 'replace':
                log.debug('Removing table "{}"'.format(new_table_name))
                self.delete_table(new_table_name)
            elif if_exists == 'fail':
                raise Exception(
                    'Table "{new_table_name}" already exists in your CARTO account. '
                    'Please choose a different `new_table_name` or use '
                    'if_exists="replace" to overwrite it.'.format(
                        new_table_name=new_table_name))

        self._rename_table(table_name, new_table_name)
        return new_table_name

    def update_privacy_table(self, table_name, privacy=None):
        DatasetInfo(self.auth_client, table_name).update_privacy(privacy)

    def get_privacy(self, table_name):
        return DatasetInfo(self.auth_client, table_name).privacy

    def get_schema(self):
        """Get user schema from current credentials"""
        query = 'SELECT current_schema()'
        result = self.execute_query(query, do_post=False)
        schema = result['rows'][0]['current_schema']
        log.debug('schema: {}'.format(schema))
        return schema

    def get_geom_type(self, query):
        """Fetch geom type of a remote table or query"""
        distict_query = '''
            SELECT distinct ST_GeometryType(the_geom) AS geom_type
            FROM ({}) q
            LIMIT 5
        '''.format(query)
        response = self.execute_query(distict_query, do_post=False)
        if response and response.get('rows') and len(response.get('rows')) > 0:
            st_geom_type = response.get('rows')[0].get('geom_type')
            if st_geom_type:
                return map_geom_type(st_geom_type[3:])
        return None

    def get_num_rows(self, query):
        """Get the number of rows in the query"""
        result = self.execute_query(
            'SELECT COUNT(*) FROM ({query}) _query'.format(query=query))
        return result.get('rows')[0].get('count')

    def get_bounds(self, query):
        extent_query = '''
            SELECT ARRAY[
                ARRAY[st_xmin(geom_env), st_ymin(geom_env)],
                ARRAY[st_xmax(geom_env), st_ymax(geom_env)]
            ] bounds FROM (
                SELECT ST_Extent(the_geom) geom_env
                FROM ({}) q
            ) q;
        '''.format(query)
        response = self.execute_query(extent_query, do_post=False)
        if response and response.get('rows') and len(response.get('rows')) > 0:
            return response.get('rows')[0].get('bounds')
        return None

    def get_column_names(self, source, schema=None, exclude=None):
        query = self.compute_query(source, schema)
        columns = [c.name for c in self._get_query_columns_info(query)]

        if exclude and isinstance(exclude, list):
            columns = list(set(columns) - set(exclude))

        return columns

    def is_public(self, query):
        # Used to detect public tables in queries in the publication,
        # because privacy only works for tables.
        public_auth_client = _create_auth_client(self.credentials, public=True)
        public_sql_client = SQLClient(public_auth_client)
        exists_query = 'EXPLAIN {}'.format(query)
        try:
            public_sql_client.send(exists_query, do_post=False)
            return True
        except CartoException:
            return False

    def get_table_names(self, query):
        # Used to detect tables in queries in the publication.
        query = 'SELECT CDB_QueryTablesText($q${}$q$) as tables'.format(query)
        result = self.execute_query(query)
        tables = []
        if result['total_rows'] > 0 and result['rows'][0]['tables']:
            # Dataset_info only works with tables without schema
            tables = [
                table.split('.')[1] if '.' in table else table
                for table in result['rows'][0]['tables']
            ]
        return tables

    def _compare_columns(self, a, b):
        a_copy = [i for i in a if _not_reserved(i.name)]
        b_copy = [i for i in b if _not_reserved(i.name)]

        a_copy.sort()
        b_copy.sort()

        return a_copy == b_copy

    def _drop_create_table_from_query(self, table_name, schema, query,
                                      cartodbfy):
        log.debug('DROP + CREATE table "{}"'.format(table_name))
        query = 'BEGIN; {drop}; {create}; {cartodbfy}; COMMIT;'.format(
            drop=_drop_table_query(table_name),
            create=_create_table_from_query_query(table_name, query),
            cartodbfy=_cartodbfy_query(table_name, schema)
            if cartodbfy else '')
        self.execute_long_running_query(query)

    def _create_table_from_columns(self, table_name, schema, columns,
                                   cartodbfy):
        log.debug('CREATE table "{}"'.format(table_name))
        query = 'BEGIN; {create}; {cartodbfy}; COMMIT;'.format(
            create=_create_table_from_columns_query(table_name, columns),
            cartodbfy=_cartodbfy_query(table_name, schema)
            if cartodbfy else '')
        self.execute_long_running_query(query)

    def _truncate_table(self, table_name, schema, cartodbfy):
        log.debug('TRUNCATE table "{}"'.format(table_name))
        query = 'BEGIN; {truncate}; {cartodbfy}; COMMIT;'.format(
            truncate=_truncate_table_query(table_name),
            cartodbfy=_cartodbfy_query(table_name, schema)
            if cartodbfy else '')
        self.execute_long_running_query(query)

    def _truncate_and_drop_add_columns(self, table_name, schema, df_columns,
                                       table_columns, cartodbfy):
        log.debug(
            'TRUNCATE AND DROP + ADD columns table "{}"'.format(table_name))
        query = '{regenerate}; BEGIN; {truncate}; {drop_columns}; {add_columns}; {cartodbfy}; COMMIT;'.format(
            regenerate=_regenerate_table_query(table_name, schema)
            if self._check_regenerate_table_exists() else '',
            truncate=_truncate_table_query(table_name),
            drop_columns=_drop_columns_query(table_name, table_columns),
            add_columns=_add_columns_query(table_name, df_columns),
            cartodbfy=_cartodbfy_query(table_name, schema)
            if cartodbfy else '')
        self.execute_long_running_query(query)

    def compute_query(self, source, schema=None):
        if is_sql_query(source):
            return source
        schema = schema or self.get_schema()
        return self._compute_query_from_table(source, schema)

    def _compute_query_from_table(self, table_name, schema):
        return 'SELECT * FROM "{schema}"."{table_name}"'.format(
            schema=schema or 'public', table_name=table_name)

    def _check_exists(self, query):
        exists_query = 'EXPLAIN {}'.format(query)
        try:
            self.execute_query(exists_query, do_post=False)
            return True
        except CartoException:
            return False

    def _check_regenerate_table_exists(self):
        query = '''
            SELECT 1
            FROM pg_catalog.pg_proc p
            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
            WHERE p.proname = 'cdb_regeneratetable' AND n.nspname = 'cartodb';
        '''
        result = self.execute_query(query)
        return len(result['rows']) > 0

    def _get_query_columns_info(self, query):
        query = 'SELECT * FROM ({}) _q LIMIT 0'.format(query)
        table_info = self.execute_query(query)
        return get_query_columns_info(table_info['fields'])

    def _get_copy_query(self, query, columns, limit):
        query_columns = [
            double_quote(column.name) for column in columns
            if (column.name != 'the_geom_webmercator')
        ]

        query = 'SELECT {columns} FROM ({query}) _q'.format(
            query=query, columns=','.join(query_columns))

        if limit is not None:
            if isinstance(limit, int) and (limit >= 0):
                query += ' LIMIT {limit}'.format(limit=limit)
            else:
                raise ValueError("`limit` parameter must an integer >= 0")

        return query

    @retry_copy
    def _copy_to(self, query, columns, retry_times=DEFAULT_RETRY_TIMES):
        log.debug('COPY TO')
        copy_query = "COPY ({0}) TO stdout WITH (FORMAT csv, HEADER true, NULL '{1}')".format(
            query, PG_NULL)

        raw_result = self.copy_client.copyto_stream(copy_query)

        converters = obtain_converters(columns)
        parse_dates = date_columns_names(columns)

        df = pd.read_csv(raw_result,
                         converters=converters,
                         parse_dates=parse_dates)

        return df

    @retry_copy
    def _copy_from(self,
                   dataframe,
                   table_name,
                   columns,
                   retry_times=DEFAULT_RETRY_TIMES):
        log.debug('COPY FROM')
        query = """
            COPY {table_name}({columns}) FROM stdin WITH (FORMAT csv, DELIMITER '|', NULL '{null}');
        """.format(table_name=table_name,
                   null=PG_NULL,
                   columns=','.join(
                       double_quote(column.dbname)
                       for column in columns)).strip()
        data = _compute_copy_data(dataframe, columns)

        self.copy_client.copyfrom(query, data)

    def _rename_table(self, table_name, new_table_name):
        query = _rename_table_query(table_name, new_table_name)
        self.execute_query(query)

    def normalize_table_name(self, table_name):
        norm_table_name = normalize_name(table_name)
        if norm_table_name != table_name:
            log.debug('Table name normalized: "{}"'.format(norm_table_name))
        return norm_table_name