Skip to content
11 changes: 6 additions & 5 deletions dash/_callback_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import contextvars
import typing

import flask

from . import exceptions
from . import backends
from ._utils import AttributeDict, stringify_id


Expand Down Expand Up @@ -220,14 +219,15 @@ def record_timing(name, duration, description=None):
:param description: A description of the resource.
:type description: string or None
"""
timing_information = getattr(flask.g, "timing_information", {})
request = backends.backend.request_adapter()
timing_information = getattr(request.context, "timing_information", {})

if name in timing_information:
raise KeyError(f'Duplicate resource name "{name}" found.')

timing_information[name] = {"dur": round(duration * 1000), "desc": description}

setattr(flask.g, "timing_information", timing_information)
setattr(request.context, "timing_information", timing_information)

@property
@has_context
Expand All @@ -250,7 +250,8 @@ def using_outputs_grouping(self):
@property
@has_context
def timing_information(self):
return getattr(flask.g, "timing_information", {})
request = backends.backend.request_adapter()
return getattr(request.context, "timing_information", {})

@has_context
def set_props(self, component_id: typing.Union[str, dict], props: dict):
Expand Down
5 changes: 2 additions & 3 deletions dash/_configs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ._utils import get_root_path
import os
import flask

# noinspection PyCompatibility
from . import exceptions
from ._utils import AttributeDict
Expand Down Expand Up @@ -127,7 +126,7 @@ def pages_folder_config(name, pages_folder, use_pages):
if not pages_folder:
return None
is_custom_folder = str(pages_folder) != "pages"
pages_folder_path = os.path.join(flask.helpers.get_root_path(name), pages_folder)
pages_folder_path = os.path.join(get_root_path(name), pages_folder)
if (use_pages or is_custom_folder) and not os.path.isdir(pages_folder_path):
error_msg = f"""
A folder called `{pages_folder}` does not exist. If a folder for pages is not
Expand Down
4 changes: 2 additions & 2 deletions dash/_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ._callback_context import context_value
from ._get_app import get_app
from ._get_paths import get_relative_path
from ._utils import AttributeDict
from ._utils import AttributeDict, get_root_path

CONFIG = AttributeDict()
PAGE_REGISTRY = collections.OrderedDict()
Expand Down Expand Up @@ -98,7 +98,7 @@ def _path_to_module_name(path):
def _infer_module_name(page_path):
relative_path = page_path.split(CONFIG.pages_folder)[-1]
module = _path_to_module_name(relative_path)
proj_root = flask.helpers.get_root_path(CONFIG.name)
proj_root = get_root_path(CONFIG.name)
if CONFIG.pages_folder.startswith(proj_root):
parent_path = CONFIG.pages_folder[len(proj_root) :]
else:
Expand Down
59 changes: 59 additions & 0 deletions dash/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import uuid
import hashlib
import importlib
from collections import abc
import subprocess
import logging
Expand All @@ -12,6 +13,7 @@
import string
import inspect
import re
import os

from html import escape
from functools import wraps
Expand Down Expand Up @@ -322,3 +324,60 @@ def pascal_case(name: Union[str, None]):
return s[0].upper() + re.sub(
r"[\-_\.]+([a-z])", lambda match: match.group(1).upper(), s[1:]
)


def get_root_path(import_name: str) -> str:
"""Find the root path of a package, or the path that contains a
module. If it cannot be found, returns the current working
directory.

Not to be confused with the value returned by :func:`find_package`.

:meta private:
"""
# Module already imported and has a file attribute. Use that first.
mod = sys.modules.get(import_name)

if mod is not None and hasattr(mod, "__file__") and mod.__file__ is not None:
return os.path.dirname(os.path.abspath(mod.__file__))

# Next attempt: check the loader.
try:
spec = importlib.util.find_spec(import_name)

if spec is None:
raise ValueError
except (ImportError, ValueError):
loader = None
else:
loader = spec.loader

