Exemple #1
0
class TradingEngine(Application):
    '''A configureable trading application'''
    name = 'AAT'
    description = 'async algorithmic trading engine'

    # Configureable parameters
    verbose = Bool(default_value=True)
    api = Bool(default_value=False)
    port = Unicode(default_value='8080',
                   help="Port to run on").tag(config=True)
    event_loop = Instance(klass=asyncio.events.AbstractEventLoop)
    executor = Instance(klass=ThreadPoolExecutor, args=(4, ), kwargs={})

    # Core components
    trading_type = Instance(klass=TradingType,
                            default_value=TradingType.SIMULATION)
    order_manager = Instance(OrderManager, args=(), kwargs={})
    risk_manager = Instance(RiskManager, args=(), kwargs={})
    portfolio_manager = Instance(PortfolioManager, args=(), kwargs={})
    exchanges = List(trait=Instance(klass=Exchange))
    event_handlers = List(trait=Instance(EventHandler), default_value=[])

    # API application
    api_application = Instance(klass=TornadoApplication)
    api_handlers = List(default_value=[])

    table_manager = Instance(klass=PerspectiveManager or object,
                             args=(),
                             kwargs={})  # failover to object

    aliases = {
        'port': 'AAT.port',
        'trading_type': 'AAT.trading_type',
    }

    @validate('trading_type')
    def _validate_trading_type(self, proposal):
        if proposal['value'] not in (TradingType.LIVE, TradingType.SIMULATION,
                                     TradingType.SANDBOX,
                                     TradingType.BACKTEST):
            raise TraitError(f'Invalid trading type: {proposal["value"]}')
        return proposal['value']

    @validate('exchanges')
    def _validate_exchanges(self, proposal):
        for exch in proposal['value']:
            if not isinstance(exch, Exchange):
                raise TraitError(f'Invalid exchange type: {exch}')
        return proposal['value']

    def __init__(self, **config):
        # get port for API access
        self.port = config.get('general', {}).get('port', self.port)

        # run in verbose mode (print all events)
        self.verbose = bool(
            int(config.get('general', {}).get('verbose', self.verbose)))

        # enable API access?
        self.api = bool(int(config.get('general', {}).get('api', self.api)))

        # Trading type
        self.trading_type = TradingType(
            config.get('general', {}).get('trading_type',
                                          'simulation').upper())

        # Load exchange instances
        self.exchanges = getExchanges(config.get('exchange',
                                                 {}).get('exchanges', []),
                                      trading_type=self.trading_type,
                                      verbose=self.verbose)

        # instantiate the Strategy Manager
        self.manager = StrategyManager(self, self.trading_type, self.exchanges)

        # set event loop to use uvloop
        if uvloop:
            asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

        # install event loop
        self.event_loop = asyncio.get_event_loop()

        # setup subscriptions
        self._handler_subscriptions = {
            m: []
            for m in EventType.__members__.values()
        }

        # setup `now` handler for backtest
        self._latest = datetime.fromtimestamp(0)

        # register internal management event handler before all strategy handlers
        self.registerHandler(self.manager)

        # install event handlers
        strategies = getStrategies(
            config.get('strategy', {}).get('strategies', []))
        for strategy in strategies:
            self.log.critical("Installing strategy: {}".format(strategy))
            self.registerHandler(strategy)

        # warn if no event handlers installed
        if not self.event_handlers:
            self.log.critical('Warning! No event handlers set')

        # install print handler if verbose
        if self.verbose:
            self.log.critical('Installing print handler')
            self.registerHandler(PrintHandler())

        # install webserver
        if self.api:
            self.log.critical('Installing API handlers')

            if PerspectiveManager is not None:
                table_handler = TableHandler()
                table_handler.installTables(self.table_manager)
                self.registerHandler(table_handler)

            self.api_handlers.append((r"/", RedirectHandler, {
                "url": "/index.html"
            }))
            self.api_handlers.append(
                (r"/api/v1/ws", PerspectiveTornadoHandler, {
                    "manager": self.table_manager,
                    "check_origin": True
                }))
            self.api_handlers.append((r"/static/js/(.*)", StaticFileHandler, {
                "path":
                os.path.join(os.path.dirname(__file__), '..', '..', 'ui',
                             'assets', 'static', 'js')
            }))
            self.api_handlers.append((r"/static/css/(.*)", StaticFileHandler, {
                "path":
                os.path.join(os.path.dirname(__file__), '..', '..', 'ui',
                             'assets', 'static', 'css')
            }))
            self.api_handlers.append(
                (r"/static/fonts/(.*)", StaticFileHandler, {
                    "path":
                    os.path.join(os.path.dirname(__file__), '..', '..', 'ui',
                                 'assets', 'static', 'fonts')
                }))
            self.api_handlers.append((r"/(.*)", StaticFileHandler, {
                "path":
                os.path.join(os.path.dirname(__file__), '..', '..', 'ui',
                             'assets', 'static', 'html')
            }))
            self.api_application = ServerApplication(
                handlers=self.api_handlers)
            self.log.critical('.......')
            self.log.critical(f'listening on 0.0.0.0:{self.port}')
            self.log.critical('.......')
            self.api_application.listen(self.port)

    def registerHandler(self, handler):
        '''register a handler and all callbacks that handler implements

        Args:
            handler (EventHandler): the event handler to register
        Returns:
            value (EventHandler or None): event handler if its new, else None
        '''
        if handler not in self.event_handlers:
            # append to handler list
            self.event_handlers.append(handler)

            # register callbacks for event types
            for type in EventType.__members__.values():
                # get callback or callback tuple
                # could be none if not implemented
                cbs = handler.callback(type)

                # if not a tuple, make for consistency
                if not isinstance(cbs, tuple):
                    cbs = (cbs, )

                for cb in cbs:
                    if cb:
                        self.registerCallback(type, cb, handler)
            handler._setManager(self.manager)
            return handler
        return None

    def _make_async(self, function):
        async def _wrapper(event):
            return await self.event_loop.run_in_executor(
                self.executor, function, event)

        return _wrapper

    def registerCallback(self, event_type, callback, handler=None):
        '''register a callback for a given event type

        Args:
            event_type (EventType): event type enum value to register
            callback (function): function to call on events of `event_type`
            handler (EventHandler): class holding the callback (optional)
        Returns:
            value (bool): True if registered (new), else False
        '''
        if (callback, handler) not in self._handler_subscriptions[event_type]:
            if not asyncio.iscoroutinefunction(callback):
                callback = self._make_async(callback)
            self._handler_subscriptions[event_type].append((callback, handler))
            return True
        return False

    def pushEvent(self, event):
        '''push non-exchange event into the queue'''
        self._queued_events.append(event)

    def pushTargetedEvent(self, strategy, event):
        '''push non-exchange event targeted to a specific strat into the queue'''
        self._queued_targeted_events.append((strategy, event))

    async def run(self):
        '''run the engine'''
        # setup future queue
        self._queued_events = deque()
        self._queued_targeted_events = deque()

        # await all connections
        await asyncio.gather(*(asyncio.create_task(exch.connect())
                               for exch in self.exchanges))
        await asyncio.gather(*(asyncio.create_task(exch.instruments())
                               for exch in self.exchanges))

        # send start event to all callbacks
        await self.tick(Event(type=EventType.START, target=None))

        async with merge(*(
                exch.tick() for exch in self.exchanges
                if inspect.isasyncgenfunction(exch.tick))).stream() as stream:
            # stream through all events
            async for event in stream:
                # tick exchange event to handlers
                await self.tick(event)

                # TODO move out of critical path
                self._latest = event.target.timestamp if hasattr(
                    event, 'target') and hasattr(event.target,
                                                 'timestamp') else self._latest

                # process any secondary events
                while self._queued_events:
                    event = self._queued_events.popleft()
                    await self.tick(event)

                # process any secondary events
                while self._queued_targeted_events:
                    strat, event = self._queued_targeted_events.popleft()
                    await self.tick(event, strat)

        await self.tick(Event(type=EventType.EXIT, target=None))

    async def tick(self, event, strategy=None):
        '''send an event to all registered event handlers

        Arguments:
            event (Event): event to send
        '''
        for callback, handler in self._handler_subscriptions[event.type]:
            # TODO make cleaner? move to somewhere not in critical path?
            if strategy is not None and (handler
                                         not in (strategy, self.manager)):
                continue

            # TODO make cleaner? move to somewhere not in critical path?
            if event.type in (EventType.TRADE, EventType.OPEN, EventType.CHANGE, EventType.CANCEL, EventType.DATA) and \
               not self.manager.dataSubscriptions(handler, event):
                continue

            try:
                await callback(event)
            except KeyboardInterrupt:
                raise
            except SystemExit:
                raise
            except BaseException as e:
                if event.type == EventType.ERROR:
                    # don't infinite error
                    raise
                await self.tick(
                    Event(type=EventType.ERROR,
                          target=Error(target=event,
                                       handler=handler,
                                       callback=callback,
                                       exception=e)))
                await asyncio.sleep(1)

    def now(self):
        '''Return the current datetime. Useful to avoid code changes between
        live trading and backtesting. Defaults to `datetime.now`'''
        return self._latest if self.trading_type == TradingType.BACKTEST else datetime.now(
        )

    def start(self):
        try:
            # if self.event_loop.is_running():
            #     # return future
            #     return asyncio.create_task(self.run())
            # block until done
            self.event_loop.run_until_complete(self.run())
        except KeyboardInterrupt:
            pass
        # send exit event to all callbacks
        asyncio.ensure_future(
            self.tick(Event(type=EventType.EXIT, target=None)))
Exemple #2
0
 def __init__(self, **kwargs):
     super(WidgetTraitTuple, self).__init__(Instance(Widget), Unicode(),
                                            **kwargs)
Exemple #3
0
class YAPKernelApp(BaseIPythonApplication, InteractiveShellApp,
                   ConnectionFileMixin):
    name = 'yapkernel'
    aliases = Dict(kernel_aliases)
    flags = Dict(kernel_flags)
    classes = [InteractiveShell, ZMQInteractiveShell, ProfileDir, Session]
    # the kernel class, as an importstring
    kernel_class = Type('yapkernel.ipkernel.YAPKernel',
                        klass='yapkernel.kernelbase.Kernel',
                        help="""The Kernel subclass to be used.

    This should allow easy re-use of the YAPKernelApp entry point
    to configure and launch kernels other than YAP's own.
    """).tag(config=True)
    kernel = Any()
    poller = Any(
    )  # don't restrict this even though current pollers are all Threads
    heartbeat = Instance(Heartbeat, allow_none=True)

    context = Any()
    shell_socket = Any()
    control_socket = Any()
    debugpy_socket = Any()
    debug_shell_socket = Any()
    stdin_socket = Any()
    iopub_socket = Any()
    iopub_thread = Any()
    control_thread = Any()

    _ports = Dict()

    subcommands = {
        'install': ('yapkernel.kernelspec.InstallYAPKernelSpecApp',
                    'Install the YAP kernel'),
    }

    # connection info:
    connection_dir = Unicode()

    @default('connection_dir')
    def _default_connection_dir(self):
        return jupyter_runtime_dir()

    @property
    def abs_connection_file(self):
        if os.path.basename(self.connection_file) == self.connection_file:
            return os.path.join(self.connection_dir, self.connection_file)
        else:
            return self.connection_file

    # streams, etc.
    no_stdout = Bool(
        False, help="redirect stdout to the null device").tag(config=True)
    no_stderr = Bool(
        False, help="redirect stderr to the null device").tag(config=True)
    trio_loop = Bool(False, help="Set main event loop.").tag(config=True)
    quiet = Bool(
        True, help="Only send stdout/stderr to output stream").tag(config=True)
    outstream_class = DottedObjectName(
        'yapkernel.iostream.OutStream',
        help="The importstring for the OutStream factory").tag(config=True)
    displayhook_class = DottedObjectName(
        'yapkernel.displayhook.ZMQDisplayHook',
        help="The importstring for the DisplayHook factory").tag(config=True)

    # polling
    parent_handle = Integer(
        int(os.environ.get('JPY_PARENT_PID') or 0),
        help="""kill this process if its parent dies.  On Windows, the argument
        specifies the HANDLE of the parent process, otherwise it is simply boolean.
        """).tag(config=True)
    interrupt = Integer(int(os.environ.get('JPY_INTERRUPT_EVENT') or 0),
                        help="""ONLY USED ON WINDOWS
        Interrupt this process when the parent is signaled.
        """).tag(config=True)

    def init_crash_handler(self):
        sys.excepthook = self.excepthook

    def excepthook(self, etype, evalue, tb):
        # write uncaught traceback to 'real' stderr, not zmq-forwarder
        traceback.print_exception(etype, evalue, tb, file=sys.__stderr__)

    def init_poller(self):
        if sys.platform == 'win32':
            if self.interrupt or self.parent_handle:
                self.poller = ParentPollerWindows(self.interrupt,
                                                  self.parent_handle)
        elif self.parent_handle and self.parent_handle != 1:
            # PID 1 (init) is special and will never go away,
            # only be reassigned.
            # Parent polling doesn't work if ppid == 1 to start with.
            self.poller = ParentPollerUnix()

    def _try_bind_socket(self, s, port):
        iface = '%s://%s' % (self.transport, self.ip)
        if self.transport == 'tcp':
            if port <= 0:
                port = s.bind_to_random_port(iface)
            else:
                s.bind("tcp://%s:%i" % (self.ip, port))
        elif self.transport == 'ipc':
            if port <= 0:
                port = 1
                path = "%s-%i" % (self.ip, port)
                while os.path.exists(path):
                    port = port + 1
                    path = "%s-%i" % (self.ip, port)
            else:
                path = "%s-%i" % (self.ip, port)
            s.bind("ipc://%s" % path)
        return port

    def _bind_socket(self, s, port):
        try:
            win_in_use = errno.WSAEADDRINUSE
        except AttributeError:
            win_in_use = None

        # Try up to 100 times to bind a port when in conflict to avoid
        # infinite attempts in bad setups
        max_attempts = 1 if port else 100
        for attempt in range(max_attempts):
            try:
                return self._try_bind_socket(s, port)
            except zmq.ZMQError as ze:
                # Raise if we have any error not related to socket binding
                if ze.errno != errno.EADDRINUSE and ze.errno != win_in_use:
                    raise
                if attempt == max_attempts - 1:
                    raise

    def write_connection_file(self):
        """write connection info to JSON file"""
        cf = self.abs_connection_file
        self.log.debug("Writing connection file: %s", cf)
        write_connection_file(cf,
                              ip=self.ip,
                              key=self.session.key,
                              transport=self.transport,
                              shell_port=self.shell_port,
                              stdin_port=self.stdin_port,
                              hb_port=self.hb_port,
                              iopub_port=self.iopub_port,
                              control_port=self.control_port)

    def cleanup_connection_file(self):
        cf = self.abs_connection_file
        self.log.debug("Cleaning up connection file: %s", cf)
        try:
            os.remove(cf)
        except (IOError, OSError):
            pass

        self.cleanup_ipc_files()

    def init_connection_file(self):
        if not self.connection_file:
            self.connection_file = "kernel-%s.json" % os.getpid()
        try:
            self.connection_file = filefind(self.connection_file,
                                            ['.', self.connection_dir])
        except IOError:
            self.log.debug("Connection file not found: %s",
                           self.connection_file)
            # This means I own it, and I'll create it in this directory:
            ensure_dir_exists(os.path.dirname(self.abs_connection_file), 0o700)
            # Also, I will clean it up:
            atexit.register(self.cleanup_connection_file)
            return
        try:
            self.load_connection_file()
        except Exception:
            self.log.error("Failed to load connection file: %r",
                           self.connection_file,
                           exc_info=True)
            self.exit(1)

    def init_sockets(self):
        # Create a context, a session, and the kernel sockets.
        self.log.info("Starting the kernel at pid: %i", os.getpid())
        assert self.context is None, "init_sockets cannot be called twice!"
        self.context = context = zmq.Context()
        atexit.register(self.close)

        self.shell_socket = context.socket(zmq.ROUTER)
        self.shell_socket.linger = 1000
        self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
        self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port)

        self.stdin_socket = context.socket(zmq.ROUTER)
        self.stdin_socket.linger = 1000
        self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
        self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port)

        if hasattr(zmq, 'ROUTER_HANDOVER'):
            # set router-handover to workaround zeromq reconnect problems
            # in certain rare circumstances
            # see ipython/yapkernel#270 and zeromq/libzmq#2892
            self.shell_socket.router_handover = \
                self.stdin_socket.router_handover = 1

        self.init_control(context)
        self.init_iopub(context)

    def init_control(self, context):
        self.control_socket = context.socket(zmq.ROUTER)
        self.control_socket.linger = 1000
        self.control_port = self._bind_socket(self.control_socket,
                                              self.control_port)
        self.log.debug("control ROUTER Channel on port: %i" %
                       self.control_port)

        self.debugpy_socket = context.socket(zmq.STREAM)
        self.debugpy_socket.linger = 1000

        self.debug_shell_socket = context.socket(zmq.DEALER)
        self.debug_shell_socket.linger = 1000
        if self.shell_socket.getsockopt(zmq.LAST_ENDPOINT):
            self.debug_shell_socket.connect(
                self.shell_socket.getsockopt(zmq.LAST_ENDPOINT))

        if hasattr(zmq, 'ROUTER_HANDOVER'):
            # set router-handover to workaround zeromq reconnect problems
            # in certain rare circumstances
            # see ipython/yapkernel#270 and zeromq/libzmq#2892
            self.control_socket.router_handover = 1

        self.control_thread = ControlThread(daemon=True)

    def init_iopub(self, context):
        self.iopub_socket = context.socket(zmq.PUB)
        self.iopub_socket.linger = 1000
        self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
        self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)
        self.configure_tornado_logger()
        self.iopub_thread = IOPubThread(self.iopub_socket, pipe=True)
        self.iopub_thread.start()
        # backward-compat: wrap iopub socket API in background thread
        self.iopub_socket = self.iopub_thread.background_socket

    def init_heartbeat(self):
        """start the heart beating"""
        # heartbeat doesn't share context, because it mustn't be blocked
        # by the GIL, which is accessed by libzmq when freeing zero-copy messages
        hb_ctx = zmq.Context()
        self.heartbeat = Heartbeat(hb_ctx,
                                   (self.transport, self.ip, self.hb_port))
        self.hb_port = self.heartbeat.port
        self.log.debug("Heartbeat REP Channel on port: %i" % self.hb_port)
        self.heartbeat.start()

    def close(self):
        """Close zmq sockets in an orderly fashion"""
        # un-capture IO before we start closing channels
        self.reset_io()
        self.log.info("Cleaning up sockets")
        if self.heartbeat:
            self.log.debug("Closing heartbeat channel")
            self.heartbeat.context.term()
        if self.iopub_thread:
            self.log.debug("Closing iopub channel")
            self.iopub_thread.stop()
            self.iopub_thread.close()
        if self.control_thread and self.control_thread.is_alive():
            self.log.debug("Closing control thread")
            self.control_thread.stop()
            self.control_thread.join()

        if self.debugpy_socket and not self.debugpy_socket.closed:
            self.debugpy_socket.close()
        if self.debug_shell_socket and not self.debug_shell_socket.closed:
            self.debug_shell_socket.close()

        for channel in ('shell', 'control', 'stdin'):
            self.log.debug("Closing %s channel", channel)
            socket = getattr(self, channel + "_socket", None)
            if socket and not socket.closed:
                socket.close()
        self.log.debug("Terminating zmq context")
        self.context.term()
        self.log.debug("Terminated zmq context")

    def log_connection_info(self):
        """display connection info, and store ports"""
        basename = os.path.basename(self.connection_file)
        if basename == self.connection_file or \
            os.path.dirname(self.connection_file) == self.connection_dir:
            # use shortname
            tail = basename
        else:
            tail = self.connection_file
        lines = [
            "To connect another client to this kernel, use:",
            "    --existing %s" % tail,
        ]
        # log connection info
        # info-level, so often not shown.
        # frontends should use the %connect_info magic
        # to see the connection info
        for line in lines:
            self.log.info(line)
        # also raw print to the terminal if no parent_handle (`ipython kernel`)
        # unless log-level is CRITICAL (--quiet)
        if not self.parent_handle and self.log_level < logging.CRITICAL:
            print(_ctrl_c_message, file=sys.__stdout__)
            for line in lines:
                print(line, file=sys.__stdout__)

        self._ports = dict(shell=self.shell_port,
                           iopub=self.iopub_port,
                           stdin=self.stdin_port,
                           hb=self.hb_port,
                           control=self.control_port)

    def init_blackhole(self):
        """redirects stdout/stderr to devnull if necessary"""
        if self.no_stdout or self.no_stderr:
            blackhole = open(os.devnull, 'w')
            if self.no_stdout:
                sys.stdout = sys.__stdout__ = blackhole
            if self.no_stderr:
                sys.stderr = sys.__stderr__ = blackhole

    def init_io(self):
        """Redirect input streams and set a display hook."""
        if self.outstream_class:
            outstream_factory = import_item(str(self.outstream_class))
            if sys.stdout is not None:
                sys.stdout.flush()

            e_stdout = None if self.quiet else sys.__stdout__
            e_stderr = None if self.quiet else sys.__stderr__

            sys.stdout = outstream_factory(self.session,
                                           self.iopub_thread,
                                           'stdout',
                                           echo=e_stdout)
            if sys.stderr is not None:
                sys.stderr.flush()
            sys.stderr = outstream_factory(self.session,
                                           self.iopub_thread,
                                           "stderr",
                                           echo=e_stderr)
            if hasattr(sys.stderr, "_original_stdstream_copy"):

                for handler in self.log.handlers:
                    if isinstance(handler, StreamHandler) and (
                            handler.stream.buffer.fileno() == 2):
                        self.log.debug(
                            "Seeing logger to stderr, rerouting to raw filedescriptor."
                        )

                        handler.stream = TextIOWrapper(
                            FileIO(sys.stderr._original_stdstream_copy, "w"))
        if self.displayhook_class:
            displayhook_factory = import_item(str(self.displayhook_class))
            self.displayhook = displayhook_factory(self.session,
                                                   self.iopub_socket)
            sys.displayhook = self.displayhook

        self.patch_io()

    def reset_io(self):
        """restore original io

        restores state after init_io
        """
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__
        sys.displayhook = sys.__displayhook__

    def patch_io(self):
        """Patch important libraries that can't handle sys.stdout forwarding"""
        try:
            import faulthandler
        except ImportError:
            pass
        else:
            # Warning: this is a monkeypatch of `faulthandler.enable`, watch for possible
            # updates to the upstream API and update accordingly (up-to-date as of Python 3.5):
            # https://docs.python.org/3/library/faulthandler.html#faulthandler.enable

            # change default file to __stderr__ from forwarded stderr
            faulthandler_enable = faulthandler.enable

            def enable(file=sys.__stderr__, all_threads=True, **kwargs):
                return faulthandler_enable(file=file,
                                           all_threads=all_threads,
                                           **kwargs)

            faulthandler.enable = enable

            if hasattr(faulthandler, 'register'):
                faulthandler_register = faulthandler.register

                def register(signum,
                             file=sys.__stderr__,
                             all_threads=True,
                             chain=False,
                             **kwargs):
                    return faulthandler_register(signum,
                                                 file=file,
                                                 all_threads=all_threads,
                                                 chain=chain,
                                                 **kwargs)

                faulthandler.register = register

    def init_signal(self):
        signal.signal(signal.SIGINT, signal.SIG_IGN)

    def init_kernel(self):
        """Create the Kernel object itself"""
        shell_stream = ZMQStream(self.shell_socket)
        control_stream = ZMQStream(self.control_socket,
                                   self.control_thread.io_loop)
        debugpy_stream = ZMQStream(self.debugpy_socket,
                                   self.control_thread.io_loop)
        self.control_thread.start()
        kernel_factory = self.kernel_class.instance

        kernel = kernel_factory(
            parent=self,
            session=self.session,
            control_stream=control_stream,
            debugpy_stream=debugpy_stream,
            debug_shell_socket=self.debug_shell_socket,
            shell_stream=shell_stream,
            control_thread=self.control_thread,
            iopub_thread=self.iopub_thread,
            iopub_socket=self.iopub_socket,
            stdin_socket=self.stdin_socket,
            log=self.log,
            profile_dir=self.profile_dir,
            user_ns=self.user_ns,
        )
        kernel.record_ports(
            {name + '_port': port
             for name, port in self._ports.items()})
        self.kernel = kernel

        # Allow the displayhook to get the execution count
        self.displayhook.get_execution_count = lambda: kernel.execution_count

    def init_gui_pylab(self):
        print("gui")
        """Enable GUI event loop integration, taking pylab into account."""

        # Register inline backend as default
        # this is higher priority than matplotlibrc,
        # but lower priority than anything else (mpl.use() for instance).
        # This only affects matplotlib >= 1.5
        if not os.environ.get('MPLBACKEND'):
            os.environ[
                'MPLBACKEND'] = 'module://matplotlib_inline.backend_inline'

        # Provide a wrapper for :meth:`InteractiveShellApp.init_gui_pylab`
        # to ensure that any exception is printed straight to stderr.
        # Normally _showtraceback associates the reply with an execution,
        # which means frontends will never draw it, as this exception
        # is not associated with any execute request.

        shell = self.shell
        _showtraceback = shell._showtraceback
        try:
            # replace error-sending traceback with stderr
            def print_tb(etype, evalue, stb):
                print("GUI event loop or pylab initialization failed",
                      file=sys.stderr)
                print(shell.InteractiveTB.stb2text(stb), file=sys.stderr)

            shell._showtraceback = print_tb
            InteractiveShellApp.init_gui_pylab(self)
        finally:
            shell._showtraceback = _showtraceback

    def init_shell(self):
        self.shell = getattr(self.kernel, 'shell', None)
        if self.shell:
            self.shell.configurables.append(self)

    def configure_tornado_logger(self):
        """ Configure the tornado logging.Logger.

        Must set up the tornado logger or else tornado will call
        basicConfig for the root logger which makes the root logger
        go to the real sys.stderr instead of the capture streams.
        This function mimics the setup of logging.basicConfig.
        """
        logger = logging.getLogger('tornado')
        handler = logging.StreamHandler()
        formatter = logging.Formatter(logging.BASIC_FORMAT)
        handler.setFormatter(formatter)
        logger.addHandler(handler)

    def _init_asyncio_patch(self):
        """set default asyncio policy to be compatible with tornado

        Tornado 6 (at least) is not compatible with the default
        asyncio implementation on Windows

        Pick the older SelectorEventLoopPolicy on Windows
        if the known-incompatible default policy is in use.

        Support for Proactor via a background thread is available in tornado 6.1,
        but it is still preferable to run the Selector in the main thread
        instead of the background.

        do this as early as possible to make it a low priority and overrideable

        ref: https://github.com/tornadoweb/tornado/issues/2608

        FIXME: if/when tornado supports the defaults in asyncio without threads,
               remove and bump tornado requirement for py38.
               Most likely, this will mean a new Python version
               where asyncio.ProactorEventLoop supports add_reader and friends.

        """
        if sys.platform.startswith("win") and sys.version_info >= (3, 8):
            import asyncio
            try:
                from asyncio import (
                    WindowsProactorEventLoopPolicy,
                    WindowsSelectorEventLoopPolicy,
                )
            except ImportError:
                pass
                # not affected
            else:
                if type(asyncio.get_event_loop_policy()
                        ) is WindowsProactorEventLoopPolicy:
                    # WindowsProactorEventLoopPolicy is not compatible with tornado 6
                    # fallback to the pre-3.8 default of Selector
                    asyncio.set_event_loop_policy(
                        WindowsSelectorEventLoopPolicy())

    def init_pdb(self):
        """Replace pdb with IPython's version that is interruptible.

        With the non-interruptible version, stopping pdb() locks up the kernel in a
        non-recoverable state.
        """
        import pdb
        from IPython.core import debugger
        if hasattr(debugger, "InterruptiblePdb"):
            # Only available in newer IPython releases:
            debugger.Pdb = debugger.InterruptiblePdb
            pdb.Pdb = debugger.Pdb
            pdb.set_trace = debugger.set_trace

    @catch_config_error
    def initialize(self, argv=None):
        self._init_asyncio_patch()
        super(YAPKernelApp, self).initialize(argv)
        if self.subapp is not None:
            return

        self.init_pdb()
        self.init_blackhole()
        self.init_connection_file()
        self.init_poller()
        self.init_sockets()
        self.init_heartbeat()
        # writing/displaying connection info must be *after* init_sockets/heartbeat
        self.write_connection_file()
        # Log connection info after writing connection file, so that the connection
        # file is definitely available at the time someone reads the log.
        self.log_connection_info()
        self.init_io()
        try:
            self.init_signal()
        except Exception:
            # Catch exception when initializing signal fails, eg when running the
            # kernel on a separate thread
            if self.log_level < logging.CRITICAL:
                self.log.error("Unable to initialize signal:", exc_info=True)
        self.init_kernel()
        # shell init steps
        self.init_path()
        self.init_shell()
        if self.shell:
            #InteractiveShell.run_cell_async = Jupyter4YAP.run_cell_async
            # InteractiveShell.split_cell = Jupyter4YAP.split_cell
            # InteractiveShell.prolog_call = Jupyter4YAP.prolog_call
            # InteractiveShell.prolog = Jupyter4YAP.prolog
            # InteractiveShell.syntaxErrors = Jupyter4YAP.syntaxErrors
            #InteractiveShell.YAPinit = Jupyter4YAP.init
            InteractiveShell.showindentationerror = lambda self: False
            self.init_gui_pylab()
            self.init_extensions()
            self.init_code()


