Beispiel #1
0
    def _custom_traits(self):
        '''
        Create creates for the atomic size (using the covalent radius) and atom
        colors (using the common `Jmol`_ color scheme). Note that that data is
        present in the static data (see :mod:`~exa.relational.isotope`).

        .. _Jmol: http://jmol.sourceforge.net/jscolors/
        '''
        self._set_categories()
        grps = self.groupby('frame')
        symbols = grps.apply(lambda g: g['symbol'].cat.codes.values)    # Pass integers rather than string symbols
        symbols = Unicode(symbols.to_json(orient='values')).tag(sync=True)
        symmap = {i: v for i, v in enumerate(self['symbol'].cat.categories)}
        sym2rad = symbol_to_radius()
        radii = sym2rad[self['symbol'].unique()]
        radii = Dict({i: radii[v] for i, v in symmap.items()}).tag(sync=True)  # (Int) symbol radii
        sym2col = symbol_to_color()
        colors = sym2col[self['symbol'].unique()]    # Same thing for colors
        colors = Dict({i: colors[v] for i, v in symmap.items()}).tag(sync=True)
        try:
            atom_set = grps.apply(lambda g: g['set'].values).to_json(orient='values')
            atom_set = Unicode(atom_set).tag(sync=True)
        except KeyError:
            atom_set = Unicode().tag(sync=True)
        # Note that position traits (atom_x, atom_y, atom_z) are created automatically
        # since we have defined _traits = ['x', 'y', 'z'] above.
        return {'atom_symbols': symbols, 'atom_radii': radii, 'atom_colors': colors,
                'atom_set': atom_set}
Beispiel #2
0
class CoregistrationUI(HasTraits):
    """Class for coregistration assisted by graphical interface.

    Parameters
    ----------
    info_file : None | str
        The FIFF file with digitizer data for coregistration.
    %(subject)s
    %(subjects_dir)s
    fiducials : list | dict | str
        The fiducials given in the MRI (surface RAS) coordinate
        system. If a dict is provided it must be a dict with 3 entries
        with keys 'lpa', 'rpa' and 'nasion' with as values coordinates in m.
        If a list it must be a list of DigPoint instances as returned
        by the read_fiducials function.
        If set to 'estimated', the fiducials are initialized
        automatically using fiducials defined in MNI space on fsaverage
        template. If set to 'auto', one tries to find the fiducials
        in a file with the canonical name (``bem/{subject}-fiducials.fif``)
        and if abstent one falls back to 'estimated'. Defaults to 'auto'.
    head_resolution : bool
        If True, use a high-resolution head surface. Defaults to False.
    head_transparency : bool
        If True, display the head surface with transparency. Defaults to False.
    hpi_coils : bool
        If True, display the HPI coils. Defaults to True.
    head_shape_points : bool
        If True, display the head shape points. Defaults to True.
    eeg_channels : bool
        If True, display the EEG channels. Defaults to True.
    orient_glyphs : bool
        If True, orient the sensors towards the head surface. Default to False.
    scale_by_distance : bool
        If True, scale the sensors based on their distance to the head surface.
        Defaults to True.
    project_eeg : bool
        If True, project the EEG channels onto the head surface. Defaults to
        False.
    mark_inside : bool
        If True, mark the head shape points that are inside the head surface
        with a different color. Defaults to True.
    sensor_opacity : float
        The opacity of the sensors between 0 and 1. Defaults to 1.0.
    trans : str
        The path to the Head<->MRI transform FIF file ("-trans.fif").
    size : tuple
        The dimensions (width, height) of the rendering view. The default is
        (800, 600).
    bgcolor : tuple | str
        The background color as a tuple (red, green, blue) of float
        values between 0 and 1 or a valid color name (i.e. 'white'
        or 'w'). Defaults to 'grey'.
    show : bool
        Display the window as soon as it is ready. Defaults to True.
    standalone : bool
        If True, start the Qt application event loop. Default to False.
    %(scene_interaction)s
        Defaults to ``'terrain'``.

        .. versionadded:: 1.0
    %(verbose)s
    """

    _subject = Unicode()
    _subjects_dir = Unicode()
    _lock_fids = Bool()
    _fiducials_file = Unicode()
    _current_fiducial = Unicode()
    _info_file = Unicode()
    _orient_glyphs = Bool()
    _scale_by_distance = Bool()
    _project_eeg = Bool()
    _mark_inside = Bool()
    _hpi_coils = Bool()
    _head_shape_points = Bool()
    _eeg_channels = Bool()
    _head_resolution = Bool()
    _head_transparency = Bool()
    _grow_hair = Float()
    _scale_mode = Unicode()
    _icp_fid_match = Unicode()

    @verbose
    def __init__(self,
                 info_file,
                 subject=None,
                 subjects_dir=None,
                 fiducials='auto',
                 head_resolution=None,
                 head_transparency=None,
                 hpi_coils=None,
                 head_shape_points=None,
                 eeg_channels=None,
                 orient_glyphs=None,
                 scale_by_distance=None,
                 project_eeg=None,
                 mark_inside=None,
                 sensor_opacity=None,
                 trans=None,
                 size=None,
                 bgcolor=None,
                 show=True,
                 standalone=False,
                 interaction='terrain',
                 verbose=None):
        from ..viz.backends.renderer import _get_renderer
        from ..viz.backends._utils import _qt_app_exec

        def _get_default(var, val):
            return var if var is not None else val

        self._actors = dict()
        self._surfaces = dict()
        self._widgets = dict()
        self._verbose = verbose
        self._plot_locked = False
        self._params_locked = False
        self._refresh_rate_ms = max(int(round(1000. / 60.)), 1)
        self._redraws_pending = set()
        self._parameter_mutex = threading.Lock()
        self._redraw_mutex = threading.Lock()
        self._head_geo = None
        self._coord_frame = "mri"
        self._mouse_no_mvt = -1
        self._to_cf_t = None
        self._omit_hsp_distance = 0.0
        self._head_opacity = 1.0
        self._fid_colors = tuple(DEFAULTS['coreg'][f'{key}_color']
                                 for key in ('lpa', 'nasion', 'rpa'))
        self._defaults = dict(
            size=_get_default(size, (800, 600)),
            bgcolor=_get_default(bgcolor, "grey"),
            orient_glyphs=_get_default(orient_glyphs, True),
            scale_by_distance=_get_default(scale_by_distance, True),
            project_eeg=_get_default(project_eeg, False),
            mark_inside=_get_default(mark_inside, True),
            hpi_coils=_get_default(hpi_coils, True),
            head_shape_points=_get_default(head_shape_points, True),
            eeg_channels=_get_default(eeg_channels, True),
            head_resolution=_get_default(head_resolution, True),
            head_transparency=_get_default(head_transparency, False),
            head_opacity=0.5,
            sensor_opacity=_get_default(sensor_opacity, 1.0),
            fiducials=("LPA", "Nasion", "RPA"),
            fiducial="LPA",
            lock_fids=True,
            grow_hair=0.0,
            scale_modes=["None", "uniform", "3-axis"],
            scale_mode="None",
            icp_fid_matches=('nearest', 'matched'),
            icp_fid_match='matched',
            icp_n_iterations=20,
            omit_hsp_distance=10.0,
            lock_head_opacity=self._head_opacity < 1.0,
            weights=dict(
                lpa=1.0,
                nasion=10.0,
                rpa=1.0,
                hsp=1.0,
                eeg=1.0,
                hpi=1.0,
            ),
        )

        # process requirements
        info = None
        subjects_dir = get_subjects_dir(subjects_dir=subjects_dir,
                                        raise_error=True)
        subject = _get_default(subject, self._get_subjects(subjects_dir)[0])

        # setup the window
        self._renderer = _get_renderer(size=self._defaults["size"],
                                       bgcolor=self._defaults["bgcolor"])
        self._renderer._window_close_connect(self._clean)
        self._renderer.set_interaction(interaction)
        self._renderer._status_bar_initialize()

        # setup the model
        self._immediate_redraw = (self._renderer._kind != 'qt')
        self._info = info
        self._fiducials = fiducials
        self._coreg = Coregistration(self._info, subject, subjects_dir,
                                     fiducials)
        fid_accurate = self._coreg._fid_accurate
        for fid in self._defaults["weights"].keys():
            setattr(self, f"_{fid}_weight", self._defaults["weights"][fid])

        # set main traits
        self._set_subjects_dir(subjects_dir)
        self._set_subject(subject)
        self._set_info_file(info_file)
        self._set_orient_glyphs(self._defaults["orient_glyphs"])
        self._set_scale_by_distance(self._defaults["scale_by_distance"])
        self._set_project_eeg(self._defaults["project_eeg"])
        self._set_mark_inside(self._defaults["mark_inside"])
        self._set_hpi_coils(self._defaults["hpi_coils"])
        self._set_head_shape_points(self._defaults["head_shape_points"])
        self._set_eeg_channels(self._defaults["eeg_channels"])
        self._set_head_resolution(self._defaults["head_resolution"])
        self._set_head_transparency(self._defaults["head_transparency"])
        self._set_grow_hair(self._defaults["grow_hair"])
        self._set_omit_hsp_distance(self._defaults["omit_hsp_distance"])
        self._set_icp_n_iterations(self._defaults["icp_n_iterations"])
        self._set_icp_fid_match(self._defaults["icp_fid_match"])

        # configure UI
        self._reset_fitting_parameters()
        self._configure_status_bar()
        self._configure_dock()
        self._configure_picking()

        # once the docks are initialized
        self._set_current_fiducial(self._defaults["fiducial"])
        self._set_scale_mode(self._defaults["scale_mode"])
        if trans is not None:
            self._load_trans(trans)
        self._redraw()  # we need the elements to be present now
        if not fid_accurate:
            self._set_head_resolution('high')
            self._forward_widget_command('high_res_head', "set_value", True)
            self._set_lock_fids(True)  # hack to make the dig disappear
        self._set_lock_fids(fid_accurate)

        # must be done last
        if show:
            self._renderer.show()
        # update the view once shown
        views = {
            True: dict(azimuth=90, elevation=90),  # front
            False: dict(azimuth=180, elevation=90)
        }  # left
        self._renderer.set_camera(distance=None, **views[self._lock_fids])
        self._redraw()
        # XXX: internal plotter/renderer should not be exposed
        if not self._immediate_redraw:
            self._renderer.plotter.add_callback(self._redraw,
                                                self._refresh_rate_ms)
        self._renderer.plotter.show_axes()
        if standalone:
            _qt_app_exec(self._renderer.figure.store["app"])

    def _set_subjects_dir(self, subjects_dir):
        self._subjects_dir = _check_fname(subjects_dir,
                                          overwrite=True,
                                          must_exist=True,
                                          need_dir=True)

    def _set_subject(self, subject):
        self._subject = subject

    def _set_lock_fids(self, state):
        self._lock_fids = bool(state)

    def _set_fiducials_file(self, fname):
        if not self._check_fif('fiducials', fname):
            return
        self._fiducials_file = _check_fname(fname,
                                            overwrite=True,
                                            must_exist=True,
                                            need_dir=False)

    def _set_current_fiducial(self, fid):
        self._current_fiducial = fid.lower()

    def _set_info_file(self, fname):
        if fname is None:
            return

        # info file can be anything supported by read_raw
        try:
            check_fname(fname,
                        'info',
                        tuple(raw_supported_types.keys()),
                        endings_err=tuple(raw_supported_types.keys()))
        except IOError as e:
            warn(e)
            self._widgets["info_file"].set_value(0, '')
            return

        fname = _check_fname(fname, overwrite=True)  # convert to str

        # ctf ds `files` are actually directories
        if fname.endswith(('.ds', )):
            self._info_file = _check_fname(fname,
                                           overwrite=True,
                                           must_exist=True,
                                           need_dir=True)
        else:
            self._info_file = _check_fname(fname,
                                           overwrite=True,
                                           must_exist=True,
                                           need_dir=False)

    def _set_omit_hsp_distance(self, distance):
        self._omit_hsp_distance = distance

    def _set_orient_glyphs(self, state):
        self._orient_glyphs = bool(state)

    def _set_scale_by_distance(self, state):
        self._scale_by_distance = bool(state)

    def _set_project_eeg(self, state):
        self._project_eeg = bool(state)

    def _set_mark_inside(self, state):
        self._mark_inside = bool(state)

    def _set_hpi_coils(self, state):
        self._hpi_coils = bool(state)

    def _set_head_shape_points(self, state):
        self._head_shape_points = bool(state)

    def _set_eeg_channels(self, state):
        self._eeg_channels = bool(state)

    def _set_head_resolution(self, state):
        self._head_resolution = bool(state)

    def _set_head_transparency(self, state):
        self._head_transparency = bool(state)

    def _set_grow_hair(self, value):
        self._grow_hair = value

    def _set_scale_mode(self, mode):
        self._scale_mode = mode

    def _set_fiducial(self, value, coord):
        fid = self._current_fiducial.lower()
        coords = ["X", "Y", "Z"]
        idx = coords.index(coord)
        getattr(self._coreg, f"_{fid}")[0][idx] = value / 1e3
        self._update_plot("mri_fids")

    def _set_parameter(self, value, mode_name, coord):
        if self._params_locked:
            return
        with self._parameter_mutex:
            self._set_parameter_safe(value, mode_name, coord)
        self._update_plot("sensors")

    def _set_parameter_safe(self, value, mode_name, coord):
        params = dict(
            rotation=self._coreg._rotation,
            translation=self._coreg._translation,
            scale=self._coreg._scale,
        )
        idx = ["X", "Y", "Z"].index(coord)
        if mode_name == "rotation":
            params[mode_name][idx] = np.deg2rad(value)
        elif mode_name == "translation":
            params[mode_name][idx] = value / 1e3
        else:
            assert mode_name == "scale"
            params[mode_name][idx] = value / 1e2
        self._coreg._update_params(
            rot=params["rotation"],
            tra=params["translation"],
            sca=params["scale"],
        )

    def _set_icp_n_iterations(self, n_iterations):
        self._icp_n_iterations = n_iterations

    def _set_icp_fid_match(self, method):
        self._icp_fid_match = method

    def _set_point_weight(self, weight, point):
        funcs = {
            'hpi': '_set_hpi_coils',
            'hsp': '_set_head_shape_points',
            'eeg': '_set_eeg_channels',
        }
        if point in funcs.keys():
            getattr(self, funcs[point])(weight > 0)
        setattr(self, f"_{point}_weight", weight)
        setattr(self._coreg, f"_{point}_weight", weight)
        self._update_distance_estimation()

    @observe("_subjects_dir")
    def _subjects_dir_changed(self, change=None):
        # XXX: add coreg.set_subjects_dir
        self._coreg._subjects_dir = self._subjects_dir
        subjects = self._get_subjects()
        self._subject = subjects[0]
        self._reset()

    @observe("_subject")
    def _subject_changed(self, changed=None):
        # XXX: add coreg.set_subject()
        self._coreg._subject = self._subject
        self._coreg._setup_bem()
        self._coreg._setup_fiducials(self._fiducials)
        self._reset()
        rr = (self._coreg._processed_low_res_mri_points * self._coreg._scale)
        self._head_geo = dict(rr=rr,
                              tris=self._coreg._bem_low_res["tris"],
                              nn=self._coreg._bem_low_res["nn"])

    @observe("_lock_fids")
    def _lock_fids_changed(self, change=None):
        view_widgets = ["project_eeg", "fit_fiducials", "fit_icp"]
        fid_widgets = ["fid_X", "fid_Y", "fid_Z", "fids_file", "fids"]
        self._set_head_transparency(self._lock_fids)
        if self._lock_fids:
            self._forward_widget_command(view_widgets, "set_enabled", True)
            self._display_message()
            self._update_distance_estimation()
        else:
            self._forward_widget_command(view_widgets, "set_enabled", False)
            self._display_message("Picking fiducials - "
                                  f"{self._current_fiducial.upper()}")
        self._set_sensors_visibility(self._lock_fids)
        self._forward_widget_command("lock_fids", "set_value", self._lock_fids)
        self._forward_widget_command(fid_widgets, "set_enabled",
                                     not self._lock_fids)

    @observe("_fiducials_file")
    def _fiducials_file_changed(self, change=None):
        fids, _ = read_fiducials(self._fiducials_file)
        self._coreg._setup_fiducials(fids)
        self._update_distance_estimation()
        self._reset()
        self._set_lock_fids(True)

    @observe("_current_fiducial")
    def _current_fiducial_changed(self, change=None):
        self._update_fiducials()
        self._follow_fiducial_view()
        if not self._lock_fids:
            self._display_message("Picking fiducials - "
                                  f"{self._current_fiducial.upper()}")

    @observe("_info_file")
    def _info_file_changed(self, change=None):
        if not self._info_file:
            return
        elif self._info_file.endswith(('.fif', '.fif.gz')):
            fid, tree, _ = fiff_open(self._info_file)
            fid.close()
            if len(dir_tree_find(tree, FIFF.FIFFB_MEAS_INFO)) > 0:
                self._info = read_info(self._info_file, verbose=False)
            elif len(dir_tree_find(tree, FIFF.FIFFB_ISOTRAK)) > 0:
                self._info = _empty_info(1)
                self._info['dig'] = read_dig_fif(fname=self._info_file).dig
                self._info._unlocked = False
        else:
            self._info = read_raw(self._info_file).info
        # XXX: add coreg.set_info()
        self._coreg._info = self._info
        self._coreg._setup_digs()
        self._reset()

    @observe("_orient_glyphs")
    def _orient_glyphs_changed(self, change=None):
        self._update_plot(["hpi", "hsp", "eeg"])

    @observe("_scale_by_distance")
    def _scale_by_distance_changed(self, change=None):
        self._update_plot(["hpi", "hsp", "eeg"])

    @observe("_project_eeg")
    def _project_eeg_changed(self, change=None):
        self._update_plot("eeg")

    @observe("_mark_inside")
    def _mark_inside_changed(self, change=None):
        self._update_plot("hsp")

    @observe("_hpi_coils")
    def _hpi_coils_changed(self, change=None):
        self._update_plot("hpi")

    @observe("_head_shape_points")
    def _head_shape_point_changed(self, change=None):
        self._update_plot("hsp")

    @observe("_eeg_channels")
    def _eeg_channels_changed(self, change=None):
        self._update_plot("eeg")

    @observe("_head_resolution")
    def _head_resolution_changed(self, change=None):
        self._update_plot(["head"])
        if self._grow_hair > 0:
            self._update_plot(["hair"])

    @observe("_head_transparency")
    def _head_transparency_changed(self, change=None):
        self._head_opacity = self._defaults["head_opacity"] \
            if self._head_transparency else 1.0
        self._actors["head"].GetProperty().SetOpacity(self._head_opacity)
        self._renderer._update()

    @observe("_grow_hair")
    def _grow_hair_changed(self, change=None):
        self._coreg.set_grow_hair(self._grow_hair)
        self._update_plot("hair")

    @observe("_scale_mode")
    def _scale_mode_changed(self, change=None):
        mode = None if self._scale_mode == "None" else self._scale_mode
        self._coreg.set_scale_mode(mode)
        self._forward_widget_command(["sX", "sY", "sZ"], "set_enabled", mode
                                     is not None)

    @observe("_icp_fid_match")
    def _icp_fid_match_changed(self, change=None):
        self._coreg.set_fid_match(self._icp_fid_match)

    def _configure_picking(self):
        self._renderer._update_picking_callback(self._on_mouse_move,
                                                self._on_button_press,
                                                self._on_button_release,
                                                self._on_pick)

    @verbose
    def _redraw(self, verbose=None):
        if not self._redraws_pending:
            return
        draw_map = dict(
            head=self._add_head_surface,
            hair=self._add_head_hair,
            mri_fids=self._add_mri_fiducials,
            hsp=self._add_head_shape_points,
            hpi=self._add_hpi_coils,
            eeg=self._add_eeg_channels,
            head_fids=self._add_head_fiducials,
        )
        with self._redraw_mutex:
            logger.debug(f'Redrawing {self._redraws_pending}')
            for key in self._redraws_pending:
                draw_map[key]()
            self._redraws_pending.clear()
            self._renderer._update()
            self._renderer._process_events()  # necessary for MacOS?

    def _on_mouse_move(self, vtk_picker, event):
        if self._mouse_no_mvt:
            self._mouse_no_mvt -= 1

    def _on_button_press(self, vtk_picker, event):
        self._mouse_no_mvt = 2

    def _on_button_release(self, vtk_picker, event):
        if self._mouse_no_mvt > 0:
            x, y = vtk_picker.GetEventPosition()
            # XXX: internal plotter/renderer should not be exposed
            plotter = self._renderer.figure.plotter
            picked_renderer = self._renderer.figure.plotter.renderer
            # trigger the pick
            plotter.picker.Pick(x, y, 0, picked_renderer)
        self._mouse_no_mvt = 0

    def _on_pick(self, vtk_picker, event):
        if self._lock_fids:
            return
        # XXX: taken from Brain, can be refactored
        cell_id = vtk_picker.GetCellId()
        mesh = vtk_picker.GetDataSet()
        if mesh is None or cell_id == -1 or not self._mouse_no_mvt:
            return
        if not getattr(mesh, "_picking_target", False):
            return
        pos = np.array(vtk_picker.GetPickPosition())
        vtk_cell = mesh.GetCell(cell_id)
        cell = [
            vtk_cell.GetPointId(point_id)
            for point_id in range(vtk_cell.GetNumberOfPoints())
        ]
        vertices = mesh.points[cell]
        idx = np.argmin(abs(vertices - pos), axis=0)
        vertex_id = cell[idx[0]]

        fiducials = [s.lower() for s in self._defaults["fiducials"]]
        idx = fiducials.index(self._current_fiducial.lower())
        # XXX: add coreg.set_fids
        self._coreg._fid_points[idx] = self._surfaces["head"].points[vertex_id]
        self._coreg._reset_fiducials()
        self._update_fiducials()
        self._update_plot("mri_fids")

    def _reset_fitting_parameters(self):
        self._forward_widget_command("icp_n_iterations", "set_value",
                                     self._defaults["icp_n_iterations"])
        self._forward_widget_command("icp_fid_match", "set_value",
                                     self._defaults["icp_fid_match"])
        weights_widgets = [
            f"{w}_weight" for w in self._defaults["weights"].keys()
        ]
        self._forward_widget_command(weights_widgets, "set_value",
                                     list(self._defaults["weights"].values()))

    def _reset_fiducials(self):
        self._set_current_fiducial(self._defaults["fiducial"])

    def _omit_hsp(self):
        self._coreg.omit_head_shape_points(self._omit_hsp_distance / 1e3)
        n_omitted = np.sum(~self._coreg._extra_points_filter)
        n_remaining = len(self._coreg._dig_dict['hsp']) - n_omitted
        self._update_plot("hsp")
        self._update_distance_estimation()
        self._display_message(f"{n_omitted} head shape points omitted, "
                              f"{n_remaining} remaining.")

    def _reset_omit_hsp_filter(self):
        self._coreg._extra_points_filter = None
        self._coreg._update_params(force_update_omitted=True)
        self._update_plot("hsp")
        self._update_distance_estimation()
        n_total = len(self._coreg._dig_dict['hsp'])
        self._display_message(
            f"No head shape point is omitted, the total is {n_total}.")

    def _update_plot(self, changes="all"):
        # Update list of things that need to be updated/plotted (and maybe
        # draw them immediately)
        if self._plot_locked:
            return
        if self._info is None:
            changes = ["head", "mri_fids"]
            self._to_cf_t = dict(mri=dict(trans=np.eye(4)), head=None)
        else:
            self._to_cf_t = _get_transforms_to_coord_frame(
                self._info, self._coreg.trans, coord_frame=self._coord_frame)
        all_keys = (
            'head',
            'mri_fids',  # MRI first
            'hair',  # then hair
            'hsp',
            'hpi',
            'eeg',
            'head_fids',  # then dig
        )
        if changes == 'all':
            changes = list(all_keys)
        elif changes == 'sensors':
            changes = all_keys[2:]  # omit MRI ones
        elif isinstance(changes, str):
            changes = [changes]
        changes = set(changes)
        # ideally we would maybe have this in:
        # with self._redraw_mutex:
        # it would reduce "jerkiness" of the updates, but this should at least
        # work okay
        bad = changes.difference(set(all_keys))
        assert len(bad) == 0, f'Unknown changes: {bad}'
        self._redraws_pending.update(changes)
        if self._immediate_redraw:
            self._redraw()

    @contextmanager
    def _lock_plot(self):
        old_plot_locked = self._plot_locked
        self._plot_locked = True
        try:
            yield
        finally:
            self._plot_locked = old_plot_locked

    @contextmanager
    def _lock_params(self):
        old_params_locked = self._params_locked
        self._params_locked = True
        try:
            yield
        finally:
            self._params_locked = old_params_locked

    def _display_message(self, msg=""):
        self._status_msg.set_value(msg)
        self._status_msg.show()
        self._status_msg.update()

    def _follow_fiducial_view(self):
        fid = self._current_fiducial.lower()
        view = dict(lpa='left', rpa='right', nasion='front')
        kwargs = dict(front=(90., 90.), left=(180, 90), right=(0., 90))
        kwargs = dict(zip(('azimuth', 'elevation'), kwargs[view[fid]]))
        if not self._lock_fids:
            self._renderer.set_camera(distance=None, **kwargs)

    def _update_fiducials(self):
        fid = self._current_fiducial.lower()
        val = getattr(self._coreg, f"_{fid}")[0] * 1e3
        with self._lock_plot():
            self._forward_widget_command(["fid_X", "fid_Y", "fid_Z"],
                                         "set_value", val)

    def _update_distance_estimation(self):
        value = self._coreg._get_fiducials_distance_str() + '\n' + \
            self._coreg._get_point_distance_str()
        dists = self._coreg.compute_dig_mri_distances() * 1e3
        if self._hsp_weight > 0:
            value += "\nHSP <-> MRI (mean/min/max): "\
                f"{np.mean(dists):.2f} "\
                f"/ {np.min(dists):.2f} / {np.max(dists):.2f} mm"
        self._forward_widget_command("fit_label", "set_value", value)

    def _update_parameters(self):
        with self._lock_plot():
            with self._lock_params():
                # rotation
                self._forward_widget_command(["rX", "rY", "rZ"], "set_value",
                                             np.rad2deg(self._coreg._rotation))
                # translation
                self._forward_widget_command(["tX", "tY", "tZ"], "set_value",
                                             self._coreg._translation * 1e3)
                # scale
                self._forward_widget_command(["sX", "sY", "sZ"], "set_value",
                                             self._coreg._scale * 1e2)

    def _reset(self):
        self._reset_fitting_parameters()
        self._coreg.reset()
        self._update_plot()
        self._update_parameters()
        self._update_distance_estimation()

    def _forward_widget_command(self, names, command, value):
        names = [names] if not isinstance(names, list) else names
        value = list(value) if isinstance(value, np.ndarray) else value
        for idx, name in enumerate(names):
            val = value[idx] if isinstance(value, list) else value
            if name in self._widgets:
                getattr(self._widgets[name], command)(val)

    def _set_sensors_visibility(self, state):
        sensors = [
            "head_fiducials", "hpi_coils", "head_shape_points", "eeg_channels"
        ]
        for sensor in sensors:
            if sensor in self._actors and self._actors[sensor] is not None:
                actors = self._actors[sensor]
                actors = actors if isinstance(actors, list) else [actors]
                for actor in actors:
                    actor.SetVisibility(state)
        self._renderer._update()

    def _update_actor(self, actor_name, actor):
        # XXX: internal plotter/renderer should not be exposed
        self._renderer.plotter.remove_actor(self._actors.get(actor_name))
        self._actors[actor_name] = actor

    def _add_mri_fiducials(self):
        mri_fids_actors = _plot_mri_fiducials(self._renderer,
                                              self._coreg._fid_points,
                                              self._subjects_dir,
                                              self._subject, self._to_cf_t,
                                              self._fid_colors)
        # disable picking on the markers
        for actor in mri_fids_actors:
            actor.SetPickable(False)
        self._update_actor("mri_fiducials", mri_fids_actors)

    def _add_head_fiducials(self):
        head_fids_actors = _plot_head_fiducials(self._renderer, self._info,
                                                self._to_cf_t,
                                                self._fid_colors)
        self._update_actor("head_fiducials", head_fids_actors)

    def _add_hpi_coils(self):
        if self._hpi_coils:
            hpi_actors = _plot_hpi_coils(
                self._renderer,
                self._info,
                self._to_cf_t,
                opacity=self._defaults["sensor_opacity"],
                scale=DEFAULTS["coreg"]["extra_scale"],
                orient_glyphs=self._orient_glyphs,
                scale_by_distance=self._scale_by_distance,
                surf=self._head_geo)
        else:
            hpi_actors = None
        self._update_actor("hpi_coils", hpi_actors)

    def _add_head_shape_points(self):
        if self._head_shape_points:
            hsp_actors = _plot_head_shape_points(
                self._renderer,
                self._info,
                self._to_cf_t,
                opacity=self._defaults["sensor_opacity"],
                orient_glyphs=self._orient_glyphs,
                scale_by_distance=self._scale_by_distance,
                mark_inside=self._mark_inside,
                surf=self._head_geo,
                mask=self._coreg._extra_points_filter)
        else:
            hsp_actors = None
        self._update_actor("head_shape_points", hsp_actors)

    def _add_eeg_channels(self):
        if self._eeg_channels:
            eeg = ["original"]
            picks = pick_types(self._info, eeg=(len(eeg) > 0), fnirs=True)
            if len(picks) > 0:
                actors = _plot_sensors(
                    self._renderer,
                    self._info,
                    self._to_cf_t,
                    picks,
                    meg=False,
                    eeg=eeg,
                    fnirs=["sources", "detectors"],
                    warn_meg=False,
                    head_surf=self._head_geo,
                    units='m',
                    sensor_opacity=self._defaults["sensor_opacity"],
                    orient_glyphs=self._orient_glyphs,
                    scale_by_distance=self._scale_by_distance,
                    project_points=self._project_eeg,
                    surf=self._head_geo)
                sens_actors = actors["eeg"]
                sens_actors.extend(actors["fnirs"])
            else:
                sens_actors = None
        else:
            sens_actors = None
        self._update_actor("eeg_channels", sens_actors)

    def _add_head_surface(self):
        bem = None
        surface = "head-dense" if self._head_resolution else "head"
        try:
            head_actor, head_surf, _ = _plot_head_surface(
                self._renderer,
                surface,
                self._subject,
                self._subjects_dir,
                bem,
                self._coord_frame,
                self._to_cf_t,
                alpha=self._head_opacity)
        except IOError:
            head_actor, head_surf, _ = _plot_head_surface(
                self._renderer,
                "head",
                self._subject,
                self._subjects_dir,
                bem,
                self._coord_frame,
                self._to_cf_t,
                alpha=self._head_opacity)
        # mark head surface mesh to restrict picking
        head_surf._picking_target = True
        self._update_actor("head", head_actor)
        self._surfaces["head"] = head_surf

    def _add_head_hair(self):
        if "head" in self._surfaces:
            res = "high" if self._head_resolution else "low"
            self._surfaces["head"].points = \
                self._coreg._get_processed_mri_points(res)

    def _fit_fiducials(self):
        if not self._lock_fids:
            self._display_message(
                "Fitting is disabled, lock the fiducials first.")
            return
        start = time.time()
        self._coreg.fit_fiducials(
            lpa_weight=self._lpa_weight,
            nasion_weight=self._nasion_weight,
            rpa_weight=self._rpa_weight,
            verbose=self._verbose,
        )
        end = time.time()
        self._display_message(
            f"Fitting fiducials finished in {end - start:.2f} seconds.")
        self._update_plot("sensors")
        self._update_parameters()
        self._update_distance_estimation()

    def _fit_icp(self):
        if not self._lock_fids:
            self._display_message(
                "Fitting is disabled, lock the fiducials first.")
            return
        self._current_icp_iterations = 0

        def callback(iteration, n_iterations):
            self._display_message(f"Fitting ICP - iteration {iteration + 1}")
            self._update_plot("sensors")
            self._current_icp_iterations += 1
            self._update_distance_estimation()
            self._update_parameters()
            self._renderer._process_events()  # allow a draw or cancel

        start = time.time()
        self._coreg.fit_icp(
            n_iterations=self._icp_n_iterations,
            lpa_weight=self._lpa_weight,
            nasion_weight=self._nasion_weight,
            rpa_weight=self._rpa_weight,
            callback=callback,
            verbose=self._verbose,
        )
        end = time.time()
        self._display_message()
        self._display_message(
            f"Fitting ICP finished in {end - start:.2f} seconds and "
            f"{self._current_icp_iterations} iterations.")
        del self._current_icp_iterations

    def _save_trans(self, fname):
        write_trans(fname, self._coreg.trans)
        self._display_message(f"{fname} transform file is saved.")

    def _load_trans(self, fname):
        mri_head_t = _ensure_trans(read_trans(fname, return_all=True), 'mri',
                                   'head')['trans']
        rot_x, rot_y, rot_z = rotation_angles(mri_head_t)
        x, y, z = mri_head_t[:3, 3]
        self._coreg._update_params(
            rot=np.array([rot_x, rot_y, rot_z]),
            tra=np.array([x, y, z]),
        )
        self._update_parameters()
        self._update_distance_estimation()
        self._display_message(f"{fname} transform file is loaded.")

    def _get_subjects(self, sdir=None):
        # XXX: would be nice to move this function to util
        sdir = sdir if sdir is not None else self._subjects_dir
        is_dir = sdir and op.isdir(sdir)
        if is_dir:
            dir_content = os.listdir(sdir)
            subjects = [s for s in dir_content if _is_mri_subject(s, sdir)]
            if len(subjects) == 0:
                subjects.append('')
        else:
            subjects = ['']
        return sorted(subjects)

    def _check_fif(self, filetype, fname):
        try:
            check_fname(fname, filetype, ('.fif'), ('.fif'))
        except IOError:
            warn(f"The filename {fname} for {filetype} must end with '.fif'.")
            self._widgets[f"{filetype}_file"].set_value(0, '')
            return False
        return True

    def _configure_dock(self):
        self._renderer._dock_initialize(name="Input", area="left")
        layout = self._renderer._dock_add_group_box("MRI Subject")
        self._widgets["subjects_dir"] = self._renderer._dock_add_file_button(
            name="subjects_dir",
            desc="Load",
            func=self._set_subjects_dir,
            value=self._subjects_dir,
            placeholder="Subjects Directory",
            directory=True,
            tooltip="Load the path to the directory containing the "
            "FreeSurfer subjects",
            layout=layout,
        )
        self._widgets["subject"] = self._renderer._dock_add_combo_box(
            name="Subject",
            value=self._subject,
            rng=self._get_subjects(),
            callback=self._set_subject,
            compact=True,
            tooltip="Select the FreeSurfer subject name",
            layout=layout)

        layout = self._renderer._dock_add_group_box("MRI Fiducials")
        self._widgets["lock_fids"] = self._renderer._dock_add_check_box(
            name="Lock fiducials",
            value=self._lock_fids,
            callback=self._set_lock_fids,
            tooltip="Lock/Unlock interactive fiducial editing",
            layout=layout)
        self._widgets["fiducials_file"] = self._renderer._dock_add_file_button(
            name="fiducials_file",
            desc="Load",
            func=self._set_fiducials_file,
            value=self._fiducials_file,
            placeholder="Path to fiducials",
            tooltip="Load the fiducials from a FIFF file",
            layout=layout,
        )
        self._widgets["fids"] = self._renderer._dock_add_radio_buttons(
            value=self._defaults["fiducial"],
            rng=self._defaults["fiducials"],
            callback=self._set_current_fiducial,
            vertical=False,
            layout=layout,
        )
        hlayout = self._renderer._dock_add_layout()
        for coord in ("X", "Y", "Z"):
            name = f"fid_{coord}"
            self._widgets[name] = self._renderer._dock_add_spin_box(
                name=coord,
                value=0.,
                rng=[-1e3, 1e3],
                callback=partial(
                    self._set_fiducial,
                    coord=coord,
                ),
                compact=True,
                double=True,
                tooltip=f"Set the {coord} fiducial coordinate",
                layout=hlayout)
        self._renderer._layout_add_widget(layout, hlayout)

        layout = self._renderer._dock_add_group_box("Digitization Source")
        self._widgets["info_file"] = self._renderer._dock_add_file_button(
            name="info_file",
            desc="Load",
            func=self._set_info_file,
            value=self._info_file,
            placeholder="Path to info",
            tooltip="Load the FIFF file with digitizer data for "
            "coregistration",
            layout=layout,
        )
        self._widgets["grow_hair"] = self._renderer._dock_add_spin_box(
            name="Grow Hair (mm)",
            value=self._grow_hair,
            rng=[0.0, 10.0],
            callback=self._set_grow_hair,
            tooltip="Compensate for hair on the digitizer head shape",
            layout=layout,
        )
        hlayout = self._renderer._dock_add_layout(vertical=False)
        self._widgets["omit_distance"] = self._renderer._dock_add_spin_box(
            name="Omit Distance (mm)",
            value=self._omit_hsp_distance,
            rng=[0.0, 100.0],
            callback=self._set_omit_hsp_distance,
            tooltip="Set the head shape points exclusion distance",
            layout=hlayout,
        )
        self._widgets["omit"] = self._renderer._dock_add_button(
            name="Omit",
            callback=self._omit_hsp,
            tooltip="Exclude the head shape points that are far away from "
            "the MRI head",
            layout=hlayout,
        )
        self._widgets["reset_omit"] = self._renderer._dock_add_button(
            name="Reset",
            callback=self._reset_omit_hsp_filter,
            tooltip="Reset all excluded head shape points",
            layout=hlayout,
        )
        self._renderer._layout_add_widget(layout, hlayout)

        layout = self._renderer._dock_add_group_box("View options")
        self._widgets["project_eeg"] = self._renderer._dock_add_check_box(
            name="Project EEG",
            value=self._project_eeg,
            callback=self._set_project_eeg,
            tooltip="Enable/Disable EEG channels projection on head surface",
            layout=layout)
        self._widgets["high_res_head"] = self._renderer._dock_add_check_box(
            name="Show High Resolution Head",
            value=self._head_resolution,
            callback=self._set_head_resolution,
            tooltip="Enable/Disable high resolution head surface",
            layout=layout)
        self._renderer._dock_add_stretch()

        self._renderer._dock_initialize(name="Parameters", area="right")
        self._widgets["scaling_mode"] = self._renderer._dock_add_combo_box(
            name="Scaling Mode",
            value=self._defaults["scale_mode"],
            rng=self._defaults["scale_modes"],
            callback=self._set_scale_mode,
            tooltip="Select the scaling mode",
            compact=True,
        )
        hlayout = self._renderer._dock_add_group_box(
            name="Scaling Parameters", )
        for coord in ("X", "Y", "Z"):
            name = f"s{coord}"
            self._widgets[name] = self._renderer._dock_add_spin_box(
                name=name,
                value=0.,
                rng=[-1e3, 1e3],
                callback=partial(
                    self._set_parameter,
                    mode_name="scale",
                    coord=coord,
                ),
                compact=True,
                double=True,
                tooltip=f"Set the {coord} scaling parameter",
                layout=hlayout)

        for mode, mode_name in (("t", "Translation"), ("r", "Rotation")):
            hlayout = self._renderer._dock_add_group_box(
                f"{mode_name} ({mode})")
            for coord in ("X", "Y", "Z"):
                name = f"{mode}{coord}"
                self._widgets[name] = self._renderer._dock_add_spin_box(
                    name=name,
                    value=0.,
                    rng=[-1e3, 1e3],
                    callback=partial(
                        self._set_parameter,
                        mode_name=mode_name.lower(),
                        coord=coord,
                    ),
                    compact=True,
                    double=True,
                    step=1,
                    tooltip=f"Set the {coord} {mode_name.lower()} parameter",
                    layout=hlayout)

        layout = self._renderer._dock_add_group_box("Fitting")
        hlayout = self._renderer._dock_add_layout(vertical=False)
        self._widgets["fit_fiducials"] = self._renderer._dock_add_button(
            name="Fit Fiducials",
            callback=self._fit_fiducials,
            tooltip="Find rotation and translation to fit all 3 fiducials",
            layout=hlayout,
        )
        self._widgets["fit_icp"] = self._renderer._dock_add_button(
            name="Fit ICP",
            callback=self._fit_icp,
            tooltip="Find MRI scaling, translation, and rotation to match the "
            "head shape points",
            layout=hlayout,
        )
        self._renderer._layout_add_widget(layout, hlayout)
        self._widgets["fit_label"] = self._renderer._dock_add_label(
            value="",
            layout=layout,
        )
        self._widgets["icp_n_iterations"] = self._renderer._dock_add_spin_box(
            name="Number Of ICP Iterations",
            value=self._defaults["icp_n_iterations"],
            rng=[1, 100],
            callback=self._set_icp_n_iterations,
            compact=True,
            double=False,
            tooltip="Set the number of ICP iterations",
            layout=layout,
        )
        self._widgets["icp_fid_match"] = self._renderer._dock_add_combo_box(
            name="Fiducial point matching",
            value=self._defaults["icp_fid_match"],
            rng=self._defaults["icp_fid_matches"],
            callback=self._set_icp_fid_match,
            compact=True,
            tooltip="Select the fiducial point matching method",
            layout=layout)
        layout = self._renderer._dock_add_group_box(
            name="Weights",
            layout=layout,
        )
        for point, fid in zip(("HSP", "EEG", "HPI"),
                              self._defaults["fiducials"]):
            hlayout = self._renderer._dock_add_layout(vertical=False)
            point_lower = point.lower()
            name = f"{point_lower}_weight"
            self._widgets[name] = self._renderer._dock_add_spin_box(
                name=point,
                value=getattr(self, f"_{point_lower}_weight"),
                rng=[0., 100.],
                callback=partial(self._set_point_weight, point=point_lower),
                compact=True,
                double=True,
                tooltip=f"Set the {point} weight",
                layout=hlayout)

            fid_lower = fid.lower()
            name = f"{fid_lower}_weight"
            self._widgets[name] = self._renderer._dock_add_spin_box(
                name=fid,
                value=getattr(self, f"_{fid_lower}_weight"),
                rng=[1., 100.],
                callback=partial(self._set_point_weight, point=fid_lower),
                compact=True,
                double=True,
                tooltip=f"Set the {fid} weight",
                layout=hlayout)
            self._renderer._layout_add_widget(layout, hlayout)
        self._renderer._dock_add_button(
            name="Reset Fitting Options",
            callback=self._reset_fitting_parameters,
            tooltip="Reset all the fitting parameters to default value",
            layout=layout,
        )
        layout = self._renderer._dock_layout
        hlayout = self._renderer._dock_add_layout(vertical=False)
        self._renderer._dock_add_button(
            name="Reset",
            callback=self._reset,
            tooltip="Reset all the parameters affecting the coregistration",
            layout=hlayout,
        )
        self._widgets["save_trans"] = self._renderer._dock_add_file_button(
            name="save_trans",
            desc="Save...",
            save=True,
            func=self._save_trans,
            input_text_widget=False,
            tooltip="Save the transform file to disk",
            layout=hlayout,
        )
        self._widgets["load_trans"] = self._renderer._dock_add_file_button(
            name="load_trans",
            desc="Load...",
            func=self._load_trans,
            input_text_widget=False,
            tooltip="Load the transform file from disk",
            layout=hlayout,
        )
        self._renderer._layout_add_widget(layout, hlayout)
        self._renderer._dock_add_stretch()

    def _configure_status_bar(self):
        self._status_msg = self._renderer._status_bar_add_label("", stretch=1)
        self._status_msg.hide()

    def _clean(self):
        self._renderer = None
        self._coreg = None
        self._widgets.clear()
        self._actors.clear()
        self._surfaces.clear()
        self._defaults.clear()
        self._head_geo = None
        self._redraw = None
        self._status_msg = None

    def close(self):
        """Close interface and cleanup data structure."""
        self._renderer.close()