# Loader does not exist or we're referring to an unloaded main
# module or a main module without path (interactive sessions), go
# with the current working directory.
if loader is None:
return os.getcwd()

if hasattr(loader, "get_filename"):
filepath = loader.get_filename(import_name) # pyright: ignore
else:
# Fall back to imports.
__import__(import_name)
mod = sys.modules[import_name]
filepath = getattr(mod, "__file__", None)

# If we don't have a file path it might be because it is a
# namespace package. In this case pick the root path from the
# first module that is contained in the package.
if filepath is None:
raise RuntimeError(
"No root path can be found for the provided module"
f" {import_name!r}. This can happen because the module"
" came from an import hook that does not provide file"
" name information or because it's a namespace package."
" In this case the root path needs to be explicitly"
" provided."
)

# filepath is import_name.py for a module, or __init__.py for a package.
return os.path.dirname(os.path.abspath(filepath)) # type: ignore[no-any-return]
3 changes: 1 addition & 2 deletions dash/_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import re
from textwrap import dedent
from keyword import iskeyword
import flask

from ._grouping import grouping_len, map_grouping
from ._no_update import NoUpdate
Expand Down Expand Up @@ -511,7 +510,7 @@ def validate_use_pages(config):
"`dash.register_page()` must be called after app instantiation"
)