#flush stdout/stderr, so that anything written to these streams during
# initialization do not get associated with the first execution request
        sys.stdout.flush()
        sys.stderr.flush()

    def start(self):
        # InteractiveShell.prolog=Jupyter4YAP.prolog
        # InteractiveShell.prolog_call=Jupyter4YAP.prolog_call
        InteractiveShell.run_cell_async = Jupyter4YAP.run_cell_async
        TransformerManager.old_tm = TransformerManager.transform_cell
        TransformerManager.transform_cell = Jupyter4YAP.transform_cell
        TransformerManager.old_checc = TransformerManager.check_complete
        TransformerManager.check_complete = Jupyter4YAP.check_complete
        InteractiveShell.complete = Jupyter4YAP.complete
        InteractiveShell.showindentationerror = lambda self: False
        #self.yap = Jupyter4YAP(self)
        if self.subapp is not None:
            return self.subapp.start()
        if self.poller is not None:
            self.poller.start()
        self.kernel.start()
        self.io_loop = ioloop.IOLoop.current()
        if self.trio_loop:
            from yapkernel.trio_runner import TrioRunner
            tr = TrioRunner()
            tr.initialize(self.kernel, self.io_loop)
            try:
                tr.run()
            except KeyboardInterrupt:
                pass
        else:
            try:
                self.io_loop.start()
            except KeyboardIntberrupt:
                pass
Exemple #4
0
class Annotation(DOMWidget):
    _view_name = Unicode('AnnotationView').tag(sync=True)
    _model_name = Unicode('AnnotationModel').tag(sync=True)

    toolbar = Instance(Toolbar).tag(sync=True, **widget_serialization)
    progress = Instance(Progress).tag(sync=True, **widget_serialization)
    canvas = Instance(Canvas).tag(sync=True, **widget_serialization)
    tasks = Instance(Tasks)

    def __init__(self, toolbar, tasks, progress=None, canvas=None):
        if canvas is None:
            canvas = OutputCanvas()
        if progress is None:
            progress = Progress()
        super(Annotation, self).__init__(toolbar=toolbar,
                                         progress=progress,
                                         canvas=canvas,
                                         tasks=tasks)
        self.toolbar.register(self)
        self.progress.register(self)
        self.on_msg(self.handle_message)
        self.observe(self.update_toolbar, 'toolbar', type='change')
        self.observe(self.update_progress, 'progress', type='change')
        self.observe(self.update_canvas, 'canvas', type='change')
        self.observe(self.update_tasks, 'tasks', type='change')
        self.update()

    def update_toolbar(self, data):
        toolbar = data['new']
        toolbar.register(self)
        self.update()

    def update_progress(self, data):
        progress = data['new']
        progress.register(self)
        self.update()

    def update_canvas(self, data):
        self.update()

    def update_tasks(self, data):
        self.update()

    def handle_keypress(self, key):
        button = self.toolbar.find(shortcut=key)
        if button:
            button.click()

    def handle_message(self, _, content, buffers):
        event = content['event']
        if event == 'keypress':
            code = content['code']
            key = decode_key(code)
            if key:
                self.handle_keypress(key)

    def update(self):
        task = self.tasks.current
        self.toolbar.update(task.value)
        self.progress.update()
        self.canvas.render(task.output)

    def next(self):
        self.tasks.next()
        self.update()

    def back(self):
        self.tasks.back()
        self.update()
Exemple #5
0
class ZMQInteractiveShell(InteractiveShell):
    """A subclass of InteractiveShell for ZMQ."""

    displayhook_class = Type(ZMQShellDisplayHook)
    display_pub_class = Type(ZMQDisplayPublisher)
    data_pub_class = Any()
    kernel = Any()
    parent_header = Any()

    @default("banner1")
    def _default_banner1(self):
        return default_banner

    # Override the traitlet in the parent class, because there's no point using
    # readline for the kernel. Can be removed when the readline code is moved
    # to the terminal frontend.
    readline_use = CBool(False)
    # autoindent has no meaning in a zmqshell, and attempting to enable it
    # will print a warning in the absence of readline.
    autoindent = CBool(False)

    exiter = Instance(ZMQExitAutocall)

    @default("exiter")
    def _default_exiter(self):
        return ZMQExitAutocall(self)

    @observe("exit_now")
    def _update_exit_now(self, change):
        """stop eventloop when exit_now fires"""
        if change["new"]:
            if hasattr(self.kernel, "io_loop"):
                loop = self.kernel.io_loop
                loop.call_later(0.1, loop.stop)
            if self.kernel.eventloop:
                exit_hook = getattr(self.kernel.eventloop, "exit_hook", None)
                if exit_hook:
                    exit_hook(self.kernel)

    keepkernel_on_exit = None

    # Over ZeroMQ, GUI control isn't done with PyOS_InputHook as there is no
    # interactive input being read; we provide event loop support in ipkernel
    def enable_gui(self, gui):
        from .eventloops import enable_gui as real_enable_gui

        try:
            real_enable_gui(gui)
            self.active_eventloop = gui
        except ValueError as e:
            raise UsageError("%s" % e) from e

    def init_environment(self):
        """Configure the user's environment."""
        env = os.environ
        # These two ensure 'ls' produces nice coloring on BSD-derived systems
        env["TERM"] = "xterm-color"
        env["CLICOLOR"] = "1"
        # Since normal pagers don't work at all (over pexpect we don't have
        # single-key control of the subprocess), try to disable paging in
        # subprocesses as much as possible.
        env["PAGER"] = "cat"
        env["GIT_PAGER"] = "cat"

    def init_hooks(self):
        super().init_hooks()
        self.set_hook("show_in_pager", page.as_hook(payloadpage.page), 99)

    def init_data_pub(self):
        """Delay datapub init until request, for deprecation warnings"""
        pass

    @property
    def data_pub(self):
        if not hasattr(self, "_data_pub"):
            warnings.warn(
                "InteractiveShell.data_pub is deprecated outside IPython parallel.",
                DeprecationWarning,
                stacklevel=2,
            )

            self._data_pub = self.data_pub_class(parent=self)
            self._data_pub.session = self.display_pub.session
            self._data_pub.pub_socket = self.display_pub.pub_socket
        return self._data_pub

    @data_pub.setter
    def data_pub(self, pub):
        self._data_pub = pub

    def ask_exit(self):
        """Engage the exit actions."""
        self.exit_now = not self.keepkernel_on_exit
        payload = dict(
            source="ask_exit",
            keepkernel=self.keepkernel_on_exit,
        )
        self.payload_manager.write_payload(payload)

    def run_cell(self, *args, **kwargs):
        self._last_traceback = None
        return super().run_cell(*args, **kwargs)

    def _showtraceback(self, etype, evalue, stb):
        # try to preserve ordering of tracebacks and print statements
        sys.stdout.flush()
        sys.stderr.flush()

        exc_content = {
            "traceback": stb,
            "ename": str(etype.__name__),
            "evalue": str(evalue),
        }

        dh = self.displayhook
        # Send exception info over pub socket for other clients than the caller
        # to pick up
        topic = None
        if dh.topic:
            topic = dh.topic.replace(b"execute_result", b"error")

        dh.session.send(
            dh.pub_socket,
            "error",
            json_clean(exc_content),
            dh.parent_header,
            ident=topic,
        )

        # FIXME - Once we rely on Python 3, the traceback is stored on the
        # exception object, so we shouldn't need to store it here.
        self._last_traceback = stb

    def set_next_input(self, text, replace=False):
        """Send the specified text to the frontend to be presented at the next
        input cell."""
        payload = dict(
            source="set_next_input",
            text=text,
            replace=replace,
        )
        self.payload_manager.write_payload(payload)

    def set_parent(self, parent):
        """Set the parent header for associating output with its triggering input"""
        self.parent_header = parent
        self.displayhook.set_parent(parent)
        self.display_pub.set_parent(parent)
        if hasattr(self, "_data_pub"):
            self.data_pub.set_parent(parent)
        try:
            sys.stdout.set_parent(parent)  # type:ignore[attr-defined]
        except AttributeError:
            pass
        try:
            sys.stderr.set_parent(parent)  # type:ignore[attr-defined]
        except AttributeError:
            pass

    def get_parent(self):
        return self.parent_header

    def init_magics(self):
        super().init_magics()
        self.register_magics(KernelMagics)
        self.magics_manager.register_alias("ed", "edit")

    def init_virtualenv(self):
        # Overridden not to do virtualenv detection, because it's probably
        # not appropriate in a kernel. To use a kernel in a virtualenv, install
        # it inside the virtualenv.
        # https://ipython.readthedocs.io/en/latest/install/kernel_install.html
        pass

    def system_piped(self, cmd):
        """Call the given cmd in a subprocess, piping stdout/err

        Parameters
        ----------
        cmd : str
            Command to execute (can not end in '&', as background processes are
            not supported.  Should not be a command that expects input
            other than simple text.
        """
        if cmd.rstrip().endswith("&"):
            # this is *far* from a rigorous test
            # We do not support backgrounding processes because we either use
            # pexpect or pipes to read from.  Users can always just call
            # os.system() or use ip.system=ip.system_raw
            # if they really want a background process.
            raise OSError("Background processes not supported.")

        # we explicitly do NOT return the subprocess status code, because
        # a non-None value would trigger :func:`sys.displayhook` calls.
        # Instead, we store the exit_code in user_ns.
        # Also, protect system call from UNC paths on Windows here too
        # as is done in InteractiveShell.system_raw
        if sys.platform == "win32":
            cmd = self.var_expand(cmd, depth=1)
            from IPython.utils._process_win32 import AvoidUNCPath

            with AvoidUNCPath() as path:
                if path is not None:
                    cmd = "pushd %s &&%s" % (path, cmd)
                self.user_ns["_exit_code"] = system(cmd)
        else:
            self.user_ns["_exit_code"] = system(self.var_expand(cmd, depth=1))

    # Ensure new system_piped implementation is used
    system = system_piped
Exemple #6
0
class Map(DOMWidget, InteractMixin):
    @default('layout')
    def _default_layout(self):
        return Layout(height='400px', align_self='stretch')

    _view_name = Unicode('LeafletMapView').tag(sync=True)
    _model_name = Unicode('LeafletMapModel').tag(sync=True)
    _view_module = Unicode('jupyter-leaflet').tag(sync=True)
    _model_module = Unicode('jupyter-leaflet').tag(sync=True)
    _view_module_version = Unicode(EXTENSION_VERSION).tag(sync=True)
    _model_module_version = Unicode(EXTENSION_VERSION).tag(sync=True)

    # Map options
    center = List(def_loc).tag(sync=True, o=True)
    zoom_start = Int(12).tag(sync=True, o=True)
    zoom = Int(12).tag(sync=True, o=True)
    max_zoom = Int(18).tag(sync=True, o=True)
    min_zoom = Int(1).tag(sync=True, o=True)

    # Specification of the basemap
    basemap = Dict(default_value=dict(
        url='https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png',
        max_zoom=19,
        attribution=
        'Map data (c) <a href="https://openstreetmap.org">OpenStreetMap</a> contributors'
    )).tag(sync=True, o=True)
    modisdate = Unicode('yesterday').tag(sync=True)

    # Interaction options
    dragging = Bool(True).tag(sync=True, o=True)
    touch_zoom = Bool(True).tag(sync=True, o=True)
    scroll_wheel_zoom = Bool(False).tag(sync=True, o=True)
    double_click_zoom = Bool(True).tag(sync=True, o=True)
    box_zoom = Bool(True).tag(sync=True, o=True)
    tap = Bool(True).tag(sync=True, o=True)
    tap_tolerance = Int(15).tag(sync=True, o=True)
    world_copy_jump = Bool(False).tag(sync=True, o=True)
    close_popup_on_click = Bool(True).tag(sync=True, o=True)
    bounce_at_zoom_limits = Bool(True).tag(sync=True, o=True)
    keyboard = Bool(True).tag(sync=True, o=True)
    keyboard_pan_offset = Int(80).tag(sync=True, o=True)
    keyboard_zoom_offset = Int(1).tag(sync=True, o=True)
    inertia = Bool(True).tag(sync=True, o=True)
    inertia_deceleration = Int(3000).tag(sync=True, o=True)
    inertia_max_speed = Int(1500).tag(sync=True, o=True)
    # inertia_threshold = Int(?, o=True).tag(sync=True)
    zoom_control = Bool(True).tag(sync=True, o=True)
    attribution_control = Bool(True).tag(sync=True, o=True)
    # fade_animation = Bool(?).tag(sync=True, o=True)
    # zoom_animation = Bool(?).tag(sync=True, o=True)
    zoom_animation_threshold = Int(4).tag(sync=True, o=True)
    # marker_zoom_animation = Bool(?).tag(sync=True, o=True)

    options = List(trait=Unicode).tag(sync=True)

    @default('options')
    def _default_options(self):
        return [name for name in self.traits(o=True)]

    _south = Float(def_loc[0]).tag(sync=True)
    _north = Float(def_loc[0]).tag(sync=True)
    _east = Float(def_loc[1]).tag(sync=True)
    _west = Float(def_loc[1]).tag(sync=True)

    default_tiles = Instance(TileLayer,
                             allow_none=True).tag(sync=True,
                                                  **widget_serialization)

    @default('default_tiles')
    def _default_tiles(self):
        return basemap_to_tiles(self.basemap, self.modisdate)

    @property
    def north(self):
        return self._north

    @property
    def south(self):
        return self._south

    @property
    def east(self):
        return self._east

    @property
    def west(self):
        return self._west

    @property
    def bounds_polygon(self):
        return [(self.north, self.west), (self.north, self.east),
                (self.south, self.east), (self.south, self.west)]

    @property
    def bounds(self):
        return [(self.south, self.west), (self.north, self.east)]

    def __init__(self, **kwargs):
        super(Map, self).__init__(**kwargs)
        self.on_displayed(self._fire_children_displayed)
        if self.default_tiles is not None:
            self.layers = (self.default_tiles, )
        self.on_msg(self._handle_leaflet_event)

    def _fire_children_displayed(self, widget, **kwargs):
        for layer in self.layers:
            layer._handle_displayed(**kwargs)
        for control in self.controls:
            control._handle_displayed(**kwargs)

    layers = Tuple(trait=Instance(Layer)).tag(sync=True,
                                              **widget_serialization)
    layer_ids = List()

    @validate('layers')
    def _validate_layers(self, proposal):
        """Validate layers list.

        Makes sure only one instance of any given layer can exist in the
        layers list.
        """
        self.layer_ids = [l.model_id for l in proposal['value']]
        if len(set(self.layer_ids)) != len(self.layer_ids):
            raise LayerException(
                'duplicate layer detected, only use each layer once')
        return proposal['value']

    def add_layer(self, layer):
        if layer.model_id in self.layer_ids:
            raise LayerException('layer already on map: %r' % layer)
        layer._map = self
        self.layers = tuple([l for l in self.layers] + [layer])
        layer.visible = True

    def remove_layer(self, layer):
        if layer.model_id not in self.layer_ids:
            raise LayerException('layer not on map: %r' % layer)
        self.layers = tuple(
            [l for l in self.layers if l.model_id != layer.model_id])
        layer.visible = False

    def clear_layers(self):
        self.layers = ()

    controls = Tuple(trait=Instance(Control)).tag(sync=True,
                                                  **widget_serialization)
    control_ids = List()

    @validate('controls')
    def _validate_controls(self, proposal):
        """Validate controls list.

        Makes sure only one instance of any given layer can exist in the
        controls list.
        """
        self.control_ids = [c.model_id for c in proposal['value']]
        if len(set(self.control_ids)) != len(self.control_ids):
            raise ControlException(
                'duplicate control detected, only use each control once')
        return proposal['value']

    def add_control(self, control):
        if control.model_id in self.control_ids:
            raise ControlException('control already on map: %r' % control)
        control._map = self
        self.controls = tuple([c for c in self.controls] + [control])
        control.visible = True

    def remove_control(self, control):
        if control.model_id not in self.control_ids:
            raise ControlException('control not on map: %r' % control)
        self.controls = tuple(
            [c for c in self.controls if c.model_id != control.model_id])
        control.visible = False

    def clear_controls(self):
        self.controls = ()

    def __iadd__(self, item):
        if isinstance(item, Layer):
            self.add_layer(item)
        elif isinstance(item, Control):
            self.add_control(item)
        return self

    def __isub__(self, item):
        if isinstance(item, Layer):
            self.remove_layer(item)
        elif isinstance(item, Control):
            self.remove_control(item)
        return self

    def __add__(self, item):
        if isinstance(item, Layer):
            self.add_layer(item)
        elif isinstance(item, Control):
            self.add_control(item)
        return self

    # Event handling
    _moveend_callbacks = Instance(CallbackDispatcher, ())

    def _handle_leaflet_event(self, _, content, buffers):
        if content.get('event', '') == 'moveend':
            self._moveend_callbacks(**content)

    def on_moveend(self, callback, remove=False):
        self._moveend_callbacks.register_callback(callback, remove=remove)
Exemple #7
0
class Scheduler(LoggingConfigurable):

    loop = Instance(ioloop.IOLoop)

    @default("loop")
    def _default_loop(self):
        return ioloop.IOLoop.current()

    session = Instance(jupyter_client.session.Session)

    @default("session")
    def _default_session(self):
        return jupyter_client.session.Session(parent=self)

    client_stream = Instance(
        zmqstream.ZMQStream, allow_none=True
    )  # client-facing stream
    engine_stream = Instance(
        zmqstream.ZMQStream, allow_none=True
    )  # engine-facing stream
    notifier_stream = Instance(
        zmqstream.ZMQStream, allow_none=True
    )  # hub-facing sub stream
    mon_stream = Instance(zmqstream.ZMQStream, allow_none=True)  # hub-facing pub stream
    query_stream = Instance(
        zmqstream.ZMQStream, allow_none=True
    )  # hub-facing DEALER stream

    all_completed = Set()  # set of all completed tasks
    all_failed = Set()  # set of all failed tasks
    all_done = Set()  # set of all finished tasks=union(completed,failed)
    all_ids = Set()  # set of all submitted task IDs

    ident = Bytes()  # ZMQ identity. This should just be self.session.session as bytes

    # but ensure Bytes
    @default("ident")
    def _ident_default(self):
        return self.session.bsession

    def start(self):
        self.engine_stream.on_recv(self.dispatch_result, copy=False)
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def resume_receiving(self):
        """Resume accepting jobs."""
        self.client_stream.on_recv(self.dispatch_submission, copy=False)

    def stop_receiving(self):
        """Stop accepting jobs while there are no engines.
        Leave them in the ZMQ queue."""
        self.client_stream.on_recv(None)

    def dispatch_result(self, raw_msg):
        raise NotImplementedError("Implement in subclasses")

    def dispatch_submission(self, raw_msg):
        raise NotImplementedError("Implement in subclasses")

    def append_new_msg_id_to_msg(self, new_id, target_id, idents, msg):
        if isinstance(target_id, str):
            target_id = target_id.encode("utf8")
        new_idents = [target_id] + idents
        msg['header']['msg_id'] = new_id
        new_msg_list = self.session.serialize(msg, ident=new_idents)
        new_msg_list.extend(msg['buffers'])
        return new_msg_list

    def get_new_msg_id(self, original_msg_id, outgoing_id):
        return f'{original_msg_id}_{outgoing_id if isinstance(outgoing_id, str) else outgoing_id.decode("utf8")}'
Exemple #8
0
class ImageRecorder(Recorder):
    """Creates a recorder which allows to grab an Image from a MediaStream widget.
    """
    _model_name = Unicode('ImageRecorderModel').tag(sync=True)
    _view_name = Unicode('ImageRecorderView').tag(sync=True)

    image = Instance(Image).tag(sync=True, **widget_serialization)
    format = Unicode('png', help='The format of the image.').tag(sync=True)
    _width = Unicode().tag(sync=True)
    _height = Unicode().tag(sync=True)

    def __init__(self,
                 format='png',
                 filename=Recorder.filename.default_value,
                 recording=False,
                 autosave=False,
                 **kwargs):
        super(ImageRecorder, self).__init__(format=format,
                                            filename=filename,
                                            recording=recording,
                                            autosave=autosave,
                                            **kwargs)
        if 'image' not in kwargs:
            # Set up initial observer on child:
            self.image.observe(self._check_autosave, 'value')

    @traitlets.default('image')
    def _default_image(self):
        return Image(width=self._width,
                     height=self._height,
                     format=self.format)

    @observe('_width')
    def _update_image_width(self, change):
        self.image.width = self._width

    @observe('_height')
    def _update_image_height(self, change):
        self.image.height = self._height

    @observe('format')
    def _update_image_format(self, change):
        self.image.format = self.format

    @observe('image')
    def _bind_image(self, change):
        if change.old:
            change.old.unobserve(self._check_autosave, 'value')
        change.new.observe(self._check_autosave, 'value')

    def _check_autosave(self, change):
        if len(self.image.value) and self.autosave:
            self.save()

    def save(self, filename=None):
        """Save the image to a file, if no filename is given it is based on the filename trait and the format.

        >>> recorder = ImageRecorder(filename='test', format='png')
        >>> ...
        >>> recorder.save()  # will save to test.png
        >>> recorder.save('foo')  # will save to foo.png
        >>> recorder.save('foo.dat')  # will save to foo.dat

        """
        filename = filename or self.filename
        if '.' not in filename:
            filename += '.' + self.format
        if len(self.image.value) == 0:
            raise ValueError('No data, did you record anything?')
        with open(filename, 'wb') as f:
            f.write(self.image.value)
Exemple #9
0
class AudioRecorder(Recorder):
    """Creates a recorder which allows to record the Audio of a MediaStream widget, play the
    record in the Notebook, and download it or turn it into an Audio widget.

    For help on supported values for the "codecs" attribute, see
    https://stackoverflow.com/questions/41739837/all-mime-types-supported-by-mediarecorder-in-firefox-and-chrome
    """
    _model_name = Unicode('AudioRecorderModel').tag(sync=True)
    _view_name = Unicode('AudioRecorderView').tag(sync=True)

    audio = Instance(Audio).tag(sync=True, **widget_serialization)
    codecs = Unicode(
        '',
        help='Optional codecs for the recording, e.g. "opus".').tag(sync=True)

    def __init__(self,
                 format='webm',
                 filename=Recorder.filename.default_value,
                 recording=False,
                 autosave=False,
                 **kwargs):
        super(AudioRecorder, self).__init__(format=format,
                                            filename=filename,
                                            recording=recording,
                                            autosave=autosave,
                                            **kwargs)
        if 'audio' not in kwargs:
            # Set up initial observer on child:
            self.audio.observe(self._check_autosave, 'value')

    @traitlets.default('audio')
    def _default_audio(self):
        return Audio(format=self.format, controls=True)

    @observe('format')
    def _update_audio_format(self, change):
        self.audio.format = self.format

    @observe('audio')
    def _bind_audio(self, change):
        if change.old:
            change.old.unobserve(self._check_autosave, 'value')
        change.new.observe(self._check_autosave, 'value')

    def _check_autosave(self, change):
        if len(self.audio.value) and self.autosave:
            self.save()

    def save(self, filename=None):
        """Save the audio to a file, if no filename is given it is based on the filename trait and the format.

        >>> recorder = AudioRecorder(filename='test', format='mp3')
        >>> ...
        >>> recorder.save()  # will save to test.mp3
        >>> recorder.save('foo')  # will save to foo.mp3
        >>> recorder.save('foo.dat')  # will save to foo.dat

        """
        filename = filename or self.filename
        if '.' not in filename:
            filename += '.' + self.format
        if len(self.audio.value) == 0:
            raise ValueError('No data, did you record anything?')
        with open(filename, 'wb') as f:
            f.write(self.audio.value)
Exemple #10
0
class InlineBackend(InlineBackendConfig):
    """An object to store configuration of the inline backend."""
    def _config_changed(self, name, old, new):
        # warn on change of renamed config section
        if new.InlineBackendConfig != getattr(old, 'InlineBackendConfig',
                                              Config()):
            warn("InlineBackendConfig has been renamed to InlineBackend")
        super(InlineBackend, self)._config_changed(name, old, new)

    # The typical default figure size is too large for inline use,
    # so we shrink the figure size to 6x4, and tweak fonts to
    # make that fit.
    rc = Dict(
        {
            'figure.figsize': (6.0, 4.0),
            # play nicely with white background in the Qt and notebook frontend
            'figure.facecolor': (1, 1, 1, 0),
            'figure.edgecolor': (1, 1, 1, 0),
            # 12pt labels get cutoff on 6x4 logplots, so use 10pt.
            'font.size': 10,
            # 72 dpi matches SVG/qtconsole
            # this only affects PNG export, as SVG has no dpi setting
            'savefig.dpi': 72,
            # 10pt still needs a little more room on the xlabel:
            'figure.subplot.bottom': .125
        },
        config=True,
        help="""Subset of matplotlib rcParams that should be different for the
        inline backend.""")

    figure_formats = Set({'png'},
                         config=True,
                         help="""A set of figure formats to enable: 'png',
                          'retina', 'jpeg', 'svg', 'pdf'.""")

    def _update_figure_formatters(self):
        if self.shell is not None:
            from IPython.core.pylabtools import select_figure_formats
            select_figure_formats(self.shell, self.figure_formats,
                                  **self.print_figure_kwargs)

    def _figure_formats_changed(self, name, old, new):
        if 'jpg' in new or 'jpeg' in new:
            if not pil_available():
                raise TraitError("Requires PIL/Pillow for JPG figures")
        self._update_figure_formatters()

    figure_format = Unicode(config=True,
                            help="""The figure format to enable (deprecated
                                         use `figure_formats` instead)""")

    def _figure_format_changed(self, name, old, new):
        if new:
            self.figure_formats = {new}

    print_figure_kwargs = Dict(
        {'bbox_inches': 'tight'},
        config=True,
        help="""Extra kwargs to be passed to fig.canvas.print_figure.

        Logical examples include: bbox_inches, quality (for jpeg figures), etc.
        """)
    _print_figure_kwargs_changed = _update_figure_formatters

    close_figures = Bool(True,
                         config=True,
                         help="""Close all figures at the end of each cell.

        When True, ensures that each cell starts with no active figures, but it
        also means that one must keep track of references in order to edit or
        redraw figures in subsequent cells. This mode is ideal for the notebook,
        where residual plots from other cells might be surprising.

        When False, one must call figure() to create new figures. This means
        that gcf() and getfigs() can reference figures created in other cells,
        and the active figure can continue to be edited with pylab/pyplot
        methods that reference the current active figure. This mode facilitates
        iterative editing of figures, and behaves most consistently with
        other matplotlib backends, but figure barriers between cells must
        be explicit.
        """)

    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
Exemple #11
0
class AudioStream(MediaStream):
    """Represent a stream of an audio element"""
    _model_name = Unicode('AudioStreamModel').tag(sync=True)
    _view_name = Unicode('AudioStreamView').tag(sync=True)

    audio = Instance(
        Audio,
        help=
        "An ipywidgets.Audio instance that will be the source of the media stream."
    ).tag(sync=True, **widget_serialization)
    playing = Bool(True,
                   help='Plays the audiostream or pauses it.').tag(sync=True)

    @classmethod
    def from_file(cls, filename, **kwargs):
        """Create a `AudioStream` from a local file.

        Parameters
        ----------
        filename: str
            The location of a file to read into the audio value from disk.
        **kwargs
            Extra keyword arguments for `AudioStream`
        """
        audio = Audio.from_file(filename, autoplay=False, controls=False)
        return cls(audio=audio, **kwargs)

    @classmethod
    def from_url(cls, url, **kwargs):
        """Create a `AudioStream` from a url.

        This will create a `AudioStream` from an Audio using its url

        Parameters
        ----------
        url: str
            The url of the file that will be used for the .audio trait.
        **kwargs
            Extra keyword arguments for `AudioStream`
        """
        audio = Audio.from_url(url, autoplay=False, controls=False)
        return cls(audio=audio, **kwargs)

    @classmethod
    def from_download(cls, url, **kwargs):
        """Create a `AudioStream` from a url by downloading

        Parameters
        ----------
        url: str
            The url of the file that will be downloadeded and its bytes
            assigned to the value trait of the video trait.
        **kwargs
            Extra keyword arguments for `AudioStream`
        """
        ext = os.path.splitext(url)[1]
        if ext:
            format = ext[1:]
        audio = Audio(value=urlopen(url).read(),
                      format=format,
                      autoplay=False,
                      controls=False)
        return cls(audio=audio, **kwargs)
