Exemplo n.º 1
0
class MappingKernelManager(MultiKernelManager):
    """A KernelManager that handles
    - File mapping
    - HTTP error handling
    - Kernel message filtering
    """

    @default("kernel_manager_class")
    def _default_kernel_manager_class(self):
        return "jupyter_client.ioloop.IOLoopKernelManager"

    kernel_argv = List(Unicode())

    root_dir = Unicode(config=True)

    _kernel_connections = Dict()

    _kernel_ports = Dict()

    _culler_callback = None

    _initialized_culler = False

    @default("root_dir")
    def _default_root_dir(self):
        try:
            return self.parent.root_dir
        except AttributeError:
            return os.getcwd()

    @validate("root_dir")
    def _update_root_dir(self, proposal):
        """Do a bit of validation of the root dir."""
        value = proposal["value"]
        if not os.path.isabs(value):
            # If we receive a non-absolute path, make it absolute.
            value = os.path.abspath(value)
        if not exists(value) or not os.path.isdir(value):
            raise TraitError("kernel root dir %r is not a directory" % value)
        return value

    cull_idle_timeout = Integer(
        0,
        config=True,
        help="""Timeout (in seconds) after which a kernel is considered idle and ready to be culled.
        Values of 0 or lower disable culling. Very short timeouts may result in kernels being culled
        for users with poor network connections.""",
    )

    cull_interval_default = 300  # 5 minutes
    cull_interval = Integer(
        cull_interval_default,
        config=True,
        help="""The interval (in seconds) on which to check for idle kernels exceeding the cull timeout value.""",
    )

    cull_connected = Bool(
        False,
        config=True,
        help="""Whether to consider culling kernels which have one or more connections.
        Only effective if cull_idle_timeout > 0.""",
    )

    cull_busy = Bool(
        False,
        config=True,
        help="""Whether to consider culling kernels which are busy.
        Only effective if cull_idle_timeout > 0.""",
    )

    buffer_offline_messages = Bool(
        True,
        config=True,
        help="""Whether messages from kernels whose frontends have disconnected should be buffered in-memory.

        When True (default), messages are buffered and replayed on reconnect,
        avoiding lost messages due to interrupted connectivity.

        Disable if long-running kernels will produce too much output while
        no frontends are connected.
        """,
    )

    kernel_info_timeout = Float(
        60,
        config=True,
        help="""Timeout for giving up on a kernel (in seconds).

        On starting and restarting kernels, we check whether the
        kernel is running and responsive by sending kernel_info_requests.
        This sets the timeout in seconds for how long the kernel can take
        before being presumed dead.
        This affects the MappingKernelManager (which handles kernel restarts)
        and the ZMQChannelsHandler (which handles the startup).
        """,
    )

    _kernel_buffers = Any()

    @default("_kernel_buffers")
    def _default_kernel_buffers(self):
        return defaultdict(lambda: {"buffer": [], "session_key": "", "channels": {}})

    last_kernel_activity = Instance(
        datetime, help="The last activity on any kernel, including shutting down a kernel"
    )

    def __init__(self, **kwargs):
        self.pinned_superclass = MultiKernelManager
        self.pinned_superclass.__init__(self, **kwargs)
        self.last_kernel_activity = utcnow()

    allowed_message_types = List(
        trait=Unicode(),
        config=True,
        help="""White list of allowed kernel message types.
        When the list is empty, all message types are allowed.
        """,
    )

    allow_tracebacks = Bool(
        True, config=True, help=("Whether to send tracebacks to clients on exceptions.")
    )

    traceback_replacement_message = Unicode(
        "An exception occurred at runtime, which is not shown due to security reasons.",
        config=True,
        help=("Message to print when allow_tracebacks is False, and an exception occurs"),
    )

    # -------------------------------------------------------------------------
    # Methods for managing kernels and sessions
    # -------------------------------------------------------------------------

    def _handle_kernel_died(self, kernel_id):
        """notice that a kernel died"""
        self.log.warning("Kernel %s died, removing from map.", kernel_id)
        self.remove_kernel(kernel_id)

    def cwd_for_path(self, path):
        """Turn API path into absolute OS path."""
        os_path = to_os_path(path, self.root_dir)
        # in the case of documents and kernels not being on the same filesystem,
        # walk up to root_dir if the paths don't exist
        while not os.path.isdir(os_path) and os_path != self.root_dir:
            os_path = os.path.dirname(os_path)
        return os_path

    async def start_kernel(self, kernel_id=None, path=None, **kwargs):
        """Start a kernel for a session and return its kernel_id.

        Parameters
        ----------
        kernel_id : uuid
            The uuid to associate the new kernel with. If this
            is not None, this kernel will be persistent whenever it is
            requested.
        path : API path
            The API path (unicode, '/' delimited) for the cwd.
            Will be transformed to an OS path relative to root_dir.
        kernel_name : str
            The name identifying which kernel spec to launch. This is ignored if
            an existing kernel is returned, but it may be checked in the future.
        """
        if kernel_id is None or kernel_id not in self:
            if path is not None:
                kwargs["cwd"] = self.cwd_for_path(path)
            if kernel_id is not None:
                kwargs["kernel_id"] = kernel_id
            kernel_id = await ensure_async(self.pinned_superclass.start_kernel(self, **kwargs))
            self._kernel_connections[kernel_id] = 0
            fut = asyncio.ensure_future(self._finish_kernel_start(kernel_id))
            if not getattr(self, "use_pending_kernels", None):
                await fut
            # add busy/activity markers:
            kernel = self.get_kernel(kernel_id)
            kernel.execution_state = "starting"
            kernel.reason = ""
            kernel.last_activity = utcnow()
            self.log.info("Kernel started: %s" % kernel_id)
            self.log.debug("Kernel args: %r" % kwargs)

            # Increase the metric of number of kernels running
            # for the relevant kernel type by 1
            KERNEL_CURRENTLY_RUNNING_TOTAL.labels(type=self._kernels[kernel_id].kernel_name).inc()

        else:
            self.log.info("Using existing kernel: %s" % kernel_id)

        # Initialize culling if not already
        if not self._initialized_culler:
            self.initialize_culler()

        return kernel_id

    async def _finish_kernel_start(self, kernel_id):
        km = self.get_kernel(kernel_id)
        if hasattr(km, "ready"):
            try:
                await km.ready
            except Exception:
                self.log.exception(km.ready.exception())
                return

        self._kernel_ports[kernel_id] = km.ports
        self.start_watching_activity(kernel_id)
        # register callback for failed auto-restart
        self.add_restart_callback(
            kernel_id,
            lambda: self._handle_kernel_died(kernel_id),
            "dead",
        )

    def ports_changed(self, kernel_id):
        """Used by ZMQChannelsHandler to determine how to coordinate nudge and replays.

        Ports are captured when starting a kernel (via MappingKernelManager).  Ports
        are considered changed (following restarts) if the referenced KernelManager
        is using a set of ports different from those captured at startup.  If changes
        are detected, the captured set is updated and a value of True is returned.

        NOTE: Use is exclusive to ZMQChannelsHandler because this object is a singleton
        instance while ZMQChannelsHandler instances are per WebSocket connection that
        can vary per kernel lifetime.
        """
        changed_ports = self._get_changed_ports(kernel_id)
        if changed_ports:
            # If changed, update captured ports and return True, else return False.
            self.log.debug(f"Port change detected for kernel: {kernel_id}")
            self._kernel_ports[kernel_id] = changed_ports
            return True
        return False

    def _get_changed_ports(self, kernel_id):
        """Internal method to test if a kernel's ports have changed and, if so, return their values.

        This method does NOT update the captured ports for the kernel as that can only be done
        by ZMQChannelsHandler, but instead returns the new list of ports if they are different
        than those captured at startup.  This enables the ability to conditionally restart
        activity monitoring immediately following a kernel's restart (if ports have changed).
        """
        # Get current ports and return comparison with ports captured at startup.
        km = self.get_kernel(kernel_id)
        if km.ports != self._kernel_ports[kernel_id]:
            return km.ports
        return None

    def start_buffering(self, kernel_id, session_key, channels):
        """Start buffering messages for a kernel

        Parameters
        ----------
        kernel_id : str
            The id of the kernel to stop buffering.
        session_key : str
            The session_key, if any, that should get the buffer.
            If the session_key matches the current buffered session_key,
            the buffer will be returned.
        channels : dict({'channel': ZMQStream})
            The zmq channels whose messages should be buffered.
        """

        if not self.buffer_offline_messages:
            for channel, stream in channels.items():
                stream.close()
            return

        self.log.info("Starting buffering for %s", session_key)
        self._check_kernel_id(kernel_id)
        # clear previous buffering state
        self.stop_buffering(kernel_id)
        buffer_info = self._kernel_buffers[kernel_id]
        # record the session key because only one session can buffer
        buffer_info["session_key"] = session_key
        # TODO: the buffer should likely be a memory bounded queue, we're starting with a list to keep it simple
        buffer_info["buffer"] = []
        buffer_info["channels"] = channels

        # forward any future messages to the internal buffer
        def buffer_msg(channel, msg_parts):
            self.log.debug("Buffering msg on %s:%s", kernel_id, channel)
            buffer_info["buffer"].append((channel, msg_parts))

        for channel, stream in channels.items():
            stream.on_recv(partial(buffer_msg, channel))

    def get_buffer(self, kernel_id, session_key):
        """Get the buffer for a given kernel

        Parameters
        ----------
        kernel_id : str
            The id of the kernel to stop buffering.
        session_key : str, optional
            The session_key, if any, that should get the buffer.
            If the session_key matches the current buffered session_key,
            the buffer will be returned.
        """
        self.log.debug("Getting buffer for %s", kernel_id)
        if kernel_id not in self._kernel_buffers:
            return

        buffer_info = self._kernel_buffers[kernel_id]
        if buffer_info["session_key"] == session_key:
            # remove buffer
            self._kernel_buffers.pop(kernel_id)
            # only return buffer_info if it's a match
            return buffer_info
        else:
            self.stop_buffering(kernel_id)

    def stop_buffering(self, kernel_id):
        """Stop buffering kernel messages

        Parameters
        ----------
        kernel_id : str
            The id of the kernel to stop buffering.
        """
        self.log.debug("Clearing buffer for %s", kernel_id)
        self._check_kernel_id(kernel_id)

        if kernel_id not in self._kernel_buffers:
            return
        buffer_info = self._kernel_buffers.pop(kernel_id)
        # close buffering streams
        for stream in buffer_info["channels"].values():
            if not stream.closed():
                stream.on_recv(None)
                stream.close()

        msg_buffer = buffer_info["buffer"]
        if msg_buffer:
            self.log.info(
                "Discarding %s buffered messages for %s",
                len(msg_buffer),
                buffer_info["session_key"],
            )

    def shutdown_kernel(self, kernel_id, now=False, restart=False):
        """Shutdown a kernel by kernel_id"""
        self._check_kernel_id(kernel_id)
        self.stop_watching_activity(kernel_id)
        self.stop_buffering(kernel_id)
        self._kernel_connections.pop(kernel_id, None)

        # Decrease the metric of number of kernels
        # running for the relevant kernel type by 1
        KERNEL_CURRENTLY_RUNNING_TOTAL.labels(type=self._kernels[kernel_id].kernel_name).dec()

        self.pinned_superclass.shutdown_kernel(self, kernel_id, now=now, restart=restart)
        # Unlike its async sibling method in AsyncMappingKernelManager, removing the kernel_id
        # from the connections dictionary isn't as problematic before the shutdown since the
        # method is synchronous.  However, we'll keep the relative call orders the same from
        # a maintenance perspective.
        self._kernel_connections.pop(kernel_id, None)
        self._kernel_ports.pop(kernel_id, None)

    async def restart_kernel(self, kernel_id, now=False):
        """Restart a kernel by kernel_id"""
        self._check_kernel_id(kernel_id)
        await ensure_async(self.pinned_superclass.restart_kernel(self, kernel_id, now=now))
        kernel = self.get_kernel(kernel_id)
        # return a Future that will resolve when the kernel has successfully restarted
        channel = kernel.connect_shell()
        future = Future()

        def finish():
            """Common cleanup when restart finishes/fails for any reason."""
            if not channel.closed():
                channel.close()
            loop.remove_timeout(timeout)
            kernel.remove_restart_callback(on_restart_failed, "dead")

        def on_reply(msg):
            self.log.debug("Kernel info reply received: %s", kernel_id)
            finish()
            if not future.done():
                future.set_result(msg)

        def on_timeout():
            self.log.warning("Timeout waiting for kernel_info_reply: %s", kernel_id)
            finish()
            if not future.done():
                future.set_exception(TimeoutError("Timeout waiting for restart"))

        def on_restart_failed():
            self.log.warning("Restarting kernel failed: %s", kernel_id)
            finish()
            if not future.done():
                future.set_exception(RuntimeError("Restart failed"))

        kernel.add_restart_callback(on_restart_failed, "dead")
        kernel.session.send(channel, "kernel_info_request")
        channel.on_recv(on_reply)
        loop = IOLoop.current()
        timeout = loop.add_timeout(loop.time() + self.kernel_info_timeout, on_timeout)
        # Re-establish activity watching if ports have changed...
        if self._get_changed_ports(kernel_id) is not None:
            self.stop_watching_activity(kernel_id)
            self.start_watching_activity(kernel_id)
        return future

    def notify_connect(self, kernel_id):
        """Notice a new connection to a kernel"""
        if kernel_id in self._kernel_connections:
            self._kernel_connections[kernel_id] += 1

    def notify_disconnect(self, kernel_id):
        """Notice a disconnection from a kernel"""
        if kernel_id in self._kernel_connections:
            self._kernel_connections[kernel_id] -= 1

    def kernel_model(self, kernel_id):
        """Return a JSON-safe dict representing a kernel

        For use in representing kernels in the JSON APIs.
        """
        self._check_kernel_id(kernel_id)
        kernel = self._kernels[kernel_id]

        model = {
            "id": kernel_id,
            "name": kernel.kernel_name,
            "last_activity": isoformat(kernel.last_activity),
            "execution_state": kernel.execution_state,
            "connections": self._kernel_connections.get(kernel_id, 0),
        }
        if getattr(kernel, "reason", None):
            model["reason"] = kernel.reason
        return model

    def list_kernels(self):
        """Returns a list of kernel_id's of kernels running."""
        kernels = []
        kernel_ids = self.pinned_superclass.list_kernel_ids(self)
        for kernel_id in kernel_ids:
            try:
                model = self.kernel_model(kernel_id)
                kernels.append(model)
            except (web.HTTPError, KeyError):
                pass  # Probably due to a (now) non-existent kernel, continue building the list
        return kernels

    # override _check_kernel_id to raise 404 instead of KeyError
    def _check_kernel_id(self, kernel_id):
        """Check a that a kernel_id exists and raise 404 if not."""
        if kernel_id not in self:
            raise web.HTTPError(404, "Kernel does not exist: %s" % kernel_id)

    # monitoring activity:

    def start_watching_activity(self, kernel_id):
        """Start watching IOPub messages on a kernel for activity.

        - update last_activity on every message
        - record execution_state from status messages
        """
        kernel = self._kernels[kernel_id]
        # add busy/activity markers:
        kernel.execution_state = "starting"
        kernel.reason = ""
        kernel.last_activity = utcnow()
        kernel._activity_stream = kernel.connect_iopub()
        session = Session(
            config=kernel.session.config,
            key=kernel.session.key,
        )

        def record_activity(msg_list):
            """Record an IOPub message arriving from a kernel"""
            self.last_kernel_activity = kernel.last_activity = utcnow()

            idents, fed_msg_list = session.feed_identities(msg_list)
            msg = session.deserialize(fed_msg_list)

            msg_type = msg["header"]["msg_type"]
            if msg_type == "status":
                kernel.execution_state = msg["content"]["execution_state"]
                self.log.debug(
                    "activity on %s: %s (%s)", kernel_id, msg_type, kernel.execution_state
                )
            else:
                self.log.debug("activity on %s: %s", kernel_id, msg_type)

        kernel._activity_stream.on_recv(record_activity)

    def stop_watching_activity(self, kernel_id):
        """Stop watching IOPub messages on a kernel for activity."""
        kernel = self._kernels[kernel_id]
        if getattr(kernel, "_activity_stream", None):
            kernel._activity_stream.close()
            kernel._activity_stream = None

    def initialize_culler(self):
        """Start idle culler if 'cull_idle_timeout' is greater than zero.

        Regardless of that value, set flag that we've been here.
        """
        if not self._initialized_culler and self.cull_idle_timeout > 0:
            if self._culler_callback is None:
                loop = IOLoop.current()
                if self.cull_interval <= 0:  # handle case where user set invalid value
                    self.log.warning(
                        "Invalid value for 'cull_interval' detected (%s) - using default value (%s).",
                        self.cull_interval,
                        self.cull_interval_default,
                    )
                    self.cull_interval = self.cull_interval_default
                self._culler_callback = PeriodicCallback(
                    self.cull_kernels, 1000 * self.cull_interval
                )
                self.log.info(
                    "Culling kernels with idle durations > %s seconds at %s second intervals ...",
                    self.cull_idle_timeout,
                    self.cull_interval,
                )
                if self.cull_busy:
                    self.log.info("Culling kernels even if busy")
                if self.cull_connected:
                    self.log.info("Culling kernels even with connected clients")
                self._culler_callback.start()

        self._initialized_culler = True

    async def cull_kernels(self):
        self.log.debug(
            "Polling every %s seconds for kernels idle > %s seconds...",
            self.cull_interval,
            self.cull_idle_timeout,
        )
        """Create a separate list of kernels to avoid conflicting updates while iterating"""
        for kernel_id in list(self._kernels):
            try:
                await self.cull_kernel_if_idle(kernel_id)
            except Exception as e:
                self.log.exception(
                    "The following exception was encountered while checking the idle duration of kernel %s: %s",
                    kernel_id,
                    e,
                )

    async def cull_kernel_if_idle(self, kernel_id):
        kernel = self._kernels[kernel_id]

        if getattr(kernel, "execution_state") == "dead":
            self.log.warning(
                "Culling '%s' dead kernel '%s' (%s).",
                kernel.execution_state,
                kernel.kernel_name,
                kernel_id,
            )
            await ensure_async(self.shutdown_kernel(kernel_id))
            return

        if hasattr(
            kernel, "last_activity"
        ):  # last_activity is monkey-patched, so ensure that has occurred
            self.log.debug(
                "kernel_id=%s, kernel_name=%s, last_activity=%s",
                kernel_id,
                kernel.kernel_name,
                kernel.last_activity,
            )
            dt_now = utcnow()
            dt_idle = dt_now - kernel.last_activity
            # Compute idle properties
            is_idle_time = dt_idle > timedelta(seconds=self.cull_idle_timeout)
            is_idle_execute = self.cull_busy or (kernel.execution_state != "busy")
            connections = self._kernel_connections.get(kernel_id, 0)
            is_idle_connected = self.cull_connected or not connections
            # Cull the kernel if all three criteria are met
            if is_idle_time and is_idle_execute and is_idle_connected:
                idle_duration = int(dt_idle.total_seconds())
                self.log.warning(
                    "Culling '%s' kernel '%s' (%s) with %d connections due to %s seconds of inactivity.",
                    kernel.execution_state,
                    kernel.kernel_name,
                    kernel_id,
                    connections,
                    idle_duration,
                )
                await ensure_async(self.shutdown_kernel(kernel_id))