if flask.has_request_context():
if backends.backend.has_request_context():
raise exceptions.PageError(
"""
dash.register_page() can’t be called within a callback as it updates dash.page_registry, which is a global variable.
Expand Down
22 changes: 22 additions & 0 deletions dash/backends/_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@ def after_request(self, func: Callable[[], Any] | None):
# FastAPI does not have after_request, but we can use middleware
self.server.middleware("http")(self._make_after_middleware(func))

def has_request_context(self) -> bool:
try:
get_current_request()
return True
except RuntimeError:
return False

def run(self, dash_app: Dash, host, port, debug, **kwargs):
frame = inspect.stack()[2]
dev_tools = dash_app._dev_tools # pylint: disable=protected-access
Expand Down Expand Up @@ -433,6 +440,14 @@ async def view_func(_request: Request, body: dict = Body(...)):
include_in_schema=True,
)

def enable_compression(self) -> None:
from fastapi.middleware.gzip import GZipMiddleware

self.server.add_middleware(GZipMiddleware, minimum_size=500)

if not hasattr(self.server.config, "COMPRESS_ALGORITHM"):
self.server.config["COMPRESS_ALGORITHM"] = ["gzip"]


class FastAPIRequestAdapter(RequestAdapter):
def __init__(self):
Expand All @@ -443,6 +458,13 @@ def __call__(self):
self._request = get_current_request()
return self

@property
def context(self):
if self._request is None:
raise RuntimeError("No active request in context")

return self._request.state

@property
def root(self):
return str(self._request.base_url)
Expand Down
31 changes: 31 additions & 0 deletions dash/backends/_flask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from contextvars import copy_context
from importlib_metadata import version as _get_distribution_version
from typing import TYPE_CHECKING, Any, Callable, Dict
import asyncio
import pkgutil
Expand All @@ -16,15 +17,18 @@
request,
jsonify,
g as flask_g,
has_request_context,
)
from werkzeug.debug import tbtools

from dash.fingerprint import check_fingerprint
from dash import _validate
from dash.exceptions import PreventUpdate, InvalidResourceError
from dash._callback import _invoke_callback, _async_invoke_callback
from dash._utils import parse_version
from .base_server import BaseDashServer, RequestAdapter


if TYPE_CHECKING: # pragma: no cover - typing only
from dash import Dash

Expand Down Expand Up @@ -127,6 +131,9 @@ def after_request(self, func: Callable[[Any], Any]):
# Flask after_request expects a function(response) -> response
self.server.after_request(func)

def has_request_context(self) -> bool:
return has_request_context()

def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: Any):
self.server.run(host=host, port=port, debug=debug, **kwargs)

Expand Down Expand Up @@ -304,6 +311,24 @@ def _sync_view_func(*args, handler=handler, **kwargs):
route, endpoint=endpoint, view_func=view_func, methods=methods
)

def enable_compression(self) -> None:
try:
import flask_compress # pylint: disable=import-outside-toplevel

Compress = flask_compress.Compress
Compress(self.server)
_flask_compress_version = parse_version(
_get_distribution_version("flask_compress")
)
if not hasattr(
self.server.config, "COMPRESS_ALGORITHM"
) and _flask_compress_version >= parse_version("1.6.0"):
self.server.config["COMPRESS_ALGORITHM"] = ["gzip"]
except ImportError as error:
raise ImportError(
"To use the compress option, you need to install dash[compress]"
) from error


class FlaskRequestAdapter(RequestAdapter):
"""Flask implementation using property-based accessors."""
Expand All @@ -316,6 +341,12 @@ def __init__(self) -> None:
def __call__(self, *args: Any, **kwds: Any):
return self

@property
def context(self):
if not has_request_context():
raise RuntimeError("No active request in context")
return flask_g

@property
def args(self):
return self._request.args
Expand Down
44 changes: 39 additions & 5 deletions dash/backends/_quart.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from importlib_metadata import version as _get_distribution_version
from contextvars import copy_context
import typing as _t
import mimetypes
Expand All @@ -16,18 +17,20 @@
jsonify,
request,
Blueprint,
g,
g as quart_g,
has_request_context,
)
except ImportError:
Quart = None
Response = None
jsonify = None
request = None
Blueprint = None
g = None
quart_g = None

from dash.exceptions import PreventUpdate, InvalidResourceError
from dash.fingerprint import check_fingerprint
from dash._utils import parse_version
from dash import _validate, Dash
from .base_server import BaseDashServer
from ._utils import format_traceback_html
Expand Down Expand Up @@ -94,15 +97,17 @@ async def _wrap_errors(error):
def register_timing_hooks(self, _first_run: bool): # type: ignore[name-defined] parity with Flask factory
@self.server.before_request
async def _before_request(): # pragma: no cover - timing infra
if g is not None:
g.timing_information = { # type: ignore[attr-defined]
if quart_g is not None:
quart_g.timing_information = { # type: ignore[attr-defined]
"__dash_server": {"dur": time.time(), "desc": None}
}

@self.server.after_request
async def _after_request(response): # pragma: no cover - timing infra
timing_information = (
getattr(g, "timing_information", None) if g is not None else None
getattr(quart_g, "timing_information", None)
if quart_g is not None
else None
)
if timing_information is None:
return response
Expand Down Expand Up @@ -180,6 +185,11 @@ async def _after(response):
await result
return response

def has_request_context(self) -> bool:
if has_request_context is None:
raise RuntimeError("Quart not installed; cannot check request context")
return has_request_context()

def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any):
self.config = {"debug": debug, **kwargs} if debug else kwargs
self.server.run(host=host, port=port, debug=debug, **kwargs)
Expand Down Expand Up @@ -304,13 +314,37 @@ def _serve_default_favicon(self):
pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon"
)

def enable_compression(self) -> None:
try:
import quart_compress # pylint: disable=import-outside-toplevel

Compress = quart_compress.Compress
Compress(self.server)
_flask_compress_version = parse_version(
_get_distribution_version("quart_compress")
)
if not hasattr(
self.server.config, "COMPRESS_ALGORITHM"
) and _flask_compress_version >= parse_version("1.6.0"):
self.server.config["COMPRESS_ALGORITHM"] = ["gzip"]
except ImportError as error:
raise ImportError(
"To use the compress option, you need to install quart_compress."
) from error


class QuartRequestAdapter:
def __init__(self) -> None:
self._request = request # type: ignore[assignment]
if self._request is None:
raise RuntimeError("Quart not installed; cannot access request context")

@property
def context(self):
if not has_request_context():
raise RuntimeError("No active request in context")
return quart_g

@property
def request(self) -> _t.Any:
return self._request
Expand Down
Loading
Loading