예제 #1
0
async def _tiles_stack(datasette, request, tms):
    priority_order = await tiles_stack_database_order(datasette)
    # Try each database in turn
    for database in priority_order:
        tile = await load_tile(database, request, tms=tms)
        if tile is not None:
            return Response(body=tile, content_type="image/png")
    return Response(body=PNG_404, content_type="image/png", status=404)
예제 #2
0
async def _tile(request, datasette, tms):
    db_name = request.url_vars["db_name"]
    mbtiles_databases = await detect_mtiles_databases(datasette)
    if db_name not in mbtiles_databases:
        raise NotFound("Not a valid mbtiles database")
    db = datasette.get_database(db_name)
    tile = await load_tile(db, request, tms)
    if tile is None:
        return Response(body=PNG_404, content_type="image/png", status=404)
    return Response(body=tile, content_type="image/png")
예제 #3
0
 async def get(self, request):
     token = request.args.get("token") or ""
     if not self.ds._root_token:
         return Response("Root token has already been used", status=403)
     if secrets.compare_digest(token, self.ds._root_token):
         self.ds._root_token = None
         response = Response.redirect("/")
         response.set_cookie("ds_actor",
                             self.ds.sign({"a": {
                                 "id": "root"
                             }}, "actor"))
         return response
     else:
         return Response("Invalid token", status=403)
예제 #4
0
    async def get(self, request, as_format):
        await self.check_permission(request, "view-instance")
        if self.needs_request:
            data = self.data_callback(request)
        else:
            data = self.data_callback()
        if as_format:
            headers = {}
            if self.ds.cors:
                headers["Access-Control-Allow-Origin"] = "*"
            return Response(
                json.dumps(data),
                content_type="application/json; charset=utf-8",
                headers=headers,
            )

        else:
            return await self.render(
                ["show_json.html"],
                request=request,
                context={
                    "filename": self.filename,
                    "data_json": json.dumps(data, indent=4),
                },
            )
예제 #5
0
파일: special.py 프로젝트: simonw/datasette
    async def get(self, request):
        as_format = request.url_vars["format"]
        await self.ds.ensure_permissions(request.actor, ["view-instance"])
        if self.needs_request:
            data = self.data_callback(request)
        else:
            data = self.data_callback()
        if as_format:
            headers = {}
            if self.ds.cors:
                add_cors_headers(headers)
            return Response(
                json.dumps(data),
                content_type="application/json; charset=utf-8",
                headers=headers,
            )

        else:
            return await self.render(
                ["show_json.html"],
                request=request,
                context={
                    "filename": self.filename,
                    "data_json": json.dumps(data, indent=4),
                },
            )
예제 #6
0
 async def get(self, request):
     if not await self.ds.permission_allowed(request.actor, "permissions-debug"):
         return Response("Permission denied", status=403)
     return await self.render(
         ["permissions_debug.html"],
         request,
         {"permission_checks": reversed(self.ds._permission_checks)},
     )
예제 #7
0
 async def get(self, request):
     token = request.args.get("token") or ""
     if not self.ds._root_token:
         return Response("Root token has already been used", status=403)
     if secrets.compare_digest(token, self.ds._root_token):
         self.ds._root_token = None
         cookie = SimpleCookie()
         cookie["ds_actor"] = self.ds.sign({"id": "root"}, "actor")
         cookie["ds_actor"]["path"] = "/"
         response = Response(
             body="",
             status=302,
             headers={
                 "Location": "/",
                 "set-cookie": cookie.output(header="").lstrip(),
             },
         )
         return response
     else:
         return Response("Invalid token", status=403)