Beispiel #3
0
class EGridNLSSFeature(NLSSFeature):
    """
    extract boe exact features by comparing e_grid of qe with qe's nlss
    """
    feature_name_pre = Unicode('EGridNLSS')
    l_grid_lvl_pool = List(Unicode, default_value=['Sum', 'Max',
                                                   'Mean']).tag(config=True)
    h_pool_name_func = {
        'Sum': np.sum,
        'Max': np.amax,
        'Mean': np.mean,
    }
    l_nlss_lvl_pool = List(Unicode, default_value=['Max',
                                                   'Mean']).tag(config=True)

    def __init__(self, **kwargs):
        super(EGridNLSSFeature, self).__init__(**kwargs)
        for pool_name in self.l_grid_lvl_pool:
            assert pool_name in self.h_pool_name_func

    def _extract_per_entity_via_nlss(self, q_info, ana, doc_info, l_qe_nlss):
        """
        :param ana:
        :param doc_info:
        :param l_qe_nlss:
        :return:
        """

        h_this_feature = dict()
        h_e_grid = doc_info.get(E_GRID_FIELD, {})
        l_nlss_bow = self._form_nlss_bow(l_qe_nlss)
        l_nlss_emb = self._form_nlss_emb(l_qe_nlss)
        for field in self.l_target_fields:
            if field not in h_e_grid:
                continue
            l_e_grid = h_e_grid.get(field, [])
            h_field_grid_feature = self._extract_per_entity_per_nlss_per_field(
                ana, doc_info, l_qe_nlss, l_e_grid, l_nlss_bow, l_nlss_emb)
            h_this_feature.update(
                add_feature_prefix(h_field_grid_feature, field + '_'))
        return h_this_feature

    def _extract_per_entity_per_nlss_per_field(self, ana, doc_info, l_qe_nlss,
                                               l_e_grid, l_nlss_bow,
                                               l_nlss_emb):
        """
        for each sentence in e_grid,
            check if ana e in it, and if len < max_sent_len
            calculate similarity with all qe_nlss
            average and max sum up
        :param ana:
        :param doc_info:
        :param l_qe_nlss: nlss of qe
        :param l_e_grid: grid of this field
        :param l_nlss_bow: pre calc bow of nlss
        :param l_nlss_emb: pre calc emb of nlss
        :return:
        """
        e_id = ana['id']
        l_this_e_grid = self._filter_e_grid(e_id, l_e_grid)
        l_grid_bow = self._form_grid_bow(l_this_e_grid)
        l_grid_emb = self._form_grid_emb(l_this_e_grid)

        m_bow_sim = self._calc_bow_trans(l_grid_bow, l_nlss_bow)
        m_emb_sim = self._calc_emb_trans(l_grid_emb, l_nlss_emb)

        # if self.intermediate_data_out_name:
        #     self._log_intermediate_res(ana, doc_info, l_this_e_grid, l_qe_nlss, m_bow_sim, m_emb_sim)

        h_bow_feature = self._pool_grid_nlss_sim(m_bow_sim)
        h_emb_feature = self._pool_grid_nlss_sim(m_emb_sim)

        h_feature = dict()
        h_feature.update(add_feature_prefix(h_bow_feature, 'BOW_'))
        h_feature.update(add_feature_prefix(h_emb_feature, 'Emb_'))
        return h_feature

    def _filter_e_grid(self, e_id, l_e_grid):
        """
        filer e grid to those that
            contain e id
            not too long (<self.max_sent_len)
        :param e_id: target e id
        :param l_e_grid: grid of doc
        :return:
        """
        l_kept_grid = []
        for e_grid in l_e_grid:
            if len(e_grid['sent'].split()) > self.max_sent_len:
                continue
            contain_flag = False
            for ana in e_grid[SPOT_FIELD]:
                if ana['id'] == e_id:
                    contain_flag = True
                    break
            if contain_flag:
                l_kept_grid.append(e_grid)
        return l_kept_grid

    def _form_grid_bow(self, l_e_grid):
        l_sent = [grid['sent'] for grid in l_e_grid]
        return self._form_sents_bow(l_sent)

    def _form_grid_emb(self, l_e_grid):
        l_sent = [grid['sent'] for grid in l_e_grid]
        return self._form_sents_emb(l_sent)

    def _pool_grid_nlss_sim(self, trans_mtx):
        h_feature = {}
        for name1 in self.l_grid_lvl_pool:
            f1 = self.h_pool_name_func[name1]
            for name2 in self.l_nlss_lvl_pool:
                f2 = self.h_pool_name_func[name2]
                score = -1
                if (trans_mtx.shape[0] > 0) & (trans_mtx.shape[1] > 0):
                    score = f1(f2(trans_mtx, axis=1), axis=0)
                pool_name = 'R' + name1 + 'C' + name2
                h_feature[pool_name] = score
        return h_feature

    def _log_intermediate_res(self, ana, doc_info, l_this_e_grid, l_qe_nlss,
                              m_bow_sim, m_emb_sim):
        """
        dump out the intermediate results
            e_id, name, doc no,
            e_grid_sentences:
            grid sentence for this e_id in doc, mean sim in bow and emb,
                max sim in bow and emb, and corresponding nlss that generate the max
        :param ana:
        :param doc_info:
        :param l_this_e_grid:
        :param l_qe_nlss:
        :param m_bow_sim:
        :param m_emb_sim:
        :return:
        """
        # use json
        if not doc_info:
            return
        h_pair_res = dict()
        h_pair_res['id'] = ana['id']
        h_pair_res['surface'] = ana['surface']
        h_pair_res['docno'] = doc_info['docno']
        if (not l_this_e_grid) | (not l_qe_nlss):
            print >> self.intermediate_out, json.dumps(h_pair_res)
            return

        l_e_grid_info = []
        for i in xrange(len(l_this_e_grid)):
            h_this_sent = {}
            h_this_sent['sent'] = l_this_e_grid[i]['sent']
            h_this_sent['mean_bow_sim'] = np.mean(m_bow_sim[i])
            h_this_sent['mean_emb_sim'] = np.mean(m_emb_sim[i])

            max_p = np.argmax(m_bow_sim[i])
            h_this_sent['max_bow_sim'] = m_bow_sim[i, max_p]
            h_this_sent['max_bow_nlss'] = l_qe_nlss[max_p][0]
            max_p = np.argmax(m_emb_sim[i])
            h_this_sent['max_emb_sim'] = m_emb_sim[i, max_p]
            h_this_sent['max_emb_nlss'] = l_qe_nlss[max_p][0]

            l_e_grid_info.append(h_this_sent)

        h_pair_res['e_grid_info'] = l_e_grid_info

        print >> self.intermediate_out, json.dumps(h_pair_res)
        return
Beispiel #4
0
class ExtensionApp(JupyterApp):
    """Base class for configurable Jupyter Server Extension Applications.

    ExtensionApp subclasses can be initialized two ways:
    1. Extension is listed as a jpserver_extension, and ServerApp calls 
        its load_jupyter_server_extension classmethod. This is the 
        classic way of loading a server extension.
    2. Extension is launched directly by calling its `launch_instance`
        class method. This method can be set as a entry_point in 
        the extensions setup.py
    """
    # Name of the extension
    extension_name = Unicode("", help="Name of extension.")

    @default("extension_name")
    def _default_extension_name(self):
        raise ValueError("The extension must be given a `name`.")

    INVALID_EXTENSION_NAME_CHARS = [' ', '.', '+', '/']

    def _validate_extension_name(self):
        value = self.extension_name
        if isinstance(value, str):
            # Validate that extension_name doesn't contain any invalid characters.
            for c in ExtensionApp.INVALID_EXTENSION_NAME_CHARS:
                if c in value:
                    raise ValueError(
                        "Extension name '{name}' cannot contain any of the following characters: "
                        "{invalid_chars}.".format(
                            name=value,
                            invalid_chars=ExtensionApp.
                            INVALID_EXTENSION_NAME_CHARS))
            return value
        raise ValueError(
            "Extension name must be a string, found {type}.".format(
                type=type(value)))

    # Extension can configure the ServerApp from the command-line
    classes = [
        ServerApp,
    ]

    aliases = aliases
    flags = flags

    @property
    def static_url_prefix(self):
        return "/static/{extension_name}/".format(
            extension_name=self.extension_name)

    static_paths = List(Unicode(),
                        help="""paths to search for serving static files.
        
        This allows adding javascript/css to be available from the notebook server machine,
        or overriding individual files in the IPython
        """).tag(config=True)

    template_paths = List(
        Unicode(),
        help=_("""Paths to search for serving jinja templates.

        Can be used to override templates from notebook.templates.""")).tag(
            config=True)

    settings = Dict(
        help=_("""Settings that will passed to the server.""")).tag(
            config=True)

    handlers = List(help=_("""Handlers appended to the server.""")).tag(
        config=True)

    default_url = Unicode('/',
                          config=True,
                          help=_("The default URL to redirect to from `/`"))

    def initialize_settings(self):
        """Override this method to add handling of settings."""
        pass

    def initialize_handlers(self):
        """Override this method to append handlers to a Jupyter Server."""
        pass

    def initialize_templates(self):
        """Override this method to add handling of template files."""
        pass

    def _prepare_config(self):
        """Builds a Config object from the extension's traits and passes
        the object to the webapp's settings as `<extension_name>_config`.  
        """
        traits = self.class_own_traits().keys()
        self.config = Config({t: getattr(self, t) for t in traits})
        self.settings['{}_config'.format(self.extension_name)] = self.config

    def _prepare_settings(self):
        # Make webapp settings accessible to initialize_settings method
        webapp = self.serverapp.web_app
        self.settings.update(**webapp.settings)

        # Add static and template paths to settings.
        self.settings.update({
            "{}_static_paths".format(self.extension_name):
            self.static_paths,
        })

        # Get setting defined by subclass using initialize_settings method.
        self.initialize_settings()

        # Update server settings with extension settings.
        webapp.settings.update(**self.settings)

    def _prepare_handlers(self):
        webapp = self.serverapp.web_app

        # Get handlers defined by extension subclass.
        self.initialize_handlers()

        # prepend base_url onto the patterns that we match
        new_handlers = []
        for handler_items in self.handlers:
            # Build url pattern including base_url
            pattern = url_path_join(webapp.settings['base_url'],
                                    handler_items[0])
            handler = handler_items[1]

            # Get handler kwargs, if given
            kwargs = {}
            if issubclass(handler, ExtensionHandler):
                kwargs['extension_name'] = self.extension_name
            try:
                kwargs.update(handler_items[2])
            except IndexError:
                pass

            new_handler = (pattern, handler, kwargs)
            new_handlers.append(new_handler)

        # Add static endpoint for this extension, if static paths are given.
        if len(self.static_paths) > 0:
            # Append the extension's static directory to server handlers.
            static_url = url_path_join("/static", self.extension_name, "(.*)")

            # Construct handler.
            handler = (static_url, webapp.settings['static_handler_class'], {
                'path': self.static_paths
            })
            new_handlers.append(handler)

        webapp.add_handlers('.*$', new_handlers)

    def _prepare_templates(self):
        # Add templates to web app settings if extension has templates.
        if len(self.template_paths) > 0:
            self.settings.update({
                "{}_template_paths".format(self.extension_name):
                self.template_paths
            })
        self.initialize_templates()

    @staticmethod
    def initialize_server(argv=[], **kwargs):
        """Get an instance of the Jupyter Server."""
        # Get a jupyter server instance
        serverapp = ServerApp(**kwargs)
        # Initialize ServerApp config.
        # Parses the command line looking for
        # ServerApp configuration.
        serverapp.initialize(argv=argv)
        return serverapp

    def initialize(self, serverapp, argv=[]):
        """Initialize the extension app.
        
        This method:
        - Loads the extension's config from file
        - Updates the extension's config from argv
        - Initializes templates environment
        - Passes settings to webapp
        - Appends handlers to webapp.
        """
        self._validate_extension_name()
        # Initialize the extension application
        super(ExtensionApp, self).initialize(argv=argv)
        self.serverapp = serverapp

        # Initialize config, settings, templates, and handlers.
        self._prepare_config()
        self._prepare_templates()
        self._prepare_settings()
        self._prepare_handlers()

    def listen(self):
        """Extension doesn't listen for anything. IOLoop lives
        in the attached server.
        """
        pass

    def start(self):
        """Extension has nothing to `start`. See `start_server`
        for starting the jupyter server application.
        """
        pass

    def start_server(self, **kwargs):
        """Start the Jupyter server.
        
        Server should be started after extension is initialized.
        """
        # Start the server.
        self.serverapp.start(**kwargs)

    @classmethod
    def load_jupyter_server_extension(cls, serverapp, argv=[], **kwargs):
        """Initialize and configure this extension, then add the extension's
        settings and handlers to the server's web application.
        """
        # Configure and initialize extension.
        extension = cls()
        extension.initialize(serverapp, argv=argv)

        return extension

    @classmethod
    def _prepare_launch(cls, serverapp, argv=[], **kwargs):
        """Prepare the extension application for launch by 
        configuring the server and the extension from argv.
        Does not start the ioloop.
        """
        # Load the extension
        extension = cls.load_jupyter_server_extension(serverapp,
                                                      argv=argv,
                                                      **kwargs)

        # Start the browser at this extensions default_url, unless user
        # configures ServerApp.default_url on command line.
        try:
            server_config = extension.config['ServerApp']
            if 'default_url' not in server_config:
                serverapp.default_url = extension.default_url
        except KeyError:
            pass

        return extension

    @classmethod
    def launch_instance(cls, argv=None, **kwargs):
        """Launch the extension like an application. Initializes+configs a stock server 
        and appends the extension to the server. Then starts the server and routes to
        extension's landing page.
        """
        # Check for help, version, and generate-config arguments
        # before initializing server to make sure these
        # arguments trigger actions from the extension not the server.
        _preparse_command_line(cls)
        # Handle arguments.
        if argv is not None:
            args = sys.argv[1:]  # slice out extension config.
        else:
            args = []

        # Get a jupyter server instance.
        serverapp = cls.initialize_server(argv=args)
        extension = cls._prepare_launch(serverapp, argv=args, **kwargs)
        # Start the ioloop.
        extension.start_server()
Beispiel #5
0
class IPClusterStart(IPClusterEngines):

    name = u'ipcluster'
    description = start_help
    examples = _start_examples
    default_log_level = logging.INFO
    auto_create = Bool(
        True,
        config=True,
        help="whether to create the profile_dir if it doesn't exist")
    classes = List()

    def _classes_default(self, ):
        from ipyparallel.apps import launcher
        return [ProfileDir] + [IPClusterEngines] + launcher.all_launchers

    clean_logs = Bool(True,
                      config=True,
                      help="whether to cleanup old logs before starting")

    delay = CFloat(
        1.,
        config=True,
        help="delay (in s) between starting the controller and the engines")

    controller_ip = Unicode(config=True,
                            help="Set the IP address of the controller.")
    controller_launcher = Any(config=True,
                              help="Deprecated, use controller_launcher_class")

    def _controller_launcher_changed(self, name, old, new):
        if isinstance(new, string_types):
            # old 0.11-style config
            self.log.warn(
                "WARNING: %s.controller_launcher is deprecated as of 0.12,"
                " use controller_launcher_class" % self.__class__.__name__)
            self.controller_launcher_class = new

    controller_launcher_class = DottedObjectName(
        'LocalControllerLauncher',
        config=True,
        help=
        """The class for launching a Controller. Change this value if you want
        your controller to also be launched by a batch system, such as PBS,SGE,MPI,etc.

        Each launcher class has its own set of configuration options, for making sure
        it will work in your environment.
        
        Note that using a batch launcher for the controller *does not* put it
        in the same batch job as the engines, so they will still start separately.

        IPython's bundled examples include:

            Local : start engines locally as subprocesses
            MPI : use mpiexec to launch the controller in an MPI universe
            PBS : use PBS (qsub) to submit the controller to a batch queue
            SGE : use SGE (qsub) to submit the controller to a batch queue
            LSF : use LSF (bsub) to submit the controller to a batch queue
            HTCondor : use HTCondor to submit the controller to a batch queue
            Slurm : use Slurm to submit engines to a batch queue
            SSH : use SSH to start the controller
            WindowsHPC : use Windows HPC

        If you are using one of IPython's builtin launchers, you can specify just the
        prefix, e.g:

            c.IPClusterStart.controller_launcher_class = 'SSH'

        or:

            ipcluster start --controller=MPI

        """)
    reset = Bool(False,
                 config=True,
                 help="Whether to reset config files as part of '--create'.")

    # flags = Dict(flags)
    aliases = Dict(start_aliases)

    def init_launchers(self):
        self.controller_launcher = self.build_launcher(
            self.controller_launcher_class, 'Controller')
        if self.controller_ip:
            self.controller_launcher.controller_args.append('--ip=%s' %
                                                            self.controller_ip)
        self.engine_launcher = self.build_launcher(self.engine_launcher_class,
                                                   'EngineSet')

    def engines_stopped(self, r):
        """prevent parent.engines_stopped from stopping everything on engine shutdown"""
        pass

    def start_controller(self):
        self.log.info("Starting Controller with %s",
                      self.controller_launcher_class)
        self.controller_launcher.on_stop(self.stop_launchers)
        try:
            self.controller_launcher.start()
        except:
            self.log.exception("Controller start failed")
            raise

    def stop_controller(self):
        # self.log.info("In stop_controller")
        if self.controller_launcher and self.controller_launcher.running:
            return self.controller_launcher.stop()

    def stop_launchers(self, r=None):
        if not self._stopping:
            self.stop_controller()
            super(IPClusterStart, self).stop_launchers()

    def start(self):
        """Start the app for the start subcommand."""
        # First see if the cluster is already running
        try:
            pid = self.get_pid_from_file()
        except PIDFileError:
            pass
        else:
            if self.check_pid(pid):
                self.log.critical('Cluster is already running with [pid=%s]. '
                                  'use "ipcluster stop" to stop the cluster.' %
                                  pid)
                # Here I exit with a unusual exit status that other processes
                # can watch for to learn how I existed.
                self.exit(ALREADY_STARTED)
            else:
                self.remove_pid_file()

        # Now log and daemonize
        self.log.info('Starting ipcluster with [daemon=%r]' % self.daemonize)
        # TODO: Get daemonize working on Windows or as a Windows Server.
        if self.daemonize:
            if os.name == 'posix':
                daemonize()

        def start():
            self.start_controller()
            self.loop.add_timeout(self.loop.time() + self.delay,
                                  self.start_engines)

        self.loop.add_callback(start)
        # Now write the new pid file AFTER our new forked pid is active.
        self.write_pid_file()
        try:
            self.loop.start()
        except KeyboardInterrupt:
            pass
        except zmq.ZMQError as e:
            if e.errno == errno.EINTR:
                pass
            else:
                raise
        finally:
            self.remove_pid_file()