Exemplo n.º 2
0
class DrawControl(Control):
    _view_name = Unicode('LeafletDrawControlView').tag(sync=True)
    _model_name = Unicode('LeafletDrawControlModel').tag(sync=True)

    # Enable each of the following drawing by giving them a non empty dict of options
    # You can add Leaflet style options in the shapeOptions sub-dict
    # See https://github.com/Leaflet/Leaflet.draw#polylineoptions
    # TODO: mutable default value!
    polyline = Dict({'shapeOptions': {}}).tag(sync=True)
    # See https://github.com/Leaflet/Leaflet.draw#polygonoptions
    # TODO: mutable default value!
    polygon = Dict({'shapeOptions': {}}).tag(sync=True)
    circlemarker = Dict({'shapeOptions': {}}).tag(sync=True)

    # Leave empty to disable these
    circle = Dict().tag(sync=True)
    rectangle = Dict().tag(sync=True)
    marker = Dict().tag(sync=True)

    # Edit tools
    edit = Bool(True).tag(sync=True)
    remove = Bool(True).tag(sync=True)

    last_draw = Dict({'type': 'Feature', 'geometry': None})
    last_action = Unicode()

    _draw_callbacks = Instance(CallbackDispatcher, ())

    def __init__(self, **kwargs):
        super(DrawControl, self).__init__(**kwargs)
        self.on_msg(self._handle_leaflet_event)

    def _handle_leaflet_event(self, _, content, buffers):
        if content.get('event', '').startswith('draw'):
            event, action = content.get('event').split(':')
            self.last_draw = content.get('geo_json')
            self.last_action = action
            self._draw_callbacks(self, action=action, geo_json=self.last_draw)

    def on_draw(self, callback, remove=False):
        self._draw_callbacks.register_callback(callback, remove=remove)

    def clear(self):
        self.send({'msg': 'clear'})

    def clear_polylines(self):
        self.send({'msg': 'clear_polylines'})

    def clear_polygons(self):
        self.send({'msg': 'clear_polygons'})

    def clear_circles(self):
        self.send({'msg': 'clear_circles'})

    def clear_circle_markers(self):
        self.send({'msg': 'clear_circle_markers'})

    def clear_rectangles(self):
        self.send({'msg': 'clear_rectangles'})

    def clear_markers(self):
        self.send({'msg': 'clear_markers'})