예제 #8
0
def render_ics(datasette, request, database, table, rows, columns, sql,
               query_name, data):
    from datasette.views.base import DatasetteError

    if not REQUIRED_COLUMNS.issubset(columns):
        raise DatasetteError(
            "SQL query must return columns {}".format(
                ", ".join(REQUIRED_COLUMNS)),
            status=400,
        )
    c = Calendar(
        creator="-//Datasette {}//datasette-ics//EN".format(__version__))
    title = request.args.get("_ics_title") or ""
    if table and not title:
        title = table
    if data.get("human_description_en"):
        title += ": " + data["human_description_en"]

    # If this is a canned query the configured title for that over-rides all others
    if query_name:
        try:
            title = datasette.metadata(
                database=database)["queries"][query_name]["title"]
        except (KeyError, TypeError):
            pass

    if title:
        c.extra.append(ContentLine(name="X-WR-CALNAME", params={},
                                   value=title))

    for row in reversed(rows):
        e = EventWithTimezone()
        e.name = row["event_name"]
        e.begin = row["event_dtstart"]
        if "event_dtend" in columns:
            e.end = row["event_dtend"]
        elif "event_duration" in columns:
            e.duration = row["event_duration"]
        if "event_description" in columns:
            e.description = row["event_description"]
        if "event_uid" in columns:
            # TODO: Must be globally unique - include the
            # current URL to help achieve this
            e.uid = str(row["event_uid"])
        if "event_tzid" in columns:
            e.timezone = row["event_tzid"]
        c.events.add(e)

    content_type = "text/calendar; charset=utf-8"
    if request.args.get("_plain"):
        content_type = "text/plain; charset=utf-8"

    return Response(str(c), content_type=content_type, status=200)
예제 #9
0
async def sitemap_xml(datasette):
    content = [
        '<?xml version="1.0" encoding="UTF-8"?>',
        '<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">',
    ]
    for db in datasette.databases.values():
        for row in await db.execute("select id from museums"):
            content.append(
                "<url><loc>https://www.niche-museums.com/{}</loc></url>".
                format(row["id"]))
    content.append("</urlset>")
    return Response("\n".join(content), 200, content_type="application/xml")
예제 #10
0
    async def get(self, request, as_format):
        data = self.data_callback()
        if as_format:
            headers = {}
            if self.ds.cors:
                headers["Access-Control-Allow-Origin"] = "*"
            return Response(
                json.dumps(data),
                content_type="application/json; charset=utf-8",
                headers=headers,
            )

        else:
            return self.render(["show_json.html"], filename=self.filename, data=data)
예제 #11
0
파일: table.py 프로젝트: zeta1999/datasette
    async def get(self, request, db_name, table, pk_path, column):
        await self.check_permissions(
            request,
            [
                ("view-table", (db_name, table)),
                ("view-database", db_name),
                "view-instance",
            ],
        )
        try:
            db = self.ds.get_database(db_name)
        except KeyError:
            raise NotFound("Database {} does not exist".format(db_name))
        if not await db.table_exists(table):
            raise NotFound("Table {} does not exist".format(table))
        # Ensure the column exists and is of type BLOB
        column_types = {
            c.name: c.type
            for c in await db.table_column_details(table)
        }
        if column not in column_types:
            raise NotFound("Table {} does not have column {}".format(
                table, column))
        if column_types[column].upper() not in ("BLOB", ""):
            raise NotFound(
                "Table {} does not have column {} of type BLOB".format(
                    table, column))
        # Ensure the row exists for the pk_path
        pk_values = urlsafe_components(pk_path)
        sql, params, _ = await _sql_params_pks(db, table, pk_values)
        results = await db.execute(sql, params, truncate=True)
        rows = list(results.rows)
        if not rows:
            raise NotFound("Record not found: {}".format(pk_values))

        # Serve back the binary data
        filename_bits = [to_css_class(table), pk_path, to_css_class(column)]
        filename = "-".join(filename_bits) + ".blob"
        headers = {
            "X-Content-Type-Options": "nosniff",
            "Content-Disposition":
            'attachment; filename="{}"'.format(filename),
        }
        return Response(
            body=rows[0][column],
            status=200,
            headers=headers,
            content_type="application/binary",
        )