Beispiel #6
0
class Figure(ipywebrtc.MediaStream):
    """Widget class representing a volume (rendering) using three.js"""
    _view_name = Unicode('FigureView').tag(sync=True)
    _view_module = Unicode('ipyvolume').tag(sync=True)
    _model_name = Unicode('FigureModel').tag(sync=True)
    _model_module = Unicode('ipyvolume').tag(sync=True)
    _view_module_version = Unicode(semver_range_frontend).tag(sync=True)
    _model_module_version = Unicode(semver_range_frontend).tag(sync=True)

    volume_data = Array(default_value=None,
                        allow_none=True).tag(sync=True,
                                             **array_cube_tile_serialization)
    eye_separation = traitlets.CFloat(6.4).tag(sync=True)
    data_min = traitlets.CFloat().tag(sync=True)
    data_max = traitlets.CFloat().tag(sync=True)
    opacity_scale = traitlets.CFloat(1.0).tag(sync=True)
    tf = traitlets.Instance(TransferFunction, allow_none=True).tag(
        sync=True, **ipywidgets.widget_serialization)

    scatters = traitlets.List(traitlets.Instance(Scatter), [],
                              allow_none=False).tag(
                                  sync=True, **ipywidgets.widget_serialization)
    meshes = traitlets.List(traitlets.Instance(Mesh), [],
                            allow_none=False).tag(
                                sync=True, **ipywidgets.widget_serialization)

    animation = traitlets.Float(1000.0).tag(sync=True)
    animation_exponent = traitlets.Float(.5).tag(sync=True)

    ambient_coefficient = traitlets.Float(0.5).tag(sync=True)
    diffuse_coefficient = traitlets.Float(0.8).tag(sync=True)
    specular_coefficient = traitlets.Float(0.5).tag(sync=True)
    specular_exponent = traitlets.Float(5).tag(sync=True)
    stereo = traitlets.Bool(False).tag(sync=True)

    camera_control = traitlets.Unicode(default_value='trackball').tag(
        sync=True)
    camera_fov = traitlets.CFloat(45, min=0.1, max=179.9).tag(sync=True)
    camera_center = traitlets.List(traitlets.CFloat,
                                   default_value=[0, 0, 0]).tag(sync=True)
    #Tuple(traitlets.CFloat(0), traitlets.CFloat(0), traitlets.CFloat(0)).tag(sync=True)

    camera = traitlets.Instance(pythreejs.Camera).tag(
        sync=True, **ipywidgets.widget_serialization)

    @traitlets.default('camera')
    def _default_camera(self):
        # return pythreejs.CombinedCamera(fov=46, position=(0, 0, 2), width=400, height=500)
        return pythreejs.PerspectiveCamera(fov=46,
                                           position=(0, 0, 2),
                                           width=400,
                                           height=500)

    scene = traitlets.Instance(pythreejs.Scene).tag(
        sync=True, **ipywidgets.widget_serialization)

    @traitlets.default('scene')
    def _default_scene(self):
        # could be removed when https://github.com/jovyan/pythreejs/issues/176 is solved
        # the default for pythreejs is white, which leads the volume rendering pass to make everything white
        return pythreejs.Scene(background=None)

    width = traitlets.CInt(500).tag(sync=True)
    height = traitlets.CInt(400).tag(sync=True)
    downscale = traitlets.CInt(1).tag(sync=True)
    show = traitlets.Unicode("Volume").tag(sync=True)  # for debugging

    xlim = traitlets.List(traitlets.CFloat,
                          default_value=[0, 1],
                          minlen=2,
                          maxlen=2).tag(sync=True)
    ylim = traitlets.List(traitlets.CFloat,
                          default_value=[0, 1],
                          minlen=2,
                          maxlen=2).tag(sync=True)
    zlim = traitlets.List(traitlets.CFloat,
                          default_value=[0, 1],
                          minlen=2,
                          maxlen=2).tag(sync=True)

    extent = traitlets.Any().tag(sync=True)

    matrix_projection = traitlets.List(traitlets.CFloat,
                                       default_value=[0] * 16,
                                       allow_none=True,
                                       minlen=16,
                                       maxlen=16).tag(sync=True)
    matrix_world = traitlets.List(traitlets.CFloat,
                                  default_value=[0] * 16,
                                  allow_none=True,
                                  minlen=16,
                                  maxlen=16).tag(sync=True)

    xlabel = traitlets.Unicode("x").tag(sync=True)
    ylabel = traitlets.Unicode("y").tag(sync=True)
    zlabel = traitlets.Unicode("z").tag(sync=True)

    style = traitlets.Dict(default_value=ipyvolume.styles.default).tag(
        sync=True)

    render_continuous = traitlets.Bool(False).tag(sync=True)
    selector = traitlets.Unicode(default_value='lasso').tag(sync=True)
    selection_mode = traitlets.Unicode(default_value='replace').tag(sync=True)
    mouse_mode = traitlets.Unicode(default_value='normal').tag(sync=True)

    #xlim = traitlets.Tuple(traitlets.CFloat(0), traitlets.CFloat(1)).tag(sync=True)
    #y#lim = traitlets.Tuple(traitlets.CFloat(0), traitlets.CFloat(1)).tag(sync=True)
    #zlim = traitlets.Tuple(traitlets.CFloat(0), traitlets.CFloat(1)).tag(sync=True)

    def __init__(self, **kwargs):
        super(Figure, self).__init__(**kwargs)
        self._screenshot_handlers = widgets.CallbackDispatcher()
        self._selection_handlers = widgets.CallbackDispatcher()
        self.on_msg(self._handle_custom_msg)

    def __enter__(self):
        """Sets this figure as the current in the pylab API

        Example:
        >>> f1 = ipv.figure(1)
        >>> f2 = ipv.figure(2)
        >>> with f1:
        >>>  ipv.scatter(x, y, z)
        >>> assert ipv.gcf() is f2
        """

        import ipyvolume as ipv
        self._previous_figure = ipv.gcf()
        ipv.figure(self)

    def __exit__(self, type, value, traceback):
        import ipyvolume as ipv
        ipv.figure(self._previous_figure)
        del self._previous_figure

    def screenshot(self, width=None, height=None, mime_type='image/png'):
        self.send({
            'msg': 'screenshot',
            'width': width,
            'height': height,
            'mime_type': mime_type
        })

    def on_screenshot(self, callback, remove=False):
        self._screenshot_handlers.register_callback(callback, remove=remove)

    def _handle_custom_msg(self, content, buffers):
        if content.get('event', '') == 'screenshot':
            self._screenshot_handlers(content['data'])
        elif content.get('event', '') == 'selection':
            self._selection_handlers(content['data'])

    def on_selection(self, callback, remove=False):
        self._selection_handlers.register_callback(callback, remove=remove)

    def project(self, x, y, z):
        W = np.matrix(self.matrix_world).reshape((4, 4)).T
        P = np.matrix(self.matrix_projection).reshape((4, 4)).T
        M = np.dot(P, W)
        x = np.asarray(x)
        vertices = np.array([x, y, z, np.ones(x.shape)])
        screen_h = np.tensordot(M, vertices, axes=(1, 0))
        xy = screen_h[:2] / screen_h[3]
        return xy
Beispiel #7
0
class ExecutePreprocessor(Preprocessor):
    """
    Executes all the cells in a notebook
    """

    timeout = Integer(30, allow_none=True,
        help=dedent(
            """
            The time to wait (in seconds) for output from executions.
            If a cell execution takes longer, an exception (TimeoutError
            on python 3+, RuntimeError on python 2) is raised.

            `None` or `-1` will disable the timeout. If `timeout_func` is set,
            it overrides `timeout`.
            """
        )
    ).tag(config=True)

    timeout_func = Any(
        default_value=None,
        allow_none=True,
        help=dedent(
            """
            A callable which, when given the cell source as input,
            returns the time to wait (in seconds) for output from cell
            executions. If a cell execution takes longer, an exception
            (TimeoutError on python 3+, RuntimeError on python 2) is
            raised.

            Returning `None` or `-1` will disable the timeout for the cell.
            Not setting `timeout_func` will cause the preprocessor to
            default to using the `timeout` trait for all cells. The
            `timeout_func` trait overrides `timeout` if it is not `None`.
            """
        )
    ).tag(config=True)

    interrupt_on_timeout = Bool(False,
        help=dedent(
            """
            If execution of a cell times out, interrupt the kernel and
            continue executing other cells rather than throwing an error and
            stopping.
            """
        )
    ).tag(config=True)

    startup_timeout = Integer(60,
        help=dedent(
            """
            The time to wait (in seconds) for the kernel to start.
            If kernel startup takes longer, a RuntimeError is
            raised.
            """
        )
    ).tag(config=True)

    allow_errors = Bool(False,
        help=dedent(
            """
            If `False` (default), when a cell raises an error the
            execution is stopped and a `CellExecutionError`
            is raised.
            If `True`, execution errors are ignored and the execution
            is continued until the end of the notebook. Output from
            exceptions is included in the cell output in both cases.
            """
        )
    ).tag(config=True)

    force_raise_errors = Bool(False,
        help=dedent(
            """
            If False (default), errors from executing the notebook can be
            allowed with a `raises-exception` tag on a single cell, or the
            `allow_errors` configurable option for all cells. An allowed error
            will be recorded in notebook output, and execution will continue.
            If an error occurs when it is not explicitly allowed, a
            `CellExecutionError` will be raised.
            If True, `CellExecutionError` will be raised for any error that occurs
            while executing the notebook. This overrides both the
            `allow_errors` option and the `raises-exception` cell tag.
            """
        )
    ).tag(config=True)

    extra_arguments = List(Unicode())

    kernel_name = Unicode('',
        help=dedent(
            """
            Name of kernel to use to execute the cells.
            If not set, use the kernel_spec embedded in the notebook.
            """
        )
    ).tag(config=True)

    raise_on_iopub_timeout = Bool(False,
        help=dedent(
            """
            If `False` (default), then the kernel will continue waiting for
            iopub messages until it receives a kernel idle message, or until a
            timeout occurs, at which point the currently executing cell will be
            skipped. If `True`, then an error will be raised after the first
            timeout. This option generally does not need to be used, but may be
            useful in contexts where there is the possibility of executing
            notebooks with memory-consuming infinite loops.
            """
            )
    ).tag(config=True)

    iopub_timeout = Integer(4, allow_none=False,
        help=dedent(
            """
            The time to wait (in seconds) for IOPub output. This generally
            doesn't need to be set, but on some slow networks (such as CI
            systems) the default timeout might not be long enough to get all
            messages.
            """
        )
    ).tag(config=True)

    shutdown_kernel = Enum(['graceful', 'immediate'],
        default_value='graceful',
        help=dedent(
            """
            If `graceful` (default), then the kernel is given time to clean
            up after executing all cells, e.g., to execute its `atexit` hooks.
            If `immediate`, then the kernel is signaled to immediately
            terminate.
            """
            )
    ).tag(config=True)

    kernel_manager_class = Type(
        config=True,
        help='The kernel manager class to use.'
    )
    @default('kernel_manager_class')
    def _km_default(self):
        """Use a dynamic default to avoid importing jupyter_client at startup"""
        try:
            from jupyter_client import KernelManager
        except ImportError:
            raise ImportError("`nbconvert --execute` requires the jupyter_client package: `pip install jupyter_client`")
        return KernelManager

    # mapping of locations of outputs with a given display_id
    # tracks cell index and output index within cell.outputs for
    # each appearance of the display_id
    # {
    #   'display_id': {
    #     cell_idx: [output_idx,]
    #   }
    # }
    _display_id_map = Dict()

    def preprocess(self, nb, resources):
        """
        Preprocess notebook executing each code cell.

        The input argument `nb` is modified in-place.

        Parameters
        ----------
        nb : NotebookNode
            Notebook being executed.
        resources : dictionary
            Additional resources used in the conversion process. For example,
            passing ``{'metadata': {'path': run_path}}`` sets the
            execution path to ``run_path``.

        Returns
        -------
        nb : NotebookNode
            The executed notebook.
        resources : dictionary
            Additional resources used in the conversion process.
        """
        path = resources.get('metadata', {}).get('path', '')
        if path == '':
            path = None

        # clear display_id map
        self._display_id_map = {}

        # from jupyter_client.manager import start_new_kernel

        def start_new_kernel(startup_timeout=60, kernel_name='python', **kwargs):
            km = self.kernel_manager_class(kernel_name=kernel_name)
            km.start_kernel(**kwargs)
            kc = km.client()
            kc.start_channels()
            try:
                kc.wait_for_ready(timeout=startup_timeout)
            except RuntimeError:
                kc.stop_channels()
                km.shutdown_kernel()
                raise

            return km, kc

        kernel_name = nb.metadata.get('kernelspec', {}).get('name', 'python')
        if self.kernel_name:
            kernel_name = self.kernel_name
        self.log.info("Executing notebook with kernel: %s" % kernel_name)
        self.km, self.kc = start_new_kernel(
            startup_timeout=self.startup_timeout,
            kernel_name=kernel_name,
            extra_arguments=self.extra_arguments,
            cwd=path)
        self.kc.allow_stdin = False
        self.nb = nb

        try:
            nb, resources = super(ExecutePreprocessor, self).preprocess(nb, resources)
        finally:
            self.kc.stop_channels()
            self.km.shutdown_kernel(now=self.shutdown_kernel == 'immediate')

        delattr(self, 'nb')

        return nb, resources

    def preprocess_cell(self, cell, resources, cell_index):
        """
        Executes a single code cell. See base.py for details.

        To execute all cells see :meth:`preprocess`.
        """
        if cell.cell_type != 'code' or not cell.source.strip():
            return cell, resources

        reply, outputs = self.run_cell(cell, cell_index)
        cell.outputs = outputs

        cell_allows_errors = (self.allow_errors or "raises-exception"
                              in cell.metadata.get("tags", []))

        if self.force_raise_errors or not cell_allows_errors:
            for out in outputs:
                if out.output_type == 'error':
                    raise CellExecutionError.from_cell_and_msg(cell, out)
            if (reply is not None) and reply['content']['status'] == 'error':
                raise CellExecutionError.from_cell_and_msg(cell, reply['content'])
        return cell, resources

    def _update_display_id(self, display_id, msg):
        """Update outputs with a given display_id"""
        if display_id not in self._display_id_map:
            self.log.debug("display id %r not in %s", display_id, self._display_id_map)
            return

        if msg['header']['msg_type'] == 'update_display_data':
            msg['header']['msg_type'] = 'display_data'

        try:
            out = output_from_msg(msg)
        except ValueError:
            self.log.error("unhandled iopub msg: " + msg['msg_type'])
            return

        for cell_idx, output_indices in self._display_id_map[display_id].items():
            cell = self.nb['cells'][cell_idx]
            outputs = cell['outputs']
            for output_idx in output_indices:
                outputs[output_idx]['data'] = out['data']
                outputs[output_idx]['metadata'] = out['metadata']

    def _wait_for_reply(self, msg_id, cell):
        # wait for finish, with timeout
        while True:
            try:
                if self.timeout_func is not None:
                    timeout = self.timeout_func(cell)
                else:
                    timeout = self.timeout

                if not timeout or timeout < 0:
                    timeout = None
                msg = self.kc.shell_channel.get_msg(timeout=timeout)
            except Empty:
                self.log.error(
                    "Timeout waiting for execute reply (%is)." % self.timeout)
                if self.interrupt_on_timeout:
                    self.log.error("Interrupting kernel")
                    self.km.interrupt_kernel()
                    break
                else:
                    try:
                        exception = TimeoutError
                    except NameError:
                        exception = RuntimeError
                    raise exception("Cell execution timed out")

            if msg['parent_header'].get('msg_id') == msg_id:
                return msg
            else:
                # not our reply
                continue

    def run_cell(self, cell, cell_index=0):
        msg_id = self.kc.execute(cell.source)
        self.log.debug("Executing cell:\n%s", cell.source)
        exec_reply = self._wait_for_reply(msg_id, cell)

        outs = cell.outputs = []

        while True:
            try:
                # We've already waited for execute_reply, so all output
                # should already be waiting. However, on slow networks, like
                # in certain CI systems, waiting < 1 second might miss messages.
                # So long as the kernel sends a status:idle message when it
                # finishes, we won't actually have to wait this long, anyway.
                msg = self.kc.iopub_channel.get_msg(timeout=self.iopub_timeout)
            except Empty:
                self.log.warning("Timeout waiting for IOPub output")
                if self.raise_on_iopub_timeout:
                    raise RuntimeError("Timeout waiting for IOPub output")
                else:
                    break
            if msg['parent_header'].get('msg_id') != msg_id:
                # not an output from our execution
                continue

            msg_type = msg['msg_type']
            self.log.debug("output: %s", msg_type)
            content = msg['content']

            # set the prompt number for the input and the output
            if 'execution_count' in content:
                cell['execution_count'] = content['execution_count']

            if msg_type == 'status':
                if content['execution_state'] == 'idle':
                    break
                else:
                    continue
            elif msg_type == 'execute_input':
                continue
            elif msg_type == 'clear_output':
                outs[:] = []
                # clear display_id mapping for this cell
                for display_id, cell_map in self._display_id_map.items():
                    if cell_index in cell_map:
                        cell_map[cell_index] = []
                continue
            elif msg_type.startswith('comm'):
                continue

            display_id = None
            if msg_type in {'execute_result', 'display_data', 'update_display_data'}:
                display_id = msg['content'].get('transient', {}).get('display_id', None)
                if display_id:
                    self._update_display_id(display_id, msg)
                if msg_type == 'update_display_data':
                    # update_display_data doesn't get recorded
                    continue

            try:
                out = output_from_msg(msg)
            except ValueError:
                self.log.error("unhandled iopub msg: " + msg_type)
                continue
            if display_id:
                # record output index in:
                #   _display_id_map[display_id][cell_idx]
                cell_map = self._display_id_map.setdefault(display_id, {})
                output_idx_list = cell_map.setdefault(cell_index, [])
                output_idx_list.append(len(outs))

            outs.append(out)

        return exec_reply, outs
Beispiel #8
0
class PAMAuthenticator(LocalAuthenticator):
    """Authenticate local UNIX users with PAM"""

    # run PAM in a thread, since it can be slow
    executor = Any()

    @default('executor')
    def _default_executor(self):
        return ThreadPoolExecutor(1)

    encoding = Unicode(
        'utf8',
        help="""
        The text encoding to use when communicating with PAM
        """,
    ).tag(config=True)

    service = Unicode(
        'login',
        help="""
        The name of the PAM service to use for authentication
        """,
    ).tag(config=True)

    open_sessions = Bool(
        True,
        help="""
        Whether to open a new PAM session when spawners are started.

        This may trigger things like mounting shared filsystems,
        loading credentials, etc. depending on system configuration,
        but it does not always work.

        If any errors are encountered when opening/closing PAM sessions,
        this is automatically set to False.
        """,
    ).tag(config=True)

    check_account = Bool(
        True,
        help="""
        Whether to check the user's account status via PAM during authentication.

        The PAM account stack performs non-authentication based account 
        management. It is typically used to restrict/permit access to a 
        service and this step is needed to access the host's user access control.

        Disabling this can be dangerous as authenticated but unauthorized users may
        be granted access and, therefore, arbitrary execution on the system.
        """,
    ).tag(config=True)

    admin_groups = Set(
        help="""
        Authoritative list of user groups that determine admin access.
        Users not in these groups can still be granted admin status through admin_users.

        White/blacklisting rules still apply.
        """
    ).tag(config=True)

    pam_normalize_username = Bool(
        False,
        help="""
        Round-trip the username via PAM lookups to make sure it is unique

        PAM can accept multiple usernames that map to the same user,
        for example DOMAIN\\username in some cases.  To prevent this,
        convert username into uid, then back to uid to normalize.
        """,
    ).tag(config=True)

    def __init__(self, **kwargs):
        if pamela is None:
            raise _pamela_error from None
        super().__init__(**kwargs)

    @run_on_executor
    def is_admin(self, handler, authentication):
        """PAM admin status checker. Returns Bool to indicate user admin status."""
        # Checks upper level function (admin_users)
        admin_status = super().is_admin(handler, authentication)
        username = authentication['name']

        # If not yet listed as an admin, and admin_groups is on, use it authoritatively
        if not admin_status and self.admin_groups:
            try:
                # Most likely source of error here is a group name <-> gid mapping failure
                # This is most likely due to a typo in the configuration or in the case of LDAP/AD, a network
                # connectivity issue. Maybe a long one where the local caches have timed out, though PAM would
                # most likely would refuse to authenticate a remote user by that point.

                # It was decided that the best course of action on group resolution failure was to
                # fail to authenticate and raise instead of soft-failing and not changing admin status
                # (returning None instead of just the username) as this indicates some sort of system failure

                admin_group_gids = {self._getgrnam(x).gr_gid for x in self.admin_groups}
                user_group_gids = set(
                    self._getgrouplist(username, self._getpwnam(username).pw_gid)
                )
                admin_status = len(admin_group_gids & user_group_gids) != 0

            except Exception as e:
                if handler is not None:
                    self.log.error(
                        "PAM Admin Group Check failed (%s@%s): %s",
                        username,
                        handler.request.remote_ip,
                        e,
                    )
                else:
                    self.log.error("PAM Admin Group Check failed: %s", e)
                # re-raise to return a 500 to the user and indicate a problem. We failed, not them.
                raise

        return admin_status

    @run_on_executor
    def authenticate(self, handler, data):
        """Authenticate with PAM, and return the username if login is successful.

        Return None otherwise.
        """
        username = data['username']
        try:
            pamela.authenticate(
                username, data['password'], service=self.service, encoding=self.encoding
            )
        except pamela.PAMError as e:
            if handler is not None:
                self.log.warning(
                    "PAM Authentication failed (%s@%s): %s",
                    username,
                    handler.request.remote_ip,
                    e,
                )
            else:
                self.log.warning("PAM Authentication failed: %s", e)
            return None

        if self.check_account:
            try:
                pamela.check_account(
                    username, service=self.service, encoding=self.encoding
                )
            except pamela.PAMError as e:
                if handler is not None:
                    self.log.warning(
                        "PAM Account Check failed (%s@%s): %s",
                        username,
                        handler.request.remote_ip,
                        e,
                    )
                else:
                    self.log.warning("PAM Account Check failed: %s", e)
                return None

        return username

    @run_on_executor
    def pre_spawn_start(self, user, spawner):
        """Open PAM session for user if so configured"""
        if not self.open_sessions:
            return
        try:
            pamela.open_session(user.name, service=self.service, encoding=self.encoding)
        except pamela.PAMError as e:
            self.log.warning("Failed to open PAM session for %s: %s", user.name, e)
            self.log.warning("Disabling PAM sessions from now on.")
            self.open_sessions = False

    @run_on_executor
    def post_spawn_stop(self, user, spawner):
        """Close PAM session for user if we were configured to opened one"""
        if not self.open_sessions:
            return
        try:
            pamela.close_session(
                user.name, service=self.service, encoding=self.encoding
            )
        except pamela.PAMError as e:
            self.log.warning("Failed to close PAM session for %s: %s", user.name, e)
            self.log.warning("Disabling PAM sessions from now on.")
            self.open_sessions = False

    def normalize_username(self, username):
        """Round-trip the username to normalize it with PAM

        PAM can accept multiple usernames as the same user, normalize them."""
        if self.pam_normalize_username:
            import pwd

            uid = pwd.getpwnam(username).pw_uid
            username = pwd.getpwuid(uid).pw_name
            username = self.username_map.get(username, username)
        else:
            return super().normalize_username(username)
Beispiel #9
0
class VitessceWidget(widgets.DOMWidget):
    """
    A class to represent a Jupyter widget for Vitessce.
    """

    # Name of the widget view class in front-end
    _view_name = Unicode('VitessceView').tag(sync=True)

    # Name of the widget model class in front-end
    _model_name = Unicode('VitessceModel').tag(sync=True)

    # Name of the front-end module containing widget view
    _view_module = Unicode('vitessce-jupyter').tag(sync=True)

    # Name of the front-end module containing widget model
    _model_module = Unicode('vitessce-jupyter').tag(sync=True)

    # Version of the front-end module containing widget view
    _view_module_version = Unicode(
        '^%s.%s.%s' %
        (js_version_info[0], js_version_info[1], js_version_info[2])).tag(
            sync=True)
    # Version of the front-end module containing widget model
    _model_module_version = Unicode(
        '^%s.%s.%s' %
        (js_version_info[0], js_version_info[1], js_version_info[2])).tag(
            sync=True)

    # Widget specific property.
    # Widget properties are defined as traitlets. Any property tagged with `sync=True`
    # is automatically synced to the frontend *any* time it changes in Python.
    # It is synced back to Python from the frontend *any* time the model is touched.
    config = Dict({}).tag(sync=True)
    height = Int(600).tag(sync=True)
    theme = Unicode('auto').tag(sync=True)
    proxy = Bool(False).tag(sync=True)

    next_port = DEFAULT_PORT

    def __init__(self,
                 config,
                 height=600,
                 theme='auto',
                 port=None,
                 proxy=False):
        """
        Construct a new Vitessce widget.

        :param config: A view config instance.
        :type config: VitessceConfig
        :param str theme: The theme name, either "light" or "dark". By default, "auto", which selects light or dark based on operating system preferences.
        :param int height: The height of the widget, in pixels. By default, 600.
        :param int port: The port to use when serving data objects on localhost. By default, 8000.
        :param bool proxy: Is this widget being served through a proxy, for example with a cloud notebook (e.g. Binder)?

        .. code-block:: python
            :emphasize-lines: 4

            from vitessce import VitessceConfig, VitessceWidget

            vc = VitessceConfig.from_object(my_scanpy_object)
            vw = vc.widget()
            vw
        """

        base_url, use_port, VitessceWidget.next_port = get_base_url_and_port(
            port, VitessceWidget.next_port, proxy=proxy)
        config_dict = config.to_dict(base_url=base_url)
        routes = config.get_routes()

        super(VitessceWidget, self).__init__(config=config_dict,
                                             height=height,
                                             theme=theme,
                                             proxy=proxy)

        serve_routes(routes, use_port)

    def _get_coordination_value(self, coordination_type, coordination_scope):
        obj = self.config['coordinationSpace'][coordination_type]
        obj_scopes = list(obj.keys())
        if coordination_scope != None:
            if coordination_scope in obj_scopes:
                return obj[coordination_scope]
            else:
                raise ValueError(
                    f"The specified coordination scope '{coordination_scope}' could not be found for the coordination type '{coordination_type}'. Known coordination scopes are {obj_scopes}"
                )
        else:
            if len(obj_scopes) == 1:
                auto_coordination_scope = obj_scopes[0]
                return obj[auto_coordination_scope]
            elif len(obj_scopes) > 1:
                raise ValueError(
                    f"The coordination scope could not be automatically determined because multiple coordination scopes exist for the coordination type '{coordination_type}'. Please specify one of {obj_scopes} using the scope parameter."
                )
            else:
                raise ValueError(
                    f"No coordination scopes were found for the coordination type '{coordination_type}'."
                )

    def get_cell_selection(self, scope=None):
        return self._get_coordination_value('cellSelection', scope)