Exemplo n.º 3
0
class HubAuth(SingletonConfigurable):
    """A class for authenticating with JupyterHub

    This can be used by any application.

    If using tornado, use via :class:`HubAuthenticated` mixin.
    If using manually, use the ``.user_for_cookie(cookie_value)`` method
    to identify the user corresponding to a given cookie value.

    The following config must be set:

    - api_token (token for authenticating with JupyterHub API),
      fetched from the JUPYTERHUB_API_TOKEN env by default.

    The following config MAY be set:

    - api_url: the base URL of the Hub's internal API,
      fetched from JUPYTERHUB_API_URL by default.
    - cookie_cache_max_age: the number of seconds responses
      from the Hub should be cached.
    - login_url (the *public* ``/hub/login`` URL of the Hub).
    - cookie_name: the name of the cookie I should be using,
      if different from the default (unlikely).

    """

    hub_host = Unicode('',
                       help="""The public host of JupyterHub
        
        Only used if JupyterHub is spreading servers across subdomains.
        """).tag(config=True)

    @default('hub_host')
    def _default_hub_host(self):
        return os.getenv('JUPYTERHUB_HOST', '')

    base_url = Unicode(os.getenv('JUPYTERHUB_SERVICE_PREFIX') or '/',
                       help="""The base URL prefix of this application

        e.g. /services/service-name/ or /user/name/

        Default: get from JUPYTERHUB_SERVICE_PREFIX
        """).tag(config=True)

    @validate('base_url')
    def _add_slash(self, proposal):
        """Ensure base_url starts and ends with /"""
        value = proposal['value']
        if not value.startswith('/'):
            value = '/' + value
        if not value.endswith('/'):
            value = value + '/'
        return value

    # where is the hub
    api_url = Unicode(os.getenv('JUPYTERHUB_API_URL')
                      or 'http://127.0.0.1:8081/hub/api',
                      help="""The base API URL of the Hub.

        Typically `http://hub-ip:hub-port/hub/api`
        """).tag(config=True)

    @default('api_url')
    def _api_url(self):
        env_url = os.getenv('JUPYTERHUB_API_URL')
        if env_url:
            return env_url
        else:
            return 'http://127.0.0.1:8081' + url_path_join(
                self.hub_prefix, 'api')

    api_token = Unicode(os.getenv('JUPYTERHUB_API_TOKEN', ''),
                        help="""API key for accessing Hub API.

        Generate with `jupyterhub token [username]` or add to JupyterHub.services config.
        """).tag(config=True)

    hub_prefix = Unicode('/hub/',
                         help="""The URL prefix for the Hub itself.

        Typically /hub/
        """).tag(config=True)

    @default('hub_prefix')
    def _default_hub_prefix(self):
        return url_path_join(os.getenv('JUPYTERHUB_BASE_URL') or '/',
                             'hub') + '/'

    login_url = Unicode('/hub/login',
                        help="""The login URL to use

        Typically /hub/login
        """).tag(config=True)

    @default('login_url')
    def _default_login_url(self):
        return self.hub_host + url_path_join(self.hub_prefix, 'login')

    cookie_name = Unicode(
        'jupyterhub-services',
        help="""The name of the cookie I should be looking for""").tag(
            config=True)

    cookie_options = Dict(
        help="""Additional options to pass when setting cookies.

        Can include things like `expires_days=None` for session-expiry
        or `secure=True` if served on HTTPS and default HTTPS discovery fails
        (e.g. behind some proxies).
        """).tag(config=True)

    @default('cookie_options')
    def _default_cookie_options(self):
        # load default from env
        options_env = os.environ.get('JUPYTERHUB_COOKIE_OPTIONS')
        if options_env:
            return json.loads(options_env)
        else:
            return {}

    cookie_cache_max_age = Integer(help="DEPRECATED. Use cache_max_age")

    @observe('cookie_cache_max_age')
    def _deprecated_cookie_cache(self, change):
        warnings.warn(
            "cookie_cache_max_age is deprecated in JupyterHub 0.8. Use cache_max_age instead."
        )
        self.cache_max_age = change.new

    cache_max_age = Integer(
        300,
        help=
        """The maximum time (in seconds) to cache the Hub's responses for authentication.

        A larger value reduces load on the Hub and occasional response lag.
        A smaller value reduces propagation time of changes on the Hub (rare).

        Default: 300 (five minutes)
        """).tag(config=True)
    cache = Instance(_ExpiringDict, allow_none=False)

    @default('cache')
    def _default_cache(self):
        return _ExpiringDict(self.cache_max_age)

    def _check_hub_authorization(self, url, cache_key=None, use_cache=True):
        """Identify a user with the Hub
        
        Args:
            url (str): The API URL to check the Hub for authorization
                       (e.g. http://127.0.0.1:8081/hub/api/authorizations/token/abc-def)
            cache_key (str): The key for checking the cache
            use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)

        Returns:
            user_model (dict): The user model, if a user is identified, None if authentication fails.

        Raises an HTTPError if the request failed for a reason other than no such user.
        """
        if use_cache:
            if cache_key is None:
                raise ValueError("cache_key is required when using cache")
            # check for a cached reply, so we don't check with the Hub if we don't have to
            cached = self.cache.get(cache_key)
            if cached is not None:
                return cached
            else:
                app_log.debug("Cache miss: %s" % cache_key)

        data = self._api_request('GET', url, allow_404=True)
        if data is None:
            app_log.warning("No Hub user identified for request")
        else:
            app_log.debug("Received request from Hub user %s", data)
        if use_cache:
            # cache result
            self.cache[cache_key] = data
        return data

    def _api_request(self, method, url, **kwargs):
        """Make an API request"""
        allow_404 = kwargs.pop('allow_404', False)
        headers = kwargs.setdefault('headers', {})
        headers.setdefault('Authorization', 'token %s' % self.api_token)
        try:
            r = requests.request(method, url, **kwargs)
        except requests.ConnectionError as e:
            app_log.error("Error connecting to %s: %s", self.api_url, e)
            msg = "Failed to connect to Hub API at %r." % self.api_url
            msg += "  Is the Hub accessible at this URL (from host: %s)?" % socket.gethostname(
            )
            if '127.0.0.1' in self.api_url:
                msg += "  Make sure to set c.JupyterHub.hub_ip to an IP accessible to" + \
                       " single-user servers if the servers are not on the same host as the Hub."
            raise HTTPError(500, msg)

        data = None
        if r.status_code == 404 and allow_404:
            pass
        elif r.status_code == 403:
            app_log.error(
                "I don't have permission to check authorization with JupyterHub, my auth token may have expired: [%i] %s",
                r.status_code, r.reason)
            app_log.error(r.text)
            raise HTTPError(
                500,
                "Permission failure checking authorization, I may need a new token"
            )
        elif r.status_code >= 500:
            app_log.error("Upstream failure verifying auth token: [%i] %s",
                          r.status_code, r.reason)
            app_log.error(r.text)
            raise HTTPError(
                502, "Failed to check authorization (upstream problem)")
        elif r.status_code >= 400:
            app_log.warning("Failed to check authorization: [%i] %s",
                            r.status_code, r.reason)
            app_log.warning(r.text)
            msg = "Failed to check authorization"
            # pass on error_description from oauth failure
            try:
                description = r.json().get("error_description")
            except Exception:
                pass
            else:
                msg += ": " + description
            raise HTTPError(500, msg)
        else:
            data = r.json()

        return data

    def user_for_cookie(self, encrypted_cookie, use_cache=True, session_id=''):
        """Ask the Hub to identify the user for a given cookie.

        Args:
            encrypted_cookie (str): the cookie value (not decrypted, the Hub will do that)
            use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)

        Returns:
            user_model (dict): The user model, if a user is identified, None if authentication fails.

            The 'name' field contains the user's name.
        """
        return self._check_hub_authorization(
            url=url_path_join(self.api_url, "authorizations/cookie",
                              self.cookie_name, quote(encrypted_cookie,
                                                      safe='')),
            cache_key='cookie:{}:{}'.format(session_id, encrypted_cookie),
            use_cache=use_cache,
        )

    def user_for_token(self, token, use_cache=True, session_id=''):
        """Ask the Hub to identify the user for a given token.

        Args:
            token (str): the token
            use_cache (bool): Specify use_cache=False to skip cached cookie values (default: True)

        Returns:
            user_model (dict): The user model, if a user is identified, None if authentication fails.

            The 'name' field contains the user's name.
        """
        return self._check_hub_authorization(
            url=url_path_join(self.api_url, "authorizations/token",
                              quote(token, safe='')),
            cache_key='token:{}:{}'.format(session_id, token),
            use_cache=use_cache,
        )

    auth_header_name = 'Authorization'
    auth_header_pat = re.compile('token\s+(.+)', re.IGNORECASE)

    def get_token(self, handler):
        """Get the user token from a request

        - in URL parameters: ?token=<token>
        - in header: Authorization: token <token>
        """

        user_token = handler.get_argument('token', '')
        if not user_token:
            # get it from Authorization header
            m = self.auth_header_pat.match(
                handler.request.headers.get(self.auth_header_name, ''))
            if m:
                user_token = m.group(1)
        return user_token

    def _get_user_cookie(self, handler):
        """Get the user model from a cookie"""
        encrypted_cookie = handler.get_cookie(self.cookie_name)
        session_id = self.get_session_id(handler)
        if encrypted_cookie:
            return self.user_for_cookie(encrypted_cookie,
                                        session_id=session_id)

    def get_session_id(self, handler):
        """Get the jupyterhub session id

        from the jupyterhub-session-id cookie.
        """
        return handler.get_cookie('jupyterhub-session-id', '')

    def get_user(self, handler):
        """Get the Hub user for a given tornado handler.

        Checks cookie with the Hub to identify the current user.

        Args:
            handler (tornado.web.RequestHandler): the current request handler

        Returns:
            user_model (dict): The user model, if a user is identified, None if authentication fails.

            The 'name' field contains the user's name.
        """

        # only allow this to be called once per handler
        # avoids issues if an error is raised,
        # since this may be called again when trying to render the error page
        if hasattr(handler, '_cached_hub_user'):
            return handler._cached_hub_user

        handler._cached_hub_user = user_model = None
        session_id = self.get_session_id(handler)

        # check token first
        token = self.get_token(handler)
        if token:
            user_model = self.user_for_token(token, session_id=session_id)
            if user_model:
                handler._token_authenticated = True

        # no token, check cookie
        if user_model is None:
            user_model = self._get_user_cookie(handler)

        # cache result
        handler._cached_hub_user = user_model
        if not user_model:
            app_log.debug("No user identified")
        return user_model