예제 #12
0
async def render_blob(datasette, database, rows, columns, request, table,
                      view_name):
    if _BLOB_COLUMN not in request.args:
        raise BadRequest(f"?{_BLOB_COLUMN}= is required")
    blob_column = request.args[_BLOB_COLUMN]
    if blob_column not in columns:
        raise BadRequest(f"{blob_column} is not a valid column")

    # If ?_blob_hash= provided, use that to select the row - otherwise use first row
    blob_hash = None
    if _BLOB_HASH in request.args:
        blob_hash = request.args[_BLOB_HASH]
        for row in rows:
            value = row[blob_column]
            if hashlib.sha256(value).hexdigest() == blob_hash:
                break
        else:
            # Loop did not break
            raise BadRequest(
                "Link has expired - the requested binary content has changed or could not be found."
            )
    else:
        row = rows[0]

    value = row[blob_column]
    filename_bits = []
    if table:
        filename_bits.append(to_css_class(table))
    if "pk_path" in request.url_vars:
        filename_bits.append(request.url_vars["pk_path"])
    filename_bits.append(to_css_class(blob_column))
    if blob_hash:
        filename_bits.append(blob_hash[:6])
    filename = "-".join(filename_bits) + ".blob"
    headers = {
        "X-Content-Type-Options": "nosniff",
        "Content-Disposition": f'attachment; filename="{filename}"',
    }
    return Response(
        body=value or b"",
        status=200,
        headers=headers,
        content_type="application/binary",
    )
예제 #13
0
async def render_geo_json(datasette, request, sql, columns, rows, database,
                          table, query_name, view_name, data):
    result = await datasette.execute(database, sql)

    # create a new list which will store the single GeoJSON features
    feature_list = list()

    for row in result:
        # Is this a blatant Point object?
        if 'longitude' in row.keys() and 'latitude' in row.keys():
            # https://geojson.org/geojson-spec.html#id2
            point = geojson.Point((row['longitude'], row['latitude']))
            # create a geojson feature
            feature = geojson.Feature(geometry=point, properties=dict(row))
            # append the current feature to the list of all features
            feature_list.append(feature)

        # Otherwise, does this have a "the_geom" object, which was used in the old Carto database, which encodes geographical data as a string in the "well-known binary" format?
        elif 'the_geom' in row.keys():
            feature = Geometry(row['the_geom'])
            feature_list.append(feature)

        else:
            # TODO: do we need to have error handling here?
            pass

    feature_collection = geojson.FeatureCollection(feature_list)
    body = geojson.dumps(feature_collection)
    content_type = "application/json; charset=utf-8"
    headers = {}
    status_code = 200  # add query / error handling!

    # Note: currently max result size is 1000 (can increase via settings.json)
    #       OR look at using link setup for paging result sets
    #       if required as outlined in datasette docs
    # if next_url:
    #     headers["link"] = f'<{next_url}>; rel="next"'

    return Response(body=body,
                    status=status_code,
                    headers=headers,
                    content_type=content_type)
예제 #14
0
 async def get(self, request):
     database = tilde_decode(request.url_vars["database"])
     await self.ds.ensure_permissions(
         request.actor,
         [
             ("view-database-download", database),
             ("view-database", database),
             "view-instance",
         ],
     )
     try:
         db = self.ds.get_database(route=database)
     except KeyError:
         raise DatasetteError("Invalid database", status=404)
     if db.is_memory:
         raise DatasetteError("Cannot download in-memory databases",
                              status=404)
     if not self.ds.setting("allow_download") or db.is_mutable:
         raise Forbidden("Database download is forbidden")
     if not db.path:
         raise DatasetteError("Cannot download database", status=404)
     filepath = db.path
     headers = {}
     if self.ds.cors:
         add_cors_headers(headers)
     if db.hash:
         etag = '"{}"'.format(db.hash)
         headers["Etag"] = etag
         # Has user seen this already?
         if_none_match = request.headers.get("if-none-match")
         if if_none_match and if_none_match == etag:
             return Response("", status=304)
     headers["Transfer-Encoding"] = "chunked"
     return AsgiFileDownload(
         filepath,
         filename=os.path.basename(filepath),
         content_type="application/octet-stream",
         headers=headers,
     )
