示例#1
0
 def processAlgorithm(self, feedback):
     database = self.getParameterValue(self.DATABASE)
     uri = QgsDataSourceUri(database)
     if uri.database() is "":
         if "|layerid" in database:
             database = database[: database.find("|layerid")]
         uri = QgsDataSourceUri("dbname='%s'" % (database))
     self.db = spatialite.GeoDB(uri)
     sql = self.getParameterValue(self.SQL).replace("\n", " ")
     try:
         self.db._exec_sql_and_commit(str(sql))
     except spatialite.DbError as e:
         raise GeoAlgorithmExecutionException(self.tr("Error executing SQL:\n%s") % str(e))
示例#2
0
 def processAlgorithm(self, progress):
     database = self.getParameterValue(self.DATABASE)
     uri = QgsDataSourceUri(database)
     if uri.database() is '':
         if '|layerid' in database:
             database = database[:database.find('|layerid')]
         uri = QgsDataSourceUri('dbname=\'%s\'' % (database))
     self.db = spatialite.GeoDB(uri)
     sql = self.getParameterValue(self.SQL).replace('\n', ' ')
     try:
         self.db._exec_sql_and_commit(str(sql))
     except spatialite.DbError as e:
         raise GeoAlgorithmExecutionException(
             self.tr('Error executing SQL:\n%s') % str(e))
示例#3
0
    def copy(self, target_path, copied_files, keep_existent=False):
        """
        Copy a layer to a new path and adjust its datasource.

        :param layer: The layer to copy
        :param target_path: A path to a folder into which the data will be copied
        :param keep_existent: if True and target file already exists, keep it as it is
        """
        if not self.is_file:
            # Copy will also be called on non-file layers like WMS. In this case, just do nothing.
            return

        layer_name_suffix = ''
        # Shapefiles and GeoPackages have the path in the source
        uri_parts = self.layer.source().split('|', 1)
        file_path = uri_parts[0]
        if len(uri_parts) > 1:
            layer_name_suffix = uri_parts[1]
        # Spatialite have the path in the table part of the uri
        uri = QgsDataSourceUri(self.layer.dataProvider().dataSourceUri())

        if os.path.isfile(file_path):
            source_path, file_name = os.path.split(file_path)
            basename, extensions = get_file_extension_group(file_name)
            for ext in extensions:
                dest_file = os.path.join(target_path, basename + ext)
                if os.path.exists(os.path.join(source_path, basename + ext)) and \
                        (keep_existent is False or not os.path.isfile(dest_file)):
                    shutil.copy(os.path.join(source_path, basename + ext), dest_file)

            new_source = os.path.join(target_path, file_name)
            if layer_name_suffix:
                new_source = new_source + '|' + layer_name_suffix
            self._change_data_source(new_source)
        # Spatialite files have a uri
        else:
            file_path = uri.database()
            if os.path.isfile(file_path):
                source_path, file_name = os.path.split(file_path)
                basename, extensions = get_file_extension_group(file_name)
                for ext in extensions:
                    dest_file = os.path.join(target_path, basename + ext)
                    if os.path.exists(os.path.join(source_path, basename + ext)) and \
                            (keep_existent is False or not os.path.isfile(dest_file)):
                        shutil.copy(os.path.join(source_path, basename + ext),
                                    dest_file)
                uri.setDatabase(os.path.join(target_path, file_name))
                self._change_data_source(uri.uri())
        return copied_files
示例#4
0
    def processAlgorithm(self, parameters, context, feedback):
        database = self.parameterAsVectorLayer(parameters, self.DATABASE, context)
        databaseuri = database.dataProvider().dataSourceUri()
        uri = QgsDataSourceUri(databaseuri)
        if uri.database() is '':
            if '|layerid' in databaseuri:
                databaseuri = databaseuri[:databaseuri.find('|layerid')]
            uri = QgsDataSourceUri('dbname=\'%s\'' % (databaseuri))
        db = spatialite.GeoDB(uri)
        sql = self.parameterAsString(parameters, self.SQL, context).replace('\n', ' ')
        try:
            db._exec_sql_and_commit(str(sql))
        except spatialite.DbError as e:
            raise GeoAlgorithmExecutionException(
                self.tr('Error executing SQL:\n{0}').format(str(e)))

        return {}
    def addConnectionConfig(cls, conn_name, uri):
        """Necessary to allow db_manager to have the list of connections get from settings."""
        uri = QgsDataSourceUri(uri)

        settings = QgsSettings()
        baseKey = "/PostgreSQL/connections/"
        baseKey += conn_name
        settings.setValue(baseKey + "/service", uri.service())
        settings.setValue(baseKey + "/host", uri.host())
        settings.setValue(baseKey + "/port", uri.port())
        settings.setValue(baseKey + "/database", uri.database())
        if uri.username():
            settings.setValue(baseKey + "/username", uri.username())
        if uri.password():
            settings.setValue(baseKey + "/password", uri.password())
        if uri.authConfigId():
            settings.setValue(baseKey + "/authcfg", uri.authConfigId())
        if uri.sslMode():
            settings.setValue(baseKey + "/sslmode", uri.sslMode())
def getConnectionParameterFromDbLayer(layer: QgsMapLayer) -> Dict[str,str]:
    '''
    Get connection parameters
    from the layer datasource
    '''
    connectionParams = None

    if layer.providerType() == 'postgres':
        dbType = 'postgis'
    else:
        dbType = 'spatialite'

    src = layer.source()
    try:
        uri = QgsDataSourceUri(src)
    except:
        uri = QgsDataSourceURI(src)

    # TODO Use immutable namedtuple
    connectionParams = {
        'service' : uri.service(),
        'dbname' : uri.database(),
        'host' : uri.host(),
        'port': uri.port(),
        'user' : uri.username(),
        'password': uri.password(),
        'sslmode' : uri.sslMode(),
        'key': uri.keyColumn(),
        'estimatedmetadata' : str(uri.useEstimatedMetadata()),
        'checkPrimaryKeyUnicity' : '',
        'srid' : uri.srid(),
        'type': uri.wkbType(),
        'schema': uri.schema(),
        'table' : uri.table(),
        'geocol' : uri.geometryColumn(),
        'sql' : uri.sql(),
        'dbType': dbType
    }

    return connectionParams
示例#7
0
    def consolidateVectorLayer(self, layer):
        newPath = self.layerTreePath(layer)
        if not os.path.isdir(newPath):
            os.makedirs(newPath)

        exportLayer = False

        providerType = layer.providerType()
        if providerType == 'ogr':
            self._processGdalDatasource(layer, newPath)
        elif providerType in ('gpx', 'delimitedtext'):
            layerFile, layerName = self._filePathFromUri(layer.source())
            self._copyLayerFiles(layerFile, newPath)
            newDirectory = newPath.replace(self.baseDirectory, '.')
            newSource = '{dirName}/{fileName}?{layer}'.format(dirName=newDirectory, fileName=os.path.split(layerFile)[1], layer=layerName)
            self.updateLayerSource(layer.id(), newSource)
        elif providerType == 'spatialite':
            uri = QgsDataSourceUri(layer.source())
            layerFile = uri.database()
            self._copyLayerFiles(layerFile, newPath)
            newDirectory = newPath.replace(self.baseDirectory, '.')
            uri.setDatabase('./{dirName}/{fileName}'.format(dirName=newDirectory, fileName=os.path.split(layerFile)[1]))
            self.updateLayerSource(layer.id(), uri.uri())
        elif providerType == 'memory':
            exportLayer = True
        elif providerType in ('DB2', 'mssql', 'oracle', 'postgres', 'wfs'):
            if 'exportRemote' in self.settings and self.settings['exportRemote']:
                exportLayer = True
        else:
            QgsMessageLog.logMessage(self.tr('Layers from the "{provider}" provider are currently not supported.'.format(provider=providerType)), 'QConsolidate', Qgis.Info)

        if exportLayer:
            filePath = os.path.join(newPath, self.safeName(layer.name()))
            ok, filePath = self.exportVectorLayer(layer, filePath)
            if ok:
                newSource = filePath.replace(self.baseDirectory, '.')
                self.updateLayerSource(layer.id(), newSource, 'ogr')
示例#8
0
    def ogrConnectionStringAndFormatFromLayer(layer):
        provider = layer.dataProvider().name()
        if provider == 'spatialite':
            # dbname='/geodata/osm_ch.sqlite' table="places" (Geometry) sql=
            regex = re.compile("dbname='(.+)'")
            r = regex.search(str(layer.source()))
            ogrstr = r.groups()[0]
            format = 'SQLite'
        elif provider == 'postgres':
            # dbname='ktryjh_iuuqef' host=spacialdb.com port=9999
            # user='******' password='******' sslmode=disable
            # key='gid' estimatedmetadata=true srid=4326 type=MULTIPOLYGON
            # table="t4" (geom) sql=
            dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
            conninfo = dsUri.connectionInfo()
            conn = None
            ok = False
            while not conn:
                try:
                    conn = psycopg2.connect(dsUri.connectionInfo())
                except psycopg2.OperationalError:
                    (ok, user, passwd) = QgsCredentials.instance().get(conninfo, dsUri.username(), dsUri.password())
                    if not ok:
                        break

                    dsUri.setUsername(user)
                    dsUri.setPassword(passwd)

            if not conn:
                raise RuntimeError('Could not connect to PostgreSQL database - check connection info')

            if ok:
                QgsCredentials.instance().put(conninfo, user, passwd)

            ogrstr = "PG:%s" % dsUri.connectionInfo()
            format = 'PostgreSQL'
        elif provider == 'mssql':
            #'dbname=\'db_name\' host=myHost estimatedmetadata=true
            # srid=27700 type=MultiPolygon table="dbo"."my_table"
            # #(Shape) sql='
            dsUri = layer.dataProvider().uri()
            ogrstr = 'MSSQL:'
            ogrstr += 'database={0};'.format(dsUri.database())
            ogrstr += 'server={0};'.format(dsUri.host())
            if dsUri.username() != "":
                ogrstr += 'uid={0};'.format(dsUri.username())
            else:
                ogrstr += 'trusted_connection=yes;'
            if dsUri.password() != '':
                ogrstr += 'pwd={0};'.format(dsUri.password())
            ogrstr += 'tables={0}'.format(dsUri.table())
            format = 'MSSQL'
        elif provider == "oracle":
            # OCI:user/password@host:port/service:table
            dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
            ogrstr = "OCI:"
            if dsUri.username() != "":
                ogrstr += dsUri.username()
                if dsUri.password() != "":
                    ogrstr += "/" + dsUri.password()
                delim = "@"

            if dsUri.host() != "":
                ogrstr += delim + dsUri.host()
                delim = ""
                if dsUri.port() != "" and dsUri.port() != '1521':
                    ogrstr += ":" + dsUri.port()
                ogrstr += "/"
                if dsUri.database() != "":
                    ogrstr += dsUri.database()
            elif dsUri.database() != "":
                ogrstr += delim + dsUri.database()

            if ogrstr == "OCI:":
                raise RuntimeError('Invalid oracle data source - check connection info')

            ogrstr += ":"
            if dsUri.schema() != "":
                ogrstr += dsUri.schema() + "."

            ogrstr += dsUri.table()
            format = 'OCI'
        else:
            ogrstr = str(layer.source()).split("|")[0]
            path, ext = os.path.splitext(ogrstr)
            format = QgsVectorFileWriter.driverForExtension(ext)

        return ogrstr, '"' + format + '"'