Exemplo n.º 4
0
class BokehSPE(Tool):
    name = "BokehSPE"
    description = "Interactively explore the steps in obtaining and fitting " \
                  "SPE spectrum"

    aliases = Dict(
        dict(
            r='EventFileReaderFactory.reader',
            f='EventFileReaderFactory.input_path',
            max_events='EventFileReaderFactory.max_events',
            ped='CameraR1CalibratorFactory.pedestal_path',
            tf='CameraR1CalibratorFactory.tf_path',
            pe='CameraR1CalibratorFactory.pe_path',
            fitter='ChargeFitterFactory.fitter',
        ))
    classes = List([
        EventFileReaderFactory,
        CameraR1CalibratorFactory,
        ChargeFitterFactory,
    ])

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._event = None
        self._event_index = None
        self._event_id = None
        self._active_pixel = 0

        self.w_event_index = None
        self.w_goto_event_index = None
        self.w_hoa = None
        self.w_fitspectrum = None
        self.w_fitcamera = None
        self.layout = None

        self.reader = None
        self.r1 = None
        self.dl0 = None
        self.dl1 = None
        self.dl1_height = None
        self.area = None
        self.height = None

        self.n_events = None
        self.n_pixels = None
        self.n_samples = None

        self.cleaner = None
        self.extractor = None
        self.extractor_height = None
        self.dead = None
        self.fitter = None

        self.neighbours2d = None
        self.stage_names = None

        self.p_camera_area = None
        self.p_camera_fit_gain = None
        self.p_camera_fit_brightness = None
        self.p_fitter = None
        self.p_stage_viewer = None
        self.p_fit_viewer = None
        self.p_fit_table = None

    def setup(self):
        self.log_format = "%(levelname)s: %(message)s [%(name)s.%(funcName)s]"
        kwargs = dict(config=self.config, tool=self)

        reader_factory = EventFileReaderFactory(**kwargs)
        reader_class = reader_factory.get_class()
        self.reader = reader_class(**kwargs)

        r1_factory = CameraR1CalibratorFactory(origin=self.reader.origin,
                                               **kwargs)
        r1_class = r1_factory.get_class()
        self.r1 = r1_class(**kwargs)

        self.dl0 = CameraDL0Reducer(**kwargs)

        self.cleaner = CHECMWaveformCleanerAverage(**kwargs)
        self.extractor = AverageWfPeakIntegrator(**kwargs)
        self.extractor_height = SimpleIntegrator(window_shift=0,
                                                 window_width=1,
                                                 **kwargs)

        self.dl1 = CameraDL1Calibrator(extractor=self.extractor,
                                       cleaner=self.cleaner,
                                       **kwargs)
        self.dl1_height = CameraDL1Calibrator(extractor=self.extractor_height,
                                              cleaner=self.cleaner,
                                              **kwargs)

        self.dead = Dead()

        fitter_factory = ChargeFitterFactory(**kwargs)
        fitter_class = fitter_factory.get_class()
        self.fitter = fitter_class(**kwargs)

        self.n_events = self.reader.num_events
        first_event = self.reader.get_event(0)
        self.n_pixels = first_event.inst.num_pixels[0]
        self.n_samples = first_event.r0.tel[0].num_samples

        geom = CameraGeometry.guess(*first_event.inst.pixel_pos[0],
                                    first_event.inst.optical_foclen[0])
        self.neighbours2d = get_neighbours_2d(geom.pix_x, geom.pix_y)

        # Get stage names
        self.stage_names = [
            '0: raw', '1: baseline_sub', '2: no_pulse', '3: smooth_baseline',
            '4: smooth_wf', '5: cleaned'
        ]

        # Init Plots
        self.p_camera_area = Camera(self, self.neighbours2d, "Area", geom)
        self.p_camera_fit_gain = Camera(self, self.neighbours2d, "Gain", geom)
        self.p_camera_fit_brightness = Camera(self, self.neighbours2d,
                                              "Brightness", geom)
        self.p_fitter = FitterWidget(fitter=self.fitter, **kwargs)
        self.p_stage_viewer = StageViewer(**kwargs)
        self.p_fit_viewer = FitViewer(**kwargs)
        self.p_fit_table = FitTable(**kwargs)

    def start(self):
        # Prepare storage array
        self.area = np.zeros((self.n_events, self.n_pixels))
        self.height = np.zeros((self.n_events, self.n_pixels))

        source = self.reader.read()
        desc = "Looping through file"
        for event in tqdm(source, total=self.n_events, desc=desc):
            index = event.count

            self.r1.calibrate(event)
            self.dl0.reduce(event)
            self.dl1.calibrate(event)
            peak_area = np.copy(event.dl1.tel[0].image)
            self.dl1_height.calibrate(event)
            peak_height = np.copy(event.dl1.tel[0].image)

            self.area[index] = peak_area
            self.height[index] = peak_height

        # Setup Plots
        self.p_camera_area.enable_pixel_picker()
        self.p_camera_area.add_colorbar()
        self.p_camera_fit_gain.enable_pixel_picker()
        self.p_camera_fit_gain.add_colorbar()
        self.p_camera_fit_brightness.enable_pixel_picker()
        self.p_camera_fit_brightness.add_colorbar()
        self.p_fitter.create()
        self.p_stage_viewer.create(self.neighbours2d, self.stage_names)
        self.p_fit_viewer.create(self.p_fitter.fitter.subfit_labels)
        self.p_fit_table.create()

        # Setup widgets
        self.create_event_index_widget()
        self.create_goto_event_index_widget()
        self.event_index = 0
        self.create_hoa_widget()
        self.create_fitspectrum_widget()
        self.create_fitcamera_widget()

        # Get bokeh layouts
        l_camera_area = self.p_camera_area.layout
        l_camera_fit_gain = self.p_camera_fit_gain.layout
        l_camera_fit_brightness = self.p_camera_fit_brightness.layout
        l_fitter = self.p_fitter.layout
        l_stage_viewer = self.p_stage_viewer.layout
        l_fit_viewer = self.p_fit_viewer.layout
        l_fit_table = self.p_fit_table.layout

        # Setup layout
        self.layout = layout([
            [self.w_hoa, self.w_fitspectrum, self.w_fitcamera],
            [l_camera_fit_brightness, l_fit_viewer, l_fitter],
            [l_camera_fit_gain, l_fit_table],
            [l_camera_area, self.w_goto_event_index, self.w_event_index],
            [Div(text="Stage Viewer")],
            [l_stage_viewer],
        ])

    def finish(self):
        curdoc().add_root(self.layout)
        curdoc().title = "Event Viewer"

    def fit_spectrum(self, pix):
        if self.w_hoa.active == 0:
            spectrum = self.area
        else:
            spectrum = self.height

        success = self.p_fitter.fit(spectrum[:, pix])
        if not success:
            self.log.warning("Pixel {} couldn't be fit".format(pix))
        return success

    def fit_camera(self):
        gain = np.ma.zeros(self.n_pixels)
        gain.mask = np.zeros(gain.shape, dtype=np.bool)
        brightness = np.ma.zeros(self.n_pixels)
        brightness.mask = np.zeros(gain.shape, dtype=np.bool)

        fitter = self.p_fitter.fitter.fitter_type
        if fitter == 'spe':
            coeff = 'lambda_'
        elif fitter == 'bright':
            coeff = 'mean'
        else:
            self.log.error("No case for fitter type: {}".format(fitter))
            raise ValueError

        desc = "Fitting pixels"
        for pix in trange(self.n_pixels, desc=desc):
            if not self.fit_spectrum(pix):
                gain.mask[pix] = True
                continue
            if fitter == 'spe':
                gain[pix] = self.p_fitter.fitter.coeff['spe']
            brightness[pix] = self.p_fitter.fitter.coeff[coeff]

        gain = np.ma.masked_where(np.isnan(gain), gain)
        gain = self.dead.mask1d(gain)
        brightness = np.ma.masked_where(np.isnan(brightness), brightness)
        brightness = self.dead.mask1d(brightness)

        self.p_camera_fit_gain.image = gain
        self.p_camera_fit_brightness.image = brightness

    @property
    def event(self):
        return self._event

    @event.setter
    def event(self, val):
        self._event = val

        self.r1.calibrate(val)
        self.dl0.reduce(val)
        self.dl1.calibrate(val)
        peak_area = val.dl1.tel[0].image

        self._event_index = val.count
        self._event_id = val.r0.event_id
        self.update_event_index_widget()

        stages = self.dl1.cleaner.stages
        pulse_window = self.dl1.cleaner.stages['window'][0]
        int_window = val.dl1.tel[0].extracted_samples[0]

        self.p_camera_area.image = peak_area
        self.p_stage_viewer.update_stages(np.arange(self.n_samples), stages,
                                          pulse_window, int_window)

    @property
    def event_index(self):
        return self._event_index

    @event_index.setter
    def event_index(self, val):
        self._event_index = val
        self.event = self.reader.get_event(val, False)

    @property
    def active_pixel(self):
        return self._active_pixel

    @active_pixel.setter
    def active_pixel(self, val):
        if not self._active_pixel == val:
            self._active_pixel = val

            self.fit_spectrum(val)

            self.p_camera_area.active_pixel = val
            self.p_camera_fit_gain.active_pixel = val
            self.p_camera_fit_brightness.active_pixel = val
            self.p_stage_viewer.active_pixel = val

            self.p_fit_viewer.update(self.p_fitter.fitter)
            self.p_fit_table.update(self.p_fitter.fitter)

    def create_event_index_widget(self):
        self.w_event_index = TextInput(title="Event Index:", value='')

    def update_event_index_widget(self):
        if self.w_event_index:
            self.w_event_index.value = str(self.event_index)

    def create_goto_event_index_widget(self):
        self.w_goto_event_index = Button(label="GOTO Index", width=100)
        self.w_goto_event_index.on_click(self.on_goto_event_index_widget_click)

    def on_goto_event_index_widget_click(self):
        self.event_index = int(self.w_event_index.value)

    def on_event_index_widget_change(self, attr, old, new):
        if self.event_index != int(self.w_event_index.value):
            self.event_index = int(self.w_event_index.value)

    def create_hoa_widget(self):
        self.w_hoa = RadioGroup(labels=['area', 'height'], active=0)
        self.w_hoa.on_click(self.on_hoa_widget_select)

    def on_hoa_widget_select(self, active):
        self.fit_spectrum(self.active_pixel)
        self.p_fit_viewer.update(self.p_fitter.fitter)
        self.p_fit_table.update(self.p_fitter.fitter)

    def create_fitspectrum_widget(self):
        self.w_fitspectrum = Button(label='Fit Spectrum')
        self.w_fitspectrum.on_click(self.on_fitspectrum_widget_select)

    def on_fitspectrum_widget_select(self):
        t = time()
        self.fit_spectrum(self.active_pixel)
        self.log.info("Fit took {} seconds".format(time() - t))
        self.p_fit_viewer.update(self.p_fitter.fitter)
        self.p_fit_table.update(self.p_fitter.fitter)

    def create_fitcamera_widget(self):
        self.w_fitcamera = Button(label='Fit Camera')
        self.w_fitcamera.on_click(self.on_fitcamera_widget_select)

    def on_fitcamera_widget_select(self):
        self.fit_camera()
        self.fit_spectrum(self.active_pixel)
        self.p_fit_viewer.update(self.p_fitter.fitter)
        self.p_fit_table.update(self.p_fitter.fitter)