예제 #15
0
    async def get(self, request, as_format):
        await self.check_permission(request, "view-instance")
        databases = []
        for name, db in self.ds.databases.items():
            visible, database_private = await check_visibility(
                self.ds,
                request.actor,
                "view-database",
                name,
            )
            if not visible:
                continue
            table_names = await db.table_names()
            hidden_table_names = set(await db.hidden_table_names())

            views = []
            for view_name in await db.view_names():
                visible, private = await check_visibility(
                    self.ds,
                    request.actor,
                    "view-table",
                    (name, view_name),
                )
                if visible:
                    views.append({"name": view_name, "private": private})

            # Perform counts only for immutable or DBS with <= COUNT_TABLE_LIMIT tables
            table_counts = {}
            if not db.is_mutable or db.size < COUNT_DB_SIZE_LIMIT:
                table_counts = await db.table_counts(10)
                # If any of these are None it means at least one timed out - ignore them all
                if any(v is None for v in table_counts.values()):
                    table_counts = {}

            tables = {}
            for table in table_names:
                visible, private = await check_visibility(
                    self.ds,
                    request.actor,
                    "view-table",
                    (name, table),
                )
                if not visible:
                    continue
                table_columns = await db.table_columns(table)
                tables[table] = {
                    "name": table,
                    "columns": table_columns,
                    "primary_keys": await db.primary_keys(table),
                    "count": table_counts.get(table),
                    "hidden": table in hidden_table_names,
                    "fts_table": await db.fts_table(table),
                    "num_relationships_for_sorting": 0,
                    "private": private,
                }

            if request.args.get(
                    "_sort") == "relationships" or not table_counts:
                # We will be sorting by number of relationships, so populate that field
                all_foreign_keys = await db.get_all_foreign_keys()
                for table, foreign_keys in all_foreign_keys.items():
                    if table in tables.keys():
                        count = len(foreign_keys["incoming"] +
                                    foreign_keys["outgoing"])
                        tables[table]["num_relationships_for_sorting"] = count

            hidden_tables = [t for t in tables.values() if t["hidden"]]
            visible_tables = [t for t in tables.values() if not t["hidden"]]

            tables_and_views_truncated = list(
                sorted(
                    (t for t in tables.values() if t not in hidden_tables),
                    key=lambda t: (
                        t["num_relationships_for_sorting"],
                        t["count"] or 0,
                        t["name"],
                    ),
                    reverse=True,
                )[:TRUNCATE_AT])

            # Only add views if this is less than TRUNCATE_AT
            if len(tables_and_views_truncated) < TRUNCATE_AT:
                num_views_to_add = TRUNCATE_AT - len(
                    tables_and_views_truncated)
                for view in views[:num_views_to_add]:
                    tables_and_views_truncated.append(view)

            databases.append({
                "name":
                name,
                "hash":
                db.hash,
                "color":
                db.hash[:6] if db.hash else hashlib.md5(
                    name.encode("utf8")).hexdigest()[:6],
                "path":
                self.ds.urls.database(name),
                "tables_and_views_truncated":
                tables_and_views_truncated,
                "tables_and_views_more":
                (len(visible_tables) + len(views)) > TRUNCATE_AT,
                "tables_count":
                len(visible_tables),
                "table_rows_sum":
                sum((t["count"] or 0) for t in visible_tables),
                "show_table_row_counts":
                bool(table_counts),
                "hidden_table_rows_sum":
                sum(t["count"] for t in hidden_tables
                    if t["count"] is not None),
                "hidden_tables_count":
                len(hidden_tables),
                "views_count":
                len(views),
                "private":
                database_private,
            })

        if as_format:
            headers = {}
            if self.ds.cors:
                add_cors_headers(headers)
            return Response(
                json.dumps({db["name"]: db
                            for db in databases},
                           cls=CustomJSONEncoder),
                content_type="application/json; charset=utf-8",
                headers=headers,
            )
        else:
            return await self.render(
                ["index.html"],
                request=request,
                context={
                    "databases":
                    databases,
                    "metadata":
                    self.ds.metadata(),
                    "datasette_version":
                    __version__,
                    "private":
                    not await self.ds.permission_allowed(
                        None, "view-instance", default=True),
                },
            )
