예제 #1
0
def _get_output_hash(value, func_or_code, hash_funcs):
    hasher = hashlib.new("md5")
    update_hash(
        value,
        hasher=hasher,
        hash_funcs=hash_funcs,
        hash_reason=HashReason.CACHING_FUNC_OUTPUT,
        hash_source=func_or_code,
    )
    return hasher.digest()
예제 #2
0
def _get_output_hash(value: Any, func_or_code: Callable[..., Any],
                     hash_funcs: Optional[HashFuncsDict]) -> bytes:
    hasher = hashlib.new("md5")
    update_hash(
        value,
        hasher=hasher,
        hash_funcs=hash_funcs,
        hash_reason=HashReason.CACHING_FUNC_OUTPUT,
        hash_source=func_or_code,
    )
    return hasher.digest()
예제 #3
0
def _hash_func(func, hash_funcs) -> str:
    # Create the unique key for a function's cache. The cache will be retrieved
    # from inside the wrapped function.
    #
    # A naive implementation would involve simply creating the cache object
    # right in the wrapper, which in a normal Python script would be executed
    # only once. But in Streamlit, we reload all modules related to a user's
    # app when the app is re-run, which means that - among other things - all
    # function decorators in the app will be re-run, and so any decorator-local
    # objects will be recreated.
    #
    # Furthermore, our caches can be destroyed and recreated (in response to
    # cache clearing, for example), which means that retrieving the function's
    # cache in the decorator (so that the wrapped function can save a lookup)
    # is incorrect: the cache itself may be recreated between
    # decorator-evaluation time and decorated-function-execution time. So we
    # must retrieve the cache object *and* perform the cached-value lookup
    # inside the decorated function.
    func_hasher = hashlib.new("md5")

    # Include the function's __module__ and __qualname__ strings in the hash.
    # This means that two identical functions in different modules
    # will not share a hash; it also means that two identical *nested*
    # functions in the same module will not share a hash.
    # We do not pass `hash_funcs` here, because we don't want our function's
    # name to get an unexpected hash.
    update_hash(
        (func.__module__, func.__qualname__),
        hasher=func_hasher,
        hash_funcs=None,
        hash_reason=HashReason.CACHING_FUNC_BODY,
        hash_source=func,
    )

    # Include the function's body in the hash. We *do* pass hash_funcs here,
    # because this step will be hashing any objects referenced in the function
    # body.
    update_hash(
        func,
        hasher=func_hasher,
        hash_funcs=hash_funcs,
        hash_reason=HashReason.CACHING_FUNC_BODY,
        hash_source=func,
    )
    cache_key = func_hasher.hexdigest()
    _LOGGER.debug(
        "mem_cache key for %s.%s: %s", func.__module__, func.__qualname__, cache_key
    )
    return cache_key