Exemple #12
0
class KernelManager(ConnectionFileMixin):
    """Manages a single kernel in a subprocess on this host.

    This version starts kernels with Popen.
    """

    # The PyZMQ Context to use for communication with the kernel.
    context = Instance(zmq.Context)

    def _context_default(self):
        return zmq.Context()

    # the class to create with our `client` method
    client_class = DottedObjectName(
        'jupyter_client.blocking.BlockingKernelClient')
    client_factory = Type(klass='jupyter_client.KernelClient')

    def _client_factory_default(self):
        return import_item(self.client_class)

    def _client_class_changed(self, name, old, new):
        self.client_factory = import_item(str(new))

    # The kernel process with which the KernelManager is communicating.
    # generally a Popen instance
    kernel = Any()

    kernel_spec_manager = Instance(kernelspec.KernelSpecManager)

    def _kernel_spec_manager_default(self):
        return kernelspec.KernelSpecManager(data_dir=self.data_dir)

    def _kernel_spec_manager_changed(self):
        self._kernel_spec = None

    shutdown_wait_time = Float(
        5.0,
        config=True,
        help="Time to wait for a kernel to terminate before killing it, "
        "in seconds.")

    kernel_name = Unicode(kernelspec.NATIVE_KERNEL_NAME)

    def _kernel_name_changed(self, name, old, new):
        self._kernel_spec = None
        if new == 'python':
            self.kernel_name = kernelspec.NATIVE_KERNEL_NAME

    _kernel_spec = None

    @property
    def kernel_spec(self):
        if self._kernel_spec is None and self.kernel_name is not '':
            self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(
                self.kernel_name)
        return self._kernel_spec

    kernel_cmd = List(Unicode(),
                      config=True,
                      help="""DEPRECATED: Use kernel_name instead.

        The Popen Command to launch the kernel.
        Override this if you have a custom kernel.
        If kernel_cmd is specified in a configuration file,
        Jupyter does not pass any arguments to the kernel,
        because it cannot make any assumptions about the
        arguments that the kernel understands. In particular,
        this means that the kernel does not receive the
        option --debug if it given on the Jupyter command line.
        """)

    def _kernel_cmd_changed(self, name, old, new):
        warnings.warn("Setting kernel_cmd is deprecated, use kernel_spec to "
                      "start different kernels.")

    @property
    def ipykernel(self):
        return self.kernel_name in {'python', 'python2', 'python3'}

    # Protected traits
    _launch_args = Any()
    _control_socket = Any()

    _restarter = Any()

    autorestart = Bool(True,
                       config=True,
                       help="""Should we autorestart the kernel if it dies.""")

    def __del__(self):
        self._close_control_socket()
        self.cleanup_connection_file()

    #--------------------------------------------------------------------------
    # Kernel restarter
    #--------------------------------------------------------------------------

    def start_restarter(self):
        pass

    def stop_restarter(self):
        pass

    def add_restart_callback(self, callback, event='restart'):
        """register a callback to be called when a kernel is restarted"""
        if self._restarter is None:
            return
        self._restarter.add_callback(callback, event)

    def remove_restart_callback(self, callback, event='restart'):
        """unregister a callback to be called when a kernel is restarted"""
        if self._restarter is None:
            return
        self._restarter.remove_callback(callback, event)

    #--------------------------------------------------------------------------
    # create a Client connected to our Kernel
    #--------------------------------------------------------------------------

    def client(self, **kwargs):
        """Create a client configured to connect to our kernel"""
        kw = {}
        kw.update(self.get_connection_info(session=True))
        kw.update(dict(
            connection_file=self.connection_file,
            parent=self,
        ))

        # add kwargs last, for manual overrides
        kw.update(kwargs)
        return self.client_factory(**kw)

    #--------------------------------------------------------------------------
    # Kernel management
    #--------------------------------------------------------------------------

    def format_kernel_cmd(self, extra_arguments=None):
        """replace templated args (e.g. {connection_file})"""
        extra_arguments = extra_arguments or []
        if self.kernel_cmd:
            cmd = self.kernel_cmd + extra_arguments
        else:
            cmd = self.kernel_spec.argv + extra_arguments

        if cmd and cmd[0] in {
                'python',
                'python%i' % sys.version_info[0],
                'python%i.%i' % sys.version_info[:2]
        }:
            # executable is 'python' or 'python3', use sys.executable.
            # These will typically be the same,
            # but if the current process is in an env
            # and has been launched by abspath without
            # activating the env, python on PATH may not be sys.executable,
            # but it should be.
            cmd[0] = sys.executable

        ns = dict(
            connection_file=self.connection_file,
            prefix=sys.prefix,
        )

        if self.kernel_spec:
            ns["resource_dir"] = self.kernel_spec.resource_dir

        ns.update(self._launch_args)

        pat = re.compile(r'\{([A-Za-z0-9_]+)\}')

        def from_ns(match):
            """Get the key out of ns if it's there, otherwise no change."""
            return ns.get(match.group(1), match.group())

        return [pat.sub(from_ns, arg) for arg in cmd]

    def _launch_kernel(self, kernel_cmd, **kw):
        """actually launch the kernel

        override in a subclass to launch kernel subprocesses differently
        """
        return launch_kernel(kernel_cmd, **kw)

    # Control socket used for polite kernel shutdown

    def _connect_control_socket(self):
        if self._control_socket is None:
            self._control_socket = self._create_connected_socket('control')
            self._control_socket.linger = 100

    def _close_control_socket(self):
        if self._control_socket is None:
            return
        self._control_socket.close()
        self._control_socket = None

    def start_kernel(self, **kw):
        """Starts a kernel on this host in a separate process.

        If random ports (port=0) are being used, this method must be called
        before the channels are created.

        Parameters
        ----------
        `**kw` : optional
             keyword arguments that are passed down to build the kernel_cmd
             and launching the kernel (e.g. Popen kwargs).
        """
        if self.transport == 'tcp' and not is_local_ip(self.ip):
            raise RuntimeError(
                "Can only launch a kernel on a local interface. "
                "This one is not: %s."
                "Make sure that the '*_address' attributes are "
                "configured properly. "
                "Currently valid addresses are: %s" % (self.ip, local_ips()))

        # write connection file / get default ports
        self.write_connection_file()

        # save kwargs for use in restart
        self._launch_args = kw.copy()
        # build the Popen cmd
        extra_arguments = kw.pop('extra_arguments', [])
        kernel_cmd = self.format_kernel_cmd(extra_arguments=extra_arguments)
        env = kw.pop('env', os.environ).copy()
        # Don't allow PYTHONEXECUTABLE to be passed to kernel process.
        # If set, it can bork all the things.
        env.pop('PYTHONEXECUTABLE', None)
        if not self.kernel_cmd:
            # If kernel_cmd has been set manually, don't refer to a kernel spec
            # Environment variables from kernel spec are added to os.environ
            env.update(self.kernel_spec.env or {})

        # launch the kernel subprocess
        self.log.debug("Starting kernel: %s", kernel_cmd)
        self.kernel = self._launch_kernel(kernel_cmd, env=env, **kw)
        self.start_restarter()
        self._connect_control_socket()

    def request_shutdown(self, restart=False):
        """Send a shutdown request via control channel
        """
        content = dict(restart=restart)
        msg = self.session.msg("shutdown_request", content=content)
        # ensure control socket is connected
        self._connect_control_socket()
        self.session.send(self._control_socket, msg)

    def finish_shutdown(self, waittime=None, pollinterval=0.1):
        """Wait for kernel shutdown, then kill process if it doesn't shutdown.

        This does not send shutdown requests - use :meth:`request_shutdown`
        first.
        """
        if waittime is None:
            waittime = max(self.shutdown_wait_time, 0)
        for i in range(int(waittime / pollinterval)):
            if self.is_alive():
                time.sleep(pollinterval)
            else:
                break
        else:
            # OK, we've waited long enough.
            if self.has_kernel:
                self.log.debug("Kernel is taking too long to finish, killing")
                self._kill_kernel()

    def cleanup(self, connection_file=True):
        """Clean up resources when the kernel is shut down"""
        if connection_file:
            self.cleanup_connection_file()

        self.cleanup_ipc_files()
        self._close_control_socket()

    def shutdown_kernel(self, now=False, restart=False):
        """Attempts to stop the kernel process cleanly.

        This attempts to shutdown the kernels cleanly by:

        1. Sending it a shutdown message over the shell channel.
        2. If that fails, the kernel is shutdown forcibly by sending it
           a signal.

        Parameters
        ----------
        now : bool
            Should the kernel be forcible killed *now*. This skips the
            first, nice shutdown attempt.
        restart: bool
            Will this kernel be restarted after it is shutdown. When this
            is True, connection files will not be cleaned up.
        """
        # Stop monitoring for restarting while we shutdown.
        self.stop_restarter()

        if now:
            self._kill_kernel()
        else:
            self.request_shutdown(restart=restart)
            # Don't send any additional kernel kill messages immediately, to give
            # the kernel a chance to properly execute shutdown actions. Wait for at
            # most 1s, checking every 0.1s.
            self.finish_shutdown()

        self.cleanup(connection_file=not restart)

    def restart_kernel(self, now=False, newports=False, **kw):
        """Restarts a kernel with the arguments that were used to launch it.

        Parameters
        ----------
        now : bool, optional
            If True, the kernel is forcefully restarted *immediately*, without
            having a chance to do any cleanup action.  Otherwise the kernel is
            given 1s to clean up before a forceful restart is issued.

            In all cases the kernel is restarted, the only difference is whether
            it is given a chance to perform a clean shutdown or not.

        newports : bool, optional
            If the old kernel was launched with random ports, this flag decides
            whether the same ports and connection file will be used again.
            If False, the same ports and connection file are used. This is
            the default. If True, new random port numbers are chosen and a
            new connection file is written. It is still possible that the newly
            chosen random port numbers happen to be the same as the old ones.

        `**kw` : optional
            Any options specified here will overwrite those used to launch the
            kernel.
        """
        if self._launch_args is None:
            raise RuntimeError("Cannot restart the kernel. "
                               "No previous call to 'start_kernel'.")
        else:
            # Stop currently running kernel.
            self.shutdown_kernel(now=now, restart=True)

            if newports:
                self.cleanup_random_ports()

            # Start new kernel.
            self._launch_args.update(kw)
            self.start_kernel(**self._launch_args)

    @property
    def has_kernel(self):
        """Has a kernel been started that we are managing."""
        return self.kernel is not None

    def _kill_kernel(self):
        """Kill the running kernel.

        This is a private method, callers should use shutdown_kernel(now=True).
        """
        if self.has_kernel:

            # Signal the kernel to terminate (sends SIGKILL on Unix and calls
            # TerminateProcess() on Win32).
            try:
                if hasattr(signal, 'SIGKILL'):
                    self.signal_kernel(signal.SIGKILL)
                else:
                    self.kernel.kill()
            except OSError as e:
                # In Windows, we will get an Access Denied error if the process
                # has already terminated. Ignore it.
                if sys.platform == 'win32':
                    if e.winerror != 5:
                        raise
                # On Unix, we may get an ESRCH error if the process has already
                # terminated. Ignore it.
                else:
                    from errno import ESRCH
                    if e.errno != ESRCH:
                        raise

            # Block until the kernel terminates.
            self.kernel.wait()
            self.kernel = None
        else:
            raise RuntimeError("Cannot kill kernel. No kernel is running!")

    def interrupt_kernel(self):
        """Interrupts the kernel by sending it a signal.

        Unlike ``signal_kernel``, this operation is well supported on all
        platforms.
        """
        if self.has_kernel:
            interrupt_mode = self.kernel_spec.interrupt_mode
            if interrupt_mode == 'signal':
                if sys.platform == 'win32':
                    from .win_interrupt import send_interrupt
                    send_interrupt(self.kernel.win32_interrupt_event)
                else:
                    self.signal_kernel(signal.SIGINT)

            elif interrupt_mode == 'message':
                msg = self.session.msg("interrupt_request", content={})
                self._connect_control_socket()
                self.session.send(self._control_socket, msg)
        else:
            raise RuntimeError(
                "Cannot interrupt kernel. No kernel is running!")

    def signal_kernel(self, signum):
        """Sends a signal to the process group of the kernel (this
        usually includes the kernel and any subprocesses spawned by
        the kernel).

        Note that since only SIGTERM is supported on Windows, this function is
        only useful on Unix systems.
        """
        if self.has_kernel:
            if hasattr(os, "getpgid") and hasattr(os, "killpg"):
                try:
                    pgid = os.getpgid(self.kernel.pid)
                    os.killpg(pgid, signum)
                    return
                except OSError:
                    pass
            self.kernel.send_signal(signum)
        else:
            raise RuntimeError("Cannot signal kernel. No kernel is running!")

    def is_alive(self):
        """Is the kernel process still running?"""
        if self.has_kernel:
            if self.kernel.poll() is None:
                return True
            else:
                return False
        else:
            # we don't have a kernel
            return False
Exemple #13
0
class Node(Widget):
    """ The node widget """
    _view_name = Unicode('NodeView').tag(sync=True)
    _model_name = Unicode('NodeModel').tag(sync=True)
    _view_module = Unicode('ipytree').tag(sync=True)
    _model_module = Unicode('ipytree').tag(sync=True)
    _view_module_version = Unicode(__version__).tag(sync=True)
    _model_module_version = Unicode(__version__).tag(sync=True)

    _style_values = ["warning", "danger", "success", "info", "default"]

    name = Unicode("Node").tag(sync=True)
    opened = Bool(True).tag(sync=True)
    disabled = Bool(False).tag(sync=True)
    selected = Bool(False).tag(sync=True)

    show_icon = Bool(True).tag(sync=True)
    icon = Unicode("folder").tag(sync=True)
    icon_style = Enum(values=_style_values,
                      default_value="default").tag(sync=True)
    icon_image = Unicode("").tag(sync=True)

    open_icon = Unicode("plus").tag(sync=True)
    open_icon_style = Enum(values=_style_values,
                           default_value="default").tag(sync=True)

    close_icon = Unicode("minus").tag(sync=True)
    close_icon_style = Enum(values=_style_values,
                            default_value="default").tag(sync=True)

    nodes = Tuple().tag(trait=Instance(Widget),
                        sync=True,
                        **widget_serialization)

    _id = Unicode(read_only=True).tag(sync=True)

    def __init__(self, name="Node", nodes=[], **kwargs):
        super(Node, self).__init__(**kwargs)

        self.name = name
        self.nodes = nodes

    @default('_id')
    def _default_id(self):
        return id_gen()

    def add_node(self, node, position=None):
        if not isinstance(node, Node):
            raise TraitError('The added node must be a Node instance')

        nodes = list(self.nodes)
        if position is None or position > len(nodes):
            position = len(nodes)
        nodes.insert(position, node)
        self.nodes = tuple(nodes)

    def remove_node(self, node):
        if node not in self.nodes:
            raise RuntimeError('{} is not a children of {}'.format(
                node.name, self.name))
        self.nodes = tuple([n for n in self.nodes if n._id != node._id])
Exemple #14
0
class TerminalInteractiveShell(InteractiveShell):
    mime_renderers = Dict().tag(config=True)

    space_for_menu = Integer(
        6,
        help='Number of line at the bottom of the screen '
        'to reserve for the completion menu').tag(config=True)

    pt_app = None
    debugger_history = None

    simple_prompt = Bool(
        _use_simple_prompt,
        help=
        """Use `raw_input` for the REPL, without completion and prompt colors.

            Useful when controlling IPython as a subprocess, and piping STDIN/OUT/ERR. Known usage are:
            IPython own testing machinery, and emacs inferior-shell integration through elpy.

            This mode default to `True` if the `IPY_TEST_SIMPLE_PROMPT`
            environment variable is set, or the current terminal is not a tty."""
    ).tag(config=True)

    @property
    def debugger_cls(self):
        return Pdb if self.simple_prompt else TerminalPdb

    confirm_exit = Bool(
        True,
        help="""
        Set to confirm when you try to exit IPython with an EOF (Control-D
        in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
        you can force a direct exit without any confirmation.""",
    ).tag(config=True)

    editing_mode = Unicode(
        'emacs',
        help="Shortcut style to use at the prompt. 'vi' or 'emacs'.",
    ).tag(config=True)

    autoformatter = Unicode(
        None,
        help=
        "Autoformatter to reformat Terminal code. Can be `'black'` or `None`",
        allow_none=True).tag(config=True)

    mouse_support = Bool(
        False,
        help=
        "Enable mouse support in the prompt\n(Note: prevents selecting text with the mouse)"
    ).tag(config=True)

    # We don't load the list of styles for the help string, because loading
    # Pygments plugins takes time and can cause unexpected errors.
    highlighting_style = Union(
        [Unicode('legacy'), Type(klass=Style)],
        help="""The name or class of a Pygments style to use for syntax
        highlighting. To see available styles, run `pygmentize -L styles`."""
    ).tag(config=True)

    @validate('editing_mode')
    def _validate_editing_mode(self, proposal):
        if proposal['value'].lower() == 'vim':
            proposal['value'] = 'vi'
        elif proposal['value'].lower() == 'default':
            proposal['value'] = 'emacs'

        if hasattr(EditingMode, proposal['value'].upper()):
            return proposal['value'].lower()

        return self.editing_mode

    @observe('editing_mode')
    def _editing_mode(self, change):
        u_mode = change.new.upper()
        if self.pt_app:
            self.pt_app.editing_mode = u_mode

    @observe('autoformatter')
    def _autoformatter_changed(self, change):
        formatter = change.new
        if formatter is None:
            self.reformat_handler = lambda x: x
        elif formatter == 'black':
            self.reformat_handler = black_reformat_handler
        else:
            raise ValueError

    @observe('highlighting_style')
    @observe('colors')
    def _highlighting_style_changed(self, change):
        self.refresh_style()

    def refresh_style(self):
        self._style = self._make_style_from_name_or_cls(
            self.highlighting_style)

    highlighting_style_overrides = Dict(
        help="Override highlighting format for specific tokens").tag(
            config=True)

    true_color = Bool(
        False,
        help=("Use 24bit colors instead of 256 colors in prompt highlighting. "
              "If your terminal supports true color, the following command "
              "should print 'TRUECOLOR' in orange: "
              "printf \"\\x1b[38;2;255;100;0mTRUECOLOR\\x1b[0m\\n\"")).tag(
                  config=True)

    editor = Unicode(
        get_default_editor(),
        help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
    ).tag(config=True)

    prompts_class = Type(
        Prompts,
        help='Class used to generate Prompt token for prompt_toolkit').tag(
            config=True)

    prompts = Instance(Prompts)

    @default('prompts')
    def _prompts_default(self):
        return self.prompts_class(self)

#    @observe('prompts')
#    def _(self, change):
#        self._update_layout()

    @default('displayhook_class')
    def _displayhook_class_default(self):
        return RichPromptDisplayHook

    term_title = Bool(
        True, help="Automatically set the terminal title").tag(config=True)

    term_title_format = Unicode(
        "IPython: {cwd}",
        help=
        "Customize the terminal title format.  This is a python format string. "
        + "Available substitutions are: {cwd}.").tag(config=True)

    display_completions = Enum(
        ('column', 'multicolumn', 'readlinelike'),
        help=
        ("Options for displaying tab completions, 'column', 'multicolumn', and "
         "'readlinelike'. These options are for `prompt_toolkit`, see "
         "`prompt_toolkit` documentation for more information."),
        default_value='multicolumn').tag(config=True)

    highlight_matching_brackets = Bool(
        True,
        help="Highlight matching brackets.",
    ).tag(config=True)

    extra_open_editor_shortcuts = Bool(
        False,
        help=
        "Enable vi (v) or Emacs (C-X C-E) shortcuts to open an external editor. "
        "This is in addition to the F2 binding, which is always enabled.").tag(
            config=True)

    handle_return = Any(
        None,
        help="Provide an alternative handler to be called when the user presses "
        "Return. This is an advanced option intended for debugging, which "
        "may be changed or removed in later releases.").tag(config=True)

    enable_history_search = Bool(
        True,
        help="Allows to enable/disable the prompt toolkit history search").tag(
            config=True)

    prompt_includes_vi_mode = Bool(
        True,
        help="Display the current vi mode (when using vi editing mode).").tag(
            config=True)

    @observe('term_title')
    def init_term_title(self, change=None):
        # Enable or disable the terminal title.
        if self.term_title:
            toggle_set_term_title(True)
            set_term_title(self.term_title_format.format(cwd=abbrev_cwd()))
        else:
            toggle_set_term_title(False)

    def restore_term_title(self):
        if self.term_title:
            restore_term_title()

    def init_display_formatter(self):
        super(TerminalInteractiveShell, self).init_display_formatter()
        # terminal only supports plain text
        self.display_formatter.active_types = ['text/plain']
        # disable `_ipython_display_`
        self.display_formatter.ipython_display_formatter.enabled = False

    def init_prompt_toolkit_cli(self):
        if self.simple_prompt:
            # Fall back to plain non-interactive output for tests.
            # This is very limited.
            def prompt():
                prompt_text = "".join(x[1]
                                      for x in self.prompts.in_prompt_tokens())
                lines = [input(prompt_text)]
                prompt_continuation = "".join(
                    x[1] for x in self.prompts.continuation_prompt_tokens())
                while self.check_complete('\n'.join(lines))[0] == 'incomplete':
                    lines.append(input(prompt_continuation))
                return '\n'.join(lines)

            self.prompt_for_code = prompt
            return

        # Set up keyboard shortcuts
        key_bindings = create_ipython_shortcuts(self)

        # Pre-populate history from IPython's history database
        history = InMemoryHistory()
        last_cell = u""
        for __, ___, cell in self.history_manager.get_tail(
                self.history_load_length, include_latest=True):
            # Ignore blank lines and consecutive duplicates
            cell = cell.rstrip()
            if cell and (cell != last_cell):
                history.append_string(cell)
                last_cell = cell

        self._style = self._make_style_from_name_or_cls(
            self.highlighting_style)
        self.style = DynamicStyle(lambda: self._style)

        editing_mode = getattr(EditingMode, self.editing_mode.upper())

        self.pt_app = PromptSession(
            editing_mode=editing_mode,
            key_bindings=key_bindings,
            history=history,
            completer=IPythonPTCompleter(shell=self),
            enable_history_search=self.enable_history_search,
            style=self.style,
            include_default_pygments_style=False,
            mouse_support=self.mouse_support,
            enable_open_in_editor=self.extra_open_editor_shortcuts,
            color_depth=self.color_depth,
            **self._extra_prompt_options())

    def _make_style_from_name_or_cls(self, name_or_cls):
        """
        Small wrapper that make an IPython compatible style from a style name

        We need that to add style for prompt ... etc.
        """
        style_overrides = {}
        if name_or_cls == 'legacy':
            legacy = self.colors.lower()
            if legacy == 'linux':
                style_cls = get_style_by_name('monokai')
                style_overrides = _style_overrides_linux
            elif legacy == 'lightbg':
                style_overrides = _style_overrides_light_bg
                style_cls = get_style_by_name('pastie')
            elif legacy == 'neutral':
                # The default theme needs to be visible on both a dark background
                # and a light background, because we can't tell what the terminal
                # looks like. These tweaks to the default theme help with that.
                style_cls = get_style_by_name('default')
                style_overrides.update({
                    Token.Number:
                    '#007700',
                    Token.Operator:
                    'noinherit',
                    Token.String:
                    '#BB6622',
                    Token.Name.Function:
                    '#2080D0',
                    Token.Name.Class:
                    'bold #2080D0',
                    Token.Name.Namespace:
                    'bold #2080D0',
                    Token.Prompt:
                    '#009900',
                    Token.PromptNum:
                    '#ansibrightgreen bold',
                    Token.OutPrompt:
                    '#990000',
                    Token.OutPromptNum:
                    '#ansibrightred bold',
                })

                # Hack: Due to limited color support on the Windows console
                # the prompt colors will be wrong without this
                if os.name == 'nt':
                    style_overrides.update({
                        Token.Prompt: '#ansidarkgreen',
                        Token.PromptNum: '#ansigreen bold',
                        Token.OutPrompt: '#ansidarkred',
                        Token.OutPromptNum: '#ansired bold',
                    })
            elif legacy == 'nocolor':
                style_cls = _NoStyle
                style_overrides = {}
            else:
                raise ValueError('Got unknown colors: ', legacy)
        else:
            if isinstance(name_or_cls, str):
                style_cls = get_style_by_name(name_or_cls)
            else:
                style_cls = name_or_cls
            style_overrides = {
                Token.Prompt: '#009900',
                Token.PromptNum: '#ansibrightgreen bold',
                Token.OutPrompt: '#990000',
                Token.OutPromptNum: '#ansibrightred bold',
            }
        style_overrides.update(self.highlighting_style_overrides)
        style = merge_styles([
            style_from_pygments_cls(style_cls),
            style_from_pygments_dict(style_overrides),
        ])

        return style

    @property
    def pt_complete_style(self):
        return {
            'multicolumn': CompleteStyle.MULTI_COLUMN,
            'column': CompleteStyle.COLUMN,
            'readlinelike': CompleteStyle.READLINE_LIKE,
        }[self.display_completions]

    @property
    def color_depth(self):
        return (ColorDepth.TRUE_COLOR if self.true_color else None)

    def _extra_prompt_options(self):
        """
        Return the current layout option for the current Terminal InteractiveShell
        """
        def get_message():
            return PygmentsTokens(self.prompts.in_prompt_tokens())

        if self.editing_mode == 'emacs':
            # with emacs mode the prompt is (usually) static, so we call only
            # the function once. With VI mode it can toggle between [ins] and
            # [nor] so we can't precompute.
            # here I'm going to favor the default keybinding which almost
            # everybody uses to decrease CPU usage.
            # if we have issues with users with custom Prompts we can see how to
            # work around this.
            get_message = get_message()

        options = {
            'complete_in_thread':
            False,
            'lexer':
            IPythonPTLexer(),
            'reserve_space_for_menu':
            self.space_for_menu,
            'message':
            get_message,
            'prompt_continuation':
            (lambda width, lineno, is_soft_wrap: PygmentsTokens(
                self.prompts.continuation_prompt_tokens(width))),
            'multiline':
            True,
            'complete_style':
            self.pt_complete_style,

            # Highlight matching brackets, but only when this setting is
            # enabled, and only when the DEFAULT_BUFFER has the focus.
            'input_processors': [
                ConditionalProcessor(
                    processor=HighlightMatchingBracketProcessor(
                        chars='[](){}'),
                    filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()
                    & Condition(lambda: self.highlight_matching_brackets))
            ],
        }
        if not PTK3:
            options['inputhook'] = self.inputhook

        return options

    def prompt_for_code(self):
        if self.rl_next_input:
            default = self.rl_next_input
            self.rl_next_input = None
        else:
            default = ''

        with patch_stdout(raw=True):
            text = self.pt_app.prompt(
                default=default,
                #                pre_run=self.pre_prompt,# reset_current_buffer=True,
                **self._extra_prompt_options())
        return text

    def enable_win_unicode_console(self):
        # Since IPython 7.10 doesn't support python < 3.6 and PEP 528, Python uses the unicode APIs for the Windows
        # console by default, so WUC shouldn't be needed.
        from warnings import warn
        warn(
            "`enable_win_unicode_console` is deprecated since IPython 7.10, does not do anything and will be removed in the future",
            DeprecationWarning,
            stacklevel=2)

    def init_io(self):
        if sys.platform not in {'win32', 'cli'}:
            return

        import colorama
        colorama.init()

        # For some reason we make these wrappers around stdout/stderr.
        # For now, we need to reset them so all output gets coloured.
        # https://github.com/ipython/ipython/issues/8669
        # io.std* are deprecated, but don't show our own deprecation warnings
        # during initialization of the deprecated API.
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', DeprecationWarning)
            io.stdout = io.IOStream(sys.stdout)
            io.stderr = io.IOStream(sys.stderr)

    def init_magics(self):
        super(TerminalInteractiveShell, self).init_magics()
        self.register_magics(TerminalMagics)

    def init_alias(self):
        # The parent class defines aliases that can be safely used with any
        # frontend.
        super(TerminalInteractiveShell, self).init_alias()

        # Now define aliases that only make sense on the terminal, because they
        # need direct access to the console in a way that we can't emulate in
        # GUI or web frontend
        if os.name == 'posix':
            for cmd in ('clear', 'more', 'less', 'man'):
                self.alias_manager.soft_define_alias(cmd, cmd)

    def __init__(self, *args, **kwargs):
        super(TerminalInteractiveShell, self).__init__(*args, **kwargs)
        self.init_prompt_toolkit_cli()
        self.init_term_title()
        self.keep_running = True

        self.debugger_history = InMemoryHistory()

    def ask_exit(self):
        self.keep_running = False

    rl_next_input = None

    def interact(self, display_banner=DISPLAY_BANNER_DEPRECATED):

        if display_banner is not DISPLAY_BANNER_DEPRECATED:
            warn(
                'interact `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.',
                DeprecationWarning,
                stacklevel=2)

        self.keep_running = True
        while self.keep_running:
            print(self.separate_in, end='')

            try:
                code = self.prompt_for_code()
            except EOFError:
                if (not self.confirm_exit) \
                        or self.ask_yes_no('Do you really want to exit ([y]/n)?','y','n'):
                    self.ask_exit()

            else:
                if code:
                    self.run_cell(code, store_history=True)

    def mainloop(self, display_banner=DISPLAY_BANNER_DEPRECATED):
        # An extra layer of protection in case someone mashing Ctrl-C breaks
        # out of our internal code.
        if display_banner is not DISPLAY_BANNER_DEPRECATED:
            warn(
                'mainloop `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.',
                DeprecationWarning,
                stacklevel=2)
        while True:
            try:
                self.interact()
                break
            except KeyboardInterrupt as e:
                print("\n%s escaped interact()\n" % type(e).__name__)
            finally:
                # An interrupt during the eventloop will mess up the
                # internal state of the prompt_toolkit library.
                # Stopping the eventloop fixes this, see
                # https://github.com/ipython/ipython/pull/9867
                if hasattr(self, '_eventloop'):
                    self._eventloop.stop()

                self.restore_term_title()

    _inputhook = None

    def inputhook(self, context):
        if self._inputhook is not None:
            self._inputhook(context)

    active_eventloop = None

    def enable_gui(self, gui=None):
        if gui and (gui != 'inline'):
            self.active_eventloop, self._inputhook =\
                get_inputhook_name_and_func(gui)
        else:
            self.active_eventloop = self._inputhook = None

        # For prompt_toolkit 3.0. We have to create an asyncio event loop with
        # this inputhook.
        if PTK3:
            if self._inputhook:
                from prompt_toolkit.eventloop import set_eventloop_with_inputhook
                set_eventloop_with_inputhook(self._inputhook)
            else:
                import asyncio
                asyncio.set_event_loop(asyncio.new_event_loop())

    # Run !system commands directly, not through pipes, so terminal programs
    # work correctly.
    system = InteractiveShell.system_raw

    def auto_rewrite_input(self, cmd):
        """Overridden from the parent class to use fancy rewriting prompt"""
        if not self.show_rewritten_input:
            return

        tokens = self.prompts.rewrite_prompt_tokens()
        if self.pt_app:
            print_formatted_text(PygmentsTokens(tokens),
                                 end='',
                                 style=self.pt_app.app.style)
            print(cmd)
        else:
            prompt = ''.join(s for t, s in tokens)
            print(prompt, cmd, sep='')

    _prompts_before = None

    def switch_doctest_mode(self, mode):
        """Switch prompts to classic for %doctest_mode"""
        if mode:
            self._prompts_before = self.prompts
            self.prompts = ClassicPrompts(self)
        elif self._prompts_before:
            self.prompts = self._prompts_before
            self._prompts_before = None