示例#9
0
文件: GdalUtils.py 项目: ufolr/QGIS
    def ogrConnectionStringAndFormat(uri, context):
        """Generates OGR connection string and format string from layer source
        Returned values are a tuple of the connection string and format string
        """
        ogrstr = None
        format = None

        layer = QgsProcessingUtils.mapLayerFromString(uri, context, False)
        if layer is None:
            path, ext = os.path.splitext(uri)
            format = QgsVectorFileWriter.driverForExtension(ext)
            return '"' + uri + '"', '"' + format + '"'

        provider = layer.dataProvider().name()
        if provider == 'spatialite':
            # dbname='/geodata/osm_ch.sqlite' table="places" (Geometry) sql=
            regex = re.compile("dbname='(.+)'")
            r = regex.search(str(layer.source()))
            ogrstr = r.groups()[0]
            format = 'SQLite'
        elif provider == 'postgres':
            # dbname='ktryjh_iuuqef' host=spacialdb.com port=9999
            # user='******' password='******' sslmode=disable
            # key='gid' estimatedmetadata=true srid=4326 type=MULTIPOLYGON
            # table="t4" (geom) sql=
            dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
            conninfo = dsUri.connectionInfo()
            conn = None
            ok = False
            while not conn:
                try:
                    conn = psycopg2.connect(dsUri.connectionInfo())
                except psycopg2.OperationalError:
                    (ok, user, passwd) = QgsCredentials.instance().get(conninfo, dsUri.username(), dsUri.password())
                    if not ok:
                        break

                    dsUri.setUsername(user)
                    dsUri.setPassword(passwd)

            if not conn:
                raise RuntimeError('Could not connect to PostgreSQL database - check connection info')

            if ok:
                QgsCredentials.instance().put(conninfo, user, passwd)

            ogrstr = "PG:%s" % dsUri.connectionInfo()
            format = 'PostgreSQL'
        elif provider == "oracle":
            # OCI:user/password@host:port/service:table
            dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
            ogrstr = "OCI:"
            if dsUri.username() != "":
                ogrstr += dsUri.username()
                if dsUri.password() != "":
                    ogrstr += "/" + dsUri.password()
                delim = "@"

            if dsUri.host() != "":
                ogrstr += delim + dsUri.host()
                delim = ""
                if dsUri.port() != "" and dsUri.port() != '1521':
                    ogrstr += ":" + dsUri.port()
                ogrstr += "/"
                if dsUri.database() != "":
                    ogrstr += dsUri.database()
            elif dsUri.database() != "":
                ogrstr += delim + dsUri.database()

            if ogrstr == "OCI:":
                raise RuntimeError('Invalid oracle data source - check connection info')

            ogrstr += ":"
            if dsUri.schema() != "":
                ogrstr += dsUri.schema() + "."

            ogrstr += dsUri.table()
            format = 'OCI'
        else:
            ogrstr = str(layer.source()).split("|")[0]
            path, ext = os.path.splitext(ogrstr)
            format = QgsVectorFileWriter.driverForExtension(ext)

        return '"' + ogrstr + '"', '"' + format + '"'
示例#10
0
    def processAlgorithm(self, progress):
        database = self.getParameterValue(self.DATABASE)
        uri = QgsDataSourceUri(database)
        if uri.database() is '':
            if '|layerid' in database:
                database = database[:database.find('|layerid')]
            uri = QgsDataSourceUri('dbname=\'%s\'' % (database))
        db = spatialite.GeoDB(uri)

        overwrite = self.getParameterValue(self.OVERWRITE)
        createIndex = self.getParameterValue(self.CREATEINDEX)
        convertLowerCase = self.getParameterValue(self.LOWERCASE_NAMES)
        dropStringLength = self.getParameterValue(self.DROP_STRING_LENGTH)
        forceSinglePart = self.getParameterValue(self.FORCE_SINGLEPART)
        primaryKeyField = self.getParameterValue(self.PRIMARY_KEY) or 'id'
        encoding = self.getParameterValue(self.ENCODING)

        layerUri = self.getParameterValue(self.INPUT)
        layer = dataobjects.getObjectFromUri(layerUri)

        table = self.getParameterValue(self.TABLENAME)
        if table:
            table.strip()
        if not table or table == '':
            table = layer.name()
        table = table.replace(' ', '').lower()
        providerName = 'spatialite'

        geomColumn = self.getParameterValue(self.GEOMETRY_COLUMN)
        if not geomColumn:
            geomColumn = 'the_geom'

        options = {}
        if overwrite:
            options['overwrite'] = True
        if convertLowerCase:
            options['lowercaseFieldNames'] = True
            geomColumn = geomColumn.lower()
        if dropStringLength:
            options['dropStringConstraints'] = True
        if forceSinglePart:
            options['forceSinglePartGeometryType'] = True

        # Clear geometry column for non-geometry tables
        if not layer.hasGeometryType():
            geomColumn = None

        uri = db.uri
        uri.setDataSource('', table, geomColumn, '', primaryKeyField)

        if encoding:
            layer.setProviderEncoding(encoding)

        (ret, errMsg) = QgsVectorLayerImport.importLayer(
            layer,
            uri.uri(),
            providerName,
            self.crs,
            False,
            False,
            options,
        )
        if ret != 0:
            raise GeoAlgorithmExecutionException(
                self.tr('Error importing to Spatialite\n%s' % errMsg))

        if geomColumn and createIndex:
            db.create_spatial_index(table, geomColumn)
示例#11
0
    def __init__(self, destination, encoding, fields, geometryType,
                 crs, options=None):
        self.destination = destination
        self.isNotFileBased = False
        self.layer = None
        self.writer = None

        if encoding is None:
            settings = QSettings()
            encoding = settings.value('/Processing/encoding', 'System', str)

        if self.destination.startswith(self.MEMORY_LAYER_PREFIX):
            self.isNotFileBased = True

            uri = QgsWkbTypes.displayString(geometryType) + "?uuid=" + str(uuid.uuid4())
            if crs.isValid():
                uri += '&crs=' + crs.authid()
            fieldsdesc = []
            for f in fields:
                qgsfield = _toQgsField(f)
                fieldsdesc.append('field=%s:%s' % (qgsfield.name(),
                                                   TYPE_MAP_MEMORY_LAYER.get(qgsfield.type(), "string")))
            if fieldsdesc:
                uri += '&' + '&'.join(fieldsdesc)

            self.layer = QgsVectorLayer(uri, self.destination, 'memory')
            self.writer = self.layer.dataProvider()
        elif self.destination.startswith(self.POSTGIS_LAYER_PREFIX):
            self.isNotFileBased = True
            uri = QgsDataSourceUri(self.destination[len(self.POSTGIS_LAYER_PREFIX):])
            connInfo = uri.connectionInfo()
            (success, user, passwd) = QgsCredentials.instance().get(connInfo, None, None)
            if success:
                QgsCredentials.instance().put(connInfo, user, passwd)
            else:
                raise GeoAlgorithmExecutionException("Couldn't connect to database")
            try:
                db = postgis.GeoDB(host=uri.host(), port=int(uri.port()),
                                   dbname=uri.database(), user=user, passwd=passwd)
            except postgis.DbError as e:
                raise GeoAlgorithmExecutionException(
                    "Couldn't connect to database:\n%s" % e.message)

            def _runSQL(sql):
                try:
                    db._exec_sql_and_commit(str(sql))
                except postgis.DbError as e:
                    raise GeoAlgorithmExecutionException(
                        'Error creating output PostGIS table:\n%s' % e.message)

            fields = [_toQgsField(f) for f in fields]
            fieldsdesc = ",".join('%s %s' % (f.name(),
                                             TYPE_MAP_POSTGIS_LAYER.get(f.type(), "VARCHAR"))
                                  for f in fields)

            _runSQL("CREATE TABLE %s.%s (%s)" % (uri.schema(), uri.table().lower(), fieldsdesc))
            if geometryType != QgsWkbTypes.NullGeometry:
                _runSQL("SELECT AddGeometryColumn('{schema}', '{table}', 'the_geom', {srid}, '{typmod}', 2)".format(
                    table=uri.table().lower(), schema=uri.schema(), srid=crs.authid().split(":")[-1],
                    typmod=QgsWkbTypes.displayString(geometryType).upper()))

            self.layer = QgsVectorLayer(uri.uri(), uri.table(), "postgres")
            self.writer = self.layer.dataProvider()
        elif self.destination.startswith(self.SPATIALITE_LAYER_PREFIX):
            self.isNotFileBased = True
            uri = QgsDataSourceUri(self.destination[len(self.SPATIALITE_LAYER_PREFIX):])
            try:
                db = spatialite.GeoDB(uri=uri)
            except spatialite.DbError as e:
                raise GeoAlgorithmExecutionException(
                    "Couldn't connect to database:\n%s" % e.message)

            def _runSQL(sql):
                try:
                    db._exec_sql_and_commit(str(sql))
                except spatialite.DbError as e:
                    raise GeoAlgorithmExecutionException(
                        'Error creating output Spatialite table:\n%s' % str(e))

            fields = [_toQgsField(f) for f in fields]
            fieldsdesc = ",".join('%s %s' % (f.name(),
                                             TYPE_MAP_SPATIALITE_LAYER.get(f.type(), "VARCHAR"))
                                  for f in fields)

            _runSQL("DROP TABLE IF EXISTS %s" % uri.table().lower())
            _runSQL("CREATE TABLE %s (%s)" % (uri.table().lower(), fieldsdesc))
            if geometryType != QgsWkbTypes.NullGeometry:
                _runSQL("SELECT AddGeometryColumn('{table}', 'the_geom', {srid}, '{typmod}', 2)".format(
                    table=uri.table().lower(), srid=crs.authid().split(":")[-1],
                    typmod=QgsWkbTypes.displayString(geometryType).upper()))

            self.layer = QgsVectorLayer(uri.uri(), uri.table(), "spatialite")
            self.writer = self.layer.dataProvider()
        else:
            formats = QgsVectorFileWriter.supportedFiltersAndFormats()
            OGRCodes = {}
            for (key, value) in list(formats.items()):
                extension = str(key)
                extension = extension[extension.find('*.') + 2:]
                extension = extension[:extension.find(' ')]
                OGRCodes[extension] = value
            OGRCodes['dbf'] = "DBF file"

            extension = self.destination[self.destination.rfind('.') + 1:]

            if extension not in OGRCodes:
                extension = 'shp'
                self.destination = self.destination + '.shp'

            if geometryType == QgsWkbTypes.NoGeometry:
                if extension == 'shp':
                    extension = 'dbf'
                    self.destination = self.destination[:self.destination.rfind('.')] + '.dbf'
                if extension not in self.nogeometry_extensions:
                    raise GeoAlgorithmExecutionException(
                        "Unsupported format for tables with no geometry")

            qgsfields = QgsFields()
            for field in fields:
                qgsfields.append(_toQgsField(field))

            # use default dataset/layer options
            dataset_options = QgsVectorFileWriter.defaultDatasetOptions(OGRCodes[extension])
            layer_options = QgsVectorFileWriter.defaultLayerOptions(OGRCodes[extension])

            self.writer = QgsVectorFileWriter(self.destination, encoding,
                                              qgsfields, geometryType, crs, OGRCodes[extension],
                                              dataset_options, layer_options)