Beispiel #10
0
class Application(v.VuetifyTemplate, HubListener):
    _metadata = Dict({'mount_id': 'content'}).tag(sync=True)

    show_menu_bar = Bool(True).tag(sync=True)
    show_toolbar = Bool(True).tag(sync=True)
    show_tray_bar = Bool(True).tag(sync=True)

    template = load_template("app.vue", __file__).tag(sync=True)
    methods = Unicode("""
    {
        checkNotebookContext() {
            this.notebook_context = document.getElementById("ipython-main-app");
            return this.notebook_context;
        },

        loadRemoteCSS() {
            var muiIconsSheet = document.createElement('link');
            muiIconsSheet.type='text/css';
            muiIconsSheet.rel='stylesheet';
            muiIconsSheet.href='https://cdn.jsdelivr.net/npm/@mdi/[email protected]/css/materialdesignicons.min.css';
            document.getElementsByTagName('head')[0].appendChild(muiIconsSheet);
            return true;
        }
    }
    """).tag(sync=True)

    css = load_template("app.css", __file__).tag(sync=True)

    def __init__(self, configuration=None, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._application_handler = JupyterApplication()

        # Load in default configuration file. This must come before loading
        #  in the components for the toolbar/tray_bar.
        # self.load_configuration(configuration)

        plugins = {
            entry_point.name: entry_point.load()
            for entry_point in pkg_resources.iter_entry_points(group='plugins')
        }

        components = {
            'g-viewer-area': ViewerArea(session=self.session),
            'g-default-toolbar': DefaultToolbar(session=self.session),
            'g-tray-area': TrayArea(session=self.session)
        }

        components.update(
            {k: v(session=self.session)
             for k, v in tools.members.items()})

        self.components = components

        # Parse configuration
        self.load_configuration(configuration)

        # Subscribe to viewer messages
        self.hub.subscribe(self, NewViewerMessage, handler=self._on_new_viewer)

        # Subscribe to load data messages
        self.hub.subscribe(self,
                           LoadDataMessage,
                           handler=lambda msg: self.load_data(msg.path))

    @property
    def hub(self):
        return self._application_handler.data_collection.hub

    @property
    def session(self):
        return self._application_handler.session

    def load_data(self, path):
        self._application_handler.load_data(path)

    def _on_new_viewer(self, msg):
        view = self._application_handler.new_data_viewer(msg.cls,
                                                         data=msg.data,
                                                         show=False)

        if msg.x_attr is not None:
            x = msg.data.id[msg.x_attr]
            view.state.x_att = x

        self.hub.broadcast(AddViewerMessage(view, sender=self))

        return view

    def _registry_component(self):
        pass

    def load_configuration(self, path):
        # Parse the default configuration file
        default_path = os.path.join(os.path.dirname(__file__), "configs")

        plugins = {
            entry_point.name: entry_point.load()
            for entry_point in pkg_resources.iter_entry_points(group='plugins')
        }

        if path is None or path == 'default':
            path = os.path.join(default_path, "default", "default.yaml")
        elif path == 'cubeviz':
            path = os.path.join(default_path, "cubeviz", "cubeviz.yaml")
        elif not os.path.isfile(path):
            raise ValueError("Configuration must be path to a .yaml file.")

        with open(path, 'r') as f:
            config = yaml.safe_load(f)

        # Get a reference to the component visibility states
        # comps = config.get('components', {})

        # Toggle the rendering of the components in the gui
        # self.show_menu_bar = comps.get('menu_bar', True)
        # self.show_toolbar = comps.get('toolbar', True)
        # self.show_tray_bar = comps.get('tray_bar', True)

        if 'viewer_area' in config:
            viewer_area_layout = config.get('viewer_area')
            self.components.get('g-viewer-area').parse_layout(
                viewer_area_layout)

        # Add the toolbar item filter to the toolbar component
        for name in config.get('toolbar', []):
            tool = tools.members.get(name)(session=self.session)
            self.components['g-default-toolbar'].add_tool(tool)
Beispiel #11
0
class Displayable(StepDOMWidget):
    """A Jupyter widget for visualizing constraint satisfaction problems (CSPs).

    Handles arc consistency, domain splitting, and stochastic local search (SLS).

    See the accompanying frontend file: `js/src/csp/CSPVisualizer.ts`
    """
    _view_name = Unicode('CSPViewer').tag(sync=True)
    _model_name = Unicode('CSPViewerModel').tag(sync=True)
    _view_module = Unicode('aispace2').tag(sync=True)
    _model_module = Unicode('aispace2').tag(sync=True)
    _view_module_version = Unicode(__version__).tag(sync=True)
    _model_module_version = Unicode(__version__).tag(sync=True)

    # The CSP that is synced as a graph to the frontend.
    graph = Instance(klass=CSP, allow_none=True).tag(sync=True,
                                                     to_json=csp_to_json,
                                                     from_json=json_to_csp)

    # Constrols whether the auto arc consistency button will show up in the widget (will not in SLS)
    need_AC_button = Bool(True).tag(sync=True)

    # Tracks if the visualization has been rendered at least once in the front-end. See the @visualize decorator.
    _previously_rendered = Bool(False).tag(sync=True)
    wait_for_render = Bool(True).tag(sync=True)

    def __init__(self):
        super().__init__()
        self.visualizer = self

        ##############################
        ### SLS-specific variables ###
        ##############################
        # Tracks if this is the first conflict reported.
        # If so, will also compute non-conflicts to highlight green the first time around.
        self._sls_first_conflict = True

        ##########################################
        ### Arc consistency-specific variables ###
        ##########################################
        # A reference to the arc the user has selected for arc consistency. A tuple of (variable name, Constraint instance).
        self._selected_arc = None
        # True if the user has selected an arc to perform arc-consistency on. Otherwise, an arc is automatically chosen.
        self._has_user_selected_arc = False
        # True if the algorithm is at a point where an arc is waiting to be chosen. Used to filter out extraneous clicks otherwise.
        self._is_waiting_for_arc_selection = False

        ###########################################
        ### Domain splitting-specific variables ###
        ##########################################
        # A reference to the variable the user has selected for domain splitting.
        self._selected_var = None
        # True if the user has selected a var to perform domain splitting on. Otherwise, a variable is automatically chosen.
        self._has_user_selected_var = False
        # True if the algorithm is at a point where a var is waiting to be chosen. Used to filter out extraneous clicks otherwise.
        self._is_waiting_for_var_selection = False
        # The domain the user has chosen as their first split for `_selected_var`.
        self._domain_split = None

        # self.graph = self.csp
        self.graph = CSP(self.csp.domains, self.csp.constraints,
                         self.csp.positions)
        (self._domain_map,
         self._edge_map) = generate_csp_graph_mappings(self.csp)

        self._initialize_controls()

    def wait_for_arc_selection(self, to_do):
        """Pauses execution until an arc has been selected and returned.

        If the algorithm is running in auto mode, an arc is returned immediately.
        Otherwise, this function blocks until an arc is selected by the user.

        Args:
            to_do (set): A set of arcs to choose from. This set will be modified.

        Returns:
            (string, Constraint):
                A tuple (var_name, constraint) that represents an arc from `to_do`.
        """
        # Running in Auto mode. Don't block!
        if self.max_display_level == 1 or self.max_display_level == 0:
            return to_do.pop()

        self._is_waiting_for_arc_selection = True
        self._block_for_user_input.wait()

        if self._has_user_selected_arc:
            self._has_user_selected_arc = False
            to_do.discard(self._selected_arc)
            return self._selected_arc

        # User did not select. Return random arc.
        return to_do.pop()

    def wait_for_var_selection(self, iter_var):
        """Pauses execution until a variable has been selected and returned.

        If the user steps instead of clicking on a variable, a random variable is returned.
        Otherwise, the variable clicked by the user is returned, but only if it is a variable
        that can be split on. Otherwise, this function continues waiting.

        Args:
            iter_var (iter): Variables that the user is allowed to split on.

        Returns:
            (string): The variable to split on.
        """
        # Running in Auto mode. Split in half!
        if self.max_display_level == 1:
            return list(iter_var)[0]

        # Running in Auto Arc Consistency mode. Change to normal!
        if self.max_display_level == 0:
            self.max_display_level = 2

        iter_var = list(iter_var)
        self._send_highlight_splittable_nodes_action(iter_var)
        self._is_waiting_for_var_selection = True
        self._block_for_user_input.wait()

        while (self.max_display_level != 1
               and not self._has_user_selected_var):
            self._block_for_user_input.wait()

        if self._has_user_selected_var:
            self._has_user_selected_var = False
            if self._selected_var in iter_var:
                return self._selected_var
            else:
                return self.wait_for_var_selection(iter_var)

        self._is_waiting_for_var_selection = False
        return iter_var[0]

    def choose_domain_partition(self, domain, var):
        """Pauses execution until a domain has been split on.

        If the user chooses to not select a domain (clicks 'Cancel'), splits the domain in half.
        Otherwise, the subset of the domain chosen by the user is used as the initial split.

        Args:
            domain (set): Domain of the variable being split on.

        Returns:
            (set): A subset of the domain to be split on first.
        """
        # Running in Auto mode. Split in half!
        if self.max_display_level == 1:
            split = len(domain) // 2
            dom1 = set(list(domain)[:split])
            dom2 = domain - dom1
            return dom1, dom2

        if self._domain_split is None:
            # Split in half
            split = len(domain) // 2
            dom1 = set(list(domain)[:split])
            dom2 = domain - dom1
            return dom1, dom2

        # make sure type of chosen domain matches original domain
        if all(isinstance(n, int) for n in domain):
            number_domain = set()
            for n in self._domain_split:
                number_domain.add(int(n))
            self._domain_split = number_domain

        split1 = set(self._domain_split)
        split2 = set(domain) - split1
        return split1, split2

    def handle_custom_msgs(self, _, content, buffers=None):
        super().handle_custom_msgs(None, content, buffers)
        event = content.get('event', '')

        if event == 'arc:click':
            """
            Expects a dictionary containing:
                varName (string): The name of the variable connected to this arc.
                constId (string): The id of the constraint connected to this arc.
            """
            if self._is_waiting_for_arc_selection:
                var_name = content.get('varName')
                const = self.csp.constraints[content.get('constId')]
                self.max_display_level = 2

                self._selected_arc = (var_name, const)
                self._has_user_selected_arc = True
                self._block_for_user_input.set()
                self._block_for_user_input.clear()
                self._is_waiting_for_arc_selection = False

        elif event == 'var:click':
            """
            Expects a dictionary containing:
                varName (string): The name of the variable to split on.
            """
            if not self._is_waiting_for_var_selection and content.get(
                    'varType') == 'csp:variable':
                self.send({'action': 'chooseDomainSplitBeforeAC'})

        elif event == 'domain_split':
            """
            Expects a dictionary containing:
                domain (string[]|None):
                    An array of the elements in the domain to first split on, or None if no choice is made.
                    In this case, splits the domain in half as a default.
            """
            domain = content.get('domain')
            var_name = content.get('var')
            self._selected_var = var_name
            self._domain_split = domain
            self._has_user_selected_var = True
            self._block_for_user_input.set()
            self._block_for_user_input.clear()
            self._is_waiting_for_var_selection = False

        elif event == 'reset':
            """
            Reset the algorithm and graph
            """
            # Before resetting backend, freeze the execution of queued function to avoid undetermined state
            self._pause()
            # Wait until freezeing completed
            sleep(0.2)

            # Reset algorithm related variables

            user_sleep_time = getattr(self, 'sleep_time', None)
            super().__init__()
            self.sleep_time = user_sleep_time
            self.visualizer = self
            self._sls_first_conflict = True
            self._selected_arc = None
            self._has_user_selected_arc = False
            self._is_waiting_for_arc_selection = False
            self._selected_var = None
            self._has_user_selected_var = False
            self._is_waiting_for_var_selection = False
            self._domain_split = None
            self.graph = CSP(self.csp.domains, self.csp.constraints,
                             self.csp.positions)
            (self._domain_map,
             self._edge_map) = generate_csp_graph_mappings(self.csp)

            # Tell frontend that it is ready to reset frontend graph and able to restart algorithm
            self.send({'action': 'frontReset'})

            # Terminate current running thread
            if self._thread:
                self.stop_thread(self._thread)

        elif event == 'initial_render':
            queued_func = getattr(self, '_queued_func', None)

            # Run queued function after we know the frontend view exists
            if queued_func:
                func = queued_func['func']
                args = queued_func['args']
                kwargs = queued_func['kwargs']
                self._previously_rendered = True
                self._thread = ReturnableThread(target=func,
                                                args=args,
                                                kwargs=kwargs)
                self._thread.start()

    def display(self, level, *args, **kwargs):
        if self.wait_for_render is False:
            return

        should_wait = True

        if args[0] == 'Performing AC with domains':
            should_wait = False
            domains = args[1]
            vars_to_change = []
            domains_to_change = []

            for var, domain in domains.items():
                vars_to_change.append(var)
                domains_to_change.append(domain)

            self._send_set_domains_action(vars_to_change, domains_to_change)

        elif args[0] == 'Domain pruned':
            variable = args[2]
            domain = args[4]
            constraint = args[6]
            self._send_set_domains_action(variable, domain)
            self._send_highlight_arcs_action((variable, constraint),
                                             style='bold',
                                             colour='green')

        elif args[0] == "Processing arc (":
            variable = args[1]
            constraint = args[3]
            self._send_highlight_arcs_action((variable, constraint),
                                             style='bold',
                                             colour=None)

        elif args[0] == "Arc: (" and args[4] == ") is inconsistent":
            variable = args[1]
            constraint = args[3]
            self._send_highlight_arcs_action((variable, constraint),
                                             style='bold',
                                             colour='red')

        elif args[0] == "Arc: (" and args[4] == ") now consistent":
            variable = args[1]
            constraint = args[3]
            self._send_highlight_arcs_action((variable, constraint),
                                             style='normal',
                                             colour='green')
            should_wait = False

        elif (args[0] == "Adding"
              or args[0] == "New domain. Adding") and args[2] == "to to_do.":
            if args[1] != "nothing":
                arcs = list(args[1])
                arcs_to_highlight = []

                for arc in arcs:
                    arcs_to_highlight.append((arc[0], arc[1]))

                self._send_highlight_arcs_action(arcs_to_highlight,
                                                 style='normal',
                                                 colour='blue')

        elif args[
                0] == "You can now split domain. Click on a variable whose domain has more than 1 value.":
            self.send({'action': 'chooseDomainSplit'})

        elif args[0] == "... splitting":
            self.send({
                'action': 'setOrder',
                'var': args[1],
                'domain': args[3],
                'other': args[5]
            })

        elif args[0] == "Solution found:":
            if self.max_display_level == 0:
                self.max_display_level = 2
            solString = ""
            for var in args[1]:
                solString += var + "=" + str(args[1][var]) + ", "
            solString = solString[:-2]
            self.send({'action': 'setPreSolution', 'solution': solString})
            args += (
                "\nClick Step, Auto Arc Consistency or Auto Solve to find solutions in other domains.",
            )

        elif args[0] == "Solving new domain with":
            self.send({
                'action': 'setSplit',
                'domain': args[2],
                'var': args[1]
            })

        elif args[
                0] == "Click Step, Auto Arc Consistency or Auto Solve to find solutions in other domains.":
            if self.max_display_level == 0:
                self.max_display_level = 2
            self.send({'action': 'noSolution'})

        #############################
        ### SLS-specific displays ###
        #############################

        elif args[0] == "Initial assignment":
            assignment = args[1]
            for (key, val) in assignment.items():
                self._send_set_domains_action(key, [val])

        elif args[0] == "Assigning" and args[2] == "=":
            var = args[1]
            domain = args[3]
            self._send_set_domains_action(var, [domain])
            self._send_highlight_nodes_action(var, "blue")

        elif args[0] == "Checking":
            node = args[1]
            self._send_highlight_nodes_action(node, "blue")

        elif args[0] == "Still inconsistent":
            const = args[1]
            nodes_to_highlight = {const}
            arcs_to_highlight = []

            for var in const.scope:
                nodes_to_highlight.add(var)
                arcs_to_highlight.append((var, const))

            self._send_highlight_nodes_action(nodes_to_highlight, "red")
            self._send_highlight_arcs_action(arcs_to_highlight, "bold", "red")

        elif args[0] == "Still consistent":
            const = args[1]
            nodes_to_highlight = {const}
            arcs_to_highlight = []

            for var in const.scope:
                nodes_to_highlight.add(var)
                arcs_to_highlight.append((var, const))

            self._send_highlight_nodes_action(nodes_to_highlight, "green")
            self._send_highlight_arcs_action(arcs_to_highlight, "bold",
                                             "green")

        elif args[0] == "Became consistent":
            const = args[1]
            nodes_to_highlight = {const}
            arcs_to_highlight = []

            for var in const.scope:
                nodes_to_highlight.add(var)
                arcs_to_highlight.append((var, const))

            self._send_highlight_nodes_action(nodes_to_highlight, "green")
            self._send_highlight_arcs_action(arcs_to_highlight, "bold",
                                             "green")

        elif args[0] == "Became inconsistent":
            const = args[1]
            nodes_to_highlight = {const}
            arcs_to_highlight = []

            for var in const.scope:
                nodes_to_highlight.add(var)
                arcs_to_highlight.append((var, const))

            self._send_highlight_nodes_action(nodes_to_highlight, "red")
            self._send_highlight_arcs_action(arcs_to_highlight, "bold", "red")

        elif args[0] == "AC done. Reduced domains":
            should_wait = False

        elif args[0] == "Conflicts:":
            conflicts = args[1]
            conflict_nodes_to_highlight = set()
            conflict_arcs_to_highlight = []
            non_conflict_nodes_to_highlight = set()
            non_conflict_arcs_to_highlight = []

            if self._sls_first_conflict:
                # Highlight all non-conflicts green
                self._sls_first_conflict = False
                not_conflicts = set(self.csp.constraints) - conflicts

                for not_conflict in not_conflicts:
                    non_conflict_nodes_to_highlight.add(not_conflict)

                    for node in not_conflict.scope:
                        non_conflict_nodes_to_highlight.add(node)
                        non_conflict_arcs_to_highlight.append(
                            (node, not_conflict))

                self._send_highlight_nodes_action(
                    non_conflict_nodes_to_highlight, "green")
                self._send_highlight_arcs_action(
                    non_conflict_arcs_to_highlight, "bold", "green")

            # Highlight all conflicts red
            for conflict in conflicts:
                conflict_nodes_to_highlight.add(conflict)

                for node in conflict.scope:
                    conflict_nodes_to_highlight.add(node)
                    conflict_arcs_to_highlight.append((node, conflict))

            self._send_highlight_nodes_action(conflict_nodes_to_highlight,
                                              "red")
            self._send_highlight_arcs_action(conflict_arcs_to_highlight,
                                             "bold", "red")

        super().display(level, *args, **dict(kwargs, should_wait=should_wait))

    def _send_highlight_nodes_action(self, vars, colour):
        """Sends a message to the front-end visualization to highlight nodes.

        Args:
            vars (string|string[]): The name(s) of the variables to highlight.
            colour (string|None): A HTML colour string for the stroke of the node.
                Passing in None will keep the existing stroke of the node.
        """

        # We don't want to check if it is iterable because a string is iterable
        if not isinstance(vars, list) and not isinstance(vars, set):
            vars = [vars]

        nodeIds = []
        for var in vars:
            nodeIds.append(self._domain_map[var])

        self.send({
            'action': 'highlightNodes',
            'nodeIds': nodeIds,
            'colour': colour
        })

    def _send_highlight_splittable_nodes_action(self, vars):
        """Sends a message to the front-end visualization to highlight Splittable nodes when users can split domain.

        Args:
            vars (string|string[]): The name(s) of the splittable variables to highlight.
        """

        # We don't want to check if it is iterable because a string is iterable
        if not isinstance(vars, list) and not isinstance(vars, set):
            vars = [vars]

        nodeIds = []
        for var in vars:
            nodeIds.append(self._domain_map[var])

        self.send({
            'action': 'highlightSplittableNodes',
            'nodeIds': nodeIds,
        })

    def _send_highlight_arcs_action(self, arcs, style='normal', colour=None):
        """Sends a message to the front-end visualization to highlight arcs.

        Args:
            arcs ((string, Constraint)|(string, Constraint)[]):
                Tuples of (variable name, Constraint instance) that form an arc.
                For convenience, you do not need to pass a list of tuples of you only have one to highlight.
            style ('normal'|'bold'): Style of the highlight. Applied to every arc passed in.
            colour (string|None): A HTML colour string for the colour of the line.
                Passing in None will keep the existing colour of the arcs.
        """

        if not isinstance(arcs, list):
            arcs = [arcs]

        arc_ids = []
        for arc in arcs:
            arc_ids.append(self._edge_map[arc])

        self.send({
            'action': 'highlightArcs',
            'arcIds': arc_ids,
            'style': style,
            'colour': colour
        })

    def _send_set_domains_action(self, vars, domains):
        """Sends a message to the front-end visualization to set the domains of variables.

        Args:
            vars (string|string[]): The name of the variable(s) whose domain should be changed.
            domains (List[int|string]|List[List[int|string]]): The updated domain of the variable(s).
              If vars is an array, then domain is an array of domains, in the same order.
        """

        is_single_var = False
        if not isinstance(vars, list):
            vars = [vars]
            is_single_var = True

        self.send({
            'action':
            'setDomains',
            'nodeIds': [self._domain_map[var] for var in vars],
            'domains':
            [list(domain)
             for domain in domains] if not is_single_var else [domains]
        })
class KeystoneAuthenticator(Authenticator):
    auth_url = Unicode(config=True,
                       help="""
        Keystone server auth url
        """)

    api_version = Unicode('3',
                          config=True,
                          help="""
        Keystone authentication version
        """)

    region_name = Unicode(config=True,
                          help="""
        Keystone authentication region name
        """)

    @gen.coroutine
    def authenticate(self, handler, data):
        username = data['username']
        password = data['password']

        client = self._create_client(username=username, password=password)
        token = client.get_token()

        if token is None:
            return None

        auth_state = {}
        openstack_rc = {
            'OS_AUTH_URL': self.auth_url,
            'OS_INTERFACE': 'public',
            'OS_IDENTITY_API_VERSION': self.api_version,
            'OS_AUTH_TYPE': 'token',
            'OS_TOKEN': token,
        }

        if self.region_name:
            openstack_rc['OS_REGION_NAME'] = self.region_name

        projects = client.get_projects()

        if projects:
            default_project = projects[0]
            openstack_rc['OS_PROJECT_NAME'] = default_project['name']
            openstack_rc['OS_PROJECT_DOMAIN_ID'] = default_project['domain_id']
            domain = client.get_project_domain(default_project)
            if domain:
                openstack_rc['OS_PROJECT_DOMAIN_NAME'] = domain['name']
        else:
            self.log.warn(('Could not select default project for user %r, '
                           'no projects found'), username)

        auth_state['openstack_rc'] = openstack_rc

        return dict(
            name=username,
            auth_state=auth_state,
        )

    @gen.coroutine
    def refresh_user(self, user, handler=None):
        auth_state = yield user.get_auth_state()
        if not auth_state:
            # auth_state not enabled
            return True

        try:
            openstack_rc = auth_state.get('openstack_rc', {})
            token = openstack_rc.get('OS_TOKEN')

            if not token:
                self.log.warning(
                    ('Could not get OpenStack token from auth_state'))
                return True

            client = self._create_client(token=token)

            # If we can generate a new token, it means ours is still valid.
            # There is no value in storing the new token, as its expiration will
            # be tied to the requesting token's expiration.
            return client.get_token() is not None
        except Exception as err:
            self.log.warning(
                (f'Failed to refresh OpenStack token in pre_spawn: {err}'))
            return True

    @gen.coroutine
    def pre_spawn_start(self, user, spawner):
        """Fill in OpenRC environment variables from user auth state.
        """
        auth_state = yield user.get_auth_state()
        if not auth_state:
            # auth_state not enabled
            self.log.error(
                ('auth_state is not enabled! Cannot set OpenStack RC '
                 'parameters'))
            return
        for rc_key, rc_value in auth_state.get('openstack_rc', {}).items():
            spawner.environment[rc_key] = rc_value

    def _create_client(self, **kwargs):
        return Client(self.auth_url, log=self.log, **kwargs)
Beispiel #13
0
class TextFileContentsManager(FileContentsManager, Configurable):
    """
    A FileContentsManager Class that reads and stores notebooks to classical
    Jupyter notebooks (.ipynb), R Markdown notebooks (.Rmd), Julia (.jl),
    Python (.py) or R scripts (.R)
    """

    nb_extensions = [ext for ext in NOTEBOOK_EXTENSIONS if ext != '.ipynb']

    # Dictionary: notebook path => (fmt, formats) where fmt is the current format, and formats the paired formats.
    paired_notebooks = dict()

    def all_nb_extensions(self):
        """
        Notebook extensions, including ipynb
        :return:
        """
        return ['.ipynb'] + self.nb_extensions

    default_jupytext_formats = Unicode(
        u'',
        help='Save notebooks to these file extensions. '
             'Can be any of ipynb,Rmd,md,jl,py,R,nb.jl,nb.py,nb.R '
             'comma separated. If you want another format than the '
             'default one, append the format name to the extension, '
             'e.g. ipynb,py:percent to save the notebook to '
             'hydrogen/spyder/vscode compatible scripts',
        config=True)

    preferred_jupytext_formats_save = Unicode(
        u'',
        help='Preferred format when saving notebooks as text, per extension. '
             'Use "jl:percent,py:percent,R:percent" if you want to save '
             'Julia, Python and R scripts in the double percent format and '
             'only write "jupytext_formats": "py" in the notebook metadata.',
        config=True)

    preferred_jupytext_formats_read = Unicode(
        u'',
        help='Preferred format when reading notebooks from text, per '
             'extension. Use "py:sphinx" if you want to read all python '
             'scripts as Sphinx gallery scripts.',
        config=True)

    default_notebook_metadata_filter = Unicode(
        u'',
        help="Cell metadata that should be save in the text representations. "
             "Examples: 'all', '-all', 'widgets,nteract', 'kernelspec,jupytext-all'",
        config=True)

    default_cell_metadata_filter = Unicode(
        u'',
        help="Notebook metadata that should be saved in the text representations. "
             "Examples: 'all', 'hide_input,hide_output'",
        config=True)

    comment_magics = Enum(
        values=[True, False],
        allow_none=True,
        help='Should Jupyter magic commands be commented out in the text representation?',
        config=True)

    split_at_heading = Bool(
        False,
        help='Split markdown cells on headings (Markdown and R Markdown formats only)',
        config=True)

    sphinx_convert_rst2md = Bool(
        False,
        help='When opening a Sphinx Gallery script, convert the reStructuredText to markdown',
        config=True)

    outdated_text_notebook_margin = Float(
        1.0,
        help='Refuse to overwrite inputs of a ipynb notebooks with those of a '
             'text notebook when the text notebook plus margin is older than '
             'the ipynb notebook',
        config=True)

    def drop_paired_notebook(self, path):
        """Remove the current notebook from the list of paired notebooks"""
        if path not in self.paired_notebooks:
            return

        fmt, formats = self.paired_notebooks.pop(path)
        prev_paired_paths = paired_paths(path, fmt, formats)
        for alt_path, _ in prev_paired_paths:
            if alt_path in self.paired_notebooks:
                self.drop_paired_notebook(alt_path)

    def update_paired_notebooks(self, path, fmt, formats):
        """Update the list of paired notebooks to include/update the current pair"""
        if not formats:
            self.drop_paired_notebook(path)
            return

        new_paired_paths = paired_paths(path, fmt, formats)
        for alt_path, _ in new_paired_paths:
            self.drop_paired_notebook(alt_path)

        long_formats = long_form_multiple_formats(formats)
        if len(long_formats) == 1 and set(long_formats[0]) <= {'extension'}:
            return

        short_formats = short_form_multiple_formats(formats)
        for alt_path, alt_fmt in new_paired_paths:
            self.paired_notebooks[alt_path] = short_form_one_format(alt_fmt), short_formats

    def set_default_format_options(self, format_options, read=False):
        """Set default format option"""
        if self.default_notebook_metadata_filter:
            format_options.setdefault('notebook_metadata_filter', self.default_notebook_metadata_filter)
        if self.default_cell_metadata_filter:
            format_options.setdefault('cell_metadata_filter', self.default_cell_metadata_filter)
        if self.comment_magics is not None:
            format_options.setdefault('comment_magics', self.comment_magics)
        if self.split_at_heading:
            format_options.setdefault('split_at_heading', self.split_at_heading)
        if read and self.sphinx_convert_rst2md:
            format_options.setdefault('rst2md', self.sphinx_convert_rst2md)

    def default_formats(self, path):
        """Return the default formats, if they apply to the current path #157"""
        formats = long_form_multiple_formats(self.default_jupytext_formats)
        for fmt in formats:
            try:
                base_path(path, fmt)
                return self.default_jupytext_formats
            except InconsistentPath:
                continue

        return None

    def create_prefix_dir(self, path, fmt):
        """Create the prefix dir, if missing"""
        create_prefix_dir(self._get_os_path(path.strip('/')), fmt)

    def save(self, model, path=''):
        """Save the file model and return the model with no content."""
        if model['type'] != 'notebook':
            return super(TextFileContentsManager, self).save(model, path)

        nbk = model['content']
        try:
            metadata = nbk.get('metadata')
            rearrange_jupytext_metadata(metadata)
            jupytext_formats = metadata.get('jupytext', {}).get('formats') or self.default_formats(path)

            if not jupytext_formats:
                text_representation = metadata.get('jupytext', {}).get('text_representation', {})
                ext = os.path.splitext(path)[1]
                fmt = {'extension': ext}

                if ext == text_representation.get('extension') and text_representation.get('format_name'):
                    fmt['format_name'] = text_representation.get('format_name')

                jupytext_formats = [fmt]

            jupytext_formats = long_form_multiple_formats(jupytext_formats, metadata)

            # Set preferred formats if not format name is given yet
            jupytext_formats = [preferred_format(fmt, self.preferred_jupytext_formats_save) for fmt in jupytext_formats]

            base, fmt = find_base_path_and_format(path, jupytext_formats)
            self.update_paired_notebooks(path, fmt, jupytext_formats)

            # Save as ipynb first
            latest_result = None
            for fmt in jupytext_formats[::-1]:
                if fmt['extension'] != '.ipynb':
                    continue

                alt_path = full_path(base, fmt)
                self.create_prefix_dir(alt_path, fmt)
                self.log.info("Saving %s", os.path.basename(alt_path))
                latest_result = super(TextFileContentsManager, self).save(model, alt_path)

            # And then to the other formats, in reverse order so that
            # the first format is the most recent
            for fmt in jupytext_formats[::-1]:
                if fmt['extension'] == '.ipynb':
                    continue

                alt_path = full_path(base, fmt)
                self.create_prefix_dir(alt_path, fmt)
                self.set_default_format_options(fmt)
                if 'format_name' in fmt and fmt['extension'] not in ['.Rmd', '.md']:
                    self.log.info("Saving %s in format %s:%s",
                                  os.path.basename(alt_path), fmt['extension'][1:], fmt['format_name'])
                else:
                    self.log.info("Saving %s", os.path.basename(alt_path))
                with mock.patch('nbformat.writes', _jupytext_writes(fmt)):
                    latest_result = super(TextFileContentsManager, self).save(model, alt_path)

            return latest_result

        except Exception as err:
            raise HTTPError(400, str(err))

    def get(self, path, content=True, type=None, format=None, load_alternative_format=True):
        """ Takes a path for an entity and returns its model"""
        path = path.strip('/')
        ext = os.path.splitext(path)[1]

        # Not a notebook?
        if not self.exists(path) or (type != 'notebook' if type else ext not in self.all_nb_extensions()):
            return super(TextFileContentsManager, self).get(path, content, type, format)

        fmt = preferred_format(ext, self.preferred_jupytext_formats_read)
        if ext == '.ipynb':
            model = self._notebook_model(path, content=content)
        else:
            self.set_default_format_options(fmt, read=True)
            with mock.patch('nbformat.reads', _jupytext_reads(fmt)):
                model = self._notebook_model(path, content=content)

        if not load_alternative_format:
            return model

        if not content:
            # Modification time of a paired notebook, in this context - Jupyter is checking timestamp
            # before saving - is the most recent among all representations #118
            if path not in self.paired_notebooks:
                return model

            fmt, formats = self.paired_notebooks.get(path)
            for alt_path, _ in paired_paths(path, fmt, formats):
                if alt_path != path and self.exists(alt_path):
                    alt_model = self._notebook_model(alt_path, content=False)
                    if alt_model['last_modified'] > model['last_modified']:
                        model['last_modified'] = alt_model['last_modified']

            return model

        # We will now read a second file if this is a paired notebooks.
        nbk = model['content']
        jupytext_formats = nbk.metadata.get('jupytext', {}).get('formats') or self.default_formats(path)
        jupytext_formats = long_form_multiple_formats(jupytext_formats)

        # Compute paired notebooks from formats
        alt_paths = [(path, fmt)]
        if jupytext_formats:
            try:
                _, fmt = find_base_path_and_format(path, jupytext_formats)
                alt_paths = paired_paths(path, fmt, jupytext_formats)
                self.update_paired_notebooks(path, fmt, jupytext_formats)
            except InconsistentPath as err:
                self.log.info("Unable to read paired notebook: %s", str(err))
        else:
            if path in self.paired_notebooks:
                fmt, formats = self.paired_notebooks.get(path)
                alt_paths = paired_paths(path, fmt, formats)

        if len(alt_paths) > 1 and ext == '.ipynb':
            # Apply default options (like saving and reloading would do)
            jupytext_metadata = model['content']['metadata'].get('jupytext', {})
            self.set_default_format_options(jupytext_metadata, read=True)
            if jupytext_metadata:
                model['content']['metadata']['jupytext'] = jupytext_metadata

        org_model = model
        fmt_inputs = fmt
        path_inputs = path_outputs = path
        model_outputs = None

        # Source format is first non ipynb format found on disk
        if path.endswith('.ipynb'):
            for alt_path, alt_fmt in alt_paths:
                if not alt_path.endswith('.ipynb') and self.exists(alt_path):
                    self.log.info(u'Reading SOURCE from {}'.format(alt_path))
                    path_inputs = alt_path
                    fmt_inputs = alt_fmt
                    model_outputs = model
                    model = self.get(alt_path, content=content, type=type, format=format,
                                     load_alternative_format=False)
                    break
        # Outputs taken from ipynb if in group, if file exists
        else:
            for alt_path, _ in alt_paths:
                if alt_path.endswith('.ipynb') and self.exists(alt_path):
                    self.log.info(u'Reading OUTPUTS from {}'.format(alt_path))
                    path_outputs = alt_path
                    model_outputs = self.get(alt_path, content=content, type=type, format=format,
                                             load_alternative_format=False)
                    break

        try:
            check_file_version(model['content'], path_inputs, path_outputs)
        except Exception as err:
            raise HTTPError(400, str(err))

        # Before we combine the two files, we make sure we're not overwriting ipynb cells
        # with an outdated text file
        try:
            if model_outputs and model_outputs['last_modified'] > model['last_modified'] + \
                    timedelta(seconds=self.outdated_text_notebook_margin):
                raise HTTPError(
                    400,
                    '''{out} (last modified {out_last})
                    seems more recent than {src} (last modified {src_last})
                    Please either:
                    - open {src} in a text editor, make sure it is up to date, and save it,
                    - or delete {src} if not up to date,
                    - or increase check margin by adding, say,
                        c.ContentsManager.outdated_text_notebook_margin = 5 # in seconds # or float("inf")
                    to your .jupyter/jupyter_notebook_config.py file
                    '''.format(src=path_inputs, src_last=model['last_modified'],
                               out=path_outputs, out_last=model_outputs['last_modified']))
        except OverflowError:
            pass

        if model_outputs:
            combine_inputs_with_outputs(model['content'], model_outputs['content'], fmt_inputs)
        elif not path.endswith('.ipynb'):
            nbk = model['content']
            language = nbk.metadata.get('jupytext', {}).get('main_language', 'python')
            if 'kernelspec' not in nbk.metadata and language != 'python':
                kernelspec = kernelspec_from_language(language)
                if kernelspec:
                    nbk.metadata['kernelspec'] = kernelspec

            self.notary.sign(nbk)
            self.mark_trusted_cells(nbk, path)

        # Path and name of the notebook is the one of the original path
        model['path'] = org_model['path']
        model['name'] = org_model['name']

        return model

    def trust_notebook(self, path):
        """Trust the current notebook"""
        if path.endswith('.ipynb') or path not in self.paired_notebooks:
            super(TextFileContentsManager, self).trust_notebook(path)
            return

        fmt, formats = self.paired_notebooks[path]
        for alt_path, alt_fmt in paired_paths(path, fmt, formats):
            if alt_fmt['extension'] == '.ipynb':
                super(TextFileContentsManager, self).trust_notebook(alt_path)

    def rename_file(self, old_path, new_path):
        """Rename the current notebook, as well as its alternative representations"""
        if old_path not in self.paired_notebooks:
            super(TextFileContentsManager, self).rename_file(old_path, new_path)
            return

        fmt, formats = self.paired_notebooks.get(old_path)
        old_alt_paths = paired_paths(old_path, fmt, formats)

        # Is the new file name consistent with suffix?
        try:
            new_base = base_path(new_path, fmt)
        except Exception as err:
            raise HTTPError(400, str(err))

        for old_alt_path, alt_fmt in old_alt_paths:
            new_alt_path = full_path(new_base, alt_fmt)
            if self.exists(old_alt_path):
                super(TextFileContentsManager, self).rename_file(old_alt_path, new_alt_path)

        self.drop_paired_notebook(old_path)
        self.update_paired_notebooks(new_path, fmt, formats)
class ToggleNBExtensionApp(BaseExtensionApp):
    """A base class for apps that enable/disable extensions"""
    name = "jupyter nbextension enable/disable"
    version = __version__
    description = "Enable/disable an nbextension in configuration."

    section = Unicode(
        'notebook',
        config=True,
        help=
        """Which config section to add the extension to, 'common' will affect all pages."""
    )
    user = Bool(
        True,
        config=True,
        help="Apply the configuration only for the current user (default)")

    aliases = {'section': 'ToggleNBExtensionApp.section'}

    _toggle_value = None

    def _config_file_name_default(self):
        """The default config file name."""
        return 'jupyter_notebook_config'

    def toggle_nbextension_python(self, module):
        """Toggle some extensions in an importable Python module.

        Returns a list of booleans indicating whether the state was changed as
        requested.

        Parameters
        ----------
        module : str
            Importable Python module exposing the
            magic-named `_jupyter_nbextension_paths` function
        """
        toggle = (enable_nbextension_python
                  if self._toggle_value else disable_nbextension_python)
        return toggle(module,
                      user=self.user,
                      sys_prefix=self.sys_prefix,
                      logger=self.log)

    def toggle_nbextension(self, require):
        """Toggle some a named nbextension by require-able AMD module.

        Returns whether the state was changed as requested.

        Parameters
        ----------
        require : str
            require.js path used to load the nbextension
        """
        toggle = (enable_nbextension
                  if self._toggle_value else disable_nbextension)
        return toggle(self.section,
                      require,
                      user=self.user,
                      sys_prefix=self.sys_prefix,
                      logger=self.log)

    def start(self):
        if not self.extra_args:
            sys.exit(
                'Please specify an nbextension/package to enable or disable')
        elif len(self.extra_args) > 1:
            sys.exit('Please specify one nbextension/package at a time')
        if self.python:
            self.toggle_nbextension_python(self.extra_args[0])
        else:
            self.toggle_nbextension(self.extra_args[0])
Beispiel #15
0
 def _update_custom_traits(self):
     coefs = self.groupby('frame').apply(lambda x: x.pivot('basis_function', 'orbital', 'coefficient').fillna(value=0).values)
     coefs = Unicode(coefs.to_json(orient='values')).tag(sync=True)
     #coefs = Unicode('[' + sq.groupby(by=sq.columns, axis=1).apply(
     #            lambda x: x[x.columns[0]].values).to_json(orient='values') + ']').tag(sync=True)
     return {'momatrix_coefficient': coefs}
class UninstallNBExtensionApp(BaseExtensionApp):
    """Entry point for uninstalling notebook extensions"""
    version = __version__
    description = """Uninstall Jupyter notebook extensions
    
    Usage
    
        jupyter nbextension uninstall path/url path/url/entrypoint
        jupyter nbextension uninstall --py pythonPackageName

    This uninstalls an nbextension. By default, it uninstalls from the
    first directory on the search path where it finds the extension, but you can
    uninstall from a specific location using the --user, --sys-prefix or
    --system flags, or the --prefix option.

    If you specify the --require option, the named extension will be disabled,
    e.g.::

        jupyter nbextension uninstall myext --require myext/main

    If you use the --py or --python flag, the name should be a Python module.
    It will uninstall nbextensions listed in that module, but not the module
    itself (which you should uninstall using a package manager such as pip).
    """

    examples = """
    jupyter nbextension uninstall dest/dir dest/dir/extensionjs
    jupyter nbextension uninstall --py extensionPyPackage
    """

    aliases = {
        "prefix": "UninstallNBExtensionApp.prefix",
        "nbextensions": "UninstallNBExtensionApp.nbextensions_dir",
        "require": "UninstallNBExtensionApp.require",
    }
    flags = BaseExtensionApp.flags.copy()
    flags['system'] = ({
        'UninstallNBExtensionApp': {
            'system': True
        }
    }, "Uninstall specifically from systemwide installation directory")

    prefix = Unicode(
        '',
        config=True,
        help="Installation prefix. Overrides --user, --sys-prefix and --system"
    )
    nbextensions_dir = Unicode(
        '',
        config=True,
        help="Full path to nbextensions dir (probably use prefix or user)")
    require = Unicode('',
                      config=True,
                      help="require.js module to disable loading")
    system = Bool(
        False,
        config=True,
        help="Uninstall specifically from systemwide installation directory")

    def _config_file_name_default(self):
        """The default config file name."""
        return 'jupyter_notebook_config'

    def uninstall_extension(self):
        """Uninstall an nbextension from a specific location"""
        kwargs = {
            'user': self.user,
            'sys_prefix': self.sys_prefix,
            'prefix': self.prefix,
            'nbextensions_dir': self.nbextensions_dir,
            'logger': self.log
        }

        if self.python:
            uninstall_nbextension_python(self.extra_args[0], **kwargs)
        else:
            if self.require:
                kwargs['require'] = self.require
            uninstall_nbextension(self.extra_args[0], **kwargs)

    def find_uninstall_extension(self):
        """Uninstall an nbextension from an unspecified location"""
        name = self.extra_args[0]
        if self.python:
            _, nbexts = _get_nbextension_metadata(name)
            changed = False
            for nbext in nbexts:
                if _find_uninstall_nbextension(nbext['dest'], logger=self.log):
                    changed = True

                # Also disable it in config.
                for section in NBCONFIG_SECTIONS:
                    _find_disable_nbextension(section,
                                              nbext['require'],
                                              logger=self.log)

        else:
            changed = _find_uninstall_nbextension(name, logger=self.log)

        if not changed:
            print("No installed extension %r found." % name)

        if self.require:
            for section in NBCONFIG_SECTIONS:
                _find_disable_nbextension(section,
                                          self.require,
                                          logger=self.log)

    def start(self):
        if not self.extra_args:
            sys.exit('Please specify an nbextension to uninstall')
        elif len(self.extra_args) > 1:
            sys.exit("Only one nbextension allowed at a time. "
                     "Call multiple times to uninstall multiple extensions.")
        elif (self.user or self.sys_prefix or self.system or self.prefix
              or self.nbextensions_dir):
            # The user has specified a location from which to uninstall.
            try:
                self.uninstall_extension()
            except ArgumentConflict as e:
                sys.exit(str(e))
        else:
            # Uninstall wherever it is.
            self.find_uninstall_extension()
Beispiel #17
0
class InteractiveShellApp(Configurable):
    """A Mixin for applications that start InteractiveShell instances.

    Provides configurables for loading extensions and executing files
    as part of configuring a Shell environment.

    The following methods should be called by the :meth:`initialize` method
    of the subclass:

      - :meth:`init_path`
      - :meth:`init_shell` (to be implemented by the subclass)
      - :meth:`init_gui_pylab`
      - :meth:`init_extensions`
      - :meth:`init_code`
    """
    extensions = List(
        Unicode(),
        help="A list of dotted module names of IPython extensions to load."
    ).tag(config=True)

    extra_extensions = List(
        DottedObjectName(),
        help="""
        Dotted module name(s) of one or more IPython extensions to load.

        For specifying extra extensions to load on the command-line.

        .. versionadded:: 7.10
        """,
    ).tag(config=True)

    reraise_ipython_extension_failures = Bool(
        False,
        help="Reraise exceptions encountered loading IPython extensions?",
    ).tag(config=True)

    # Extensions that are always loaded (not configurable)
    default_extensions = List(Unicode(), [u'storemagic']).tag(config=False)

    hide_initial_ns = Bool(
        True,
        help=
        """Should variables loaded at startup (by startup files, exec_lines, etc.)
        be hidden from tools like %who?""").tag(config=True)

    exec_files = List(
        Unicode(),
        help="""List of files to run at IPython startup.""").tag(config=True)
    exec_PYTHONSTARTUP = Bool(
        True,
        help="""Run the file referenced by the PYTHONSTARTUP environment
        variable at IPython startup.""").tag(config=True)
    file_to_run = Unicode('', help="""A file to be run""").tag(config=True)

    exec_lines = List(
        Unicode(),
        help="""lines of code to run at IPython startup.""").tag(config=True)
    code_to_run = Unicode(
        '', help="Execute the given command string.").tag(config=True)
    module_to_run = Unicode(
        '', help="Run the module as a script.").tag(config=True)
    gui = CaselessStrEnum(
        gui_keys,
        allow_none=True,
        help="Enable GUI event loop integration with any of {0}.".format(
            gui_keys)).tag(config=True)
    matplotlib = CaselessStrEnum(
        backend_keys,
        allow_none=True,
        help="""Configure matplotlib for interactive use with
        the default matplotlib backend.""").tag(config=True)
    pylab = CaselessStrEnum(
        backend_keys,
        allow_none=True,
        help="""Pre-load matplotlib and numpy for interactive use,
        selecting a particular matplotlib backend and loop integration.
        """).tag(config=True)
    pylab_import_all = Bool(
        True,
        help=
        """If true, IPython will populate the user namespace with numpy, pylab, etc.
        and an ``import *`` is done from numpy and pylab, when using pylab mode.

        When False, pylab mode should not import any names into the user namespace.
        """).tag(config=True)
    ignore_cwd = Bool(
        False,
        help=
        """If True, IPython will not add the current working directory to sys.path.
        When False, the current working directory is added to sys.path, allowing imports
        of modules defined in the current directory.""").tag(config=True)
    shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
    # whether interact-loop should start
    interact = Bool(True)

    user_ns = Instance(dict, args=None, allow_none=True)

    @observe('user_ns')
    def _user_ns_changed(self, change):
        if self.shell is not None:
            self.shell.user_ns = change['new']
            self.shell.init_user_ns()

    def init_path(self):
        """Add current working directory, '', to sys.path

        Unlike Python's default, we insert before the first `site-packages`
        or `dist-packages` directory,
        so that it is after the standard library.

        .. versionchanged:: 7.2
            Try to insert after the standard library, instead of first.
        .. versionchanged:: 8.0
            Allow optionally not including the current directory in sys.path
        """
        if '' in sys.path or self.ignore_cwd:
            return
        for idx, path in enumerate(sys.path):
            parent, last_part = os.path.split(path)
            if last_part in {'site-packages', 'dist-packages'}:
                break
        else:
            # no site-packages or dist-packages found (?!)
            # back to original behavior of inserting at the front
            idx = 0
        sys.path.insert(idx, '')

    def init_shell(self):
        raise NotImplementedError("Override in subclasses")

    def init_gui_pylab(self):
        """Enable GUI event loop integration, taking pylab into account."""
        enable = False
        shell = self.shell
        if self.pylab:
            enable = lambda key: shell.enable_pylab(
                key, import_all=self.pylab_import_all)
            key = self.pylab
        elif self.matplotlib:
            enable = shell.enable_matplotlib
            key = self.matplotlib
        elif self.gui:
            enable = shell.enable_gui
            key = self.gui

        if not enable:
            return

        try:
            r = enable(key)
        except ImportError:
            self.log.warning(
                "Eventloop or matplotlib integration failed. Is matplotlib installed?"
            )
            self.shell.showtraceback()
            return
        except Exception:
            self.log.warning("GUI event loop or pylab initialization failed")
            self.shell.showtraceback()
            return

        if isinstance(r, tuple):
            gui, backend = r[:2]
            self.log.info(
                "Enabling GUI event loop integration, "
                "eventloop=%s, matplotlib=%s", gui, backend)
            if key == "auto":
                print("Using matplotlib backend: %s" % backend)
        else:
            gui = r
            self.log.info(
                "Enabling GUI event loop integration, "
                "eventloop=%s", gui)

    def init_extensions(self):
        """Load all IPython extensions in IPythonApp.extensions.

        This uses the :meth:`ExtensionManager.load_extensions` to load all
        the extensions listed in ``self.extensions``.
        """
        try:
            self.log.debug("Loading IPython extensions...")
            extensions = (self.default_extensions + self.extensions +
                          self.extra_extensions)
            for ext in extensions:
                try:
                    self.log.info("Loading IPython extension: %s" % ext)
                    self.shell.extension_manager.load_extension(ext)
                except:
                    if self.reraise_ipython_extension_failures:
                        raise
                    msg = ("Error in loading extension: {ext}\n"
                           "Check your config files in {location}".format(
                               ext=ext, location=self.profile_dir.location))
                    self.log.warning(msg, exc_info=True)
        except:
            if self.reraise_ipython_extension_failures:
                raise
            self.log.warning("Unknown error in loading extensions:",
                             exc_info=True)

    def init_code(self):
        """run the pre-flight code, specified via exec_lines"""
        self._run_startup_files()
        self._run_exec_lines()
        self._run_exec_files()

        # Hide variables defined here from %who etc.
        if self.hide_initial_ns:
            self.shell.user_ns_hidden.update(self.shell.user_ns)

        # command-line execution (ipython -i script.py, ipython -m module)
        # should *not* be excluded from %whos
        self._run_cmd_line_code()
        self._run_module()

        # flush output, so itwon't be attached to the first cell
        sys.stdout.flush()
        sys.stderr.flush()
        self.shell._sys_modules_keys = set(sys.modules.keys())

    def _run_exec_lines(self):
        """Run lines of code in IPythonApp.exec_lines in the user's namespace."""
        if not self.exec_lines:
            return
        try:
            self.log.debug("Running code from IPythonApp.exec_lines...")
            for line in self.exec_lines:
                try:
                    self.log.info("Running code in user namespace: %s" % line)
                    self.shell.run_cell(line, store_history=False)
                except:
                    self.log.warning("Error in executing line in user "
                                     "namespace: %s" % line)
                    self.shell.showtraceback()
        except:
            self.log.warning(
                "Unknown error in handling IPythonApp.exec_lines:")
            self.shell.showtraceback()

    def _exec_file(self, fname, shell_futures=False):
        try:
            full_filename = filefind(fname, [u'.', self.ipython_dir])
        except IOError:
            self.log.warning("File not found: %r" % fname)
            return
        # Make sure that the running script gets a proper sys.argv as if it
        # were run from a system shell.
        save_argv = sys.argv
        sys.argv = [full_filename] + self.extra_args[1:]
        try:
            if os.path.isfile(full_filename):
                self.log.info("Running file in user namespace: %s" %
                              full_filename)
                # Ensure that __file__ is always defined to match Python
                # behavior.
                with preserve_keys(self.shell.user_ns, '__file__'):
                    self.shell.user_ns['__file__'] = fname
                    if full_filename.endswith(
                            '.ipy') or full_filename.endswith('.ipynb'):
                        self.shell.safe_execfile_ipy(
                            full_filename, shell_futures=shell_futures)
                    else:
                        # default to python, even without extension
                        self.shell.safe_execfile(full_filename,
                                                 self.shell.user_ns,
                                                 shell_futures=shell_futures,
                                                 raise_exceptions=True)
        finally:
            sys.argv = save_argv

    def _run_startup_files(self):
        """Run files from profile startup directory"""
        startup_dirs = [self.profile_dir.startup_dir] + [
            os.path.join(p, 'startup')
            for p in chain(ENV_CONFIG_DIRS, SYSTEM_CONFIG_DIRS)
        ]
        startup_files = []

        if self.exec_PYTHONSTARTUP and os.environ.get('PYTHONSTARTUP', False) and \
                not (self.file_to_run or self.code_to_run or self.module_to_run):
            python_startup = os.environ['PYTHONSTARTUP']
            self.log.debug("Running PYTHONSTARTUP file %s...", python_startup)
            try:
                self._exec_file(python_startup)
            except:
                self.log.warning(
                    "Unknown error in handling PYTHONSTARTUP file %s:",
                    python_startup)
                self.shell.showtraceback()
        for startup_dir in startup_dirs[::-1]:
            startup_files += glob.glob(os.path.join(startup_dir, '*.py'))
            startup_files += glob.glob(os.path.join(startup_dir, '*.ipy'))
        if not startup_files:
            return

        self.log.debug("Running startup files from %s...", startup_dir)
        try:
            for fname in sorted(startup_files):
                self._exec_file(fname)
        except:
            self.log.warning("Unknown error in handling startup files:")
            self.shell.showtraceback()

    def _run_exec_files(self):
        """Run files from IPythonApp.exec_files"""
        if not self.exec_files:
            return

        self.log.debug("Running files in IPythonApp.exec_files...")
        try:
            for fname in self.exec_files:
                self._exec_file(fname)
        except:
            self.log.warning(
                "Unknown error in handling IPythonApp.exec_files:")
            self.shell.showtraceback()

    def _run_cmd_line_code(self):
        """Run code or file specified at the command-line"""
        if self.code_to_run:
            line = self.code_to_run
            try:
                self.log.info("Running code given at command line (c=): %s" %
                              line)
                self.shell.run_cell(line, store_history=False)
            except:
                self.log.warning(
                    "Error in executing line in user namespace: %s" % line)
                self.shell.showtraceback()
                if not self.interact:
                    self.exit(1)

        # Like Python itself, ignore the second if the first of these is present
        elif self.file_to_run:
            fname = self.file_to_run
            if os.path.isdir(fname):
                fname = os.path.join(fname, "__main__.py")
            if not os.path.exists(fname):
                self.log.warning("File '%s' doesn't exist", fname)
                if not self.interact:
                    self.exit(2)
            try:
                self._exec_file(fname, shell_futures=True)
            except:
                self.shell.showtraceback(tb_offset=4)
                if not self.interact:
                    self.exit(1)

    def _run_module(self):
        """Run module specified at the command-line."""
        if self.module_to_run:
            # Make sure that the module gets a proper sys.argv as if it were
            # run using `python -m`.
            save_argv = sys.argv
            sys.argv = [sys.executable] + self.extra_args
            try:
                self.shell.safe_run_module(self.module_to_run,
                                           self.shell.user_ns)
            finally:
                sys.argv = save_argv
Beispiel #18
0
class OAuthenticator(Authenticator):
    """Base class for OAuthenticators

    Subclasses must override:

    login_service (string identifying the service provider)
    login_handler (likely a subclass of OAuthLoginHandler)
    authenticate (method takes one arg - the request handler handling the oauth callback)
    """

    login_service = 'override in subclass'
    oauth_callback_url = Unicode(
        os.getenv('OAUTH_CALLBACK_URL', ''),
        config=True,
        help="""Callback URL to use.
        Typically `https://{host}/hub/oauth_callback`"""
    )

    client_id_env = ''
    client_id = Unicode(config=True)
    def _client_id_default(self):
        if self.client_id_env:
            client_id = os.getenv(self.client_id_env, '')
            if client_id:
                return client_id
        return os.getenv('OAUTH_CLIENT_ID', '')

    client_secret_env = ''
    client_secret = Unicode(config=True)
    def _client_secret_default(self):
        if self.client_secret_env:
            client_secret = os.getenv(self.client_secret_env, '')
            if client_secret:
                return client_secret
        return os.getenv('OAUTH_CLIENT_SECRET', '')

    validate_server_cert_env = 'OAUTH_TLS_VERIFY'
    validate_server_cert = Bool(config=True)
    def _validate_server_cert_default(self):
        env_value = os.getenv(self.validate_server_cert_env, '')
        if env_value == '0':
            return False
        else:
            return True

    def login_url(self, base_url):
        return url_path_join(base_url, 'oauth_login')

    login_handler = "Specify login handler class in subclass"
    callback_handler = OAuthCallbackHandler
    
    def get_callback_url(self, handler=None):
        """Get my OAuth redirect URL
        
        Either from config or guess based on the current request.
        """
        if self.oauth_callback_url:
            return self.oauth_callback_url
        elif handler:
            return guess_callback_uri(
                handler.request.protocol,
                handler.request.host,
                handler.hub.server.base_url
            )
        else:
            raise ValueError("Specify callback oauth_callback_url or give me a handler to guess with")

    def get_handlers(self, app):
        return [
            (r'/oauth_login', self.login_handler),
            (r'/oauth_callback', self.callback_handler),
        ]

    @gen.coroutine
    def authenticate(self, handler, data=None):
        raise NotImplementedError()
Beispiel #19
0
class Comm(LoggingConfigurable):
    """Class for communicating between a Frontend and a Kernel"""
    kernel = Instance('ipykernel.kernelbase.Kernel', allow_none=True)

    @default('kernel')
    def _default_kernel(self):
        if Kernel.initialized():
            return Kernel.instance()

    comm_id = Unicode()

    @default('comm_id')
    def _default_comm_id(self):
        return uuid.uuid4().hex

    primary = Bool(True, help="Am I the primary or secondary Comm?")

    target_name = Unicode('comm')
    target_module = Unicode(None,
                            allow_none=True,
                            help="""requirejs module from
        which to load comm target.""")

    topic = Bytes()

    @default('topic')
    def _default_topic(self):
        return ('comm-%s' % self.comm_id).encode('ascii')

    _open_data = Dict(help="data dict, if any, to be included in comm_open")
    _close_data = Dict(help="data dict, if any, to be included in comm_close")

    _msg_callback = Any()
    _close_callback = Any()

    _closed = Bool(True)

    def __init__(self,
                 target_name='',
                 data=None,
                 metadata=None,
                 buffers=None,
                 **kwargs):
        if target_name:
            kwargs['target_name'] = target_name
        super(Comm, self).__init__(**kwargs)
        if self.kernel:
            if self.primary:
                # I am primary, open my peer.
                self.open(data=data, metadata=metadata, buffers=buffers)
            else:
                self._closed = False

    def _publish_msg(self,
                     msg_type,
                     data=None,
                     metadata=None,
                     buffers=None,
                     **keys):
        """Helper for sending a comm message on IOPub"""
        data = {} if data is None else data
        metadata = {} if metadata is None else metadata
        content = json_clean(dict(data=data, comm_id=self.comm_id, **keys))
        self.kernel.session.send(
            self.kernel.iopub_socket,
            msg_type,
            content,
            metadata=json_clean(metadata),
            parent=self.kernel._parent_header,
            ident=self.topic,
            buffers=buffers,
        )

    def __del__(self):
        """trigger close on gc"""
        self.close()

    # publishing messages

    def open(self, data=None, metadata=None, buffers=None):
        """Open the frontend-side version of this comm"""
        if data is None:
            data = self._open_data
        comm_manager = getattr(self.kernel, 'comm_manager', None)
        if comm_manager is None:
            raise RuntimeError("Comms cannot be opened without a kernel "
                               "and a comm_manager attached to that kernel.")

        comm_manager.register_comm(self)
        try:
            self._publish_msg(
                'comm_open',
                data=data,
                metadata=metadata,
                buffers=buffers,
                target_name=self.target_name,
                target_module=self.target_module,
            )
            self._closed = False
        except:
            comm_manager.unregister_comm(self)
            raise

    def close(self, data=None, metadata=None, buffers=None):
        """Close the frontend-side version of this comm"""
        if self._closed:
            # only close once
            return
        self._closed = True
        # nothing to send if we have no kernel
        # can be None during interpreter cleanup
        if not self.kernel:
            return
        if data is None:
            data = self._close_data
        self._publish_msg(
            'comm_close',
            data=data,
            metadata=metadata,
            buffers=buffers,
        )
        self.kernel.comm_manager.unregister_comm(self)

    def send(self, data=None, metadata=None, buffers=None):
        """Send a message to the frontend-side version of this comm"""
        self._publish_msg(
            'comm_msg',
            data=data,
            metadata=metadata,
            buffers=buffers,
        )

    # registering callbacks

    def on_close(self, callback):
        """Register a callback for comm_close

        Will be called with the `data` of the close message.

        Call `on_close(None)` to disable an existing callback.
        """
        self._close_callback = callback

    def on_msg(self, callback):
        """Register a callback for comm_msg

        Will be called with the `data` of any comm_msg messages.

        Call `on_msg(None)` to disable an existing callback.
        """
        self._msg_callback = callback

    # handling of incoming messages

    def handle_close(self, msg):
        """Handle a comm_close message"""
        self.log.debug("handle_close[%s](%s)", self.comm_id, msg)
        if self._close_callback:
            self._close_callback(msg)

    def handle_msg(self, msg):
        """Handle a comm_msg message"""
        self.log.debug("handle_msg[%s](%s)", self.comm_id, msg)
        if self._msg_callback:
            shell = self.kernel.shell
            if shell:
                shell.events.trigger('pre_execute')
            self._msg_callback(msg)
            if shell:
                shell.events.trigger('post_execute')
Beispiel #20
0
class Timer(DOMWidget):
    _view_name = Unicode('HelloView').tag(sync=True)
    _view_module = Unicode('hello').tag(sync=True)
    _view_module_version = Unicode('0.1.0').tag(sync=True)
    value = Unicode('00:00:00').tag(sync=True)

    times_pressed = 0
    go_time = False
    event = threading.Event()

    # this method starts and stops the timer based on how many times the user clicked it
    def threaded_timer(self, b, max_time=180):
        if self.times_pressed == 0:
            self.times_pressed += 1
            b.description = "PAUSE"
            b.button_style = "warning"
            b.disabled = True
            self.go_time = True
            thread = threading.Thread(target=self.timeit, args=(b, max_time))
            thread.start()
            time.sleep(
                1
            )  # prevents user from spamming start/stop buttom which would throw off the time
            b.disabled = False

        # PAUSE button pushed
        elif (self.times_pressed % 2) != 0:
            self.times_pressed += 1
            b.description = "RESUME"
            b.button_style = "success"
            b.disabled = False
            self.event.clear()
            self.go_time = False

        # RESUME button pushed
        elif (self.times_pressed % 2) == 0:
            self.times_pressed += 1
            self.go_time = True
            b.description = "PAUSE"
            b.button_style = "warning"
            b.disabled = True
            self.event.set()
            time.sleep(
                1
            )  # prevents user from spamming start/stop buttom which would throw off the time
            b.disabled = False

    # this is what's actually keeping track of the time and updating the custom Timer widget
    def timeit(self, b, max_time=180):
        hours = 0
        mins = 0
        secs = 0
        for i in range(1, (max_time * 60 + 1)):
            if self.go_time:
                if (i % 60) == 0:
                    if (i % 3600) == 0:
                        secs = 0
                        mins = 0
                        hours += 1
                        if hours == 1:
                            self.one_hour_warning()
                        elif hours == 2:
                            self.two_hour_warning()
                    else:
                        self.one_hour_warning()
                        self.event.wait()
                        secs = 0
                        mins += 1
                else:
                    secs += 1
                self.value = '{hour:02}:{minute:02}:{second:02}'.format(
                    hour=hours, minute=mins, second=secs)
                time.sleep(1)

            else:
                self.event.wait()
        else:
            b.button_style = "danger"
            b.description = "TIME'S UP!"
            self.timeup()

    # show user a pop-up notification when they run out of time
    def timeup(self):
        display(
            Javascript("""
        require(
        ["base/js/dialog"],
        function(dialog) {
            dialog.modal({
                title: "TIME'S UP!",
                body: "CONGRATS! You just finished the workout.  \
                        How'd you feel about your performance? You can continue this workout \
                        if you're feeling particularly inspired today, \
                        but don't fret about trying to get the best/perfect answers. I recommend you just take a few minutes to \
                        reflect on the challenges you overcame and bring \
                        the new you to another workout when you're ready.",
                buttons: {
                    'OK': {}
                         }
                  });

              }
        );
        """))

    # pop-up notification for two hour remaining
    def two_hour_warning(self):
        display(
            Javascript("""
        require(
        ["base/js/dialog"],
        function(dialog) {
            dialog.modal({
                title: "2 HOURS LEFT,
                body: "Hey there! Just a friendly reminder that you have two more hours.  \
                        Keep up the good work!",
                buttons: {
                    'OK': {}
                         }
                  });

              }
        );
        """))

    # pop-up notification for one hour remaining
    def one_hour_warning(self):
        display(
            Javascript("""
        require(
        ["base/js/dialog"],
        function(dialog) {
            dialog.modal({
                title: "1 HOUR LEFT",
                body: "Just 1  hour remaining. How's it going? Any exciting results? \
                        Remember not to focus too much time on any one thing. \
                        Consider wrapping up your main section in the next 30-45 mins to give you time \
                        for the conclusion/results section.",
                buttons: {
                    'OK': {}
                         }
                  });

              }
        );
        """))
Beispiel #21
0
class JestApp(ProcessTestApp):
    """DEPRECATED: A notebook app that runs a jest test."""

    default_url = Unicode('/lab')
    extension_url = '/lab'
    name = __name__
    app_name = 'JupyterLab Jest Application'
    app_url = '/lab'

    coverage = Bool(False, help='Whether to run coverage').tag(config=True)

    testPathPattern = Unicode('').tag(config=True)

    testNamePattern = Unicode('').tag(config=True)

    watchAll = Bool(False).tag(config=True)

    aliases = jest_aliases

    flags = jest_flags

    jest_dir = Unicode('')

    test_config = Dict(dict(foo='bar'))

    serverapp_config = {
        "open_browser": False
    }

    @deprecated(removed_version=4)
    def get_command(self):
        """Get the command to run"""
        terminalsAvailable = self.settings['terminals_available']
        debug = self.log.level == logging.DEBUG

        # find jest
        target = osp.join('node_modules', 'jest', 'bin', 'jest.js')
        jest = ''
        cwd = osp.realpath(self.jest_dir)
        while osp.dirname(cwd) != cwd:
            if osp.exists(osp.join(cwd, target)):
                jest = osp.join(cwd, target)
                break
            cwd = osp.dirname(cwd)
        if not jest:
            raise RuntimeError('jest not found!')

        cmd = ['node']
        if self.coverage:
            cmd += [jest, '--coverage']
        elif debug:
            cmd += ['--inspect-brk', jest, '--no-cache']
            if self.watchAll:
                cmd += ['--watchAll']
            else:
                cmd += ['--watch']
        else:
            cmd += [jest]

        if self.testPathPattern:
            cmd += ['--testPathPattern', self.testPathPattern]

        if self.testNamePattern:
            cmd += ['--testNamePattern', self.testNamePattern]

        cmd += ['--runInBand']

        if self.log_level > logging.INFO:
            cmd += ['--silent']

        config = dict(baseUrl=self.serverapp.connection_url,
                      terminalsAvailable=str(terminalsAvailable),
                      token=self.settings['token'])
        config.update(**self.test_config)

        td = tempfile.mkdtemp()
        atexit.register(lambda: shutil.rmtree(td, True))

        config_path = os.path.join(td, 'config.json')
        with open(config_path, 'w') as fid:
            json.dump(config, fid)

        env = os.environ.copy()
        env['JUPYTER_CONFIG_DATA'] = config_path
        return cmd, dict(cwd=self.jest_dir, env=env)
Beispiel #22
0
class Scatter(widgets.DOMWidget):
    _view_name = Unicode('ScatterView').tag(sync=True)
    _view_module = Unicode('ipyvolume').tag(sync=True)
    _model_name = Unicode('ScatterModel').tag(sync=True)
    _model_module = Unicode('ipyvolume').tag(sync=True)
    _view_module_version = Unicode(semver_range_frontend).tag(sync=True)
    _model_module_version = Unicode(semver_range_frontend).tag(sync=True)
    x = Array(default_value=None).tag(sync=True,
                                      **array_sequence_serialization)
    y = Array(default_value=None).tag(sync=True,
                                      **array_sequence_serialization)
    z = Array(default_value=None).tag(sync=True,
                                      **array_sequence_serialization)
    vx = Array(default_value=None,
               allow_none=True).tag(sync=True, **array_sequence_serialization)
    vy = Array(default_value=None,
               allow_none=True).tag(sync=True, **array_sequence_serialization)
    vz = Array(default_value=None,
               allow_none=True).tag(sync=True, **array_sequence_serialization)
    selected = Array(default_value=None,
                     allow_none=True).tag(sync=True,
                                          **array_sequence_serialization)
    sequence_index = Integer(default_value=0).tag(sync=True)
    size = traitlets.Union([
        Array(default_value=None, allow_none=True).tag(
            sync=True, **array_sequence_serialization),
        traitlets.Float().tag(sync=True)
    ],
                           default_value=5).tag(sync=True)
    size_selected = traitlets.Union([
        Array(default_value=None, allow_none=True).tag(
            sync=True, **array_sequence_serialization),
        traitlets.Float().tag(sync=True)
    ],
                                    default_value=7).tag(sync=True)
    color = Array(default_value="red",
                  allow_none=True).tag(sync=True, **color_serialization)
    color_selected = traitlets.Union([
        Array(default_value=None, allow_none=True).tag(sync=True,
                                                       **color_serialization),
        Unicode().tag(sync=True)
    ],
                                     default_value="green").tag(sync=True)
    geo = traitlets.Unicode('diamond').tag(sync=True)
    connected = traitlets.CBool(default_value=False).tag(sync=True)
    visible = traitlets.CBool(default_value=True).tag(sync=True)

    texture = traitlets.Union([
        traitlets.Instance(ipywebrtc.MediaStream),
        Unicode(),
        traitlets.List(Unicode, [], allow_none=True),
        Image(default_value=None, allow_none=True),
        traitlets.List(Image(default_value=None, allow_none=True))
    ]).tag(sync=True, **texture_serialization)

    material = traitlets.Instance(pythreejs.ShaderMaterial).tag(
        sync=True, **ipywidgets.widget_serialization)

    @traitlets.default('material')
    def _default_material(self):
        return pythreejs.ShaderMaterial()

    line_material = traitlets.Instance(pythreejs.ShaderMaterial).tag(
        sync=True, **ipywidgets.widget_serialization)

    @traitlets.default('line_material')
    def _default_line_material(self):
        return pythreejs.ShaderMaterial()
Beispiel #23
0
class KarmaTestApp(ProcessTestApp):
    """DEPRECATED: A notebook app that runs the jupyterlab karma tests.
    """

    default_url = Unicode('/lab')
    extension_url = '/lab'
    name = __name__
    app_name = 'JupyterLab Karma Application'
    app_url = '/lab'

    karma_pattern = Unicode('src/*.spec.ts*')
    karma_base_dir = Unicode('')
    karma_coverage_dir = Unicode('')

    @deprecated(removed_version=4)
    def get_command(self):
        """Get the command to run."""
        terminalsAvailable = self.settings['terminals_available']
        token = self.settings['token']
        config = dict(baseUrl=self.serverapp.connection_url, token=token,
                      terminalsAvailable=str(terminalsAvailable),
                      foo='bar')

        cwd = self.karma_base_dir

        karma_inject_file = pjoin(cwd, 'build', 'injector.js')
        if not os.path.exists(pjoin(cwd, 'build')):
            os.makedirs(pjoin(cwd, 'build'))

        with open(karma_inject_file, 'w') as fid:
            fid.write("""
            require('es6-promise/dist/es6-promise.js');
            require('@lumino/widgets/style/index.css');

            var node = document.createElement('script');
            node.id = 'jupyter-config-data';
            node.type = 'application/json';
            node.textContent = '%s';
            document.body.appendChild(node);
            """ % json.dumps(config))

        # validate the pattern
        parser = argparse.ArgumentParser()
        parser.add_argument('--pattern', action='store')
        args, argv = parser.parse_known_args()
        pattern = args.pattern or self.karma_pattern
        files = glob.glob(pjoin(cwd, pattern))
        if not files:
            msg = 'No files matching "%s" found in "%s"'
            raise ValueError(msg % (pattern, cwd))

        # Find and validate the coverage folder if not specified
        if not self.karma_coverage_dir:
            with open(pjoin(cwd, 'package.json')) as fid:
                data = json.load(fid)
            name = data['name'].replace('@jupyterlab/test-', '')
            folder = osp.realpath(pjoin(HERE, '..', '..', 'packages', name))
            if not osp.exists(folder):
                raise ValueError(
                    'No source package directory found for "%s", use the pattern '
                    '"@jupyterlab/test-<package_dir_name>"' % name
                )
            self.karma_coverage_dir = folder

        env = os.environ.copy()
        env['KARMA_INJECT_FILE'] = karma_inject_file
        env.setdefault('KARMA_FILE_PATTERN', pattern)
        env.setdefault('KARMA_COVER_FOLDER', self.karma_coverage_dir)
        cwd = self.karma_base_dir
        cmd = ['karma', 'start'] + sys.argv[1:]
        return cmd, dict(env=env, cwd=cwd)
Beispiel #24
0
class Kernel(SingletonConfigurable):

    #---------------------------------------------------------------------------
    # Kernel interface
    #---------------------------------------------------------------------------

    # attribute to override with a GUI
    eventloop = Any(None)

    @observe('eventloop')
    def _update_eventloop(self, change):
        """schedule call to eventloop from IOLoop"""
        loop = ioloop.IOLoop.current()
        if change.new is not None:
            loop.add_callback(self.enter_eventloop)

    session = Instance(Session, allow_none=True)
    profile_dir = Instance('IPython.core.profiledir.ProfileDir',
                           allow_none=True)
    shell_streams = List()
    control_stream = Instance(ZMQStream, allow_none=True)
    iopub_socket = Any()
    iopub_thread = Any()
    stdin_socket = Any()
    log = Instance(logging.Logger, allow_none=True)

    # identities:
    int_id = Integer(-1)
    ident = Unicode()

    @default('ident')
    def _default_ident(self):
        return unicode_type(uuid.uuid4())

    # This should be overridden by wrapper kernels that implement any real
    # language.
    language_info = {}

    # any links that should go in the help menu
    help_links = List()

    # Private interface

    _darwin_app_nap = Bool(
        True,
        help="""Whether to use appnope for compatibility with OS X App Nap.

        Only affects OS X >= 10.9.
        """).tag(config=True)

    # track associations with current request
    _allow_stdin = Bool(False)
    _parent_header = Dict()
    _parent_ident = Any(b'')
    # Time to sleep after flushing the stdout/err buffers in each execute
    # cycle.  While this introduces a hard limit on the minimal latency of the
    # execute cycle, it helps prevent output synchronization problems for
    # clients.
    # Units are in seconds.  The minimum zmq latency on local host is probably
    # ~150 microseconds, set this to 500us for now.  We may need to increase it
    # a little if it's not enough after more interactive testing.
    _execute_sleep = Float(0.0005).tag(config=True)

    # Frequency of the kernel's event loop.
    # Units are in seconds, kernel subclasses for GUI toolkits may need to
    # adapt to milliseconds.
    _poll_interval = Float(0.01).tag(config=True)

    stop_on_error_timeout = Float(
        0.1,
        config=True,
        help="""time (in seconds) to wait for messages to arrive
        when aborting queued requests after an error.

        Requests that arrive within this window after an error
        will be cancelled.

        Increase in the event of unusually slow network
        causing significant delays,
        which can manifest as e.g. "Run all" in a notebook
        aborting some, but not all, messages after an error.
        """)

    # If the shutdown was requested over the network, we leave here the
    # necessary reply message so it can be sent by our registered atexit
    # handler.  This ensures that the reply is only sent to clients truly at
    # the end of our shutdown process (which happens after the underlying
    # IPython shell's own shutdown).
    _shutdown_message = None

    # This is a dict of port number that the kernel is listening on. It is set
    # by record_ports and used by connect_request.
    _recorded_ports = Dict()

    # set of aborted msg_ids
    aborted = Set()

    # Track execution count here. For IPython, we override this to use the
    # execution count we store in the shell.
    execution_count = 0

    msg_types = [
        'execute_request',
        'complete_request',
        'inspect_request',
        'history_request',
        'comm_info_request',
        'kernel_info_request',
        'connect_request',
        'shutdown_request',
        'is_complete_request',
        # deprecated:
        'apply_request',
    ]
    # add deprecated ipyparallel control messages
    control_msg_types = msg_types + ['clear_request', 'abort_request']

    def __init__(self, **kwargs):
        super(Kernel, self).__init__(**kwargs)
        # Build dict of handlers for message types
        self.shell_handlers = {}
        for msg_type in self.msg_types:
            self.shell_handlers[msg_type] = getattr(self, msg_type)

        self.control_handlers = {}
        for msg_type in self.control_msg_types:
            self.control_handlers[msg_type] = getattr(self, msg_type)

    @gen.coroutine
    def dispatch_control(self, msg):
        """dispatch control requests"""
        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Control Message", exc_info=True)
            return

        self.log.debug("Control received: %s", msg)

        # Set the parent message for side effects.
        self.set_parent(idents, msg)
        self._publish_status(u'busy')
        if self._aborting:
            self._send_abort_reply(self.control_stream, msg, idents)
            self._publish_status(u'idle')
            return

        header = msg['header']
        msg_type = header['msg_type']

        handler = self.control_handlers.get(msg_type, None)
        if handler is None:
            self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
        else:
            try:
                yield gen.maybe_future(
                    handler(self.control_stream, idents, msg))
            except Exception:
                self.log.error("Exception in control handler:", exc_info=True)

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status(u'idle')
        # flush to ensure reply is sent
        self.control_stream.flush(zmq.POLLOUT)

    def should_handle(self, stream, msg, idents):
        """Check whether a shell-channel message should be handled

        Allows subclasses to prevent handling of certain messages (e.g. aborted requests).
        """
        msg_id = msg['header']['msg_id']
        if msg_id in self.aborted:
            msg_type = msg['header']['msg_type']
            # is it safe to assume a msg_id will not be resubmitted?
            self.aborted.remove(msg_id)
            self._send_abort_reply(stream, msg, idents)
            return False
        return True

    @gen.coroutine
    def dispatch_shell(self, stream, msg):
        """dispatch shell requests"""
        # flush control requests first
        if self.control_stream:
            self.control_stream.flush()

        idents, msg = self.session.feed_identities(msg, copy=False)
        try:
            msg = self.session.deserialize(msg, content=True, copy=False)
        except:
            self.log.error("Invalid Message", exc_info=True)
            return

        # Set the parent message for side effects.
        self.set_parent(idents, msg)
        self._publish_status(u'busy')

        if self._aborting:
            self._send_abort_reply(stream, msg, idents)
            self._publish_status(u'idle')
            # flush to ensure reply is sent before
            # handling the next request
            stream.flush(zmq.POLLOUT)
            return

        msg_type = msg['header']['msg_type']

        # Print some info about this message and leave a '--->' marker, so it's
        # easier to trace visually the message chain when debugging.  Each
        # handler prints its message at the end.
        self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
        self.log.debug('   Content: %s\n   --->\n   ', msg['content'])

        if not self.should_handle(stream, msg, idents):
            return

        handler = self.shell_handlers.get(msg_type, None)
        if handler is None:
            self.log.warning("Unknown message type: %r", msg_type)
        else:
            self.log.debug("%s: %s", msg_type, msg)
            try:
                self.pre_handler_hook()
            except Exception:
                self.log.debug("Unable to signal in pre_handler_hook:",
                               exc_info=True)
            try:
                yield gen.maybe_future(handler(stream, idents, msg))
            except Exception:
                self.log.error("Exception in message handler:", exc_info=True)
            finally:
                try:
                    self.post_handler_hook()
                except Exception:
                    self.log.debug("Unable to signal in post_handler_hook:",
                                   exc_info=True)

        sys.stdout.flush()
        sys.stderr.flush()
        self._publish_status(u'idle')
        # flush to ensure reply is sent before
        # handling the next request
        stream.flush(zmq.POLLOUT)

    def pre_handler_hook(self):
        """Hook to execute before calling message handler"""
        # ensure default_int_handler during handler call
        self.saved_sigint_handler = signal(SIGINT, default_int_handler)

    def post_handler_hook(self):
        """Hook to execute after calling message handler"""
        signal(SIGINT, self.saved_sigint_handler)

    def enter_eventloop(self):
        """enter eventloop"""
        self.log.info("Entering eventloop %s", self.eventloop)
        # record handle, so we can check when this changes
        eventloop = self.eventloop

        def advance_eventloop():
            # check if eventloop changed:
            if self.eventloop is not eventloop:
                self.log.info("exiting eventloop %s", eventloop)
                return
            if self.msg_queue.qsize():
                self.log.debug("Delaying eventloop due to waiting messages")
                # still messages to process, make the eventloop wait
                schedule_next()
                return
            self.log.debug("Advancing eventloop %s", eventloop)
            try:
                eventloop(self)
            except KeyboardInterrupt:
                # Ctrl-C shouldn't crash the kernel
                self.log.error("KeyboardInterrupt caught in kernel")
                pass
            if self.eventloop is eventloop:
                # schedule advance again
                schedule_next()

        def schedule_next():
            """Schedule the next advance of the eventloop"""
            # flush the eventloop every so often,
            # giving us a chance to handle messages in the meantime
            self.log.debug("Scheduling eventloop advance")
            self.io_loop.call_later(1, advance_eventloop)

        # begin polling the eventloop
        schedule_next()

    @gen.coroutine
    def do_one_iteration(self):
        """Process a single shell message

        Any pending control messages will be flushed as well

        .. versionchanged:: 5
            This is now a coroutine
        """
        # flush messages off of shell streams into the message queue
        for stream in self.shell_streams:
            stream.flush()
        # process all messages higher priority than shell (control),
        # and at most one shell message per iteration
        priority = 0
        while priority is not None and priority < SHELL_PRIORITY:
            priority = yield self.process_one(wait=False)

    @gen.coroutine
    def process_one(self, wait=True):
        """Process one request

        Returns priority of the message handled.
        Returns None if no message was handled.
        """
        if wait:
            priority, t, dispatch, args = yield self.msg_queue.get()
        else:
            try:
                priority, t, dispatch, args = self.msg_queue.get_nowait()
            except QueueEmpty:
                return None
        yield gen.maybe_future(dispatch(*args))

    @gen.coroutine
    def dispatch_queue(self):
        """Coroutine to preserve order of message handling

        Ensures that only one message is processing at a time,
        even when the handler is async
        """

        while True:
            # receive the next message and handle it
            try:
                yield self.process_one()
            except Exception:
                self.log.exception("Error in message handler")

    _message_counter = Any(help="""Monotonic counter of messages

        Ensures messages of the same priority are handled in arrival order.
        """, )

    @default('_message_counter')
    def _message_counter_default(self):
        return itertools.count()

    def schedule_dispatch(self, priority, dispatch, *args):
        """schedule a message for dispatch"""
        idx = next(self._message_counter)

        self.msg_queue.put_nowait((
            priority,
            idx,
            dispatch,
            args,
        ))
        # ensure the eventloop wakes up
        self.io_loop.add_callback(lambda: None)

    def start(self):
        """register dispatchers for streams"""
        self.io_loop = ioloop.IOLoop.current()
        self.msg_queue = PriorityQueue()
        self.io_loop.add_callback(self.dispatch_queue)

        if self.control_stream:
            self.control_stream.on_recv(
                partial(
                    self.schedule_dispatch,
                    CONTROL_PRIORITY,
                    self.dispatch_control,
                ),
                copy=False,
            )

        for s in self.shell_streams:
            if s is self.control_stream:
                continue
            s.on_recv(
                partial(
                    self.schedule_dispatch,
                    SHELL_PRIORITY,
                    self.dispatch_shell,
                    s,
                ),
                copy=False,
            )

        # publish idle status
        self._publish_status('starting')

    def record_ports(self, ports):
        """Record the ports that this kernel is using.

        The creator of the Kernel instance must call this methods if they
        want the :meth:`connect_request` method to return the port numbers.
        """
        self._recorded_ports = ports

    #---------------------------------------------------------------------------
    # Kernel request handlers
    #---------------------------------------------------------------------------

    def _publish_execute_input(self, code, parent, execution_count):
        """Publish the code request on the iopub stream."""

        self.session.send(self.iopub_socket,
                          u'execute_input', {
                              u'code': code,
                              u'execution_count': execution_count
                          },
                          parent=parent,
                          ident=self._topic('execute_input'))

    def _publish_status(self, status, parent=None):
        """send status (busy/idle) on IOPub"""
        self.session.send(
            self.iopub_socket,
            u'status',
            {u'execution_state': status},
            parent=parent or self._parent_header,
            ident=self._topic('status'),
        )

    def set_parent(self, ident, parent):
        """Set the current parent_header

        Side effects (IOPub messages) and replies are associated with
        the request that caused them via the parent_header.

        The parent identity is used to route input_request messages
        on the stdin channel.
        """
        self._parent_ident = ident
        self._parent_header = parent

    def send_response(self,
                      stream,
                      msg_or_type,
                      content=None,
                      ident=None,
                      buffers=None,
                      track=False,
                      header=None,
                      metadata=None):
        """Send a response to the message we're currently processing.

        This accepts all the parameters of :meth:`jupyter_client.session.Session.send`
        except ``parent``.

        This relies on :meth:`set_parent` having been called for the current
        message.
        """
        return self.session.send(stream, msg_or_type, content,
                                 self._parent_header, ident, buffers, track,
                                 header, metadata)

    def init_metadata(self, parent):
        """Initialize metadata.

        Run at the beginning of execution requests.
        """
        # FIXME: `started` is part of ipyparallel
        # Remove for ipykernel 5.0
        return {
            'started': now(),
        }

    def finish_metadata(self, parent, metadata, reply_content):
        """Finish populating metadata.

        Run after completing an execution request.
        """
        return metadata

    @gen.coroutine
    def execute_request(self, stream, ident, parent):
        """handle an execute_request"""

        try:
            content = parent[u'content']
            code = py3compat.cast_unicode_py2(content[u'code'])
            silent = content[u'silent']
            store_history = content.get(u'store_history', not silent)
            user_expressions = content.get('user_expressions', {})
            allow_stdin = content.get('allow_stdin', False)
        except:
            self.log.error("Got bad msg: ")
            self.log.error("%s", parent)
            return

        stop_on_error = content.get('stop_on_error', True)

        metadata = self.init_metadata(parent)

        # Re-broadcast our input for the benefit of listening clients, and
        # start computing output
        if not silent:
            self.execution_count += 1
            self._publish_execute_input(code, parent, self.execution_count)

        reply_content = yield gen.maybe_future(
            self.do_execute(
                code,
                silent,
                store_history,
                user_expressions,
                allow_stdin,
            ))

        # Flush output before sending the reply.
        sys.stdout.flush()
        sys.stderr.flush()
        # FIXME: on rare occasions, the flush doesn't seem to make it to the
        # clients... This seems to mitigate the problem, but we definitely need
        # to better understand what's going on.
        if self._execute_sleep:
            time.sleep(self._execute_sleep)

        # Send the reply.
        reply_content = json_clean(reply_content)
        metadata = self.finish_metadata(parent, metadata, reply_content)

        reply_msg = self.session.send(stream,
                                      u'execute_reply',
                                      reply_content,
                                      parent,
                                      metadata=metadata,
                                      ident=ident)

        self.log.debug("%s", reply_msg)

        if not silent and reply_msg['content'][
                'status'] == u'error' and stop_on_error:
            yield self._abort_queues()

    def do_execute(self,
                   code,
                   silent,
                   store_history=True,
                   user_expressions=None,
                   allow_stdin=False):
        """Execute user code. Must be overridden by subclasses.
        """
        raise NotImplementedError

    @gen.coroutine
    def complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']
        cursor_pos = content['cursor_pos']

        matches = yield gen.maybe_future(self.do_complete(code, cursor_pos))
        matches = json_clean(matches)
        completion_msg = self.session.send(stream, 'complete_reply', matches,
                                           parent, ident)

    def do_complete(self, code, cursor_pos):
        """Override in subclasses to find completions.
        """
        return {
            'matches': [],
            'cursor_end': cursor_pos,
            'cursor_start': cursor_pos,
            'metadata': {},
            'status': 'ok'
        }

    @gen.coroutine
    def inspect_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = yield gen.maybe_future(
            self.do_inspect(
                content['code'],
                content['cursor_pos'],
                content.get('detail_level', 0),
            ))
        # Before we send this object over, we scrub it for JSON usage
        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'inspect_reply', reply_content, parent,
                                ident)
        self.log.debug("%s", msg)

    def do_inspect(self, code, cursor_pos, detail_level=0):
        """Override in subclasses to allow introspection.
        """
        return {'status': 'ok', 'data': {}, 'metadata': {}, 'found': False}

    @gen.coroutine
    def history_request(self, stream, ident, parent):
        content = parent['content']

        reply_content = yield gen.maybe_future(self.do_history(**content))

        reply_content = json_clean(reply_content)
        msg = self.session.send(stream, 'history_reply', reply_content, parent,
                                ident)
        self.log.debug("%s", msg)

    def do_history(self,
                   hist_access_type,
                   output,
                   raw,
                   session=None,
                   start=None,
                   stop=None,
                   n=None,
                   pattern=None,
                   unique=False):
        """Override in subclasses to access history.
        """
        return {'status': 'ok', 'history': []}

    def connect_request(self, stream, ident, parent):
        if self._recorded_ports is not None:
            content = self._recorded_ports.copy()
        else:
            content = {}
        content['status'] = 'ok'
        msg = self.session.send(stream, 'connect_reply', content, parent,
                                ident)
        self.log.debug("%s", msg)

    @property
    def kernel_info(self):
        return {
            'protocol_version': kernel_protocol_version,
            'implementation': self.implementation,
            'implementation_version': self.implementation_version,
            'language_info': self.language_info,
            'banner': self.banner,
            'help_links': self.help_links,
        }

    def kernel_info_request(self, stream, ident, parent):
        content = {'status': 'ok'}
        content.update(self.kernel_info)
        msg = self.session.send(stream, 'kernel_info_reply', content, parent,
                                ident)
        self.log.debug("%s", msg)

    def comm_info_request(self, stream, ident, parent):
        content = parent['content']
        target_name = content.get('target_name', None)

        # Should this be moved to ipkernel?
        if hasattr(self, 'comm_manager'):
            comms = {
                k: dict(target_name=v.target_name)
                for (k, v) in self.comm_manager.comms.items()
                if v.target_name == target_name or target_name is None
            }
        else:
            comms = {}
        reply_content = dict(comms=comms, status='ok')
        msg = self.session.send(stream, 'comm_info_reply', reply_content,
                                parent, ident)
        self.log.debug("%s", msg)

    @gen.coroutine
    def shutdown_request(self, stream, ident, parent):
        content = yield gen.maybe_future(
            self.do_shutdown(parent['content']['restart']))
        self.session.send(stream,
                          u'shutdown_reply',
                          content,
                          parent,
                          ident=ident)
        # same content, but different msg_id for broadcasting on IOPub
        self._shutdown_message = self.session.msg(u'shutdown_reply', content,
                                                  parent)

        self._at_shutdown()
        # call sys.exit after a short delay
        loop = ioloop.IOLoop.current()
        loop.add_timeout(time.time() + 0.1, loop.stop)

    def do_shutdown(self, restart):
        """Override in subclasses to do things when the frontend shuts down the
        kernel.
        """
        return {'status': 'ok', 'restart': restart}

    @gen.coroutine
    def is_complete_request(self, stream, ident, parent):
        content = parent['content']
        code = content['code']

        reply_content = yield gen.maybe_future(self.do_is_complete(code))
        reply_content = json_clean(reply_content)
        reply_msg = self.session.send(stream, 'is_complete_reply',
                                      reply_content, parent, ident)
        self.log.debug("%s", reply_msg)

    def do_is_complete(self, code):
        """Override in subclasses to find completions.
        """
        return {
            'status': 'unknown',
        }

    #---------------------------------------------------------------------------
    # Engine methods (DEPRECATED)
    #---------------------------------------------------------------------------

    def apply_request(self, stream, ident, parent):
        self.log.warning(
            "apply_request is deprecated in kernel_base, moving to ipyparallel."
        )
        try:
            content = parent[u'content']
            bufs = parent[u'buffers']
            msg_id = parent['header']['msg_id']
        except:
            self.log.error("Got bad msg: %s", parent, exc_info=True)
            return

        md = self.init_metadata(parent)

        reply_content, result_buf = self.do_apply(content, bufs, msg_id, md)

        # flush i/o
        sys.stdout.flush()
        sys.stderr.flush()

        md = self.finish_metadata(parent, md, reply_content)

        self.session.send(stream,
                          u'apply_reply',
                          reply_content,
                          parent=parent,
                          ident=ident,
                          buffers=result_buf,
                          metadata=md)

    def do_apply(self, content, bufs, msg_id, reply_metadata):
        """DEPRECATED"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Control messages (DEPRECATED)
    #---------------------------------------------------------------------------

    def abort_request(self, stream, ident, parent):
        """abort a specific msg by id"""
        self.log.warning(
            "abort_request is deprecated in kernel_base. It is only part of IPython parallel"
        )
        msg_ids = parent['content'].get('msg_ids', None)
        if isinstance(msg_ids, string_types):
            msg_ids = [msg_ids]
        if not msg_ids:
            self._abort_queues()
        for mid in msg_ids:
            self.aborted.add(str(mid))

        content = dict(status='ok')
        reply_msg = self.session.send(stream,
                                      'abort_reply',
                                      content=content,
                                      parent=parent,
                                      ident=ident)
        self.log.debug("%s", reply_msg)

    def clear_request(self, stream, idents, parent):
        """Clear our namespace."""
        self.log.warning(
            "clear_request is deprecated in kernel_base. It is only part of IPython parallel"
        )
        content = self.do_clear()
        self.session.send(stream,
                          'clear_reply',
                          ident=idents,
                          parent=parent,
                          content=content)

    def do_clear(self):
        """DEPRECATED since 4.0.3"""
        raise NotImplementedError

    #---------------------------------------------------------------------------
    # Protected interface
    #---------------------------------------------------------------------------

    def _topic(self, topic):
        """prefixed topic for IOPub messages"""
        base = "kernel.%s" % self.ident

        return py3compat.cast_bytes("%s.%s" % (base, topic))

    _aborting = Bool(False)

    @gen.coroutine
    def _abort_queues(self):
        for stream in self.shell_streams:
            stream.flush()
        self._aborting = True

        self.schedule_dispatch(
            ABORT_PRIORITY,
            self._dispatch_abort,
        )

    @gen.coroutine
    def _dispatch_abort(self):
        self.log.info("Finishing abort")
        yield gen.sleep(self.stop_on_error_timeout)
        self._aborting = False

    def _send_abort_reply(self, stream, msg, idents):
        """Send a reply to an aborted request"""
        self.log.info("Aborting:")
        self.log.info("%s", msg)
        reply_type = msg['header']['msg_type'].rsplit('_', 1)[0] + '_reply'
        status = {'status': 'aborted'}
        md = {'engine': self.ident}
        md.update(status)
        self.session.send(
            stream,
            reply_type,
            metadata=md,
            content=status,
            parent=msg,
            ident=idents,
        )

    def _no_raw_input(self):
        """Raise StdinNotImplentedError if active frontend doesn't support
        stdin."""
        raise StdinNotImplementedError("raw_input was called, but this "
                                       "frontend does not support stdin.")

    def getpass(self, prompt='', stream=None):
        """Forward getpass to frontends

        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "getpass was called, but this frontend does not support input requests."
            )
        if stream is not None:
            import warnings
            warnings.warn(
                "The `stream` parameter of `getpass.getpass` will have no effect when using ipykernel",
                UserWarning,
                stacklevel=2)
        return self._input_request(
            prompt,
            self._parent_ident,
            self._parent_header,
            password=True,
        )

    def raw_input(self, prompt=''):
        """Forward raw_input to frontends

        Raises
        ------
        StdinNotImplentedError if active frontend doesn't support stdin.
        """
        if not self._allow_stdin:
            raise StdinNotImplementedError(
                "raw_input was called, but this frontend does not support input requests."
            )
        return self._input_request(
            str(prompt),
            self._parent_ident,
            self._parent_header,
            password=False,
        )

    def _input_request(self, prompt, ident, parent, password=False):
        # Flush output before making the request.
        sys.stderr.flush()
        sys.stdout.flush()
        # flush the stdin socket, to purge stale replies
        while True:
            try:
                self.stdin_socket.recv_multipart(zmq.NOBLOCK)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    break
                else:
                    raise

        # Send the input request.
        content = json_clean(dict(prompt=prompt, password=password))
        self.session.send(self.stdin_socket,
                          u'input_request',
                          content,
                          parent,
                          ident=ident)

        # Await a response.
        while True:
            try:
                ident, reply = self.session.recv(self.stdin_socket, 0)
            except Exception:
                self.log.warning("Invalid Message:", exc_info=True)
            except KeyboardInterrupt:
                # re-raise KeyboardInterrupt, to truncate traceback
                raise KeyboardInterrupt
            else:
                break
        try:
            value = py3compat.unicode_to_str(reply['content']['value'])
        except:
            self.log.error("Bad input_reply: %s", parent)
            value = ''
        if value == '\x04':
            # EOF
            raise EOFError
        return value

    def _at_shutdown(self):
        """Actions taken at shutdown by the kernel, called by python's atexit.
        """
        if self._shutdown_message is not None:
            self.session.send(self.iopub_socket,
                              self._shutdown_message,
                              ident=self._topic('shutdown'))
            self.log.debug("%s", self._shutdown_message)
        [s.flush(zmq.POLLOUT) for s in self.shell_streams]