예제 #4
0
    def has_changes(self) -> bool:
        current_frame = inspect.currentframe()

        assert current_frame is not None
        caller_frame = current_frame.f_back

        current_file = inspect.getfile(current_frame)
        caller_file = inspect.getfile(caller_frame)
        real_caller_is_parent_frame = current_file == caller_file
        if real_caller_is_parent_frame:
            caller_frame = caller_frame.f_back

        filename, caller_lineno, code_context = _get_frame_info(caller_frame)

        assert code_context is not None
        code_context = code_context[0]

        context_indent = len(code_context) - len(code_context.lstrip())

        lines = []
        # TODO: Memoize open(filename, 'r') in a way that clears the memoized
        # version with each run of the user's script. Then use the memoized
        # text here, in st.echo, and other places.
        with open(filename, "r") as f:
            for line in f.readlines()[caller_lineno:]:
                if line.strip() == "":
                    lines.append(line)
                indent = len(line) - len(line.lstrip())
                if indent <= context_indent:
                    break
                if line.strip() and not line.lstrip().startswith("#"):
                    lines.append(line)

        while lines[-1].strip() == "":
            lines.pop()

        code_block = "".join(lines)
        program = textwrap.dedent(code_block)

        context = Context(
            dict(caller_frame.f_globals, **caller_frame.f_locals), {}, {})
        code = compile(program, filename, "exec")

        hasher = hashlib.new("md5")
        update_hash(
            code,
            hasher=hasher,
            context=context,
            hash_reason=HashReason.CACHING_BLOCK,
            hash_source=code,
        )

        key = hasher.hexdigest()
        _LOGGER.debug("Cache key: %s", key)

        try:
            value, _ = _read_from_cache(
                mem_cache=self._mem_cache,
                key=key,
                persist=self._persist,
                allow_output_mutation=self._allow_output_mutation,
                func_or_code=code,
            )
            self.update(value)

        except CacheKeyNotFoundError:
            if self._allow_output_mutation and not self._persist:
                # If we don't hash the results, we don't need to use exec and just return True.
                # This way line numbers will be correct.
                _write_to_cache(
                    mem_cache=self._mem_cache,
                    key=key,
                    value=self,
                    persist=False,
                    allow_output_mutation=True,
                    func_or_code=code,
                )
                return True

            exec(code, caller_frame.f_globals, caller_frame.f_locals)
            _write_to_cache(
                mem_cache=self._mem_cache,
                key=key,
                value=self,
                persist=self._persist,
                allow_output_mutation=self._allow_output_mutation,
                func_or_code=code,
            )

        # Return False so that we have control over the execution.
        return False
예제 #5
0
        def get_or_create_cached_value():
            # First, get the cache that's attached to this function.
            # This cache's key is generated (above) from the function's code.
            mem_cache = _mem_caches.get_cache(cache_key, max_entries, ttl)

            # Next, calculate the key for the value we'll be searching for
            # within that cache. This key is generated from both the function's
            # code and the arguments that are passed into it. (Even though this
            # key is used to index into a per-function cache, it must be
            # globally unique, because it is *also* used for a global on-disk
            # cache that is *not* per-function.)
            value_hasher = hashlib.new("md5")

            if args:
                update_hash(
                    args,
                    hasher=value_hasher,
                    hash_funcs=hash_funcs,
                    hash_reason=HashReason.CACHING_FUNC_ARGS,
                    hash_source=func,
                )

            if kwargs:
                update_hash(
                    kwargs,
                    hasher=value_hasher,
                    hash_funcs=hash_funcs,
                    hash_reason=HashReason.CACHING_FUNC_ARGS,
                    hash_source=func,
                )

            value_key = value_hasher.hexdigest()

            # Avoid recomputing the body's hash by just appending the
            # previously-computed hash to the arg hash.
            value_key = "%s-%s" % (value_key, cache_key)

            _LOGGER.debug("Cache key: %s", value_key)

            try:
                return_value = _read_from_cache(
                    mem_cache=mem_cache,
                    key=value_key,
                    persist=persist,
                    allow_output_mutation=allow_output_mutation,
                    func_or_code=func,
                    hash_funcs=hash_funcs,
                )
                _LOGGER.debug("Cache hit: %s", func)

            except CacheKeyNotFoundError:
                _LOGGER.debug("Cache miss: %s", func)

                with _calling_cached_function(func):
                    if suppress_st_warning:
                        with suppress_cached_st_function_warning():
                            return_value = func(*args, **kwargs)
                    else:
                        return_value = func(*args, **kwargs)

                _write_to_cache(
                    mem_cache=mem_cache,
                    key=value_key,
                    value=return_value,
                    persist=persist,
                    allow_output_mutation=allow_output_mutation,
                    func_or_code=func,
                    hash_funcs=hash_funcs,
                )

            return return_value