Exemplo n.º 5
0
class HistoryManager(HistoryAccessor):
    """A class to organize all history-related functionality in one place.
    """
    # Public interface

    # An instance of the IPython shell we are attached to
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    # Lists to hold processed and raw history. These start with a blank entry
    # so that we can index them starting from 1
    input_hist_parsed = List([""])
    input_hist_raw = List([""])
    # A list of directories visited during session
    dir_hist = List()

    @default('dir_hist')
    def _dir_hist_default(self):
        try:
            return [os.getcwd()]
        except OSError:
            return []

    # A dict of output history, keyed with ints from the shell's
    # execution count.
    output_hist = Dict()
    # The text/plain repr of outputs.
    output_hist_reprs = Dict()

    # The number of the current session in the history database
    session_number = Integer()

    db_log_output = Bool(
        False,
        help="Should the history database include output? (default: no)").tag(
            config=True)
    db_cache_size = Integer(
        0,
        help=
        "Write to database every x commands (higher values save disk access & power).\n"
        "Values of 1 or less effectively disable caching.").tag(config=True)
    # The input and output caches
    db_input_cache = List()
    db_output_cache = List()

    # History saving in separate thread
    save_thread = Instance('IPython.core.history.HistorySavingThread',
                           allow_none=True)
    try:  # Event is a function returning an instance of _Event...
        save_flag = Instance(threading._Event, allow_none=True)
    except AttributeError:  # ...until Python 3.3, when it's a class.
        save_flag = Instance(threading.Event, allow_none=True)

    # Private interface
    # Variables used to store the three last inputs from the user.  On each new
    # history update, we populate the user's namespace with these, shifted as
    # necessary.
    _i00 = Unicode(u'')
    _i = Unicode(u'')
    _ii = Unicode(u'')
    _iii = Unicode(u'')

    # A regex matching all forms of the exit command, so that we don't store
    # them in the history (it's annoying to rewind the first entry and land on
    # an exit call).
    _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")

    def __init__(self, shell=None, config=None, **traits):
        """Create a new history manager associated with a shell instance.
        """
        # We need a pointer back to the shell for various tasks.
        super(HistoryManager, self).__init__(shell=shell,
                                             config=config,
                                             **traits)
        self.save_flag = threading.Event()
        self.db_input_cache_lock = threading.Lock()
        self.db_output_cache_lock = threading.Lock()

        try:
            self.new_session()
        except OperationalError:
            self.log.error(
                "Failed to create history session in %s. History will not be saved.",
                self.hist_file,
                exc_info=True)
            self.hist_file = ':memory:'

        if self.enabled and self.hist_file != ':memory:':
            self.save_thread = HistorySavingThread(self)
            self.save_thread.start()

    def _get_hist_file_name(self, profile=None):
        """Get default history file name based on the Shell's profile.
        
        The profile parameter is ignored, but must exist for compatibility with
        the parent class."""
        profile_dir = self.shell.profile_dir.location
        return os.path.join(profile_dir, 'history.sqlite')

    @needs_sqlite
    def new_session(self, conn=None):
        """Get a new session number."""
        if conn is None:
            conn = self.db

        with conn:
            cur = conn.execute(
                """INSERT INTO sessions VALUES (NULL, ?, NULL,
                            NULL, "") """, (datetime.datetime.now(), ))
            self.session_number = cur.lastrowid

    def end_session(self):
        """Close the database session, filling in the end time and line count."""
        self.writeout_cache()
        with self.db:
            self.db.execute(
                """UPDATE sessions SET end=?, num_cmds=? WHERE
                            session==?""",
                (datetime.datetime.now(), len(self.input_hist_parsed) - 1,
                 self.session_number))
        self.session_number = 0

    def name_session(self, name):
        """Give the current session a name in the history database."""
        with self.db:
            self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
                            (name, self.session_number))

    def reset(self, new_session=True):
        """Clear the session history, releasing all object references, and
        optionally open a new session."""
        self.output_hist.clear()
        # The directory history can't be completely empty
        self.dir_hist[:] = [os.getcwd()]

        if new_session:
            if self.session_number:
                self.end_session()
            self.input_hist_parsed[:] = [""]
            self.input_hist_raw[:] = [""]
            self.new_session()

    # ------------------------------
    # Methods for retrieving history
    # ------------------------------
    def get_session_info(self, session=0):
        """Get info about a session.

        Parameters
        ----------

        session : int
            Session number to retrieve. The current session is 0, and negative
            numbers count back from current session, so -1 is the previous session.

        Returns
        -------
        
        session_id : int
           Session ID number
        start : datetime
           Timestamp for the start of the session.
        end : datetime
           Timestamp for the end of the session, or None if IPython crashed.
        num_cmds : int
           Number of commands run, or None if IPython crashed.
        remark : unicode
           A manually set description.
        """
        if session <= 0:
            session += self.session_number

        return super(HistoryManager, self).get_session_info(session=session)

    def _get_range_session(self, start=1, stop=None, raw=True, output=False):
        """Get input and output history from the current session. Called by
        get_range, and takes similar parameters."""
        input_hist = self.input_hist_raw if raw else self.input_hist_parsed

        n = len(input_hist)
        if start < 0:
            start += n
        if not stop or (stop > n):
            stop = n
        elif stop < 0:
            stop += n

        for i in range(start, stop):
            if output:
                line = (input_hist[i], self.output_hist_reprs.get(i))
            else:
                line = input_hist[i]
            yield (0, i, line)

    def get_range(self, session=0, start=1, stop=None, raw=True, output=False):
        """Retrieve input by session.
        
        Parameters
        ----------
        session : int
            Session number to retrieve. The current session is 0, and negative
            numbers count back from current session, so -1 is previous session.
        start : int
            First line to retrieve.
        stop : int
            End of line range (excluded from output itself). If None, retrieve
            to the end of the session.
        raw : bool
            If True, return untranslated input
        output : bool
            If True, attempt to include output. This will be 'real' Python
            objects for the current session, or text reprs from previous
            sessions if db_log_output was enabled at the time. Where no output
            is found, None is used.
            
        Returns
        -------
        entries
          An iterator over the desired lines. Each line is a 3-tuple, either
          (session, line, input) if output is False, or
          (session, line, (input, output)) if output is True.
        """
        if session <= 0:
            session += self.session_number
        if session == self.session_number:  # Current session
            return self._get_range_session(start, stop, raw, output)
        return super(HistoryManager, self).get_range(session, start, stop, raw,
                                                     output)

    ## ----------------------------
    ## Methods for storing history:
    ## ----------------------------
    def store_inputs(self, line_num, source, source_raw=None):
        """Store source and raw input in history and create input cache
        variables ``_i*``.

        Parameters
        ----------
        line_num : int
          The prompt number of this input.

        source : str
          Python input.

        source_raw : str, optional
          If given, this is the raw input without any IPython transformations
          applied to it.  If not given, ``source`` is used.
        """
        if source_raw is None:
            source_raw = source
        source = source.rstrip('\n')
        source_raw = source_raw.rstrip('\n')

        # do not store exit/quit commands
        if self._exit_re.match(source_raw.strip()):
            return

        self.input_hist_parsed.append(source)
        self.input_hist_raw.append(source_raw)

        with self.db_input_cache_lock:
            self.db_input_cache.append((line_num, source, source_raw))
            # Trigger to flush cache and write to DB.
            if len(self.db_input_cache) >= self.db_cache_size:
                self.save_flag.set()

        # update the auto _i variables
        self._iii = self._ii
        self._ii = self._i
        self._i = self._i00
        self._i00 = source_raw

        # hackish access to user namespace to create _i1,_i2... dynamically
        new_i = '_i%s' % line_num
        to_main = {
            '_i': self._i,
            '_ii': self._ii,
            '_iii': self._iii,
            new_i: self._i00
        }

        if self.shell is not None:
            self.shell.push(to_main, interactive=False)

    def store_output(self, line_num):
        """If database output logging is enabled, this saves all the
        outputs from the indicated prompt number to the database. It's
        called by run_cell after code has been executed.

        Parameters
        ----------
        line_num : int
          The line number from which to save outputs
        """
        if (not self.db_log_output) or (line_num
                                        not in self.output_hist_reprs):
            return
        output = self.output_hist_reprs[line_num]

        with self.db_output_cache_lock:
            self.db_output_cache.append((line_num, output))
        if self.db_cache_size <= 1:
            self.save_flag.set()

    def _writeout_input_cache(self, conn):
        with conn:
            for line in self.db_input_cache:
                conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
                             (self.session_number, ) + line)

    def _writeout_output_cache(self, conn):
        with conn:
            for line in self.db_output_cache:
                conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
                             (self.session_number, ) + line)

    @needs_sqlite
    def writeout_cache(self, conn=None):
        """Write any entries in the cache to the database."""
        if conn is None:
            conn = self.db

        with self.db_input_cache_lock:
            try:
                self._writeout_input_cache(conn)
            except sqlite3.IntegrityError:
                self.new_session(conn)
                print("ERROR! Session/line number was not unique in",
                      "database. History logging moved to new session",
                      self.session_number)
                try:
                    # Try writing to the new session. If this fails, don't
                    # recurse
                    self._writeout_input_cache(conn)
                except sqlite3.IntegrityError:
                    pass
            finally:
                self.db_input_cache = []

        with self.db_output_cache_lock:
            try:
                self._writeout_output_cache(conn)
            except sqlite3.IntegrityError:
                print("!! Session/line number for output was not unique",
                      "in database. Output will not be stored.")
            finally:
                self.db_output_cache = []