Beispiel #25
0
class PDFExporter(LatexExporter):
    """Writer designed to write to PDF files.

    This inherits from :class:`LatexExporter`. It creates a LaTeX file in
    a temporary directory using the template machinery, and then runs LaTeX
    to create a pdf.
    """

    latex_count = Integer(
        3, help="How many times latex will be called.").tag(config=True)

    latex_command = List(
        [u"xelatex", u"{filename}"],
        help="Shell command used to compile latex.").tag(config=True)

    bib_command = List(
        [u"bibtex", u"{filename}"],
        help="Shell command used to run bibtex.").tag(config=True)

    verbose = Bool(
        False, help="Whether to display the output of latex commands.").tag(
            config=True)

    texinputs = Unicode(help="texinputs dir. A notebook's directory is added")
    writer = Instance("nbconvert.writers.FilesWriter",
                      args=(),
                      kw={'build_directory': '.'})

    _captured_output = List()

    def run_command(self, command_list, filename, count, log_function):
        """Run command_list count times.
        
        Parameters
        ----------
        command_list : list
            A list of args to provide to Popen. Each element of this
            list will be interpolated with the filename to convert.
        filename : unicode
            The name of the file to convert.
        count : int
            How many times to run the command.
        
        Returns
        -------
        success : bool
            A boolean indicating if the command was successful (True)
            or failed (False).
        """
        command = [c.format(filename=filename) for c in command_list]

        # On windows with python 2.x there is a bug in subprocess.Popen and
        # unicode commands are not supported
        if sys.platform == 'win32' and sys.version_info < (3, 0):
            #We must use cp1252 encoding for calling subprocess.Popen
            #Note that sys.stdin.encoding and encoding.DEFAULT_ENCODING
            # could be different (cp437 in case of dos console)
            command = [c.encode('cp1252') for c in command]

        # This will throw a clearer error if the command is not found
        cmd = which(command_list[0])
        if cmd is None:
            link = "https://nbconvert.readthedocs.io/en/latest/install.html#installing-tex"
            raise OSError(
                "{formatter} not found on PATH, if you have not installed "
                "{formatter} you may need to do so. Find further instructions "
                "at {link}.".format(formatter=command_list[0], link=link))

        times = 'time' if count == 1 else 'times'
        self.log.info("Running %s %i %s: %s", command_list[0], count, times,
                      command)

        shell = (sys.platform == 'win32')
        if shell:
            command = subprocess.list2cmdline(command)
        env = os.environ.copy()
        prepend_to_env_search_path('TEXINPUTS', self.texinputs, env)
        prepend_to_env_search_path('BIBINPUTS', self.texinputs, env)
        prepend_to_env_search_path('BSTINPUTS', self.texinputs, env)

        with open(os.devnull, 'rb') as null:
            stdout = subprocess.PIPE if not self.verbose else None
            for index in range(count):
                p = subprocess.Popen(command,
                                     stdout=stdout,
                                     stderr=subprocess.STDOUT,
                                     stdin=null,
                                     shell=shell,
                                     env=env)
                out, _ = p.communicate()
                if p.returncode:
                    if self.verbose:
                        # verbose means I didn't capture stdout with PIPE,
                        # so it's already been displayed and `out` is None.
                        out = u''
                    else:
                        out = out.decode('utf-8', 'replace')
                    log_function(command, out)
                    self._captured_output.append(out)
                    return False  # failure
        return True  # success

    def run_latex(self, filename):
        """Run xelatex self.latex_count times."""
        def log_error(command, out):
            self.log.critical(u"%s failed: %s\n%s", command[0], command, out)

        return self.run_command(self.latex_command, filename, self.latex_count,
                                log_error)

    def run_bib(self, filename):
        """Run bibtex self.latex_count times."""
        filename = os.path.splitext(filename)[0]

        def log_error(command, out):
            self.log.warning(
                '%s had problems, most likely because there were no citations',
                command[0])
            self.log.debug(u"%s output: %s\n%s", command[0], command, out)

        return self.run_command(self.bib_command, filename, 1, log_error)

    def from_notebook_node(self, nb, resources=None, **kw):
        latex, resources = super(PDFExporter,
                                 self).from_notebook_node(nb,
                                                          resources=resources,
                                                          **kw)
        # set texinputs directory, so that local files will be found
        if resources and resources.get('metadata', {}).get('path'):
            self.texinputs = resources['metadata']['path']
        else:
            self.texinputs = getcwd()

        self._captured_outputs = []
        with TemporaryWorkingDirectory():
            notebook_name = 'notebook'
            tex_file = self.writer.write(latex,
                                         resources,
                                         notebook_name=notebook_name)
            self.log.info("Building PDF")
            rc = self.run_latex(tex_file)
            if rc:
                rc = self.run_bib(tex_file)
            if rc:
                rc = self.run_latex(tex_file)

            pdf_file = notebook_name + '.pdf'
            if not os.path.isfile(pdf_file):
                raise LatexFailed('\n'.join(self._captured_output))
            self.log.info('PDF successfully created')
            with open(pdf_file, 'rb') as f:
                pdf_data = f.read()

        # convert output extension to pdf
        # the writer above required it to be tex
        resources['output_extension'] = '.pdf'
        # clear figure outputs, extracted by latex export,
        # so we don't claim to be a multi-file export.
        resources.pop('outputs', None)

        return pdf_data, resources