示例#12
0
class GeoDB(object):
    @classmethod
    def from_name(cls, conn_name):
        uri = uri_from_name(conn_name)
        return cls(uri=uri)

    def __init__(self,
                 host=None,
                 port=None,
                 dbname=None,
                 user=None,
                 passwd=None,
                 service=None,
                 uri=None):
        # Regular expression for identifiers without need to quote them
        self.re_ident_ok = re.compile(r"^\w+$")
        port = str(port)

        if uri:
            self.uri = uri
        else:
            self.uri = QgsDataSourceUri()
            if service:
                self.uri.setConnection(service, dbname, user, passwd)
            else:
                self.uri.setConnection(host, port, dbname, user, passwd)

        conninfo = self.uri.connectionInfo(False)
        err = None
        for i in range(4):
            expandedConnInfo = self.uri.connectionInfo(True)
            try:
                self.con = psycopg2.connect(expandedConnInfo)
                if err is not None:
                    QgsCredentials.instance().put(conninfo,
                                                  self.uri.username(),
                                                  self.uri.password())
                break
            except psycopg2.OperationalError as e:
                if i == 3:
                    raise DbError(str(e))

                err = str(e)
                user = self.uri.username()
                password = self.uri.password()
                (ok, user, password) = QgsCredentials.instance().get(
                    conninfo, user, password, err)
                if not ok:
                    raise DbError(
                        QCoreApplication.translate("PostGIS",
                                                   'Action canceled by user'))
                if user:
                    self.uri.setUsername(user)
                if password:
                    self.uri.setPassword(password)
            finally:
                # remove certs (if any) of the expanded connectionInfo
                expandedUri = QgsDataSourceUri(expandedConnInfo)

                sslCertFile = expandedUri.param("sslcert")
                if sslCertFile:
                    sslCertFile = sslCertFile.replace("'", "")
                    os.remove(sslCertFile)

                sslKeyFile = expandedUri.param("sslkey")
                if sslKeyFile:
                    sslKeyFile = sslKeyFile.replace("'", "")
                    os.remove(sslKeyFile)

                sslCAFile = expandedUri.param("sslrootcert")
                if sslCAFile:
                    sslCAFile = sslCAFile.replace("'", "")
                    os.remove(sslCAFile)

        self.has_postgis = self.check_postgis()

    def get_info(self):
        c = self.con.cursor()
        self._exec_sql(c, 'SELECT version()')
        return c.fetchone()[0]

    def check_postgis(self):
        """Check whether postgis_version is present in catalog.
        """

        c = self.con.cursor()
        self._exec_sql(
            c,
            "SELECT COUNT(*) FROM pg_proc WHERE proname = 'postgis_version'")
        return c.fetchone()[0] > 0

    def get_postgis_info(self):
        """Returns tuple about PostGIS support:
              - lib version
              - installed scripts version
              - released scripts version
              - geos version
              - proj version
              - whether uses stats
        """

        c = self.con.cursor()
        self._exec_sql(
            c, 'SELECT postgis_lib_version(), postgis_scripts_installed(), \
            postgis_scripts_released(), postgis_geos_version(), \
            postgis_proj_version(), postgis_uses_stats()')
        return c.fetchone()

    def list_schemas(self):
        """Get list of schemas in tuples: (oid, name, owner, perms).
        """

        c = self.con.cursor()
        sql = "SELECT oid, nspname, pg_get_userbyid(nspowner), nspacl \
               FROM pg_namespace \
               WHERE nspname !~ '^pg_' AND nspname != 'information_schema'"

        self._exec_sql(c, sql)
        return c.fetchall()

    def list_geotables(self, schema=None):
        """Get list of tables with schemas, whether user has privileges,
        whether table has geometry column(s) etc.

        Geometry_columns:
          - f_table_schema
          - f_table_name
          - f_geometry_column
          - coord_dimension
          - srid
          - type
        """

        c = self.con.cursor()

        if schema:
            schema_where = " AND nspname = '%s' " % self._quote_unicode(schema)
        else:
            schema_where = \
                " AND (nspname != 'information_schema' AND nspname !~ 'pg_') "

        # LEFT OUTER JOIN: like LEFT JOIN but if there are more matches,
        # for join, all are used (not only one)

        # First find out whether PostGIS is enabled
        if not self.has_postgis:
            # Get all tables and views
            sql = """SELECT pg_class.relname, pg_namespace.nspname,
                            pg_class.relkind, pg_get_userbyid(relowner),
                            reltuples, relpages, NULL, NULL, NULL, NULL
                  FROM pg_class
                  JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
                  WHERE pg_class.relkind IN ('v', 'r', 'm', 'p')""" \
                  + schema_where + 'ORDER BY nspname, relname'
        else:
            # Discovery of all tables and whether they contain a
            # geometry column
            sql = """SELECT pg_class.relname, pg_namespace.nspname,
                            pg_class.relkind, pg_get_userbyid(relowner),
                            reltuples, relpages, pg_attribute.attname,
                            pg_attribute.atttypid::regtype, NULL, NULL
                  FROM pg_class
                  JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
                  LEFT OUTER JOIN pg_attribute ON
                      pg_attribute.attrelid = pg_class.oid AND
                      (pg_attribute.atttypid = 'geometry'::regtype
                      OR pg_attribute.atttypid IN
                          (SELECT oid FROM pg_type
                           WHERE typbasetype='geometry'::regtype))
                  WHERE pg_class.relkind IN ('v', 'r', 'm', 'p') """ \
                  + schema_where + 'ORDER BY nspname, relname, attname'

        self._exec_sql(c, sql)
        items = c.fetchall()

        # Get geometry info from geometry_columns if exists
        if self.has_postgis:
            sql = """SELECT relname, nspname, relkind,
                            pg_get_userbyid(relowner), reltuples, relpages,
                            geometry_columns.f_geometry_column,
                            geometry_columns.type,
                            geometry_columns.coord_dimension,
                            geometry_columns.srid
                  FROM pg_class
                  JOIN pg_namespace ON relnamespace=pg_namespace.oid
                  LEFT OUTER JOIN geometry_columns ON
                      relname=f_table_name AND nspname=f_table_schema
                  WHERE relkind IN ('r','v','m','p') """ \
                  + schema_where + 'ORDER BY nspname, relname, \
                  f_geometry_column'

            self._exec_sql(c, sql)

            # Merge geometry info to "items"
            for (i, geo_item) in enumerate(c.fetchall()):
                if geo_item[7]:
                    items[i] = geo_item

        return items

    def get_table_rows(self, table, schema=None):
        c = self.con.cursor()
        self._exec_sql(
            c, 'SELECT COUNT(*) FROM %s' % self._table_name(schema, table))
        return c.fetchone()[0]

    def get_table_fields(self, table, schema=None):
        """Return list of columns in table"""

        c = self.con.cursor()
        schema_where = (" AND nspname='%s' " % self._quote_unicode(schema)
                        if schema is not None else '')
        sql = """SELECT a.attnum AS ordinal_position,
                        a.attname AS column_name,
                        t.typname AS data_type,
                        a.attlen AS char_max_len,
                        a.atttypmod AS modifier,
                        a.attnotnull AS notnull,
                        a.atthasdef AS hasdefault,
                        adef.adsrc AS default_value
              FROM pg_class c
              JOIN pg_attribute a ON a.attrelid = c.oid
              JOIN pg_type t ON a.atttypid = t.oid
              JOIN pg_namespace nsp ON c.relnamespace = nsp.oid
              LEFT JOIN pg_attrdef adef ON adef.adrelid = a.attrelid
                  AND adef.adnum = a.attnum
              WHERE
                  c.relname = '%s' %s AND
                  a.attnum > 0
              ORDER BY a.attnum""" \
              % (self._quote_unicode(table), schema_where)

        self._exec_sql(c, sql)
        attrs = []
        for row in c.fetchall():
            attrs.append(TableAttribute(row))
        return attrs

    def get_table_indexes(self, table, schema=None):
        """Get info about table's indexes. ignore primary key and unique
        index, they get listed in constraints.
        """

        c = self.con.cursor()

        schema_where = (" AND nspname='%s' " % self._quote_unicode(schema)
                        if schema is not None else '')
        sql = """SELECT relname, indkey
              FROM pg_class, pg_index
              WHERE pg_class.oid = pg_index.indexrelid AND pg_class.oid IN (
                     SELECT indexrelid
                     FROM pg_index, pg_class
                     JOIN pg_namespace nsp ON pg_class.relnamespace = nsp.oid
                     WHERE pg_class.relname='%s' %s AND
                         pg_class.oid=pg_index.indrelid
                         AND indisunique != 't' AND indisprimary != 't' )""" \
              % (self._quote_unicode(table), schema_where)
        self._exec_sql(c, sql)
        indexes = []
        for row in c.fetchall():
            indexes.append(TableIndex(row))
        return indexes

    def get_table_constraints(self, table, schema=None):
        c = self.con.cursor()

        schema_where = (" AND nspname='%s' " % self._quote_unicode(schema)
                        if schema is not None else '')
        sql = """SELECT c.conname, c.contype, c.condeferrable, c.condeferred,
                        array_to_string(c.conkey, ' '), c.consrc, t2.relname,
                        c.confupdtype, c.confdeltype, c.confmatchtype,
                        array_to_string(c.confkey, ' ')
              FROM pg_constraint c
              LEFT JOIN pg_class t ON c.conrelid = t.oid
              LEFT JOIN pg_class t2 ON c.confrelid = t2.oid
              JOIN pg_namespace nsp ON t.relnamespace = nsp.oid
              WHERE t.relname = '%s' %s """ \
              % (self._quote_unicode(table), schema_where)

        self._exec_sql(c, sql)

        constrs = []
        for row in c.fetchall():
            constrs.append(TableConstraint(row))
        return constrs

    def get_view_definition(self, view, schema=None):
        """Returns definition of the view."""

        schema_where = (" AND nspname='%s' " % self._quote_unicode(schema)
                        if schema is not None else '')
        sql = """SELECT pg_get_viewdef(c.oid)
              FROM pg_class c
              JOIN pg_namespace nsp ON c.relnamespace = nsp.oid
              WHERE relname='%s' %s AND relkind IN ('v','m')""" \
              % (self._quote_unicode(view), schema_where)
        c = self.con.cursor()
        self._exec_sql(c, sql)
        return c.fetchone()[0]

    def add_geometry_column(self,
                            table,
                            geom_type,
                            schema=None,
                            geom_column='the_geom',
                            srid=-1,
                            dim=2):
        # Use schema if explicitly specified
        if schema:
            schema_part = "'%s', " % self._quote_unicode(schema)
        else:
            schema_part = ''
        sql = "SELECT AddGeometryColumn(%s'%s', '%s', %d, '%s', %d)" % (
            schema_part,
            self._quote_unicode(table),
            self._quote_unicode(geom_column),
            srid,
            self._quote_unicode(geom_type),
            dim,
        )
        self._exec_sql_and_commit(sql)

    def delete_geometry_column(self, table, geom_column, schema=None):
        """Use PostGIS function to delete geometry column correctly."""

        if schema:
            schema_part = "'%s', " % self._quote_unicode(schema)
        else:
            schema_part = ''
        sql = "SELECT DropGeometryColumn(%s'%s', '%s')" % (
            schema_part, self._quote_unicode(table),
            self._quote_unicode(geom_column))
        self._exec_sql_and_commit(sql)

    def delete_geometry_table(self, table, schema=None):
        """Delete table with one or more geometries using PostGIS function."""

        if schema:
            schema_part = "'%s', " % self._quote_unicode(schema)
        else:
            schema_part = ''
        sql = "SELECT DropGeometryTable(%s'%s')" % (schema_part,
                                                    self._quote_unicode(table))
        self._exec_sql_and_commit(sql)

    def create_table(self, table, fields, pkey=None, schema=None):
        """Create ordinary table.

        'fields' is array containing instances of TableField
        'pkey' contains name of column to be used as primary key
        """

        if len(fields) == 0:
            return False

        table_name = self._table_name(schema, table)

        sql = 'CREATE TABLE %s (%s' % (table_name, fields[0].field_def())
        for field in fields[1:]:
            sql += ', %s' % field.field_def()
        if pkey:
            sql += ', PRIMARY KEY (%s)' % self._quote(pkey)
        sql += ')'
        self._exec_sql_and_commit(sql)
        return True

    def delete_table(self, table, schema=None):
        """Delete table from the database."""

        table_name = self._table_name(schema, table)
        sql = 'DROP TABLE %s' % table_name
        self._exec_sql_and_commit(sql)

    def empty_table(self, table, schema=None):
        """Delete all rows from table."""

        table_name = self._table_name(schema, table)
        sql = 'DELETE FROM %s' % table_name
        self._exec_sql_and_commit(sql)

    def rename_table(self, table, new_table, schema=None):
        """Rename a table in database."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s RENAME TO %s' % (table_name,
                                               self._quote(new_table))
        self._exec_sql_and_commit(sql)

        # Update geometry_columns if PostGIS is enabled
        if self.has_postgis:
            sql = "UPDATE geometry_columns SET f_table_name='%s' \
                   WHERE f_table_name='%s'" \
                   % (self._quote_unicode(new_table), self._quote_unicode(table))
            if schema is not None:
                sql += " AND f_table_schema='%s'" % self._quote_unicode(schema)
            self._exec_sql_and_commit(sql)

    def create_view(self, name, query, schema=None):
        view_name = self._table_name(schema, name)
        sql = 'CREATE VIEW %s AS %s' % (view_name, query)
        self._exec_sql_and_commit(sql)

    def delete_view(self, name, schema=None):
        view_name = self._table_name(schema, name)
        sql = 'DROP VIEW %s' % view_name
        self._exec_sql_and_commit(sql)

    def rename_view(self, name, new_name, schema=None):
        """Rename view in database."""

        self.rename_table(name, new_name, schema)

    def create_schema(self, schema):
        """Create a new empty schema in database."""

        sql = 'CREATE SCHEMA %s' % self._quote(schema)
        self._exec_sql_and_commit(sql)

    def delete_schema(self, schema):
        """Drop (empty) schema from database."""

        sql = 'DROP SCHEMA %s' % self._quote(schema)
        self._exec_sql_and_commit(sql)

    def rename_schema(self, schema, new_schema):
        """Rename a schema in database."""

        sql = 'ALTER SCHEMA %s RENAME TO %s' % (self._quote(schema),
                                                self._quote(new_schema))
        self._exec_sql_and_commit(sql)

        # Update geometry_columns if PostGIS is enabled
        if self.has_postgis:
            sql = \
                "UPDATE geometry_columns SET f_table_schema='%s' \
                 WHERE f_table_schema='%s'" \
                 % (self._quote_unicode(new_schema), self._quote_unicode(schema))
            self._exec_sql_and_commit(sql)

    def table_add_column(self, table, field, schema=None):
        """Add a column to table (passed as TableField instance)."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ADD %s' % (table_name, field.field_def())
        self._exec_sql_and_commit(sql)

    def table_delete_column(self, table, field, schema=None):
        """Delete column from a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s DROP %s' % (table_name, self._quote(field))
        self._exec_sql_and_commit(sql)

    def table_column_rename(self, table, name, new_name, schema=None):
        """Rename column in a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s RENAME %s TO %s' % (
            table_name, self._quote(name), self._quote(new_name))
        self._exec_sql_and_commit(sql)

        # Update geometry_columns if PostGIS is enabled
        if self.has_postgis:
            sql = "UPDATE geometry_columns SET f_geometry_column='%s' \
                   WHERE f_geometry_column='%s' AND f_table_name='%s'" \
                   % (self._quote_unicode(new_name), self._quote_unicode(name),
                      self._quote_unicode(table))
            if schema is not None:
                sql += " AND f_table_schema='%s'" % self._quote(schema)
            self._exec_sql_and_commit(sql)

    def table_column_set_type(self, table, column, data_type, schema=None):
        """Change column type."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ALTER %s TYPE %s' % (
            table_name, self._quote(column), data_type)
        self._exec_sql_and_commit(sql)

    def table_column_set_default(self, table, column, default, schema=None):
        """Change column's default value.

        If default=None drop default value.
        """

        table_name = self._table_name(schema, table)
        if default:
            sql = 'ALTER TABLE %s ALTER %s SET DEFAULT %s' % (
                table_name, self._quote(column), default)
        else:
            sql = 'ALTER TABLE %s ALTER %s DROP DEFAULT' % (
                table_name, self._quote(column))
        self._exec_sql_and_commit(sql)

    def table_column_set_null(self, table, column, is_null, schema=None):
        """Change whether column can contain null values."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ALTER %s ' % (table_name, self._quote(column))
        if is_null:
            sql += 'DROP NOT NULL'
        else:
            sql += 'SET NOT NULL'
        self._exec_sql_and_commit(sql)

    def table_add_primary_key(self, table, column, schema=None):
        """Add a primery key (with one column) to a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ADD PRIMARY KEY (%s)' % (table_name,
                                                       self._quote(column))
        self._exec_sql_and_commit(sql)

    def table_add_unique_constraint(self, table, column, schema=None):
        """Add a unique constraint to a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ADD UNIQUE (%s)' % (table_name,
                                                  self._quote(column))
        self._exec_sql_and_commit(sql)

    def table_delete_constraint(self, table, constraint, schema=None):
        """Delete constraint in a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s DROP CONSTRAINT %s' % (table_name,
                                                     self._quote(constraint))
        self._exec_sql_and_commit(sql)

    def table_move_to_schema(self, table, new_schema, schema=None):
        if new_schema == schema:
            return
        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s SET SCHEMA %s' % (table_name,
                                                self._quote(new_schema))
        self._exec_sql_and_commit(sql)

        # Update geometry_columns if PostGIS is enabled
        if self.has_postgis:
            sql = "UPDATE geometry_columns SET f_table_schema='%s' \
                   WHERE f_table_name='%s'" \
                   % (self._quote_unicode(new_schema), self._quote_unicode(table))
            if schema is not None:
                sql += " AND f_table_schema='%s'" % self._quote_unicode(schema)
            self._exec_sql_and_commit(sql)

    def create_index(self, table, name, column, schema=None):
        """Create index on one column using default options."""

        table_name = self._table_name(schema, table)
        idx_name = self._quote(name)
        sql = 'CREATE INDEX %s ON %s (%s)' % (idx_name, table_name,
                                              self._quote(column))
        self._exec_sql_and_commit(sql)

    def create_spatial_index(self, table, schema=None, geom_column='the_geom'):
        table_name = self._table_name(schema, table)
        idx_name = self._quote(u"sidx_%s_%s" % (table, geom_column))
        sql = 'CREATE INDEX %s ON %s USING GIST(%s)' % (
            idx_name, table_name, self._quote(geom_column))
        self._exec_sql_and_commit(sql)

    def delete_index(self, name, schema=None):
        index_name = self._table_name(schema, name)
        sql = 'DROP INDEX %s' % index_name
        self._exec_sql_and_commit(sql)

    def get_database_privileges(self):
        """DB privileges: (can create schemas, can create temp. tables).
        """

        sql = "SELECT has_database_privilege('%(d)s', 'CREATE'), \
                      has_database_privilege('%(d)s', 'TEMP')" \
              % {'d': self._quote_unicode(self.uri.database())}
        c = self.con.cursor()
        self._exec_sql(c, sql)
        return c.fetchone()

    def get_schema_privileges(self, schema):
        """Schema privileges: (can create new objects, can access objects
        in schema)."""

        sql = "SELECT has_schema_privilege('%(s)s', 'CREATE'), \
                      has_schema_privilege('%(s)s', 'USAGE')" \
              % {'s': self._quote_unicode(schema)}
        c = self.con.cursor()
        self._exec_sql(c, sql)
        return c.fetchone()

    def get_table_privileges(self, table, schema=None):
        """Table privileges: (select, insert, update, delete).
        """

        t = self._table_name(schema, table)
        sql = """SELECT has_table_privilege('%(t)s', 'SELECT'),
                        has_table_privilege('%(t)s', 'INSERT'),
                        has_table_privilege('%(t)s', 'UPDATE'),
                        has_table_privilege('%(t)s', 'DELETE')""" \
              % {'t': self._quote_unicode(t)}
        c = self.con.cursor()
        self._exec_sql(c, sql)
        return c.fetchone()

    def vacuum_analyze(self, table, schema=None):
        """Run VACUUM ANALYZE on a table."""

        t = self._table_name(schema, table)

        # VACUUM ANALYZE must be run outside transaction block - we
        # have to change isolation level
        self.con.set_isolation_level(
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
        c = self.con.cursor()
        self._exec_sql(c, 'VACUUM ANALYZE %s' % t)
        self.con.set_isolation_level(
            psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)

    def sr_info_for_srid(self, srid):
        if not self.has_postgis:
            return 'Unknown'

        try:
            c = self.con.cursor()
            self._exec_sql(
                c,
                "SELECT srtext FROM spatial_ref_sys WHERE srid = '%d'" % srid)
            srtext = c.fetchone()[0]

            # Try to extract just SR name (should be quoted in double
            # quotes)
            x = re.search('"([^"]+)"', srtext)
            if x is not None:
                srtext = x.group()
            return srtext
        except DbError:
            return 'Unknown'

    def insert_table_row(self, table, values, schema=None, cursor=None):
        """Insert a row with specified values to a table.

        If a cursor is specified, it doesn't commit (expecting that
        there will be more inserts) otherwise it commits immediately.
        """

        t = self._table_name(schema, table)
        sql = ''
        for value in values:
            # TODO: quote values?
            if sql:
                sql += ', '
            sql += value
        sql = 'INSERT INTO %s VALUES (%s)' % (t, sql)
        if cursor:
            self._exec_sql(cursor, sql)
        else:
            self._exec_sql_and_commit(sql)

    def _exec_sql(self, cursor, sql):
        try:
            cursor.execute(sql)
        except psycopg2.Error as e:
            raise DbError(str(e),
                          e.cursor.query.decode(e.cursor.connection.encoding))

    def _exec_sql_and_commit(self, sql):
        """Tries to execute and commit some action, on error it rolls
        back the change.
        """

        try:
            c = self.con.cursor()
            self._exec_sql(c, sql)
            self.con.commit()
        except DbError:
            self.con.rollback()
            raise

    def _quote(self, identifier):
        """Quote identifier if needed."""

        # Make sure it's python unicode string
        identifier = str(identifier)

        # Is it needed to quote the identifier?
        if self.re_ident_ok.match(identifier) is not None:
            return identifier

        # It's needed - let's quote it (and double the double-quotes)
        return u'"%s"' % identifier.replace('"', '""')

    def _quote_unicode(self, txt):
        """Make the string safe - replace ' with ''.
        """

        # make sure it's python unicode string
        txt = str(txt)
        return txt.replace("'", "''")

    def _table_name(self, schema, table):
        if not schema:
            return self._quote(table)
        else:
            return u'%s.%s' % (self._quote(schema), self._quote(table))
示例#13
0
class GeoDB(object):

    @classmethod
    def from_name(cls, conn_name):
        uri = uri_from_name(conn_name)
        return cls(uri=uri)

    def __init__(self, host=None, port=None, dbname=None, user=None,
                 passwd=None, service=None, uri=None):
        # Regular expression for identifiers without need to quote them
        self.re_ident_ok = re.compile(r"^\w+$")
        port = str(port)

        if uri:
            self.uri = uri
        else:
            self.uri = QgsDataSourceUri()
            if service:
                self.uri.setConnection(service, dbname, user, passwd)
            else:
                self.uri.setConnection(host, port, dbname, user, passwd)

        conninfo = self.uri.connectionInfo(False)
        err = None
        for i in range(4):
            expandedConnInfo = self.uri.connectionInfo(True)
            try:
                self.con = psycopg2.connect(expandedConnInfo)
                if err is not None:
                    QgsCredentials.instance().put(conninfo,
                                                  self.uri.username(),
                                                  self.uri.password())
                break
            except psycopg2.OperationalError as e:
                if i == 3:
                    raise QgsProcessingException(str(e))

                err = str(e)
                user = self.uri.username()
                password = self.uri.password()
                (ok, user, password) = QgsCredentials.instance().get(conninfo,
                                                                     user,
                                                                     password,
                                                                     err)
                if not ok:
                    raise QgsProcessingException(QCoreApplication.translate("PostGIS", 'Action canceled by user'))
                if user:
                    self.uri.setUsername(user)
                if password:
                    self.uri.setPassword(password)
            finally:
                # remove certs (if any) of the expanded connectionInfo
                expandedUri = QgsDataSourceUri(expandedConnInfo)

                sslCertFile = expandedUri.param("sslcert")
                if sslCertFile:
                    sslCertFile = sslCertFile.replace("'", "")
                    os.remove(sslCertFile)

                sslKeyFile = expandedUri.param("sslkey")
                if sslKeyFile:
                    sslKeyFile = sslKeyFile.replace("'", "")
                    os.remove(sslKeyFile)

                sslCAFile = expandedUri.param("sslrootcert")
                if sslCAFile:
                    sslCAFile = sslCAFile.replace("'", "")
                    os.remove(sslCAFile)

        self.has_postgis = self.check_postgis()

    def get_info(self):
        c = self.con.cursor()
        self._exec_sql(c, 'SELECT version()')
        return c.fetchone()[0]

    def check_postgis(self):
        """Check whether postgis_version is present in catalog.
        """

        c = self.con.cursor()
        self._exec_sql(c,
                       "SELECT COUNT(*) FROM pg_proc WHERE proname = 'postgis_version'")
        return c.fetchone()[0] > 0

    def get_postgis_info(self):
        """Returns tuple about PostGIS support:
              - lib version
              - installed scripts version
              - released scripts version
              - geos version
              - proj version
              - whether uses stats
        """

        c = self.con.cursor()
        self._exec_sql(c,
                       'SELECT postgis_lib_version(), postgis_scripts_installed(), \
            postgis_scripts_released(), postgis_geos_version(), \
            postgis_proj_version(), postgis_uses_stats()')
        return c.fetchone()

    def list_schemas(self):
        """Get list of schemas in tuples: (oid, name, owner, perms).
        """

        c = self.con.cursor()
        sql = "SELECT oid, nspname, pg_get_userbyid(nspowner), nspacl \
               FROM pg_namespace \
               WHERE nspname !~ '^pg_' AND nspname != 'information_schema'"
        self._exec_sql(c, sql)
        return c.fetchall()

    def list_geotables(self, schema=None):
        """Get list of tables with schemas, whether user has privileges,
        whether table has geometry column(s) etc.

        Geometry_columns:
          - f_table_schema
          - f_table_name
          - f_geometry_column
          - coord_dimension
          - srid
          - type
        """

        c = self.con.cursor()

        if schema:
            schema_where = " AND nspname = '%s' " % self._quote_unicode(schema)
        else:
            schema_where = \
                " AND (nspname != 'information_schema' AND nspname !~ 'pg_') "

        # LEFT OUTER JOIN: like LEFT JOIN but if there are more matches,
        # for join, all are used (not only one)

        # First find out whether PostGIS is enabled
        if not self.has_postgis:
            # Get all tables and views
            sql = """SELECT pg_class.relname, pg_namespace.nspname,
                            pg_class.relkind, pg_get_userbyid(relowner),
                            reltuples, relpages, NULL, NULL, NULL, NULL
                  FROM pg_class
                  JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
                  WHERE pg_class.relkind IN ('v', 'r', 'm', 'p')""" \
                  + schema_where + 'ORDER BY nspname, relname'
        else:
            # Discovery of all tables and whether they contain a
            # geometry column
            sql = """SELECT pg_class.relname, pg_namespace.nspname,
                            pg_class.relkind, pg_get_userbyid(relowner),
                            reltuples, relpages, pg_attribute.attname,
                            pg_attribute.atttypid::regtype, NULL, NULL
                  FROM pg_class
                  JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
                  LEFT OUTER JOIN pg_attribute ON
                      pg_attribute.attrelid = pg_class.oid AND
                      (pg_attribute.atttypid = 'geometry'::regtype
                      OR pg_attribute.atttypid IN
                          (SELECT oid FROM pg_type
                           WHERE typbasetype='geometry'::regtype))
                  WHERE pg_class.relkind IN ('v', 'r', 'm', 'p') """ \
                  + schema_where + 'ORDER BY nspname, relname, attname'

        self._exec_sql(c, sql)
        items = c.fetchall()

        # Get geometry info from geometry_columns if exists
        if self.has_postgis:
            sql = """SELECT relname, nspname, relkind,
                            pg_get_userbyid(relowner), reltuples, relpages,
                            geometry_columns.f_geometry_column,
                            geometry_columns.type,
                            geometry_columns.coord_dimension,
                            geometry_columns.srid
                  FROM pg_class
                  JOIN pg_namespace ON relnamespace=pg_namespace.oid
                  LEFT OUTER JOIN geometry_columns ON
                      relname=f_table_name AND nspname=f_table_schema
                  WHERE relkind IN ('r','v','m','p') """ \
                  + schema_where + 'ORDER BY nspname, relname, \
                  f_geometry_column'
            self._exec_sql(c, sql)

            # Merge geometry info to "items"
            for (i, geo_item) in enumerate(c.fetchall()):
                if geo_item[7]:
                    items[i] = geo_item

        return items

    def get_table_rows(self, table, schema=None):
        c = self.con.cursor()
        self._exec_sql(c, 'SELECT COUNT(*) FROM %s' % self._table_name(schema,
                                                                       table))
        return c.fetchone()[0]

    def get_table_fields(self, table, schema=None):
        """Return list of columns in table"""

        c = self.con.cursor()
        schema_where = (" AND nspname='%s' "
                        % self._quote_unicode(schema) if schema is not None else ''
                        )
        sql = """SELECT a.attnum AS ordinal_position,
                        a.attname AS column_name,
                        t.typname AS data_type,
                        a.attlen AS char_max_len,
                        a.atttypmod AS modifier,
                        a.attnotnull AS notnull,
                        a.atthasdef AS hasdefault,
                        adef.adsrc AS default_value
              FROM pg_class c
              JOIN pg_attribute a ON a.attrelid = c.oid
              JOIN pg_type t ON a.atttypid = t.oid
              JOIN pg_namespace nsp ON c.relnamespace = nsp.oid
              LEFT JOIN pg_attrdef adef ON adef.adrelid = a.attrelid
                  AND adef.adnum = a.attnum
              WHERE
                  c.relname = '%s' %s AND
                  a.attnum > 0
              ORDER BY a.attnum""" \
              % (self._quote_unicode(table), schema_where)

        self._exec_sql(c, sql)
        attrs = []
        for row in c.fetchall():
            attrs.append(TableAttribute(row))
        return attrs

    def get_table_indexes(self, table, schema=None):
        """Get info about table's indexes. ignore primary key and unique
        index, they get listed in constraints.
        """

        c = self.con.cursor()

        schema_where = (" AND nspname='%s' "
                        % self._quote_unicode(schema) if schema is not None else ''
                        )
        sql = """SELECT relname, indkey
              FROM pg_class, pg_index
              WHERE pg_class.oid = pg_index.indexrelid AND pg_class.oid IN (
                     SELECT indexrelid
                     FROM pg_index, pg_class
                     JOIN pg_namespace nsp ON pg_class.relnamespace = nsp.oid
                     WHERE pg_class.relname='%s' %s AND
                         pg_class.oid=pg_index.indrelid
                         AND indisunique != 't' AND indisprimary != 't' )""" \
              % (self._quote_unicode(table), schema_where)
        self._exec_sql(c, sql)
        indexes = []
        for row in c.fetchall():
            indexes.append(TableIndex(row))
        return indexes

    def get_table_constraints(self, table, schema=None):
        c = self.con.cursor()

        schema_where = (" AND nspname='%s' "
                        % self._quote_unicode(schema) if schema is not None else ''
                        )
        sql = """SELECT c.conname, c.contype, c.condeferrable, c.condeferred,
                        array_to_string(c.conkey, ' '), c.consrc, t2.relname,
                        c.confupdtype, c.confdeltype, c.confmatchtype,
                        array_to_string(c.confkey, ' ')
              FROM pg_constraint c
              LEFT JOIN pg_class t ON c.conrelid = t.oid
              LEFT JOIN pg_class t2 ON c.confrelid = t2.oid
              JOIN pg_namespace nsp ON t.relnamespace = nsp.oid
              WHERE t.relname = '%s' %s """ \
              % (self._quote_unicode(table), schema_where)

        self._exec_sql(c, sql)

        constrs = []
        for row in c.fetchall():
            constrs.append(TableConstraint(row))
        return constrs

    def get_view_definition(self, view, schema=None):
        """Returns definition of the view."""

        schema_where = (" AND nspname='%s' "
                        % self._quote_unicode(schema) if schema is not None else ''
                        )
        sql = """SELECT pg_get_viewdef(c.oid)
              FROM pg_class c
              JOIN pg_namespace nsp ON c.relnamespace = nsp.oid
              WHERE relname='%s' %s AND relkind IN ('v','m')""" \
              % (self._quote_unicode(view), schema_where)
        c = self.con.cursor()
        self._exec_sql(c, sql)
        return c.fetchone()[0]

    def add_geometry_column(self, table, geom_type, schema=None,
                            geom_column='the_geom', srid=-1, dim=2):
        # Use schema if explicitly specified
        if schema:
            schema_part = "'%s', " % self._quote_unicode(schema)
        else:
            schema_part = ''
        sql = "SELECT AddGeometryColumn(%s'%s', '%s', %d, '%s', %d)" % (
            schema_part,
            self._quote_unicode(table),
            self._quote_unicode(geom_column),
            srid,
            self._quote_unicode(geom_type),
            dim,
        )
        self._exec_sql_and_commit(sql)

    def delete_geometry_column(self, table, geom_column, schema=None):
        """Use PostGIS function to delete geometry column correctly."""

        if schema:
            schema_part = "'%s', " % self._quote_unicode(schema)
        else:
            schema_part = ''
        sql = "SELECT DropGeometryColumn(%s'%s', '%s')" % (schema_part,
                                                           self._quote_unicode(table), self._quote_unicode(geom_column))
        self._exec_sql_and_commit(sql)

    def delete_geometry_table(self, table, schema=None):
        """Delete table with one or more geometries using PostGIS function."""

        if schema:
            schema_part = "'%s', " % self._quote_unicode(schema)
        else:
            schema_part = ''
        sql = "SELECT DropGeometryTable(%s'%s')" % (schema_part,
                                                    self._quote_unicode(table))
        self._exec_sql_and_commit(sql)

    def create_table(self, table, fields, pkey=None, schema=None):
        """Create ordinary table.

        'fields' is array containing instances of TableField
        'pkey' contains name of column to be used as primary key
        """

        if len(fields) == 0:
            return False

        table_name = self._table_name(schema, table)

        sql = 'CREATE TABLE %s (%s' % (table_name, fields[0].field_def())
        for field in fields[1:]:
            sql += ', %s' % field.field_def()
        if pkey:
            sql += ', PRIMARY KEY (%s)' % self._quote(pkey)
        sql += ')'
        self._exec_sql_and_commit(sql)
        return True

    def delete_table(self, table, schema=None):
        """Delete table from the database."""

        table_name = self._table_name(schema, table)
        sql = 'DROP TABLE %s' % table_name
        self._exec_sql_and_commit(sql)

    def empty_table(self, table, schema=None):
        """Delete all rows from table."""

        table_name = self._table_name(schema, table)
        sql = 'DELETE FROM %s' % table_name
        self._exec_sql_and_commit(sql)

    def rename_table(self, table, new_table, schema=None):
        """Rename a table in database."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s RENAME TO %s' % (table_name,
                                               self._quote(new_table))
        self._exec_sql_and_commit(sql)

        # Update geometry_columns if PostGIS is enabled
        if self.has_postgis:
            sql = "UPDATE geometry_columns SET f_table_name='%s' \
                   WHERE f_table_name='%s'" \
                   % (self._quote_unicode(new_table), self._quote_unicode(table))
            if schema is not None:
                sql += " AND f_table_schema='%s'" % self._quote_unicode(schema)
            self._exec_sql_and_commit(sql)

    def create_view(self, name, query, schema=None):
        view_name = self._table_name(schema, name)
        sql = 'CREATE VIEW %s AS %s' % (view_name, query)
        self._exec_sql_and_commit(sql)

    def delete_view(self, name, schema=None):
        view_name = self._table_name(schema, name)
        sql = 'DROP VIEW %s' % view_name
        self._exec_sql_and_commit(sql)

    def rename_view(self, name, new_name, schema=None):
        """Rename view in database."""

        self.rename_table(name, new_name, schema)

    def create_schema(self, schema):
        """Create a new empty schema in database."""

        sql = 'CREATE SCHEMA %s' % self._quote(schema)
        self._exec_sql_and_commit(sql)

    def delete_schema(self, schema):
        """Drop (empty) schema from database."""

        sql = 'DROP SCHEMA %s' % self._quote(schema)
        self._exec_sql_and_commit(sql)

    def rename_schema(self, schema, new_schema):
        """Rename a schema in database."""

        sql = 'ALTER SCHEMA %s RENAME TO %s' % (self._quote(schema),
                                                self._quote(new_schema))
        self._exec_sql_and_commit(sql)

        # Update geometry_columns if PostGIS is enabled
        if self.has_postgis:
            sql = \
                "UPDATE geometry_columns SET f_table_schema='%s' \
                 WHERE f_table_schema='%s'" \
                 % (self._quote_unicode(new_schema), self._quote_unicode(schema))
            self._exec_sql_and_commit(sql)

    def table_add_column(self, table, field, schema=None):
        """Add a column to table (passed as TableField instance)."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ADD %s' % (table_name, field.field_def())
        self._exec_sql_and_commit(sql)

    def table_delete_column(self, table, field, schema=None):
        """Delete column from a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s DROP %s' % (table_name, self._quote(field))
        self._exec_sql_and_commit(sql)

    def table_column_rename(self, table, name, new_name, schema=None):
        """Rename column in a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s RENAME %s TO %s' % (table_name,
                                                  self._quote(name), self._quote(new_name))
        self._exec_sql_and_commit(sql)

        # Update geometry_columns if PostGIS is enabled
        if self.has_postgis:
            sql = "UPDATE geometry_columns SET f_geometry_column='%s' \
                   WHERE f_geometry_column='%s' AND f_table_name='%s'" \
                   % (self._quote_unicode(new_name), self._quote_unicode(name),
                      self._quote_unicode(table))
            if schema is not None:
                sql += " AND f_table_schema='%s'" % self._quote(schema)
            self._exec_sql_and_commit(sql)

    def table_column_set_type(self, table, column, data_type, schema=None):
        """Change column type."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ALTER %s TYPE %s' % (table_name,
                                                   self._quote(column), data_type)
        self._exec_sql_and_commit(sql)

    def table_column_set_default(self, table, column, default, schema=None):
        """Change column's default value.

        If default=None drop default value.
        """

        table_name = self._table_name(schema, table)
        if default:
            sql = 'ALTER TABLE %s ALTER %s SET DEFAULT %s' % (table_name,
                                                              self._quote(column), default)
        else:
            sql = 'ALTER TABLE %s ALTER %s DROP DEFAULT' % (table_name,
                                                            self._quote(column))
        self._exec_sql_and_commit(sql)

    def table_column_set_null(self, table, column, is_null, schema=None):
        """Change whether column can contain null values."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ALTER %s ' % (table_name, self._quote(column))
        if is_null:
            sql += 'DROP NOT NULL'
        else:
            sql += 'SET NOT NULL'
        self._exec_sql_and_commit(sql)

    def table_add_primary_key(self, table, column, schema=None):
        """Add a primery key (with one column) to a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ADD PRIMARY KEY (%s)' % (table_name,
                                                       self._quote(column))
        self._exec_sql_and_commit(sql)

    def table_add_unique_constraint(self, table, column, schema=None):
        """Add a unique constraint to a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s ADD UNIQUE (%s)' % (table_name,
                                                  self._quote(column))
        self._exec_sql_and_commit(sql)

    def table_delete_constraint(self, table, constraint, schema=None):
        """Delete constraint in a table."""

        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s DROP CONSTRAINT %s' % (table_name,
                                                     self._quote(constraint))
        self._exec_sql_and_commit(sql)

    def table_move_to_schema(self, table, new_schema, schema=None):
        if new_schema == schema:
            return
        table_name = self._table_name(schema, table)
        sql = 'ALTER TABLE %s SET SCHEMA %s' % (table_name,
                                                self._quote(new_schema))
        self._exec_sql_and_commit(sql)

        # Update geometry_columns if PostGIS is enabled
        if self.has_postgis:
            sql = "UPDATE geometry_columns SET f_table_schema='%s' \
                   WHERE f_table_name='%s'" \
                   % (self._quote_unicode(new_schema), self._quote_unicode(table))
            if schema is not None:
                sql += " AND f_table_schema='%s'" % self._quote_unicode(schema)
            self._exec_sql_and_commit(sql)

    def create_index(self, table, name, column, schema=None):
        """Create index on one column using default options."""

        table_name = self._table_name(schema, table)
        idx_name = self._quote(name)
        sql = 'CREATE INDEX %s ON %s (%s)' % (idx_name, table_name,
                                              self._quote(column))
        self._exec_sql_and_commit(sql)

    def create_spatial_index(self, table, schema=None, geom_column='the_geom'):
        table_name = self._table_name(schema, table)
        idx_name = self._quote(u"sidx_%s_%s" % (table, geom_column))
        sql = 'CREATE INDEX %s ON %s USING GIST(%s)' % (idx_name, table_name,
                                                        self._quote(geom_column))
        self._exec_sql_and_commit(sql)

    def delete_index(self, name, schema=None):
        index_name = self._table_name(schema, name)
        sql = 'DROP INDEX %s' % index_name
        self._exec_sql_and_commit(sql)

    def get_database_privileges(self):
        """DB privileges: (can create schemas, can create temp. tables).
        """

        sql = "SELECT has_database_privilege('%(d)s', 'CREATE'), \
                      has_database_privilege('%(d)s', 'TEMP')" \
              % {'d': self._quote_unicode(self.uri.database())}
        c = self.con.cursor()
        self._exec_sql(c, sql)
        return c.fetchone()

    def get_schema_privileges(self, schema):
        """Schema privileges: (can create new objects, can access objects
        in schema)."""

        sql = "SELECT has_schema_privilege('%(s)s', 'CREATE'), \
                      has_schema_privilege('%(s)s', 'USAGE')" \
              % {'s': self._quote_unicode(schema)}
        c = self.con.cursor()
        self._exec_sql(c, sql)
        return c.fetchone()

    def get_table_privileges(self, table, schema=None):
        """Table privileges: (select, insert, update, delete).
        """

        t = self._table_name(schema, table)
        sql = """SELECT has_table_privilege('%(t)s', 'SELECT'),
                        has_table_privilege('%(t)s', 'INSERT'),
                        has_table_privilege('%(t)s', 'UPDATE'),
                        has_table_privilege('%(t)s', 'DELETE')""" \
              % {'t': self._quote_unicode(t)}
        c = self.con.cursor()
        self._exec_sql(c, sql)
        return c.fetchone()

    def vacuum_analyze(self, table, schema=None):
        """Run VACUUM ANALYZE on a table."""

        t = self._table_name(schema, table)

        # VACUUM ANALYZE must be run outside transaction block - we
        # have to change isolation level
        self.con.set_isolation_level(
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
        c = self.con.cursor()
        self._exec_sql(c, 'VACUUM ANALYZE %s' % t)
        self.con.set_isolation_level(
            psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)

    def sr_info_for_srid(self, srid):
        if not self.has_postgis:
            return 'Unknown'

        try:
            c = self.con.cursor()
            self._exec_sql(c,
                           "SELECT srtext FROM spatial_ref_sys WHERE srid = '%d'"
                           % srid)
            srtext = c.fetchone()[0]

            # Try to extract just SR name (should be quoted in double
            # quotes)
            x = re.search('"([^"]+)"', srtext)
            if x is not None:
                srtext = x.group()
            return srtext
        except DbError:
            return 'Unknown'

    def insert_table_row(self, table, values, schema=None, cursor=None):
        """Insert a row with specified values to a table.

        If a cursor is specified, it doesn't commit (expecting that
        there will be more inserts) otherwise it commits immediately.
        """

        t = self._table_name(schema, table)
        sql = ''
        for value in values:
            # TODO: quote values?
            if sql:
                sql += ', '
            sql += value
        sql = 'INSERT INTO %s VALUES (%s)' % (t, sql)
        if cursor:
            self._exec_sql(cursor, sql)
        else:
            self._exec_sql_and_commit(sql)

    def _exec_sql(self, cursor, sql):
        try:
            cursor.execute(sql)
        except psycopg2.Error as e:
            raise QgsProcessingException(str(e) + ' QUERY: ' +
                                         e.cursor.query.decode(e.cursor.connection.encoding))

    def _exec_sql_and_commit(self, sql):
        """Tries to execute and commit some action, on error it rolls
        back the change.
        """

        try:
            c = self.con.cursor()
            self._exec_sql(c, sql)
            self.con.commit()
        except DbError:
            self.con.rollback()
            raise

    def _quote(self, identifier):
        """Quote identifier if needed."""

        # Make sure it's python unicode string
        identifier = str(identifier)

        # Is it needed to quote the identifier?
        if self.re_ident_ok.match(identifier) is not None:
            return identifier

        # It's needed - let's quote it (and double the double-quotes)
        return u'"%s"' % identifier.replace('"', '""')

    def _quote_unicode(self, txt):
        """Make the string safe - replace ' with ''.
        """

        # make sure it's python unicode string
        txt = str(txt)
        return txt.replace("'", "''")

    def _table_name(self, schema, table):
        if not schema:
            return self._quote(table)
        else:
            return u'%s.%s' % (self._quote(schema), self._quote(table))
示例#14
0
    def ogrConnectionStringAndFormat(uri, context):
        """Generates OGR connection string and format string from layer source
        Returned values are a tuple of the connection string and format string
        """
        ogrstr = None
        format = None

        layer = QgsProcessingUtils.mapLayerFromString(uri, context, False)
        if layer is None:
            path, ext = os.path.splitext(uri)
            format = QgsVectorFileWriter.driverForExtension(ext)
            return '"' + uri + '"', '"' + format + '"'

        provider = layer.dataProvider().name()
        if provider == 'spatialite':
            # dbname='/geodata/osm_ch.sqlite' table="places" (Geometry) sql=
            regex = re.compile("dbname='(.+)'")
            r = regex.search(str(layer.source()))
            ogrstr = r.groups()[0]
            format = 'SQLite'
        elif provider == 'postgres':
            # dbname='ktryjh_iuuqef' host=spacialdb.com port=9999
            # user='******' password='******' sslmode=disable
            # key='gid' estimatedmetadata=true srid=4326 type=MULTIPOLYGON
            # table="t4" (geom) sql=
            dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
            conninfo = dsUri.connectionInfo()
            conn = None
            ok = False
            while not conn:
                try:
                    conn = psycopg2.connect(dsUri.connectionInfo())
                except psycopg2.OperationalError:
                    (ok, user, passwd) = QgsCredentials.instance().get(
                        conninfo, dsUri.username(), dsUri.password())
                    if not ok:
                        break

                    dsUri.setUsername(user)
                    dsUri.setPassword(passwd)

            if not conn:
                raise RuntimeError(
                    'Could not connect to PostgreSQL database - check connection info'
                )

            if ok:
                QgsCredentials.instance().put(conninfo, user, passwd)

            ogrstr = "PG:%s" % dsUri.connectionInfo()
            format = 'PostgreSQL'
        elif provider == "oracle":
            # OCI:user/password@host:port/service:table
            dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
            ogrstr = "OCI:"
            if dsUri.username() != "":
                ogrstr += dsUri.username()
                if dsUri.password() != "":
                    ogrstr += "/" + dsUri.password()
                delim = "@"

            if dsUri.host() != "":
                ogrstr += delim + dsUri.host()
                delim = ""
                if dsUri.port() != "" and dsUri.port() != '1521':
                    ogrstr += ":" + dsUri.port()
                ogrstr += "/"
                if dsUri.database() != "":
                    ogrstr += dsUri.database()
            elif dsUri.database() != "":
                ogrstr += delim + dsUri.database()

            if ogrstr == "OCI:":
                raise RuntimeError(
                    'Invalid oracle data source - check connection info')

            ogrstr += ":"
            if dsUri.schema() != "":
                ogrstr += dsUri.schema() + "."

            ogrstr += dsUri.table()
            format = 'OCI'
        else:
            ogrstr = str(layer.source()).split("|")[0]
            path, ext = os.path.splitext(ogrstr)
            format = QgsVectorFileWriter.driverForExtension(ext)

        return '"' + ogrstr + '"', '"' + format + '"'
示例#15
0
def ogrConnectionString(uri):
    """Generates OGR connection sting from layer source
    """
    ogrstr = None

    layer = dataobjects.getObjectFromUri(uri, False)
    if layer is None:
        return '"' + uri + '"'
    provider = layer.dataProvider().name()
    if provider == "spatialite":
        # dbname='/geodata/osm_ch.sqlite' table="places" (Geometry) sql=
        regex = re.compile("dbname='(.+)'")
        r = regex.search(str(layer.source()))
        ogrstr = r.groups()[0]
    elif provider == "postgres":
        # dbname='ktryjh_iuuqef' host=spacialdb.com port=9999
        # user='******' password='******' sslmode=disable
        # key='gid' estimatedmetadata=true srid=4326 type=MULTIPOLYGON
        # table="t4" (geom) sql=
        dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
        conninfo = dsUri.connectionInfo()
        conn = None
        ok = False
        while not conn:
            try:
                conn = psycopg2.connect(dsUri.connectionInfo())
            except psycopg2.OperationalError:
                (ok, user, passwd) = QgsCredentials.instance().get(conninfo, dsUri.username(), dsUri.password())
                if not ok:
                    break

                dsUri.setUsername(user)
                dsUri.setPassword(passwd)

        if not conn:
            raise RuntimeError("Could not connect to PostgreSQL database - check connection info")

        if ok:
            QgsCredentials.instance().put(conninfo, user, passwd)

        ogrstr = "PG:%s" % dsUri.connectionInfo()
    elif provider == "oracle":
        # OCI:user/password@host:port/service:table
        dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
        ogrstr = "OCI:"
        if dsUri.username() != "":
            ogrstr += dsUri.username()
            if dsUri.password() != "":
                ogrstr += "/" + dsUri.password()
            delim = "@"

        if dsUri.host() != "":
            ogrstr += delim + dsUri.host()
            delim = ""
            if dsUri.port() != "" and dsUri.port() != "1521":
                ogrstr += ":" + dsUri.port()
            ogrstr += "/"
            if dsUri.database() != "":
                ogrstr += dsUri.database()
        elif dsUri.database() != "":
            ogrstr += delim + dsUri.database()

        if ogrstr == "OCI:":
            raise RuntimeError("Invalid oracle data source - check connection info")

        ogrstr += ":"
        if dsUri.schema() != "":
            ogrstr += dsUri.schema() + "."

        ogrstr += dsUri.table()
    else:
        ogrstr = str(layer.source()).split("|")[0]

    return '"' + ogrstr + '"'
示例#16
0
def ogrConnectionString(uri):
    """Generates OGR connection sting from layer source
    """
    ogrstr = None

    layer = dataobjects.getObjectFromUri(uri, False)
    if layer is None:
        return '"' + uri + '"'
    provider = layer.dataProvider().name()
    if provider == 'spatialite':
        # dbname='/geodata/osm_ch.sqlite' table="places" (Geometry) sql=
        regex = re.compile("dbname='(.+)'")
        r = regex.search(unicode(layer.source()))
        ogrstr = r.groups()[0]
    elif provider == 'postgres':
        # dbname='ktryjh_iuuqef' host=spacialdb.com port=9999
        # user='******' password='******' sslmode=disable
        # key='gid' estimatedmetadata=true srid=4326 type=MULTIPOLYGON
        # table="t4" (geom) sql=
        dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
        conninfo = dsUri.connectionInfo()
        conn = None
        ok = False
        while not conn:
            try:
                conn = psycopg2.connect(dsUri.connectionInfo())
            except psycopg2.OperationalError:
                (ok, user, passwd) = QgsCredentials.instance().get(
                    conninfo, dsUri.username(), dsUri.password())
                if not ok:
                    break

                dsUri.setUsername(user)
                dsUri.setPassword(passwd)

        if not conn:
            raise RuntimeError(
                'Could not connect to PostgreSQL database - check connection info'
            )

        if ok:
            QgsCredentials.instance().put(conninfo, user, passwd)

        ogrstr = "PG:%s" % dsUri.connectionInfo()
    elif provider == "oracle":
        # OCI:user/password@host:port/service:table
        dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
        ogrstr = "OCI:"
        if dsUri.username() != "":
            ogrstr += dsUri.username()
            if dsUri.password() != "":
                ogrstr += "/" + dsUri.password()
            delim = "@"

        if dsUri.host() != "":
            ogrstr += delim + dsUri.host()
            delim = ""
            if dsUri.port() != "" and dsUri.port() != '1521':
                ogrstr += ":" + dsUri.port()
            ogrstr += "/"
            if dsUri.database() != "":
                ogrstr += dsUri.database()
        elif dsUri.database() != "":
            ogrstr += delim + dsUri.database()

        if ogrstr == "OCI:":
            raise RuntimeError(
                'Invalid oracle data source - check connection info')

        ogrstr += ":"
        if dsUri.schema() != "":
            ogrstr += dsUri.schema() + "."

        ogrstr += dsUri.table()
    else:
        ogrstr = unicode(layer.source()).split("|")[0]

    return '"' + ogrstr + '"'
示例#17
0
    def ogrConnectionStringAndFormatFromLayer(layer):
        provider = layer.dataProvider().name()
        if provider == 'spatialite':
            # dbname='/geodata/osm_ch.sqlite' table="places" (Geometry) sql=
            regex = re.compile("dbname='(.+)'")
            r = regex.search(str(layer.source()))
            ogrstr = r.groups()[0]
            format = 'SQLite'
        elif provider == 'postgres':
            # dbname='ktryjh_iuuqef' host=spacialdb.com port=9999
            # user='******' password='******' sslmode=disable
            # key='gid' estimatedmetadata=true srid=4326 type=MULTIPOLYGON
            # table="t4" (geom) sql=
            dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
            conninfo = dsUri.connectionInfo()
            conn = None
            ok = False
            while not conn:
                try:
                    conn = psycopg2.connect(dsUri.connectionInfo())
                except psycopg2.OperationalError:
                    (ok, user, passwd) = QgsCredentials.instance().get(
                        conninfo, dsUri.username(), dsUri.password())
                    if not ok:
                        break

                    dsUri.setUsername(user)
                    dsUri.setPassword(passwd)

            if not conn:
                raise RuntimeError(
                    'Could not connect to PostgreSQL database - check connection info'
                )

            if ok:
                QgsCredentials.instance().put(conninfo, user, passwd)

            ogrstr = "PG:%s" % dsUri.connectionInfo()
            format = 'PostgreSQL'
        elif provider == 'mssql':
            #'dbname=\'db_name\' host=myHost estimatedmetadata=true
            # srid=27700 type=MultiPolygon table="dbo"."my_table"
            # #(Shape) sql='
            dsUri = layer.dataProvider().uri()
            ogrstr = 'MSSQL:'
            ogrstr += 'database={0};'.format(dsUri.database())
            ogrstr += 'server={0};'.format(dsUri.host())
            if dsUri.username() != "":
                ogrstr += 'uid={0};'.format(dsUri.username())
            else:
                ogrstr += 'trusted_connection=yes;'
            if dsUri.password() != '':
                ogrstr += 'pwd={0};'.format(dsUri.password())
            ogrstr += 'tables={0}'.format(dsUri.table())
            format = 'MSSQL'
        elif provider == "oracle":
            # OCI:user/password@host:port/service:table
            dsUri = QgsDataSourceUri(layer.dataProvider().dataSourceUri())
            ogrstr = "OCI:"
            if dsUri.username() != "":
                ogrstr += dsUri.username()
                if dsUri.password() != "":
                    ogrstr += "/" + dsUri.password()
                delim = "@"

            if dsUri.host() != "":
                ogrstr += delim + dsUri.host()
                delim = ""
                if dsUri.port() != "" and dsUri.port() != '1521':
                    ogrstr += ":" + dsUri.port()
                ogrstr += "/"
                if dsUri.database() != "":
                    ogrstr += dsUri.database()
            elif dsUri.database() != "":
                ogrstr += delim + dsUri.database()

            if ogrstr == "OCI:":
                raise RuntimeError(
                    'Invalid oracle data source - check connection info')

            ogrstr += ":"
            if dsUri.schema() != "":
                ogrstr += dsUri.schema() + "."

            ogrstr += dsUri.table()
            format = 'OCI'
        else:
            ogrstr = str(layer.source()).split("|")[0]
            path, ext = os.path.splitext(ogrstr)
            format = QgsVectorFileWriter.driverForExtension(ext)

        return ogrstr, '"' + format + '"'
示例#18
0
    def processAlgorithm(self, parameters, context, feedback):
        database = self.parameterAsVectorLayer(parameters, self.DATABASE,
                                               context)
        databaseuri = database.dataProvider().dataSourceUri()
        uri = QgsDataSourceUri(databaseuri)
        if uri.database() is '':
            if '|layername' in databaseuri:
                databaseuri = databaseuri[:databaseuri.find('|layername')]
            elif '|layerid' in databaseuri:
                databaseuri = databaseuri[:databaseuri.find('|layerid')]
            uri = QgsDataSourceUri('dbname=\'%s\'' % (databaseuri))
        db = spatialite.GeoDB(uri)

        overwrite = self.parameterAsBool(parameters, self.OVERWRITE, context)
        createIndex = self.parameterAsBool(parameters, self.CREATEINDEX,
                                           context)
        convertLowerCase = self.parameterAsBool(parameters,
                                                self.LOWERCASE_NAMES, context)
        dropStringLength = self.parameterAsBool(parameters,
                                                self.DROP_STRING_LENGTH,
                                                context)
        forceSinglePart = self.parameterAsBool(parameters,
                                               self.FORCE_SINGLEPART, context)
        primaryKeyField = self.parameterAsString(parameters, self.PRIMARY_KEY,
                                                 context) or 'id'
        encoding = self.parameterAsString(parameters, self.ENCODING, context)

        source = self.parameterAsSource(parameters, self.INPUT, context)

        table = self.parameterAsString(parameters, self.TABLENAME, context)
        if table:
            table.strip()
        if not table or table == '':
            table = source.sourceName()
            table = table.replace('.', '_')
        table = table.replace(' ', '').lower()
        providerName = 'spatialite'

        geomColumn = self.parameterAsString(parameters, self.GEOMETRY_COLUMN,
                                            context)
        if not geomColumn:
            geomColumn = 'geom'

        options = {}
        if overwrite:
            options['overwrite'] = True
        if convertLowerCase:
            options['lowercaseFieldNames'] = True
            geomColumn = geomColumn.lower()
        if dropStringLength:
            options['dropStringConstraints'] = True
        if forceSinglePart:
            options['forceSinglePartGeometryType'] = True

        # Clear geometry column for non-geometry tables
        if source.wkbType() == QgsWkbTypes.NoGeometry:
            geomColumn = None

        uri = db.uri
        uri.setDataSource('', table, geomColumn, '', primaryKeyField)

        if encoding:
            options['fileEncoding'] = encoding

        exporter = QgsVectorLayerExporter(uri.uri(), providerName,
                                          source.fields(), source.wkbType(),
                                          source.sourceCrs(), overwrite,
                                          options)

        if exporter.errorCode() != QgsVectorLayerExporter.NoError:
            raise QgsProcessingException(
                self.tr('Error importing to Spatialite\n{0}').format(
                    exporter.errorMessage()))

        features = source.getFeatures()
        total = 100.0 / source.featureCount() if source.featureCount() else 0
        for current, f in enumerate(features):
            if feedback.isCanceled():
                break

            if not exporter.addFeature(f, QgsFeatureSink.FastInsert):
                feedback.reportError(exporter.errorMessage())

            feedback.setProgress(int(current * total))

        exporter.flushBuffer()
        if exporter.errorCode() != QgsVectorLayerExporter.NoError:
            raise QgsProcessingException(
                self.tr('Error importing to Spatialite\n{0}').format(
                    exporter.errorMessage()))

        if geomColumn and createIndex:
            db.create_spatial_index(table, geomColumn)

        return {}
示例#19
0
    def __init__(self,
                 destination,
                 encoding,
                 fields,
                 geometryType,
                 crs,
                 options=None):
        self.destination = destination
        self.isNotFileBased = False
        self.layer = None
        self.writer = None

        if encoding is None:
            settings = QSettings()
            encoding = settings.value('/Processing/encoding',
                                      'System',
                                      type=str)

        if self.destination.startswith(self.MEMORY_LAYER_PREFIX):
            self.isNotFileBased = True

            uri = GEOM_TYPE_MAP[geometryType] + "?uuid=" + unicode(
                uuid.uuid4())
            if crs.isValid():
                uri += '&crs=' + crs.authid()
            fieldsdesc = []
            for f in fields:
                qgsfield = _toQgsField(f)
                fieldsdesc.append(
                    'field=%s:%s' %
                    (qgsfield.name(),
                     TYPE_MAP_MEMORY_LAYER.get(qgsfield.type(), "string")))
            if fieldsdesc:
                uri += '&' + '&'.join(fieldsdesc)

            self.layer = QgsVectorLayer(uri, self.destination, 'memory')
            self.writer = self.layer.dataProvider()
        elif self.destination.startswith(self.POSTGIS_LAYER_PREFIX):
            self.isNotFileBased = True
            uri = QgsDataSourceUri(
                self.destination[len(self.POSTGIS_LAYER_PREFIX):])
            connInfo = uri.connectionInfo()
            (success, user,
             passwd) = QgsCredentials.instance().get(connInfo, None, None)
            if success:
                QgsCredentials.instance().put(connInfo, user, passwd)
            else:
                raise GeoAlgorithmExecutionException(
                    "Couldn't connect to database")
            # fix_print_with_import
            print(uri.uri())
            try:
                db = postgis.GeoDB(host=uri.host(),
                                   port=int(uri.port()),
                                   dbname=uri.database(),
                                   user=user,
                                   passwd=passwd)
            except postgis.DbError as e:
                raise GeoAlgorithmExecutionException(
                    "Couldn't connect to database:\n%s" % e.message)

            def _runSQL(sql):
                try:
                    db._exec_sql_and_commit(unicode(sql))
                except postgis.DbError as e:
                    raise GeoAlgorithmExecutionException(
                        'Error creating output PostGIS table:\n%s' % e.message)

            fields = [_toQgsField(f) for f in fields]
            fieldsdesc = ",".join(
                '%s %s' %
                (f.name(), TYPE_MAP_POSTGIS_LAYER.get(f.type(), "VARCHAR"))
                for f in fields)

            _runSQL("CREATE TABLE %s.%s (%s)" %
                    (uri.schema(), uri.table().lower(), fieldsdesc))
            if geometryType != QgsWkbTypes.NullGeometry:
                _runSQL(
                    "SELECT AddGeometryColumn('{schema}', '{table}', 'the_geom', {srid}, '{typmod}', 2)"
                    .format(table=uri.table().lower(),
                            schema=uri.schema(),
                            srid=crs.authid().split(":")[-1],
                            typmod=GEOM_TYPE_MAP[geometryType].upper()))

            self.layer = QgsVectorLayer(uri.uri(), uri.table(), "postgres")
            self.writer = self.layer.dataProvider()
        elif self.destination.startswith(self.SPATIALITE_LAYER_PREFIX):
            self.isNotFileBased = True
            uri = QgsDataSourceUri(
                self.destination[len(self.SPATIALITE_LAYER_PREFIX):])
            # fix_print_with_import
            print(uri.uri())
            try:
                db = spatialite.GeoDB(uri=uri)
            except spatialite.DbError as e:
                raise GeoAlgorithmExecutionException(
                    "Couldn't connect to database:\n%s" % e.message)

            def _runSQL(sql):
                try:
                    db._exec_sql_and_commit(unicode(sql))
                except spatialite.DbError as e:
                    raise GeoAlgorithmExecutionException(
                        'Error creating output Spatialite table:\n%s' %
                        unicode(e))

            fields = [_toQgsField(f) for f in fields]
            fieldsdesc = ",".join(
                '%s %s' %
                (f.name(), TYPE_MAP_SPATIALITE_LAYER.get(f.type(), "VARCHAR"))
                for f in fields)

            _runSQL("DROP TABLE IF EXISTS %s" % uri.table().lower())
            _runSQL("CREATE TABLE %s (%s)" % (uri.table().lower(), fieldsdesc))
            if geometryType != QgsWkbTypes.NullGeometry:
                _runSQL(
                    "SELECT AddGeometryColumn('{table}', 'the_geom', {srid}, '{typmod}', 2)"
                    .format(table=uri.table().lower(),
                            srid=crs.authid().split(":")[-1],
                            typmod=GEOM_TYPE_MAP[geometryType].upper()))

            self.layer = QgsVectorLayer(uri.uri(), uri.table(), "spatialite")
            self.writer = self.layer.dataProvider()
        else:
            formats = QgsVectorFileWriter.supportedFiltersAndFormats()
            OGRCodes = {}
            for (key, value) in formats.items():
                extension = unicode(key)
                extension = extension[extension.find('*.') + 2:]
                extension = extension[:extension.find(' ')]
                OGRCodes[extension] = value
            OGRCodes['dbf'] = "DBF file"

            extension = self.destination[self.destination.rfind('.') + 1:]

            if extension not in OGRCodes:
                extension = 'shp'
                self.destination = self.destination + '.shp'

            if geometryType == QgsWkbTypes.NoGeometry:
                if extension == 'shp':
                    extension = 'dbf'
                    self.destination = self.destination[:self.destination.
                                                        rfind('.')] + '.dbf'
                if extension not in self.nogeometry_extensions:
                    raise GeoAlgorithmExecutionException(
                        "Unsupported format for tables with no geometry")

            qgsfields = QgsFields()
            for field in fields:
                qgsfields.append(_toQgsField(field))

            # use default dataset/layer options
            dataset_options = QgsVectorFileWriter.defaultDatasetOptions(
                OGRCodes[extension])
            layer_options = QgsVectorFileWriter.defaultLayerOptions(
                OGRCodes[extension])

            self.writer = QgsVectorFileWriter(self.destination, encoding,
                                              qgsfields, geometryType, crs,
                                              OGRCodes[extension],
                                              dataset_options, layer_options)
示例#20
0
    def processAlgorithm(self, parameters, context, feedback):
        database = self.parameterAsVectorLayer(parameters, self.DATABASE, context)
        databaseuri = database.dataProvider().dataSourceUri()
        uri = QgsDataSourceUri(databaseuri)
        if uri.database() is '':
            if '|layerid' in databaseuri:
                databaseuri = databaseuri[:databaseuri.find('|layerid')]
            uri = QgsDataSourceUri('dbname=\'%s\'' % (databaseuri))
        db = spatialite.GeoDB(uri)

        overwrite = self.parameterAsBool(parameters, self.OVERWRITE, context)
        createIndex = self.parameterAsBool(parameters, self.CREATEINDEX, context)
        convertLowerCase = self.parameterAsBool(parameters, self.LOWERCASE_NAMES, context)
        dropStringLength = self.parameterAsBool(parameters, self.DROP_STRING_LENGTH, context)
        forceSinglePart = self.parameterAsBool(parameters, self.FORCE_SINGLEPART, context)
        primaryKeyField = self.parameterAsString(parameters, self.PRIMARY_KEY, context) or 'id'
        encoding = self.parameterAsString(parameters, self.ENCODING, context)

        source = self.parameterAsSource(parameters, self.INPUT, context)

        table = self.parameterAsString(parameters, self.TABLENAME, context)
        if table:
            table.strip()
        if not table or table == '':
            table = source.sourceName()
            table = table.replace('.', '_')
        table = table.replace(' ', '').lower()
        providerName = 'spatialite'

        geomColumn = self.parameterAsString(parameters, self.GEOMETRY_COLUMN, context)
        if not geomColumn:
            geomColumn = 'geom'

        options = {}
        if overwrite:
            options['overwrite'] = True
        if convertLowerCase:
            options['lowercaseFieldNames'] = True
            geomColumn = geomColumn.lower()
        if dropStringLength:
            options['dropStringConstraints'] = True
        if forceSinglePart:
            options['forceSinglePartGeometryType'] = True

        # Clear geometry column for non-geometry tables
        if source.wkbType() == QgsWkbTypes.NoGeometry:
            geomColumn = None

        uri = db.uri
        uri.setDataSource('', table, geomColumn, '', primaryKeyField)

        if encoding:
            options['fileEncoding'] = encoding

        exporter = QgsVectorLayerExporter(uri.uri(), providerName, source.fields(),
                                          source.wkbType(), source.sourceCrs(), overwrite, options)

        if exporter.errorCode() != QgsVectorLayerExporter.NoError:
            raise QgsProcessingException(
                self.tr('Error importing to Spatialite\n{0}').format(exporter.errorMessage()))

        features = source.getFeatures()
        total = 100.0 / source.featureCount() if source.featureCount() else 0
        for current, f in enumerate(features):
            if feedback.isCanceled():
                break

            if not exporter.addFeature(f, QgsFeatureSink.FastInsert):
                feedback.reportError(exporter.errorMessage())

            feedback.setProgress(int(current * total))

        exporter.flushBuffer()
        if exporter.errorCode() != QgsVectorLayerExporter.NoError:
            raise QgsProcessingException(
                self.tr('Error importing to Spatialite\n{0}').format(exporter.errorMessage()))

        if geomColumn and createIndex:
            db.create_spatial_index(table, geomColumn)

        return {}
示例#21
0
    def processAlgorithm(self, parameters, context, feedback):
        database = self.getParameterValue(self.DATABASE)
        uri = QgsDataSourceUri(database)
        if uri.database() is '':
            if '|layerid' in database:
                database = database[:database.find('|layerid')]
            uri = QgsDataSourceUri('dbname=\'%s\'' % (database))
        db = spatialite.GeoDB(uri)

        overwrite = self.getParameterValue(self.OVERWRITE)
        createIndex = self.getParameterValue(self.CREATEINDEX)
        convertLowerCase = self.getParameterValue(self.LOWERCASE_NAMES)
        dropStringLength = self.getParameterValue(self.DROP_STRING_LENGTH)
        forceSinglePart = self.getParameterValue(self.FORCE_SINGLEPART)
        primaryKeyField = self.getParameterValue(self.PRIMARY_KEY) or 'id'
        encoding = self.getParameterValue(self.ENCODING)

        layerUri = self.getParameterValue(self.INPUT)
        layer = QgsProcessingUtils.mapLayerFromString(layerUri, context)

        table = self.getParameterValue(self.TABLENAME)
        if table:
            table.strip()
        if not table or table == '':
            table = layer.name()
        table = table.replace(' ', '').lower()
        providerName = 'spatialite'

        geomColumn = self.getParameterValue(self.GEOMETRY_COLUMN)
        if not geomColumn:
            geomColumn = 'the_geom'

        options = {}
        if overwrite:
            options['overwrite'] = True
        if convertLowerCase:
            options['lowercaseFieldNames'] = True
            geomColumn = geomColumn.lower()
        if dropStringLength:
            options['dropStringConstraints'] = True
        if forceSinglePart:
            options['forceSinglePartGeometryType'] = True

        # Clear geometry column for non-geometry tables
        if not layer.hasGeometryType():
            geomColumn = None

        uri = db.uri
        uri.setDataSource('', table, geomColumn, '', primaryKeyField)

        if encoding:
            layer.setProviderEncoding(encoding)

        (ret, errMsg) = QgsVectorLayerExporter.exportLayer(
            layer,
            uri.uri(),
            providerName,
            self.crs,
            False,
            options,
        )
        if ret != 0:
            raise GeoAlgorithmExecutionException(
                self.tr('Error importing to Spatialite\n{0}').format(errMsg))

        if geomColumn and createIndex:
            db.create_spatial_index(table, geomColumn)