Exemple #15
0
class MarkerCluster(Layer):
    _view_name = Unicode('LeafletMarkerClusterView').tag(sync=True)
    _model_name = Unicode('LeafletMarkerClusterModel').tag(sync=True)

    markers = Tuple(trait=Instance(Marker)).tag(sync=True,
                                                **widget_serialization)
Exemple #16
0
class PrefilterManager(Configurable):
    """Main prefilter component.

    The IPython prefilter is run on all user input before it is run.  The
    prefilter consumes lines of input and produces transformed lines of
    input.

    The implementation consists of two phases:

    1. Transformers
    2. Checkers and handlers

    Over time, we plan on deprecating the checkers and handlers and doing
    everything in the transformers.

    The transformers are instances of :class:`PrefilterTransformer` and have
    a single method :meth:`transform` that takes a line and returns a
    transformed line.  The transformation can be accomplished using any
    tool, but our current ones use regular expressions for speed.

    After all the transformers have been run, the line is fed to the checkers,
    which are instances of :class:`PrefilterChecker`.  The line is passed to
    the :meth:`check` method, which either returns `None` or a
    :class:`PrefilterHandler` instance.  If `None` is returned, the other
    checkers are tried.  If an :class:`PrefilterHandler` instance is returned,
    the line is passed to the :meth:`handle` method of the returned
    handler and no further checkers are tried.

    Both transformers and checkers have a `priority` attribute, that determines
    the order in which they are called.  Smaller priorities are tried first.

    Both transformers and checkers also have `enabled` attribute, which is
    a boolean that determines if the instance is used.

    Users or developers can change the priority or enabled attribute of
    transformers or checkers, but they must call the :meth:`sort_checkers`
    or :meth:`sort_transformers` method after changing the priority.
    """

    multi_line_specials = Bool(True).tag(config=True)
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)

    def __init__(self, shell=None, **kwargs):
        super(PrefilterManager, self).__init__(shell=shell, **kwargs)
        self.shell = shell
        self.init_transformers()
        self.init_handlers()
        self.init_checkers()

    #-------------------------------------------------------------------------
    # API for managing transformers
    #-------------------------------------------------------------------------

    def init_transformers(self):
        """Create the default transformers."""
        self._transformers = []
        for transformer_cls in _default_transformers:
            transformer_cls(shell=self.shell,
                            prefilter_manager=self,
                            parent=self)

    def sort_transformers(self):
        """Sort the transformers by priority.

        This must be called after the priority of a transformer is changed.
        The :meth:`register_transformer` method calls this automatically.
        """
        self._transformers.sort(key=lambda x: x.priority)

    @property
    def transformers(self):
        """Return a list of checkers, sorted by priority."""
        return self._transformers

    def register_transformer(self, transformer):
        """Register a transformer instance."""
        if transformer not in self._transformers:
            self._transformers.append(transformer)
            self.sort_transformers()

    def unregister_transformer(self, transformer):
        """Unregister a transformer instance."""
        if transformer in self._transformers:
            self._transformers.remove(transformer)

    #-------------------------------------------------------------------------
    # API for managing checkers
    #-------------------------------------------------------------------------

    def init_checkers(self):
        """Create the default checkers."""
        self._checkers = []
        for checker in _default_checkers:
            checker(shell=self.shell, prefilter_manager=self, parent=self)

    def sort_checkers(self):
        """Sort the checkers by priority.

        This must be called after the priority of a checker is changed.
        The :meth:`register_checker` method calls this automatically.
        """
        self._checkers.sort(key=lambda x: x.priority)

    @property
    def checkers(self):
        """Return a list of checkers, sorted by priority."""
        return self._checkers

    def register_checker(self, checker):
        """Register a checker instance."""
        if checker not in self._checkers:
            self._checkers.append(checker)
            self.sort_checkers()

    def unregister_checker(self, checker):
        """Unregister a checker instance."""
        if checker in self._checkers:
            self._checkers.remove(checker)

    #-------------------------------------------------------------------------
    # API for managing handlers
    #-------------------------------------------------------------------------

    def init_handlers(self):
        """Create the default handlers."""
        self._handlers = {}
        self._esc_handlers = {}
        for handler in _default_handlers:
            handler(shell=self.shell, prefilter_manager=self, parent=self)

    @property
    def handlers(self):
        """Return a dict of all the handlers."""
        return self._handlers

    def register_handler(self, name, handler, esc_strings):
        """Register a handler instance by name with esc_strings."""
        self._handlers[name] = handler
        for esc_str in esc_strings:
            self._esc_handlers[esc_str] = handler

    def unregister_handler(self, name, handler, esc_strings):
        """Unregister a handler instance by name with esc_strings."""
        try:
            del self._handlers[name]
        except KeyError:
            pass
        for esc_str in esc_strings:
            h = self._esc_handlers.get(esc_str)
            if h is handler:
                del self._esc_handlers[esc_str]

    def get_handler_by_name(self, name):
        """Get a handler by its name."""
        return self._handlers.get(name)

    def get_handler_by_esc(self, esc_str):
        """Get a handler by its escape string."""
        return self._esc_handlers.get(esc_str)

    #-------------------------------------------------------------------------
    # Main prefiltering API
    #-------------------------------------------------------------------------

    def prefilter_line_info(self, line_info):
        """Prefilter a line that has been converted to a LineInfo object.

        This implements the checker/handler part of the prefilter pipe.
        """
        # print "prefilter_line_info: ", line_info
        handler = self.find_handler(line_info)
        return handler.handle(line_info)

    def find_handler(self, line_info):
        """Find a handler for the line_info by trying checkers."""
        for checker in self.checkers:
            if checker.enabled:
                handler = checker.check(line_info)
                if handler:
                    return handler
        return self.get_handler_by_name('normal')

    def transform_line(self, line, continue_prompt):
        """Calls the enabled transformers in order of increasing priority."""
        for transformer in self.transformers:
            if transformer.enabled:
                line = transformer.transform(line, continue_prompt)
        return line

    def prefilter_line(self, line, continue_prompt=False):
        """Prefilter a single input line as text.

        This method prefilters a single line of text by calling the
        transformers and then the checkers/handlers.
        """

        # print "prefilter_line: ", line, continue_prompt
        # All handlers *must* return a value, even if it's blank ('').

        # save the line away in case we crash, so the post-mortem handler can
        # record it
        self.shell._last_input_line = line

        if not line:
            # Return immediately on purely empty lines, so that if the user
            # previously typed some whitespace that started a continuation
            # prompt, he can break out of that loop with just an empty line.
            # This is how the default python prompt works.
            return ''

        # At this point, we invoke our transformers.
        if not continue_prompt or (continue_prompt
                                   and self.multi_line_specials):
            line = self.transform_line(line, continue_prompt)

        # Now we compute line_info for the checkers and handlers
        line_info = LineInfo(line, continue_prompt)

        # the input history needs to track even empty lines
        stripped = line.strip()

        normal_handler = self.get_handler_by_name('normal')
        if not stripped:
            return normal_handler.handle(line_info)

        # special handlers are only allowed for single line statements
        if continue_prompt and not self.multi_line_specials:
            return normal_handler.handle(line_info)

        prefiltered = self.prefilter_line_info(line_info)
        # print "prefiltered line: %r" % prefiltered
        return prefiltered

    def prefilter_lines(self, lines, continue_prompt=False):
        """Prefilter multiple input lines of text.

        This is the main entry point for prefiltering multiple lines of
        input.  This simply calls :meth:`prefilter_line` for each line of
        input.

        This covers cases where there are multiple lines in the user entry,
        which is the case when the user goes back to a multiline history
        entry and presses enter.
        """
        llines = lines.rstrip('\n').split('\n')
        # We can get multiple lines in one shot, where multiline input 'blends'
        # into one line, in cases like recalling from the readline history
        # buffer.  We need to make sure that in such cases, we correctly
        # communicate downstream which line is first and which are continuation
        # ones.
        if len(llines) > 1:
            out = '\n'.join([
                self.prefilter_line(line, lnum > 0)
                for lnum, line in enumerate(llines)
            ])
        else:
            out = self.prefilter_line(llines[0], continue_prompt)

        return out
Exemple #17
0
class LayerGroup(Layer):
    _view_name = Unicode('LeafletLayerGroupView').tag(sync=True)
    _model_name = Unicode('LeafletLayerGroupModel').tag(sync=True)

    layers = List(Instance(Layer)).tag(sync=True, **widget_serialization)
Exemple #18
0
class PlotMesh(Mesh):
    plot = Instance('sage.plot.plot3d.base.Graphics3d')

    def _plot_changed(self, name, old, new):
        self.type = new.scenetree_json()['type']
        if self.type == 'object':
            self.type = new.scenetree_json()['geometry']['type']
            self.material = self.material_from_object(new)
        else:
            self.type = new.scenetree_json()['children'][0]['geometry']['type']
            self.material = self.material_from_other(new)
        if self.type == 'index_face_set':
            self.geometry = self.geometry_from_plot(new)
        elif self.type == 'sphere':
            self.geometry = self.geometry_from_sphere(new)
        elif self.type == 'box':
            self.geometry = self.geometry_from_box(new)

    def material_from_object(self, p):
        # TODO: do this without scenetree_json()
        t = p.texture.scenetree_json()
        m = MeshLambertMaterial(side='DoubleSide')
        m.color = t['color']
        m.opacity = t['opacity']
        # TODO: support other attributes
        return m

    def material_from_other(self, p):
        # TODO: do this without scenetree_json()
        t = p.scenetree_json()['children'][0]['texture']
        m = MeshLambertMaterial(side='DoubleSide')
        m.color = t['color']
        m.opacity = t['opacity']
        # TODO: support other attributes
        return m

    def geometry_from_box(self, p):
        g = BoxGeometry()
        g.width = p.scenetree_json()['geometry']['size'][0]
        g.height = p.scenetree_json()['geometry']['size'][1]
        g.depth = p.scenetree_json()['geometry']['size'][2]
        return g

    def geometry_from_sphere(self, p):
        g = SphereGeometry()
        g.radius = p.scenetree_json()['children'][0]['geometry']['radius']
        return g

    def geometry_from_plot(self, p):
        from itertools import groupby, chain

        def flatten(ll):
            return list(chain.from_iterable(ll))

        p.triangulate()

        g = FaceGeometry()
        g.vertices = flatten(p.vertices())
        f = p.index_faces()
        f.sort(key=len)
        faces = {k: flatten(v) for k, v in groupby(f, len)}
        g.face3 = faces.get(3, [])
        g.face4 = faces.get(4, [])
        return g
Exemple #19
0
class Range(widgets.Widget):
    value = Union([List(), List(Instance(list))],
                  default_value=[0, 1]).tag(sync=True)
Exemple #20
0
class DisplayHook(Configurable):
    """The custom IPython displayhook to replace sys.displayhook.

    This class does many things, but the basic idea is that it is a callable
    that gets called anytime user code returns a value.
    """

    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    exec_result = Instance('IPython.core.interactiveshell.ExecutionResult',
                           allow_none=True)
    cull_fraction = Float(0.2)

    def __init__(self, shell=None, cache_size=1000, **kwargs):
        super(DisplayHook, self).__init__(shell=shell, **kwargs)
        cache_size_min = 3
        if cache_size <= 0:
            self.do_full_cache = 0
            cache_size = 0
        elif cache_size < cache_size_min:
            self.do_full_cache = 0
            cache_size = 0
            warn('caching was disabled (min value for cache size is %s).' %
                 cache_size_min,
                 stacklevel=3)
        else:
            self.do_full_cache = 1

        self.cache_size = cache_size

        # we need a reference to the user-level namespace
        self.shell = shell

        self._, self.__, self.___ = '', '', ''

        # these are deliberately global:
        to_user_ns = {'_': self._, '__': self.__, '___': self.___}
        self.shell.user_ns.update(to_user_ns)

    @property
    def prompt_count(self):
        return self.shell.execution_count

    #-------------------------------------------------------------------------
    # Methods used in __call__. Override these methods to modify the behavior
    # of the displayhook.
    #-------------------------------------------------------------------------

    def check_for_underscore(self):
        """Check if the user has set the '_' variable by hand."""
        # If something injected a '_' variable in __builtin__, delete
        # ipython's automatic one so we don't clobber that.  gettext() in
        # particular uses _, so we need to stay away from it.
        if '_' in builtin_mod.__dict__:
            try:
                del self.shell.user_ns['_']
            except KeyError:
                pass

    def quiet(self):
        """Should we silence the display hook because of ';'?"""
        # do not print output if input ends in ';'

        try:
            cell = cast_unicode_py2(
                self.shell.history_manager.input_hist_parsed[-1])
        except IndexError:
            # some uses of ipshellembed may fail here
            return False

        sio = _io.StringIO(cell)
        tokens = list(tokenize.generate_tokens(sio.readline))

        for token in reversed(tokens):
            if token[0] in (tokenize.ENDMARKER, tokenize.NL, tokenize.NEWLINE,
                            tokenize.COMMENT):
                continue
            if (token[0] == tokenize.OP) and (token[1] == ';'):
                return True
            else:
                return False

    def start_displayhook(self):
        """Start the displayhook, initializing resources."""
        pass

    def write_output_prompt(self):
        """Write the output prompt.

        The default implementation simply writes the prompt to
        ``sys.stdout``.
        """
        # Use write, not print which adds an extra space.
        sys.stdout.write(self.shell.separate_out)
        outprompt = 'Out[{}]: '.format(self.shell.execution_count)
        if self.do_full_cache:
            sys.stdout.write(outprompt)

    def compute_format_data(self, result):
        """Compute format data of the object to be displayed.

        The format data is a generalization of the :func:`repr` of an object.
        In the default implementation the format data is a :class:`dict` of
        key value pair where the keys are valid MIME types and the values
        are JSON'able data structure containing the raw data for that MIME
        type. It is up to frontends to determine pick a MIME to to use and
        display that data in an appropriate manner.

        This method only computes the format data for the object and should
        NOT actually print or write that to a stream.

        Parameters
        ----------
        result : object
            The Python object passed to the display hook, whose format will be
            computed.

        Returns
        -------
        (format_dict, md_dict) : dict
            format_dict is a :class:`dict` whose keys are valid MIME types and values are
            JSON'able raw data for that MIME type. It is recommended that
            all return values of this should always include the "text/plain"
            MIME type representation of the object.
            md_dict is a :class:`dict` with the same MIME type keys
            of metadata associated with each output.
            
        """
        return self.shell.display_formatter.format(result)

    # This can be set to True by the write_output_prompt method in a subclass
    prompt_end_newline = False

    def write_format_data(self, format_dict, md_dict=None):
        """Write the format data dict to the frontend.

        This default version of this method simply writes the plain text
        representation of the object to ``sys.stdout``. Subclasses should
        override this method to send the entire `format_dict` to the
        frontends.

        Parameters
        ----------
        format_dict : dict
            The format dict for the object passed to `sys.displayhook`.
        md_dict : dict (optional)
            The metadata dict to be associated with the display data.
        """
        if 'text/plain' not in format_dict:
            # nothing to do
            return
        # We want to print because we want to always make sure we have a
        # newline, even if all the prompt separators are ''. This is the
        # standard IPython behavior.
        result_repr = format_dict['text/plain']
        if '\n' in result_repr:
            # So that multi-line strings line up with the left column of
            # the screen, instead of having the output prompt mess up
            # their first line.
            # We use the prompt template instead of the expanded prompt
            # because the expansion may add ANSI escapes that will interfere
            # with our ability to determine whether or not we should add
            # a newline.
            if not self.prompt_end_newline:
                # But avoid extraneous empty lines.
                result_repr = '\n' + result_repr

        print(result_repr)

    def update_user_ns(self, result):
        """Update user_ns with various things like _, __, _1, etc."""

        # Avoid recursive reference when displaying _oh/Out
        if result is not self.shell.user_ns['_oh']:
            if len(self.shell.user_ns['_oh']
                   ) >= self.cache_size and self.do_full_cache:
                self.cull_cache()
            # Don't overwrite '_' and friends if '_' is in __builtin__ (otherwise
            # we cause buggy behavior for things like gettext).

            if '_' not in builtin_mod.__dict__:
                self.___ = self.__
                self.__ = self._
                self._ = result
                self.shell.push({
                    '_': self._,
                    '__': self.__,
                    '___': self.___
                },
                                interactive=False)

            # hackish access to top-level  namespace to create _1,_2... dynamically
            to_main = {}
            if self.do_full_cache:
                new_result = '_' + repr(self.prompt_count)
                to_main[new_result] = result
                self.shell.push(to_main, interactive=False)
                self.shell.user_ns['_oh'][self.prompt_count] = result

    def fill_exec_result(self, result):
        if self.exec_result is not None:
            self.exec_result.result = result

    def log_output(self, format_dict):
        """Log the output."""
        if 'text/plain' not in format_dict:
            # nothing to do
            return
        if self.shell.logger.log_output:
            self.shell.logger.log_write(format_dict['text/plain'], 'output')
        self.shell.history_manager.output_hist_reprs[self.prompt_count] = \
                                                    format_dict['text/plain']

    def finish_displayhook(self):
        """Finish up all displayhook activities."""
        sys.stdout.write(self.shell.separate_out2)
        sys.stdout.flush()

    def __call__(self, result=None):
        """Printing with history cache management.

        This is invoked everytime the interpreter needs to print, and is
        activated by setting the variable sys.displayhook to it.
        """
        self.check_for_underscore()
        if result is not None and not self.quiet():
            self.start_displayhook()
            self.write_output_prompt()
            format_dict, md_dict = self.compute_format_data(result)
            self.update_user_ns(result)
            self.fill_exec_result(result)
            if format_dict:
                self.write_format_data(format_dict, md_dict)
                self.log_output(format_dict)
            self.finish_displayhook()

    def cull_cache(self):
        """Output cache is full, cull the oldest entries"""
        oh = self.shell.user_ns.get('_oh', {})
        sz = len(oh)
        cull_count = max(int(sz * self.cull_fraction), 2)
        warn('Output cache limit (currently {sz} entries) hit.\n'
             'Flushing oldest {cull_count} entries.'.format(
                 sz=sz, cull_count=cull_count))

        for i, n in enumerate(sorted(oh)):
            if i >= cull_count:
                break
            self.shell.user_ns.pop('_%i' % n, None)
            oh.pop(n, None)

    def flush(self):
        if not self.do_full_cache:
            raise ValueError("You shouldn't have reached the cache flush "
                             "if full caching is not enabled!")
        # delete auto-generated vars from global namespace

        for n in range(1, self.prompt_count + 1):
            key = '_' + repr(n)
            try:
                del self.shell.user_ns[key]
            except:
                pass
        # In some embedded circumstances, the user_ns doesn't have the
        # '_oh' key set up.
        oh = self.shell.user_ns.get('_oh', None)
        if oh is not None:
            oh.clear()

        # Release our own references to objects:
        self._, self.__, self.___ = '', '', ''

        if '_' not in builtin_mod.__dict__:
            self.shell.user_ns.update({'_': None, '__': None, '___': None})
        import gc
        # TODO: Is this really needed?
        # IronPython blocks here forever
        if sys.platform != "cli":
            gc.collect()
class LocalProcessSpawner(Spawner):
    """
    A Spawner that uses `subprocess.Popen` to start single-user servers as local processes.

    Requires local UNIX users matching the authenticated users to exist.
    Does not work on Windows.

    This is the default spawner for JupyterHub.
    """

    INTERRUPT_TIMEOUT = Integer(10,
                                help="""
        Seconds to wait for single-user server process to halt after SIGINT.

        If the process has not exited cleanly after this many seconds, a SIGTERM is sent.
        """).tag(config=True)

    TERM_TIMEOUT = Integer(5,
                           help="""
        Seconds to wait for single-user server process to halt after SIGTERM.

        If the process does not exit cleanly after this many seconds of SIGTERM, a SIGKILL is sent.
        """).tag(config=True)

    KILL_TIMEOUT = Integer(5,
                           help="""
        Seconds to wait for process to halt after SIGKILL before giving up.

        If the process does not exit cleanly after this many seconds of SIGKILL, it becomes a zombie
        process. The hub process will log a warning and then give up.
        """).tag(config=True)

    popen_kwargs = Dict(help="""Extra keyword arguments to pass to Popen

        when spawning single-user servers.

        For example::

            popen_kwargs = dict(shell=True)

        """).tag(config=True)
    shell_cmd = Command(minlen=0,
                        help="""Specify a shell command to launch.

        The single-user command will be appended to this list,
        so it sould end with `-c` (for bash) or equivalent.

        For example::

            c.LocalProcessSpawner.shell_cmd = ['bash', '-l', '-c']

        to launch with a bash login shell, which would set up the user's own complete environment.

        .. warning::

            Using shell_cmd gives users control over PATH, etc.,
            which could change what the jupyterhub-singleuser launch command does.
            Only use this for trusted users.
        """)

    proc = Instance(Popen,
                    allow_none=True,
                    help="""
        The process representing the single-user server process spawned for current user.

        Is None if no process has been spawned yet.
        """)
    pid = Integer(0,
                  help="""
        The process id (pid) of the single-user server process spawned for current user.
        """)

    def make_preexec_fn(self, name):
        """
        Return a function that can be used to set the user id of the spawned process to user with name `name`

        This function can be safely passed to `preexec_fn` of `Popen`
        """
        return set_user_setuid(name)

    def load_state(self, state):
        """Restore state about spawned single-user server after a hub restart.

        Local processes only need the process id.
        """
        super(LocalProcessSpawner, self).load_state(state)
        if 'pid' in state:
            self.pid = state['pid']

    def get_state(self):
        """Save state that is needed to restore this spawner instance after a hub restore.

        Local processes only need the process id.
        """
        state = super(LocalProcessSpawner, self).get_state()
        if self.pid:
            state['pid'] = self.pid
        return state

    def clear_state(self):
        """Clear stored state about this spawner (pid)"""
        super(LocalProcessSpawner, self).clear_state()
        self.pid = 0

    def user_env(self, env):
        """Augment environment of spawned process with user specific env variables."""
        import pwd
        env['USER'] = self.user.name
        home = pwd.getpwnam(self.user.name).pw_dir
        shell = pwd.getpwnam(self.user.name).pw_shell
        # These will be empty if undefined,
        # in which case don't set the env:
        if home:
            env['HOME'] = home
        if shell:
            env['SHELL'] = shell
        return env

    def get_env(self):
        """Get the complete set of environment variables to be set in the spawned process."""
        env = super().get_env()
        env = self.user_env(env)
        return env

    @gen.coroutine
    def start(self):
        """Start the single-user server."""
        self.port = random_port()
        cmd = []
        env = self.get_env()

        cmd.extend(self.cmd)
        cmd.extend(self.get_args())

        if self.shell_cmd:
            # using shell_cmd (e.g. bash -c),
            # add our cmd list as the last (single) argument:
            cmd = self.shell_cmd + [' '.join(pipes.quote(s) for s in cmd)]

        self.log.info("Spawning %s", ' '.join(pipes.quote(s) for s in cmd))

        popen_kwargs = dict(
            preexec_fn=self.make_preexec_fn(self.user.name),
            start_new_session=True,  # don't forward signals
        )
        popen_kwargs.update(self.popen_kwargs)
        # don't let user config override env
        popen_kwargs['env'] = env
        try:
            self.proc = Popen(cmd, **popen_kwargs)
        except PermissionError:
            # use which to get abspath
            script = shutil.which(cmd[0]) or cmd[0]
            self.log.error(
                "Permission denied trying to run %r. Does %s have access to this file?",
                script,
                self.user.name,
            )
            raise

        self.pid = self.proc.pid

        if self.__class__ is not LocalProcessSpawner:
            # subclasses may not pass through return value of super().start,
            # relying on deprecated 0.6 way of setting ip, port,
            # so keep a redundant copy here for now.
            # A deprecation warning will be shown if the subclass
            # does not return ip, port.
            if self.ip:
                self.server.ip = self.ip
            self.server.port = self.port
        return (self.ip or '127.0.0.1', self.port)

    @gen.coroutine
    def poll(self):
        """Poll the spawned process to see if it is still running.

        If the process is still running, we return None. If it is not running,
        we return the exit code of the process if we have access to it, or 0 otherwise.
        """
        # if we started the process, poll with Popen
        if self.proc is not None:
            status = self.proc.poll()
            if status is not None:
                # clear state if the process is done
                self.clear_state()
            return status

        # if we resumed from stored state,
        # we don't have the Popen handle anymore, so rely on self.pid
        if not self.pid:
            # no pid, not running
            self.clear_state()
            return 0

        # send signal 0 to check if PID exists
        # this doesn't work on Windows, but that's okay because we don't support Windows.
        alive = yield self._signal(0)
        if not alive:
            self.clear_state()
            return 0
        else:
            return None

    @gen.coroutine
    def _signal(self, sig):
        """Send given signal to a single-user server's process.

        Returns True if the process still exists, False otherwise.

        The hub process is assumed to have enough privileges to do this (e.g. root).
        """
        try:
            os.kill(self.pid, sig)
        except OSError as e:
            if e.errno == errno.ESRCH:
                return False  # process is gone
            else:
                raise
        return True  # process exists

    @gen.coroutine
    def stop(self, now=False):
        """Stop the single-user server process for the current user.

        If `now` is False (default), shutdown the server as gracefully as possible,
        e.g. starting with SIGINT, then SIGTERM, then SIGKILL.
        If `now` is True, terminate the server immediately.

        The coroutine should return when the process is no longer running.
        """
        if not now:
            status = yield self.poll()
            if status is not None:
                return
            self.log.debug("Interrupting %i", self.pid)
            yield self._signal(signal.SIGINT)
            yield self.wait_for_death(self.INTERRUPT_TIMEOUT)

        # clean shutdown failed, use TERM
        status = yield self.poll()
        if status is not None:
            return
        self.log.debug("Terminating %i", self.pid)
        yield self._signal(signal.SIGTERM)
        yield self.wait_for_death(self.TERM_TIMEOUT)

        # TERM failed, use KILL
        status = yield self.poll()
        if status is not None:
            return
        self.log.debug("Killing %i", self.pid)
        yield self._signal(signal.SIGKILL)
        yield self.wait_for_death(self.KILL_TIMEOUT)

        status = yield self.poll()
        if status is None:
            # it all failed, zombie process
            self.log.warning("Process %i never died", self.pid)