Beispiel #26
0
class InlineBackend(InlineBackendConfig):
    """An object to store configuration of the inline backend."""

    # The typical default figure size is too large for inline use,
    # so we shrink the figure size to 6x4, and tweak fonts to
    # make that fit.
    rc = Dict({'figure.figsize': (6.0,4.0),
        # play nicely with white background in the Qt and notebook frontend
        'figure.facecolor': (1,1,1,0),
        'figure.edgecolor': (1,1,1,0),
        # 12pt labels get cutoff on 6x4 logplots, so use 10pt.
        'font.size': 10,
        # 72 dpi matches SVG/qtconsole
        # this only affects PNG export, as SVG has no dpi setting
        'figure.dpi': 72,
        # 10pt still needs a little more room on the xlabel:
        'figure.subplot.bottom' : .125
        },
        help="""Subset of matplotlib rcParams that should be different for the
        inline backend."""
    ).tag(config=True)

    figure_formats = Set({'png'},
                          help="""A set of figure formats to enable: 'png',
                          'retina', 'jpeg', 'svg', 'pdf'.""").tag(config=True)

    def _update_figure_formatters(self):
        if self.shell is not None:
            from yap_ipython.core.pylabtools import select_figure_formats
            select_figure_formats(self.shell, self.figure_formats, **self.print_figure_kwargs)

    def _figure_formats_changed(self, name, old, new):
        if 'jpg' in new or 'jpeg' in new:
            if not pil_available():
                raise TraitError("Requires PIL/Pillow for JPG figures")
        self._update_figure_formatters()

    figure_format = Unicode(help="""The figure format to enable (deprecated
                                         use `figure_formats` instead)""").tag(config=True)

    def _figure_format_changed(self, name, old, new):
        if new:
            self.figure_formats = {new}

    print_figure_kwargs = Dict({'bbox_inches' : 'tight'},
        help="""Extra kwargs to be passed to fig.canvas.print_figure.

        Logical examples include: bbox_inches, quality (for jpeg figures), etc.
        """
    ).tag(config=True)
    _print_figure_kwargs_changed = _update_figure_formatters

    close_figures = Bool(True,
        help="""Close all figures at the end of each cell.

        When True, ensures that each cell starts with no active figures, but it
        also means that one must keep track of references in order to edit or
        redraw figures in subsequent cells. This mode is ideal for the notebook,
        where residual plots from other cells might be surprising.

        When False, one must call figure() to create new figures. This means
        that gcf() and getfigs() can reference figures created in other cells,
        and the active figure can continue to be edited with pylab/pyplot
        methods that reference the current active figure. This mode facilitates
        iterative editing of figures, and behaves most consistently with
        other matplotlib backends, but figure barriers between cells must
        be explicit.
        """).tag(config=True)
    
    shell = Instance('yap_ipython.core.interactiveshell.InteractiveShellABC',
                     allow_none=True)