예제 #16
0
파일: base.py 프로젝트: simonw/datasette
    async def get(self, request):
        database_route = tilde_decode(request.url_vars["database"])

        try:
            db = self.ds.get_database(route=database_route)
        except KeyError:
            raise NotFound("Database not found: {}".format(database_route))
        database = db.name

        _format = request.url_vars["format"]
        data_kwargs = {}

        if _format == "csv":
            return await self.as_csv(request, database_route)

        if _format is None:
            # HTML views default to expanding all foreign key labels
            data_kwargs["default_labels"] = True

        extra_template_data = {}
        start = time.perf_counter()
        status_code = None
        templates = []
        try:
            response_or_template_contexts = await self.data(
                request, **data_kwargs)
            if isinstance(response_or_template_contexts, Response):
                return response_or_template_contexts
            # If it has four items, it includes an HTTP status code
            if len(response_or_template_contexts) == 4:
                (
                    data,
                    extra_template_data,
                    templates,
                    status_code,
                ) = response_or_template_contexts
            else:
                data, extra_template_data, templates = response_or_template_contexts
        except QueryInterrupted:
            raise DatasetteError(
                """
                SQL query took too long. The time limit is controlled by the
                <a href="https://docs.datasette.io/en/stable/settings.html#sql-time-limit-ms">sql_time_limit_ms</a>
                configuration option.
            """,
                title="SQL Interrupted",
                status=400,
                message_is_html=True,
            )
        except (sqlite3.OperationalError, InvalidSql) as e:
            raise DatasetteError(str(e), title="Invalid SQL", status=400)

        except sqlite3.OperationalError as e:
            raise DatasetteError(str(e))

        except DatasetteError:
            raise

        end = time.perf_counter()
        data["query_ms"] = (end - start) * 1000
        for key in ("source", "source_url", "license", "license_url"):
            value = self.ds.metadata(key)
            if value:
                data[key] = value

        # Special case for .jsono extension - redirect to _shape=objects
        if _format == "jsono":
            return self.redirect(
                request,
                path_with_added_args(
                    request,
                    {"_shape": "objects"},
                    path=request.path.rsplit(".jsono", 1)[0] + ".json",
                ),
                forward_querystring=False,
            )

        if _format in self.ds.renderers.keys():
            # Dispatch request to the correct output format renderer
            # (CSV is not handled here due to streaming)
            result = call_with_supported_arguments(
                self.ds.renderers[_format][0],
                datasette=self.ds,
                columns=data.get("columns") or [],
                rows=data.get("rows") or [],
                sql=data.get("query", {}).get("sql", None),
                query_name=data.get("query_name"),
                database=database,
                table=data.get("table"),
                request=request,
                view_name=self.name,
                # These will be deprecated in Datasette 1.0:
                args=request.args,
                data=data,
            )
            if asyncio.iscoroutine(result):
                result = await result
            if result is None:
                raise NotFound("No data")
            if isinstance(result, dict):
                r = Response(
                    body=result.get("body"),
                    status=result.get("status_code", status_code or 200),
                    content_type=result.get("content_type", "text/plain"),
                    headers=result.get("headers"),
                )
            elif isinstance(result, Response):
                r = result
                if status_code is not None:
                    # Over-ride the status code
                    r.status = status_code
            else:
                assert False, f"{result} should be dict or Response"
        else:
            extras = {}
            if callable(extra_template_data):
                extras = extra_template_data()
                if asyncio.iscoroutine(extras):
                    extras = await extras
            else:
                extras = extra_template_data
            url_labels_extra = {}
            if data.get("expandable_columns"):
                url_labels_extra = {"_labels": "on"}

            renderers = {}
            for key, (_, can_render) in self.ds.renderers.items():
                it_can_render = call_with_supported_arguments(
                    can_render,
                    datasette=self.ds,
                    columns=data.get("columns") or [],
                    rows=data.get("rows") or [],
                    sql=data.get("query", {}).get("sql", None),
                    query_name=data.get("query_name"),
                    database=database,
                    table=data.get("table"),
                    request=request,
                    view_name=self.name,
                )
                it_can_render = await await_me_maybe(it_can_render)
                if it_can_render:
                    renderers[key] = self.ds.urls.path(
                        path_with_format(request=request,
                                         format=key,
                                         extra_qs={**url_labels_extra}))

            url_csv_args = {"_size": "max", **url_labels_extra}
            url_csv = self.ds.urls.path(
                path_with_format(request=request,
                                 format="csv",
                                 extra_qs=url_csv_args))
            url_csv_path = url_csv.split("?")[0]
            context = {
                **data,
                **extras,
                **{
                    "renderers":
                    renderers,
                    "url_csv":
                    url_csv,
                    "url_csv_path":
                    url_csv_path,
                    "url_csv_hidden_args": [(key, value) for key, value in urllib.parse.parse_qsl(request.query_string) if key not in ("_labels", "_facet", "_size")] + [("_size", "max")],
                    "datasette_version":
                    __version__,
                    "settings":
                    self.ds.settings_dict(),
                },
            }
            if "metadata" not in context:
                context["metadata"] = self.ds.metadata
            r = await self.render(templates, request=request, context=context)
            if status_code is not None:
                r.status = status_code

        ttl = request.args.get("_ttl", None)
        if ttl is None or not ttl.isdigit():
            ttl = self.ds.setting("default_cache_ttl")

        return self.set_response_headers(r, ttl)