예제 #6
0
def cache(
    func=None,
    persist=False,
    allow_output_mutation=False,
    show_spinner=True,
    suppress_st_warning=False,
    hash_funcs=None,
    max_entries=None,
    ttl=None,
):
    """Function decorator to memoize function executions.

    Parameters
    ----------
    func : callable
        The function to cache. Streamlit hashes the function and dependent code.

    persist : boolean
        Whether to persist the cache on disk.

    allow_output_mutation : boolean
        Streamlit normally shows a warning when return values are not mutated, as that
        can have unintended consequences. This is done by hashing the return value internally.

        If you know what you're doing and would like to override this warning, set this to True.

    show_spinner : boolean
        Enable the spinner. Default is True to show a spinner when there is
        a cache miss.

    suppress_st_warning : boolean
        Suppress warnings about calling Streamlit functions from within
        the cached function.

    hash_funcs : dict or None
        Mapping of types or fully qualified names to hash functions. This is used to override
        the behavior of the hasher inside Streamlit's caching mechanism: when the hasher
        encounters an object, it will first check to see if its type matches a key in this
        dict and, if so, will use the provided function to generate a hash for it. See below
        for an example of how this can be used.

    max_entries : int or None
        The maximum number of entries to keep in the cache, or None
        for an unbounded cache. (When a new entry is added to a full cache,
        the oldest cached entry will be removed.) The default is None.

    ttl : float or None
        The maximum number of seconds to keep an entry in the cache, or
        None if cache entries should not expire. The default is None.

    Example
    -------
    >>> @st.cache
    ... def fetch_and_clean_data(url):
    ...     # Fetch data from URL here, and then clean it up.
    ...     return data
    ...
    >>> d1 = fetch_and_clean_data(DATA_URL_1)
    >>> # Actually executes the function, since this is the first time it was
    >>> # encountered.
    >>>
    >>> d2 = fetch_and_clean_data(DATA_URL_1)
    >>> # Does not execute the function. Instead, returns its previously computed
    >>> # value. This means that now the data in d1 is the same as in d2.
    >>>
    >>> d3 = fetch_and_clean_data(DATA_URL_2)
    >>> # This is a different URL, so the function executes.

    To set the `persist` parameter, use this command as follows:

    >>> @st.cache(persist=True)
    ... def fetch_and_clean_data(url):
    ...     # Fetch data from URL here, and then clean it up.
    ...     return data

    To disable hashing return values, set the `allow_output_mutation` parameter to `True`:

    >>> @st.cache(allow_output_mutation=True)
    ... def fetch_and_clean_data(url):
    ...     # Fetch data from URL here, and then clean it up.
    ...     return data


    To override the default hashing behavior, pass a custom hash function.
    You can do that by mapping a type (e.g. `MongoClient`) to a hash function (`id`) like this:

    >>> @st.cache(hash_funcs={MongoClient: id})
    ... def connect_to_database(url):
    ...     return MongoClient(url)

    Alternatively, you can map the type's fully-qualified name
    (e.g. `"pymongo.mongo_client.MongoClient"`) to the hash function instead:

    >>> @st.cache(hash_funcs={"pymongo.mongo_client.MongoClient": id})
    ... def connect_to_database(url):
    ...     return MongoClient(url)

    """
    _LOGGER.debug("Entering st.cache: %s", func)

    # Support passing the params via function decorator, e.g.
    # @st.cache(persist=True, allow_output_mutation=True)
    if func is None:
        return lambda f: cache(
            func=f,
            persist=persist,
            allow_output_mutation=allow_output_mutation,
            show_spinner=show_spinner,
            suppress_st_warning=suppress_st_warning,
            hash_funcs=hash_funcs,
            max_entries=max_entries,
            ttl=ttl,
        )

    # Create the unique key for this function's cache. The cache will be
    # retrieved from inside the wrapped function.
    #
    # A naive implementation would involve simply creating the cache object
    # right here in the wrapper, which in a normal Python script would be
    # executed only once. But in Streamlit, we reload all modules related to a
    # user's app when the app is re-run, which means that - among other
    # things - all function decorators in the app will be re-run, and so any
    # decorator-local objects will be recreated.
    #
    # Furthermore, our caches can be destroyed and recreated (in response
    # to cache clearing, for example), which means that retrieving the
    # function's cache here (so that the wrapped function can save a lookup)
    # is incorrect: the cache itself may be recreated between
    # decorator-evaluation time and decorated-function-execution time. So
    # we must retrieve the cache object *and* perform the cached-value lookup
    # inside the decorated function.

    func_hasher = hashlib.new("md5")

    # Include the function's __module__ and __qualname__ strings in the hash.
    # This means that two identical functions in different modules
    # will not share a hash; it also means that two identical *nested*
    # functions in the same module will not share a hash.
    # We do not pass `hash_funcs` here, because we don't want our function's
    # name to get an unexpected hash.
    update_hash(
        (func.__module__, func.__qualname__),
        hasher=func_hasher,
        hash_funcs=None,
        hash_reason=HashReason.CACHING_FUNC_BODY,
        hash_source=func,
    )

    # Include the function's body in the hash. We *do* pass hash_funcs here,
    # because this step will be hashing any objects referenced in the function
    # body.
    update_hash(
        func,
        hasher=func_hasher,
        hash_funcs=hash_funcs,
        hash_reason=HashReason.CACHING_FUNC_BODY,
        hash_source=func,
    )

    cache_key = func_hasher.hexdigest()
    _LOGGER.debug("mem_cache key for %s.%s: %s", func.__module__,
                  func.__qualname__, cache_key)

    @functools.wraps(func)
    def wrapped_func(*args, **kwargs):
        """This function wrapper will only call the underlying function in
        the case of a cache miss. Cached objects are stored in the cache/
        directory."""

        if not config.get_option("client.caching"):
            _LOGGER.debug("Purposefully skipping cache")
            return func(*args, **kwargs)

        name = func.__qualname__

        if len(args) == 0 and len(kwargs) == 0:
            message = "Running `%s()`." % name
        else:
            message = "Running `%s(...)`." % name

        def get_or_create_cached_value():
            # First, get the cache that's attached to this function.
            # This cache's key is generated (above) from the function's code.
            mem_cache = _mem_caches.get_cache(cache_key, max_entries, ttl)

            # Next, calculate the key for the value we'll be searching for
            # within that cache. This key is generated from both the function's
            # code and the arguments that are passed into it. (Even though this
            # key is used to index into a per-function cache, it must be
            # globally unique, because it is *also* used for a global on-disk
            # cache that is *not* per-function.)
            value_hasher = hashlib.new("md5")

            if args:
                update_hash(
                    args,
                    hasher=value_hasher,
                    hash_funcs=hash_funcs,
                    hash_reason=HashReason.CACHING_FUNC_ARGS,
                    hash_source=func,
                )

            if kwargs:
                update_hash(
                    kwargs,
                    hasher=value_hasher,
                    hash_funcs=hash_funcs,
                    hash_reason=HashReason.CACHING_FUNC_ARGS,
                    hash_source=func,
                )

            value_key = value_hasher.hexdigest()

            # Avoid recomputing the body's hash by just appending the
            # previously-computed hash to the arg hash.
            value_key = "%s-%s" % (value_key, cache_key)

            _LOGGER.debug("Cache key: %s", value_key)

            try:
                return_value = _read_from_cache(
                    mem_cache=mem_cache,
                    key=value_key,
                    persist=persist,
                    allow_output_mutation=allow_output_mutation,
                    func_or_code=func,
                    hash_funcs=hash_funcs,
                )
                _LOGGER.debug("Cache hit: %s", func)

            except CacheKeyNotFoundError:
                _LOGGER.debug("Cache miss: %s", func)

                with _calling_cached_function(func):
                    if suppress_st_warning:
                        with suppress_cached_st_function_warning():
                            return_value = func(*args, **kwargs)
                    else:
                        return_value = func(*args, **kwargs)

                _write_to_cache(
                    mem_cache=mem_cache,
                    key=value_key,
                    value=return_value,
                    persist=persist,
                    allow_output_mutation=allow_output_mutation,
                    func_or_code=func,
                    hash_funcs=hash_funcs,
                )

            return return_value

        if show_spinner:
            with st.spinner(message):
                return get_or_create_cached_value()
        else:
            return get_or_create_cached_value()

    # Make this a well-behaved decorator by preserving important function
    # attributes.
    try:
        wrapped_func.__dict__.update(func.__dict__)
    except AttributeError:
        pass

    return wrapped_func