Beispiel #27
0
class SlidesExporter(HTMLExporter):
    """Exports HTML slides with reveal.js"""

    reveal_url_prefix = Unicode(help="""The URL prefix for reveal.js.
        This can be a a relative URL for a local copy of reveal.js,
        or point to a CDN.

        For speaker notes to work, a local reveal.js prefix must be used.
        """).tag(config=True)

    @default('reveal_url_prefix')
    def _reveal_url_prefix_default(self):
        if 'RevealHelpPreprocessor.url_prefix' in self.config:
            warn("Please update RevealHelpPreprocessor.url_prefix to "
                 "SlidesExporter.reveal_url_prefix in config files.")
            return self.config.RevealHelpPreprocessor.url_prefix
        return 'reveal.js'

    reveal_theme = Unicode('simple',
                           help="""
        Name of the reveal.js theme to use.

        We look for a file with this name under `reveal_url_prefix`/css/theme/`reveal_theme`.css.

        https://github.com/hakimel/reveal.js/tree/master/css/theme has
        list of themes that ship by default with reveal.js.
        """).tag(config=True)

    require_js_url = Unicode(
        "https://cdnjs.cloudflare.com/ajax/libs/require.js/2.1.10/require.min.js",
        help="""
        URL to load require.js from.

        Defaults to loading from cdnjs.
        """).tag(config=True)

    jquery_url = Unicode(
        "https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.3/jquery.min.js",
        help="""
        URL to load jQuery from.

        Defaults to loading from cdnjs.
        """).tag(config=True)

    font_awesome_url = Unicode(
        "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.1.0/css/font-awesome.css",
        help="""
        URL to load font awesome from.

        Defaults to loading from cdnjs.
        """).tag(config=True)

    @default('file_extension')
    def _file_extension_default(self):
        return '.slides.html'

    @default('template_file')
    def _template_file_default(self):
        return 'slides_reveal'

    output_mimetype = 'text/html'

    def from_notebook_node(self, nb, resources=None, **kw):
        resources = self._init_resources(resources)
        if 'reveal' not in resources:
            resources['reveal'] = {}
        resources['reveal']['url_prefix'] = self.reveal_url_prefix
        resources['reveal']['theme'] = self.reveal_theme
        resources['reveal']['require_js_url'] = self.require_js_url
        resources['reveal']['jquery_url'] = self.jquery_url
        resources['reveal']['font_awesome_url'] = self.font_awesome_url

        nb = prepare(nb)

        return super(SlidesExporter,
                     self).from_notebook_node(nb, resources=resources, **kw)