Exemplo n.º 6
0
class HistoryAccessor(HistoryAccessorBase):
    """Access the history database without adding to it.
    
    This is intended for use by standalone history tools. IPython shells use
    HistoryManager, below, which is a subclass of this."""

    # counter for init_db retries, so we don't keep trying over and over
    _corrupt_db_counter = 0
    # after two failures, fallback on :memory:
    _corrupt_db_limit = 2

    # String holding the path to the history file
    hist_file = Unicode(
        help="""Path to file to use for SQLite history database.
        
        By default, IPython will put the history database in the IPython
        profile directory.  If you would rather share one history among
        profiles, you can set this value in each, so that they are consistent.
        
        Due to an issue with fcntl, SQLite is known to misbehave on some NFS
        mounts.  If you see IPython hanging, try setting this to something on a
        local disk, e.g::
        
            ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite

        you can also use the specific value `:memory:` (including the colon
        at both end but not the back ticks), to avoid creating an history file.
        
        """).tag(config=True)

    enabled = Bool(True,
                   help="""enable the SQLite history
        
        set enabled=False to disable the SQLite history,
        in which case there will be no stored history, no SQLite connection,
        and no background saving thread.  This may be necessary in some
        threaded environments where IPython is embedded.
        """).tag(config=True)

    connection_options = Dict(
        help="""Options for configuring the SQLite connection
        
        These options are passed as keyword args to sqlite3.connect
        when establishing database conenctions.
        """).tag(config=True)

    # The SQLite database
    db = Any()

    @observe('db')
    def _db_changed(self, change):
        """validate the db, since it can be an Instance of two different types"""
        new = change['new']
        connection_types = (DummyDB, )
        if sqlite3 is not None:
            connection_types = (DummyDB, sqlite3.Connection)
        if not isinstance(new, connection_types):
            msg = "%s.db must be sqlite3 Connection or DummyDB, not %r" % \
                    (self.__class__.__name__, new)
            raise TraitError(msg)

    def __init__(self, profile='default', hist_file=u'', **traits):
        """Create a new history accessor.
        
        Parameters
        ----------
        profile : str
          The name of the profile from which to open history.
        hist_file : str
          Path to an SQLite history database stored by IPython. If specified,
          hist_file overrides profile.
        config : :class:`~traitlets.config.loader.Config`
          Config object. hist_file can also be set through this.
        """
        # We need a pointer back to the shell for various tasks.
        super(HistoryAccessor, self).__init__(**traits)
        # defer setting hist_file from kwarg until after init,
        # otherwise the default kwarg value would clobber any value
        # set by config
        if hist_file:
            self.hist_file = hist_file

        if self.hist_file == u'':
            # No one has set the hist_file, yet.
            self.hist_file = self._get_hist_file_name(profile)

        if sqlite3 is None and self.enabled:
            warn(
                "IPython History requires SQLite, your history will not be saved"
            )
            self.enabled = False

        self.init_db()

    def _get_hist_file_name(self, profile='default'):
        """Find the history file for the given profile name.
        
        This is overridden by the HistoryManager subclass, to use the shell's
        active profile.
        
        Parameters
        ----------
        profile : str
          The name of a profile which has a history file.
        """
        return os.path.join(locate_profile(profile), 'history.sqlite')

    @catch_corrupt_db
    def init_db(self):
        """Connect to the database, and create tables if necessary."""
        if not self.enabled:
            self.db = DummyDB()
            return

        # use detect_types so that timestamps return datetime objects
        kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES
                      | sqlite3.PARSE_COLNAMES)
        kwargs.update(self.connection_options)
        self.db = sqlite3.connect(self.hist_file, **kwargs)
        self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
                        primary key autoincrement, start timestamp,
                        end timestamp, num_cmds integer, remark text)""")
        self.db.execute("""CREATE TABLE IF NOT EXISTS history
                (session integer, line integer, source text, source_raw text,
                PRIMARY KEY (session, line))""")
        # Output history is optional, but ensure the table's there so it can be
        # enabled later.
        self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
                        (session integer, line integer, output text,
                        PRIMARY KEY (session, line))""")
        self.db.commit()
        # success! reset corrupt db count
        self._corrupt_db_counter = 0

    def writeout_cache(self):
        """Overridden by HistoryManager to dump the cache before certain
        database lookups."""
        pass

    ## -------------------------------
    ## Methods for retrieving history:
    ## -------------------------------
    def _run_sql(self, sql, params, raw=True, output=False):
        """Prepares and runs an SQL query for the history database.

        Parameters
        ----------
        sql : str
          Any filtering expressions to go after SELECT ... FROM ...
        params : tuple
          Parameters passed to the SQL query (to replace "?")
        raw, output : bool
          See :meth:`get_range`

        Returns
        -------
        Tuples as :meth:`get_range`
        """
        toget = 'source_raw' if raw else 'source'
        sqlfrom = "history"
        if output:
            sqlfrom = "history LEFT JOIN output_history USING (session, line)"
            toget = "history.%s, output_history.output" % toget
        cur = self.db.execute("SELECT session, line, %s FROM %s " %\
                                (toget, sqlfrom) + sql, params)
        if output:  # Regroup into 3-tuples, and parse JSON
            return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
        return cur

    @needs_sqlite
    @catch_corrupt_db
    def get_session_info(self, session):
        """Get info about a session.

        Parameters
        ----------

        session : int
            Session number to retrieve.

        Returns
        -------
        
        session_id : int
           Session ID number
        start : datetime
           Timestamp for the start of the session.
        end : datetime
           Timestamp for the end of the session, or None if IPython crashed.
        num_cmds : int
           Number of commands run, or None if IPython crashed.
        remark : unicode
           A manually set description.
        """
        query = "SELECT * from sessions where session == ?"
        return self.db.execute(query, (session, )).fetchone()

    @catch_corrupt_db
    def get_last_session_id(self):
        """Get the last session ID currently in the database.
        
        Within IPython, this should be the same as the value stored in
        :attr:`HistoryManager.session_number`.
        """
        for record in self.get_tail(n=1, include_latest=True):
            return record[0]

    @catch_corrupt_db
    def get_tail(self, n=10, raw=True, output=False, include_latest=False):
        """Get the last n lines from the history database.

        Parameters
        ----------
        n : int
          The number of lines to get
        raw, output : bool
          See :meth:`get_range`
        include_latest : bool
          If False (default), n+1 lines are fetched, and the latest one
          is discarded. This is intended to be used where the function
          is called by a user command, which it should not return.

        Returns
        -------
        Tuples as :meth:`get_range`
        """
        self.writeout_cache()
        if not include_latest:
            n += 1
        cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?", (n, ),
                            raw=raw,
                            output=output)
        if not include_latest:
            return reversed(list(cur)[1:])
        return reversed(list(cur))

    @catch_corrupt_db
    def search(self,
               pattern="*",
               raw=True,
               search_raw=True,
               output=False,
               n=None,
               unique=False):
        """Search the database using unix glob-style matching (wildcards
        * and ?).

        Parameters
        ----------
        pattern : str
          The wildcarded pattern to match when searching
        search_raw : bool
          If True, search the raw input, otherwise, the parsed input
        raw, output : bool
          See :meth:`get_range`
        n : None or int
          If an integer is given, it defines the limit of
          returned entries.
        unique : bool
          When it is true, return only unique entries.

        Returns
        -------
        Tuples as :meth:`get_range`
        """
        tosearch = "source_raw" if search_raw else "source"
        if output:
            tosearch = "history." + tosearch
        self.writeout_cache()
        sqlform = "WHERE %s GLOB ?" % tosearch
        params = (pattern, )
        if unique:
            sqlform += ' GROUP BY {0}'.format(tosearch)
        if n is not None:
            sqlform += " ORDER BY session DESC, line DESC LIMIT ?"
            params += (n, )
        elif unique:
            sqlform += " ORDER BY session, line"
        cur = self._run_sql(sqlform, params, raw=raw, output=output)
        if n is not None:
            return reversed(list(cur))
        return cur

    @catch_corrupt_db
    def get_range(self, session, start=1, stop=None, raw=True, output=False):
        """Retrieve input by session.

        Parameters
        ----------
        session : int
            Session number to retrieve.
        start : int
            First line to retrieve.
        stop : int
            End of line range (excluded from output itself). If None, retrieve
            to the end of the session.
        raw : bool
            If True, return untranslated input
        output : bool
            If True, attempt to include output. This will be 'real' Python
            objects for the current session, or text reprs from previous
            sessions if db_log_output was enabled at the time. Where no output
            is found, None is used.

        Returns
        -------
        entries
          An iterator over the desired lines. Each line is a 3-tuple, either
          (session, line, input) if output is False, or
          (session, line, (input, output)) if output is True.
        """
        if stop:
            lineclause = "line >= ? AND line < ?"
            params = (session, start, stop)
        else:
            lineclause = "line>=?"
            params = (session, start)

        return self._run_sql("WHERE session==? AND %s" % lineclause,
                             params,
                             raw=raw,
                             output=output)

    def get_range_by_str(self, rangestr, raw=True, output=False):
        """Get lines of history from a string of ranges, as used by magic
        commands %hist, %save, %macro, etc.

        Parameters
        ----------
        rangestr : str
          A string specifying ranges, e.g. "5 ~2/1-4". See
          :func:`magic_history` for full details.
        raw, output : bool
          As :meth:`get_range`

        Returns
        -------
        Tuples as :meth:`get_range`
        """
        for sess, s, e in extract_hist_ranges(rangestr):
            for line in self.get_range(sess, s, e, raw=raw, output=output):
                yield line