Exemple #22
0
class IPEngine(BaseParallelApplication):

    name = 'ipengine'
    description = _description
    examples = _examples
    classes = List([ZMQInteractiveShell, ProfileDir, Session, Kernel])
    _deprecated_classes = ["EngineFactory", "IPEngineApp"]

    enable_nanny = Bool(
        True,
        config=True,
        help="""Enable the nanny process.

    The nanny process enables remote signaling of single engines
    and more responsive notification of engine shutdown.

    .. versionadded:: 7.0

    """,
    )

    startup_script = Unicode(
        '', config=True, help='specify a script to be run at startup'
    )
    startup_command = Unicode(
        '', config=True, help='specify a command to be run at startup'
    )

    url_file = Unicode(
        '',
        config=True,
        help="""The full location of the file containing the connection information for
        the controller. If this is not given, the file must be in the
        security directory of the cluster directory.  This location is
        resolved using the `profile` or `profile_dir` options.""",
    )
    wait_for_url_file = Float(
        10,
        config=True,
        help="""The maximum number of seconds to wait for url_file to exist.
        This is useful for batch-systems and shared-filesystems where the
        controller and engine are started at the same time and it
        may take a moment for the controller to write the connector files.""",
    )

    url_file_name = Unicode('ipcontroller-engine.json', config=True)

    connection_info_env = Unicode()

    @default("connection_info_env")
    def _default_connection_file_env(self):
        return os.environ.get("IPP_CONNECTION_INFO", "")

    @observe('cluster_id')
    def _cluster_id_changed(self, change):
        if change['new']:
            base = 'ipcontroller-{}'.format(change['new'])
        else:
            base = 'ipcontroller'
        self.url_file_name = "%s-engine.json" % base

    log_url = Unicode(
        '',
        config=True,
        help="""The URL for the iploggerapp instance, for forwarding
        logging to a central location.""",
    )

    registration_url = Unicode(
        config=True,
        help="""Override the registration URL""",
    )
    out_stream_factory = Type(
        'ipykernel.iostream.OutStream',
        config=True,
        help="""The OutStream for handling stdout/err.
        Typically 'ipykernel.iostream.OutStream'""",
    )
    display_hook_factory = Type(
        'ipykernel.displayhook.ZMQDisplayHook',
        config=True,
        help="""The class for handling displayhook.
        Typically 'ipykernel.displayhook.ZMQDisplayHook'""",
    )
    location = Unicode(
        config=True,
        help="""The location (an IP address) of the controller.  This is
        used for disambiguating URLs, to determine whether
        loopback should be used to connect or the public address.""",
    )
    timeout = Float(
        5.0,
        config=True,
        help="""The time (in seconds) to wait for the Controller to respond
        to registration requests before giving up.""",
    )
    max_heartbeat_misses = Integer(
        50,
        config=True,
        help="""The maximum number of times a check for the heartbeat ping of a
        controller can be missed before shutting down the engine.

        If set to 0, the check is disabled.""",
    )
    sshserver = Unicode(
        config=True,
        help="""The SSH server to use for tunneling connections to the Controller.""",
    )
    sshkey = Unicode(
        config=True,
        help="""The SSH private key file to use when tunneling connections to the Controller.""",
    )
    paramiko = Bool(
        sys.platform == 'win32',
        config=True,
        help="""Whether to use paramiko instead of openssh for tunnels.""",
    )

    use_mpi = Bool(
        False,
        config=True,
        help="""Enable MPI integration.

        If set, MPI rank will be requested for my rank,
        and additionally `mpi_init` will be executed in the interactive shell.
        """,
    )
    init_mpi = Unicode(
        DEFAULT_MPI_INIT,
        config=True,
        help="""Code to execute in the user namespace when initializing MPI""",
    )
    mpi_registration_delay = Float(
        0.02,
        config=True,
        help="""Per-engine delay for mpiexec-launched engines

        avoids flooding the controller with registrations,
        which can stall under heavy load.

        Default: .02 (50 engines/sec, or 3000 engines/minute)
        """,
    )

    # not configurable:
    user_ns = Dict()
    id = Integer(
        None,
        allow_none=True,
        config=True,
        help="""Request this engine ID.

        If run in MPI, will use the MPI rank.
        Otherwise, let the Hub decide what our rank should be.
        """,
    )

    @default('id')
    def _id_default(self):
        if not self.use_mpi:
            return None
        from mpi4py import MPI

        if MPI.COMM_WORLD.size > 1:
            self.log.debug("MPI rank = %i", MPI.COMM_WORLD.rank)
            return MPI.COMM_WORLD.rank

    registrar = Instance('zmq.eventloop.zmqstream.ZMQStream', allow_none=True)
    kernel = Instance(Kernel, allow_none=True)
    hb_check_period = Integer()

    # States for the heartbeat monitoring
    # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
    # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
    _hb_last_pinged = 0.0
    _hb_last_monitored = 0.0
    _hb_missed_beats = 0
    # The zmq Stream which receives the pings from the Heart
    _hb_listener = None

    bident = Bytes()
    ident = Unicode()

    @default("ident")
    def _default_ident(self):
        return self.session.session

    @default("bident")
    def _default_bident(self):
        return self.ident.encode("utf8")

    @observe("ident")
    def _ident_changed(self, change):
        self.bident = self._default_bident()

    using_ssh = Bool(False)

    context = Instance(zmq.Context)

    @default("context")
    def _default_context(self):
        return zmq.Context.instance()

    # an IPKernelApp instance, used to setup listening for shell frontends
    kernel_app = Instance(IPKernelApp, allow_none=True)

    aliases = Dict(aliases)
    flags = Dict(flags)

    curve_serverkey = Bytes(
        config=True, help="The controller's public key for CURVE security"
    )
    curve_secretkey = Bytes(
        config=True,
        help="""The engine's secret key for CURVE security.
        Usually autogenerated on launch.""",
    )
    curve_publickey = Bytes(
        config=True,
        help="""The engine's public key for CURVE security.
        Usually autogenerated on launch.""",
    )

    @default("curve_serverkey")
    def _default_curve_serverkey(self):
        return os.environ.get("IPP_CURVE_SERVERKEY", "").encode("ascii")

    @default("curve_secretkey")
    def _default_curve_secretkey(self):
        return os.environ.get("IPP_CURVE_SECRETKEY", "").encode("ascii")

    @default("curve_publickey")
    def _default_curve_publickey(self):
        return os.environ.get("IPP_CURVE_PUBLICKEY", "").encode("ascii")

    @validate("curve_publickey", "curve_secretkey", "curve_serverkey")
    def _cast_bytes(self, proposal):
        if isinstance(proposal.value, str):
            return proposal.value.encode("ascii")
        return proposal.value

    def _ensure_curve_keypair(self):
        if not self.curve_secretkey or not self.curve_publickey:
            self.log.info("Generating new CURVE credentials")
            self.curve_publickey, self.curve_secretkey = zmq.curve_keypair()

    def find_connection_file(self):
        """Set the url file.

        Here we don't try to actually see if it exists for is valid as that
        is handled by the connection logic.
        """
        # Find the actual ipcontroller-engine.json connection file
        if not self.url_file:
            self.url_file = os.path.join(
                self.profile_dir.security_dir, self.url_file_name
            )

    def load_connection_file(self):
        """load config from a JSON connector file,
        at a *lower* priority than command-line/config files.

        Same content can be specified in $IPP_CONNECTION_INFO env
        """
        config = self.config

        if self.connection_info_env:
            self.log.info("Loading connection info from $IPP_CONNECTION_INFO")
            d = json.loads(self.connection_info_env)
        else:
            self.log.info("Loading connection file %r", self.url_file)
            with open(self.url_file) as f:
                d = json.load(f)

        # allow hand-override of location for disambiguation
        # and ssh-server
        if 'IPEngine.location' not in self.cli_config:
            self.location = d['location']
        if 'ssh' in d and not self.sshserver:
            self.sshserver = d.get("ssh")

        proto, ip = d['interface'].split('://')
        ip = disambiguate_ip_address(ip, self.location)
        d['interface'] = f'{proto}://{ip}'

        if d.get('curve_serverkey'):
            # connection file takes precedence over env, if present and defined
            self.curve_serverkey = d['curve_serverkey'].encode('ascii')
        if self.curve_serverkey:
            self.log.info("Using CurveZMQ security")
            self._ensure_curve_keypair()
        else:
            self.log.warning("Not using CurveZMQ security")

        # DO NOT allow override of basic URLs, serialization, or key
        # JSON file takes top priority there
        if d.get('key') or 'key' not in config.Session:
            config.Session.key = d.get('key', '').encode('utf8')
        config.Session.signature_scheme = d['signature_scheme']

        self.registration_url = f"{d['interface']}:{d['registration']}"

        config.Session.packer = d['pack']
        config.Session.unpacker = d['unpack']
        self.session = Session(parent=self)

        self.log.debug("Config changed:")
        self.log.debug("%r", config)
        self.connection_info = d

    def bind_kernel(self, **kwargs):
        """Promote engine to listening kernel, accessible to frontends."""
        if self.kernel_app is not None:
            return

        self.log.info("Opening ports for direct connections as an IPython kernel")
        if self.curve_serverkey:
            self.log.warning("Bound kernel does not support CURVE security")

        kernel = self.kernel

        kwargs.setdefault('config', self.config)
        kwargs.setdefault('log', self.log)
        kwargs.setdefault('profile_dir', self.profile_dir)
        kwargs.setdefault('session', self.session)

        app = self.kernel_app = IPKernelApp(**kwargs)

        # allow IPKernelApp.instance():
        IPKernelApp._instance = app

        app.init_connection_file()
        # relevant contents of init_sockets:

        app.shell_port = app._bind_socket(kernel.shell_streams[0], app.shell_port)
        app.log.debug("shell ROUTER Channel on port: %i", app.shell_port)

        iopub_socket = kernel.iopub_socket
        # ipykernel 4.3 iopub_socket is an IOThread wrapper:
        if hasattr(iopub_socket, 'socket'):
            iopub_socket = iopub_socket.socket

        app.iopub_port = app._bind_socket(iopub_socket, app.iopub_port)
        app.log.debug("iopub PUB Channel on port: %i", app.iopub_port)

        kernel.stdin_socket = self.context.socket(zmq.ROUTER)
        app.stdin_port = app._bind_socket(kernel.stdin_socket, app.stdin_port)
        app.log.debug("stdin ROUTER Channel on port: %i", app.stdin_port)

        # start the heartbeat, and log connection info:

        app.init_heartbeat()

        app.log_connection_info()
        app.connection_dir = self.profile_dir.security_dir
        app.write_connection_file()

    @property
    def tunnel_mod(self):
        from zmq.ssh import tunnel

        return tunnel

    def init_connector(self):
        """construct connection function, which handles tunnels."""
        self.using_ssh = bool(self.sshkey or self.sshserver)

        if self.sshkey and not self.sshserver:
            # We are using ssh directly to the controller, tunneling localhost to localhost
            self.sshserver = self.registration_url.split('://')[1].split(':')[0]

        if self.using_ssh:
            if self.tunnel_mod.try_passwordless_ssh(
                self.sshserver, self.sshkey, self.paramiko
            ):
                password = False
            else:
                password = getpass("SSH Password for %s: " % self.sshserver)
        else:
            password = False

        def connect(s, url, curve_serverkey=None):
            url = disambiguate_url(url, self.location)
            if curve_serverkey is None:
                curve_serverkey = self.curve_serverkey
            if curve_serverkey:
                s.setsockopt(zmq.CURVE_SERVERKEY, curve_serverkey)
                s.setsockopt(zmq.CURVE_SECRETKEY, self.curve_secretkey)
                s.setsockopt(zmq.CURVE_PUBLICKEY, self.curve_publickey)

            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
                return self.tunnel_mod.tunnel_connection(
                    s,
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            else:
                return s.connect(url)

        def maybe_tunnel(url):
            """like connect, but don't complete the connection (for use by heartbeat)"""
            url = disambiguate_url(url, self.location)
            if self.using_ssh:
                self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
                url, tunnelobj = self.tunnel_mod.open_tunnel(
                    url,
                    self.sshserver,
                    keyfile=self.sshkey,
                    paramiko=self.paramiko,
                    password=password,
                )
            return str(url)

        return connect, maybe_tunnel

    def register(self):
        """send the registration_request"""
        if self.use_mpi and self.id and self.id >= 100 and self.mpi_registration_delay:
            # Some launchres implement delay at the Launcher level,
            # but mpiexec must implement it int he engine process itself
            # delay based on our rank

            delay = self.id * self.mpi_registration_delay
            self.log.info(
                f"Delaying registration for {self.id} by {int(delay * 1000)}ms"
            )
            time.sleep(delay)

        self.log.info("Registering with controller at %s" % self.registration_url)
        ctx = self.context
        connect, maybe_tunnel = self.init_connector()
        reg = ctx.socket(zmq.DEALER)
        reg.setsockopt(zmq.IDENTITY, self.bident)
        connect(reg, self.registration_url)

        self.registrar = zmqstream.ZMQStream(reg, self.loop)

        content = dict(uuid=self.ident)
        if self.id is not None:
            self.log.info("Requesting id: %i", self.id)
            content['id'] = self.id
        self._registration_completed = False
        self.registrar.on_recv(
            lambda msg: self.complete_registration(msg, connect, maybe_tunnel)
        )

        self.session.send(self.registrar, "registration_request", content=content)

    def _report_ping(self, msg):
        """Callback for when the heartmonitor.Heart receives a ping"""
        # self.log.debug("Received a ping: %s", msg)
        self._hb_last_pinged = time.time()

    def complete_registration(self, msg, connect, maybe_tunnel):
        try:
            self._complete_registration(msg, connect, maybe_tunnel)
        except Exception as e:
            self.log.critical(f"Error completing registration: {e}", exc_info=True)
            self.exit(255)

    def _complete_registration(self, msg, connect, maybe_tunnel):
        ctx = self.context
        loop = self.loop
        identity = self.bident
        idents, msg = self.session.feed_identities(msg)
        msg = self.session.deserialize(msg)
        content = msg['content']
        info = self.connection_info

        def url(key):
            """get zmq url for given channel"""
            return str(info["interface"] + ":%i" % info[key])

        def urls(key):
            return [f'{info["interface"]}:{port}' for port in info[key]]

        if content['status'] == 'ok':
            requested_id = self.id
            self.id = content['id']
            if requested_id is not None and self.id != requested_id:
                self.log.warning(
                    f"Did not get the requested id: {self.id} != {requested_id}"
                )
                self.log.name = self.log.name.rsplit(".", 1)[0] + f".{self.id}"
            elif self.id is None:
                self.log.name += f".{self.id}"

            # create Shell Connections (MUX, Task, etc.):

            # select which broadcast endpoint to connect to
            # use rank % len(broadcast_leaves)
            broadcast_urls = urls('broadcast')
            broadcast_leaves = len(broadcast_urls)
            broadcast_index = self.id % len(broadcast_urls)
            broadcast_url = broadcast_urls[broadcast_index]

            shell_addrs = [url('mux'), url('task'), broadcast_url]
            self.log.info(f'Shell_addrs: {shell_addrs}')

            # Use only one shell stream for mux and tasks
            stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            stream.setsockopt(zmq.IDENTITY, identity)
            # TODO: enable PROBE_ROUTER when schedulers can handle the empty message
            # stream.setsockopt(zmq.PROBE_ROUTER, 1)
            self.log.debug("Setting shell identity %r", identity)

            shell_streams = [stream]
            for addr in shell_addrs:
                self.log.info("Connecting shell to %s", addr)
                connect(stream, addr)

            # control stream:
            control_url = url('control')
            curve_serverkey = self.curve_serverkey
            if self.enable_nanny:
                nanny_url, self.nanny_pipe = self.start_nanny(
                    control_url=control_url,
                )
                control_url = nanny_url
                # nanny uses our curve_publickey, not the controller's publickey
                curve_serverkey = self.curve_publickey
            control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
            control_stream.setsockopt(zmq.IDENTITY, identity)
            connect(control_stream, control_url, curve_serverkey=curve_serverkey)

            # create iopub stream:
            iopub_addr = url('iopub')
            iopub_socket = ctx.socket(zmq.PUB)
            iopub_socket.SNDHWM = 0
            iopub_socket.setsockopt(zmq.IDENTITY, identity)
            connect(iopub_socket, iopub_addr)
            try:
                from ipykernel.iostream import IOPubThread
            except ImportError:
                pass
            else:
                iopub_socket = IOPubThread(iopub_socket)
                iopub_socket.start()

            # disable history:
            self.config.HistoryManager.hist_file = ':memory:'

            # Redirect input streams and set a display hook.
            if self.out_stream_factory:
                sys.stdout = self.out_stream_factory(
                    self.session, iopub_socket, 'stdout'
                )
                sys.stdout.topic = f"engine.{self.id}.stdout".encode("ascii")
                sys.stderr = self.out_stream_factory(
                    self.session, iopub_socket, 'stderr'
                )
                sys.stderr.topic = f"engine.{self.id}.stderr".encode("ascii")

                # copied from ipykernel 6, which captures sys.__stderr__ at the FD-level
                if getattr(sys.stderr, "_original_stdstream_copy", None) is not None:
                    for handler in self.log.handlers:
                        if isinstance(handler, StreamHandler) and (
                            handler.stream.buffer.fileno() == 2
                        ):
                            self.log.debug(
                                "Seeing logger to stderr, rerouting to raw filedescriptor."
                            )

                            handler.stream = TextIOWrapper(
                                FileIO(sys.stderr._original_stdstream_copy, "w")
                            )
            if self.display_hook_factory:
                sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
                sys.displayhook.topic = f"engine.{self.id}.execute_result".encode(
                    "ascii"
                )

            # patch Session to always send engine uuid metadata
            original_send = self.session.send

            def send_with_metadata(
                stream,
                msg_or_type,
                content=None,
                parent=None,
                ident=None,
                buffers=None,
                track=False,
                header=None,
                metadata=None,
                **kwargs,
            ):
                """Ensure all messages set engine uuid metadata"""
                metadata = metadata or {}
                metadata.setdefault("engine", self.ident)
                return original_send(
                    stream,
                    msg_or_type,
                    content=content,
                    parent=parent,
                    ident=ident,
                    buffers=buffers,
                    track=track,
                    header=header,
                    metadata=metadata,
                    **kwargs,
                )

            self.session.send = send_with_metadata

            self.kernel = Kernel.instance(
                parent=self,
                engine_id=self.id,
                ident=self.ident,
                session=self.session,
                control_stream=control_stream,
                shell_streams=shell_streams,
                iopub_socket=iopub_socket,
                loop=loop,
                user_ns=self.user_ns,
                log=self.log,
            )

            self.kernel.shell.display_pub.topic = f"engine.{self.id}.displaypub".encode(
                "ascii"
            )

            # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
            self.init_signal()
            app = IPKernelApp(
                parent=self, shell=self.kernel.shell, kernel=self.kernel, log=self.log
            )
            if self.use_mpi and self.init_mpi:
                app.exec_lines.insert(0, self.init_mpi)
            app.init_profile_dir()
            app.init_code()

            self.kernel.start()
        else:
            self.log.fatal("Registration Failed: %s" % msg)
            raise Exception("Registration Failed: %s" % msg)

        self.start_heartbeat(
            maybe_tunnel(url('hb_ping')),
            maybe_tunnel(url('hb_pong')),
            content['hb_period'],
            identity,
        )
        self.log.info("Completed registration with id %i" % self.id)
        self.loop.remove_timeout(self._abort_timeout)

    def start_nanny(self, control_url):
        self.log.info("Starting nanny")
        config = Config()
        config.Session = self.config.Session
        return start_nanny(
            engine_id=self.id,
            identity=self.bident,
            control_url=control_url,
            curve_serverkey=self.curve_serverkey,
            curve_secretkey=self.curve_secretkey,
            curve_publickey=self.curve_publickey,
            registration_url=self.registration_url,
            config=config,
        )

    def start_heartbeat(self, hb_ping, hb_pong, hb_period, identity):
        """Start our heart beating"""

        hb_monitor = None
        if self.max_heartbeat_misses > 0:
            # Add a monitor socket which will record the last time a ping was seen
            mon = self.context.socket(zmq.SUB)
            if self.curve_serverkey:
                mon.setsockopt(zmq.CURVE_SERVER, 1)
                mon.setsockopt(zmq.CURVE_SECRETKEY, self.curve_secretkey)
            mport = mon.bind_to_random_port('tcp://%s' % localhost())
            mon.setsockopt(zmq.SUBSCRIBE, b"")
            self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
            self._hb_listener.on_recv(self._report_ping)

            hb_monitor = "tcp://%s:%i" % (localhost(), mport)

        heart = Heart(
            hb_ping,
            hb_pong,
            hb_monitor,
            heart_id=identity,
            curve_serverkey=self.curve_serverkey,
            curve_secretkey=self.curve_secretkey,
            curve_publickey=self.curve_publickey,
        )
        heart.start()

        # periodically check the heartbeat pings of the controller
        # Should be started here and not in "start()" so that the right period can be taken
        # from the hubs HeartBeatMonitor.period
        if self.max_heartbeat_misses > 0:
            # Use a slightly bigger check period than the hub signal period to not warn unnecessary
            self.hb_check_period = hb_period + 500
            self.log.info(
                "Starting to monitor the heartbeat signal from the hub every %i ms.",
                self.hb_check_period,
            )
            self._hb_reporter = ioloop.PeriodicCallback(
                self._hb_monitor, self.hb_check_period
            )
            self._hb_reporter.start()
        else:
            self.log.info(
                "Monitoring of the heartbeat signal from the hub is not enabled."
            )

    def abort(self):
        self.log.fatal("Registration timed out after %.1f seconds" % self.timeout)
        if "127." in self.registration_url:
            self.log.fatal(
                """
            If the controller and engines are not on the same machine,
            you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
                c.IPController.ip = '0.0.0.0' # for all interfaces, internal and external
                c.IPController.ip = '192.168.1.101' # or any interface that the engines can see
            or tunnel connections via ssh.
            """
            )
        self.session.send(
            self.registrar, "unregistration_request", content=dict(id=self.id)
        )
        time.sleep(1)
        sys.exit(255)

    def _hb_monitor(self):
        """Callback to monitor the heartbeat from the controller"""
        self._hb_listener.flush()
        if self._hb_last_monitored > self._hb_last_pinged:
            self._hb_missed_beats += 1
            self.log.warning(
                "No heartbeat in the last %s ms (%s time(s) in a row).",
                self.hb_check_period,
                self._hb_missed_beats,
            )
        else:
            # self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
            self._hb_missed_beats = 0

        if self._hb_missed_beats >= self.max_heartbeat_misses:
            self.log.fatal(
                "Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
                self.max_heartbeat_misses,
                self.hb_check_period,
            )
            self.session.send(
                self.registrar, "unregistration_request", content=dict(id=self.id)
            )
            self.loop.stop()

        self._hb_last_monitored = time.time()

    def init_engine(self):
        # This is the working dir by now.
        sys.path.insert(0, '')
        config = self.config

        if not self.connection_info_env:
            self.find_connection_file()
            if self.wait_for_url_file and not os.path.exists(self.url_file):
                self.log.warning(f"Connection file {self.url_file!r} not found")
                self.log.warning(
                    "Waiting up to %.1f seconds for it to arrive.",
                    self.wait_for_url_file,
                )
                tic = time.monotonic()
                while not os.path.exists(self.url_file) and (
                    time.monotonic() - tic < self.wait_for_url_file
                ):
                    # wait for url_file to exist, or until time limit
                    time.sleep(0.1)

            if not os.path.exists(self.url_file):
                self.log.fatal(f"Fatal: connection file never arrived: {self.url_file}")
                self.exit(1)

        self.load_connection_file()

        exec_lines = []
        for app in ('IPKernelApp', 'InteractiveShellApp'):
            if '%s.exec_lines' % app in config:
                exec_lines = config[app].exec_lines
                break

        exec_files = []
        for app in ('IPKernelApp', 'InteractiveShellApp'):
            if '%s.exec_files' % app in config:
                exec_files = config[app].exec_files
                break

        config.IPKernelApp.exec_lines = exec_lines
        config.IPKernelApp.exec_files = exec_files

        if self.startup_script:
            exec_files.append(self.startup_script)
        if self.startup_command:
            exec_lines.append(self.startup_command)

    def forward_logging(self):
        if self.log_url:
            self.log.info("Forwarding logging to %s", self.log_url)
            context = self.context
            lsock = context.socket(zmq.PUB)
            lsock.connect(self.log_url)
            handler = EnginePUBHandler(self.engine, lsock)
            handler.setLevel(self.log_level)
            self.log.addHandler(handler)

    @default("log_level")
    def _default_debug(self):
        return 10

    @catch_config_error
    def initialize(self, argv=None):
        self.log_level = 10
        super().initialize(argv)
        self.log_level = 10
        self.init_engine()
        self.forward_logging()

    def init_signal(self):
        signal.signal(signal.SIGINT, self._signal_sigint)
        signal.signal(signal.SIGTERM, self._signal_stop)

    def _signal_sigint(self, sig, frame):
        self.log.warning("Ignoring SIGINT. Terminate with SIGTERM.")

    def _signal_stop(self, sig, frame):
        self.log.critical(f"received signal {sig}, stopping")
        self.loop.add_callback_from_signal(self.loop.stop)

    def start(self):
        if self.id is not None:
            self.log.name += f".{self.id}"
        loop = self.loop

        def _start():
            self.register()
            self._abort_timeout = loop.add_timeout(
                loop.time() + self.timeout, self.abort
            )

        self.loop.add_callback(_start)
        try:
            self.loop.start()
        except KeyboardInterrupt:
            self.log.critical("Engine Interrupted, shutting down...\n")
Exemple #23
0
class DuoAuthenticator(Authenticator):
    """Duo Two-Factor Authenticator"""

    ikey = Unicode(
        help="""
        The Duo Integration Key.

        """
    ).tag(config=True)

    skey = Unicode(
        help="""
        The Duo Secret Key.

        """
    ).tag(config=True)

    akey = Unicode(
        help="""
        The Duo Application Key.

        """
    ).tag(config=True)

    apihost =  Unicode(
        help="""
        The Duo API hostname.

        """
    ).tag(config=True)

    primary_auth_class = Type(PAMAuthenticator, Authenticator,
        help="""Class to use for primary authentication of users.

        Must follow the same structure as a standard authenticator class.

        Defaults to PAMAuthenticator.
        """
    ).tag(config=True)

    primary_authenticator = Instance(Authenticator)

    @default('primary_authenticator')
    def _primary_auth_default(self):
        return self.primary_auth_class(parent=self, db=self.db)

    duo_custom_html = Unicode(
        help="""
        Custom html to use for the Duo iframe page.  Must contain at minimum an
        iframe with id="duo_iframe", as well as 'data-host' and 'data-sig-request'
        template attributes to be populated.

        Defaults to an empty string, which uses the included 'duo.html' template.
        """
    ).tag(config=True)

    def get_handlers(self,app):
        return [
            (r'/login', DuoHandler)
        ]

    @gen.coroutine
    def authenticate(self, handler, data):
        """Do secondary authentication with Duo, and return the username if successful.

        Return None otherwise.
        """

        sig_response = data['sig_response']
        authenticated_username = duo_web.verify_response(self.ikey,\
            self.skey, self.akey, sig_response)
        if authenticated_username:
            self.log.debug("Duo Authentication succeeded for user '%s'", \
                authenticated_username)
            return authenticated_username
        else:
            self.log.warning("Duo Authentication failed for user '%s'", username)
            return None

    @gen.coroutine
    def do_primary_auth(self, handler, data):
        """Do primary authentication, and return the username if successful.

        Return None otherwise.
        """
        primary_username = yield self.primary_authenticator.authenticate(handler, data)
        if primary_username:
            return primary_username
        else:
            return None
Exemple #24
0
class Widget(LoggingHasTraits):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None

    # widgets is a dictionary of all active widget objects
    widgets = {}

    # widget_types is a registry of widgets by module, version, and name:
    widget_types = WidgetRegistry()

    @classmethod
    def close_all(cls):
        for widget in list(cls.widgets.values()):
            widget.close()

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(
                Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        version = msg.get('metadata', {}).get('version', '')
        if version.split('.')[0] != PROTOCOL_VERSION_MAJOR:
            raise ValueError(
                "Incompatible widget protocol versions: received version %r, expected version %r"
                % (version, __protocol_version__))
        data = msg['content']['data']
        state = data['state']

        # Find the widget class to instantiate in the registered widgets
        widget_class = Widget.widget_types.get(state['_model_module'],
                                               state['_model_module_version'],
                                               state['_model_name'],
                                               state['_view_module'],
                                               state['_view_module_version'],
                                               state['_view_name'])
        widget = widget_class(comm=comm)
        if 'buffer_paths' in data:
            _put_buffers(state, data['buffer_paths'], msg['buffers'])
        widget.set_state(state)

    @staticmethod
    def get_manager_state(drop_defaults=False, widgets=None):
        """Returns the full state for a widget manager for embedding

        :param drop_defaults: when True, it will not include default value
        :param widgets: list with widgets to include in the state (or all widgets when None)
        :return:
        """
        state = {}
        if widgets is None:
            widgets = Widget.widgets.values()
        for widget in widgets:
            state[widget.model_id] = widget._get_embed_state(
                drop_defaults=drop_defaults)
        return {'version_major': 2, 'version_minor': 0, 'state': state}

    def _get_embed_state(self, drop_defaults=False):
        state = {
            'model_name': self._model_name,
            'model_module': self._model_module,
            'model_module_version': self._model_module_version
        }
        model_state, buffer_paths, buffers = _remove_buffers(
            self.get_state(drop_defaults=drop_defaults))
        state['state'] = model_state
        if len(buffers) > 0:
            state['buffers'] = [{
                'encoding': 'base64',
                'path': p,
                'data': standard_b64encode(d).decode('ascii')
            } for p, d in zip(buffer_paths, buffers)]
        return state

    def get_view_spec(self):
        return dict(version_major=2, version_minor=0, model_id=self._model_id)

    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_name = Unicode('WidgetModel',
                          help="Name of the model.",
                          read_only=True).tag(sync=True)
    _model_module = Unicode('@jupyter-widgets/base',
                            help="The namespace for the model.",
                            read_only=True).tag(sync=True)
    _model_module_version = Unicode(
        __jupyter_widgets_base_version__,
        help="A semver requirement for namespace version containing the model.",
        read_only=True).tag(sync=True)
    _view_name = Unicode(None, allow_none=True,
                         help="Name of the view.").tag(sync=True)
    _view_module = Unicode(None,
                           allow_none=True,
                           help="The namespace for the view.").tag(sync=True)
    _view_module_version = Unicode(
        '',
        help=
        "A semver requirement for the namespace version containing the view."
    ).tag(sync=True)

    _view_count = Int(
        None,
        allow_none=True,
        help=
        "EXPERIMENTAL: The number of views of the model displayed in the frontend. This attribute is experimental and may change or be removed in the future. None signifies that views will not be tracked. Set this to 0 to start tracking view creation/deletion."
    ).tag(sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)

    keys = List(help="The traits which are synced.")

    @default('keys')
    def _default_keys(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Dict()
    _holding_sync = False
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            state, buffer_paths, buffers = _remove_buffers(self.get_state())

            args = dict(target_name='jupyter.widget',
                        data={
                            'state': state,
                            'buffer_paths': buffer_paths
                        },
                        buffers=buffers,
                        metadata={'version': __protocol_version__})
            if self._model_id is not None:
                args['comm_id'] = self._model_id

            self.comm = Comm(**args)

    @observe('comm')
    def _comm_changed(self, change):
        """Called when the comm is changed."""
        if change['new'] is None:
            return
        self._model_id = self.model_id

        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
            self._ipython_display_ = None

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end, if it exists.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state = self.get_state(key=key)
        if len(state) > 0:
            if self._property_lock:  # we need to keep this dict up to date with the front-end values
                for name, value in state.items():
                    if name in self._property_lock:
                        self._property_lock[name] = value
            state, buffer_paths, buffers = _remove_buffers(state)
            msg = {
                'method': 'update',
                'state': state,
                'buffer_paths': buffer_paths
            }
            self._send(msg, buffers=buffers)

    def get_state(self, key=None, drop_defaults=False):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError(
                "key must be a string, an iterable of keys, or None")
        state = {}
        traits = self.traits()
        for k in keys:
            to_json = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = to_json(getattr(self, k), self)
            if not PY3 and isinstance(traits[k], Bytes) and isinstance(
                    value, bytes):
                value = memoryview(value)
            if not drop_defaults or not self._compare(value,
                                                      traits[k].default_value):
                state[k] = value
        return state

    def _is_numpy(self, x):
        return x.__class__.__name__ == 'ndarray' and x.__class__.__module__ == 'numpy'

    def _compare(self, a, b):
        if self._is_numpy(a) or self._is_numpy(b):
            import numpy as np
            return np.array_equal(a, b)
        else:
            return a == b

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._lock_property(**sync_data), self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    self.set_trait(name, from_json(sync_data[name], self))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::

                callback(widget, content, buffers)

        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::

                callback(widget, **kwargs)

            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_traits(self, **traits):
        """Dynamically add trait attributes to the Widget."""
        super(Widget, self).add_traits(**traits)
        for name, trait in traits.items():
            if trait.get_metadata('sync'):
                self.keys.append(name)
                self.send_state(name)

    def notify_change(self, change):
        """Called when a property has changed."""
        # Send the state to the frontend before the user-registered callbacks
        # are called.
        name = change['name']
        if self.comm is not None and self.comm.kernel is not None:
            # Make sure this isn't information that the front-end just sent us.
            if name in self.keys and self._should_send_property(
                    name, getattr(self, name)):
                # Send new state to front-end
                self.send_state(key=name)
        super(Widget, self).notify_change(change)

    def __repr__(self):
        return self._gen_repr_from_keys(self._repr_keys())

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the outermost context manager exits"""
        if self._holding_sync is True:
            yield
        else:
            try:
                self._holding_sync = True
                yield
            finally:
                self._holding_sync = False
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if key in self._property_lock:
            # model_state, buffer_paths, buffers
            split_value = _remove_buffers({key: to_json(value, self)})
            split_lock = _remove_buffers({key: self._property_lock[key]})
            # A roundtrip conversion through json in the comparison takes care of
            # idiosyncracies of how python data structures map to json, for example
            # tuples get converted to lists.
            if (jsonloads(jsondumps(split_value[0])) == split_lock[0]
                    and split_value[1] == split_lock[1]
                    and _buffer_list_equal(split_value[2], split_lock[2])):
                return False
        if self._holding_sync:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        if method == 'update':
            if 'state' in data:
                state = data['state']
                if 'buffer_paths' in data:
                    _put_buffers(state, data['buffer_paths'], msg['buffers'])
                self.set_state(state)

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error(
                'Unknown front-end to back-end widget msg with method "%s"' %
                method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    @staticmethod
    def _trait_to_json(x, self):
        """Convert a trait value to json."""
        return x

    @staticmethod
    def _trait_from_json(x, self):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""

        plaintext = repr(self)
        if len(plaintext) > 110:
            plaintext = plaintext[:110] + '…'
        data = {
            'text/plain': plaintext,
        }
        if self._view_name is not None:
            # The 'application/vnd.jupyter.widget-view+json' mimetype has not been registered yet.
            # See the registration process and naming convention at
            # http://tools.ietf.org/html/rfc6838
            # and the currently registered mimetypes at
            # http://www.iana.org/assignments/media-types/media-types.xhtml.
            data['application/vnd.jupyter.widget-view+json'] = {
                'version_major': 2,
                'version_minor': 0,
                'model_id': self._model_id
            }
        display(data, raw=True)

        if self._view_name is not None:
            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        if self.comm is not None and self.comm.kernel is not None:
            self.comm.send(data=msg, buffers=buffers)

    def _repr_keys(self):
        traits = self.traits()
        for key in sorted(self.keys):
            # Exclude traits that start with an underscore
            if key[0] == '_':
                continue
            # Exclude traits who are equal to their default value
            value = getattr(self, key)
            trait = traits[key]
            if self._compare(value, trait.default_value):
                continue
            elif (isinstance(trait, (Container, Dict))
                  and trait.default_value == Undefined
                  and (value is None or len(value) == 0)):
                # Empty container, and dynamic default will be empty
                continue
            yield key

    def _gen_repr_from_keys(self, keys):
        class_name = self.__class__.__name__
        signature = ', '.join('%s=%r' % (key, getattr(self, key))
                              for key in keys)
        return '%s(%s)' % (class_name, signature)
Exemple #25
0
class ZMQDisplayPublisher(DisplayPublisher):
    """A display publisher that publishes data using a ZeroMQ PUB socket."""

    session = Instance(Session, allow_none=True)
    pub_socket = Any(allow_none=True)
    parent_header = Dict({})
    topic = CBytes(b"display_data")

    # thread_local:
    # An attribute used to ensure the correct output message
    # is processed. See ipykernel Issue 113 for a discussion.
    _thread_local = Any()

    def set_parent(self, parent):
        """Set the parent for outbound messages."""
        self.parent_header = extract_header(parent)

    def _flush_streams(self):
        """flush IO Streams prior to display"""
        sys.stdout.flush()
        sys.stderr.flush()

    @default("_thread_local")
    def _default_thread_local(self):
        """Initialize our thread local storage"""
        return local()

    @property
    def _hooks(self):
        if not hasattr(self._thread_local, "hooks"):
            # create new list for a new thread
            self._thread_local.hooks = []
        return self._thread_local.hooks

    def publish(
        self,
        data,
        metadata=None,
        transient=None,
        update=False,
    ):
        """Publish a display-data message

        Parameters
        ----------
        data : dict
            A mime-bundle dict, keyed by mime-type.
        metadata : dict, optional
            Metadata associated with the data.
        transient : dict, optional, keyword-only
            Transient data that may only be relevant during a live display,
            such as display_id.
            Transient data should not be persisted to documents.
        update : bool, optional, keyword-only
            If True, send an update_display_data message instead of display_data.
        """
        self._flush_streams()
        if metadata is None:
            metadata = {}
        if transient is None:
            transient = {}
        self._validate_data(data, metadata)
        content = {}
        content["data"] = encode_images(data)
        content["metadata"] = metadata
        content["transient"] = transient

        msg_type = "update_display_data" if update else "display_data"

        # Use 2-stage process to send a message,
        # in order to put it through the transform
        # hooks before potentially sending.
        msg = self.session.msg(msg_type,
                               json_clean(content),
                               parent=self.parent_header)

        # Each transform either returns a new
        # message or None. If None is returned,
        # the message has been 'used' and we return.
        for hook in self._hooks:
            msg = hook(msg)
            if msg is None:
                return

        self.session.send(
            self.pub_socket,
            msg,
            ident=self.topic,
        )

    def clear_output(self, wait=False):
        """Clear output associated with the current execution (cell).

        Parameters
        ----------
        wait : bool (default: False)
            If True, the output will not be cleared immediately,
            instead waiting for the next display before clearing.
            This reduces bounce during repeated clear & display loops.

        """
        content = dict(wait=wait)
        self._flush_streams()
        self.session.send(
            self.pub_socket,
            "clear_output",
            content,
            parent=self.parent_header,
            ident=self.topic,
        )

    def register_hook(self, hook):
        """
        Registers a hook with the thread-local storage.

        Parameters
        ----------
        hook : Any callable object

        Returns
        -------
        Either a publishable message, or `None`.
        The DisplayHook objects must return a message from
        the __call__ method if they still require the
        `session.send` method to be called after transformation.
        Returning `None` will halt that execution path, and
        session.send will not be called.
        """
        self._hooks.append(hook)

    def unregister_hook(self, hook):
        """
        Un-registers a hook with the thread-local storage.

        Parameters
        ----------
        hook : Any callable object which has previously been
            registered as a hook.

        Returns
        -------
        bool - `True` if the hook was removed, `False` if it wasn't
            found.
        """
        try:
            self._hooks.remove(hook)
            return True
        except ValueError:
            return False
Exemple #26
0
class HistoryManager(HistoryAccessor):
    """A class to organize all history-related functionality in one place.
    """
    # Public interface

    # An instance of the IPython shell we are attached to
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    # Lists to hold processed and raw history. These start with a blank entry
    # so that we can index them starting from 1
    input_hist_parsed = List([""])
    input_hist_raw = List([""])
    # A list of directories visited during session
    dir_hist = List()

    @default('dir_hist')
    def _dir_hist_default(self):
        try:
            return [os.getcwd()]
        except OSError:
            return []

    # A dict of output history, keyed with ints from the shell's
    # execution count.
    output_hist = Dict()
    # The text/plain repr of outputs.
    output_hist_reprs = Dict()

    # The number of the current session in the history database
    session_number = Integer()

    db_log_output = Bool(
        False,
        help="Should the history database include output? (default: no)").tag(
            config=True)
    db_cache_size = Integer(
        0,
        help=
        "Write to database every x commands (higher values save disk access & power).\n"
        "Values of 1 or less effectively disable caching.").tag(config=True)
    # The input and output caches
    db_input_cache = List()
    db_output_cache = List()

    # History saving in separate thread
    save_thread = Instance('IPython.core.history.HistorySavingThread',
                           allow_none=True)
    save_flag = Instance(threading.Event, allow_none=True)

    # Private interface
    # Variables used to store the three last inputs from the user.  On each new
    # history update, we populate the user's namespace with these, shifted as
    # necessary.
    _i00 = Unicode(u'')
    _i = Unicode(u'')
    _ii = Unicode(u'')
    _iii = Unicode(u'')

    # A regex matching all forms of the exit command, so that we don't store
    # them in the history (it's annoying to rewind the first entry and land on
    # an exit call).
    _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")

    def __init__(self, shell=None, config=None, **traits):
        """Create a new history manager associated with a shell instance.
        """
        # We need a pointer back to the shell for various tasks.
        super(HistoryManager, self).__init__(shell=shell,
                                             config=config,
                                             **traits)
        self.save_flag = threading.Event()
        self.db_input_cache_lock = threading.Lock()
        self.db_output_cache_lock = threading.Lock()

        try:
            self.new_session()
        except OperationalError:
            self.log.error(
                "Failed to create history session in %s. History will not be saved.",
                self.hist_file,
                exc_info=True)
            self.hist_file = ':memory:'

        if self.enabled and self.hist_file != ':memory:':
            self.save_thread = HistorySavingThread(self)
            self.save_thread.start()

    def _get_hist_file_name(self, profile=None):
        """Get default history file name based on the Shell's profile.
        
        The profile parameter is ignored, but must exist for compatibility with
        the parent class."""
        profile_dir = self.shell.profile_dir.location
        return os.path.join(profile_dir, 'history.sqlite')

    @needs_sqlite
    def new_session(self, conn=None):
        """Get a new session number."""
        if conn is None:
            conn = self.db

        with conn:
            cur = conn.execute(
                """INSERT INTO sessions VALUES (NULL, ?, NULL,
                            NULL, "") """, (datetime.datetime.now(), ))
            self.session_number = cur.lastrowid

    def end_session(self):
        """Close the database session, filling in the end time and line count."""
        self.writeout_cache()
        with self.db:
            self.db.execute(
                """UPDATE sessions SET end=?, num_cmds=? WHERE
                            session==?""",
                (datetime.datetime.now(), len(self.input_hist_parsed) - 1,
                 self.session_number))
        self.session_number = 0

    def name_session(self, name):
        """Give the current session a name in the history database."""
        with self.db:
            self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
                            (name, self.session_number))

    def reset(self, new_session=True):
        """Clear the session history, releasing all object references, and
        optionally open a new session."""
        self.output_hist.clear()
        # The directory history can't be completely empty
        self.dir_hist[:] = [os.getcwd()]

        if new_session:
            if self.session_number:
                self.end_session()
            self.input_hist_parsed[:] = [""]
            self.input_hist_raw[:] = [""]
            self.new_session()

    # ------------------------------
    # Methods for retrieving history
    # ------------------------------
    def get_session_info(self, session=0):
        """Get info about a session.

        Parameters
        ----------

        session : int
            Session number to retrieve. The current session is 0, and negative
            numbers count back from current session, so -1 is the previous session.

        Returns
        -------
        
        session_id : int
           Session ID number
        start : datetime
           Timestamp for the start of the session.
        end : datetime
           Timestamp for the end of the session, or None if IPython crashed.
        num_cmds : int
           Number of commands run, or None if IPython crashed.
        remark : unicode
           A manually set description.
        """
        if session <= 0:
            session += self.session_number

        return super(HistoryManager, self).get_session_info(session=session)

    def _get_range_session(self, start=1, stop=None, raw=True, output=False):
        """Get input and output history from the current session. Called by
        get_range, and takes similar parameters."""
        input_hist = self.input_hist_raw if raw else self.input_hist_parsed

        n = len(input_hist)
        if start < 0:
            start += n
        if not stop or (stop > n):
            stop = n
        elif stop < 0:
            stop += n

        for i in range(start, stop):
            if output:
                line = (input_hist[i], self.output_hist_reprs.get(i))
            else:
                line = input_hist[i]
            yield (0, i, line)

    def get_range(self, session=0, start=1, stop=None, raw=True, output=False):
        """Retrieve input by session.
        
        Parameters
        ----------
        session : int
            Session number to retrieve. The current session is 0, and negative
            numbers count back from current session, so -1 is previous session.
        start : int
            First line to retrieve.
        stop : int
            End of line range (excluded from output itself). If None, retrieve
            to the end of the session.
        raw : bool
            If True, return untranslated input
        output : bool
            If True, attempt to include output. This will be 'real' Python
            objects for the current session, or text reprs from previous
            sessions if db_log_output was enabled at the time. Where no output
            is found, None is used.
            
        Returns
        -------
        entries
          An iterator over the desired lines. Each line is a 3-tuple, either
          (session, line, input) if output is False, or
          (session, line, (input, output)) if output is True.
        """
        if session <= 0:
            session += self.session_number
        if session == self.session_number:  # Current session
            return self._get_range_session(start, stop, raw, output)
        return super(HistoryManager, self).get_range(session, start, stop, raw,
                                                     output)

    ## ----------------------------
    ## Methods for storing history:
    ## ----------------------------
    def store_inputs(self, line_num, source, source_raw=None):
        """Store source and raw input in history and create input cache
        variables ``_i*``.

        Parameters
        ----------
        line_num : int
          The prompt number of this input.

        source : str
          Python input.

        source_raw : str, optional
          If given, this is the raw input without any IPython transformations
          applied to it.  If not given, ``source`` is used.
        """
        if source_raw is None:
            source_raw = source
        source = source.rstrip('\n')
        source_raw = source_raw.rstrip('\n')

        # do not store exit/quit commands
        if self._exit_re.match(source_raw.strip()):
            return

        self.input_hist_parsed.append(source)
        self.input_hist_raw.append(source_raw)

        with self.db_input_cache_lock:
            self.db_input_cache.append((line_num, source, source_raw))
            # Trigger to flush cache and write to DB.
            if len(self.db_input_cache) >= self.db_cache_size:
                self.save_flag.set()

        # update the auto _i variables
        self._iii = self._ii
        self._ii = self._i
        self._i = self._i00
        self._i00 = source_raw

        # hackish access to user namespace to create _i1,_i2... dynamically
        new_i = '_i%s' % line_num
        to_main = {
            '_i': self._i,
            '_ii': self._ii,
            '_iii': self._iii,
            new_i: self._i00
        }

        if self.shell is not None:
            self.shell.push(to_main, interactive=False)

    def store_output(self, line_num):
        """If database output logging is enabled, this saves all the
        outputs from the indicated prompt number to the database. It's
        called by run_cell after code has been executed.

        Parameters
        ----------
        line_num : int
          The line number from which to save outputs
        """
        if (not self.db_log_output) or (line_num
                                        not in self.output_hist_reprs):
            return
        output = self.output_hist_reprs[line_num]

        with self.db_output_cache_lock:
            self.db_output_cache.append((line_num, output))
        if self.db_cache_size <= 1:
            self.save_flag.set()

    def _writeout_input_cache(self, conn):
        with conn:
            for line in self.db_input_cache:
                conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
                             (self.session_number, ) + line)

    def _writeout_output_cache(self, conn):
        with conn:
            for line in self.db_output_cache:
                conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
                             (self.session_number, ) + line)

    @needs_sqlite
    def writeout_cache(self, conn=None):
        """Write any entries in the cache to the database."""
        if conn is None:
            conn = self.db

        with self.db_input_cache_lock:
            try:
                self._writeout_input_cache(conn)
            except sqlite3.IntegrityError:
                self.new_session(conn)
                print("ERROR! Session/line number was not unique in",
                      "database. History logging moved to new session",
                      self.session_number)
                try:
                    # Try writing to the new session. If this fails, don't
                    # recurse
                    self._writeout_input_cache(conn)
                except sqlite3.IntegrityError:
                    pass
            finally:
                self.db_input_cache = []

        with self.db_output_cache_lock:
            try:
                self._writeout_output_cache(conn)
            except sqlite3.IntegrityError:
                print("!! Session/line number for output was not unique",
                      "in database. Output will not be stored.")
            finally:
                self.db_output_cache = []
Exemple #27
0
class DaskGateway(Application):
    """A gateway for managing dask clusters across multiple users"""

    name = "dask-gateway-server"
    version = VERSION

    description = """Start a Dask Gateway server"""

    examples = """

    Start the server with config file ``config.py``

        dask-gateway-server -f config.py
    """

    subcommands = {
        "generate-config": (
            "dask_gateway_server.app.GenerateConfig",
            "Generate a default config file",
        ),
        "scheduler-proxy": (
            "dask_gateway_server.proxy.core.SchedulerProxyApp",
            "Start the scheduler proxy",
        ),
        "web-proxy": (
            "dask_gateway_server.proxy.core.WebProxyApp",
            "Start the web proxy",
        ),
    }

    aliases = {
        "log-level": "DaskGateway.log_level",
        "f": "DaskGateway.config_file",
        "config": "DaskGateway.config_file",
    }

    config_file = Unicode("dask_gateway_config.py",
                          help="The config file to load",
                          config=True)

    # Fail if the config file errors
    raise_config_file_errors = True

    scheduler_proxy_class = Type(
        "dask_gateway_server.proxy.SchedulerProxy",
        help="The gateway scheduler proxy class to use",
    )

    web_proxy_class = Type("dask_gateway_server.proxy.WebProxy",
                           help="The gateway web proxy class to use")

    authenticator_class = Type(
        "dask_gateway_server.auth.DummyAuthenticator",
        klass="dask_gateway_server.auth.Authenticator",
        help="The gateway authenticator class to use",
        config=True,
    )

    cluster_manager_class = Type(
        "dask_gateway_server.managers.local.UnsafeLocalClusterManager",
        klass="dask_gateway_server.managers.ClusterManager",
        help="The gateway cluster manager class to use",
        config=True,
    )

    cluster_manager_options = Instance(
        Options,
        args=(),
        help="""
        User options for configuring the cluster manager.

        Allows users to specify configuration overrides when creating a new
        cluster manager. See the documentation for more information.
        """,
        config=True,
    )

    public_url = Unicode(
        "http://:8000",
        help="The public facing URL of the whole Dask Gateway application",
        config=True,
    )

    public_connect_url = Unicode(
        help="""
        The address that the public URL can be connected to.

        Useful if the address the web proxy should listen at is different than
        the address it's reachable at (by e.g. the scheduler/workers).

        Defaults to ``public_url``.
        """,
        config=True,
    )

    gateway_url = Unicode(help="The URL that Dask clients will connect to",
                          config=True)

    private_url = Unicode(
        "http://127.0.0.1:0",
        help="""
        The gateway's private URL, used for internal communication.

        This must be reachable from the web proxy, but shouldn't be publicly
        accessible (if possible). Default is ``http://127.0.0.1:{random-port}``.
        """,
        config=True,
    )

    private_connect_url = Unicode(
        help="""
        The address that the private URL can be connected to.

        Useful if the address the gateway should listen at is different than
        the address it's reachable at (by e.g. the web proxy).

        Defaults to ``private_url``.
        """,
        config=True,
    )

    @validate(
        "public_url",
        "public_connect_url",
        "gateway_url",
        "private_url",
        "private_connect_url",
    )
    def _validate_url(self, proposal):
        url = proposal.value
        name = proposal.trait.name
        scheme = urlparse(url).scheme
        if name.startswith("gateway"):
            if scheme != "tls":
                raise ValueError("'gateway_url' must be a tls url, got %s" %
                                 url)
        else:
            if scheme not in {"http", "https"}:
                raise ValueError("%r must be an http/https url, got %s" %
                                 (name, url))
        return url

    tls_key = Unicode(
        "",
        help="""Path to TLS key file for the public url of the web proxy.

        When setting this, you should also set tls_cert.
        """,
        config=True,
    )

    tls_cert = Unicode(
        "",
        help="""Path to TLS certificate file for the public url of the web proxy.

        When setting this, you should also set tls_key.
        """,
        config=True,
    )

    cookie_secret = Bytes(
        help="""The cookie secret to use to encrypt cookies.

        Loaded from the DASK_GATEWAY_COOKIE_SECRET environment variable by
        default.
        """,
        config=True,
    )

    @default("cookie_secret")
    def _cookie_secret_default(self):
        secret = os.environb.get(b"DASK_GATEWAY_COOKIE_SECRET", b"")
        if not secret:
            self.log.debug("Generating new cookie secret")
            secret = os.urandom(32)
        return secret

    @validate("cookie_secret")
    def _cookie_secret_validate(self, proposal):
        if len(proposal["value"]) != 32:
            raise ValueError("Cookie secret is %d bytes, it must be "
                             "32 bytes" % len(proposal["value"]))
        return proposal["value"]

    cookie_max_age_days = Float(
        7,
        help="""Number of days for a login cookie to be valid.
        Default is one week.
        """,
        config=True,
    )

    stop_clusters_on_shutdown = Bool(
        True,
        help="""
        Whether to stop active clusters on gateway shutdown.

        If true, all active clusters will be stopped before shutting down the
        gateway. Set to False to leave active clusters running.
        """,
        config=True,
    )

    @validate("stop_clusters_on_shutdown")
    def _stop_clusters_on_shutdown_validate(self, proposal):
        if not proposal.value and is_in_memory_db(self.db_url):
            raise ValueError(
                "When using an in-memory database, `stop_clusters_on_shutdown` "
                "must be True")
        return proposal.value

    check_cluster_timeout = Float(
        10,
        help="""
        Timeout (in seconds) before deciding a cluster is no longer active.

        When the gateway restarts, any clusters still marked as active have
        their status checked. This timeout sets the max time we allocate for
        checking a cluster's status before deciding that the cluster is no
        longer active.
        """,
        config=True,
    )

    db_url = Unicode(
        "sqlite:///:memory:",
        help="""
        The URL for the database. Default is in-memory only.

        If not in-memory, ``db_encrypt_keys`` must also be set.
        """,
        config=True,
    )

    db_encrypt_keys = List(
        help="""
        A list of keys to use to encrypt private data in the database. Can also
        be set by the environment variable ``DASK_GATEWAY_ENCRYPT_KEYS``, where
        the value is a ``;`` delimited string of encryption keys.

        Each key should be a base64-encoded 32 byte value, and should be
        cryptographically random. Lacking other options, openssl can be used to
        generate a single key via:

        .. code-block:: shell

            $ openssl rand -base64 32

        A single key is valid, multiple keys can be used to support key rotation.
        """,
        config=True,
    )

    @default("db_encrypt_keys")
    def _db_encrypt_keys_default(self):
        keys = os.environb.get(b"DASK_GATEWAY_ENCRYPT_KEYS", b"").strip()
        if not keys:
            return []
        return [
            normalize_encrypt_key(k) for k in keys.split(b";") if k.strip()
        ]

    @validate("db_encrypt_keys")
    def _db_encrypt_keys_validate(self, proposal):
        if not proposal.value and not is_in_memory_db(self.db_url):
            raise ValueError(
                "Must configure `db_encrypt_keys`/`DASK_GATEWAY_ENCRYPT_KEYS` "
                "when not using an in-memory database")
        return [normalize_encrypt_key(k) for k in proposal.value]

    db_debug = Bool(False,
                    help="If True, all database operations will be logged",
                    config=True)

    db_cleanup_period = Float(
        600,
        help="""
        Time (in seconds) between database cleanup tasks.

        This sets how frequently old records are removed from the database.
        This shouldn't be too small (to keep the overhead low), but should be
        smaller than ``db_record_max_age`` (probably by an order of magnitude).
        """,
        config=True,
    )

    db_cluster_max_age = Float(
        3600 * 24,
        help="""
        Max time (in seconds) to keep around records of completed clusters.

        Every ``db_cleanup_period``, completed clusters older than
        ``db_cluster_max_age`` are removed from the database.
        """,
        config=True,
    )

    temp_dir = Unicode(
        help="""
        Path to a directory to use to store temporary runtime files.

        The permissions on this directory must be restricted to ``0o700``. If
        the directory doesn't already exist, it will be created on startup and
        removed on shutdown.

        The default is to create a temporary directory
        ``"dask-gateway-<UUID>"`` in the system tmpdir default location.
        """,
        config=True,
    )

    @default("temp_dir")
    def _temp_dir_default(self):
        temp_dir = tempfile.mkdtemp(prefix="dask-gateway-")
        self.log.debug("Creating temporary directory %r", temp_dir)
        weakref.finalize(self, cleanup_tmpdir, self.log, temp_dir)
        return temp_dir

    _log_formatter_cls = LogFormatter

    classes = List([ClusterManager, Authenticator, WebProxy, SchedulerProxy])

    @catch_config_error
    def initialize(self, argv=None):
        super().initialize(argv)
        if self.subapp is not None:
            return
        self.log.info("Starting dask-gateway-server - version %s", VERSION)
        self.load_config_file(self.config_file)
        self.log.info("Cluster manager: %r",
                      classname(self.cluster_manager_class))
        self.log.info("Authenticator: %r", classname(self.authenticator_class))
        self.init_logging()
        self.init_tempdir()
        self.init_asyncio()
        self.init_server_urls()
        self.init_scheduler_proxy()
        self.init_web_proxy()
        self.init_authenticator()
        self.init_user_limits()
        self.init_database()
        self.init_tornado_application()

    def init_logging(self):
        # Prevent double log messages from tornado
        self.log.propagate = False

        # hook up tornado's loggers to our app handlers
        from tornado.log import app_log, access_log, gen_log

        for log in (app_log, access_log, gen_log):
            log.name = self.log.name
            log.handlers[:] = []
        logger = logging.getLogger("tornado")
        logger.handlers[:] = []
        logger.propagate = True
        logger.parent = self.log
        logger.setLevel(self.log.level)

    def init_tempdir(self):
        if os.path.exists(self.temp_dir):
            perm = stat.S_IMODE(os.stat(self.temp_dir).st_mode)
            if perm & (stat.S_IRWXO | stat.S_IRWXG):
                raise ValueError(
                    "Temporary directory %s has excessive permissions "
                    "%r, should be at '0o700'" % (self.temp_dir, oct(perm)))
        else:
            self.log.debug("Creating temporary directory %r", self.temp_dir)
            os.mkdir(self.temp_dir, mode=0o700)
            weakref.finalize(self, cleanup_tmpdir, self.log, self.temp_dir)

    def init_asyncio(self):
        self.task_pool = TaskPool()

    def init_server_urls(self):
        """Initialize addresses from configuration"""
        self.public_urls = ServerUrls(self.public_url, self.public_connect_url)
        self.private_urls = ServerUrls(self.private_url,
                                       self.private_connect_url)
        if not self.gateway_url:
            gateway_url = f"tls://{self.public_urls.bind_host}:8786"
        else:
            gateway_url = self.gateway_url
        self.gateway_urls = ServerUrls(gateway_url)
        # Additional common url
        self.api_url = self.public_urls.connect_url + "/gateway/api"

    def init_scheduler_proxy(self):
        self.scheduler_proxy = self.scheduler_proxy_class(
            parent=self, log=self.log, public_urls=self.gateway_urls)

    def init_web_proxy(self):
        self.web_proxy = self.web_proxy_class(
            parent=self,
            log=self.log,
            public_urls=self.public_urls,
            tls_cert=self.tls_cert,
            tls_key=self.tls_key,
        )

    def init_authenticator(self):
        self.authenticator = self.authenticator_class(parent=self,
                                                      log=self.log)

    def init_user_limits(self):
        self.user_limits = UserLimits(parent=self, log=self.log)

    def init_database(self):
        self.db = DataManager(url=self.db_url,
                              echo=self.db_debug,
                              encrypt_keys=self.db_encrypt_keys)

    def init_tornado_application(self):
        self.handlers = list(handlers.default_handlers)
        self.tornado_application = web.Application(
            self.handlers,
            log=self.log,
            gateway=self,
            authenticator=self.authenticator,
            cookie_secret=self.cookie_secret,
            cookie_max_age_days=self.cookie_max_age_days,
        )

    async def start_async(self):
        self.init_signal()
        await self.start_scheduler_proxy()
        await self.start_web_proxy()
        await self.load_database_state()
        await self.start_tornado_application()

    async def start_scheduler_proxy(self):
        await self.scheduler_proxy.start()

    async def start_web_proxy(self):
        await self.web_proxy.start()

    def create_cluster_manager(self, options):
        config = self.cluster_manager_options.get_configuration(options)
        return self.cluster_manager_class(
            parent=self,
            log=self.log,
            task_pool=self.task_pool,
            temp_dir=self.temp_dir,
            api_url=self.api_url,
            **config,
        )

    def init_cluster_manager(self, manager, cluster):
        manager.username = cluster.user.name
        manager.cluster_name = cluster.name
        manager.api_token = cluster.token
        manager.tls_cert = cluster.tls_cert
        manager.tls_key = cluster.tls_key

    async def load_database_state(self):
        self.db.load_database_state()

        active_clusters = list(self.db.active_clusters())
        if active_clusters:
            self.log.info(
                "Gateway was stopped with %d active clusters, "
                "checking cluster status...",
                len(active_clusters),
            )

            tasks = (self.check_cluster(c) for c in active_clusters)
            results = await asyncio.gather(*tasks, return_exceptions=True)

            n_clusters = 0
            for c, r in zip(active_clusters, results):
                if isinstance(r, Exception):
                    self.log.error("Error while checking status of cluster %s",
                                   c.name,
                                   exc_info=r)
                elif r:
                    n_clusters += 1

            self.log.info(
                "All clusters have been checked, there are %d active clusters",
                n_clusters,
            )

        self.task_pool.create_background_task(self.cleanup_database())

    async def cleanup_database(self):
        while True:
            try:
                n = self.db.cleanup_expired(self.db_cluster_max_age)
            except Exception as exc:
                self.log.error("Error while cleaning expired database records",
                               exc_info=exc)
            else:
                self.log.debug("Removed %d expired clusters from the database",
                               n)
            await asyncio.sleep(self.db_cleanup_period)

    async def check_cluster(self, cluster):
        cluster.manager = self.create_cluster_manager(cluster.options)
        self.init_cluster_manager(cluster.manager, cluster)

        if cluster.status == ClusterStatus.RUNNING:
            client = AsyncHTTPClient()
            url = "%s/api/status" % cluster.api_address
            req = HTTPRequest(
                url,
                method="GET",
                headers={"Authorization": "token %s" % cluster.token})
            try:
                resp = await asyncio.wait_for(
                    client.fetch(req), timeout=self.check_cluster_timeout)
                workers = json.loads(resp.body.decode("utf8",
                                                      "replace"))["workers"]
                running = True
            except asyncio.CancelledError:
                raise
            except Exception:
                running = False
                workers = []
        else:
            # Gateway was stopped before cluster fully started.
            running = False
            workers = []

        if running:
            # Cluster is running, update our state to match
            await self.add_cluster_to_proxies(cluster)

            # Update our set of workers to match
            actual_workers = set(workers)
            to_stop = []
            for w in cluster.active_workers():
                if w.name in actual_workers:
                    self.mark_worker_running(cluster, w)
                else:
                    to_stop.append(w)

            tasks = (self.stop_worker(cluster, w, failed=True)
                     for w in to_stop)
            await asyncio.gather(*tasks, return_exceptions=False)

            # Start the periodic monitor
            self.start_cluster_status_monitor(cluster)
        else:
            # Cluster is not available, shut it down
            await self.stop_cluster(cluster, failed=True)

        return running

    async def start_tornado_application(self):
        self.http_server = self.tornado_application.listen(
            self.private_urls.bind_port, address=self.private_urls.bind_host)
        self.log.info("Gateway private API serving at %s",
                      self.private_urls.bind_url)
        await self.web_proxy.add_route(
            self.public_urls.connect.path + "/gateway/",
            self.private_urls.connect_url)
        self.log.info("Dask-Gateway started successfully!")
        for name, urls in [
            ("Public address", self.public_urls._to_log),
            ("Proxy address", self.gateway_urls._to_log),
        ]:
            if len(urls) == 2:
                self.log.info("- %s at %s or %s", name, *urls)
            else:
                self.log.info("- %s at %s", name, *urls)

    async def start_or_exit(self):
        try:
            await self.start_async()
        except Exception:
            self.log.critical("Failed to start gateway, shutting down",
                              exc_info=True)
            await self.stop_async(stop_event_loop=False)
            self.exit(1)

    def start(self):
        if self.subapp is not None:
            return self.subapp.start()
        AsyncIOMainLoop().install()
        loop = IOLoop.current()
        loop.add_callback(self.start_or_exit)
        try:
            loop.start()
        except KeyboardInterrupt:
            print("\nInterrupted")

    def init_signal(self):
        loop = asyncio.get_event_loop()
        for s in (signal.SIGTERM, signal.SIGINT):
            loop.add_signal_handler(s, self.handle_shutdown_signal, s)

    def handle_shutdown_signal(self, sig):
        self.log.info("Received signal %s, initiating shutdown...", sig.name)
        asyncio.ensure_future(self.stop_async())

    async def _stop_async(self, timeout=5):
        # Stop the server to prevent new requests
        if hasattr(self, "http_server"):
            self.http_server.stop()

        # If requested, shutdown any active clusters
        if self.stop_clusters_on_shutdown:
            tasks = {
                asyncio.ensure_future(self.stop_cluster(c, failed=True)): c
                for c in self.db.active_clusters()
            }
            if tasks:
                self.log.info("Stopping all active clusters...")
                done, pending = await asyncio.wait(tasks.keys())
                for f in done:
                    try:
                        await f
                    except Exception as exc:
                        cluster = tasks[f]
                        self.log.error(
                            "Failed to stop cluster %s for user %s",
                            cluster.name,
                            cluster.user.name,
                            exc_info=exc,
                        )
        else:
            self.log.info("Leaving any active clusters running")

        if hasattr(self, "task_pool"):
            await self.task_pool.close(timeout=timeout)

        # Shutdown the proxies
        if hasattr(self, "scheduler_proxy"):
            self.scheduler_proxy.stop()
        if hasattr(self, "web_proxy"):
            self.web_proxy.stop()

    async def stop_async(self, timeout=5, stop_event_loop=True):
        try:
            await self._stop_async(timeout=timeout)
        except Exception:
            self.log.error("Error while shutting down:", exc_info=True)
        # Stop the event loop
        if stop_event_loop:
            IOLoop.current().stop()

    def start_cluster_status_monitor(self, cluster):
        cluster._status_monitor = self.task_pool.create_background_task(
            self._cluster_status_monitor(cluster))

    def stop_cluster_status_monitor(self, cluster):
        if cluster._status_monitor is not None:
            cluster._status_monitor.cancel()
            cluster._status_monitor = None

    async def _cluster_status_monitor(self, cluster):
        while True:
            try:
                res = await cluster.manager.cluster_status(cluster.state)
            except asyncio.CancelledError:
                raise
            except Exception as exc:
                self.log.error("Error while checking cluster %s status",
                               cluster.name,
                               exc_info=exc)
            else:
                running, msg = res if isinstance(res, tuple) else (res, None)
                if not running:
                    if msg:
                        self.log.warning("Cluster %s stopped unexpectedly: %s",
                                         cluster.name, msg)
                    else:
                        self.log.warning("Cluster %s stopped unexpectedly",
                                         cluster.name)
                    self.schedule_stop_cluster(cluster, failed=True)
                    return
            await asyncio.sleep(cluster.manager.cluster_status_period)

    async def _start_cluster(self, cluster):
        self.log.info("Starting cluster %s for user %s...", cluster.name,
                      cluster.user.name)

        # Walk through the startup process, saving state as updates occur
        async for state in cluster.manager.start_cluster():
            self.log.debug("State update for cluster %s", cluster.name)
            self.db.update_cluster(cluster, state=state)

        # Move cluster to started
        self.db.update_cluster(cluster, status=ClusterStatus.STARTED)

    async def start_cluster(self, cluster):
        """Start the cluster.

        Returns True if successfully started, False otherwise.
        """
        try:
            async with timeout(cluster.manager.cluster_start_timeout):
                await self._start_cluster(cluster)
                self.log.info("Cluster %s has started, waiting for connection",
                              cluster.name)
                self.start_cluster_status_monitor(cluster)
                addresses = await cluster._connect_future
        except asyncio.TimeoutError:
            self.log.warning(
                "Cluster %s startup timed out after %.1f seconds",
                cluster.name,
                cluster.manager.cluster_start_timeout,
            )
            return False
        except asyncio.CancelledError:
            # Catch separately to avoid in generic handler below
            raise
        except Exception as exc:
            self.log.error("Error while starting cluster %s",
                           cluster.name,
                           exc_info=exc)
            return False

        scheduler_address, dashboard_address, api_address = addresses
        self.log.info("Cluster %s connected at %s", cluster.name,
                      scheduler_address)

        # Mark cluster as running
        self.db.update_cluster(
            cluster,
            scheduler_address=scheduler_address,
            dashboard_address=dashboard_address,
            api_address=api_address,
            status=ClusterStatus.RUNNING,
        )

        # Register routes with proxies
        await self.add_cluster_to_proxies(cluster)

        return True

    async def add_cluster_to_proxies(self, cluster):
        if cluster.dashboard_address:
            await self.web_proxy.add_route(
                self.public_urls.connect.path + "/gateway/clusters/" +
                cluster.name,
                cluster.dashboard_address,
            )
        await self.scheduler_proxy.add_route("/" + cluster.name,
                                             cluster.scheduler_address)

    def start_new_cluster(self, user, request):
        # Process the user provided options
        options = self.cluster_manager_options.parse_options(request)
        manager = self.create_cluster_manager(options)

        # Check if allowed to create cluster
        allowed, msg = self.user_limits.check_cluster_limits(
            user, manager.scheduler_memory, manager.scheduler_cores)
        if not allowed:
            raise Exception(msg)

        # Finish initializing the object states
        cluster = self.db.create_cluster(user, options,
                                         manager.scheduler_memory,
                                         manager.scheduler_cores)
        cluster.manager = manager
        self.init_cluster_manager(cluster.manager, cluster)

        # Launch the cluster startup task
        f = self.task_pool.create_task(self.start_cluster(cluster))
        f.add_done_callback(
            partial(self._monitor_start_cluster, cluster=cluster))
        cluster._start_future = f

        return cluster

    def _monitor_start_cluster(self, future, cluster=None):
        try:
            if future.result():
                # Startup succeeded, nothing to do
                return
        except asyncio.CancelledError:
            # Startup cancelled, cleanup is handled separately
            return
        except Exception as exc:
            self.log.error("Unexpected error while starting cluster %s",
                           cluster.name,
                           exc_info=exc)

        self.schedule_stop_cluster(cluster, failed=True)

    async def stop_cluster(self, cluster, failed=False):
        if cluster.status >= ClusterStatus.STOPPING:
            return

        self.log.info("Stopping cluster %s...", cluster.name)

        # Move cluster to stopping
        self.db.update_cluster(cluster, status=ClusterStatus.STOPPING)

        # Stop the periodic monitor, if present
        self.stop_cluster_status_monitor(cluster)

        # If running, cancel running start task
        await cancel_task(cluster._start_future)

        # Remove routes from proxies if already set
        await self.web_proxy.delete_route("/gateway/clusters/" + cluster.name)
        await self.scheduler_proxy.delete_route("/" + cluster.name)

        # Shutdown individual workers if no bulk shutdown supported
        if not cluster.manager.supports_bulk_shutdown:
            tasks = (self.stop_worker(cluster, w)
                     for w in cluster.active_workers())
            await asyncio.gather(*tasks, return_exceptions=True)

        # Shutdown the cluster
        try:
            await cluster.manager.stop_cluster(cluster.state)
        except Exception as exc:
            self.log.error("Failed to shutdown cluster %s",
                           cluster.name,
                           exc_info=exc)

        # If we shut the workers down in bulk, cleanup their internal state now
        if cluster.manager.supports_bulk_shutdown:
            tasks = (self.stop_worker(cluster, w)
                     for w in cluster.active_workers())
            await asyncio.gather(*tasks, return_exceptions=True)

        # Update the cluster status
        status = ClusterStatus.FAILED if failed else ClusterStatus.STOPPED
        self.db.update_cluster(cluster, status=status, stop_time=timestamp())
        cluster.pending.clear()

        self.log.info("Stopped cluster %s", cluster.name)

    def schedule_stop_cluster(self, cluster, failed=False):
        self.task_pool.create_task(self.stop_cluster(cluster, failed=failed))

    async def scale(self, cluster, total):
        """Scale cluster to total workers"""
        async with cluster.lock:
            n_active = len(cluster.active_workers())
            delta = total - n_active
            if delta == 0:
                return n_active, None
            self.log.info(
                "Scaling cluster %s to %d workers, a delta of %d",
                cluster.name,
                total,
                delta,
            )
            if delta > 0:
                actual_delta, msg = self.scale_up(cluster, delta)
            else:
                actual_delta = - await self.scale_down(cluster, -delta)
                msg = None
            return n_active + actual_delta, msg

    def scale_up(self, cluster, n_start):
        # Check how many workers we're allowed_to_start
        n_allowed, msg = self.user_limits.check_scale_limits(
            cluster,
            n_start,
            memory=cluster.manager.worker_memory,
            cores=cluster.manager.worker_cores,
        )
        for _ in range(n_allowed):
            w = self.db.create_worker(cluster, cluster.manager.worker_memory,
                                      cluster.manager.worker_cores)
            w._start_future = self.task_pool.create_task(
                self.start_worker(cluster, w))
            w._start_future.add_done_callback(
                partial(self._monitor_start_worker, worker=w, cluster=cluster))
        return n_allowed, msg

    async def _start_worker(self, cluster, worker):
        self.log.info("Starting worker %s for cluster %s...", worker.name,
                      cluster.name)

        # Walk through the startup process, saving state as updates occur
        async for state in cluster.manager.start_worker(
            worker.name, cluster.state):
            self.db.update_worker(worker, state=state)

        # Move worker to started
        self.db.update_worker(worker, status=WorkerStatus.STARTED)

    async def _worker_status_monitor(self, cluster, worker):
        while True:
            try:
                res = await cluster.manager.worker_status(
                    worker.name, worker.state, cluster.state)
            except asyncio.CancelledError:
                raise
            except Exception as exc:
                self.log.error("Error while checking worker %s status",
                               worker.name,
                               exc_info=exc)
            else:
                running, msg = res if isinstance(res, tuple) else (res, None)
                if not running:
                    return msg
            await asyncio.sleep(cluster.manager.worker_status_period)

    async def start_worker(self, cluster, worker):
        worker_status_monitor = None
        try:
            async with timeout(cluster.manager.worker_start_timeout):
                # Submit the worker
                await self._start_worker(cluster, worker)

                self.log.info("Worker %s has started, waiting for connection",
                              worker.name)

                # Wait for the worker to connect, periodically checking its status
                worker_status_monitor = asyncio.ensure_future(
                    self._worker_status_monitor(cluster, worker))
                done, pending = await asyncio.wait(
                    (worker_status_monitor, worker._connect_future),
                    return_when=asyncio.FIRST_COMPLETED,
                )
        except Exception as exc:
            if worker_status_monitor is not None:
                worker_status_monitor.cancel()
            if type(exc) is asyncio.CancelledError:
                raise
            elif type(exc) is asyncio.TimeoutError:
                self.log.warning(
                    "Worker %s startup timed out after %.1f seconds",
                    worker.name,
                    cluster.manager.worker_start_timeout,
                )
            else:
                self.log.error("Error while starting worker %s",
                               worker,
                               exc_info=exc)
            return False

        # Check monitor for failures
        if worker_status_monitor in done:
            # Failure occurred
            msg = worker_status_monitor.result()
            if msg:
                self.log.warning("Worker %s failed during startup: %s",
                                 worker.name, msg)
            else:
                self.log.warning("Worker %s failed during startup",
                                 worker.name)
            return False
        else:
            worker_status_monitor.cancel()

        self.log.info("Worker %s connected to cluster %s", worker.name,
                      cluster.name)

        # Mark worker as running
        self.mark_worker_running(cluster, worker)

        return True

    def mark_worker_running(self, cluster, worker):
        if worker.status != WorkerStatus.RUNNING:
            cluster.manager.on_worker_running(worker.name, worker.state,
                                              cluster.state)
            self.db.update_worker(worker, status=WorkerStatus.RUNNING)
            cluster.pending.discard(worker.name)

    def _monitor_start_worker(self, future, worker=None, cluster=None):
        try:
            if future.result():
                # Startup succeeded, nothing to do
                return
        except asyncio.CancelledError:
            # Startup cancelled, cleanup is handled separately
            self.log.debug("Cancelled worker %s", worker.name)
            return
        except Exception as exc:
            self.log.error(
                "Unexpected error while starting worker %s for cluster %s",
                worker.name,
                cluster.name,
                exc_info=exc,
            )

        self.schedule_stop_worker(cluster, worker, failed=True)

    async def stop_worker(self, cluster, worker, failed=False):
        # Already stopping elsewhere, return
        if worker.status >= WorkerStatus.STOPPING:
            return

        self.log.info("Stopping worker %s for cluster %s", worker.name,
                      cluster.name)

        # Move worker to stopping
        self.db.update_worker(worker, status=WorkerStatus.STOPPING)
        cluster.pending.discard(worker.name)

        # Cancel a pending start if needed
        await cancel_task(worker._start_future)

        # Shutdown the worker
        if not cluster.manager.supports_bulk_shutdown:
            try:
                await cluster.manager.stop_worker(worker.name, worker.state,
                                                  cluster.state)
            except Exception as exc:
                self.log.error(
                    "Failed to shutdown worker %s for cluster %s",
                    worker.name,
                    cluster.name,
                    exc_info=exc,
                )

        # Update the worker status
        status = WorkerStatus.FAILED if failed else WorkerStatus.STOPPED
        self.db.update_worker(worker, status=status, stop_time=timestamp())

        self.log.info("Stopped worker %s", worker.name)

    def schedule_stop_worker(self, cluster, worker, failed=False):
        self.task_pool.create_task(
            self.stop_worker(cluster, worker, failed=failed))

    def maybe_fail_worker(self, cluster, worker):
        # Ignore if cluster or worker isn't active (
        if (cluster.status != ClusterStatus.RUNNING
                or worker.status >= WorkerStatus.STOPPING):
            return
        self.schedule_stop_worker(cluster, worker, failed=True)

    async def scale_down(self, cluster, n_stop):
        if cluster.pending:
            if len(cluster.pending) > n_stop:
                to_stop = [cluster.pending.pop() for _ in range(n_stop)]
            else:
                to_stop = list(cluster.pending)
                cluster.pending.clear()
            to_stop = [cluster.workers[n] for n in to_stop]

            self.log.debug("Stopping %d pending workers for cluster %s",
                           len(to_stop), cluster.name)
            for w in to_stop:
                self.schedule_stop_worker(cluster, w)
            n_stop -= len(to_stop)

        if n_stop:
            # Request scheduler shutdown n_stop workers
            client = AsyncHTTPClient()
            body = json.dumps({"remove_count": n_stop})
            url = "%s/api/scale_down" % cluster.api_address
            req = HTTPRequest(
                url,
                method="POST",
                headers={
                    "Authorization": "token %s" % cluster.token,
                    "Content-type": "application/json",
                },
                body=body,
            )
            resp = await client.fetch(req)
            data = json.loads(resp.body.decode("utf8", "replace"))
            to_stop = [cluster.workers[n] for n in data["workers_closed"]]

            self.log.debug("Stopping %d running workers for cluster %s",
                           len(to_stop), cluster.name)
            for w in to_stop:
                self.schedule_stop_worker(cluster, w)

            return len(to_stop)
Exemple #28
0
class Axis(BaseAxis):
    """A line axis.

    A line axis is the visual representation of a numerical or date scale.

    Attributes
    ----------
    icon: string (class-level attribute)
        The font-awesome icon name for this object.
    axis_types: dict (class-level attribute)
        A registry of existing axis types.
    orientation: {'horizontal', 'vertical'}
        The orientation of the axis, either vertical or horizontal
    side: {'bottom', 'top', 'left', 'right'} or None (default: None)
        The side of the axis, either bottom, top, left or right.
    label: string (default: '')
        The axis label
    tick_format: string or None (default: '')
        The tick format for the axis, for dates use d3 string formatting.
    scale: Scale
        The scale represented by the axis
    num_ticks: int or None (default: None)
        If tick_values is None, number of ticks
    tick_values: numpy.ndarray or None (default: None)
        Tick values for the axis
    offset: dict (default: {})
        Contains a scale and a value {'scale': scale or None,
        'value': value of the offset}
        If offset['scale'] is None, the corresponding figure scale is used
        instead.
    label_location: {'middle', 'start', 'end'}
        The location of the label along the axis, one of 'start', 'end' or
        'middle'
    label_color: Color or None (default: None)
        The color of the axis label
    grid_lines: {'none', 'solid', 'dashed'}
        The display of the grid lines
    grid_color: Color or None (default: None)
        The color of the grid lines
    color: Color or None (default: None)
        The color of the line
    label_offset: string or None (default: None)
        Label displacement from the axis line. Units allowed are 'em', 'px'
        and 'ex'. Positive values are away from the figure and negative
        values are towards the figure with resepect to the axis line.
    visible: bool (default: True)
        A visibility toggle for the axis
    tick_style: Dict (default: {})
        Dictionary containing the CSS-style of the text for the ticks.
        For example: font-size of the text can be changed by passing
        `{'font-size': 14}`
    tick_rotate: int (default: 0)
        Degrees to rotate tick labels by.
    """
    icon = 'fa-arrows'
    orientation = Enum(['horizontal', 'vertical'], default_value='horizontal')\
        .tag(sync=True)
    side = Enum(['bottom', 'top', 'left', 'right'],
                allow_none=True,
                default_value=None).tag(sync=True)
    label = Unicode().tag(sync=True)
    grid_lines = Enum(['none', 'solid', 'dashed'], default_value='solid')\
        .tag(sync=True)
    tick_format = Unicode(None, allow_none=True).tag(sync=True)
    scale = Instance(Scale).tag(sync=True, **widget_serialization)
    num_ticks = Int(default_value=None, allow_none=True).tag(sync=True)
    tick_values = Array(None, allow_none=True)\
        .tag(sync=True, **array_serialization)\
        .valid(array_dimension_bounds(1, 1))
    offset = Dict().tag(sync=True, **widget_serialization)
    label_location = Enum(['middle', 'start', 'end'],
                          default_value='middle').tag(sync=True)
    label_color = Color(None, allow_none=True).tag(sync=True)
    grid_color = Color(None, allow_none=True).tag(sync=True)
    color = Color(None, allow_none=True).tag(sync=True)
    label_offset = Unicode(default_value=None, allow_none=True).tag(sync=True)

    visible = Bool(True).tag(sync=True)
    tick_style = Dict().tag(sync=True)
    tick_rotate = Int(0).tag(sync=True)

    _view_name = Unicode('Axis').tag(sync=True)
    _model_name = Unicode('AxisModel').tag(sync=True)
    _ipython_display_ = None  # We cannot display an axis outside of a figure.
Exemple #29
0
class IPythonKernel(KernelBase):
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    shell_class = Type(ZMQInteractiveShell)

    use_experimental_completions = Bool(
        True,
        help=
        "Set this flag to False to deactivate the use of experimental IPython completion APIs.",
    ).tag(config=True)

    debugpy_stream = Instance(
        ZMQStream, allow_none=True) if _is_debugpy_available else None

    user_module = Any()

    @observe('user_module')
    @observe_compat
    def _user_module_changed(self, change):
        if self.shell is not None:
            self.shell.user_module = change['new']

    user_ns = Instance(dict, args=None, allow_none=True)

    @observe('user_ns')
    @observe_compat
    def _user_ns_changed(self, change):
        if self.shell is not None:
            self.shell.user_ns = change['new']
            self.shell.init_user_ns()

    # A reference to the Python builtin 'raw_input' function.
    # (i.e., __builtin__.raw_input for Python 2.7, builtins.input for Python 3)
    _sys_raw_input = Any()
    _sys_eval_input = Any()

    def __init__(self, **kwargs):
        super(IPythonKernel, self).__init__(**kwargs)

        # Initialize the Debugger
        if _is_debugpy_available:
            self.debugger = Debugger(self.log, self.debugpy_stream,
                                     self._publish_debug_event,
                                     self.debug_shell_socket, self.session)

        # Initialize the InteractiveShell subclass
        self.shell = self.shell_class.instance(
            parent=self,
            profile_dir=self.profile_dir,
            user_module=self.user_module,
            user_ns=self.user_ns,
            kernel=self,
            compiler_class=XCachingCompiler,
        )
        self.shell.displayhook.session = self.session
        self.shell.displayhook.pub_socket = self.iopub_socket
        self.shell.displayhook.topic = self._topic('execute_result')
        self.shell.display_pub.session = self.session
        self.shell.display_pub.pub_socket = self.iopub_socket

        self.comm_manager = CommManager(parent=self, kernel=self)

        self.shell.configurables.append(self.comm_manager)
        comm_msg_types = ['comm_open', 'comm_msg', 'comm_close']
        for msg_type in comm_msg_types:
            self.shell_handlers[msg_type] = getattr(self.comm_manager,
                                                    msg_type)

        if _use_appnope() and self._darwin_app_nap:
            # Disable app-nap as the kernel is not a gui but can have guis
            import appnope
            appnope.nope()

    help_links = List([
        {
            'text': "Python Reference",
            'url': "https://docs.python.org/%i.%i" % sys.version_info[:2],
        },
        {
            'text': "IPython Reference",
            'url': "https://ipython.org/documentation.html",
        },
        {
            'text': "NumPy Reference",
            'url': "https://docs.scipy.org/doc/numpy/reference/",
        },
        {
            'text': "SciPy Reference",
            'url': "https://docs.scipy.org/doc/scipy/reference/",
        },
        {
            'text': "Matplotlib Reference",
            'url': "https://matplotlib.org/contents.html",
        },
        {
            'text': "SymPy Reference",
            'url': "http://docs.sympy.org/latest/index.html",
        },
        {
            'text': "pandas Reference",
            'url': "https://pandas.pydata.org/pandas-docs/stable/",
        },
    ]).tag(config=True)

    # Kernel info fields
    implementation = 'ipython'
    implementation_version = release.version
    language_info = {
        'name': 'python',
        'version': sys.version.split()[0],
        'mimetype': 'text/x-python',
        'codemirror_mode': {
            'name': 'ipython',
            'version': sys.version_info[0]
        },
        'pygments_lexer': 'ipython%d' % 3,
        'nbconvert_exporter': 'python',
        'file_extension': '.py'
    }

    def dispatch_debugpy(self, msg):
        if _is_debugpy_available:
            # The first frame is the socket id, we can drop it
            frame = msg[1].bytes.decode('utf-8')
            self.log.debug("Debugpy received: %s", frame)
            self.debugger.tcp_client.receive_dap_frame(frame)

    @property
    def banner(self):
        return self.shell.banner

    def start(self):
        self.shell.exit_now = False
        if self.debugpy_stream is None:
            self.log.warning(
                "debugpy_stream undefined, debugging will not be enabled")
        else:
            self.debugpy_stream.on_recv(self.dispatch_debugpy, copy=False)
        super(IPythonKernel, self).start()

    def set_parent(self, ident, parent, channel='shell'):
        """Overridden from parent to tell the display hook and output streams
        about the parent message.
        """
        super(IPythonKernel, self).set_parent(ident, parent, channel)
        if channel == 'shell':
            self.shell.set_parent(parent)

    def init_metadata(self, parent):
        """Initialize metadata.

        Run at the beginning of each execution request.
        """
        md = super(IPythonKernel, self).init_metadata(parent)
        # FIXME: remove deprecated ipyparallel-specific code
        # This is required for ipyparallel < 5.0
        md.update({
            'dependencies_met': True,
            'engine': self.ident,
        })
        return md

    def finish_metadata(self, parent, metadata, reply_content):
        """Finish populating metadata.

        Run after completing an execution request.
        """
        # FIXME: remove deprecated ipyparallel-specific code
        # This is required by ipyparallel < 5.0
        metadata["status"] = reply_content["status"]
        if (reply_content["status"] == "error"
                and reply_content["ename"] == "UnmetDependency"):
            metadata["dependencies_met"] = False

        return metadata

    def _forward_input(self, allow_stdin=False):
        """Forward raw_input and getpass to the current frontend.

        via input_request
        """
        self._allow_stdin = allow_stdin

        self._sys_raw_input = builtins.input
        builtins.input = self.raw_input

        self._save_getpass = getpass.getpass
        getpass.getpass = self.getpass

    def _restore_input(self):
        """Restore raw_input, getpass"""
        builtins.input = self._sys_raw_input

        getpass.getpass = self._save_getpass

    @property
    def execution_count(self):
        return self.shell.execution_count

    @execution_count.setter
    def execution_count(self, value):
        # Ignore the incrementing done by KernelBase, in favour of our shell's
        # execution counter.
        pass

    @contextmanager
    def _cancel_on_sigint(self, future):
        """ContextManager for capturing SIGINT and cancelling a future

        SIGINT raises in the event loop when running async code,
        but we want it to halt a coroutine.

        Ideally, it would raise KeyboardInterrupt,
        but this turns it into a CancelledError.
        At least it gets a decent traceback to the user.
        """
        sigint_future = asyncio.Future()

        # whichever future finishes first,
        # cancel the other one
        def cancel_unless_done(f, _ignored):
            if f.cancelled() or f.done():
                return
            f.cancel()

        # when sigint finishes,
        # abort the coroutine with CancelledError
        sigint_future.add_done_callback(partial(cancel_unless_done, future))
        # when the main future finishes,
        # stop watching for SIGINT events
        future.add_done_callback(partial(cancel_unless_done, sigint_future))

        def handle_sigint(*args):
            def set_sigint_result():
                if sigint_future.cancelled() or sigint_future.done():
                    return
                sigint_future.set_result(1)

            # use add_callback for thread safety
            self.io_loop.add_callback(set_sigint_result)

        # set the custom sigint hander during this context
        save_sigint = signal.signal(signal.SIGINT, handle_sigint)
        try:
            yield
        finally:
            # restore the previous sigint handler
            signal.signal(signal.SIGINT, save_sigint)

    async def do_execute(self,
                         code,
                         silent,
                         store_history=True,
                         user_expressions=None,
                         allow_stdin=False):
        shell = self.shell  # we'll need this a lot here

        self._forward_input(allow_stdin)

        reply_content = {}
        if hasattr(shell, 'run_cell_async') and hasattr(
                shell, 'should_run_async'):
            run_cell = shell.run_cell_async
            should_run_async = shell.should_run_async
        else:
            should_run_async = lambda cell: False

            # older IPython,
            # use blocking run_cell and wrap it in coroutine
            async def run_cell(*args, **kwargs):
                return shell.run_cell(*args, **kwargs)

        try:

            # default case: runner is asyncio and asyncio is already running
            # TODO: this should check every case for "are we inside the runner",
            # not just asyncio
            preprocessing_exc_tuple = None
            try:
                transformed_cell = self.shell.transform_cell(code)
            except Exception:
                transformed_cell = code
                preprocessing_exc_tuple = sys.exc_info()

            if (_asyncio_runner and shell.loop_runner is _asyncio_runner
                    and asyncio.get_event_loop().is_running()
                    and should_run_async(
                        code,
                        transformed_cell=transformed_cell,
                        preprocessing_exc_tuple=preprocessing_exc_tuple,
                    )):
                coro = run_cell(
                    code,
                    store_history=store_history,
                    silent=silent,
                    transformed_cell=transformed_cell,
                    preprocessing_exc_tuple=preprocessing_exc_tuple)
                coro_future = asyncio.ensure_future(coro)

                with self._cancel_on_sigint(coro_future):
                    res = None
                    try:
                        res = await coro_future
                    finally:
                        shell.events.trigger('post_execute')
                        if not silent:
                            shell.events.trigger('post_run_cell', res)
            else:
                # runner isn't already running,
                # make synchronous call,
                # letting shell dispatch to loop runners
                res = shell.run_cell(code,
                                     store_history=store_history,
                                     silent=silent)
        finally:
            self._restore_input()

        if res.error_before_exec is not None:
            err = res.error_before_exec
        else:
            err = res.error_in_exec

        if res.success:
            reply_content['status'] = 'ok'
        else:
            reply_content['status'] = 'error'

            reply_content.update({
                'traceback': shell._last_traceback or [],
                'ename': str(type(err).__name__),
                'evalue': str(err),
            })

            # FIXME: deprecated piece for ipyparallel (remove in 5.0):
            e_info = dict(engine_uuid=self.ident,
                          engine_id=self.int_id,
                          method='execute')
            reply_content['engine_info'] = e_info

        # Return the execution counter so clients can display prompts
        reply_content['execution_count'] = shell.execution_count - 1

        if 'traceback' in reply_content:
            self.log.info("Exception in execute request:\n%s",
                          '\n'.join(reply_content['traceback']))

        # At this point, we can tell whether the main code execution succeeded
        # or not.  If it did, we proceed to evaluate user_expressions
        if reply_content['status'] == 'ok':
            reply_content['user_expressions'] = \
                         shell.user_expressions(user_expressions or {})
        else:
            # If there was an error, don't even try to compute expressions
            reply_content['user_expressions'] = {}

        # Payloads should be retrieved regardless of outcome, so we can both
        # recover partial output (that could have been generated early in a
        # block, before an error) and always clear the payload system.
        reply_content['payload'] = shell.payload_manager.read_payload()
        # Be aggressive about clearing the payload because we don't want
        # it to sit in memory until the next execute_request comes in.
        shell.payload_manager.clear_payload()

        return reply_content

    def do_complete(self, code, cursor_pos):
        if _use_experimental_60_completion and self.use_experimental_completions:
            return self._experimental_do_complete(code, cursor_pos)

        # FIXME: IPython completers currently assume single line,
        # but completion messages give multi-line context
        # For now, extract line from cell, based on cursor_pos:
        if cursor_pos is None:
            cursor_pos = len(code)
        line, offset = line_at_cursor(code, cursor_pos)
        line_cursor = cursor_pos - offset

        txt, matches = self.shell.complete('', line, line_cursor)
        return {
            'matches': matches,
            'cursor_end': cursor_pos,
            'cursor_start': cursor_pos - len(txt),
            'metadata': {},
            'status': 'ok'
        }

    async def do_debug_request(self, msg):
        if _is_debugpy_available:
            return await self.debugger.process_request(msg)

    def _experimental_do_complete(self, code, cursor_pos):
        """
        Experimental completions from IPython, using Jedi.
        """
        if cursor_pos is None:
            cursor_pos = len(code)
        with _provisionalcompleter():
            raw_completions = self.shell.Completer.completions(
                code, cursor_pos)
            completions = list(_rectify_completions(code, raw_completions))

            comps = []
            for comp in completions:
                comps.append(
                    dict(
                        start=comp.start,
                        end=comp.end,
                        text=comp.text,
                        type=comp.type,
                    ))

        if completions:
            s = completions[0].start
            e = completions[0].end
            matches = [c.text for c in completions]
        else:
            s = cursor_pos
            e = cursor_pos
            matches = []

        return {
            'matches': matches,
            'cursor_end': e,
            'cursor_start': s,
            'metadata': {
                _EXPERIMENTAL_KEY_NAME: comps
            },
            'status': 'ok'
        }

    def do_inspect(self, code, cursor_pos, detail_level=0):
        name = token_at_cursor(code, cursor_pos)

        reply_content = {'status': 'ok'}
        reply_content['data'] = {}
        reply_content['metadata'] = {}
        try:
            reply_content['data'].update(
                self.shell.object_inspect_mime(name,
                                               detail_level=detail_level))
            if not self.shell.enable_html_pager:
                reply_content['data'].pop('text/html')
            reply_content['found'] = True
        except KeyError:
            reply_content['found'] = False

        return reply_content

    def do_history(self,
                   hist_access_type,
                   output,
                   raw,
                   session=0,
                   start=0,
                   stop=None,
                   n=None,
                   pattern=None,
                   unique=False):
        if hist_access_type == 'tail':
            hist = self.shell.history_manager.get_tail(n,
                                                       raw=raw,
                                                       output=output,
                                                       include_latest=True)

        elif hist_access_type == 'range':
            hist = self.shell.history_manager.get_range(session,
                                                        start,
                                                        stop,
                                                        raw=raw,
                                                        output=output)

        elif hist_access_type == 'search':
            hist = self.shell.history_manager.search(pattern,
                                                     raw=raw,
                                                     output=output,
                                                     n=n,
                                                     unique=unique)
        else:
            hist = []

        return {
            'status': 'ok',
            'history': list(hist),
        }

    def do_shutdown(self, restart):
        self.shell.exit_now = True
        return dict(status='ok', restart=restart)

    def do_is_complete(self, code):
        transformer_manager = getattr(self.shell, 'input_transformer_manager',
                                      None)
        if transformer_manager is None:
            # input_splitter attribute is deprecated
            transformer_manager = self.shell.input_splitter
        status, indent_spaces = transformer_manager.check_complete(code)
        r = {'status': status}
        if status == 'incomplete':
            r['indent'] = ' ' * indent_spaces
        return r

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        from .serialize import serialize_object, unpack_apply_message
        shell = self.shell
        try:
            working = shell.user_ns

            prefix = "_" + str(msg_id).replace("-", "") + "_"

            f, args, kwargs = unpack_apply_message(bufs, working, copy=False)

            fname = getattr(f, '__name__', 'f')

            fname = prefix + "f"
            argname = prefix + "args"
            kwargname = prefix + "kwargs"
            resultname = prefix + "result"

            ns = {fname: f, argname: args, kwargname: kwargs, resultname: None}
            # print ns
            working.update(ns)
            code = "%s = %s(*%s,**%s)" % (resultname, fname, argname,
                                          kwargname)
            try:
                exec(code, shell.user_global_ns, shell.user_ns)
                result = working.get(resultname)
            finally:
                for key in ns:
                    working.pop(key)

            result_buf = serialize_object(
                result,
                buffer_threshold=self.session.buffer_threshold,
                item_threshold=self.session.item_threshold,
            )

        except BaseException as e:
            # invoke IPython traceback formatting
            shell.showtraceback()
            reply_content = {
                "traceback": shell._last_traceback or [],
                "ename": str(type(e).__name__),
                "evalue": str(e),
            }
            # FIXME: deprecated piece for ipyparallel (remove in 5.0):
            e_info = dict(engine_uuid=self.ident,
                          engine_id=self.int_id,
                          method='apply')
            reply_content['engine_info'] = e_info

            self.send_response(self.iopub_socket,
                               'error',
                               reply_content,
                               ident=self._topic('error'),
                               channel='shell')
            self.log.info("Exception in apply request:\n%s",
                          '\n'.join(reply_content['traceback']))
            result_buf = []
            reply_content['status'] = 'error'
        else:
            reply_content = {'status': 'ok'}

        return reply_content, result_buf

    def do_clear(self):
        self.shell.reset(False)
        return dict(status='ok')
class KernelGatewayApp(JupyterApp):
    """Application that provisions Jupyter kernels and proxies HTTP/Websocket
    traffic to the kernels.

    - reads command line and environment variable settings
    - initializes managers and routes
    - creates a Tornado HTTP server
    - starts the Tornado event loop
    """
    name = 'jupyter-kernel-gateway'
    version = __version__
    description = """
        Jupyter Kernel Gateway

        Provisions Jupyter kernels and proxies HTTP/Websocket traffic
        to them.
    """

    # Also include when generating help options
    classes = [NotebookHTTPPersonality, JupyterWebsocketPersonality]
    # Enable some command line shortcuts
    aliases = aliases

    # Server IP / PORT binding
    port_env = 'KG_PORT'
    port_default_value = 8888
    port = Integer(port_default_value,
                   config=True,
                   help="Port on which to listen (KG_PORT env var)")

    @default('port')
    def port_default(self):
        return int(os.getenv(self.port_env, self.port_default_value))

    port_retries_env = 'KG_PORT_RETRIES'
    port_retries_default_value = 50
    port_retries = Integer(
        port_retries_default_value,
        config=True,
        help=
        "Number of ports to try if the specified port is not available (KG_PORT_RETRIES env var)"
    )

    @default('port_retries')
    def port_retries_default(self):
        return int(
            os.getenv(self.port_retries_env, self.port_retries_default_value))

    ip_env = 'KG_IP'
    ip_default_value = '127.0.0.1'
    ip = Unicode(ip_default_value,
                 config=True,
                 help="IP address on which to listen (KG_IP env var)")

    @default('ip')
    def ip_default(self):
        return os.getenv(self.ip_env, self.ip_default_value)

    # Base URL
    base_url_env = 'KG_BASE_URL'
    base_url_default_value = '/'
    base_url = Unicode(
        base_url_default_value,
        config=True,
        help=
        """The base path for mounting all API resources (KG_BASE_URL env var)"""
    )

    @default('base_url')
    def base_url_default(self):
        return os.getenv(self.base_url_env, self.base_url_default_value)

    # Token authorization
    auth_token_env = 'KG_AUTH_TOKEN'
    auth_token = Unicode(
        config=True,
        help=
        'Authorization token required for all requests (KG_AUTH_TOKEN env var)'
    )

    @default('auth_token')
    def _auth_token_default(self):
        return os.getenv(self.auth_token_env, '')

    # CORS headers
    allow_credentials_env = 'KG_ALLOW_CREDENTIALS'
    allow_credentials = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Credentials header. (KG_ALLOW_CREDENTIALS env var)'
    )

    @default('allow_credentials')
    def allow_credentials_default(self):
        return os.getenv(self.allow_credentials_env, '')

    allow_headers_env = 'KG_ALLOW_HEADERS'
    allow_headers = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Headers header. (KG_ALLOW_HEADERS env var)'
    )

    @default('allow_headers')
    def allow_headers_default(self):
        return os.getenv(self.allow_headers_env, '')

    allow_methods_env = 'KG_ALLOW_METHODS'
    allow_methods = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Methods header. (KG_ALLOW_METHODS env var)'
    )

    @default('allow_methods')
    def allow_methods_default(self):
        return os.getenv(self.allow_methods_env, '')

    allow_origin_env = 'KG_ALLOW_ORIGIN'
    allow_origin = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Allow-Origin header. (KG_ALLOW_ORIGIN env var)'
    )

    @default('allow_origin')
    def allow_origin_default(self):
        return os.getenv(self.allow_origin_env, '')

    expose_headers_env = 'KG_EXPOSE_HEADERS'
    expose_headers = Unicode(
        config=True,
        help=
        'Sets the Access-Control-Expose-Headers header. (KG_EXPOSE_HEADERS env var)'
    )

    @default('expose_headers')
    def expose_headers_default(self):
        return os.getenv(self.expose_headers_env, '')

    trust_xheaders_env = 'KG_TRUST_XHEADERS'
    trust_xheaders = CBool(
        False,
        config=True,
        help=
        'Use x-* header values for overriding the remote-ip, useful when application is behing a proxy. (KG_TRUST_XHEADERS env var)'
    )

    @default('trust_xheaders')
    def trust_xheaders_default(self):
        return strtobool(os.getenv(self.trust_xheaders_env, 'False'))

    max_age_env = 'KG_MAX_AGE'
    max_age = Unicode(
        config=True,
        help='Sets the Access-Control-Max-Age header. (KG_MAX_AGE env var)')

    @default('max_age')
    def max_age_default(self):
        return os.getenv(self.max_age_env, '')

    max_kernels_env = 'KG_MAX_KERNELS'
    max_kernels = Integer(
        None,
        config=True,
        allow_none=True,
        help=
        'Limits the number of kernel instances allowed to run by this gateway. Unbounded by default. (KG_MAX_KERNELS env var)'
    )

    @default('max_kernels')
    def max_kernels_default(self):
        val = os.getenv(self.max_kernels_env)
        return val if val is None else int(val)

    seed_uri_env = 'KG_SEED_URI'
    seed_uri = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        'Runs the notebook (.ipynb) at the given URI on every kernel launched. No seed by default. (KG_SEED_URI env var)'
    )

    @default('seed_uri')
    def seed_uri_default(self):
        return os.getenv(self.seed_uri_env)

    prespawn_count_env = 'KG_PRESPAWN_COUNT'
    prespawn_count = Integer(
        None,
        config=True,
        allow_none=True,
        help=
        'Number of kernels to prespawn using the default language. No prespawn by default. (KG_PRESPAWN_COUNT env var)'
    )

    @default('prespawn_count')
    def prespawn_count_default(self):
        val = os.getenv(self.prespawn_count_env)
        return val if val is None else int(val)

    default_kernel_name_env = 'KG_DEFAULT_KERNEL_NAME'
    default_kernel_name = Unicode(
        config=True,
        help=
        'Default kernel name when spawning a kernel (KG_DEFAULT_KERNEL_NAME env var)'
    )

    @default('default_kernel_name')
    def default_kernel_name_default(self):
        # defaults to Jupyter's default kernel name on empty string
        return os.getenv(self.default_kernel_name_env, '')

    force_kernel_name_env = 'KG_FORCE_KERNEL_NAME'
    force_kernel_name = Unicode(
        config=True,
        help=
        'Override any kernel name specified in a notebook or request (KG_FORCE_KERNEL_NAME env var)'
    )

    @default('force_kernel_name')
    def force_kernel_name_default(self):
        return os.getenv(self.force_kernel_name_env, '')

    env_process_whitelist_env = 'KG_ENV_PROCESS_WHITELIST'
    env_process_whitelist = List(
        config=True,
        help=
        """Environment variables allowed to be inherited from the spawning process by the kernel"""
    )

    @default('env_process_whitelist')
    def env_process_whitelist_default(self):
        return os.getenv(self.env_process_whitelist_env, '').split(',')

    api_env = 'KG_API'
    api_default_value = 'kernel_gateway.jupyter_websocket'
    api = Unicode(
        api_default_value,
        config=True,
        help=
        """Controls which API to expose, that of a Jupyter notebook server, the seed
            notebook's, or one provided by another module, respectively using values
            'kernel_gateway.jupyter_websocket', 'kernel_gateway.notebook_http', or
            another fully qualified module name (KG_API env var)
            """)

    @default('api')
    def api_default(self):
        return os.getenv(self.api_env, self.api_default_value)

    @observe('api')
    def api_changed(self, event):
        try:
            self._load_api_module(event['new'])
        except ImportError:
            # re-raise with more sensible message to help the user
            raise ImportError('API module {} not found'.format(event['new']))

    certfile_env = 'KG_CERTFILE'
    certfile = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        """The full path to an SSL/TLS certificate file. (KG_CERTFILE env var)"""
    )

    @default('certfile')
    def certfile_default(self):
        return os.getenv(self.certfile_env)

    keyfile_env = 'KG_KEYFILE'
    keyfile = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        """The full path to a private key file for usage with SSL/TLS. (KG_KEYFILE env var)"""
    )

    @default('keyfile')
    def keyfile_default(self):
        return os.getenv(self.keyfile_env)

    client_ca_env = 'KG_CLIENT_CA'
    client_ca = Unicode(
        None,
        config=True,
        allow_none=True,
        help=
        """The full path to a certificate authority certificate for SSL/TLS client authentication. (KG_CLIENT_CA env var)"""
    )

    @default('client_ca')
    def client_ca_default(self):
        return os.getenv(self.client_ca_env)

    kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)

    kernel_spec_manager_class = Type(default_value=KernelSpecManager,
                                     config=True,
                                     help="""
        The kernel spec manager class to use. Should be a subclass
        of `jupyter_client.kernelspec.KernelSpecManager`.
        """)

    kernel_manager_class = Type(klass=MappingKernelManager,
                                default_value=SeedingMappingKernelManager,
                                config=True,
                                help="""The kernel manager class to use.""")

    def _load_api_module(self, module_name):
        """Tries to import the given module name.

        Parameters
        ----------
        module_name: str
            Module name to import

        Returns
        -------
        module
            Module with the given name loaded using importlib.import_module
        """
        # some compatibility allowances
        if module_name == 'jupyter-websocket':
            module_name = 'kernel_gateway.jupyter_websocket'
        elif module_name == 'notebook-http':
            module_name = 'kernel_gateway.notebook_http'
        return importlib.import_module(module_name)

    def _load_notebook(self, uri):
        """Loads a notebook from the local filesystem or HTTP URL.

        Raises
        ------
        RuntimeError if there is no kernel spec matching the one specified in
        the notebook or forced via configuration.

        Returns
        -------
        object
            Notebook object from nbformat
        """
        parts = urlparse(uri)

        if parts.netloc == '' or parts.netloc == 'file':
            # Local file
            with open(parts.path) as nb_fh:
                notebook = nbformat.read(nb_fh, 4)
        else:
            # Remote file
            import requests
            resp = requests.get(uri)
            resp.raise_for_status()
            notebook = nbformat.reads(resp.text, 4)

        # Error if no kernel spec can handle the language requested
        kernel_name = self.force_kernel_name if self.force_kernel_name \
            else notebook['metadata']['kernelspec']['name']
        self.kernel_spec_manager.get_kernel_spec(kernel_name)

        return notebook

    def initialize(self, argv=None):
        """Initializes the base class, configurable manager instances, the
        Tornado web app, and the tornado HTTP server.

        Parameters
        ----------
        argv
            Command line arguments
        """
        super(KernelGatewayApp, self).initialize(argv)
        self.init_configurables()
        self.init_webapp()
        self.init_http_server()

    def init_configurables(self):
        """Initializes all configurable objects including a kernel manager, kernel
        spec manager, session manager, and personality.

        Any kernel pool configured by the personality will be its responsibility
        to shut down.

        Optionally, loads a notebook and prespawns the configured number of
        kernels.
        """
        self.kernel_spec_manager = KernelSpecManager(parent=self)

        self.seed_notebook = None
        if self.seed_uri is not None:
            # Note: must be set before instantiating a SeedingMappingKernelManager
            self.seed_notebook = self._load_notebook(self.seed_uri)

        # Only pass a default kernel name when one is provided. Otherwise,
        # adopt whatever default the kernel manager wants to use.
        kwargs = {}
        if self.default_kernel_name:
            kwargs['default_kernel_name'] = self.default_kernel_name

        self.kernel_spec_manager = self.kernel_spec_manager_class(
            parent=self, )
        self.kernel_manager = self.kernel_manager_class(
            parent=self,
            log=self.log,
            connection_dir=self.runtime_dir,
            kernel_spec_manager=self.kernel_spec_manager,
            **kwargs)

        self.session_manager = SessionManager(
            log=self.log, kernel_manager=self.kernel_manager)
        self.contents_manager = None

        if self.prespawn_count:
            if self.max_kernels and self.prespawn_count > self.max_kernels:
                raise RuntimeError(
                    'cannot prespawn {}; more than max kernels {}'.format(
                        self.prespawn_count, self.max_kernels))

        api_module = self._load_api_module(self.api)
        func = getattr(api_module, 'create_personality')
        self.personality = func(parent=self, log=self.log)

        self.personality.init_configurables()

    def init_webapp(self):
        """Initializes Tornado web application with uri handlers.

        Adds the various managers and web-front configuration values to the
        Tornado settings for reference by the handlers.
        """
        # Enable the same pretty logging the notebook uses
        enable_pretty_logging()

        # Configure the tornado logging level too
        logging.getLogger().setLevel(self.log_level)

        handlers = self.personality.create_request_handlers()

        self.web_app = web.Application(
            handlers=handlers,
            kernel_manager=self.kernel_manager,
            session_manager=self.session_manager,
            contents_manager=self.contents_manager,
            kernel_spec_manager=self.kernel_spec_manager,
            kg_auth_token=self.auth_token,
            kg_allow_credentials=self.allow_credentials,
            kg_allow_headers=self.allow_headers,
            kg_allow_methods=self.allow_methods,
            kg_allow_origin=self.allow_origin,
            kg_expose_headers=self.expose_headers,
            kg_max_age=self.max_age,
            kg_max_kernels=self.max_kernels,
            kg_env_process_whitelist=self.env_process_whitelist,
            kg_api=self.api,
            kg_personality=self.personality,
            # Also set the allow_origin setting used by notebook so that the
            # check_origin method used everywhere respects the value
            allow_origin=self.allow_origin,
            # Always allow remote access (has been limited to localhost >= notebook 5.6)
            allow_remote_access=True)

        # promote the current personality's "config" tagged traitlet values to webapp settings
        for trait_name, trait_value in self.personality.class_traits(
                config=True).items():
            kg_name = 'kg_' + trait_name
            # a personality's traitlets may not overwrite the kernel gateway's
            if kg_name not in self.web_app.settings:
                self.web_app.settings[kg_name] = trait_value.get(
                    obj=self.personality)
            else:
                self.log.warning(
                    'The personality trait name, %s, conflicts with a kernel gateway trait.',
                    trait_name)

    def _build_ssl_options(self):
        """Build a dictionary of SSL options for the tornado HTTP server.

        Taken directly from jupyter/notebook code.
        """
        ssl_options = {}
        if self.certfile:
            ssl_options['certfile'] = self.certfile
        if self.keyfile:
            ssl_options['keyfile'] = self.keyfile
        if self.client_ca:
            ssl_options['ca_certs'] = self.client_ca
        if not ssl_options:
            # None indicates no SSL config
            ssl_options = None
        else:
            # SSL may be missing, so only import it if it's to be used
            import ssl
            # Disable SSLv3 by default, since its use is discouraged.
            ssl_options.setdefault('ssl_version', ssl.PROTOCOL_TLSv1)
            if ssl_options.get('ca_certs', False):
                ssl_options.setdefault('cert_reqs', ssl.CERT_REQUIRED)

        return ssl_options

    def init_http_server(self):
        """Initializes a HTTP server for the Tornado web application on the
        configured interface and port.

        Tries to find an open port if the one configured is not available using
        the same logic as the Jupyer Notebook server.
        """
        ssl_options = self._build_ssl_options()
        self.http_server = httpserver.HTTPServer(self.web_app,
                                                 xheaders=self.trust_xheaders,
                                                 ssl_options=ssl_options)

        for port in random_ports(self.port, self.port_retries + 1):
            try:
                self.http_server.listen(port, self.ip)
            except socket.error as e:
                if e.errno == errno.EADDRINUSE:
                    self.log.info(
                        'The port %i is already in use, trying another port.' %
                        port)
                    continue
                elif e.errno in (errno.EACCES,
                                 getattr(errno, 'WSAEACCES', errno.EACCES)):
                    self.log.warning("Permission to listen on port %i denied" %
                                     port)
                    continue
                else:
                    raise
            else:
                self.port = port
                break
        else:
            self.log.critical(
                'ERROR: the notebook server could not be started because '
                'no available port could be found.')
            self.exit(1)

    def start(self):
        """Starts an IO loop for the application."""
        super(KernelGatewayApp, self).start()
        self.log.info('Jupyter Kernel Gateway at http{}://{}:{}'.format(
            's' if self.keyfile else '', self.ip, self.port))
        self.io_loop = ioloop.IOLoop.current()

        if sys.platform != 'win32':
            signal.signal(signal.SIGHUP, signal.SIG_IGN)

        signal.signal(signal.SIGTERM, self._signal_stop)

        try:
            self.io_loop.start()
        except KeyboardInterrupt:
            self.log.info("Interrupted...")
        finally:
            self.shutdown()

    def stop(self):
        """
        Stops the HTTP server and IO loop associated with the application.
        """
        def _stop():
            self.http_server.stop()
            self.io_loop.stop()

        self.io_loop.add_callback(_stop)

    def shutdown(self):
        """Stop all kernels in the pool."""
        self.personality.shutdown()

    def _signal_stop(self, sig, frame):
        self.log.info("Received signal to terminate.")
        self.io_loop.stop()