Beispiel #28
0
class BinderHub(Application):
    """An Application for starting a builder."""
    aliases = {
        'log-level': 'Application.log_level',
        'f': 'BinderHub.config_file',
        'config': 'BinderHub.config_file',
        'port': 'BinderHub.port',
    }

    flags = {
        'debug': ({
            'BinderHub': {
                'debug': True
            }
        }, "Enable debug HTTP serving & debug logging")
    }

    config_file = Unicode('binderhub_config.py',
                          help="""
        Config file to load.

        If a relative path is provided, it is taken relative to current directory
        """,
                          config=True)

    google_analytics_code = Unicode(None,
                                    allow_none=True,
                                    help="""
        The Google Analytics code to use on the main page.

        Note that we'll respect Do Not Track settings, despite the fact that GA does not.
        We will not load the GA scripts on browsers with DNT enabled.
        """,
                                    config=True)

    google_analytics_domain = Unicode('auto',
                                      help="""
        The Google Analytics domain to use on the main page.

        By default this is set to 'auto', which sets it up for current domain and all
        subdomains. This can be set to a more restrictive domain here for better privacy
        """,
                                      config=True)

    base_url = Unicode('/',
                       help="The base URL of the entire application",
                       config=True)

    @validate('base_url')
    def _valid_base_url(self, proposal):
        if not proposal.value.startswith('/'):
            proposal.value = '/' + proposal.value
        if not proposal.value.endswith('/'):
            proposal.value = proposal.value + '/'
        return proposal.value

    port = Integer(8585,
                   help="""
        Port for the builder to listen on.
        """,
                   config=True)

    appendix = Unicode(
        help="""
        Appendix to pass to repo2docker

        A multi-line string of Docker directives to run.
        Since the build context cannot be affected,
        ADD will typically not be useful.

        This should be a Python string template.
        It will be formatted with at least the following names available:

        - binder_url: the shareable URL for the current image
          (e.g. for sharing links to the current Binder)
        - repo_url: the repository URL used to build the image
        """,
        config=True,
    )

    use_registry = Bool(True,
                        help="""
        Set to true to push images to a registry & check for images in registry.

        Set to false to use only local docker images. Useful when running
        in a single node.
        """,
                        config=True)

    per_repo_quota = Integer(
        0,
        help="""
        Maximum number of concurrent users running from a given repo.

        Limits the amount of Binder that can be consumed by a single repo.

        0 (default) means no quotas.
        """,
        config=True,
    )

    docker_push_secret = Unicode('docker-push-secret',
                                 allow_none=True,
                                 help="""
        A kubernetes secret object that provides credentials for pushing built images.
        """,
                                 config=True)

    docker_image_prefix = Unicode("",
                                  help="""
        Prefix for all built docker images.

        If you are pushing to gcr.io, you would have this be:
            gcr.io/<your-project-name>/

        Set according to whatever registry you are pushing to.

        Defaults to "", which is probably not what you want :)
        """,
                                  config=True)

    docker_registry_host = Unicode("",
                                   help="""
        Docker registry host.
        """,
                                   config=True)

    docker_auth_host = Unicode(help="""
        Docker authentication host.
        """,
                               config=True)

    @default('docker_auth_host')
    def _docker_auth_host_default(self):
        return self.docker_registry_host

    docker_token_url = Unicode("",
                               help="""
        Url to request docker registry authentication token.
        """,
                               config=True)

    build_memory_limit = ByteSpecification(0,
                                           help="""
        Max amount of memory allocated for each image build process.

        0 sets no limit.

        This is used as both the memory limit & request for the pod
        that is spawned to do the building, even though the pod itself
        will not be using that much memory since the docker building is
        happening outside the pod. However, it makes kubernetes aware of
        the resources being used, and lets it schedule more intelligently.
        """,
                                           config=True)

    # TODO: Factor this out!
    github_auth_token = Unicode(None,
                                allow_none=True,
                                help="""
        GitHub OAuth token to use for talking to the GitHub API.

        Might get throttled otherwise!
        """,
                                config=True)

    debug = Bool(False,
                 help="""
        Turn on debugging.
        """,
                 config=True)

    build_docker_host = Unicode("/var/run/docker.sock",
                                config=True,
                                help="""
        The docker URL repo2docker should use to build the images.

        Currently, only paths are supported, and they are expected to be available on
        all the hosts.
        """)

    @validate('build_docker_host')
    def docker_build_host_validate(self, proposal):
        parts = urlparse(proposal.value)
        if parts.scheme != 'unix' or parts.netloc != '':
            raise TraitError(
                "Only unix domain sockets on same node are supported for build_docker_host"
            )
        return proposal.value

    hub_api_token = Unicode(
        help="""API token for talking to the JupyterHub API""",
        config=True,
    )
    hub_url = Unicode(
        help="""
        The base URL of the JupyterHub instance where users will run.

        e.g. https://hub.mybinder.org/
        """,
        config=True,
    )

    @validate('hub_url')
    def _add_slash(self, proposal):
        """trait validator to ensure hub_url ends with a trailing slash"""
        if proposal.value is not None and not proposal.value.endswith('/'):
            return proposal.value + '/'
        return proposal.value

    build_namespace = Unicode('default',
                              help="""
        Kubernetes namespace to spawn build pods in.

        Note that the docker_push_secret must refer to a secret in this namespace.
        """,
                              config=True)

    builder_image_spec = Unicode('jupyter/repo2docker:687788f',
                                 help="""
        The builder image to be used for doing builds
        """,
                                 config=True)

    build_node_selector = Dict({},
                               config=True,
                               help="""
        Select the node where build pod runs on.
        """)

    repo_providers = Dict(
        {
            'gh': GitHubRepoProvider,
            'gist': GistRepoProvider,
            'git': GitRepoProvider,
            'gl': GitLabRepoProvider,
        },
        config=True,
        help="""
        List of Repo Providers to register and try
        """)
    concurrent_build_limit = Integer(
        32, config=True, help="""The number of concurrent builds to allow.""")

    # FIXME: Come up with a better name for it?
    builder_required = Bool(True,
                            config=True,
                            help="""
        If binderhub should try to continue to run without a working build infrastructure.

        Build infrastructure is kubernetes cluster + docker. This is useful for pure HTML/CSS/JS local development.
        """)

    tornado_settings = Dict(config=True,
                            help="""
        additional settings to pass through to tornado.

        can include things like additional headers, etc.
        """)

    @staticmethod
    def add_url_prefix(prefix, handlers):
        """add a url prefix to handlers"""
        for i, tup in enumerate(handlers):
            lis = list(tup)
            lis[0] = url_path_join(prefix, tup[0])
            handlers[i] = tuple(lis)
        return handlers

    def init_pycurl(self):
        try:
            AsyncHTTPClient.configure(
                "tornado.curl_httpclient.CurlAsyncHTTPClient")
        except ImportError as e:
            self.log.debug(
                "Could not load pycurl: %s\npycurl is recommended if you have a large number of users.",
                e)
        # set max verbosity of curl_httpclient at INFO
        # because debug-logging from curl_httpclient
        # includes every full request and response
        if self.log_level < logging.INFO:
            curl_log = logging.getLogger('tornado.curl_httpclient')
            curl_log.setLevel(logging.INFO)

    def initialize(self, *args, **kwargs):
        """Load configuration settings."""
        super().initialize(*args, **kwargs)
        self.load_config_file(self.config_file)
        # hook up tornado logging
        if self.debug:
            self.log_level = logging.DEBUG
        tornado.options.options.logging = logging.getLevelName(self.log_level)
        tornado.log.enable_pretty_logging()
        self.log = tornado.log.app_log

        self.init_pycurl()

        # initialize kubernetes config
        if self.builder_required:
            try:
                kubernetes.config.load_incluster_config()
            except kubernetes.config.ConfigException:
                kubernetes.config.load_kube_config()
            self.tornado_settings[
                "kubernetes_client"] = kubernetes.client.CoreV1Api()

        # times 2 for log + build threads
        self.build_pool = ThreadPoolExecutor(self.concurrent_build_limit * 2)

        jinja_options = dict(autoescape=True, )
        jinja_env = Environment(loader=FileSystemLoader(TEMPLATE_PATH),
                                **jinja_options)
        if self.use_registry and self.builder_required:
            registry = DockerRegistry(self.docker_auth_host,
                                      self.docker_token_url,
                                      self.docker_registry_host)
        else:
            registry = None

        self.launcher = Launcher(
            parent=self,
            hub_url=self.hub_url,
            hub_api_token=self.hub_api_token,
        )

        self.tornado_settings.update({
            "docker_push_secret":
            self.docker_push_secret,
            "docker_image_prefix":
            self.docker_image_prefix,
            "static_path":
            os.path.join(os.path.dirname(__file__), "static"),
            "github_auth_token":
            self.github_auth_token,
            "debug":
            self.debug,
            'hub_url':
            self.hub_url,
            'hub_api_token':
            self.hub_api_token,
            'launcher':
            self.launcher,
            'appendix':
            self.appendix,
            "build_namespace":
            self.build_namespace,
            "builder_image_spec":
            self.builder_image_spec,
            'build_node_selector':
            self.build_node_selector,
            'build_pool':
            self.build_pool,
            'per_repo_quota':
            self.per_repo_quota,
            'repo_providers':
            self.repo_providers,
            'use_registry':
            self.use_registry,
            'registry':
            registry,
            'traitlets_config':
            self.config,
            'google_analytics_code':
            self.google_analytics_code,
            'google_analytics_domain':
            self.google_analytics_domain,
            'jinja2_env':
            jinja_env,
            'build_memory_limit':
            self.build_memory_limit,
            'build_docker_host':
            self.build_docker_host,
            'base_url':
            self.base_url,
            'static_url_prefix':
            url_path_join(self.base_url, 'static/'),
        })

        handlers = [
            (r'/metrics', MetricsHandler),
            (r"/build/([^/]+)/(.+)", BuildHandler),
            (r"/v2/([^/]+)/(.+)", ParameterizedMainHandler),
            (r"/repo/([^/]+)/([^/]+)(/.*)?", LegacyRedirectHandler),
            # for backward-compatible mybinder.org badge URLs
            # /assets/images/badge.svg
            (r'/assets/(images/badge\.svg)', tornado.web.StaticFileHandler, {
                'path': self.tornado_settings['static_path']
            }),
            # /badge.svg
            (r'/(badge\.svg)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            # /favicon_XXX.ico
            (r'/(favicon\_fail\.ico)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            (r'/(favicon\_success\.ico)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            (r'/(favicon\_building\.ico)', tornado.web.StaticFileHandler, {
                'path':
                os.path.join(self.tornado_settings['static_path'], 'images')
            }),
            (r'/', MainHandler),
            (r'.*', Custom404),
        ]
        handlers = self.add_url_prefix(self.base_url, handlers)
        self.tornado_app = tornado.web.Application(handlers,
                                                   **self.tornado_settings)

    def stop(self):
        self.http_server.stop()
        self.build_pool.shutdown()

    def start(self, run_loop=True):
        self.log.info("BinderHub starting on port %i", self.port)
        self.http_server = self.tornado_app.listen(self.port)
        if run_loop:
            tornado.ioloop.IOLoop.current().start()
class GCSContentsManager(ContentsManager):

    bucket_name = Unicode(config=True)

    bucket_notebooks_path = Unicode(config=True)

    project = Unicode(config=True)

    @default('checkpoints_class')
    def _checkpoints_class_default(self):
        return GCSCheckpointManager

    @default('bucket_notebooks_path')
    def _bucket_notebooks_path_default(self):
        return ''

    def __init__(self, **kwargs):
        super(GCSContentsManager, self).__init__(**kwargs)
        self._bucket = None

    @property
    def bucket(self):
        if not self._bucket:
            if self.project:
                storage_client = storage.Client(project=self.project)
            else:
                storage_client = storage.Client()
            self._bucket = storage_client.get_bucket(self.bucket_name)
        return self._bucket

    def _normalize_path(self, path):
        path = path or ''
        return path.strip('/')

    def _gcs_path(self, normalized_path):
        if not self.bucket_notebooks_path:
            return normalized_path
        if not normalized_path:
            return self.bucket_notebooks_path
        return posixpath.join(self.bucket_notebooks_path, normalized_path)

    def is_hidden(self, path):
        try:
            path = self._normalize_path(path)
            return posixpath.basename(path).startswith('.')
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(500, 'Internal server error: {}'.format(str(ex)))

    def file_exists(self, path):
        try:
            path = self._normalize_path(path)
            if not path:
                return False
            blob_name = self._gcs_path(path)
            blob = self.bucket.get_blob(blob_name)
            return blob is not None
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(500, 'Internal server error: {}'.format(str(ex)))

    def dir_exists(self, path):
        try:
            path = self._normalize_path(path)
            if not path:
                return self.bucket.exists()

            dir_gcs_path = self._gcs_path(path)
            if self.bucket.get_blob(dir_gcs_path):
                # There is a regular file matching the specified directory.
                #
                # Would could have both a blob matching a directory path
                # and other blobs under that path. In that case, we cannot
                # treat the path as both a directory and a regular file,
                # so we treat the regular file as overriding the logical
                # directory.
                return False

            dir_contents = self.bucket.list_blobs(prefix=dir_gcs_path)
            for _ in dir_contents:
                return True

            return False
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(500, 'Internal server error: {}'.format(str(ex)))

    def _blob_model(self, normalized_path, blob, content=True):
        blob_obj = {}
        blob_obj['path'] = normalized_path
        blob_obj['name'] = posixpath.basename(normalized_path)
        blob_obj['last_modified'] = blob.updated
        blob_obj['created'] = blob.time_created
        blob_obj['writable'] = True
        blob_obj['type'] = 'notebook' if blob_obj['name'].endswith(
            '.ipynb') else 'file'
        if not content:
            blob_obj['mimetype'] = None
            blob_obj['format'] = None
            blob_obj['content'] = None
            return blob_obj

        content_str = blob.download_as_string() if content else None
        if blob_obj['type'] == 'notebook':
            blob_obj['mimetype'] = None
            blob_obj['format'] = 'json'
            blob_obj['content'] = nbformat.reads(content_str, as_version=4)
        elif blob.content_type.startswith('text/'):
            blob_obj['mimetype'] = 'text/plain'
            blob_obj['format'] = 'text'
            blob_obj['content'] = content_str.decode(utf8_encoding)
        else:
            blob_obj['mimetype'] = 'application/octet-stream'
            blob_obj['format'] = 'base64'
            blob_obj['content'] = base64.b64encode(content_str)

        return blob_obj

    def _empty_dir_model(self, normalized_path, content=True):
        dir_obj = {}
        dir_obj['path'] = normalized_path
        dir_obj['name'] = posixpath.basename(normalized_path)
        dir_obj['type'] = 'directory'
        dir_obj['mimetype'] = None
        dir_obj['writable'] = True
        dir_obj['last_modified'] = self.bucket.time_created
        dir_obj['created'] = self.bucket.time_created
        dir_obj['format'] = None
        dir_obj['content'] = None
        if content:
            dir_obj['format'] = 'json'
            dir_obj['content'] = []
        return dir_obj

    def _list_dir(self, normalized_path, content=True):
        dir_obj = self._empty_dir_model(normalized_path, content=content)
        if not content:
            return dir_obj

        # We have to convert a list of GCS blobs, which may include multiple
        # entries corresponding to a single sub-directory, into a list of immediate
        # directory contents with no duplicates.
        #
        # To do that, we keep a dictionary of immediate children, and then convert
        # that dictionary into a list once it is fully populated.
        children = {}

        def add_child(name, model, override_existing=False):
            """Add the given child model (for either a regular file or directory), to

      the list of children for the current directory model being built.

      It is possible that we will encounter a GCS blob corresponding to a
      regular file after we encounter blobs indicating that name should be a
      directory. For example, if we have the following blobs:
          some/dir/path/
          some/dir/path/with/child
          some/dir/path
      ... then the first two entries tell us that 'path' is a subdirectory of
      'dir', but the third one tells us that it is a regular file.

      In this case, we treat the regular file as shadowing the directory. The
      'override_existing' keyword argument handles that by letting the caller
      specify that the child being added should override (i.e. hide) any
      pre-existing children with the same name.
      """
            if self.is_hidden(model['path']) and not self.allow_hidden:
                return
            if (name in children) and not override_existing:
                return
            children[name] = model

        dir_gcs_path = self._gcs_path(normalized_path)
        for b in self.bucket.list_blobs(prefix=dir_gcs_path):
            # For each nested blob, identify the corresponding immediate child
            # of the directory, and then add that child to the directory model.
            prefix_len = len(dir_gcs_path) + 1 if dir_gcs_path else 0
            suffix = b.name[prefix_len:]
            if suffix:  # Ignore the place-holder blob for the directory itself
                first_slash = suffix.find('/')
                if first_slash < 0:
                    child_path = posixpath.join(normalized_path, suffix)
                    add_child(suffix,
                              self._blob_model(child_path, b, content=False),
                              override_existing=True)
                else:
                    subdir = suffix[0:first_slash]
                    if subdir:
                        child_path = posixpath.join(normalized_path, subdir)
                        add_child(
                            subdir,
                            self._empty_dir_model(child_path, content=False))

        for child in children:
            dir_obj['content'].append(children[child])

        return dir_obj

    def get(self, path, content=True, type=None, format=None):
        try:
            path = self._normalize_path(path)
            if not type and self.dir_exists(path):
                type = 'directory'
            if type == 'directory':
                return self._list_dir(path, content=content)

            gcs_path = self._gcs_path(path)
            blob = self.bucket.get_blob(gcs_path)
            return self._blob_model(path, blob, content=content)
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(500, 'Internal server error: {}'.format(str(ex)))

    def _mkdir(self, normalized_path):
        gcs_path = self._gcs_path(normalized_path) + '/'
        blob = self.bucket.blob(gcs_path)
        blob.upload_from_string('', content_type='text/plain')
        return self._empty_dir_model(normalized_path, content=False)

    def save(self, model, path):
        try:
            self.run_pre_save_hook(model=model, path=path)

            normalized_path = self._normalize_path(path)
            if model['type'] == 'directory':
                return self._mkdir(normalized_path)

            gcs_path = self._gcs_path(normalized_path)
            blob = self.bucket.get_blob(gcs_path)
            if not blob:
                blob = self.bucket.blob(gcs_path)

            content_type = model.get('mimetype', None)
            if not content_type:
                content_type, _ = mimetypes.guess_type(normalized_path)
            contents = model['content']
            if model['type'] == 'notebook':
                contents = nbformat.writes(nbformat.from_dict(contents))
            elif model['type'] == 'file' and model['format'] == 'base64':
                b64_bytes = contents.encode('ascii')
                contents = base64.decodebytes(b64_bytes)

            # GCS doesn't allow specifying the key version, so drop it if present
            if blob.kms_key_name:
                blob._properties['kmsKeyName'] = re.split(
                    '/cryptoKeyVersions/\d+$', blob.kms_key_name)[0]

            blob.upload_from_string(contents, content_type=content_type)
            return self.get(path, type=model['type'], content=False)
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(500, 'Internal server error: {}'.format(str(ex)))

    def delete_file(self, path):
        try:
            normalized_path = self._normalize_path(path)
            gcs_path = self._gcs_path(normalized_path)
            blob = self.bucket.get_blob(gcs_path)
            if blob:
                # The path corresponds to a regular file; just delete it.
                blob.delete()
                return None

            # The path (possibly) corresponds to a directory. Delete
            # every file underneath it.
            for blob in self.bucket.list_blobs(prefix=gcs_path):
                blob.delete()

            return None
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(500, 'Internal server error: {}'.format(str(ex)))

    def rename_file(self, old_path, new_path):
        try:
            old_gcs_path = self._gcs_path(self._normalize_path(old_path))
            new_gcs_path = self._gcs_path(self._normalize_path(new_path))
            blob = self.bucket.get_blob(old_gcs_path)
            if blob:
                # The path corresponds to a regular file.
                self.bucket.rename_blob(blob, new_gcs_path)
                return None

            # The path (possibly) corresponds to a directory. Rename
            # every file underneath it.
            for b in self.bucket.list_blobs(prefix=old_gcs_path):
                self.bucket.rename_blob(
                    b, b.name.replace(old_gcs_path, new_gcs_path))
            return None
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(500, 'Internal server error: {}'.format(str(ex)))
class CombinedContentsManager(ContentsManager):
    root_dir = Unicode(config=True)

    @default('checkpoints')
    def _default_checkpoints(self):
        return CombinedCheckpointsManager(self._content_managers)

    def __init__(self, **kwargs):
        print('Creating the combined contents manager...')
        super(CombinedContentsManager, self).__init__(**kwargs)

        file_cm = FileContentsManager(**kwargs)
        file_cm.checkpoints = GenericFileCheckpoints(
            **file_cm.checkpoints_kwargs)
        gcs_cm = GCSContentsManager(**kwargs)
        self._content_managers = {
            'Local Disk': file_cm,
            'GCS': gcs_cm,
        }

    def _content_manager_for_path(self, path):
        path = path or ''
        path = path.strip('/')
        for path_prefix in self._content_managers:
            if path == path_prefix or path.startswith(path_prefix + '/'):
                relative_path = path[len(path_prefix):]
                return self._content_managers[
                    path_prefix], relative_path, path_prefix
        if '/' in path:
            path_parts = path.split('/', 1)
            return None, path_parts[1], path_parts[0]
        return None, path, ''

    def is_hidden(self, path):
        try:
            cm, relative_path, unused_path_prefix = self._content_manager_for_path(
                path)
            if not cm:
                return False
            return cm.is_hidden(relative_path)
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(
                500,
                'Internal server error: [{}] {}'.format(type(ex), str(ex)))

    def file_exists(self, path):
        try:
            cm, relative_path, unused_path_prefix = self._content_manager_for_path(
                path)
            if not cm:
                return False
            return cm.file_exists(relative_path)
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(
                500,
                'Internal server error: [{}] {}'.format(type(ex), str(ex)))

    def dir_exists(self, path):
        if path in ['', '/']:
            return True
        try:
            cm, relative_path, unused_path_prefix = self._content_manager_for_path(
                path)
            if not cm:
                return False
            return cm.dir_exists(relative_path)
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(
                500,
                'Internal server error: [{}] {}'.format(type(ex), str(ex)))

    def _make_model_relative(self, model, path_prefix):
        if 'path' in model:
            model['path'] = '{}/{}'.format(path_prefix, model['path'])
        if model.get('type', None) == 'directory':
            self._make_children_relative(model, path_prefix)

    def _make_children_relative(self, model, path_prefix):
        children = model.get('content', None)
        if children:
            for child in children:
                self._make_model_relative(child, path_prefix)

    def get(self, path, content=True, type=None, format=None):
        if path in ['', '/']:
            dir_obj = {}
            dir_obj['path'] = ''
            dir_obj['name'] = ''
            dir_obj['type'] = 'directory'
            dir_obj['mimetype'] = None
            dir_obj['writable'] = False
            dir_obj['format'] = None
            dir_obj['content'] = None
            dir_obj['format'] = 'json'
            contents = []
            for path_prefix in self._content_managers:
                child_obj = self._content_managers[path_prefix].get(
                    '', content=False)
                child_obj['path'] = path_prefix
                child_obj['name'] = path_prefix
                child_obj['writable'] = False
                contents.append(child_obj)
            dir_obj['content'] = contents
            dir_obj['created'] = contents[0]['created']
            dir_obj['last_modified'] = contents[0]['last_modified']
            return dir_obj
        try:
            cm, relative_path, path_prefix = self._content_manager_for_path(
                path)
            if not cm:
                raise HTTPError(
                    404, 'No content manager defined for "{}"'.format(path))
            model = cm.get(relative_path,
                           content=content,
                           type=type,
                           format=format)
            self._make_model_relative(model, path_prefix)
            return model
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(
                500,
                'Internal server error: [{}] {}'.format(type(ex), str(ex)))

    def save(self, model, path):
        if path in ['', '/']:
            raise HTTPError(403, 'The top-level directory is read-only')
        try:
            self.run_pre_save_hook(model=model, path=path)

            cm, relative_path, path_prefix = self._content_manager_for_path(
                path)
            if (relative_path in ['', '/']) or (path_prefix in ['', '/']):
                raise HTTPError(
                    403, 'The top-level directory contents are read-only')
            if not cm:
                raise HTTPError(
                    404, 'No content manager defined for "{}"'.format(path))

            if 'path' in model:
                model['path'] = relative_path

            model = cm.save(model, relative_path)
            if 'path' in model:
                model['path'] = path
            return model
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(
                500,
                'Internal server error: [{}] {}'.format(type(ex), str(ex)))

    def delete_file(self, path):
        if path in ['', '/']:
            raise HTTPError(403, 'The top-level directory is read-only')
        try:
            cm, relative_path, path_prefix = self._content_manager_for_path(
                path)
            if (relative_path in ['', '/']) or (path_prefix in ['', '/']):
                raise HTTPError(
                    403, 'The top-level directory contents are read-only')
            if not cm:
                raise HTTPError(
                    404, 'No content manager defined for "{}"'.format(path))
            return cm.delete_file(relative_path)
        except OSError as err:
            # The built-in file contents manager will not attempt to wrap permissions
            # errors when deleting files if they occur while trying to move the
            # to-be-deleted file to the trash, because the underlying send2trash
            # library does not set the errno attribute of the raised OSError.
            #
            # To work around this we explicitly catch such errors, check if they
            # start with the magic text "Permission denied", and then wrap them
            # in an HTTPError.
            if str(err).startswith('Permission denied'):
                raise HTTPError(403, str(err))
            raise HTTPError(
                500,
                'Internal server error: [{}] {}'.format(err.errno, str(err)))
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(
                500,
                'Internal server error: [{}] {}'.format(type(ex), str(ex)))

    def rename_file(self, old_path, new_path):
        if (old_path in ['', '/']) or (new_path in ['', '/']):
            raise HTTPError(403, 'The top-level directory is read-only')
        try:
            old_cm, old_relative_path, old_prefix = self._content_manager_for_path(
                old_path)
            if (old_relative_path in ['', '/']) or (old_prefix in ['', '/']):
                raise HTTPError(
                    403, 'The top-level directory contents are read-only')
            if not old_cm:
                raise HTTPError(
                    404,
                    'No content manager defined for "{}"'.format(old_path))

            new_cm, new_relative_path, new_prefix = self._content_manager_for_path(
                new_path)
            if (new_relative_path in ['', '/']) or (new_prefix in ['', '/']):
                raise HTTPError(
                    403, 'The top-level directory contents are read-only')
            if not new_cm:
                raise HTTPError(
                    404,
                    'No content manager defined for "{}"'.format(new_path))

            if old_cm != new_cm:
                raise HTTPError(400, 'Unsupported rename across file systems')
            return old_cm.rename_file(old_relative_path, new_relative_path)
        except HTTPError as err:
            raise err
        except Exception as ex:
            raise HTTPError(
                500,
                'Internal server error: [{}] {}'.format(type(ex), str(ex)))
Beispiel #31
0
class Authenticator(LoggingConfigurable):
    """Base class for implementing an authentication provider for JupyterHub"""

    db = Any()

    enable_auth_state = Bool(
        False,
        config=True,
        help="""Enable persisting auth_state (if available).

        auth_state will be encrypted and stored in the Hub's database.
        This can include things like authentication tokens, etc.
        to be passed to Spawners as environment variables.

        Encrypting auth_state requires the cryptography package.

        Additionally, the JUPYTERHUB_CRYPT_KEY environment variable must
        contain one (or more, separated by ;) 32B encryption keys.
        These can be either base64 or hex-encoded.

        If encryption is unavailable, auth_state cannot be persisted.

        New in JupyterHub 0.8
        """,
    )

    auth_refresh_age = Integer(
        300,
        config=True,
        help="""The max age (in seconds) of authentication info
        before forcing a refresh of user auth info.

        Refreshing auth info allows, e.g. requesting/re-validating auth tokens.

        See :meth:`.refresh_user` for what happens when user auth info is refreshed
        (nothing by default).
        """,
    )

    refresh_pre_spawn = Bool(
        False,
        config=True,
        help="""Force refresh of auth prior to spawn.

        This forces :meth:`.refresh_user` to be called prior to launching
        a server, to ensure that auth state is up-to-date.

        This can be important when e.g. auth tokens that may have expired
        are passed to the spawner via environment variables from auth_state.

        If refresh_user cannot refresh the user auth data,
        launch will fail until the user logs in again.
        """,
    )

    admin_users = Set(
        help="""
        Set of users that will have admin rights on this JupyterHub.

        Admin users have extra privileges:
         - Use the admin panel to see list of users logged in
         - Add / remove users in some authenticators
         - Restart / halt the hub
         - Start / stop users' single-user servers
         - Can access each individual users' single-user server (if configured)

        Admin access should be treated the same way root access is.

        Defaults to an empty set, in which case no user has admin access.
        """
    ).tag(config=True)

    whitelist = Set(
        help="""
        Whitelist of usernames that are allowed to log in.

        Use this with supported authenticators to restrict which users can log in. This is an
        additional whitelist that further restricts users, beyond whatever restrictions the
        authenticator has in place.

        If empty, does not perform any additional restriction.
        """
    ).tag(config=True)

    blacklist = Set(
        help="""
        Blacklist of usernames that are not allowed to log in.

        Use this with supported authenticators to restrict which users can not log in. This is an
        additional blacklist that further restricts users, beyond whatever restrictions the
        authenticator has in place.

        If empty, does not perform any additional restriction.

        .. versionadded: 0.9
        """
    ).tag(config=True)

    @observe('whitelist')
    def _check_whitelist(self, change):
        short_names = [name for name in change['new'] if len(name) <= 1]
        if short_names:
            sorted_names = sorted(short_names)
            single = ''.join(sorted_names)
            string_set_typo = "set('%s')" % single
            self.log.warning(
                "whitelist contains single-character names: %s; did you mean set([%r]) instead of %s?",
                sorted_names[:8],
                single,
                string_set_typo,
            )

    custom_html = Unicode(
        help="""
        HTML form to be overridden by authenticators if they want a custom authentication form.

        Defaults to an empty string, which shows the default username/password form.
        """
    )

    login_service = Unicode(
        help="""
        Name of the login service that this authenticator is providing using to authenticate users.

        Example: GitHub, MediaWiki, Google, etc.

        Setting this value replaces the login form with a "Login with <login_service>" button.

        Any authenticator that redirects to an external service (e.g. using OAuth) should set this.
        """
    )

    username_pattern = Unicode(
        help="""
        Regular expression pattern that all valid usernames must match.

        If a username does not match the pattern specified here, authentication will not be attempted.

        If not set, allow any username.
        """
    ).tag(config=True)

    @observe('username_pattern')
    def _username_pattern_changed(self, change):
        if not change['new']:
            self.username_regex = None
        self.username_regex = re.compile(change['new'])

    username_regex = Any(
        help="""
        Compiled regex kept in sync with `username_pattern`
        """
    )

    def validate_username(self, username):
        """Validate a normalized username

        Return True if username is valid, False otherwise.
        """
        if '/' in username:
            # / is not allowed in usernames
            return False
        if not username:
            # empty usernames are not allowed
            return False
        if not self.username_regex:
            return True
        return bool(self.username_regex.match(username))

    username_map = Dict(
        help="""Dictionary mapping authenticator usernames to JupyterHub users.

        Primarily used to normalize OAuth user names to local users.
        """
    ).tag(config=True)

    delete_invalid_users = Bool(
        False,
        help="""Delete any users from the database that do not pass validation

        When JupyterHub starts, `.add_user` will be called
        on each user in the database to verify that all users are still valid.

        If `delete_invalid_users` is True,
        any users that do not pass validation will be deleted from the database.
        Use this if users might be deleted from an external system,
        such as local user accounts.

        If False (default), invalid users remain in the Hub's database
        and a warning will be issued.
        This is the default to avoid data loss due to config changes.
        """,
    )

    post_auth_hook = Any(
        config=True,
        help="""
        An optional hook function that you can implement to do some
        bootstrapping work during authentication. For example, loading user account
        details from an external system.

        This function is called after the user has passed all authentication checks
        and is ready to successfully authenticate. This function must return the
        authentication dict reguardless of changes to it.

        This maybe a coroutine.

        .. versionadded: 1.0

        Example::

            import os, pwd
            def my_hook(authenticator, handler, authentication):
                user_data = pwd.getpwnam(authentication['name'])
                spawn_data = {
                    'pw_data': user_data
                    'gid_list': os.getgrouplist(authentication['name'], user_data.pw_gid)
                }

                if authentication['auth_state'] is None:
                    authentication['auth_state'] = {}
                authentication['auth_state']['spawn_data'] = spawn_data

                return authentication

            c.Authenticator.post_auth_hook = my_hook

        """,
    )

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for method_name in (
            'check_whitelist',
            'check_blacklist',
            'check_group_whitelist',
        ):
            original_method = getattr(self, method_name, None)
            if original_method is None:
                # no such method (check_group_whitelist is optional)
                continue
            signature = inspect.signature(original_method)
            if 'authentication' not in signature.parameters:
                # adapt to pre-1.0 signature for compatibility
                warnings.warn(
                    """
                    {0}.{1} does not support the authentication argument,
                    added in JupyterHub 1.0.

                    It should have the signature:

                    def {1}(self, username, authentication=None):
                        ...

                    Adapting for compatibility.
                    """.format(
                        self.__class__.__name__, method_name
                    ),
                    DeprecationWarning,
                )

                def wrapped_method(username, authentication=None, **kwargs):
                    return original_method(username, **kwargs)

                setattr(self, method_name, wrapped_method)

    async def run_post_auth_hook(self, handler, authentication):
        """
        Run the post_auth_hook if defined

        .. versionadded: 1.0

        Args:
            handler (tornado.web.RequestHandler): the current request handler
            authentication (dict): User authentication data dictionary. Contains the
                username ('name'), admin status ('admin'), and auth state dictionary ('auth_state').
        Returns:
            Authentication (dict):
                The hook must always return the authentication dict
        """
        if self.post_auth_hook is not None:
            authentication = await maybe_future(
                self.post_auth_hook(self, handler, authentication)
            )
        return authentication

    def normalize_username(self, username):
        """Normalize the given username and return it

        Override in subclasses if usernames need different normalization rules.

        The default attempts to lowercase the username and apply `username_map` if it is
        set.
        """
        username = username.lower()
        username = self.username_map.get(username, username)
        return username

    def check_whitelist(self, username, authentication=None):
        """Check if a username is allowed to authenticate based on whitelist configuration

        Return True if username is allowed, False otherwise.
        No whitelist means any username is allowed.

        Names are normalized *before* being checked against the whitelist.

        .. versionchanged:: 1.0
            Signature updated to accept authentication data and any future changes
        """
        if not self.whitelist:
            # No whitelist means any name is allowed
            return True
        return username in self.whitelist

    def check_blacklist(self, username, authentication=None):
        """Check if a username is blocked to authenticate based on blacklist configuration

        Return True if username is allowed, False otherwise.
        No blacklist means any username is allowed.

        Names are normalized *before* being checked against the blacklist.

        .. versionadded: 0.9

        .. versionchanged:: 1.0
            Signature updated to accept authentication data as second argument
        """
        if not self.blacklist:
            # No blacklist means any name is allowed
            return True
        return username not in self.blacklist

    async def get_authenticated_user(self, handler, data):
        """Authenticate the user who is attempting to log in

        Returns user dict if successful, None otherwise.

        This calls `authenticate`, which should be overridden in subclasses,
        normalizes the username if any normalization should be done,
        and then validates the name in the whitelist.

        This is the outer API for authenticating a user.
        Subclasses should not override this method.

        The various stages can be overridden separately:
         - `authenticate` turns formdata into a username
         - `normalize_username` normalizes the username
         - `check_whitelist` checks against the user whitelist

        .. versionchanged:: 0.8
            return dict instead of username
        """
        authenticated = await maybe_future(self.authenticate(handler, data))
        if authenticated is None:
            return
        if isinstance(authenticated, dict):
            if 'name' not in authenticated:
                raise ValueError("user missing a name: %r" % authenticated)
        else:
            authenticated = {'name': authenticated}
        authenticated.setdefault('auth_state', None)
        # Leave the default as None, but reevaluate later post-whitelist
        authenticated.setdefault('admin', None)

        # normalize the username
        authenticated['name'] = username = self.normalize_username(
            authenticated['name']
        )
        if not self.validate_username(username):
            self.log.warning("Disallowing invalid username %r.", username)
            return

        blacklist_pass = await maybe_future(
            self.check_blacklist(username, authenticated)
        )
        whitelist_pass = await maybe_future(
            self.check_whitelist(username, authenticated)
        )

        if blacklist_pass:
            pass
        else:
            self.log.warning("User %r in blacklist. Stop authentication", username)
            return

        if whitelist_pass:
            if authenticated['admin'] is None:
                authenticated['admin'] = await maybe_future(
                    self.is_admin(handler, authenticated)
                )

            authenticated = await self.run_post_auth_hook(handler, authenticated)

            return authenticated
        else:
            self.log.warning("User %r not in whitelist.", username)
            return

    async def refresh_user(self, user, handler=None):
        """Refresh auth data for a given user

        Allows refreshing or invalidating auth data.

        Only override if your authenticator needs
        to refresh its data about users once in a while.

        .. versionadded: 1.0

        Args:
            user (User): the user to refresh
            handler (tornado.web.RequestHandler or None): the current request handler
        Returns:
            auth_data (bool or dict):
                Return **True** if auth data for the user is up-to-date
                and no updates are required.

                Return **False** if the user's auth data has expired,
                and they should be required to login again.

                Return a **dict** of auth data if some values should be updated.
                This dict should have the same structure as that returned
                by :meth:`.authenticate()` when it returns a dict.
                Any fields present will refresh the value for the user.
                Any fields not present will be left unchanged.
                This can include updating `.admin` or `.auth_state` fields.
        """
        return True

    def is_admin(self, handler, authentication):
        """Authentication helper to determine a user's admin status.

        .. versionadded: 1.0

        Args:
            handler (tornado.web.RequestHandler): the current request handler
            authentication: The authetication dict generated by `authenticate`.
        Returns:
            admin_status (Bool or None):
                The admin status of the user, or None if it could not be
                determined or should not change.
        """
        return True if authentication['name'] in self.admin_users else None

    async def authenticate(self, handler, data):
        """Authenticate a user with login form data

        This must be a coroutine.

        It must return the username on successful authentication,
        and return None on failed authentication.

        Checking the whitelist is handled separately by the caller.

        .. versionchanged:: 0.8
            Allow `authenticate` to return a dict containing auth_state.

        Args:
            handler (tornado.web.RequestHandler): the current request handler
            data (dict): The formdata of the login form.
                         The default form has 'username' and 'password' fields.
        Returns:
            user (str or dict or None):
                The username of the authenticated user,
                or None if Authentication failed.

                The Authenticator may return a dict instead, which MUST have a
                key `name` holding the username, and MAY have two optional keys
                set: `auth_state`, a dictionary of of auth state that will be
                persisted; and `admin`, the admin setting value for the user.
        """

    def pre_spawn_start(self, user, spawner):
        """Hook called before spawning a user's server

        Can be used to do auth-related startup, e.g. opening PAM sessions.
        """

    def post_spawn_stop(self, user, spawner):
        """Hook called after stopping a user container

        Can be used to do auth-related cleanup, e.g. closing PAM sessions.
        """

    def add_user(self, user):
        """Hook called when a user is added to JupyterHub

        This is called:
         - When a user first authenticates
         - When the hub restarts, for all users.

        This method may be a coroutine.

        By default, this just adds the user to the whitelist.

        Subclasses may do more extensive things, such as adding actual unix users,
        but they should call super to ensure the whitelist is updated.

        Note that this should be idempotent, since it is called whenever the hub restarts
        for all users.

        Args:
            user (User): The User wrapper object
        """
        if not self.validate_username(user.name):
            raise ValueError("Invalid username: %s" % user.name)
        if self.whitelist:
            self.whitelist.add(user.name)

    def delete_user(self, user):
        """Hook called when a user is deleted

        Removes the user from the whitelist.
        Subclasses should call super to ensure the whitelist is updated.

        Args:
            user (User): The User wrapper object
        """
        self.whitelist.discard(user.name)

    auto_login = Bool(
        False,
        config=True,
        help="""Automatically begin the login process

        rather than starting with a "Login with..." link at `/hub/login`

        To work, `.login_url()` must give a URL other than the default `/hub/login`,
        such as an oauth handler or another automatic login handler,
        registered with `.get_handlers()`.

        .. versionadded:: 0.8
        """,
    )

    def login_url(self, base_url):
        """Override this when registering a custom login handler

        Generally used by authenticators that do not use simple form-based authentication.

        The subclass overriding this is responsible for making sure there is a handler
        available to handle the URL returned from this method, using the `get_handlers`
        method.

        Args:
            base_url (str): the base URL of the Hub (e.g. /hub/)

        Returns:
            str: The login URL, e.g. '/hub/login'
        """
        return url_path_join(base_url, 'login')

    def logout_url(self, base_url):
        """Override when registering a custom logout handler

        The subclass overriding this is responsible for making sure there is a handler
        available to handle the URL returned from this method, using the `get_handlers`
        method.

        Args:
            base_url (str): the base URL of the Hub (e.g. /hub/)

        Returns:
            str: The logout URL, e.g. '/hub/logout'
        """
        return url_path_join(base_url, 'logout')

    def get_handlers(self, app):
        """Return any custom handlers the authenticator needs to register

        Used in conjugation with `login_url` and `logout_url`.

        Args:
            app (JupyterHub Application):
                the application object, in case it needs to be accessed for info.
        Returns:
            handlers (list):
                list of ``('/url', Handler)`` tuples passed to tornado.
                The Hub prefix is added to any URLs.
        """
        return [('/login', LoginHandler)]
Beispiel #32
0
 def _update_custom_traits(self):
     sgto = self.groupby('frame').apply(lambda x: x.groupby('l').apply( lambda y: y['ml'].values))
     sgto = Unicode(sgto.to_json(orient='values')).tag(sync=True)
     return {'sphericalgtforder_ml': sgto}
class InstallNBExtensionApp(BaseExtensionApp):
    """Entry point for installing notebook extensions"""
    description = """Install Jupyter notebook extensions
    
    Usage
    
        jupyter nbextension install path|url [--user|--sys-prefix]
    
    This copies a file or a folder into the Jupyter nbextensions directory.
    If a URL is given, it will be downloaded.
    If an archive is given, it will be extracted into nbextensions.
    If the requested files are already up to date, no action is taken
    unless --overwrite is specified.
    """

    examples = """
    jupyter nbextension install /path/to/myextension
    """
    aliases = aliases
    flags = flags

    overwrite = Bool(False,
                     config=True,
                     help="Force overwrite of existing files")
    symlink = Bool(False,
                   config=True,
                   help="Create symlinks instead of copying files")

    prefix = Unicode('', config=True, help="Installation prefix")
    nbextensions_dir = Unicode(
        '',
        config=True,
        help="Full path to nbextensions dir (probably use prefix or user)")
    destination = Unicode('',
                          config=True,
                          help="Destination for the copy or symlink")

    def _config_file_name_default(self):
        """The default config file name."""
        return 'jupyter_notebook_config'

    def install_extensions(self):
        """Perform the installation of nbextension(s)"""
        if len(self.extra_args) > 1:
            raise ValueError(
                "Only one nbextension allowed at a time. "
                "Call multiple times to install multiple extensions.")

        if self.python:
            install = install_nbextension_python
            kwargs = {}
        else:
            install = install_nbextension
            kwargs = {'destination': self.destination}

        full_dests = install(self.extra_args[0],
                             overwrite=self.overwrite,
                             symlink=self.symlink,
                             user=self.user,
                             sys_prefix=self.sys_prefix,
                             prefix=self.prefix,
                             nbextensions_dir=self.nbextensions_dir,
                             logger=self.log,
                             **kwargs)

        if full_dests:
            self.log.info(
                u"\nTo initialize this nbextension in the browser every time"
                " the notebook (or other app) loads:\n\n"
                "      jupyter nbextension enable {}{}{}{}\n".format(
                    self.extra_args[0] if self.python else "<the entry point>",
                    " --user" if self.user else "",
                    " --py" if self.python else "",
                    " --sys-prefix" if self.sys_prefix else ""))

    def start(self):
        """Perform the App's function as configured"""
        if not self.extra_args:
            sys.exit('Please specify an nbextension to install')
        else:
            try:
                self.install_extensions()
            except ArgumentConflict as e:
                sys.exit(str(e))