class PerspectiveTraitlets(HasTraits):
    '''Define the traitlet interface with `PerspectiveJupyterWidget` on the
    front end. Attributes which are set here are synchronized between the
    front-end and back-end.

    Examples:
        >>> widget = perspective.PerspectiveWidget(
        ...     data, row_pivots=["a", "b", "c"])
        PerspectiveWidget(row_pivots=["a", "b", "c"])
        >>> widget.column_pivots=["b"]
        >>> widget
        PerspectiveWidget(row_pivots=["a", "b", "c"], column_pivots=["b"])
    '''

    # `perspective-viewer` options
    plugin = Unicode('hypergrid').tag(sync=True)
    columns = List(default_value=[]).tag(sync=True)
    row_pivots = List(trait=Unicode(), default_value=[]).tag(sync=True, o=True)
    column_pivots = List(trait=Unicode(), default_value=[]).tag(sync=True)
    aggregates = Dict(default_value={}).tag(sync=True)
    sort = List(default_value=[]).tag(sync=True)
    filters = List(default_value=[]).tag(sync=True)
    computed_columns = List(default_value=[]).tag(sync=True)
    plugin_config = Dict(default_value={}).tag(sync=True)
    dark = Bool(None, allow_none=True).tag(sync=True)
    editable = Bool(False).tag(sync=True)
    client = Bool(False).tag(sync=True)

    @validate('plugin')
    def _validate_plugin(self, proposal):
        return validate_plugin(proposal.value)

    @validate('columns')
    def _validate_columns(self, proposal):
        return validate_columns(proposal.value)

    @validate('row_pivots')
    def _validate_row_pivots(self, proposal):
        return validate_row_pivots(proposal.value)

    @validate('column_pivots')
    def _validate_column_pivots(self, proposal):
        return validate_column_pivots(proposal.value)

    @validate('aggregates')
    def _validate_aggregates(self, proposal):
        return validate_aggregates(proposal.value)

    @validate('sort')
    def _validate_sort(self, proposal):
        return validate_sort(proposal.value)

    @validate('filters')
    def _validate_filters(self, proposal):
        return validate_filters(proposal.value)

    @validate('filters')
    def _validate_computed_columns(self, proposal):
        return validate_computed_columns(proposal.value)

    @validate('plugin_config')
    def _validate_plugin_config(self, proposal):
        return validate_plugin_config(proposal.value)