예제 #17
0
    async def view_get(self, request, database, hash, correct_hash_provided,
                       **kwargs):
        _format, kwargs = await self.get_format(request, database, kwargs)

        if _format == "csv":
            return await self.as_csv(request, database, hash, **kwargs)

        if _format is None:
            # HTML views default to expanding all foreign key labels
            kwargs["default_labels"] = True

        extra_template_data = {}
        start = time.time()
        status_code = 200
        templates = []
        try:
            response_or_template_contexts = await self.data(
                request, database, hash, **kwargs)
            if isinstance(response_or_template_contexts, Response):
                return response_or_template_contexts

            else:
                data, extra_template_data, templates = response_or_template_contexts
        except QueryInterrupted:
            raise DatasetteError(
                """
                SQL query took too long. The time limit is controlled by the
                <a href="https://datasette.readthedocs.io/en/stable/config.html#sql-time-limit-ms">sql_time_limit_ms</a>
                configuration option.
            """,
                title="SQL Interrupted",
                status=400,
                messagge_is_html=True,
            )
        except (sqlite3.OperationalError, InvalidSql) as e:
            raise DatasetteError(str(e), title="Invalid SQL", status=400)

        except (sqlite3.OperationalError) as e:
            raise DatasetteError(str(e))

        except DatasetteError:
            raise

        end = time.time()
        data["query_ms"] = (end - start) * 1000
        for key in ("source", "source_url", "license", "license_url"):
            value = self.ds.metadata(key)
            if value:
                data[key] = value

        # Special case for .jsono extension - redirect to _shape=objects
        if _format == "jsono":
            return self.redirect(
                request,
                path_with_added_args(
                    request,
                    {"_shape": "objects"},
                    path=request.path.rsplit(".jsono", 1)[0] + ".json",
                ),
                forward_querystring=False,
            )

        if _format in self.ds.renderers.keys():
            # Dispatch request to the correct output format renderer
            # (CSV is not handled here due to streaming)
            result = self.ds.renderers[_format](request.args, data, self.name)
            if result is None:
                raise NotFound("No data")

            r = Response(
                body=result.get("body"),
                status=result.get("status_code", 200),
                content_type=result.get("content_type", "text/plain"),
            )
        else:
            extras = {}
            if callable(extra_template_data):
                extras = extra_template_data()
                if asyncio.iscoroutine(extras):
                    extras = await extras
            else:
                extras = extra_template_data
            url_labels_extra = {}
            if data.get("expandable_columns"):
                url_labels_extra = {"_labels": "on"}

            renderers = {
                key: path_with_format(request, key, {**url_labels_extra})
                for key in self.ds.renderers.keys()
            }
            url_csv_args = {"_size": "max", **url_labels_extra}
            url_csv = path_with_format(request, "csv", url_csv_args)
            url_csv_path = url_csv.split("?")[0]
            context = {
                **data,
                **extras,
                **{
                    "renderers":
                    renderers,
                    "url_csv":
                    url_csv,
                    "url_csv_path":
                    url_csv_path,
                    "url_csv_hidden_args": [(key, value) for key, value in urllib.parse.parse_qsl(request.query_string) if key not in ("_labels", "_facet", "_size")] + [("_size", "max")],
                    "datasette_version":
                    __version__,
                    "config":
                    self.ds.config_dict(),
                },
            }
            if "metadata" not in context:
                context["metadata"] = self.ds.metadata
            r = await self.render(templates, request=request, context=context)
            r.status = status_code

        ttl = request.args.get("_ttl", None)
        if ttl is None or not ttl.isdigit():
            if correct_hash_provided:
                ttl = self.ds.config("default_cache_ttl_hashed")
            else:
                ttl = self.ds.config("default_cache_ttl")

        return self.set_response_headers(r, ttl)
예제 #18
0
파일: index.py 프로젝트: sunray1/datasette
    async def get(self, request, as_format):
        databases = []
        for name, db in self.ds.databases.items():
            table_names = await db.table_names()
            hidden_table_names = set(await db.hidden_table_names())
            views = await db.view_names()
            # Perform counts only for immutable or DBS with <= COUNT_TABLE_LIMIT tables
            table_counts = {}
            if not db.is_mutable or len(table_names) <= COUNT_TABLE_LIMIT:
                table_counts = await db.table_counts(10)
                # If any of these are None it means at least one timed out - ignore them all
                if any(v is None for v in table_counts.values()):
                    table_counts = {}
            tables = {}
            for table in table_names:
                table_columns = await db.table_columns(table)
                tables[table] = {
                    "name": table,
                    "columns": table_columns,
                    "primary_keys": await db.primary_keys(table),
                    "count": table_counts.get(table),
                    "hidden": table in hidden_table_names,
                    "fts_table": await db.fts_table(table),
                    "num_relationships_for_sorting": 0,
                }

            if request.args.get(
                    "_sort") == "relationships" or not table_counts:
                # We will be sorting by number of relationships, so populate that field
                all_foreign_keys = await db.get_all_foreign_keys()
                for table, foreign_keys in all_foreign_keys.items():
                    count = len(foreign_keys["incoming"] +
                                foreign_keys["outgoing"])
                    tables[table]["num_relationships_for_sorting"] = count

            hidden_tables = [t for t in tables.values() if t["hidden"]]
            visible_tables = [t for t in tables.values() if not t["hidden"]]

            tables_and_views_truncated = list(
                sorted(
                    (t for t in tables.values() if t not in hidden_tables),
                    key=lambda t: (
                        t["num_relationships_for_sorting"],
                        t["count"] or 0,
                        t["name"],
                    ),
                    reverse=True,
                )[:TRUNCATE_AT])

            # Only add views if this is less than TRUNCATE_AT
            if len(tables_and_views_truncated) < TRUNCATE_AT:
                num_views_to_add = TRUNCATE_AT - len(
                    tables_and_views_truncated)
                for view_name in views[:num_views_to_add]:
                    tables_and_views_truncated.append({"name": view_name})

            databases.append({
                "name":
                name,
                "hash":
                db.hash,
                "color":
                db.hash[:6] if db.hash else hashlib.md5(
                    name.encode("utf8")).hexdigest()[:6],
                "path":
                self.database_url(name),
                "tables_and_views_truncated":
                tables_and_views_truncated,
                "tables_and_views_more":
                (len(visible_tables) + len(views)) > TRUNCATE_AT,
                "tables_count":
                len(visible_tables),
                "table_rows_sum":
                sum((t["count"] or 0) for t in visible_tables),
                "show_table_row_counts":
                bool(table_counts),
                "hidden_table_rows_sum":
                sum(t["count"] for t in hidden_tables
                    if t["count"] is not None),
                "hidden_tables_count":
                len(hidden_tables),
                "views_count":
                len(views),
            })

        databases.sort(key=lambda database: database["name"])

        if as_format:
            headers = {}
            if self.ds.cors:
                headers["Access-Control-Allow-Origin"] = "*"
            return Response(
                json.dumps({db["name"]: db
                            for db in databases},
                           cls=CustomJSONEncoder),
                content_type="application/json; charset=utf-8",
                headers=headers,
            )
        else:
            return self.render(
                ["index.html"],
                databases=databases,
                metadata=self.ds.metadata(),
                datasette_version=__version__,
            )
예제 #19
0
def json_renderer(args, data, view_name):
    """ Render a response as JSON """
    status_code = 200
    # Handle the _json= parameter which may modify data["rows"]
    json_cols = []
    if "_json" in args:
        json_cols = args.getlist("_json")
    if json_cols and "rows" in data and "columns" in data:
        data["rows"] = convert_specific_columns_to_json(
            data["rows"], data["columns"], json_cols)

    # unless _json_infinity=1 requested, replace infinity with None
    if "rows" in data and not value_as_boolean(args.get("_json_infinity",
                                                        "0")):
        data["rows"] = [remove_infinites(row) for row in data["rows"]]

    # Deal with the _shape option
    shape = args.get("_shape", "arrays")

    next_url = data.get("next_url")

    if shape == "arrayfirst":
        data = [row[0] for row in data["rows"]]
    elif shape in ("objects", "object", "array"):
        columns = data.get("columns")
        rows = data.get("rows")
        if rows and columns:
            data["rows"] = [dict(zip(columns, row)) for row in rows]
        if shape == "object":
            error = None
            if "primary_keys" not in data:
                error = "_shape=object is only available on tables"
            else:
                pks = data["primary_keys"]
                if not pks:
                    error = (
                        "_shape=object not available for tables with no primary keys"
                    )
                else:
                    object_rows = {}
                    for row in data["rows"]:
                        pk_string = path_from_row_pks(row, pks, not pks)
                        object_rows[pk_string] = row
                    data = object_rows
            if error:
                data = {"ok": False, "error": error}
        elif shape == "array":
            data = data["rows"]

    elif shape == "arrays":
        pass
    else:
        status_code = 400
        data = {
            "ok": False,
            "error": f"Invalid _shape: {shape}",
            "status": 400,
            "title": None,
        }
    # Handle _nl option for _shape=array
    nl = args.get("_nl", "")
    if nl and shape == "array":
        body = "\n".join(
            json.dumps(item, cls=CustomJSONEncoder) for item in data)
        content_type = "text/plain"
    else:
        body = json.dumps(data, cls=CustomJSONEncoder)
        content_type = "application/json; charset=utf-8"
    headers = {}
    if next_url:
        headers["link"] = f'<{next_url}>; rel="next"'
    return Response(body,
                    status=status_code,
                    headers=headers,
                    content_type=content_type)
예제 #20
0
def render_atom(datasette, request, sql, columns, rows, database, table,
                query_name, view_name, data):
    from datasette.views.base import DatasetteError

    if not REQUIRED_COLUMNS.issubset(columns):
        raise DatasetteError(
            "SQL query must return columns {}".format(
                ", ".join(REQUIRED_COLUMNS)),
            status=400,
        )
    fg = FeedGenerator()
    fg.generator(
        generator="Datasette",
        version=__version__,
        uri="https://github.com/simonw/datasette",
    )
    fg.id(request.url)
    fg.link(href=request.url, rel="self")
    fg.updated(max(row["atom_updated"] for row in rows))
    title = request.args.get("_feed_title", sql)
    if table:
        title += "/" + table
    if data.get("human_description_en"):
        title += ": " + data["human_description_en"]
    # If this is a canned query the configured title for that over-rides all others
    if query_name:
        try:
            title = datasette.metadata(
                database=database)["queries"][query_name]["title"]
        except (KeyError, TypeError):
            pass
    fg.title(title)

    clean_function = clean
    if query_name:
        # Check allow_unsafe_html_in_canned_queries
        plugin_config = datasette.plugin_config("datasette-atom")
        if plugin_config:
            allow_unsafe_html_in_canned_queries = plugin_config.get(
                "allow_unsafe_html_in_canned_queries")
            if allow_unsafe_html_in_canned_queries is True:
                clean_function = lambda s: s
            elif isinstance(allow_unsafe_html_in_canned_queries, dict):
                allowlist = allow_unsafe_html_in_canned_queries.get(
                    database) or []
                if query_name in allowlist:
                    clean_function = lambda s: s

    # And the rows
    for row in reversed(rows):
        entry = fg.add_entry()
        entry.id(str(row["atom_id"]))
        if "atom_content_html" in columns:
            entry.content(clean_function(row["atom_content_html"]),
                          type="html")
        elif "atom_content" in columns:
            entry.content(row["atom_content"], type="text")
        entry.updated(row["atom_updated"])
        entry.title(str(row["atom_title"]))
        # atom_link is optional
        if "atom_link" in columns:
            entry.link(href=row["atom_link"])
        if "atom_author_name" in columns and row["atom_author_name"]:
            author = {
                "name": row["atom_author_name"],
            }
            for key in ("uri", "email"):
                colname = "atom_author_{}".format(key)
                if colname in columns and row[colname]:
                    author[key] = row[colname]
            entry.author(author)

    return Response(
        fg.atom_str(pretty=True),
        content_type="application/xml; charset=utf-8",
        status=200,
    )