Пример #1
0
class Viewer(widgets.DOMWidget):
    """
    Generic object for viewing and labeling Candidate objects in their rendered Contexts.
    """
    _view_name = Unicode('ViewerView').tag(sync=True)
    _view_module = Unicode('viewer').tag(sync=True)
    cids = List().tag(sync=True)
    html = Unicode('<h3>Error!</h3>').tag(sync=True)
    _labels_serialized = Unicode().tag(sync=True)
    _selected_cid = Int().tag(sync=True)

    def __init__(self,
                 candidates,
                 session,
                 gold=[],
                 n_per_page=3,
                 height=225,
                 annotator_name=None):
        """
        Initializes a Viewer.

        The Viewer uses the keyword argument annotator_name to define a AnnotatorLabelKey with that name.

        :param candidates: A Python container of Candidates (e.g., not a CandidateSet, but candidate_set.candidates)
        :param session: The SnorkelSession for the database backend
        :param gold: Optional, Python container of Candidates that are know to have positive labels
        :param n_per_page: Optional, number of Contexts to display per page
        :param height: Optional, the height in pixels of the Viewer
        :param annotator_name: Name of the human using the Viewer, for saving their work. Defaults to system username.
        """
        super(Viewer, self).__init__()
        self.session = session

        # By default, use the username as annotator name
        name = annotator_name if annotator_name is not None else getpass.getuser(
        )

        # Sets up the AnnotationKey to use
        self.annotator = self.session.query(GoldLabelKey).filter(
            GoldLabelKey.name == name).first()
        if self.annotator is None:
            self.annotator = GoldLabelKey(name=name)
            session.add(self.annotator)
            session.commit()

        # Viewer display configs
        self.n_per_page = n_per_page
        self.height = height

        # Note that the candidates are not necessarily commited to the DB, so they *may not have* non-null ids
        # Hence, we index by their position in this list
        # We get the sorted candidates and all contexts required, either from unary or binary candidates
        self.gold = list(gold)
        self.candidates = sorted(list(candidates),
                                 key=lambda c: c[0].char_start)
        self.contexts = list(
            set(c[0].get_parent() for c in self.candidates + self.gold))

        # If committed, sort contexts by id
        try:
            self.contexts = sorted(self.contexts, key=lambda c: c.id)
        except:
            pass

        # Loads existing annotations
        self.annotations = [None] * len(self.candidates)
        self.annotations_stable = [None] * len(self.candidates)
        init_labels_serialized = []
        for i, candidate in enumerate(self.candidates):

            # First look for the annotation in the primary annotations table
            existing_annotation = self.session.query(GoldLabel) \
                .filter(GoldLabel.key == self.annotator) \
                .filter(GoldLabel.candidate == candidate) \
                .first()
            if existing_annotation is not None:
                self.annotations[i] = existing_annotation
                if existing_annotation.value == 1:
                    value_string = 'true'
                elif existing_annotation.value == -1:
                    value_string = 'false'
                else:
                    raise ValueError(
                        str(existing_annotation) +
                        ' has value not in {1, -1}, which Viewer does not support.'
                    )
                init_labels_serialized.append(str(i) + '~~' + value_string)

                # If the annotator label is in the main table, also get its stable version
                context_stable_ids = '~~'.join(
                    [c.stable_id for c in candidate.get_contexts()])
                existing_annotation_stable = self.session.query(StableLabel) \
                                                 .filter(StableLabel.context_stable_ids == context_stable_ids)\
                                                 .filter(StableLabel.annotator_name == name).one_or_none()

                # If stable version is not available, create it here
                # NOTE: This is for versioning issues, should be removed?
                if existing_annotation_stable is None:
                    context_stable_ids = '~~'.join(
                        [c.stable_id for c in candidate.get_contexts()])
                    existing_annotation_stable = StableLabel(context_stable_ids=context_stable_ids,\
                                                             annotator_name=self.annotator.name,\
                                                             split=candidate.split,\
                                                             value=existing_annotation.value)
                    self.session.add(existing_annotation_stable)
                    self.session.commit()

                self.annotations_stable[i] = existing_annotation_stable

        self._labels_serialized = ','.join(init_labels_serialized)

        # Configures message handler
        self.on_msg(self.handle_label_event)

        # display js, construct html and pass on to widget model
        self.render()

    def _tag_span(self, html, cids, gold=False):
        """
        Create the span around a segment of the context associated with one or more candidates / gold annotations
        """
        classes = ['candidate'] if len(cids) > 0 else []
        classes += ['gold-annotation'] if gold else []
        classes += list(map(str, cids))

        # Scrub for non-ascii characters; replace with ?
        return u'<span class="{classes}">{html}</span>'.format(
            classes=' '.join(classes), html=html)

    def _tag_context(self, context, candidates, gold):
        """Given the raw context, tag the spans using the generic _tag_span method"""
        raise NotImplementedError()

    def render(self):
        """Renders viewer pane"""
        cids = []

        # Iterate over pages of contexts
        pid = 0
        pages = []
        N = len(self.contexts)
        for i in range(0, N, self.n_per_page):
            page_cids = []
            lis = []
            for j in range(i, min(N, i + self.n_per_page)):
                context = self.contexts[j]

                # Get the candidates in this context
                candidates = [
                    c for c in self.candidates if c[0].get_parent() == context
                ]
                gold = [g for g in self.gold if g.context_id == context.id]

                # Construct the <li> and page view elements
                li_data = self._tag_context(context, candidates, gold)
                lis.append(LI_HTML.format(data=li_data, context_id=context.id))
                page_cids.append(
                    [self.candidates.index(c) for c in candidates])

            # Assemble the page...
            pages.append(
                PAGE_HTML.format(
                    pid=pid,
                    data=''.join(lis),
                    etc=' style="display: block;"' if i == 0 else ''))
            cids.append(page_cids)
            pid += 1

        # Render in primary Viewer template
        self.cids = cids
        self.html = open(HOME + '/viewer/viewer.html').read() % (
            self.height, ''.join(pages))
        display(Javascript(open(HOME + '/viewer/viewer.js').read()))

    def _get_labels(self):
        """
        De-serialize labels from Javascript widget, map to internal candidate id, and return as list of tuples
        """
        LABEL_MAP = {'true': 1, 'false': -1}
        labels = [
            x.split('~~') for x in self._labels_serialized.split(',')
            if len(x) > 0
        ]
        vals = [(int(cid), LABEL_MAP.get(l, 0)) for cid, l in labels]
        return vals

    def handle_label_event(self, _, content, buffers):
        """
        Handles label event by persisting new label
        """
        if content.get('event', '') == 'set_label':
            cid = content.get('cid', None)
            value = content.get('value', None)
            if value is True:
                value = 1
            elif value is False:
                value = -1
            else:
                raise ValueError('Unexpected label returned from widget: ' +
                                 str(value) +
                                 '. Expected values are True and False.')

            # If label already exists, just update value (in both AnnotatorLabel and StableLabel)
            if self.annotations[cid] is not None:
                if self.annotations[cid].value != value:
                    self.annotations[cid].value = value
                    self.annotations_stable[cid].value = value
                    self.session.commit()

            # Otherwise, create a AnnotatorLabel *and a StableLabel*
            else:
                candidate = self.candidates[cid]

                # Create AnnotatorLabel
                self.annotations[cid] = GoldLabel(key=self.annotator,
                                                  candidate=candidate,
                                                  value=value)
                self.session.add(self.annotations[cid])

                # Create StableLabel
                context_stable_ids = '~~'.join(
                    [c.stable_id for c in candidate.get_contexts()])
                self.annotations_stable[cid] = StableLabel(context_stable_ids=context_stable_ids,\
                                                           annotator_name=self.annotator.name,\
                                                           value=value,\
                                                           split=candidate.split)
                self.session.add(self.annotations_stable[cid])
                self.session.commit()

        elif content.get('event', '') == 'delete_label':
            cid = content.get('cid', None)
            self.session.delete(self.annotations[cid])
            self.annotations[cid] = None
            self.session.delete(self.annotations_stable[cid])
            self.annotations_stable[cid] = None
            self.session.commit()

    def get_selected(self):
        return self.candidates[self._selected_cid]
Пример #2
0
class Widget(LoggingConfigurable):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None
    _read_only_enabled = True
    widgets = {}
    widget_types = {}

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        widget_class = import_item(str(msg['content']['data']['widget_class']))
        widget = widget_class(comm=comm)


    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_module = Unicode(None, allow_none=True, help="""A requirejs module name
        in which to find _model_name. If empty, look in the global registry.""")
    _model_name = Unicode('WidgetModel', help="""Name of the backbone model 
        registered in the front-end to create and sync this widget with.""")
    _view_module = Unicode(help="""A requirejs module in which to find _view_name.
        If empty, look in the global registry.""", sync=True)
    _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end
        to use to represent the widget.""", sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)
    
    msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the 
        front-end can send before receiving an idle msg from the back-end.""")
    
    version = Int(0, sync=True, help="""Widget's version""")
    keys = List()
    def _keys_default(self):
        return [name for name in self.traits(sync=True)]
    
    _property_lock = Dict()
    _holding_sync = False
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())
    
    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            args = dict(target_name='ipython.widget',
                        data={'model_name': self._model_name,
                              'model_module': self._model_module})
            if self._model_id is not None:
                args['comm_id'] = self._model_id
            self.comm = Comm(**args)

    def _comm_changed(self, name, new):
        """Called when the comm is changed."""
        if new is None:
            return
        self._model_id = self.model_id
        
        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self
        
        # first update
        self.send_state()

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def __setattr__(self, name, value):
        """Overload of HasTraits.__setattr__to handle read-only-ness of widget
        attributes """
        if (self._read_only_enabled and self.has_trait(name) and
            self.trait_metadata(name, 'read_only')): 
            raise TraitError('Widget attribute "%s" is read-only.' % name)
        else:
            super(Widget, self).__setattr__(name, value)


    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
    
    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state = self.get_state(key=key)
        buffer_keys, buffers = [], []
        for k, v in state.items():
            if isinstance(v, memoryview):
                state.pop(k)
                buffers.append(v)
                buffer_keys.append(k)
        msg = {'method': 'update', 'state': state, 'buffers': buffer_keys}
        self._send(msg, buffers=buffers)

    def get_state(self, key=None):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError("key must be a string, an iterable of keys, or None")
        state = {}
        for k in keys:
            to_json = self.trait_metadata(k, 'to_json', self._trait_to_json)
            state[k] = to_json(getattr(self, k), self)
        return state

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._allow_write(),\
             self._lock_property(**sync_data),\
             self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    setattr(self, name, from_json(sync_data[name], self))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::
            
                callback(widget, content, buffers)
            
        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::
            
                callback(widget, **kwargs)
            
            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_traits(self, **traits):
        """Dynamically add trait attributes to the Widget."""
        super(Widget, self).add_traits(**traits)
        for name, trait in traits.items():
            if trait.get_metadata('sync'):
                 self.keys.append(name)
                 self.send_state(name)

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes 
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def _allow_write(self):
        if self._read_only_enabled is False:
            yield
        else:
            try:
                self._read_only_enabled = False
                yield
            finally:
                self._read_only_enabled = True 

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the outermost context manager exits"""
        if self._holding_sync is True:
            yield
        else:
            try:
                self._holding_sync = True
                yield
            finally:
                self._holding_sync = False     
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if (key in self._property_lock
            and to_json(value, self) == self._property_lock[key]):
            return False
        elif self._holding_sync:
            self._states_to_send.add(key)
            return False
        else:
            return True
    
    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
        if method == 'backbone':
            if 'sync_data' in data:
                # get binary buffers too
                sync_data = data['sync_data']
                for i,k in enumerate(data.get('buffer_keys', [])):
                    sync_data[k] = msg['buffers'][i]
                self.set_state(sync_data) # handles all methods

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _notify_trait(self, name, old_value, new_value):
        """Called when a property has been changed."""
        # Trigger default traitlet callback machinery.  This allows any user
        # registered validation to be processed prior to allowing the widget
        # machinery to handle the state.
        LoggingConfigurable._notify_trait(self, name, old_value, new_value)

        # Send the state after the user registered callbacks for trait changes
        # have all fired (allows for user to validate values).
        if self.comm is not None and name in self.keys:
            # Make sure this isn't information that the front-end just sent us.
            if self._should_send_property(name, new_value):
                # Send new state to front-end
                self.send_state(key=name)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    @staticmethod
    def _trait_to_json(x, self):
        """Convert a trait value to json."""
        return x

    @staticmethod
    def _trait_from_json(x, self):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        # Show view.
        if self._view_name is not None:
            self._send({"method": "display"})
            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        self.comm.send(data=msg, buffers=buffers)
Пример #3
0
class CypherMagic(Magics, Configurable):
    """Runs Cypher statement on a database, specified by a connect string.

    Provides the %%cypher magic."""

    auto_limit = Int(defaults.auto_limit,
                     config=True,
                     help="""
        Automatically limit the size of the returned result sets
    """)
    style = Unicode(defaults.style,
                    config=True,
                    help="""
        Set the table printing style to any of prettytable's defined styles
        (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)
    """)
    short_errors = Bool(defaults.short_errors,
                        config=True,
                        help="""
        Don't display the full traceback on Neo4j errors
    """)
    data_contents = Bool(defaults.data_contents,
                         config=True,
                         help="""
        Bring extra data to render the results as a graph
    """)
    display_limit = Int(defaults.display_limit,
                        config=True,
                        help="""
        Automatically limit the number of rows displayed
        (full result set is still stored)
    """)
    auto_pandas = Bool(defaults.auto_pandas,
                       config=True,
                       help="""
        Return Pandas DataFrame instead of regular result sets
    """)
    auto_html = Bool(defaults.auto_html,
                     config=True,
                     help="""
        Return a D3 representation of the graph instead of regular result sets
    """)
    auto_networkx = Bool(defaults.auto_networkx,
                         config=True,
                         help="""
        Return Networkx MultiDiGraph instead of regular result sets
    """)
    rest = Bool(defaults.rest,
                config=True,
                help="""
        Return full REST representations of objects inside the result sets
    """)
    feedback = Bool(defaults.feedback,
                    config=True,
                    help="""
        Print number of rows affected
    """)
    uri = Unicode(defaults.uri,
                  config=True,
                  help="""
        Default database URL if none is defined inline
    """)

    def __init__(self, shell):
        Configurable.__init__(self, config=shell.config)
        Magics.__init__(self, shell=shell)
        # Add ourself to the list of module configurable via %config
        self.shell.configurables.append(self)
        self._legal_cypher_identifier = re.compile(r'^[A-Za-z0-9#_$]+')

    @needs_local_scope
    @line_magic('cypher')
    @cell_magic('cypher')
    def execute(self, line, cell='', local_ns={}):
        """Runs Cypher statement against a Neo4j graph database, specified by
        a connect string.

        If no database connection has been established, first word
        should be a connection string, or the user@host name
        of an established connection. Otherwise, http://localhost:7474/db/data
        will be assumed.

        Examples::

          %%cypher https://me:mypw@myhost:7474/db/data
          START n=node(*) RETURN n

          %%cypher me@myhost
          START n=node(*) RETURN n

          %%cypher
          START n=node(*) RETURN n

        Connect string syntax examples:

          http://localhost:7474/db/data
          https://me:mypw@localhost:7474/db/data

        """
        # save globals and locals so they can be referenced in bind vars
        user_ns = self.shell.user_ns
        user_ns.update(local_ns)
        parsed = parse("""{0}\n{1}""".format(line, cell), self)
        conn = Connection.get(parsed['as'] or parsed['uri'])
        first_word = parsed['cypher'].split(None, 1)[:1]
        if first_word and first_word[0].lower() == 'persist':
            return self._persist_dataframe(parsed['cypher'], conn, user_ns)
        try:
            result = run(parsed['cypher'], user_ns, self, conn)
            return result
        except StatusException as e:
            if self.short_errors:
                print(e)
            else:
                raise

    def _persist_dataframe(self, raw, conn, user_ns):
        if not DataFrame:
            raise ImportError("Must `pip install pandas` to use DataFrames")
        pieces = raw.split()
        if len(pieces) != 2:
            raise SyntaxError(
                "Format: %%cypher [connection] persist <DataFrameName>")
        frame_name = pieces[1].strip(';')
        frame = eval(frame_name, user_ns)
        if not isinstance(frame, DataFrame) and not isinstance(frame, Series):
            raise TypeError('%s is not a Pandas DataFrame or Series' %
                            frame_name)
        table_name = frame_name.lower()
        table_name = self._legal_cypher_identifier.search(table_name).group(0)
        frame.to_sql(table_name, conn.session.engine)
        return 'Persisted %s' % table_name
Пример #4
0
class NamespacedResourceReflector(LoggingConfigurable):
    """
    Base class for keeping a local up-to-date copy of a set of kubernetes resources.

    Must be subclassed once per kind of resource that needs watching.
    """
    labels = Dict({},
                  config=True,
                  help="""
        Labels to reflect onto local cache
        """)

    fields = Dict({},
                  config=True,
                  help="""
        Fields to restrict the reflected objects
        """)

    namespace = Unicode(None,
                        allow_none=True,
                        help="""
        Namespace to watch for resources in
        """)

    resources = Dict({},
                     help="""
        Dictionary of resource names to the appropriate resource objects.

        This can be accessed across threads safely.
        """)

    kind = Unicode('resource',
                   help="""
        Human readable name for kind of object we're watching for.

        Used for diagnostic messages.
        """)

    list_method_name = Unicode("",
                               help="""
        Name of function (on apigroup respresented by `api_group_name`) that is to be called to list resources.

        This will be passed a namespace & a label selector. You most likely want something
        of the form list_namespaced_<resource> - for example, `list_namespaced_pod` will
        give you a PodReflector.

        This must be set by a subclass.
        """)

    api_group_name = Unicode('CoreV1Api',
                             help="""
        Name of class that represents the apigroup on which `list_method_name` is to be found.

        Defaults to CoreV1Api, which has everything in the 'core' API group. If you want to watch Ingresses,
        for example, you would have to use ExtensionsV1beta1Api
        """)

    request_timeout = Int(60,
                          config=True,
                          help="""
        Network timeout for kubernetes watch.

        Trigger watch reconnect when a given request is taking too long,
        which can indicate network issues.
        """)

    timeout_seconds = Int(10,
                          config=True,
                          help="""
        Timeout for kubernetes watch.

        Trigger watch reconnect when no watch event has been received.
        This will cause a full reload of the currently existing resources
        from the API server.
        """)

    on_failure = Any(
        help="""Function to be called when the reflector gives up.""")

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Load kubernetes config here, since this is a Singleton and
        # so this __init__ will be run way before anything else gets run.
        try:
            config.load_incluster_config()
        except config.ConfigException:
            config.load_kube_config()
        self.api = shared_client(self.api_group_name)

        # FIXME: Protect against malicious labels?
        self.label_selector = ','.join(
            ['{}={}'.format(k, v) for k, v in self.labels.items()])
        self.field_selector = ','.join(
            ['{}={}'.format(k, v) for k, v in self.fields.items()])

        self.first_load_future = Future()
        self._stop_event = threading.Event()

        self.start()

    def __del__(self):
        self.stop()

    def _list_and_update(self):
        """
        Update current list of resources by doing a full fetch.

        Overwrites all current resource info.
        """
        initial_resources = getattr(self.api, self.list_method_name)(
            self.namespace,
            label_selector=self.label_selector,
            field_selector=self.field_selector,
            _request_timeout=self.request_timeout,
        )
        # This is an atomic operation on the dictionary!
        self.resources = {p.metadata.name: p for p in initial_resources.items}
        # return the resource version so we can hook up a watch
        return initial_resources.metadata.resource_version

    def _watch_and_update(self):
        """
        Keeps the current list of resources up-to-date

        This method is to be run not on the main thread!

        We first fetch the list of current resources, and store that. Then we
        register to be notified of changes to those resources, and keep our
        local store up-to-date based on these notifications.

        We also perform exponential backoff, giving up after we hit 32s
        wait time. This should protect against network connections dropping
        and intermittent unavailability of the api-server. Every time we
        recover from an exception we also do a full fetch, to pick up
        changes that might've been missed in the time we were not doing
        a watch.

        Note that we're playing a bit with fire here, by updating a dictionary
        in this thread while it is probably being read in another thread
        without using locks! However, dictionary access itself is atomic,
        and as long as we don't try to mutate them (do a 'fetch / modify /
        update' cycle on them), we should be ok!
        """
        selectors = []
        log_name = ""
        if self.label_selector:
            selectors.append("label selector=%r" % self.label_selector)
        if self.field_selector:
            selectors.append("field selector=%r" % self.field_selector)
        log_selector = ', '.join(selectors)

        cur_delay = 0.1

        self.log.info(
            "watching for %s with %s in namespace %s",
            self.kind,
            log_selector,
            self.namespace,
        )
        while True:
            self.log.debug("Connecting %s watcher", self.kind)
            w = watch.Watch()
            try:
                resource_version = self._list_and_update()
                if not self.first_load_future.done():
                    # signal that we've loaded our initial data
                    self.first_load_future.set_result(None)
                watch_args = {
                    'namespace': self.namespace,
                    'label_selector': self.label_selector,
                    'field_selector': self.field_selector,
                    'resource_version': resource_version,
                }
                if self.request_timeout:
                    # set network receive timeout
                    watch_args['_request_timeout'] = self.request_timeout
                if self.timeout_seconds:
                    # set watch timeout
                    watch_args['timeout_seconds'] = self.timeout_seconds
                # in case of timeout_seconds, the w.stream just exits (no exception thrown)
                # -> we stop the watcher and start a new one
                for ev in w.stream(getattr(self.api, self.list_method_name),
                                   **watch_args):
                    cur_delay = 0.1
                    resource = ev['object']
                    if ev['type'] == 'DELETED':
                        # This is an atomic delete operation on the dictionary!
                        self.resources.pop(resource.metadata.name, None)
                    else:
                        # This is an atomic operation on the dictionary!
                        self.resources[resource.metadata.name] = resource
                    if self._stop_event.is_set():
                        self.log.info("%s watcher stopped", self.kind)
                        break
            except ReadTimeoutError:
                # network read time out, just continue and restart the watch
                # this could be due to a network problem or just low activity
                self.log.warning("Read timeout watching %s, reconnecting",
                                 self.kind)
                continue
            except Exception:
                cur_delay = cur_delay * 2
                if cur_delay > 30:
                    self.log.exception(
                        "Watching resources never recovered, giving up")
                    if self.on_failure:
                        self.on_failure()
                    return
                self.log.exception(
                    "Error when watching resources, retrying in %ss",
                    cur_delay)
                time.sleep(cur_delay)
                continue
            else:
                # no events on watch, reconnect
                self.log.debug("%s watcher timeout", self.kind)
            finally:
                w.stop()
                if self._stop_event.is_set():
                    self.log.info("%s watcher stopped", self.kind)
                    break
        self.log.warning("%s watcher finished", self.kind)

    def start(self):
        """
        Start the reflection process!

        We'll do a blocking read of all resources first, so that we don't
        race with any operations that are checking the state of the pod
        store - such as polls. This should be called only once at the
        start of program initialization (when the singleton is being created),
        and not afterwards!
        """
        if hasattr(self, 'watch_thread'):
            raise ValueError(
                'Thread watching for resources is already running')

        self._list_and_update()
        self.watch_thread = threading.Thread(target=self._watch_and_update)
        # If the watch_thread is only thread left alive, exit app
        self.watch_thread.daemon = True
        self.watch_thread.start()

    def stop(self):
        self._stop_event.set()

    def stopped(self):
        return self._stop_event.is_set()
Пример #5
0
class _Selection(DescriptionWidget, ValueWidget, CoreWidget):
    """Base class for Selection widgets

    ``options`` can be specified as a list of values or a list of (label, value)
    tuples. The labels are the strings that will be displayed in the UI,
    representing the actual Python choices, and should be unique.
    If labels are not specified, they are generated from the values.

    When programmatically setting the value, a reverse lookup is performed
    among the options to check that the value is valid. The reverse lookup uses
    the equality operator by default, but another predicate may be provided via
    the ``equals`` keyword argument. For example, when dealing with numpy arrays,
    one may set equals=np.array_equal.
    """

    value = Any(None, help="Selected value", allow_none=True)
    label = Unicode(None, help="Selected label", allow_none=True)
    index = Int(None, help="Selected index", allow_none=True).tag(sync=True)

    options = Any(
        (),
        help=
        """Iterable of values or (label, value) pairs that the user can select.

    The labels are the strings that will be displayed in the UI, representing the
    actual Python choices, and should be unique.
    """)

    _options_full = None

    # This being read-only means that it cannot be changed by the user.
    _options_labels = TypedTuple(
        trait=Unicode(), read_only=True,
        help="The labels for the options.").tag(sync=True)

    disabled = Bool(help="Enable or disable user changes").tag(sync=True)

    def __init__(self, *args, **kwargs):
        self.equals = kwargs.pop('equals', lambda x, y: x == y)
        # We have to make the basic options bookkeeping consistent
        # so we don't have errors the first time validators run
        self._initializing_traits_ = True
        kwargs['options'] = _exhaust_iterable(kwargs.get('options', ()))
        self._options_full = _make_options(kwargs['options'])
        self._propagate_options(None)

        # Select the first item by default, if we can
        if 'index' not in kwargs and 'value' not in kwargs and 'label' not in kwargs:
            options = self._options_full
            nonempty = (len(options) > 0)
            kwargs['index'] = 0 if nonempty else None
            kwargs['label'], kwargs['value'] = options[0] if nonempty else (
                None, None)

        super().__init__(*args, **kwargs)
        self._initializing_traits_ = False

    @validate('options')
    def _validate_options(self, proposal):
        # if an iterator is provided, exhaust it
        proposal.value = _exhaust_iterable(proposal.value)
        # throws an error if there is a problem converting to full form
        self._options_full = _make_options(proposal.value)
        return proposal.value

    @observe('options')
    def _propagate_options(self, change):
        "Set the values and labels, and select the first option if we aren't initializing"
        options = self._options_full
        self.set_trait('_options_labels', tuple(i[0] for i in options))
        self._options_values = tuple(i[1] for i in options)
        if self._initializing_traits_ is not True:
            if len(options) > 0:
                if self.index == 0:
                    # Explicitly trigger the observers to pick up the new value and
                    # label. Just setting the value would not trigger the observers
                    # since traitlets thinks the value hasn't changed.
                    self._notify_trait('index', 0, 0)
                else:
                    self.index = 0
            else:
                self.index = None

    @validate('index')
    def _validate_index(self, proposal):
        if proposal.value is None or 0 <= proposal.value < len(
                self._options_labels):
            return proposal.value
        else:
            raise TraitError('Invalid selection: index out of bounds')

    @observe('index')
    def _propagate_index(self, change):
        "Propagate changes in index to the value and label properties"
        label = self._options_labels[
            change.new] if change.new is not None else None
        value = self._options_values[
            change.new] if change.new is not None else None
        if self.label is not label:
            self.label = label
        if self.value is not value:
            self.value = value

    @validate('value')
    def _validate_value(self, proposal):
        value = proposal.value
        try:
            return findvalue(self._options_values, value,
                             self.equals) if value is not None else None
        except ValueError:
            raise TraitError('Invalid selection: value not found')

    @observe('value')
    def _propagate_value(self, change):
        if change.new is None:
            index = None
        elif self.index is not None and self._options_values[
                self.index] == change.new:
            index = self.index
        else:
            index = self._options_values.index(change.new)
        if self.index != index:
            self.index = index

    @validate('label')
    def _validate_label(self, proposal):
        if (proposal.value is not None) and (proposal.value
                                             not in self._options_labels):
            raise TraitError('Invalid selection: label not found')
        return proposal.value

    @observe('label')
    def _propagate_label(self, change):
        if change.new is None:
            index = None
        elif self.index is not None and self._options_labels[
                self.index] == change.new:
            index = self.index
        else:
            index = self._options_labels.index(change.new)
        if self.index != index:
            self.index = index

    def _repr_keys(self):
        keys = super()._repr_keys()
        # Include options manually, as it isn't marked as synced:
        for key in sorted(chain(keys, ('options', ))):
            if key == 'index' and self.index == 0:
                # Index 0 is default when there are options
                continue
            yield key
Пример #6
0
class Builder(widgets.DOMWidget):
    """A Python wrapper for the Escher metabolic map.

    This map will also show data on reactions, metabolites, or genes.

    The Builder is a Jupyter widget that can be viewed in a Jupyter notebook or
    in Jupyter Lab. It can also be used to create a standalone HTML file for
    the map with the save_html() function.

    Maps are downloaded from the Escher website if found by name.

    :param int height:

        The height of the Escher Jupyter widget in pixels.

    :param str map_name:

        A string specifying a map to be downloaded from the Escher website.

    :param str map_json:

        A JSON string, or a file path to a JSON file, or a URL specifying a
        JSON file to be downloaded.

    :param model:

        A COBRApy model.

    :param model_name:

        A string specifying a model to be downloaded from the Escher web
        server.

    :param model_json:

        A JSON string, or a file path to a JSON file, or a URL specifying a
        JSON file to be downloaded.

    :param embedded_css:

        The CSS (as a string) to be embedded with the Escher SVG. In Jupyter,
        if you change embedded_css on an existing builder instance, the Builder
        must be restarted for this to take effect (e.g. by re-evaluating the
        widget in a cell). You can use the default embedded css as a starting
        point:

        https://github.com/zakandrewking/escher/blob/master/src/Builder-embed.css

    :param reaction_data:

        A dictionary with keys that correspond to reaction IDs and values that
        will be mapped to reaction arrows and labels.

    :param metabolite_data:

        A dictionary with keys that correspond to metabolite IDs and values
        that will be mapped to metabolite nodes and labels.

    :param gene_data:

        A dictionary with keys that correspond to gene IDs and values that will
        be mapped to corresponding reactions.

    **Keyword Arguments**

    You can also pass in any of the following options as keyword arguments. The
    details on each of these are provided in the JavaScript API documentation:

        - use_3d_transform
        - menu
        - scroll_behavior
        - use_3d_transform
        - enable_editing
        - enable_keys
        - enable_search
        - zoom_to_element
        - full_screen_button
        - disabled_buttons
        - semantic_zoom
        - starting_reaction
        - never_ask_before_quit
        - primary_metabolite_radius
        - secondary_metabolite_radius
        - marker_radius
        - gene_font_size
        - hide_secondary_metabolites
        - show_gene_reaction_rules
        - hide_all_labels
        - canvas_size_and_loc
        - reaction_data
        - reaction_styles
        - reaction_opacity
        - reaction_compare_style
        - reaction_scale
        - reaction_no_data_color
        - reaction_no_data_size
        - reaction_highlight
        - gene_data
        - and_method_in_gene_reaction_rule
        - metabolite_data
        - metabolite_styles
        - metabolite_compare_style
        - metabolite_scale
        - metabolite_no_data_color
        - metabolite_no_data_size
        - identifiers_on_map
        - highlight_missing
        - allow_building_duplicate_reactions
        - cofactors
        - enable_tooltips
        - enable_keys_with_tooltip
        - reaction_scale_preset
        - metabolite_scale_preset
        - primary_metabolite_radius
        - secondary_metabolite_radius
        - marker_radius
        - gene_font_size
        - reaction_no_data_size
        - metabolite_no_data_size

    If any of these is set to None, the default (or most-recent) value is used.
    To turn off a setting, use False instead.

    All arguments can also be set by assigning the property of an an existing
    Builder object, e.g.:

    .. code:: python

        my_builder.map_name = 'iJO1366.Central metabolism'

    """

    # widget info traitlets

    _view_name = Unicode('EscherMapView').tag(sync=True)
    _model_name = Unicode('EscherMapModel').tag(sync=True)
    _view_module = Unicode('escher').tag(sync=True)
    _model_module = Unicode('escher').tag(sync=True)
    _view_module_version = Unicode(__version__).tag(sync=True)
    _model_module_version = Unicode(__version__).tag(sync=True)

    # Python package options

    height = Int(500).tag(sync=True)

    embedded_css = Unicode(None, allow_none=True).tag(sync=True)

    @validate('embedded_css')
    def _validate_embedded_css(self, proposal):
        css = proposal['value']
        if css:
            return css.replace('\n', '')
        else:
            return None

    # synced data

    _loaded_map_json = Unicode(None, allow_none=True).tag(sync=True)

    @observe('_loaded_map_json')
    def _observe_loaded_map_json(self, change):
        # if map is cleared, then clear these
        if not change.new:
            self.map_name = None
            self.map_json = None

    _loaded_model_json = Unicode(None, allow_none=True).tag(sync=True)

    @observe('_loaded_model_json')
    def _observe_loaded_model_json(self, change):
        # if model is cleared, then clear these
        if not change.new:
            self.model = None
            self.model_name = None
            self.model_json = None

    # Python options that are indirectly synced to the widget

    map_name = Unicode(None, allow_none=True)

    @observe('map_name')
    def _observe_map_name(self, change):
        if change.new:
            self._loaded_map_json = map_json_for_name(change.new)
        else:
            self._loaded_map_json = None

    map_json = Unicode(None, allow_none=True)

    @observe('map_json')
    def _observe_map_json(self, change):
        if change.new:
            self._loaded_map_json = _load_resource(change.new, 'map_json')
        else:
            self._loaded_map_json = None

    model = Instance(Model, allow_none=True)

    @observe('model')
    def _observe_model(self, change):
        if change.new:
            self._loaded_model_json = cobra.io.to_json(change.new)
        else:
            self._loaded_model_json = None

    model_name = Unicode(None, allow_none=True)

    @observe('model_name')
    def _observe_model_name(self, change):
        if change.new:
            self._loaded_model_json = model_json_for_name(change.new)
        else:
            self._loaded_model_json = None

    model_json = Unicode(None, allow_none=True)

    @observe('model_json')
    def _observe_model_json(self, change):
        if change.new:
            self._loaded_model_json = _load_resource(change.new, 'model_json')
        else:
            self._loaded_model_json = None

    # Synced options passed as an object to JavaScript Builder

    menu = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    scroll_behavior = Any('none', allow_none=False)\
        .tag(sync=True, option=True)
    use_3d_transform = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    enable_editing = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    enable_keys = Any(False, allow_none=True)\
        .tag(sync=True, option=True)
    enable_search = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    zoom_to_element = Any(None, allow_none=True)\
        .tag(sync=True, option=True)

    full_screen_button_default = {
        'enable_keys': True,
        'scroll_behavior': 'pan',
        'enable_editing': True,
        'menu': 'all',
        'enable_tooltips': ['label']
    }
    full_screen_button = Any(full_screen_button_default, allow_none=True)\
        .tag(sync=True, option=True)

    disabled_buttons = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    semantic_zoom = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    starting_reaction = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    never_ask_before_quit = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    primary_metabolite_radius = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    secondary_metabolite_radius = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    marker_radius = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    gene_font_size = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    hide_secondary_metabolites = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    show_gene_reaction_rules = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    hide_all_labels = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    canvas_size_and_loc = Any(None, allow_none=True)\
        .tag(sync=True, option=True)

    reaction_data = Any(None, allow_none=True)\
        .tag(sync=True, option=True)

    @validate('reaction_data')
    def _validate_reaction_data(self, proposal):
        try:
            return convert_data(proposal['value'])
        except Exception:
            raise Exception("""Invalid data for reaction_data. Must be pandas
                            Series, pandas DataFrame, dict, list, or None""")

    reaction_styles = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    reaction_opacity = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    reaction_highlight = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    reaction_compare_style = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    reaction_scale = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    reaction_no_data_color = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    reaction_no_data_size = Any(None, allow_none=True)\
        .tag(sync=True, option=True)

    gene_data = Any(None, allow_none=True)\
        .tag(sync=True, option=True)

    @validate('gene_data')
    def _validate_gene_data(self, proposal):
        try:
            return convert_data(proposal['value'])
        except Exception:
            raise Exception("""Invalid data for gene_data. Must be pandas
                            Series, pandas DataFrame, dict, list, or None""")

    and_method_in_gene_reaction_rule = Any(None, allow_none=True)\
        .tag(sync=True, option=True)

    metabolite_data = Any(None, allow_none=True)\
        .tag(sync=True, option=True)

    @validate('metabolite_data')
    def _validate_metabolite_data(self, proposal):
        try:
            return convert_data(proposal['value'])
        except Exception:
            raise Exception("""Invalid data for metabolite_data. Must be pandas
                            Series, pandas DataFrame, dict, list, or None""")

    metabolite_styles = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    metabolite_compare_style = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    metabolite_scale = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    metabolite_no_data_color = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    metabolite_no_data_size = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    identifiers_on_map = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    highlight_missing = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    allow_building_duplicate_reactions = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    cofactors = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    enable_tooltips = Any(False, allow_none=True)\
        .tag(sync=True, option=True)
    enable_keys_with_tooltip = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    reaction_scale_preset = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    metabolite_scale_preset = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    primary_metabolite_radius = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    secondary_metabolite_radius = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    marker_radius = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    gene_font_size = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    reaction_no_data_size = Any(None, allow_none=True)\
        .tag(sync=True, option=True)
    metabolite_no_data_size = Any(None, allow_none=True)\
        .tag(sync=True, option=True)

    def __init__(
        self,
        map_name: str = None,
        map_json: str = None,
        model: Model = None,
        model_name: str = None,
        model_json: str = None,
        **kwargs,
    ) -> None:
        # kwargs will instantiate the traitlets
        super().__init__(**kwargs)

        if map_json:
            if map_name:
                warn('map_json overrides map_name')
            self.map_json = map_json
        else:
            self.map_name = map_name

        if model:
            if model_name:
                warn('model overrides model_name')
            if model_json:
                warn('model overrides model_json')
            self.model = model
        elif model_json:
            if model_name:
                warn('model_json overrides model_name')
            self.model_json = model_json
        else:
            self.model_name = model_name

        unavailable_options = {
            'fill_screen': """The fill_option screen is set automatically by
            the Escher Python package""",
            'tooltip_component': """The tooltip_component cannot be customized
          with the Python API""",
            'first_load_callback': """The first_load_callback cannot be
          customized with the Python API""",
            'unique_map_id': """The option unique_map_id is deprecated""",
            'ignore_bootstrap': """The option unique_map_id is deprecated""",
        }

        for key, val in kwargs.items():
            if key in unavailable_options:
                warn(val)

    def display_in_notebook(self, *args, **kwargs):
        """Deprecated.

        The Builder is now a Jupyter Widget, so you can return the Builder
        object from a cell to display it, or you can manually call the IPython
        display function:

        from IPython.display import display
        from escher import Builder
        b = Builder(...)
        display(b)

        """
        raise Exception(('display_in_notebook is deprecated. The Builder is '
                         'now a Jupyter Widget, so you can return the '
                         'Builder in a cell to see it, or use the IPython '
                         'display function (see Escher docs for details)'))

    def display_in_browser(self, *args, **kwargs):
        """Deprecated.

        We recommend using the Jupyter Widget (which now supports all Escher
        features) or the save_html option to generate a standalone HTML file
        that loads the map.

        """
        raise Exception(('display_in_browser is deprecated. We recommend using'
                         'the Jupyter Widget (which now supports all Escher'
                         'features) or the save_html option to generate a'
                         'standalone HTML file that loads the map.'))

    def save_html(self, filepath):
        """Save an HTML file containing the map.

        :param string filepath:

            The name of the HTML file.

        TODO apply options from self

        """

        #     options = transform(self.options)
        # get options
        options = {}
        for key in self.traits(option=True):
            val = getattr(self, key)
            if val is not None:
                options[key] = val
        options_json = json.dumps(options)

        template = env.get_template('standalone.html')
        embedded_css_b64 = (b64dump(self.embedded_css)
                            if self.embedded_css is not None else None)
        html = template.render(
            escher_url=get_url('escher_min'),
            embedded_css_b64=embedded_css_b64,
            map_data_json_b64=b64dump(self._loaded_map_json),
            model_data_json_b64=b64dump(self._loaded_model_json),
            options_json_b64=b64dump(options_json),
        )

        with open(expanduser(filepath), 'wb') as f:
            f.write(html.encode('utf-8'))
Пример #7
0
class TrajectoryPlayer(HasTraits):
    # should set default values here different from desired defaults
    # so `observe` can be triggered
    step = Int(0)
    sync_frame = Bool(True)
    interpolate = Bool(False)
    delay = Float(0.0)
    parameters = Dict()
    iparams = Dict()
    _interpolation_t = Float()
    _iterpolation_type = CaselessStrEnum(['linear', 'spline'])
    spin = Bool(False)
    _spin_x = Int(1)
    _spin_y = Int(0)
    _spin_z = Int(0)
    _spin_speed = Float(0.005)
    camera = CaselessStrEnum(['perspective', 'orthographic'],
                             default_value='perspective')
    _render_params = Dict()
    _real_time_update = Bool(False)

    widget_tab = Any(None)
    widget_repr = Any(None)
    widget_repr_parameters = Any(None)
    widget_quick_repr = Any(None)
    widget_general = Any(None)
    widget_picked = Any(None)
    widget_preference = Any(None)
    widget_extra = Any(None)
    widget_theme = Any(None)
    widget_help = Any(None)
    widget_export_image = Any(None)
    widget_component_slider = Any(None)
    widget_repr_slider = Any(None)
    widget_repr_choices = Any(None)
    widget_repr_control_buttons = Any(None)
    widget_repr_add = Any(None)
    widget_accordion_repr_parameters = Any(None)
    widget_repr_parameters_dialog = Any(None)
    widget_repr_name = Any(None)
    widget_component_dropdown = Any(None)
    widget_drag = Any(None)

    def __init__(self,
                 view,
                 step=1,
                 delay=100,
                 sync_frame=False,
                 min_delay=40):
        self._view = view
        self.step = step
        self.sync_frame = sync_frame
        self.delay = delay
        self.min_delay = min_delay
        self._interpolation_t = 0.5
        self._iterpolation_type = 'linear'
        self.iparams = dict(t=self._interpolation_t,
                            step=1,
                            type=self._iterpolation_type)
        self._render_params = dict(factor=4,
                                   antialias=True,
                                   trim=False,
                                   transparent=False)

        self._widget_names = [w for w in dir(self) if w.startswith('wiget_')]

    def _update_padding(self, padding=default.DEFAULT_PADDING):
        widget_collection = [
            self.widget_general, self.widget_repr, self.widget_preference,
            self.widget_repr_parameters, self.widget_help, self.widget_extra,
            self.widget_picked
        ]
        for widget in widget_collection:
            if widget is not None:
                widget.layout.padding = padding

    def _create_all_widgets(self):
        if self.widget_tab is None:
            self.widget_tab = self._display()

        old_index = self.widget_tab.selected_index
        for index, _ in enumerate(self.widget_tab.children):
            self.widget_tab.selected_index = index

        self.widget_tab.selected_index = old_index

    def smooth(self):
        self.interpolate = True

    @observe('camera')
    def on_camera_changed(self, change):
        camera_type = change['new']
        self._view._remote_call("setParameters",
                                target='Stage',
                                kwargs=dict(cameraType=camera_type))

    @property
    def frame(self):
        return self._view.frame

    @frame.setter
    def frame(self, value):
        self._view.frame = value

    @property
    def count(self):
        return self._view.count

    @observe('sync_frame')
    def update_sync_frame(self, change):
        value = change['new']
        if value:
            self._view._set_sync_frame()
        else:
            self._view._set_unsync_frame()

    @observe("delay")
    def update_delay(self, change):
        delay = change['new']
        self._view._set_delay(delay)

    @observe('parameters')
    def update_parameters(self, change):
        params = change['new']
        self.sync_frame = params.get("sync_frame", self.sync_frame)
        self.delay = params.get("delay", self.delay)
        self.step = params.get("step", self.step)

    @observe('_interpolation_t')
    def _interpolation_t_changed(self, change):
        self.iparams['t'] = change['new']

    @observe('spin')
    def on_spin_changed(self, change):
        self.spin = change['new']
        if self.spin:
            self._view._set_spin([self._spin_x, self._spin_y, self._spin_z],
                                 self._spin_speed)
        else:
            # stop
            self._view._set_spin(None, None)

    @observe('_spin_x')
    def on_spin_x_changed(self, change):
        self._spin_x = change['new']
        if self.spin:
            self._view._set_spin([self._spin_x, self._spin_y, self._spin_z],
                                 self._spin_speed)

    @observe('_spin_y')
    def on_spin_y_changed(self, change):
        self._spin_y = change['new']
        if self.spin:
            self._view._set_spin([self._spin_x, self._spin_y, self._spin_z],
                                 self._spin_speed)

    @observe('_spin_z')
    def on_spin_z_changed(self, change):
        self._spin_z = change['new']
        if self.spin:
            self._view._set_spin([self._spin_x, self._spin_y, self._spin_z],
                                 self._spin_speed)

    @observe('_spin_speed')
    def on_spin_speed_changed(self, change):
        self._spin_speed = change['new']
        if self.spin:
            self._view._set_spin([self._spin_x, self._spin_y, self._spin_z],
                                 self._spin_speed)

    def _display(self):
        box_factory = [(self._make_general_box, 'General'),
                       (self._make_widget_repr, 'Representation'),
                       (self._make_widget_preference, 'Preference'),
                       (self._make_theme_box, 'Theme'),
                       (self._make_extra_box, 'Extra'),
                       (self._show_website, 'Help')]

        tab = _make_delay_tab(box_factory, selected_index=-1)
        # tab = _make_autofit(tab)
        tab.layout.align_self = 'center'
        tab.layout.align_items = 'stretch'

        self.widget_tab = tab

        return self.widget_tab

    def _make_widget_tab(self):
        return self._display()

    def _make_button_center(self):
        button = Button(description=' Center', icon='fa-bullseye')

        @button.on_click
        def on_click(button):
            self._view.center()

        return button

    def _make_button_theme(self):
        button = Button(description='Oceans16')

        @button.on_click
        def on_click(button):
            from nglview import theme
            display(theme.oceans16())
            self._view._remote_call('cleanOutput', target='Widget')

        return button

    def _make_button_reset_theme(self, hide_toolbar=False):
        from nglview import theme

        if hide_toolbar:
            button = Button(description='Simplified Default')

            @button.on_click
            def on_click(button):
                theme.reset(hide_toolbar=True)
        else:
            button = Button(description='Default')

            @button.on_click
            def on_click(button):
                theme.reset()

        return button

    def _make_button_clean_error_output(self):
        button = Button(description='Clear Error')

        @button.on_click
        def on_click(_):
            js_utils.clean_error_output()

        return button

    def _make_widget_preference(self, width='100%'):
        def make_func():
            parameters = self._view._full_stage_parameters

            def func(pan_speed=parameters.get('panSpeed', 0.8),
                     rotate_speed=parameters.get('rotateSpeed', 2),
                     zoom_speed=parameters.get('zoomSpeed', 1.2),
                     clip_dist=parameters.get('clipDist', 10),
                     camera_fov=parameters.get('cameraFov', 40),
                     clip_far=parameters.get('clipFar', 100),
                     clip_near=parameters.get('clipNear', 0),
                     fog_far=parameters.get('fogFar', 100),
                     fog_near=parameters.get('fogNear', 50),
                     impostor=parameters.get('impostor', True),
                     light_intensity=parameters.get('lightIntensity', 1),
                     quality=parameters.get('quality', 'medium'),
                     sample_level=parameters.get('sampleLevel', 1)):

                self._view.parameters = dict(panSpeed=pan_speed,
                                             rotateSpeed=rotate_speed,
                                             zoomSpeed=zoom_speed,
                                             clipDist=clip_dist,
                                             clipFar=clip_far,
                                             clipNear=clip_near,
                                             cameraFov=camera_fov,
                                             fogFar=fog_far,
                                             fogNear=fog_near,
                                             impostor=impostor,
                                             lightIntensity=light_intensity,
                                             quality=quality,
                                             sampleLevel=sample_level)

            return func

        def make_widget_box():
            widget_sliders = interactive(make_func(),
                                         pan_speed=(0, 10, 0.1),
                                         rotate_speed=(0, 10, 1),
                                         zoom_speed=(0, 10, 1),
                                         clip_dist=(0, 200, 5),
                                         clip_far=(0, 100, 1),
                                         clip_near=(0, 100, 1),
                                         camera_fov=(15, 120, 1),
                                         fog_far=(0, 100, 1),
                                         fog_near=(0, 100, 1),
                                         light_intensity=(0, 10, 0.02),
                                         quality=['low', 'medium', 'high'],
                                         sample_level=(-1, 5, 1))

            for child in widget_sliders.children:
                if isinstance(child, (IntSlider, FloatSlider)):
                    child.layout.width = default.DEFAULT_SLIDER_WIDTH
            return widget_sliders

        if self.widget_preference is None:
            widget_sliders = make_widget_box()
            reset_button = Button(description='Reset')
            widget_sliders.children = [
                reset_button,
            ] + list(widget_sliders.children)

            @reset_button.on_click
            def on_click(reset_button):
                self._view.parameters = self._view._original_stage_parameters
                self._view._full_stage_parameters = self._view._original_stage_parameters
                widget_sliders.children = [
                    reset_button,
                ] + list(make_widget_box().children)

            self.widget_preference = _relayout_master(widget_sliders,
                                                      width=width)
        return self.widget_preference

    def _show_download_image(self):
        # "interactive" does not work for True/False in ipywidgets 4 yet.
        button = Button(description=' Screenshot', icon='fa-camera')

        @button.on_click
        def on_click(button):
            self._view.download_image()

        return button

    def _make_button_url(self, url, description):
        button = Button(description=description)

        @button.on_click
        def on_click(button):
            display(Javascript(js_utils.open_url_template.format(url=url)))

        return button

    def _show_website(self, ngl_base_url=default.NGL_BASE_URL):
        buttons = [
            self._make_button_url(url.format(ngl_base_url), description)
            for url, description in
            [("'http://arose.github.io/nglview/latest/'",
              "nglview"), (
                  "'{}/index.html'",
                  "NGL"), ("'{}/tutorial-selection-language.html'",
                           "Selection"),
             ("'{}/tutorial-molecular-representations.html'",
              "Representation")]
        ]
        self.widget_help = _make_autofit(HBox(buttons))
        return self.widget_help

    def _make_button_qtconsole(self):
        from nglview import js_utils
        button = Button(description='qtconsole', tooltip='pop up qtconsole')

        @button.on_click
        def on_click(button):
            js_utils.launch_qtconsole()

        return button

    def _make_text_picked(self):
        ta = Textarea(value=json.dumps(self._view.picked),
                      description='Picked atom')
        ta.layout.width = '300px'
        return ta

    def _refresh(self, component_slider, repr_slider):
        """update representation and component information
        """
        self._view._request_repr_parameters(component=component_slider.value,
                                            repr_index=repr_slider.value)
        self._view._remote_call('requestReprInfo', target='Widget')
        self._view._handle_repr_dict_changed(change=dict(
            new=self._view._repr_dict))

    def _make_button_repr_control(self, component_slider, repr_slider,
                                  repr_selection):
        button_refresh = Button(description=' Refresh',
                                tooltip='Get representation info',
                                icon='fa-refresh')
        button_center_selection = Button(description=' Center',
                                         tooltip='center selected atoms',
                                         icon='fa-bullseye')
        button_center_selection._ngl_name = 'button_center_selection'
        button_hide = Button(description=' Hide',
                             icon='fa-eye-slash',
                             tooltip='Hide/Show current representation')
        button_remove = Button(description=' Remove',
                               icon='fa-trash',
                               tooltip='Remove current representation')
        button_repr_parameter_dialog = Button(
            description=' Dialog',
            tooltip='Pop up representation parameters control dialog')

        @button_refresh.on_click
        def on_click_refresh(button):
            self._refresh(component_slider, repr_slider)

        @button_center_selection.on_click
        def on_click_center(center_selection):
            self._view.center_view(selection=repr_selection.value,
                                   component=component_slider.value)

        @button_hide.on_click
        def on_click_hide(button_hide):
            component = component_slider.value
            repr_index = repr_slider.value

            if button_hide.description == 'Hide':
                hide = True
                button_hide.description = 'Show'
            else:
                hide = False
                button_hide.description = 'Hide'

            self._view._remote_call('setVisibilityForRepr',
                                    target='Widget',
                                    args=[component, repr_index, not hide])

        @button_remove.on_click
        def on_click_remove(button_remove):
            self._view._remove_representation(component=component_slider.value,
                                              repr_index=repr_slider.value)
            self._view._request_repr_parameters(
                component=component_slider.value, repr_index=repr_slider.value)

        @button_repr_parameter_dialog.on_click
        def on_click_repr_dialog(_):
            from nglview.widget_box import DraggableBox
            if self.widget_repr_parameters is not None and self.widget_repr_choices:
                self.widget_repr_parameters_dialog = DraggableBox(
                    [self.widget_repr_choices, self.widget_repr_parameters])
                self.widget_repr_parameters_dialog._ipython_display_()
                self.widget_repr_parameters_dialog._dialog = 'on'

        bbox = _make_autofit(
            HBox([
                button_refresh, button_center_selection, button_hide,
                button_remove, button_repr_parameter_dialog
            ]))
        return bbox

    def _make_widget_repr(self):
        self.widget_repr_name = Text(value='', description='representation')
        self.widget_repr_name._ngl_name = 'repr_name_text'
        repr_selection = Text(value=' ', description='selection')
        repr_selection._ngl_name = 'repr_selection'
        repr_selection.width = self.widget_repr_name.width = default.DEFAULT_TEXT_WIDTH

        max_n_components = max(self._view.n_components - 1, 0)
        self.widget_component_slider = IntSlider(value=0,
                                                 max=max_n_components,
                                                 min=0,
                                                 description='component')
        self.widget_component_slider._ngl_name = 'component_slider'

        cvalue = ' '
        self.widget_component_dropdown = Dropdown(value=cvalue,
                                                  options=[
                                                      cvalue,
                                                  ],
                                                  description='component')
        self.widget_component_dropdown._ngl_name = 'component_dropdown'

        self.widget_repr_slider = IntSlider(value=0,
                                            description='representation',
                                            width=default.DEFAULT_SLIDER_WIDTH)
        self.widget_repr_slider._ngl_name = 'repr_slider'
        self.widget_repr_slider.visible = True

        self.widget_component_slider.layout.width = default.DEFAULT_SLIDER_WIDTH
        self.widget_repr_slider.layout.width = default.DEFAULT_SLIDER_WIDTH
        self.widget_component_dropdown.layout.width = self.widget_component_dropdown.max_width = default.DEFAULT_TEXT_WIDTH

        # turn off for now
        self.widget_component_dropdown.layout.display = 'none'
        self.widget_component_dropdown.description = ''

        # self.widget_accordion_repr_parameters = Accordion()
        self.widget_accordion_repr_parameters = Tab()
        self.widget_repr_parameters = self._make_widget_repr_parameters(
            self.widget_component_slider, self.widget_repr_slider,
            self.widget_repr_name)
        self.widget_accordion_repr_parameters.children = [
            self.widget_repr_parameters, Box()
        ]
        self.widget_accordion_repr_parameters.set_title(0, 'Parameters')
        self.widget_accordion_repr_parameters.set_title(1, 'Hide')
        self.widget_accordion_repr_parameters.selected_index = 1

        checkbox_reprlist = Checkbox(value=False, description='reprlist')
        checkbox_reprlist._ngl_name = 'checkbox_reprlist'
        self.widget_repr_choices = self._make_repr_name_choices(
            self.widget_component_slider, self.widget_repr_slider)
        self.widget_repr_choices._ngl_name = 'reprlist_choices'

        self.widget_repr_add = self._make_add_widget_repr(
            self.widget_component_slider)

        def on_update_checkbox_reprlist(change):
            self.widget_repr_choices.visible = change['new']

        checkbox_reprlist.observe(on_update_checkbox_reprlist, names='value')

        def on_repr_name_text_value_changed(change):
            name = change['new'].strip()
            old = change['old'].strip()

            should_update = (self._real_time_update and old and name
                             and name in REPRESENTATION_NAMES
                             and name != change['old'].strip())

            if should_update:
                component = self.widget_component_slider.value
                repr_index = self.widget_repr_slider.value
                self._view._remote_call(
                    'setRepresentation',
                    target='Widget',
                    args=[change['new'], {}, component, repr_index])
                self._view._request_repr_parameters(component, repr_index)

        def on_component_or_repr_slider_value_changed(change):
            self._view._request_repr_parameters(
                component=self.widget_component_slider.value,
                repr_index=self.widget_repr_slider.value)
            self.widget_component_dropdown.options = tuple(
                self._view._ngl_component_names)

            if self.widget_accordion_repr_parameters.selected_index >= 0:
                self.widget_repr_parameters.name = self.widget_repr_name.value
                self.widget_repr_parameters.repr_index = self.widget_repr_slider.value
                self.widget_repr_parameters.component_index = self.widget_component_slider.value

        def on_repr_selection_value_changed(change):
            if self._real_time_update:
                component = self.widget_component_slider.value
                repr_index = self.widget_repr_slider.value
                self._view._set_selection(change['new'],
                                          component=component,
                                          repr_index=repr_index)

        def on_change_component_dropdown(change):
            choice = change['new']
            if choice:
                self.widget_component_slider.value = self._view._ngl_component_names.index(
                    choice)

        self.widget_component_dropdown.observe(on_change_component_dropdown,
                                               names='value')

        self.widget_repr_slider.observe(
            on_component_or_repr_slider_value_changed, names='value')
        self.widget_component_slider.observe(
            on_component_or_repr_slider_value_changed, names='value')
        self.widget_repr_name.observe(on_repr_name_text_value_changed,
                                      names='value')
        repr_selection.observe(on_repr_selection_value_changed, names='value')

        self.widget_repr_control_buttons = self._make_button_repr_control(
            self.widget_component_slider, self.widget_repr_slider,
            repr_selection)

        blank_box = Box([Label("")])

        all_kids = [
            self.widget_repr_control_buttons, blank_box, self.widget_repr_add,
            self.widget_component_dropdown, self.widget_repr_name,
            repr_selection, self.widget_component_slider,
            self.widget_repr_slider, self.widget_repr_choices,
            self.widget_accordion_repr_parameters
        ]

        vbox = VBox(all_kids)

        self._view._request_repr_parameters(
            component=self.widget_component_slider.value,
            repr_index=self.widget_repr_slider.value)

        self.widget_repr = _relayout_master(vbox, width='100%')

        self._refresh(self.widget_component_slider, self.widget_repr_slider)

        setattr(self.widget_repr, "_saved_widgets", [])
        for _box in self.widget_repr.children:
            if hasattr(_box, 'children'):
                for kid in _box.children:
                    self.widget_repr._saved_widgets.append(kid)

        return self.widget_repr

    def _make_widget_repr_parameters(self,
                                     component_slider,
                                     repr_slider,
                                     repr_name_text=None):
        name = repr_name_text.value if repr_name_text is not None else ' '
        widget = self._view._display_repr(component=component_slider.value,
                                          repr_index=repr_slider.value,
                                          name=name)
        widget._ngl_name = 'repr_parameters_box'
        return widget

    def _make_button_export_image(self):
        slider_factor = IntSlider(value=4, min=1, max=10, description='scale')
        checkbox_antialias = Checkbox(value=True, description='antialias')
        checkbox_trim = Checkbox(value=False, description='trim')
        checkbox_transparent = Checkbox(value=False, description='transparent')
        filename_text = Text(value='Screenshot', description='Filename')
        delay_text = FloatText(value=1,
                               description='delay (s)',
                               tooltip='hello')

        start_text, stop_text, step_text = (IntText(value=0,
                                                    description='start'),
                                            IntText(value=self._view.count,
                                                    description='stop'),
                                            IntText(value=1,
                                                    description='step'))

        start_text.layout.max_width = stop_text.layout.max_width = step_text.layout.max_width \
                = filename_text.layout.max_width = delay_text.layout.max_width = default.DEFAULT_TEXT_WIDTH

        button_movie_images = Button(description='Export Images')

        def download_image(filename):
            self._view.download_image(factor=slider_factor.value,
                                      antialias=checkbox_antialias.value,
                                      trim=checkbox_trim.value,
                                      transparent=checkbox_transparent.value,
                                      filename=filename)

        @button_movie_images.on_click
        def on_click_images(button_movie_images):
            for i in range(start_text.value, stop_text.value, step_text.value):
                self._view.frame = i
                time.sleep(delay_text.value)
                download_image(filename=filename_text.value + str(i))
                time.sleep(delay_text.value)

        vbox = VBox([
            button_movie_images,
            start_text,
            stop_text,
            step_text,
            delay_text,
            filename_text,
            slider_factor,
            checkbox_antialias,
            checkbox_trim,
            checkbox_transparent,
        ])

        form_items = _relayout(vbox, make_form_item_layout())
        form = Box(form_items, layout=_make_box_layout())
        # form = _relayout_master(vbox)
        return form

    def _make_resize_notebook_slider(self):
        resize_notebook_slider = IntSlider(min=300,
                                           max=2000,
                                           description='resize notebook')

        def on_resize_notebook(change):
            width = change['new']
            self._view._remote_call('resizeNotebook',
                                    target='Widget',
                                    args=[
                                        width,
                                    ])

        resize_notebook_slider.observe(on_resize_notebook, names='value')
        return resize_notebook_slider

    def _make_add_widget_repr(self, component_slider):
        dropdown_repr_name = Dropdown(options=REPRESENTATION_NAMES,
                                      value='cartoon')
        repr_selection = Text(value='*', description='')
        repr_button = Button(description='Add',
                             tooltip="""Add representation.
        You can also hit Enter in selection box""")
        repr_button.layout = Layout(width='auto', flex='1 1 auto')

        dropdown_repr_name.layout.width = repr_selection.layout.width = default.DEFAULT_TEXT_WIDTH

        def on_click_or_submit(button_or_text_area):
            self._view.add_representation(
                selection=repr_selection.value.strip(),
                repr_type=dropdown_repr_name.value,
                component=component_slider.value)

        repr_button.on_click(on_click_or_submit)
        repr_selection.on_submit(on_click_or_submit)
        add_repr_box = HBox([repr_button, dropdown_repr_name, repr_selection])
        add_repr_box._ngl_name = 'add_repr_box'

        return add_repr_box

    def _make_repr_playground(self):
        vbox = VBox()
        children = []

        rep_names = REPRESENTATION_NAMES[:]
        excluded_names = ['ball+stick', 'distance']
        for name in excluded_names:
            rep_names.remove(name)

        repr_selection = Text(value='*')
        repr_selection.layout.width = default.DEFAULT_TEXT_WIDTH
        repr_selection_box = HBox([Label('selection'), repr_selection])
        setattr(repr_selection_box, 'value', repr_selection.value)

        for index, name in enumerate(rep_names):
            button = ToggleButton(description=name)

            def make_func():
                def on_toggle_button_value_change(change, button=button):
                    selection = repr_selection.value
                    new = change['new']  # True/False
                    if new:
                        self._view.add_representation(button.description,
                                                      selection=selection)
                    else:
                        self._view._remove_representations_by_name(
                            button.description)

                return on_toggle_button_value_change

            button.observe(make_func(), names='value')
            children.append(button)

        button_clear = Button(description='clear',
                              button_style='info',
                              icon='fa-eraser')

        @button_clear.on_click
        def on_clear(button_clear):
            self._view.clear()
            for kid in children:
                # unselect
                kid.value = False

        vbox.children = children + [repr_selection, button_clear]
        _make_autofit(vbox)
        self.widget_quick_repr = vbox
        return self.widget_quick_repr

    def _make_repr_name_choices(self, component_slider, repr_slider):
        repr_choices = Dropdown(options=[
            " ",
        ])

        def on_chosen(change):
            repr_name = change.get('new', " ")
            try:
                repr_index = repr_choices.options.index(repr_name)
                repr_slider.value = repr_index
            except ValueError:
                pass

        repr_choices.observe(on_chosen, names='value')
        repr_choices.layout.width = default.DEFAULT_TEXT_WIDTH

        self.widget_repr_choices = repr_choices
        return self.widget_repr_choices

    def _make_drag_widget(self):
        button_drag = Button(description='widget drag: off',
                             tooltip='dangerous')
        drag_nb = Button(description='notebook drag: off', tooltip='dangerous')
        button_reset_notebook = Button(description='notebook: reset',
                                       tooltip='reset?')
        button_dialog = Button(description='dialog', tooltip='make a dialog')
        button_split_half = Button(description='split screen',
                                   tooltip='try best to make a good layout')

        @button_drag.on_click
        def on_drag(button_drag):
            if button_drag.description == 'widget drag: off':
                self._view._set_draggable(True)
                button_drag.description = 'widget drag: on'
            else:
                self._view._set_draggable(False)
                button_drag.description = 'widget drag: off'

        @drag_nb.on_click
        def on_drag_nb(button_drag):
            if drag_nb.description == 'notebook drag: off':
                js_utils._set_notebook_draggable(True)
                drag_nb.description = 'notebook drag: on'
            else:
                js_utils._set_notebook_draggable(False)
                drag_nb.description = 'notebook drag: off'

        @button_reset_notebook.on_click
        def on_reset(button_reset_notebook):
            js_utils._reset_notebook()

        @button_dialog.on_click
        def on_dialog(button_dialog):
            self._view._remote_call('setDialog', target='Widget')

        @button_split_half.on_click
        def on_split_half(button_dialog):
            from nglview import js_utils
            import time
            js_utils._move_notebook_to_the_left()
            js_utils._set_notebook_width('5%')
            time.sleep(0.1)
            self._view._remote_call('setDialog', target='Widget')

        drag_box = HBox([
            button_drag, drag_nb, button_reset_notebook, button_dialog,
            button_split_half
        ])
        drag_box = _make_autofit(drag_box)
        self.widget_drag = drag_box
        return drag_box

    def _make_spin_box(self):
        checkbox_spin = Checkbox(self.spin, description='spin')
        spin_x_slide = IntSlider(self._spin_x,
                                 min=-1,
                                 max=1,
                                 description='spin_x')
        spin_y_slide = IntSlider(self._spin_y,
                                 min=-1,
                                 max=1,
                                 description='spin_y')
        spin_z_slide = IntSlider(self._spin_z,
                                 min=-1,
                                 max=1,
                                 description='spin_z')
        spin_speed_slide = FloatSlider(self._spin_speed,
                                       min=0,
                                       max=0.2,
                                       step=0.001,
                                       description='spin speed')
        # spin
        link((checkbox_spin, 'value'), (self, 'spin'))
        link((spin_x_slide, 'value'), (self, '_spin_x'))
        link((spin_y_slide, 'value'), (self, '_spin_y'))
        link((spin_z_slide, 'value'), (self, '_spin_z'))
        link((spin_speed_slide, 'value'), (self, '_spin_speed'))

        spin_box = VBox([
            checkbox_spin, spin_x_slide, spin_y_slide, spin_z_slide,
            spin_speed_slide
        ])
        spin_box = _relayout_master(spin_box, width='75%')
        return spin_box

    def _make_widget_picked(self):
        self.widget_picked = self._make_text_picked()
        picked_box = HBox([
            self.widget_picked,
        ])
        return _relayout_master(picked_box, width='75%')

    def _make_export_image_widget(self):
        if self.widget_export_image is None:
            self.widget_export_image = HBox([self._make_button_export_image()])
        return self.widget_export_image

    def _make_extra_box(self):
        if self.widget_extra is None:
            extra_list = [(self._make_drag_widget, 'Drag'),
                          (self._make_spin_box, 'Spin'),
                          (self._make_widget_picked, 'Picked'),
                          (self._make_repr_playground, 'Quick'),
                          (self._make_export_image_widget, 'Image'),
                          (self._make_command_box, 'Command')]

            extra_box = _make_delay_tab(extra_list, selected_index=0)
            self.widget_extra = extra_box
        return self.widget_extra

    def _make_theme_box(self):
        if self.widget_theme is None:
            self.widget_theme = Box([
                self._make_button_theme(),
                self._make_button_reset_theme(hide_toolbar=False),
                self._make_button_reset_theme(hide_toolbar=True),
                self._make_button_clean_error_output()
            ])
        return self.widget_theme

    def _make_general_box(self):
        if self.widget_general is None:
            step_slide = IntSlider(value=self.step,
                                   min=-100,
                                   max=100,
                                   description='step')
            delay_text = IntSlider(value=self.delay,
                                   min=10,
                                   max=1000,
                                   description='delay')
            toggle_button_interpolate = ToggleButton(
                self.interpolate,
                description='Smoothing',
                tooltip='smoothing trajectory')
            link((toggle_button_interpolate, 'value'), (self, 'interpolate'))

            background_color_picker = ColorPicker(value='white',
                                                  description='background')
            camera_type = Dropdown(value=self.camera,
                                   options=['perspective', 'orthographic'],
                                   description='camera')

            link((step_slide, 'value'), (self, 'step'))
            link((delay_text, 'value'), (self, 'delay'))
            link((toggle_button_interpolate, 'value'), (self, 'interpolate'))
            link((camera_type, 'value'), (self, 'camera'))
            link((background_color_picker, 'value'),
                 (self._view, 'background'))

            center_button = self._make_button_center()
            render_button = self._show_download_image()
            qtconsole_button = self._make_button_qtconsole()
            center_render_hbox = _make_autofit(
                HBox([
                    toggle_button_interpolate, center_button, render_button,
                    qtconsole_button
                ]))

            v0_left = VBox([
                step_slide,
                delay_text,
                background_color_picker,
                camera_type,
                center_render_hbox,
            ])

            v0_left = _relayout_master(v0_left, width='100%')
            self.widget_general = v0_left
        return self.widget_general

    def _make_command_box(self):
        widget_text_command = Text()

        @widget_text_command.on_submit
        def _on_submit_command(_):
            command = widget_text_command.value
            js_utils.execute(command)
            widget_text_command.value = ''

        return widget_text_command

    def _create_all_tabs(self):
        tab = self._display()
        for index, _ in enumerate(tab.children):
            # trigger ceating widgets
            tab.selected_index = index

        self.widget_extra = self._make_extra_box()
        for index, _ in enumerate(self.widget_extra.children):
            self.widget_extra.selected_index = index

    def _simplify_repr_control(self):
        for widget in self.widget_repr._saved_widgets:
            if not isinstance(widget, Tab):
                widget.layout.display = 'none'
        self.widget_repr_choices.layout.display = 'flex'
        self.widget_accordion_repr_parameters.selected_index = 0
class LDAPAuthenticator(Authenticator):
    server_address = Unicode(config=True,
                             help="""
        Address of the LDAP server to contact.

        Could be an IP address or hostname.
        """)
    server_port = Int(config=True,
                      help="""
        Port on which to contact the LDAP server.

        Defaults to `636` if `use_ssl` is set, `389` otherwise.
        """)

    def _server_port_default(self):
        if self.use_ssl:
            return 636  # default SSL port for LDAP
        else:
            return 389  # default plaintext port for LDAP

    use_ssl = Bool(True,
                   config=True,
                   help="""
        Use SSL to communicate with the LDAP server.

        Highly recommended! Your LDAP server must be configured to support this, however.
        """)

    bind_dn_template = Union([List(), Unicode()],
                             config=True,
                             help="""
        Template from which to construct the full dn
        when authenticating to LDAP. {username} is replaced
        with the actual username used to log in.

        If your LDAP is set in such a way that the userdn can not
        be formed from a template, but must be looked up with an attribute
        (such as uid or sAMAccountName), please see `lookup_dn`. It might
        be particularly relevant for ActiveDirectory installs.

        Unicode Example:
            uid={username},ou=people,dc=wikimedia,dc=org
        
        List Example:
            [
            	uid={username},ou=people,dc=wikimedia,dc=org,
            	uid={username},ou=Developers,dc=wikimedia,dc=org
        	]
        """)

    allowed_groups = List(config=True,
                          allow_none=True,
                          default=None,
                          help="""
        List of LDAP group DNs that users could be members of to be granted access.

        If a user is in any one of the listed groups, then that user is granted access.
        Membership is tested by fetching info about each group and looking for the User's
        dn to be a value of one of `member` or `uniqueMember`, *or* if the username being
        used to log in with is value of the `uid`.

        Set to an empty list or None to allow all users that have an LDAP account to log in,
        without performing any group membership checks.
        """)

    # FIXME: Use something other than this? THIS IS LAME, akin to websites restricting things you
    # can use in usernames / passwords to protect from SQL injection!
    valid_username_regex = Unicode(r'^[a-z][.a-z0-9_-]*$',
                                   config=True,
                                   help="""
        Regex for validating usernames - those that do not match this regex will be rejected.

        This is primarily used as a measure against LDAP injection, which has fatal security
        considerations. The default works for most LDAP installations, but some users might need
        to modify it to fit their custom installs. If you are modifying it, be sure to understand
        the implications of allowing additional characters in usernames and what that means for
        LDAP injection issues. See https://www.owasp.org/index.php/LDAP_injection for an overview
        of LDAP injection.
        """)

    lookup_dn = Bool(False,
                     config=True,
                     help="""
        Form user's DN by looking up an entry from directory

        By default, LDAPAuthenticator finds the user's DN by using `bind_dn_template`.
        However, in some installations, the user's DN does not contain the username, and
        hence needs to be looked up. You can set this to True and then use `user_search_base`
        and `user_attribute` to accomplish this.
        """)

    user_search_base = Unicode(config=True,
                               default=None,
                               allow_none=True,
                               help="""
        Base for looking up user accounts in the directory, if `lookup_dn` is set to True.

        LDAPAuthenticator will search all objects matching under this base where the `user_attribute`
        is set to the current username to form the userdn.

        For example, if all users objects existed under the base ou=people,dc=wikimedia,dc=org, and
        the username users use is set with the attribute `uid`, you can use the following config:

        ```
        c.LDAPAuthenticator.lookup_dn = True
        c.LDAPAuthenticator.lookup_dn_search_filter = '({login_attr}={login})'
        c.LDAPAuthenticator.lookup_dn_search_user = '******'
        c.LDAPAuthenticator.lookup_dn_search_password = '******'
        c.LDAPAuthenticator.user_search_base = 'ou=people,dc=wikimedia,dc=org'
        c.LDAPAuthenticator.user_attribute = 'sAMAccountName'
        c.LDAPAuthenticator.lookup_dn_user_dn_attribute = 'cn'
        ```
        """)

    user_attribute = Unicode(config=True,
                             default=None,
                             allow_none=True,
                             help="""
        Attribute containing user's name, if `lookup_dn` is set to True.

        See `user_search_base` for info on how this attribute is used.

        For most LDAP servers, this is uid.  For Active Directory, it is
        sAMAccountName.
        """)

    lookup_dn_search_filter = Unicode(config=True,
                                      default_value='({login_attr}={login})',
                                      allow_none=True,
                                      help="""
        How to query LDAP for user name lookup, if `lookup_dn` is set to True.
        """)

    lookup_dn_search_user = Unicode(config=True,
                                    default_value=None,
                                    allow_none=True,
                                    help="""
        Technical account for user lookup, if `lookup_dn` is set to True.

        If both lookup_dn_search_user and lookup_dn_search_password are None, then anonymous LDAP query will be done.
        """)

    lookup_dn_search_password = Unicode(config=True,
                                        default_value=None,
                                        allow_none=True,
                                        help="""
        Technical account for user lookup, if `lookup_dn` is set to True.
        """)

    lookup_dn_user_dn_attribute = Unicode(config=True,
                                          default_value=None,
                                          allow_none=True,
                                          help="""
        Attribute containing user's name needed for  building DN string, if `lookup_dn` is set to True.

        See `user_search_base` for info on how this attribute is used.

        For most LDAP servers, this is username.  For Active Directory, it is cn.
        """)

    escape_userdn = Bool(False,
                         config=True,
                         help="""
        If set to True, escape special chars in userdn when authenticating in LDAP.

        On some LDAP servers, when userdn contains chars like '(', ')', '\' authentication may fail when those chars
        are not escaped.
        """)

    def resolve_username(self, username_supplied_by_user):
        if self.lookup_dn:
            server = ldap3.Server(self.server_address,
                                  port=self.server_port,
                                  use_ssl=self.use_ssl)

            search_filter = self.lookup_dn_search_filter.format(
                login_attr=self.user_attribute,
                login=username_supplied_by_user)
            self.log.debug(
                "Looking up user with search_base={search_base}, search_filter='{search_filter}', attributes={attributes}"
                .format(search_base=self.user_search_base,
                        search_filter=search_filter,
                        attributes=self.user_attribute))

            conn = ldap3.Connection(server,
                                    user=self.escape_userdn_if_needed(
                                        self.lookup_dn_search_user),
                                    password=self.lookup_dn_search_password)
            is_bound = conn.bind()
            if not is_bound:
                self.log.warn(
                    "Can't connect to LDAP. Server={server}. User={user}. Password={password}"
                    .format(server=self.server_address,
                            user=self.escape_userdn_if_needed(
                                self.lookup_dn_search_user),
                            password=self.lookup_dn_search_password))
                return None

            conn.search(search_base=self.user_search_base,
                        search_scope=ldap3.SUBTREE,
                        search_filter=search_filter,
                        attributes=[self.lookup_dn_user_dn_attribute])

            if len(conn.response
                   ) == 0 or 'attributes' not in conn.response[0].keys():
                self.log.warn(
                    'username:%s No such user entry found when looking up with attribute %s',
                    username_supplied_by_user, self.user_attribute)
                return None
            return conn.response[0]['attributes'][
                self.lookup_dn_user_dn_attribute]
        else:
            return username_supplied_by_user

    def escape_userdn_if_needed(self, userdn):
        if self.escape_userdn:
            return escape_filter_chars(userdn)
        else:
            return userdn

    @gen.coroutine
    def authenticate(self, handler, data):
        username = data['username']
        password = data['password']

        # Get LDAP Connection
        def getConnection(userdn, username, password):
            server = ldap3.Server(self.server_address,
                                  port=self.server_port,
                                  use_ssl=self.use_ssl)
            self.log.debug(
                'Attempting to bind {username} with {userdn}'.format(
                    username=username, userdn=userdn))
            conn = ldap3.Connection(server,
                                    user=self.escape_userdn_if_needed(userdn),
                                    password=password)
            return conn

        # Protect against invalid usernames as well as LDAP injection attacks
        if not re.match(self.valid_username_regex, username):
            self.log.warn(
                'username:%s Illegal characters in username, must match regex %s',
                username, self.valid_username_regex)
            return None

        # No empty passwords!
        if password is None or password.strip() == '':
            self.log.warn('username:%s Login denied for blank password',
                          username)
            return None

        isBound = False
        self.log.debug("TYPE= '%s'", isinstance(self.bind_dn_template, list))

        resolved_username = self.resolve_username(username)
        if resolved_username is None:
            return None

        # In case, there are multiple binding templates
        if isinstance(self.bind_dn_template, list):
            for dn in self.bind_dn_template:
                userdn = dn.format(username=resolved_username)
                conn = getConnection(userdn, username, password)
                isBound = conn.bind()
                self.log.debug(
                    'Status of user bind {username} with {userdn} : {isBound}'.
                    format(username=username, userdn=userdn, isBound=isBound))
                if isBound:
                    break
        else:
            userdn = self.bind_dn_template.format(username=resolved_username)
            conn = getConnection(userdn, username, password)
            isBound = conn.bind()

        if isBound:
            if self.allowed_groups:
                self.log.debug('username:%s Using dn %s', username, userdn)
                for group in self.allowed_groups:
                    groupfilter = ('(|'
                                   '(member={userdn})'
                                   '(uniqueMember={userdn})'
                                   '(memberUid={uid})'
                                   ')').format(
                                       userdn=escape_filter_chars(userdn),
                                       uid=escape_filter_chars(username))
                    groupattributes = ['member', 'uniqueMember', 'memberUid']
                    if conn.search(group,
                                   search_scope=ldap3.BASE,
                                   search_filter=groupfilter,
                                   attributes=groupattributes):
                        return username
                # If we reach here, then none of the groups matched
                self.log.warn(
                    'username:%s User not in any of the allowed groups',
                    username)
                return None
            else:
                return username
        else:
            self.log.warn('Invalid password for user {username}'.format(
                username=userdn, ))
            return None
Пример #9
0
class Circle(CircleMarker):
    _view_name = Unicode('LeafletCircleView').tag(sync=True)
    _model_name = Unicode('LeafletCircleModel').tag(sync=True)

    # Options
    radius = Int(1000, help='radius of circle in meters').tag(sync=True, o=True)
Пример #10
0
class EventAnimationCreator(Tool):
    name = "EventAnimationCreator"
    description = "Create an animation of the camera image through timeslices"

    req_event = Int(0, help='Event to plot').tag(config=True)

    aliases = Dict(
        dict(r='EventFileReaderFactory.reader',
             f='EventFileReaderFactory.input_path',
             max_events='EventFileReaderFactory.max_events',
             ped='CameraR1CalibratorFactory.pedestal_path',
             tf='CameraR1CalibratorFactory.tf_path',
             pe='CameraR1CalibratorFactory.pe_path',
             cleaner='WaveformCleanerFactory.cleaner',
             e='EventAnimationCreator.req_event',
             start='Animator.start',
             end='Animator.end',
             p1='Animator.p1',
             p2='Animator.p2'))
    classes = List([
        EventFileReaderFactory, CameraR1CalibratorFactory,
        WaveformCleanerFactory, CHECMSPEFitter, Animator
    ])

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.reader = None
        self.r1 = None
        self.dl0 = None
        self.cleaner = None
        self.extractor = None
        self.dl1 = None

        self.fitter = None
        self.dead = None

        self.adc2pe = None

        self.animator = None

    def setup(self):
        self.log_format = "%(levelname)s: %(message)s [%(name)s.%(funcName)s]"
        kwargs = dict(config=self.config, tool=self)

        reader_factory = EventFileReaderFactory(**kwargs)
        reader_class = reader_factory.get_class()
        self.reader = reader_class(**kwargs)

        r1_factory = CameraR1CalibratorFactory(origin=self.reader.origin,
                                               **kwargs)
        r1_class = r1_factory.get_class()
        self.r1 = r1_class(**kwargs)

        cleaner_factory = WaveformCleanerFactory(**kwargs)
        cleaner_class = cleaner_factory.get_class()
        self.cleaner = cleaner_class(**kwargs)

        extractor_factory = ChargeExtractorFactory(**kwargs)
        extractor_class = extractor_factory.get_class()
        self.extractor = extractor_class(**kwargs)

        self.dl0 = CameraDL0Reducer(**kwargs)

        self.dl1 = CameraDL1Calibrator(extractor=self.extractor,
                                       cleaner=self.cleaner,
                                       **kwargs)

        self.fitter = CHECMSPEFitter(**kwargs)
        self.dead = Dead()

        self.animator = Animator(**kwargs)

    def start(self):
        event = self.reader.get_event(self.req_event)
        telid = list(event.r0.tels_with_data)[0]
        geom = CameraGeometry.guess(*event.inst.pixel_pos[0],
                                    event.inst.optical_foclen[0])

        self.r1.calibrate(event)
        self.dl0.reduce(event)
        self.dl1.calibrate(event)

        cleaned = event.dl1.tel[telid].cleaned[0]

        output_dir = self.reader.output_directory
        self.animator.plot(cleaned, geom, self.req_event, output_dir)

    def finish(self):
        pass
Пример #11
0
class Animator(Component):
    name = 'Animator'

    start = Int(40, help='Time to start gif').tag(config=True)
    end = Int(8, help='Time to end gif').tag(config=True)
    p1 = Int(0, help='Pixel 1').tag(config=True)
    p2 = Int(0, help='Pixel 2').tag(config=True)

    def __init__(self, config, tool, **kwargs):
        super().__init__(config=config, parent=tool, **kwargs)

        self.fig = plt.figure(figsize=(24, 10))
        self.ax1 = self.fig.add_subplot(2, 2, 1)
        self.ax2 = self.fig.add_subplot(2, 2, 3)
        self.camera = self.fig.add_subplot(1, 2, 2)

    def plot(self, waveforms, geom, event_id, output_dir):
        camera = CameraDisplay(geom,
                               ax=self.camera,
                               image=np.zeros(2048),
                               cmap='viridis')
        camera.add_colorbar()
        max_ = np.percentile(waveforms[:, self.start:self.end].max(), 60)
        camera.set_limits_minmax(0, max_)

        self.ax1.plot(waveforms[self.p1, :])
        self.ax2.plot(waveforms[self.p2, :])

        self.fig.suptitle("Event {}".format(event_id))
        self.ax1.set_title("Pixel: {}".format(self.p1))
        self.ax1.set_xlabel("Time (ns)")
        self.ax1.set_ylabel("Amplitude (p.e.)")
        self.ax2.set_title("Pixel: {}".format(self.p2))
        self.ax2.set_xlabel("Time (ns)")
        self.ax2.set_ylabel("Amplitude (p.e.)")
        camera.colorbar.set_label("Amplitude (p.e.)")

        line1, = self.ax1.plot([0, 0], self.ax1.get_ylim(), color='r', alpha=1)
        line2, = self.ax2.plot([0, 0], self.ax2.get_ylim(), color='r', alpha=1)

        self.camera.annotate("Pixel: {}".format(self.p1),
                             xy=(geom.pix_x.value[self.p1],
                                 geom.pix_y.value[self.p1]),
                             xycoords='data',
                             xytext=(0.05, 0.98),
                             textcoords='axes fraction',
                             arrowprops=dict(facecolor='red',
                                             width=2,
                                             alpha=0.4),
                             horizontalalignment='left',
                             verticalalignment='top')
        self.camera.annotate("Pixel: {}".format(self.p2),
                             xy=(geom.pix_x.value[self.p2],
                                 geom.pix_y.value[self.p2]),
                             xycoords='data',
                             xytext=(0.05, 0.94),
                             textcoords='axes fraction',
                             arrowprops=dict(facecolor='orange',
                                             width=2,
                                             alpha=0.4),
                             horizontalalignment='left',
                             verticalalignment='top')

        # Create animation
        div = 5
        increment = 1 / div
        n_frames = int((self.end - self.start) / increment)
        interval = int(500 * increment)

        # Prepare Output
        output_path = join(output_dir, "animation_e{}.gif".format(event_id))
        if not exists(output_dir):
            self.log.info("Creating directory: {}".format(output_dir))
            makedirs(output_dir)
        self.log.info("Output: {}".format(output_path))

        with tqdm(total=n_frames, desc="Creating animation") as pbar:

            def animate(i):
                pbar.update(1)
                t = self.start + (i / div)
                camera.image = waveforms[:, int(t)]
                line1.set_xdata(t)
                line2.set_xdata(t)

            anim = animation.FuncAnimation(self.fig,
                                           animate,
                                           frames=n_frames,
                                           interval=interval)
            anim.save(output_path, writer='imagemagick')

        self.log.info("Created animation: {}".format(output_path))
Пример #12
0
class Map(DOMWidget, InteractMixin):
    @default('layout')
    def _default_layout(self):
        return Layout(height='400px', align_self='stretch')

    _view_name = Unicode('LeafletMapView').tag(sync=True)
    _model_name = Unicode('LeafletMapModel').tag(sync=True)
    _view_module = Unicode('jupyter-leaflet').tag(sync=True)
    _model_module = Unicode('jupyter-leaflet').tag(sync=True)

    # Map options
    center = List(def_loc).tag(sync=True, o=True)
    zoom_start = Int(12).tag(sync=True, o=True)
    zoom = Int(12).tag(sync=True, o=True)
    max_zoom = Int(18).tag(sync=True, o=True)
    min_zoom = Int(1).tag(sync=True, o=True)
    # Interaction options
    dragging = Bool(True).tag(sync=True, o=True)
    touch_zoom = Bool(True).tag(sync=True, o=True)
    scroll_wheel_zoom = Bool(False).tag(sync=True, o=True)
    double_click_zoom = Bool(True).tag(sync=True, o=True)
    box_zoom = Bool(True).tag(sync=True, o=True)
    tap = Bool(True).tag(sync=True, o=True)
    tap_tolerance = Int(15).tag(sync=True, o=True)
    world_copy_jump = Bool(False).tag(sync=True, o=True)
    close_popup_on_click = Bool(True).tag(sync=True, o=True)
    bounce_at_zoom_limits = Bool(True).tag(sync=True, o=True)
    keyboard = Bool(True).tag(sync=True, o=True)
    keyboard_pan_offset = Int(80).tag(sync=True, o=True)
    keyboard_zoom_offset = Int(1).tag(sync=True, o=True)
    inertia = Bool(True).tag(sync=True, o=True)
    inertia_deceleration = Int(3000).tag(sync=True, o=True)
    inertia_max_speed = Int(1500).tag(sync=True, o=True)
    # inertia_threshold = Int(?, o=True).tag(sync=True)
    zoom_control = Bool(True).tag(sync=True, o=True)
    attribution_control = Bool(True).tag(sync=True, o=True)
    # fade_animation = Bool(?).tag(sync=True, o=True)
    # zoom_animation = Bool(?).tag(sync=True, o=True)
    zoom_animation_threshold = Int(4).tag(sync=True, o=True)
    # marker_zoom_animation = Bool(?).tag(sync=True, o=True)

    options = List(trait=Unicode).tag(sync=True)

    @default('options')
    def _default_options(self):
        return [name for name in self.traits(o=True)]

    _south = Float(def_loc[0]).tag(sync=True)
    _north = Float(def_loc[0]).tag(sync=True)
    _east = Float(def_loc[1]).tag(sync=True)
    _west = Float(def_loc[1]).tag(sync=True)

    default_tiles = Instance(TileLayer, allow_none=True)

    @default('default_tiles')
    def _default_tiles(self):
        return TileLayer()

    @property
    def north(self):
        return self._north

    @property
    def south(self):
        return self._south

    @property
    def east(self):
        return self._east

    @property
    def west(self):
        return self._west

    @property
    def bounds_polygon(self):
        return [(self.north, self.west), (self.north, self.east),
                (self.south, self.east), (self.south, self.west)]

    @property
    def bounds(self):
        return [(self.south, self.west), (self.north, self.east)]

    def __init__(self, **kwargs):
        super(Map, self).__init__(**kwargs)
        self.on_displayed(self._fire_children_displayed)
        if self.default_tiles is not None:
            self.layers = (self.default_tiles, )
        self.on_msg(self._handle_leaflet_event)

    def _fire_children_displayed(self, widget, **kwargs):
        for layer in self.layers:
            layer._handle_displayed(**kwargs)
        for control in self.controls:
            control._handle_displayed(**kwargs)

    layers = Tuple(trait=Instance(Layer)).tag(sync=True,
                                              **widget_serialization)
    layer_ids = List()

    @validate('layers')
    def _validate_layers(self, proposal):
        """Validate layers list.

        Makes sure only one instance of any given layer can exist in the
        layers list.
        """
        self.layer_ids = [l.model_id for l in proposal['value']]
        if len(set(self.layer_ids)) != len(self.layer_ids):
            raise LayerException(
                'duplicate layer detected, only use each layer once')
        return proposal['value']

    def add_layer(self, layer):
        if layer.model_id in self.layer_ids:
            raise LayerException('layer already on map: %r' % layer)
        layer._map = self
        self.layers = tuple([l for l in self.layers] + [layer])
        layer.visible = True

    def remove_layer(self, layer):
        if layer.model_id not in self.layer_ids:
            raise LayerException('layer not on map: %r' % layer)
        self.layers = tuple(
            [l for l in self.layers if l.model_id != layer.model_id])
        layer.visible = False

    def clear_layers(self):
        self.layers = ()

    controls = Tuple(trait=Instance(Control)).tag(sync=True,
                                                  **widget_serialization)
    control_ids = List()

    @validate('controls')
    def _validate_controls(self, proposal):
        """Validate controls list.

        Makes sure only one instance of any given layer can exist in the
        controls list.
        """
        self.control_ids = [c.model_id for c in proposal['value']]
        if len(set(self.control_ids)) != len(self.control_ids):
            raise ControlException(
                'duplicate control detected, only use each control once')
        return proposal['value']

    def add_control(self, control):
        if control.model_id in self.control_ids:
            raise ControlException('control already on map: %r' % control)
        control._map = self
        self.controls = tuple([c for c in self.controls] + [control])
        control.visible = True

    def remove_control(self, control):
        if control.model_id not in self.control_ids:
            raise ControlException('control not on map: %r' % control)
        self.controls = tuple(
            [c for c in self.controls if c.model_id != control.model_id])
        control.visible = False

    def clear_controls(self):
        self.controls = ()

    def __iadd__(self, item):
        if isinstance(item, Layer):
            self.add_layer(item)
        elif isinstance(item, Control):
            self.add_control(item)
        return self

    def __isub__(self, item):
        if isinstance(item, Layer):
            self.remove_layer(item)
        elif isinstance(item, Control):
            self.remove_control(item)
        return self

    def __add__(self, item):
        if isinstance(item, Layer):
            self.add_layer(item)
        elif isinstance(item, Control):
            self.add_control(item)
        return self

    def _handle_leaflet_event(self, _, content):
        pass
Пример #13
0
class CircleMarker(Circle):
    _view_name = Unicode('LeafletCircleMarkerView').tag(sync=True)
    _model_name = Unicode('LeafletCircleMarkerModel').tag(sync=True)

    radius = Int(10, help="radius of circle in pixels").tag(sync=True)
Пример #14
0
class Circle(Path):
    _view_name = Unicode('LeafletCircleView').tag(sync=True)
    _model_name = Unicode('LeafletCircleModel').tag(sync=True)

    location = List(def_loc).tag(sync=True)
    radius = Int(1000, help="radius of circle in meters").tag(sync=True)
Пример #15
0
class DataGenerator(Configurable):
    #params for data generator
    max_q_len = Int(10, help='max q len').tag(config=True)
    max_d_len = Int(500, help='max document len').tag(config=True)
    q_name = Unicode('q')
    d_name = Unicode('d')
    q_str_name = Unicode('q_str')
    q_lens_name = Unicode('q_lens')
    aux_d_name = Unicode('d_aux')
    vocabulary_size = Int(2000000).tag(config=True)

    def __init__(self, **kwargs):
        #init the data generator
        super(DataGenerator, self).__init__(**kwargs)
        print ("generator's vocabulary size: ", self.vocabulary_size)

    def pairwise_reader(self, pair_stream, batch_size, with_idf=False):
        #generate the batch of x,y in training time
        l_q = []
        l_q_str = []
        l_d = []
        l_d_aux = []
        l_y = []
        l_q_lens = []
        for line in pair_stream:
            cols = line.strip().split('\t')
            y = float(1.0)
            l_q_str.append(cols[0])
            q = np.array([int(t) for t in cols[0].split(',') if int(t) < self.vocabulary_size])
            t1 = np.array([int(t) for t in cols[1].split(',') if int(t) < self.vocabulary_size])
            t2 = np.array([int(t) for t in cols[2].split(',') if int(t) < self.vocabulary_size])

            #padding
            v_q = np.zeros(self.max_q_len)
            v_d = np.zeros(self.max_d_len)
            v_d_aux = np.zeros(self.max_d_len)

            v_q[:min(q.shape[0], self.max_q_len)] = q[:min(q.shape[0], self.max_q_len)]
            v_d[:min(t1.shape[0], self.max_d_len)] = t1[:min(t1.shape[0], self.max_d_len)]
            v_d_aux[:min(t2.shape[0], self.max_d_len)] = t2[:min(t2.shape[0], self.max_d_len)]

            l_q.append(v_q)
            l_d.append(v_d)
            l_d_aux.append(v_d_aux)
            l_y.append(y)
            l_q_lens.append(len(q))

            if len(l_q) >= batch_size:
                Q = np.array(l_q,  dtype=int,)
                D = np.array(l_d,  dtype=int,)
                D_aux = np.array(l_d_aux, dtype=int,)
                Q_lens = np.array(l_q_lens, dtype=int,)
                Y = np.array(l_y,  dtype=int,)
                X = {self.q_name: Q, self.d_name: D, self.aux_d_name: D_aux, self.q_lens_name: Q_lens, self.q_str_name: l_q_str}
                yield X, Y
                l_q, l_d, l_d_aux, l_y, l_q_lens, l_ids, l_q_str = [], [], [], [], [], [], []
        if l_q:
            Q = np.array(l_q,  dtype=int,)
            D = np.array(l_d,  dtype=int,)
            D_aux = np.array(l_d_aux,  dtype=int,)
            Q_lens = np.array(l_q_lens, dtype=int,)
            Y = np.array(l_y,  dtype=int,)
            X = {self.q_name: Q, self.d_name: D, self.aux_d_name: D_aux, self.q_lens_name: Q_lens, self.q_str_name: l_q_str}
            yield X, Y

    def test_pairwise_reader(self, pair_stream, batch_size):
        #generate the batch of x,y in test time
        l_q = []
        l_q_lens = []
        l_d = []

        for line in pair_stream:
            cols = line.strip().split('\t')
            q = np.array([int(t) for t in cols[0].split(',') if int(t) < self.vocabulary_size])
            t = np.array([int(t) for t in cols[1].split(',') if int(t) < self.vocabulary_size])

            v_q = np.zeros(self.max_q_len)
            v_d = np.zeros(self.max_d_len)

            v_q[:min(q.shape[0], self.max_q_len)] = q[:min(q.shape[0], self.max_q_len)]
            v_d[:min(t.shape[0], self.max_d_len)] = t[:min(t.shape[0], self.max_d_len)]

            l_q.append(v_q)
            l_d.append(v_d)
            l_q_lens.append(len(q))

            if len(l_q) >= batch_size:
                Q = np.array(l_q,  dtype=int,)
                D = np.array(l_d,  dtype=int,)
                Q_lens = np.array(l_q_lens, dtype=int,)
                X = {self.q_name: Q, self.d_name: D, self.q_lens_name: Q_lens}
                yield X
                l_q, l_d, l_q_lens = [], [], []
        if l_q:
            Q = np.array(l_q,  dtype=int,)
            D = np.array(l_d,  dtype=int,)
            Q_lens = np.array(l_q_lens, dtype=int,)
            X = {self.q_name: Q, self.d_name: D, self.q_lens_name: Q_lens}
            yield X
Пример #16
0
class MeasureControl(Control):
    _view_name = Unicode('LeafletMeasureControlView').tag(sync=True)
    _model_name = Unicode('LeafletMeasureControlModel').tag(sync=True)

    _length_units = ['feet', 'meters', 'miles', 'kilometers']
    _area_units = ['acres', 'hectares', 'sqfeet', 'sqmeters', 'sqmiles']
    _custom_units_dict = {}
    _custom_units = Dict().tag(sync=True)

    primary_length_unit = Enum(
        values=_length_units,
        default_value='feet',
        help="""Possible values are feet, meters, miles, kilometers or any user
                defined unit"""
    ).tag(sync=True, o=True)

    secondary_length_unit = Enum(
        values=_length_units,
        default_value=None,
        allow_none=True,
        help="""Possible values are feet, meters, miles, kilometers or any user
                defined unit"""
    ).tag(sync=True, o=True)

    primary_area_unit = Enum(
        values=_area_units,
        default_value='acres',
        help="""Possible values are acres, hectares, sqfeet, sqmeters, sqmiles
                or any user defined unit"""
    ).tag(sync=True, o=True)

    secondary_area_unit = Enum(
        values=_area_units,
        default_value=None,
        allow_none=True,
        help="""Possible values are acres, hectares, sqfeet, sqmeters, sqmiles
                or any user defined unit"""
    ).tag(sync=True, o=True)

    active_color = Color('#ABE67E').tag(sync=True, o=True)
    completed_color = Color('#C8F2BE').tag(sync=True, o=True)

    popup_options = Dict({
      'className': 'leaflet-measure-resultpopup',
      'autoPanPadding': [10, 10]
    }).tag(sync=True, o=True)

    capture_z_index = Int(10000).tag(sync=True, o=True)

    def add_length_unit(self, name, factor, decimals=0):
        self._length_units.append(name)
        self._add_unit(name, factor, decimals)

    def add_area_unit(self, name, factor, decimals=0):
        self._area_units.append(name)
        self._add_unit(name, factor, decimals)

    def _add_unit(self, name, factor, decimals):
        self._custom_units_dict[name] = {
            'factor': factor,
            'display': name,
            'decimals': decimals
        }
        self._custom_units = dict(**self._custom_units_dict)
Пример #17
0
class Ripple(Renderer):

    # meta
    meta = RendererMeta('Ripples', 'Ripples of color when keys are pressed',
                        'Steve Kondik', '1.0')

    # configurable traits
    ripple_width = Int(default_value=DEFAULT_WIDTH, min=1, max=5).tag(config=True)
    speed = Int(default_value=DEFAULT_SPEED, min=1, max=9).tag(config=True)
    preset = ColorPresetTrait(ColorScheme, default_value=None).tag(config=True)
    random = Bool(True).tag(config=True)
    color = ColorTrait().tag(config=True)


    def __init__(self, *args, **kwargs):

        super(Ripple, self).__init__(*args, **kwargs)

        self._generator = ColorUtils.rainbow_generator()
        self._max_distance = None
        self.key_expire_time = DEFAULT_SPEED * EXPIRE_TIME_FACTOR

        self.fps = 30


    @observe('speed')
    def _set_speed(self, change):
        self.key_expire_time = change.new * EXPIRE_TIME_FACTOR


    def _process_events(self, events):
        if self._generator is None:
            return None

        for event in events:
            if COLOR_KEY not in event.data:
                event.data[COLOR_KEY] = next(self._generator)


    @staticmethod
    def _ease(n):
        n = clamp(n, 0.0, 1.0)
        n = 2 * n
        if n < 1:
            return 0.5 * n**5

        n = n - 2
        return 0.5 * (n**5 + 2)


    def _draw_circles(self, layer, radius, event):
        width = self.ripple_width
        if COLOR_KEY not in event.data:
            return

        if event.coords is None or len(event.coords) == 0:
            self.logger.error('No coordinates available: %s', event)
            return

        if SCHEME_KEY in event.data:
            colors = event.data[SCHEME_KEY]
        else:
            color = event.data[COLOR_KEY]
            if width > 1:
                colors = ColorUtils.color_scheme(color=color, base_color=color, steps=width)
            else:
                colors = [color]
            event.data[SCHEME_KEY] = colors

        for circle_num in range(width - 1, -1, -1):
            if radius - circle_num < 0:
                continue

            rad = radius - circle_num
            a = Ripple._ease(1.0 - (rad / self._max_distance))
            cc = (*colors[circle_num].rgb, colors[circle_num].alpha * a)

            for coord in event.coords:
                layer.ellipse(coord.y, coord.x, rad / 1.33, rad, color=cc)


    async def draw(self, layer, timestamp):
        """
        Draw the next layer
        """

        # Yield until the queue becomes active
        events = await self.get_input_events()

        if len(events) > 0:
            self._process_events(events)

            # paint circles in descending timestamp order (oldest first)
            events = sorted(events, key=operator.itemgetter(0), reverse=True)

            for event in events:
                distance = 1.0 - event.percent_complete
                if distance < 1.0:
                    radius = self._max_distance * distance

                    self._draw_circles(layer, radius, event)

            return True

        return False


    @observe('preset', 'color', 'background_color', 'random')
    def _update_colors(self, change=None):
        with self.hold_trait_notifications():
            if change.new is None:
                return

            if change.name == 'preset':
                self.color = 'black'
                self.random = False
                self._generator = ColorUtils.color_generator(list(change.new.value))
            elif change.name == 'random' and change.new:
                self.preset = None
                self.color = 'black'
                self._generator = ColorUtils.rainbow_generator()
            else:
                self.preset = None
                self.random = False
                base_color = self.background_color
                if base_color == (0, 0, 0, 1):
                    base_color = None
                self._generator = ColorUtils.scheme_generator(
                    color=self.color, base_color=base_color)


    def init(self, frame) -> bool:

        if not self.has_key_input:
            return False

        self._max_distance = math.hypot(frame.width, frame.height)

        return True
Пример #18
0
class Layer(Widget, InteractMixin):
    _view_name = Unicode('LeafletLayerView').tag(sync=True)
    _model_name = Unicode('LeafletLayerModel').tag(sync=True)
    _view_module = Unicode('jupyter-leaflet').tag(sync=True)
    _model_module = Unicode('jupyter-leaflet').tag(sync=True)

    _view_module_version = Unicode(EXTENSION_VERSION).tag(sync=True)
    _model_module_version = Unicode(EXTENSION_VERSION).tag(sync=True)

    name = Unicode('').tag(sync=True)
    base = Bool(False).tag(sync=True)
    bottom = Bool(False).tag(sync=True)
    popup = Instance(Widget, allow_none=True, default_value=None).tag(sync=True, **widget_serialization)
    popup_min_width = Int(50).tag(sync=True)
    popup_max_width = Int(300).tag(sync=True)
    popup_max_height = Int(default_value=None, allow_none=True).tag(sync=True)

    options = List(trait=Unicode).tag(sync=True)

    def __init__(self, **kwargs):
        super(Layer, self).__init__(**kwargs)
        self.on_msg(self._handle_mouse_events)

    @default('options')
    def _default_options(self):
        return [name for name in self.traits(o=True)]

    # Event handling
    _click_callbacks = Instance(CallbackDispatcher, ())
    _dblclick_callbacks = Instance(CallbackDispatcher, ())
    _mousedown_callbacks = Instance(CallbackDispatcher, ())
    _mouseup_callbacks = Instance(CallbackDispatcher, ())
    _mouseover_callbacks = Instance(CallbackDispatcher, ())
    _mouseout_callbacks = Instance(CallbackDispatcher, ())

    def _handle_mouse_events(self, _, content, buffers):
        event_type = content.get('type', '')
        if event_type == 'click':
            self._click_callbacks(**content)
        if event_type == 'dblclick':
            self._dblclick_callbacks(**content)
        if event_type == 'mousedown':
            self._mousedown_callbacks(**content)
        if event_type == 'mouseup':
            self._mouseup_callbacks(**content)
        if event_type == 'mouseover':
            self._mouseover_callbacks(**content)
        if event_type == 'mouseout':
            self._mouseout_callbacks(**content)

    def on_click(self, callback, remove=False):
        self._click_callbacks.register_callback(callback, remove=remove)

    def on_dblclick(self, callback, remove=False):
        self._dblclick_callbacks.register_callback(callback, remove=remove)

    def on_mousedown(self, callback, remove=False):
        self._mousedown_callbacks.register_callback(callback, remove=remove)

    def on_mouseup(self, callback, remove=False):
        self._mouseup_callbacks.register_callback(callback, remove=remove)

    def on_mouseover(self, callback, remove=False):
        self._mouseover_callbacks.register_callback(callback, remove=remove)

    def on_mouseout(self, callback, remove=False):
        self._mouseout_callbacks.register_callback(callback, remove=remove)
Пример #19
0
class BaseNN(Configurable):
    n_bins = Int(
        11, help="number of kernels (including exact match)").tag(config=True)
    weight_size = Int(50, help="dimension of the first layer").tag(config=True)

    def __init__(self, **kwargs):
        super(BaseNN, self).__init__(**kwargs)

    @staticmethod
    def kernal_mus(n_kernels, use_exact):
        """
        get the mu for each guassian kernel. Mu is the middle of each bin
        :param n_kernels: number of kernels (including exact match). first one is exact match
        :return: l_mu, a list of mu.
        """
        if use_exact:
            l_mu = [1]
        else:
            l_mu = [2]
        if n_kernels == 1:
            return l_mu

        bin_size = 2.0 / (n_kernels - 1)  # score range from [-1, 1]
        l_mu.append(1 - bin_size / 2)  # mu: middle of the bin
        for i in np.arange(1, n_kernels - 1):
            l_mu.append(l_mu[i] - bin_size)
        return l_mu

    @staticmethod
    def kernel_sigmas(n_kernels, lamb, use_exact):
        """
        get sigmas for each guassian kernel.
        :param n_kernels: number of kernels (including exactmath.)
        :param lamb:
        :param use_exact:
        :return: l_sigma, a list of simga
        """
        bin_size = 2.0 / (n_kernels - 1)
        l_sigma = [0.00001]  # for exact match. small variance -> exact match
        if n_kernels == 1:
            return l_sigma

        l_sigma += [bin_size * lamb] * (n_kernels - 1)
        return l_sigma

    @staticmethod
    def weight_variable(shape, name):
        tmp = np.sqrt(6.0) / np.sqrt(shape[0] + shape[1])
        initial = tf.random_uniform(shape, minval=-tmp, maxval=tmp)
        return tf.Variable(initial, name=name)

    @staticmethod
    def re_pad(D, batch_size):
        D = np.array(D)
        D[D < 0] = 0
        if len(D) < batch_size:
            tmp = np.zeros((batch_size - len(D), D.shape[1]))
            D = np.concatenate((D, tmp), axis=0)
        return D

    def gen_mask(self, Q, D, use_exact=True):
        """
        Generate mask for the batch. Mask padding and OOV terms.
        Exact matches is alos masked if use_exat == False.
        :param Q: a batch of queries, [batch_size, max_len_q]
        :param D: a bacth of documents, [batch_size, max_len_d]
        :param use_exact: mask exact matches if set False.
        :return: a mask of shape [batch_size, max_len_q, max_len_d].
        """
        mask = np.zeros((self.batch_size, self.max_q_len, self.max_d_len))
        for b in range(len(Q)):
            for q in range(len(Q[b])):
                if Q[b, q] > 0:
                    mask[b, q, D[b] > 0] = 1
                    if not use_exact:
                        mask[b, q, D[b] == Q[b, q]] = 0
        return mask
Пример #20
0
class Map(DOMWidget, InteractMixin):
    _view_name = Unicode('LeafletMapView').tag(sync=True)
    _model_name = Unicode('LeafletMapModel').tag(sync=True)
    _view_module = Unicode('jupyter-leaflet').tag(sync=True)
    _model_module = Unicode('jupyter-leaflet').tag(sync=True)

    _view_module_version = Unicode(EXTENSION_VERSION).tag(sync=True)
    _model_module_version = Unicode(EXTENSION_VERSION).tag(sync=True)

    # Map options
    center = List(def_loc).tag(sync=True, o=True)
    zoom_start = Int(12).tag(sync=True, o=True)
    zoom = Int(12).tag(sync=True, o=True)
    max_zoom = Int(18).tag(sync=True, o=True)
    min_zoom = Int(1).tag(sync=True, o=True)
    interpolation = Unicode('bilinear').tag(sync=True, o=True)
    crs = Enum(values=allowed_crs, default_value='EPSG3857').tag(sync=True)

    # Specification of the basemap
    basemap = Dict(default_value=dict(
            url='https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png',
            max_zoom=19,
            attribution='Map data (c) <a href="https://openstreetmap.org">OpenStreetMap</a> contributors'
        )).tag(sync=True, o=True)
    modisdate = Unicode('yesterday').tag(sync=True)

    # Interaction options
    dragging = Bool(True).tag(sync=True, o=True)
    touch_zoom = Bool(True).tag(sync=True, o=True)
    scroll_wheel_zoom = Bool(False).tag(sync=True, o=True)
    double_click_zoom = Bool(True).tag(sync=True, o=True)
    box_zoom = Bool(True).tag(sync=True, o=True)
    tap = Bool(True).tag(sync=True, o=True)
    tap_tolerance = Int(15).tag(sync=True, o=True)
    world_copy_jump = Bool(False).tag(sync=True, o=True)
    close_popup_on_click = Bool(True).tag(sync=True, o=True)
    bounce_at_zoom_limits = Bool(True).tag(sync=True, o=True)
    keyboard = Bool(True).tag(sync=True, o=True)
    keyboard_pan_offset = Int(80).tag(sync=True, o=True)
    keyboard_zoom_offset = Int(1).tag(sync=True, o=True)
    inertia = Bool(True).tag(sync=True, o=True)
    inertia_deceleration = Int(3000).tag(sync=True, o=True)
    inertia_max_speed = Int(1500).tag(sync=True, o=True)
    # inertia_threshold = Int(?, o=True).tag(sync=True)
    # fade_animation = Bool(?).tag(sync=True, o=True)
    # zoom_animation = Bool(?).tag(sync=True, o=True)
    zoom_animation_threshold = Int(4).tag(sync=True, o=True)
    # marker_zoom_animation = Bool(?).tag(sync=True, o=True)
    fullscreen = Bool(False).tag(sync=True, o=True)

    options = List(trait=Unicode).tag(sync=True)

    style = InstanceDict(MapStyle).tag(sync=True, **widget_serialization)
    default_style = InstanceDict(MapStyle).tag(sync=True, **widget_serialization)
    dragging_style = InstanceDict(MapStyle).tag(sync=True, **widget_serialization)
    
    zoom_control = Bool(True)
    zoom_control_instance = ZoomControl()

    attribution_control = Bool(True)
    attribution_control_instance = AttributionControl(position='bottomright')

    @default('dragging_style')
    def _default_dragging_style(self):
        return {'cursor': 'move'}

    @default('options')
    def _default_options(self):
        return [name for name in self.traits(o=True)]

    south = Float(def_loc[0], read_only=True).tag(sync=True)
    north = Float(def_loc[0], read_only=True).tag(sync=True)
    east = Float(def_loc[1], read_only=True).tag(sync=True)
    west = Float(def_loc[1], read_only=True).tag(sync=True)

    layers = Tuple(trait=Instance(Layer)).tag(sync=True, **widget_serialization)

    @default('layers')
    def _default_layers(self):
        return (basemap_to_tiles(self.basemap, self.modisdate, base=True),)

    bounds = Tuple(read_only=True)
    bounds_polygon = Tuple(read_only=True)

    @observe('south', 'north', 'east', 'west')
    def _observe_bounds(self, change):
        self.set_trait('bounds', ((self.south, self.west),
                                  (self.north, self.east)))
        self.set_trait('bounds_polygon', ((self.north, self.west),
                                          (self.north, self.east),
                                          (self.south, self.east),
                                          (self.south, self.west)))

    def __init__(self, **kwargs):
        super(Map, self).__init__(**kwargs)
        self.on_displayed(self._fire_children_displayed)
        self.on_msg(self._handle_leaflet_event)

        if self.zoom_control:
            self.add_control(self.zoom_control_instance)
            
        if self.attribution_control:
            self.add_control(self.attribution_control_instance)
        
    @observe('zoom_control')
    def observe_zoom_control(self, change):
        if change['new']:
            self.add_control(self.zoom_control_instance)
        else:
            if self.zoom_control_instance in self.controls:
                self.remove_control(self.zoom_control_instance)

    @observe('attribution_control')
    def observe_attribution_control(self, change):
        if change['new']:
            self.add_control(self.attribution_control_instance)
        else:
            if self.attribution_control_instance in self.controls:
                self.remove_control(self.attribution_control_instance)

    def _fire_children_displayed(self, widget, **kwargs):
        for layer in self.layers:
            layer._handle_displayed(**kwargs)
        for control in self.controls:
            control._handle_displayed(**kwargs)

    _layer_ids = List()

    @validate('layers')
    def _validate_layers(self, proposal):
        '''Validate layers list.

        Makes sure only one instance of any given layer can exist in the
        layers list.
        '''
        self._layer_ids = [l.model_id for l in proposal.value]
        if len(set(self._layer_ids)) != len(self._layer_ids):
            raise LayerException('duplicate layer detected, only use each layer once')
        return proposal.value

    def add_layer(self, layer):
        if isinstance(layer, dict):
            layer = basemap_to_tiles(layer)
        if layer.model_id in self._layer_ids:
            raise LayerException('layer already on map: %r' % layer)
        self.layers = tuple([l for l in self.layers] + [layer])

    def remove_layer(self, layer):
        if layer.model_id not in self._layer_ids:
            raise LayerException('layer not on map: %r' % layer)
        self.layers = tuple([l for l in self.layers if l.model_id != layer.model_id])

    def substitute_layer(self, old, new):
        if isinstance(new, dict):
            new = basemap_to_tiles(new)
        if old.model_id not in self._layer_ids:
            raise LayerException('Could not substitute layer: layer not on map.')
        self.layers = tuple([new if l.model_id == old.model_id else l for l in self.layers])

    def clear_layers(self):
        self.layers = ()

    controls = Tuple(trait=Instance(Control)).tag(sync=True, **widget_serialization)
    _control_ids = List()

    @validate('controls')
    def _validate_controls(self, proposal):
        '''Validate controls list.

        Makes sure only one instance of any given layer can exist in the
        controls list.
        '''
        self._control_ids = [c.model_id for c in proposal.value]
        if len(set(self._control_ids)) != len(self._control_ids):
            raise ControlException('duplicate control detected, only use each control once')
        return proposal.value

    def add_control(self, control):
        if control.model_id in self._control_ids:
            raise ControlException('control already on map: %r' % control)
        self.controls = tuple([c for c in self.controls] + [control])

    def remove_control(self, control):
        if control.model_id not in self._control_ids:
            raise ControlException('control not on map: %r' % control)
        self.controls = tuple([c for c in self.controls if c.model_id != control.model_id])

    def clear_controls(self):
        self.controls = ()

    def __iadd__(self, item):
        if isinstance(item, Layer):
            self.add_layer(item)
        elif isinstance(item, Control):
            self.add_control(item)
        return self

    def __isub__(self, item):
        if isinstance(item, Layer):
            self.remove_layer(item)
        elif isinstance(item, Control):
            self.remove_control(item)
        return self

    def __add__(self, item):
        if isinstance(item, Layer):
            self.add_layer(item)
        elif isinstance(item, Control):
            self.add_control(item)
        return self

    # Event handling
    _interaction_callbacks = Instance(CallbackDispatcher, ())

    def _handle_leaflet_event(self, _, content, buffers):
        if content.get('event', '') == 'interaction':
            self._interaction_callbacks(**content)

    def on_interaction(self, callback, remove=False):
        self._interaction_callbacks.register_callback(callback, remove=remove)
Пример #21
0
class Main(Experiment):

    description = Unicode(u"Calculate precision-recall accuracy of trained coco model.")

    #
    # Run setup
    #
    batch_size = Int(256, config=True, help="Batch size. default: 256")
    num_workers = Int(8, config=True, help="Number of workers to use for data loading. default: 8")
    device = Unicode("cuda", config=True, help="Use `cuda` backend. default: cuda")

    #
    # Hyper parameters.
    #
    unseen = Bool(False, config=True, help="Test on unseen classes.")
    skip_tests = Int(1, config=True, help="How many test pairs to skip? for better runtime. default: 1")
    debug_size = Int(-1, config=True, help="Reduce dataset sizes. This is useful when developing the script. default -1")

    #
    # Resume previous run parameters.
    #
    resume_path = Unicode(u"/dccstor/alfassy/finalLaSO/code_release/paperModels", config=True, help="Resume from checkpoint file (requires using also '--resume_epoch'.")
    resume_epoch = Int(0, config=True, help="Epoch to resume (requires using also '--resume_path'.")
    coco_path = Unicode(u"/tmp/aa/coco", config=True, help="path to local coco dataset path")
    init_inception = Bool(True, config=True, help="Initialize the inception networks using the paper's base network.")

    #
    # Network hyper parameters
    #
    base_network_name = Unicode("Inception3", config=True, help="Name of base network to use.")
    avgpool_kernel = Int(10, config=True,
                         help="Size of the last avgpool layer in the Resnet. Should match the cropsize.")
    classifier_name = Unicode("Inception3Classifier", config=True, help="Name of classifier to use.")
    sets_network_name = Unicode("SetOpsResModule", config=True, help="Name of setops module to use.")
    sets_block_name = Unicode("SetopResBlock_v1", config=True, help="Name of setops network to use.")
    sets_basic_block_name = Unicode("SetopResBasicBlock", config=True,
                                    help="Name of the basic setops block to use (where applicable).")
    ops_layer_num = Int(1, config=True, help="Ops Module layers num.")
    ops_latent_dim = Int(1024, config=True, help="Ops Module inner latent dim.")
    setops_dropout = Float(0, config=True, help="Dropout ratio of setops module.")
    crop_size = Int(299, config=True, help="Size of input crop (Resnet 224, inception 299).")
    scale_size = Int(350, config=True, help="Size of input scale for data augmentation. default: 350")
    paper_reproduce = Bool(False, config=True, help="Use paper reproduction settings. default: False")
    discriminator_name = Unicode("AmitDiscriminator", config=True,
                                 help="Name of discriminator (unseen classifier) to use. default: AmitDiscriminator")
    embedding_dim = Int(2048, config=True, help="Dimensionality of the LaSO space. default:2048")
    classifier_latent_dim = Int(2048, config=True, help="Dimensionality of the classifier latent space. default:2048")

    def run(self):

        #
        # Setup the model
        #
        base_model, classifier, setops_model = self.setup_model()

        base_model.to(self.device)
        classifier.to(self.device)
        setops_model.to(self.device)

        base_model.eval()
        classifier.eval()
        setops_model.eval()

        #
        # Load the dataset
        #
        pair_dataset, pair_loader, pair_dataset_sub, pair_loader_sub = self.setup_datasets()

        logging.info("Calcualting classifications:")
        output_a_list, output_b_list, fake_a_list, fake_b_list, target_a_list, target_b_list = [], [], [], [], [], []
        a_S_b_list, b_S_a_list, a_U_b_list, b_U_a_list, a_I_b_list, b_I_a_list = [], [], [], [], [], []
        target_a_I_b_list, target_a_U_b_list, target_a_S_b_list, target_b_S_a_list = [], [], [], []
        with torch.no_grad():
            for batch in tqdm(pair_loader):
                input_a, input_b, target_a, target_b = _prepare_batch(batch, device=self.device)

                #
                # Apply the classification model
                #
                embed_a = base_model(input_a).view(input_a.size(0), -1)
                embed_b = base_model(input_b).view(input_b.size(0), -1)
                output_a = classifier(embed_a)
                output_b = classifier(embed_b)

                #
                # Apply the setops model.
                #
                outputs_setopt = setops_model(embed_a, embed_b)
                a_S_b, b_S_a, a_U_b, b_U_a, a_I_b, b_I_a = \
                    [classifier(o) for o in outputs_setopt[2:8]]

                output_a_list.append(output_a.cpu().numpy())
                output_b_list.append(output_b.cpu().numpy())
                # fake_a_list.append(fake_a.cpu().numpy())
                # fake_b_list.append(fake_b.cpu().numpy())
                a_S_b_list.append(a_S_b.cpu().numpy())
                b_S_a_list.append(b_S_a.cpu().numpy())
                a_U_b_list.append(a_U_b.cpu().numpy())
                b_U_a_list.append(b_U_a.cpu().numpy())
                a_I_b_list.append(a_I_b.cpu().numpy())
                b_I_a_list.append(b_I_a.cpu().numpy())

                #
                # Calculate the target setops operations
                #
                target_a_list.append(target_a.cpu().numpy())
                target_b_list.append(target_b.cpu().numpy())

                target_a = target_a.type(torch.cuda.ByteTensor)
                target_b = target_b.type(torch.cuda.ByteTensor)

                target_a_I_b = target_a & target_b
                target_a_U_b = target_a | target_b
                target_a_S_b = target_a & ~target_a_I_b
                target_b_S_a = target_b & ~target_a_I_b

                target_a_I_b_list.append(target_a_I_b.type(torch.cuda.FloatTensor).cpu().numpy())
                target_a_U_b_list.append(target_a_U_b.type(torch.cuda.FloatTensor).cpu().numpy())
                target_a_S_b_list.append(target_a_S_b.type(torch.cuda.FloatTensor).cpu().numpy())
                target_b_S_a_list.append(target_b_S_a.type(torch.cuda.FloatTensor).cpu().numpy())

        logging.info("Calculating classifications for subtraction independently:")
        a_S_b_list, b_S_a_list = [], []
        target_a_S_b_list, target_b_S_a_list = [], []
        with torch.no_grad():
            for batch in tqdm(pair_loader_sub):
                input_a, input_b, target_a, target_b = _prepare_batch(batch, device=self.device)

                #
                # Apply the classification model
                #
                embed_a = base_model(input_a).view(input_a.size(0), -1)
                embed_b = base_model(input_b).view(input_b.size(0), -1)
                #
                # Apply the setops model.
                #
                outputs_setopt = setops_model(embed_a, embed_b)
                a_S_b, b_S_a, _, _, _, _ = \
                    [classifier(o) for o in outputs_setopt[2:8]]

                a_S_b_list.append(a_S_b.cpu().numpy())
                b_S_a_list.append(b_S_a.cpu().numpy())

                #
                # Calculate the target setops operations
                #
                target_a = target_a.type(torch.cuda.ByteTensor)
                target_b = target_b.type(torch.cuda.ByteTensor)

                target_a_I_b = target_a & target_b
                target_a_S_b = target_a & ~target_a_I_b
                target_b_S_a = target_b & ~target_a_I_b
                target_a_S_b_list.append(target_a_S_b.type(torch.cuda.FloatTensor).cpu().numpy())
                target_b_S_a_list.append(target_b_S_a.type(torch.cuda.FloatTensor).cpu().numpy())

        #
        # Output restuls
        #
        logging.info("Calculating precision:")
        for output, target, name in tqdm(zip(
                (output_a_list, output_b_list, a_S_b_list, b_S_a_list, a_U_b_list, b_U_a_list, a_I_b_list, b_I_a_list),
                (target_a_list, target_b_list, target_a_S_b_list, target_b_S_a_list, target_a_U_b_list,
                 target_a_U_b_list, target_a_I_b_list, target_a_I_b_list),
                ("real_a", "real_b", "a_S_b", "b_S_a", "a_U_b", "b_U_a", "a_I_b", "b_I_a"))):

            output = np.concatenate(output, axis=0)
            target = np.concatenate(target, axis=0)

            ap = [average_precision_score(target[:, i], output[:, i]) for i in range(output.shape[1])]
            pr_graphs = [precision_recall_curve(target[:, i], output[:, i]) for i in range(output.shape[1])]
            ap_sum = 0
            for label in pair_dataset.labels_list:
                ap_sum += ap[label]
            ap_avg = ap_sum / len(pair_dataset.labels_list)
            logging.info(
                'Test {} average precision score, macro-averaged over all {} classes: {}'.format(
                    name, len(pair_dataset.labels_list), ap_avg)
            )

            with open(os.path.join(self.results_path, "{}_results.pkl".format(name)), "wb") as f:
                pickle.dump(dict(ap=ap, pr_graphs=pr_graphs), f)

    def setup_model(self):
        """Create or resume the models."""

        logging.info("Setup the models.")

        logging.info("{} model".format(self.base_network_name))
        models_path = Path(self.resume_path)
        if self.base_network_name.lower().startswith("resnet"):
            base_model, classifier = getattr(setops_models, self.base_network_name)(
                num_classes=80,
                avgpool_kernel=self.avgpool_kernel
            )
        else:
            base_model = setops_models.Inception3(aux_logits=False, transform_input=True)
            classifier = getattr(setops_models, self.classifier_name)(num_classes=80)
            if self.init_inception:
                logging.info("Initialize inception model using paper's networks.")
                checkpoint = torch.load(models_path / 'paperBaseModel')
                base_model = setops_models.Inception3(aux_logits=False, transform_input=True)
                base_model.load_state_dict(
                    {k: v for k, v in checkpoint["state_dict"].items() if k in base_model.state_dict()}
                )
                classifier.load_state_dict(
                    {k: v for k, v in checkpoint["state_dict"].items() if k in classifier.state_dict()}
                )
        setops_model_cls = getattr(setops_models, self.sets_network_name)
        setops_model = setops_model_cls(
            input_dim=self.embedding_dim,
            S_latent_dim=self.ops_latent_dim, S_layers_num=self.ops_layer_num,
            I_latent_dim=self.ops_latent_dim, I_layers_num=self.ops_layer_num,
            U_latent_dim=self.ops_latent_dim, U_layers_num=self.ops_layer_num,
            block_cls_name=self.sets_block_name, basic_block_cls_name=self.sets_basic_block_name,
            dropout_ratio=self.setops_dropout,
        )

        if self.unseen:
            #
            # In the unseen mode, we have to load the trained discriminator.
            #
            discriminator_cls = getattr(setops_models, self.discriminator_name)
            classifier = discriminator_cls(
                input_dim=self.embedding_dim,
                latent_dim=self.classifier_latent_dim
            )

        if not self.resume_path:
            raise FileNotFoundError("resume_path is compulsory in test_precision")
        logging.info("Resuming the models.")
        if not self.init_inception:
            base_model.load_state_dict(
                torch.load(sorted(models_path.glob("networks_base_model_{}*.pth".format(self.resume_epoch)))[-1])
            )

        if self.paper_reproduce:
            logging.info("using paper models")
            setops_model_cls = getattr(setops_models, "SetOpsModulePaper")
            setops_model = setops_model_cls(models_path)
            if self.unseen:
                checkpoint = torch.load(models_path / 'paperDiscriminator')
                classifier.load_state_dict(checkpoint['state_dict'])
        else:
            setops_model.load_state_dict(
                torch.load(
                    sorted(
                        models_path.glob("networks_setops_model_{}*.pth".format(self.resume_epoch))
                    )[-1]
                )
            )
            if self.unseen:
                classifier.load_state_dict(
                    torch.load(sorted(models_path.glob("networks_discriminator_{}*.pth".format(self.resume_epoch)))[-1])
                )
            elif not self.init_inception:
                classifier.load_state_dict(
                    torch.load(sorted(models_path.glob("networks_classifier_{}*.pth".format(self.resume_epoch)))[-1])
                )

        return base_model, classifier, setops_model

    def setup_datasets(self):
        """Load the training datasets."""
        # TODO: comment out if you don't want to copy coco to /tmp/aa
        # copy_coco_data()

        logging.info("Setting up the datasets.")
        CocoDatasetPairs = getattr(alfassy, "CocoDatasetPairs")
        CocoDatasetPairsSub = getattr(alfassy, "CocoDatasetPairsSub")
        if self.paper_reproduce:
            logging.info("Setting up the datasets and augmentation for paper reproduction")
            scaler = transforms.Scale((350, 350))
        else:
            scaler = transforms.Resize(self.crop_size)

        val_transform = transforms.Compose(
            [
                scaler,
                transforms.CenterCrop(self.crop_size),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ]
        )
        pair_dataset = CocoDatasetPairs(
            root_dir=self.coco_path,
            set_name='val2014',
            unseen_set=self.unseen,
            transform=val_transform,
            debug_size=self.debug_size
        )

        pair_loader = DataLoader(
            pair_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )
        pair_dataset_sub = CocoDatasetPairsSub(
            root_dir=self.coco_path,
            set_name='val2014',
            unseen_set=self.unseen,
            transform=val_transform,
            debug_size=self.debug_size
        )

        pair_loader_sub = DataLoader(
            pair_dataset_sub,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )

        return pair_dataset, pair_loader, pair_dataset_sub, pair_loader_sub
Пример #22
0
class FargateSpawner(Spawner):

    aws_region = Unicode(config=True)
    aws_ecs_host = Unicode(config=True)
    task_role_arn = Unicode(config=True)
    task_cluster_name = Unicode(config=True)
    task_container_name = Unicode(config=True)
    task_definition_arn = Unicode(config=True)
    task_security_groups = List(trait=Unicode, config=True)
    task_subnets = List(trait=Unicode, config=True)
    task_assign_public_ip = Enum(["DISABLED", "ENABLED"], "DISABLED", config=True)
    task_platform_version = Unicode("LATEST", config=True)
    notebook_port = Int(config=True)
    notebook_scheme = Unicode(config=True)
    notebook_args = List(trait=Unicode, config=True)

    authentication_class = Type(FargateSpawnerAuthentication, config=True)
    authentication = Instance(FargateSpawnerAuthentication)

    @default('authentication')
    def _default_authentication(self):
        return self.authentication_class(parent=self)

    task_arn = Unicode('')

    # We mostly are able to call the AWS API to determine status. However, when we yield the
    # event loop to create the task, if there is a poll before the creation is complete,
    # we must behave as though we are running/starting, but we have no IDs to use with which
    # to check the task.
    calling_run_task = Bool(False)

    progress_buffer = None

    def load_state(self, state):
        ''' Misleading name: this "loads" the state onto self, to be used by other methods '''

        super().load_state(state)

        # Called when first created: we might have no state from a previous invocation
        self.task_arn = state.get('task_arn', '')

    def get_state(self):
        ''' Misleading name: the return value of get_state is saved to the database in order
        to be able to restore after the hub went down '''

        state = super().get_state()
        state['task_arn'] = self.task_arn

        return state

    async def poll(self):
        # Return values, as dictacted by the Jupyterhub framework:
        # 0                   == not running, or not starting up, i.e. we need to call start
        # None                == running, or not finished starting
        # 1, or anything else == error

        return \
            None if self.calling_run_task else \
            0 if self.task_arn == '' else \
            None if (await _get_task_status(self.log, self._aws_endpoint(), self.task_cluster_name, self.task_arn)) in ALLOWED_STATUSES else \
            1

    async def start(self):
        self.log.debug('Starting spawner')

        task_port = self.notebook_port

        self.progress_buffer.write({'progress': 0.5, 'message': 'Starting server...'})
        try:
            self.calling_run_task = True
            debug_args = ['--debug'] if self.debug else []
            args = debug_args + ['--port=' + str(task_port)] + self.notebook_args
            run_response = await _run_task(
                self.log, self._aws_endpoint(),
                self.task_role_arn,
                self.task_cluster_name, self.task_container_name, self.task_definition_arn,
                self.task_security_groups, self.task_subnets,
                self.task_assign_public_ip, self.task_platform_version,
                self.cmd + args, self.get_env(), self.user_options)
            task_arn = run_response['tasks'][0]['taskArn']
            self.progress_buffer.write({'progress': 1})
        finally:
            self.calling_run_task = False

        self.task_arn = task_arn

        max_polls = 50
        num_polls = 0
        task_ip = ''
        while task_ip == '':
            num_polls += 1
            if num_polls >= max_polls:
                raise Exception('Task {} took too long to find IP address'.format(self.task_arn))

            task_ip = await _get_task_ip(self.log, self._aws_endpoint(), self.task_cluster_name, task_arn)
            await gen.sleep(1)
            self.progress_buffer.write({'progress': 1 + num_polls / max_polls})

        self.progress_buffer.write({'progress': 2})

        max_polls = self.start_timeout
        num_polls = 0
        status = ''
        while status != 'RUNNING':
            num_polls += 1
            if num_polls >= max_polls:
                raise Exception('Task {} took too long to become running'.format(self.task_arn))

            status = await _get_task_status(self.log, self._aws_endpoint(), self.task_cluster_name, task_arn)
            if status not in ALLOWED_STATUSES:
                raise Exception('Task {} is {}'.format(self.task_arn, status))

            await gen.sleep(1)
            self.progress_buffer.write({'progress': 2 + num_polls / max_polls * 98})

        self.progress_buffer.write({'progress': 100, 'message': 'Server started'})
        await gen.sleep(1)

        self.progress_buffer.close()

        return f'{self.notebook_scheme}://{task_ip}:{task_port}'

    async def stop(self, now=False):
        if self.task_arn == '':
            return

        self.log.debug('Stopping task (%s)...', self.task_arn)
        await _ensure_stopped_task(self.log, self._aws_endpoint(), self.task_cluster_name, self.task_arn)
        self.log.debug('Stopped task (%s)... (done)', self.task_arn)

    def clear_state(self):
        super().clear_state()
        self.log.debug('Clearing state: (%s)', self.task_arn)
        self.task_arn = ''
        self.progress_buffer = AsyncIteratorBuffer()

    async def progress(self):
        async for progress_message in self.progress_buffer:
            yield progress_message

    def _aws_endpoint(self):
        return {
            'region': self.aws_region,
            'ecs_host': self.aws_ecs_host,
            'ecs_auth': self.authentication.get_credentials,
        }
Пример #23
0
class Widget(LoggingHasTraits):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None

    # widgets is a dictionary of all active widget objects
    widgets = {}

    # widget_types is a registry of widgets by module, version, and name:
    widget_types = WidgetRegistry()

    @classmethod
    def close_all(cls):
        for widget in list(cls.widgets.values()):
            widget.close()


    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        version = msg.get('metadata', {}).get('version', '')
        if version.split('.')[0] != PROTOCOL_VERSION_MAJOR:
            raise ValueError("Incompatible widget protocol versions: received version %r, expected version %r"%(version, __protocol_version__))
        data = msg['content']['data']
        state = data['state']

        # Find the widget class to instantiate in the registered widgets
        widget_class = Widget.widget_types.get(state['_model_module'],
                                               state['_model_module_version'],
                                               state['_model_name'],
                                               state['_view_module'],
                                               state['_view_module_version'],
                                               state['_view_name'])
        widget = widget_class(comm=comm)
        if 'buffer_paths' in data:
            _put_buffers(state, data['buffer_paths'], msg['buffers'])
        widget.set_state(state)

    @staticmethod
    def get_manager_state(drop_defaults=False, widgets=None):
        """Returns the full state for a widget manager for embedding

        :param drop_defaults: when True, it will not include default value
        :param widgets: list with widgets to include in the state (or all widgets when None)
        :return:
        """
        state = {}
        if widgets is None:
            widgets = Widget.widgets.values()
        for widget in widgets:
            state[widget.model_id] = widget._get_embed_state(drop_defaults=drop_defaults)
        return {'version_major': 2, 'version_minor': 0, 'state': state}


    def _get_embed_state(self, drop_defaults=False):
        state = {
            'model_name': self._model_name,
            'model_module': self._model_module,
            'model_module_version': self._model_module_version
        }
        model_state, buffer_paths, buffers = _remove_buffers(self.get_state(drop_defaults=drop_defaults))
        state['state'] = model_state
        if len(buffers) > 0:
            state['buffers'] = [{'encoding': 'base64',
                                 'path': p,
                                 'data': standard_b64encode(d).decode('ascii')}
                                for p, d in zip(buffer_paths, buffers)]
        return state

    def get_view_spec(self):
        return dict(version_major=2, version_minor=0, model_id=self._model_id)

    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_name = Unicode('WidgetModel',
        help="Name of the model.", read_only=True).tag(sync=True)
    _model_module = Unicode('@jupyter-widgets/base',
        help="The namespace for the model.", read_only=True).tag(sync=True)
    _model_module_version = Unicode(__jupyter_widgets_base_version__,
        help="A semver requirement for namespace version containing the model.", read_only=True).tag(sync=True)
    _view_name = Unicode(None, allow_none=True,
        help="Name of the view.").tag(sync=True)
    _view_module = Unicode(None, allow_none=True,
        help="The namespace for the view.").tag(sync=True)
    _view_module_version = Unicode('',
        help="A semver requirement for the namespace version containing the view.").tag(sync=True)

    _view_count = Int(None, allow_none=True,
        help="EXPERIMENTAL: The number of views of the model displayed in the frontend. This attribute is experimental and may change or be removed in the future. None signifies that views will not be tracked. Set this to 0 to start tracking view creation/deletion.").tag(sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)

    keys = List(help="The traits which are synced.")

    @default('keys')
    def _default_keys(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Dict()
    _holding_sync = False
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            state, buffer_paths, buffers = _remove_buffers(self.get_state())

            args = dict(target_name='jupyter.widget',
                        data={'state': state, 'buffer_paths': buffer_paths},
                        buffers=buffers,
                        metadata={'version': __protocol_version__}
                        )
            if self._model_id is not None:
                args['comm_id'] = self._model_id

            self.comm = Comm(**args)

    @observe('comm')
    def _comm_changed(self, change):
        """Called when the comm is changed."""
        if change['new'] is None:
            return
        self._model_id = self.model_id

        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
            self._ipython_display_ = None

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end, if it exists.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state = self.get_state(key=key)
        if len(state) > 0:
            state, buffer_paths, buffers = _remove_buffers(state)
            msg = {'method': 'update', 'state': state, 'buffer_paths': buffer_paths}
            self._send(msg, buffers=buffers)

    def get_state(self, key=None, drop_defaults=False):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError("key must be a string, an iterable of keys, or None")
        state = {}
        traits = self.traits()
        for k in keys:
            to_json = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = to_json(getattr(self, k), self)
            if not PY3 and isinstance(traits[k], Bytes) and isinstance(value, bytes):
                value = memoryview(value)
            if not drop_defaults or not self._compare(value, traits[k].default_value):
                state[k] = value
        return state

    def _is_numpy(self, x):
        return x.__class__.__name__ == 'ndarray' and x.__class__.__module__ == 'numpy'

    def _compare(self, a, b):
        if self._is_numpy(a) or self._is_numpy(b):
            import numpy as np
            return np.array_equal(a, b)
        else:
            return a == b

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._lock_property(**sync_data), self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    self.set_trait(name, from_json(sync_data[name], self))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::

                callback(widget, content, buffers)

        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::

                callback(widget, **kwargs)

            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_traits(self, **traits):
        """Dynamically add trait attributes to the Widget."""
        super(Widget, self).add_traits(**traits)
        for name, trait in traits.items():
            if trait.get_metadata('sync'):
                self.keys.append(name)
                self.send_state(name)

    def notify_change(self, change):
        """Called when a property has changed."""
        # Send the state to the frontend before the user-registered callbacks
        # are called.
        name = change['name']
        if self.comm is not None and self.comm.kernel is not None:
            # Make sure this isn't information that the front-end just sent us.
            if name in self.keys and self._should_send_property(name, getattr(self, name)):
                # Send new state to front-end
                self.send_state(key=name)
        super(Widget, self).notify_change(change)

    def __repr__(self):
        return self._gen_repr_from_keys(self._repr_keys())

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the outermost context manager exits"""
        if self._holding_sync is True:
            yield
        else:
            try:
                self._holding_sync = True
                yield
            finally:
                self._holding_sync = False
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if key in self._property_lock:
            # model_state, buffer_paths, buffers
            split_value = _remove_buffers({ key: to_json(value, self)})
            split_lock = _remove_buffers({ key: self._property_lock[key]})
            # A roundtrip conversion through json in the comparison takes care of
            # idiosyncracies of how python data structures map to json, for example
            # tuples get converted to lists.
            if (jsonloads(jsondumps(split_value[0])) == split_lock[0]
                and split_value[1] == split_lock[1]
                and _buffer_list_equal(split_value[2], split_lock[2])):
                return False
        if self._holding_sync:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        if method == 'update':
            if 'state' in data:
                state = data['state']
                if 'buffer_paths' in data:
                    _put_buffers(state, data['buffer_paths'], msg['buffers'])
                self.set_state(state)

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    @staticmethod
    def _trait_to_json(x, self):
        """Convert a trait value to json."""
        return x

    @staticmethod
    def _trait_from_json(x, self):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        if self._view_name is not None:

            # The 'application/vnd.jupyter.widget-view+json' mimetype has not been registered yet.
            # See the registration process and naming convention at
            # http://tools.ietf.org/html/rfc6838
            # and the currently registered mimetypes at
            # http://www.iana.org/assignments/media-types/media-types.xhtml.
            data = {
                'text/plain': repr(self),
                'text/html': self._fallback_html(),
                'application/vnd.jupyter.widget-view+json': {
                    'version_major': 2,
                    'version_minor': 0,
                    'model_id': self._model_id
                }
            }
            display(data, raw=True)

            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        if self.comm is not None and self.comm.kernel is not None:
            self.comm.send(data=msg, buffers=buffers)

    def _repr_keys(self):
        traits = self.traits()
        for key in sorted(self.keys):
            # Exclude traits that start with an underscore
            if key[0] == '_':
                continue
            # Exclude traits who are equal to their default value
            value = getattr(self, key)
            trait = traits[key]
            if self._compare(value, trait.default_value):
                continue
            elif (isinstance(trait, (Container, Dict)) and
                  trait.default_value == Undefined and
                  (value is None or len(value) == 0)):
                # Empty container, and dynamic default will be empty
                continue
            yield key

    def _gen_repr_from_keys(self, keys):
        class_name = self.__class__.__name__
        signature = ', '.join(
            '%s=%r' % (key, getattr(self, key))
            for key in keys
        )
        return '%s(%s)' % (class_name, signature)

    def _fallback_html(self):
        return _FALLBACK_HTML_TEMPLATE.format(widget_type=type(self).__name__)
Пример #24
0
class Widget(LoggingConfigurable):
    #-------------------------------------------------------------------------
    # Class attributes
    #-------------------------------------------------------------------------
    _widget_construction_callback = None
    widgets = {}
    widget_types = {}

    @staticmethod
    def on_widget_constructed(callback):
        """Registers a callback to be called when a widget is constructed.

        The callback must have the following signature:
        callback(widget)"""
        Widget._widget_construction_callback = callback

    @staticmethod
    def _call_widget_constructed(widget):
        """Static method, called when a widget is constructed."""
        if Widget._widget_construction_callback is not None and callable(
                Widget._widget_construction_callback):
            Widget._widget_construction_callback(widget)

    @staticmethod
    def handle_comm_opened(comm, msg):
        """Static method, called when a widget is constructed."""
        class_name = str(msg['content']['data']['widget_class'])
        if class_name in Widget.widget_types:
            widget_class = Widget.widget_types[class_name]
        else:
            widget_class = import_item(class_name)
        widget = widget_class(comm=comm)

    @staticmethod
    def get_manager_state(drop_defaults=False):
        return dict(
            version_major=1,
            version_minor=0,
            state={
                k: {
                    'model_name':
                    Widget.widgets[k]._model_name,
                    'model_module':
                    Widget.widgets[k]._model_module,
                    'model_module_version':
                    Widget.widgets[k]._model_module_version,
                    'state':
                    Widget.widgets[k].get_state(drop_defaults=drop_defaults)
                }
                for k in Widget.widgets
            })

    def get_view_spec(self):
        return dict(version_major=1, version_minor=0, model_id=self._model_id)

    #-------------------------------------------------------------------------
    # Traits
    #-------------------------------------------------------------------------
    _model_module = Unicode(
        'jupyter-js-widgets',
        help="A JavaScript module name in which to find _model_name.").tag(
            sync=True)
    _model_name = Unicode(
        'WidgetModel',
        help="Name of the model object in the front-end.").tag(sync=True)
    _model_module_version = Unicode(
        '*', help="A semver requirement for the model module version.").tag(
            sync=True)
    _view_module = Unicode(
        None,
        allow_none=True,
        help="A JavaScript module in which to find _view_name.").tag(sync=True)
    _view_name = Unicode(None,
                         allow_none=True,
                         help="Name of the view object.").tag(sync=True)
    _view_module_version = Unicode(
        '*', help="A semver requirement for the view module version.").tag(
            sync=True)
    comm = Instance('ipykernel.comm.Comm', allow_none=True)

    msg_throttle = Int(
        1,
        help=
        """Maximum number of msgs the front-end can send before receiving an idle msg from the back-end."""
    ).tag(sync=True)

    keys = List()

    def _keys_default(self):
        return [name for name in self.traits(sync=True)]

    _property_lock = Dict()
    _holding_sync = False
    _states_to_send = Set()
    _display_callbacks = Instance(CallbackDispatcher, ())
    _msg_callbacks = Instance(CallbackDispatcher, ())

    #-------------------------------------------------------------------------
    # (Con/de)structor
    #-------------------------------------------------------------------------
    def __init__(self, **kwargs):
        """Public constructor"""
        self._model_id = kwargs.pop('model_id', None)
        super(Widget, self).__init__(**kwargs)

        Widget._call_widget_constructed(self)
        self.open()

    def __del__(self):
        """Object disposal"""
        self.close()

    #-------------------------------------------------------------------------
    # Properties
    #-------------------------------------------------------------------------

    def open(self):
        """Open a comm to the frontend if one isn't already open."""
        if self.comm is None:
            state, buffer_keys, buffers = self._split_state_buffers(
                self.get_state())

            args = dict(target_name='jupyter.widget', data=state)
            if self._model_id is not None:
                args['comm_id'] = self._model_id

            self.comm = Comm(**args)
            if buffers:
                # FIXME: workaround ipykernel missing binary message support in open-on-init
                # send state with binary elements as second message
                self.send_state()

    @observe('comm')
    def _comm_changed(self, change):
        """Called when the comm is changed."""
        if change['new'] is None:
            return
        self._model_id = self.model_id

        self.comm.on_msg(self._handle_msg)
        Widget.widgets[self.model_id] = self

    @property
    def model_id(self):
        """Gets the model id of this widget.

        If a Comm doesn't exist yet, a Comm will be created automagically."""
        return self.comm.comm_id

    #-------------------------------------------------------------------------
    # Methods
    #-------------------------------------------------------------------------

    def close(self):
        """Close method.

        Closes the underlying comm.
        When the comm is closed, all of the widget views are automatically
        removed from the front-end."""
        if self.comm is not None:
            Widget.widgets.pop(self.model_id, None)
            self.comm.close()
            self.comm = None
            self._ipython_display_ = None

    def _split_state_buffers(self, state):
        """Return (state_without_buffers, buffer_keys, buffers) for binary message parts"""
        buffer_keys, buffers = [], []
        for k, v in list(state.items()):
            if isinstance(v, _binary_types):
                state.pop(k)
                buffers.append(v)
                buffer_keys.append(k)
        return state, buffer_keys, buffers

    def send_state(self, key=None):
        """Sends the widget state, or a piece of it, to the front-end.

        Parameters
        ----------
        key : unicode, or iterable (optional)
            A single property's name or iterable of property names to sync with the front-end.
        """
        state = self.get_state(key=key)
        state, buffer_keys, buffers = self._split_state_buffers(state)
        msg = {'method': 'update', 'state': state, 'buffers': buffer_keys}
        self._send(msg, buffers=buffers)

    def get_state(self, key=None, drop_defaults=False):
        """Gets the widget state, or a piece of it.

        Parameters
        ----------
        key : unicode or iterable (optional)
            A single property's name or iterable of property names to get.

        Returns
        -------
        state : dict of states
        metadata : dict
            metadata for each field: {key: metadata}
        """
        if key is None:
            keys = self.keys
        elif isinstance(key, string_types):
            keys = [key]
        elif isinstance(key, collections.Iterable):
            keys = key
        else:
            raise ValueError(
                "key must be a string, an iterable of keys, or None")
        state = {}
        traits = self.traits()
        for k in keys:
            to_json = self.trait_metadata(k, 'to_json', self._trait_to_json)
            value = to_json(getattr(self, k), self)
            if not PY3 and isinstance(traits[k], Bytes) and isinstance(
                    value, bytes):
                value = memoryview(value)
            if not drop_defaults or value != traits[k].default_value:
                state[k] = value
        return state

    def set_state(self, sync_data):
        """Called when a state is received from the front-end."""
        # The order of these context managers is important. Properties must
        # be locked when the hold_trait_notification context manager is
        # released and notifications are fired.
        with self._lock_property(**sync_data), self.hold_trait_notifications():
            for name in sync_data:
                if name in self.keys:
                    from_json = self.trait_metadata(name, 'from_json',
                                                    self._trait_from_json)
                    self.set_trait(name, from_json(sync_data[name], self))

    def send(self, content, buffers=None):
        """Sends a custom msg to the widget model in the front-end.

        Parameters
        ----------
        content : dict
            Content of the message to send.
        buffers : list of binary buffers
            Binary buffers to send with message
        """
        self._send({"method": "custom", "content": content}, buffers=buffers)

    def on_msg(self, callback, remove=False):
        """(Un)Register a custom msg receive callback.

        Parameters
        ----------
        callback: callable
            callback will be passed three arguments when a message arrives::

                callback(widget, content, buffers)

        remove: bool
            True if the callback should be unregistered."""
        self._msg_callbacks.register_callback(callback, remove=remove)

    def on_displayed(self, callback, remove=False):
        """(Un)Register a widget displayed callback.

        Parameters
        ----------
        callback: method handler
            Must have a signature of::

                callback(widget, **kwargs)

            kwargs from display are passed through without modification.
        remove: bool
            True if the callback should be unregistered."""
        self._display_callbacks.register_callback(callback, remove=remove)

    def add_traits(self, **traits):
        """Dynamically add trait attributes to the Widget."""
        super(Widget, self).add_traits(**traits)
        for name, trait in traits.items():
            if trait.get_metadata('sync'):
                self.keys.append(name)
                self.send_state(name)

    def notify_change(self, change):
        """Called when a property has changed."""
        # Send the state to the frontend before the user-registered callbacks
        # are called.
        name = change['name']
        if self.comm is not None and self.comm.kernel is not None:
            # Make sure this isn't information that the front-end just sent us.
            if name in self.keys and self._should_send_property(
                    name, change['new']):
                # Send new state to front-end
                self.send_state(key=name)
        LoggingConfigurable.notify_change(self, change)

    #-------------------------------------------------------------------------
    # Support methods
    #-------------------------------------------------------------------------
    @contextmanager
    def _lock_property(self, **properties):
        """Lock a property-value pair.

        The value should be the JSON state of the property.

        NOTE: This, in addition to the single lock for all state changes, is
        flawed.  In the future we may want to look into buffering state changes
        back to the front-end."""
        self._property_lock = properties
        try:
            yield
        finally:
            self._property_lock = {}

    @contextmanager
    def hold_sync(self):
        """Hold syncing any state until the outermost context manager exits"""
        if self._holding_sync is True:
            yield
        else:
            try:
                self._holding_sync = True
                yield
            finally:
                self._holding_sync = False
                self.send_state(self._states_to_send)
                self._states_to_send.clear()

    def _should_send_property(self, key, value):
        """Check the property lock (property_lock)"""
        to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
        if (key in self._property_lock
                and to_json(value, self) == self._property_lock[key]):
            return False
        elif self._holding_sync:
            self._states_to_send.add(key)
            return False
        else:
            return True

    # Event handlers
    @_show_traceback
    def _handle_msg(self, msg):
        """Called when a msg is received from the front-end"""
        data = msg['content']['data']
        method = data['method']

        # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
        if method == 'backbone':
            if 'sync_data' in data:
                # get binary buffers too
                sync_data = data['sync_data']
                for i, k in enumerate(data.get('buffer_keys', [])):
                    sync_data[k] = msg['buffers'][i]
                self.set_state(sync_data)  # handles all methods

        # Handle a state request.
        elif method == 'request_state':
            self.send_state()

        # Handle a custom msg from the front-end.
        elif method == 'custom':
            if 'content' in data:
                self._handle_custom_msg(data['content'], msg['buffers'])

        # Catch remainder.
        else:
            self.log.error(
                'Unknown front-end to back-end widget msg with method "%s"' %
                method)

    def _handle_custom_msg(self, content, buffers):
        """Called when a custom msg is received."""
        self._msg_callbacks(self, content, buffers)

    def _handle_displayed(self, **kwargs):
        """Called when a view has been displayed for this widget instance"""
        self._display_callbacks(self, **kwargs)

    @staticmethod
    def _trait_to_json(x, self):
        """Convert a trait value to json."""
        return x

    @staticmethod
    def _trait_from_json(x, self):
        """Convert json values to objects."""
        return x

    def _ipython_display_(self, **kwargs):
        """Called when `IPython.display.display` is called on the widget."""
        def loud_error(message):
            self.log.warn(message)
            sys.stderr.write('%s\n' % message)

        # Show view.
        if self._view_name is not None:
            validated = Widget._version_validated

            # Before the user tries to display a widget, validate that the
            # widget front-end is what is expected.
            if validated is None:
                loud_error('Widget Javascript not detected.  It may not be '
                           'installed or enabled properly.')
            elif not validated:
                msg = ('The installed widget Javascript is the wrong version.'
                       ' It must satisfy the semver range %s.' %
                       __frontend_version__)
                if (Widget._version_frontend):
                    msg += ' The widget Javascript is version %s.' % Widget._version_frontend
                loud_error(msg)

            # TODO: delete this sending of a comm message when the display statement
            # below works. Then add a 'text/plain' mimetype to the dictionary below.
            self._send({"method": "display"})

            # The 'application/vnd.jupyter.widget-view+json' mimetype has not been registered yet.
            # See the registration process and naming convention at
            # http://tools.ietf.org/html/rfc6838
            # and the currently registered mimetypes at
            # http://www.iana.org/assignments/media-types/media-types.xhtml.
            # We don't have a 'text/plain' entry, so this display message will be
            # will be invisible in the current notebook.
            data = {
                'application/vnd.jupyter.widget-view+json': {
                    'model_id': self._model_id
                }
            }
            display(data, raw=True)

            self._handle_displayed(**kwargs)

    def _send(self, msg, buffers=None):
        """Sends a message to the model in the front-end."""
        if self.comm is not None and self.comm.kernel is not None:
            self.comm.send(data=msg, buffers=buffers)
Пример #25
0
class _MultipleSelection(DescriptionWidget, ValueWidget, CoreWidget):
    """Base class for multiple Selection widgets

    ``options`` can be specified as a list of values, list of (label, value)
    tuples, or a dict of {label: value}. The labels are the strings that will be
    displayed in the UI, representing the actual Python choices, and should be
    unique. If labels are not specified, they are generated from the values.

    When programmatically setting the value, a reverse lookup is performed
    among the options to check that the value is valid. The reverse lookup uses
    the equality operator by default, but another predicate may be provided via
    the ``equals`` keyword argument. For example, when dealing with numpy arrays,
    one may set equals=np.array_equal.
    """

    value = TypedTuple(trait=Any(), help="Selected values")
    label = TypedTuple(trait=Unicode(), help="Selected labels")
    index = TypedTuple(trait=Int(), help="Selected indices").tag(sync=True)

    options = Any(
        (),
        help=
        """Iterable of values or (label, value) pairs that the user can select.

    The labels are the strings that will be displayed in the UI, representing the
    actual Python choices, and should be unique.
    """)
    _options_full = None

    # This being read-only means that it cannot be changed from the frontend!
    _options_labels = TypedTuple(
        trait=Unicode(), read_only=True,
        help="The labels for the options.").tag(sync=True)

    disabled = Bool(help="Enable or disable user changes").tag(sync=True)

    def __init__(self, *args, **kwargs):
        self.equals = kwargs.pop('equals', lambda x, y: x == y)

        # We have to make the basic options bookkeeping consistent
        # so we don't have errors the first time validators run
        self._initializing_traits_ = True
        kwargs['options'] = _exhaust_iterable(kwargs.get('options', ()))
        self._options_full = _make_options(kwargs['options'])
        self._propagate_options(None)

        super().__init__(*args, **kwargs)
        self._initializing_traits_ = False

    @validate('options')
    def _validate_options(self, proposal):
        proposal.value = _exhaust_iterable(proposal.value)
        # throws an error if there is a problem converting to full form
        self._options_full = _make_options(proposal.value)
        return proposal.value

    @observe('options')
    def _propagate_options(self, change):
        "Unselect any option"
        options = self._options_full
        self.set_trait('_options_labels', tuple(i[0] for i in options))
        self._options_values = tuple(i[1] for i in options)
        if self._initializing_traits_ is not True:
            self.index = ()

    @validate('index')
    def _validate_index(self, proposal):
        "Check the range of each proposed index."
        if all(0 <= i < len(self._options_labels) for i in proposal.value):
            return proposal.value
        else:
            raise TraitError('Invalid selection: index out of bounds')

    @observe('index')
    def _propagate_index(self, change):
        "Propagate changes in index to the value and label properties"
        label = tuple(self._options_labels[i] for i in change.new)
        value = tuple(self._options_values[i] for i in change.new)
        # we check equality so we can avoid validation if possible
        if self.label != label:
            self.label = label
        if self.value != value:
            self.value = value

    @validate('value')
    def _validate_value(self, proposal):
        "Replace all values with the actual objects in the options list"
        try:
            return tuple(
                findvalue(self._options_values, i, self.equals)
                for i in proposal.value)
        except ValueError:
            raise TraitError('Invalid selection: value not found')

    @observe('value')
    def _propagate_value(self, change):
        index = tuple(self._options_values.index(i) for i in change.new)
        if self.index != index:
            self.index = index

    @validate('label')
    def _validate_label(self, proposal):
        if any(i not in self._options_labels for i in proposal.value):
            raise TraitError('Invalid selection: label not found')
        return proposal.value

    @observe('label')
    def _propagate_label(self, change):
        index = tuple(self._options_labels.index(i) for i in change.new)
        if self.index != index:
            self.index = index

    def _repr_keys(self):
        keys = super()._repr_keys()
        # Include options manually, as it isn't marked as synced:
        yield from sorted(chain(keys, ('options', )))
Пример #26
0
class Dazer(Configurable):
    emb_file = Unicode('None', help='词向量文件路径').tag(config=True)
    vocabulary_size = Int(400000, help='加载的词汇数量').tag(config=True)
    embedding_size = Int(50, help='词潜入的纬度').tag(config=True)
    train_file = Unicode('None', help='训练样本文件').tag(config=True)
    train_labels = Unicode('None', help='需要训练的分类标签').tag(config=True)
    all_labels = Unicode('None', help='全分类标签').tag(config=True)

    test_file = Unicode('None', help='训练样本文件').tag(config=True)
    ckpt_path = Unicode('None', help='模型保存路径').tag(config=True)
    test_result_file = Unicode('None', help='测试数据结果').tag(config=True)
    summary_path = Unicode('None', help='统计标量保存路径').tag(config=True)

    kernal_num = Int(50, help="滤波器个数").tag(config=True)
    kernal_width = Int(5, help="滤波器宽度").tag(config=True)

    max_epochs = Int(10, help="最大训练轮数").tag(config=True)
    eval_frequency = Int(1, help="将数据写入可视化的频率").tag(config=True)
    batch_size = Int(16, help="批次大小").tag(config=True)
    load_model = Bool(False, help="是否加载已经训练的模型").tag(config=True)

    max_pooling_num = Int(3, help="最池化k值").tag(config=True)
    decoder_mlp1_num = Int(75, help="解码第一层全连接的神经元个数").tag(config=True)

    regular_term = Float(0.01,
                         help='正则项系数,对抗分类和相似性回归 共享这个超参数').tag(config=True)
    adv_learning_rate = Float(0.001, help='对抗分类的学习率').tag(config=True)
    epsilon = Float(0.00001, help='对抗分类的学习率').tag(config=True)
    model_learning_rate = Float(0.001, help='相似性模型的学习率').tag(config=True)
    adv_loss_term = Float(0.2, help='adv损失在总损失中占的比例').tag(config=True)
    """零样本过滤模型"""
    def __init__(self, **kwargs):
        super(Dazer, self).__init__(**kwargs)
        # 定义类的实例属性
        self.word2id, self.id2word, self.emb = we.load_word2vec(
            self.emb_file, self.vocabulary_size, self.embedding_size)
        self.class_num = len(self.train_labels.split(','))
        # 初始化数据流图
        self.g = tf.Graph()
        self.structure_init()

    def structure_init(self):
        """定义数据流图结构"""
        with self.g.as_default():
            # 定义输入
            self.input_q = tf.placeholder(dtype=tf.int32, shape=[None, None])
            self.input_q_len = tf.placeholder(dtype=tf.float32, shape=[
                None,
            ])
            self.input_pos_d = tf.placeholder(dtype=tf.int32,
                                              shape=[None, None])
            self.input_neg_d = tf.placeholder(dtype=tf.int32,
                                              shape=[None, None])
            self.input_l = tf.placeholder(dtype=tf.int32,
                                          shape=[None, self.class_num])

            # 词嵌入
            emb_q = tf.nn.embedding_lookup(self.emb, self.input_q)
            emb_pos_d = tf.nn.embedding_lookup(self.emb, self.input_pos_d)
            emb_neg_d = tf.nn.embedding_lookup(self.emb, self.input_neg_d)

            # 生成标签的门向量
            class_vec = tf.divide(tf.reduce_sum(emb_q, axis=1),
                                  tf.expand_dims(self.input_q_len, axis=-1))
            query_gate_weights = weight_init(
                [self.embedding_size, self.kernal_num], 'gate_weights')
            query_gate_bias = tf.Variable(initial_value=tf.zeros(
                self.kernal_num, ),
                                          name='gate_bias')
            # shape:[bath_size, kernal_num]
            gate_vec = tf.sigmoid(
                tf.matmul(class_vec, query_gate_weights) + query_gate_bias)
            rs_gate_vec = tf.expand_dims(gate_vec, axis=1)

            # 生成文档向量
            pos_sub_info = tf.subtract(tf.expand_dims(class_vec, axis=1),
                                       emb_pos_d)
            pos_mul_info = tf.multiply(emb_pos_d,
                                       tf.expand_dims(class_vec, axis=1))
            conv_pos_input = tf.expand_dims(tf.concat(
                [emb_pos_d, pos_sub_info, pos_mul_info], -1),
                                            axis=-1)

            neg_sub_info = tf.subtract(tf.expand_dims(class_vec, axis=1),
                                       emb_neg_d)
            neg_mul_info = tf.multiply(emb_neg_d,
                                       tf.expand_dims(class_vec, axis=1))
            conv_neg_input = tf.expand_dims(tf.concat(
                [emb_neg_d, neg_sub_info, neg_mul_info], -1),
                                            axis=-1)

            # 卷积操作提取文档窗口特征
            pos_conv = tf.layers.conv2d(inputs=conv_pos_input,
                                        filters=self.kernal_num,
                                        kernel_size=(self.kernal_width,
                                                     self.embedding_size * 3),
                                        strides=(1, self.embedding_size * 3),
                                        padding="same",
                                        name='doc_conv',
                                        trainable=True)
            neg_conv = tf.layers.conv2d(inputs=conv_neg_input,
                                        filters=self.kernal_num,
                                        kernel_size=(self.kernal_width,
                                                     self.embedding_size * 3),
                                        strides=(1, self.embedding_size * 3),
                                        padding="same",
                                        name='doc_conv',
                                        trainable=True,
                                        reuse=True)
            # shape=[batch,max_dlen,1,kernal_num]
            # reshape to [batch,max_dlen,kernal_num]
            rs_pos_conv = tf.squeeze(pos_conv, [2])
            rs_neg_conv = tf.squeeze(neg_conv, [2])

            pos_gate_conv = tf.multiply(rs_gate_vec, rs_pos_conv)
            neg_gate_conv = tf.multiply(rs_gate_vec, rs_neg_conv)

            # top_k 池化
            transpose_pos_gate_conv = tf.transpose(
                pos_gate_conv,
                [0, 2, 1])  # reshape to [batch,kernal_num,max_dlen]
            pos_conv_k_max_pooling, _ = tf.nn.top_k(transpose_pos_gate_conv,
                                                    self.max_pooling_num)
            pos_encoder = tf.reshape(
                pos_conv_k_max_pooling,
                [-1, self.kernal_num * self.max_pooling_num])

            transpose_neg_gate_conv = tf.transpose(neg_gate_conv, [0, 2, 1])
            neg_conv_k_max_pooling, _ = tf.nn.top_k(transpose_neg_gate_conv,
                                                    self.max_pooling_num)
            neg_encoder = tf.reshape(
                neg_conv_k_max_pooling,
                [-1, self.kernal_num * self.max_pooling_num])

            pos_decoder_mlp1 = tf.layers.dense(inputs=pos_encoder,
                                               units=self.decoder_mlp1_num,
                                               activation=tf.nn.tanh,
                                               trainable=True,
                                               name='decoder_mlp1')
            neg_encoder_mlp1 = tf.layers.dense(inputs=neg_encoder,
                                               units=self.decoder_mlp1_num,
                                               activation=tf.nn.tanh,
                                               trainable=True,
                                               name='decoder_mlp1',
                                               reuse=True)

            self.pos_score = tf.layers.dense(inputs=pos_decoder_mlp1,
                                             units=1,
                                             activation=tf.nn.tanh,
                                             trainable=True,
                                             name='decoder_mlp2')
            neg_score = tf.layers.dense(inputs=neg_encoder_mlp1,
                                        units=1,
                                        activation=tf.nn.tanh,
                                        trainable=True,
                                        name='decoder_mlp2',
                                        reuse=True)
            hinge_loss = tf.reduce_mean(
                tf.maximum(0.0, 1 - self.pos_score + neg_score))
            tf.summary.scalar('hinge_loss', hinge_loss)

            # 对抗学习
            adv_weight = weight_init([self.decoder_mlp1_num, self.class_num],
                                     'adv_weights')
            adv_bias = tf.Variable(initial_value=tf.zeros(self.class_num, ),
                                   name='adv_bias')

            adv_prob = tf.nn.softmax(
                tf.add(tf.matmul(pos_decoder_mlp1, adv_weight), adv_bias))
            adv_prob_log = tf.log(adv_prob)
            adv_loss = tf.reduce_mean(
                tf.reduce_sum(tf.multiply(adv_prob_log,
                                          tf.cast(self.input_l, tf.float32)),
                              axis=1))
            adv_l2_loss = self.regular_term * l2_loss([
                v for v in tf.trainable_variables()
                if 'b' not in v.name and 'adv' in v.name
            ])
            loss_cat = -1 * adv_loss + adv_l2_loss
            # 动态学习率
            self.lr = tf.Variable(self.model_learning_rate,
                                  trainable=False,
                                  name='learning_rate')
            self.adv_train_op = tf.train.AdamOptimizer(learning_rate=self.lr, epsilon=self.epsilon)\
                .minimize(loss_cat, var_list=[v for v in tf.trainable_variables() if 'adv' in v.name])
            tf.summary.scalar('loss_cat', loss_cat)

            model_l2_loss = self.regular_term * l2_loss([
                v for v in tf.trainable_variables()
                if 'b' not in v.name and 'adv' not in v.name
            ])
            loss = hinge_loss + model_l2_loss + adv_loss * self.adv_loss_term
            self.model_train_op = tf.train.AdamOptimizer(learning_rate=self.lr, epsilon=self.epsilon)\
                .minimize(loss, var_list=[v for v in tf.trainable_variables() if 'adv' not in v.name])
            tf.summary.scalar('loss', loss)
            self.merged = tf.summary.merge_all()

    def analysis_result(self, sess, test_df):
        """分析预测结果"""
        result = dd1_list()
        for batch in dp.input_scale_data_test(test_df,
                                              self.max_pooling_num,
                                              batch_size=self.batch_size):
            query_dict = batch['query_dict']
            query_len_dict = batch['query_len_dict']
            doc = batch['doc']
            l_context = batch['l_context']
            l_label = batch['l_label']
            # 生成结果字典
            result['real_label'] += l_label
            result['context'] += l_context
            for label, query in query_dict.items():
                score = sess.run(self.pos_score,
                                 feed_dict={
                                     self.input_q: query,
                                     self.input_pos_d: doc,
                                     self.input_q_len: query_len_dict[label]
                                 })
                result[label] += list(np.reshape(score, (-1, )))
        result_df = pd.DataFrame(result)
        # 自动获取标签
        labels = [label for label, query in query_dict.items()]
        # 写死标签顺序, 其他数据集需要修改此标签顺序
        # labels = ['very negative', 'negative', 'neutral', 'positive', 'very positive']

        result_df['pre_label'] = result_df[labels].idxmax(axis=1)
        f1 = metrics.f1_score(result_df['real_label'],
                              result_df['pre_label'],
                              average='weighted')
        result_report = classification_report(result_df['real_label'],
                                              result_df['pre_label'])
        con_matrix = confusion_matrix(result_df['real_label'],
                                      result_df['pre_label'],
                                      labels=labels)
        return result_df, result_report, f1, labels, con_matrix

    def test(self):
        """模型预测"""
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config, graph=self.g) as sess:
            train_vars = [v for v in tf.trainable_variables()]
            saver = tf.train.Saver(var_list=train_vars)
            saver.restore(sess, self.ckpt_path +
                          '-77212')  # 是否自动读取最新的一次,还是要手动设置保存的第几次
            test_df = pd.read_csv(self.test_file, engine='python')
            # 打开调试模式
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            # 定义结果
            result_df, result_report, f1, labels, con_matrix = self.analysis_result(
                sess, test_df)
            logger.info('生成测试集预测统计,并保存具体预测结果:')
            logger.info('\n' + result_report)
            result_df.to_csv(self.test_result_file)

    # def train(self):
    #     """单独一次零样本模型训练"""
    #     config = tf.ConfigProto()
    #     config.gpu_options.allow_growth = True
    #     with tf.Session(config=config, graph=self.g) as sess:
    #         train_vars = [v for v in tf.trainable_variables()]
    #         saver = tf.train.Saver(var_list=train_vars, max_to_keep=3)
    #         sess.run(tf.global_variables_initializer())
    #         writer = tf.summary.FileWriter(self.summary_path, self.g)
    #         train_df = pd.read_csv(self.train_file, engine='python')
    #         test_df = pd.read_csv(self.test_file, engine='python')
    #         step = 0
    #         # 打开调试模式
    #         # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    #         logger.info('训练批次大小:{}'.format(self.batch_size))
    #         for epoch in range(int(self.max_epochs)):
    #             if epoch+1 % 25 == 0:
    #                 sess.run(self.lr.assign(self.model_learning_rate / 5.0))
    #             # sess.run(self.lr.assign(self.model_learning_rate * 0.95**epoch))
    #             for batch in dp.input_scale_data_train(train_df, self.train_labels.split(','), batch_size=self.batch_size, min_len=self.max_pooling_num):
    #                 query = batch['query']
    #                 label = batch['label']
    #                 pos_doc = batch['pos_doc']
    #                 neg_doc = batch['neg_doc']
    #                 query_len = batch['query_len']
    #                 merged, ato, mto = sess.run([self.merged, self.adv_train_op, self.model_train_op], feed_dict={self.input_q: query, self.input_l: label, self.input_pos_d: pos_doc, self.input_neg_d: neg_doc, self.input_q_len: query_len})
    #                 # 可视化
    #                 writer.add_summary(merged, global_step=step)
    #                 writer.flush()
    #                 step += 1
    #             # 每5次做一次验证并打印
    #             if epoch % 5 == 0:
    #                 logger.info('第{}次训练结果:'.format(epoch+1))
    #                 logger.info('训练样本和零样本的准确率:')
    #                 self.analysis_result(sess, train_df)
    #                 logger.info('测试样本和零样本的准确率:')
    #                 self.analysis_result(sess, test_df)
    #             saver.save(sess, self.ckpt_path, global_step=step)
    #             logger.info('第{}次epoch训练完成。'.format(epoch+1))
    #         writer.close()
    #     pass

    def train_all(self):
        """自动循环训练所有标签,使得每一个标签都做为零样本标签,并观察其泛化效果"""
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        now_time = datetime.now().strftime('%Y%m%d%H%M')
        train_df = pd.read_csv(self.train_file, engine='python')
        test_df = pd.read_csv(self.test_file, engine='python')
        logger.info(
            '------------------------------------------------------------------------------------------------'
        )
        logger.info('启动全样本训练, 时间:{}'.format(now_time))

        for zero_label, train_labels, train_data, test_data in reorganize_data(
                self.all_labels, train_df, test_df):
            #  单独对negative零样本标签,进行调参
            if zero_label != 'very positive':
                continue
            else:
                # self.regular_term = 0.0005  # negative 正则项系数
                # self.regular_term = 0.0001  # very negative 正则项系数
                self.regular_term = 0.001  # positive , very positive 正则项系数
                self.adv_learning_rate = 0.01  # positive , very positive 学习率

            with tf.Session(config=config, graph=self.g) as sess:
                train_vars = [v for v in tf.trainable_variables()]
                saver = tf.train.Saver(var_list=train_vars, max_to_keep=3)
                sess.run(tf.global_variables_initializer())
                max_f1 = 0.0
                best_test_result_report = ''
                best_con_matrix_report = ''

                step = 0
                # 保存在特定标签下路径中
                writer = tf.summary.FileWriter(
                    self.summary_path.format(now_time,
                                             zero_label.replace(' ', '')),
                    self.g)
                logger.info('开启学习---零样本标签:{}'.format(zero_label))
                logger.info(
                    '本次训练embedding_size:{}. max_epochs:{}, batch_size: {}, regular_term:{}, adv_learning_rate:'
                    '{}, epsilon:{}, adv_loss_term:{}'.format(
                        self.embedding_size, self.max_epochs, self.batch_size,
                        self.regular_term, self.adv_learning_rate,
                        self.epsilon, self.adv_loss_term))
                # 动态学习率
                for epoch in range(int(self.max_epochs)):
                    # 动态学习率
                    # if epoch + 1 % 25 == 0:
                    #     sess.run(self.lr.assign(self.model_learning_rate / 5.0))  # 倍数缩小设置方式
                    #     sess.run(self.lr.assign(self.model_learning_rate * 0.95**epoch))  # 指数设置方式
                    for batch in dp.input_scale_data_train(
                            train_data,
                            train_labels,
                            batch_size=self.batch_size,
                            min_len=self.max_pooling_num):
                        query = batch['query']
                        label = batch['label']
                        pos_doc = batch['pos_doc']
                        neg_doc = batch['neg_doc']
                        query_len = batch['query_len']
                        merged, ato, mto = sess.run(
                            [
                                self.merged, self.adv_train_op,
                                self.model_train_op
                            ],
                            feed_dict={
                                self.input_q: query,
                                self.input_l: label,
                                self.input_pos_d: pos_doc,
                                self.input_neg_d: neg_doc,
                                self.input_q_len: query_len
                            })
                        # 可视化
                        writer.add_summary(merged, global_step=step)
                        writer.flush()
                        step += 1
                    # 每5次做一次验证并打印
                    if epoch % 5 == 0:
                        logger.info('第{}次训练结果:'.format(epoch + 1))
                        logger.info('训练样本的准确率:')
                        train_result_df, train_result_report, train_f1, train_labels, train_con_matrix = self.analysis_result(
                            sess, train_data)
                        logger.info('\n' + train_result_report)
                        logger.info('测试样本和零样本的准确率:')
                        test_result_df, test_result_report, test_f1, test_labels, test_con_matrix = self.analysis_result(
                            sess, test_data)
                        logger.info('\n' + test_result_report)
                        logger.info('测试样本和零样本的混淆矩阵:{}'.format(test_labels))
                        test_con_matrix_report = con_matrix_2_str(
                            test_labels, test_con_matrix)
                        logger.info('\n' + test_con_matrix_report)
                        if test_f1 > max_f1:
                            # 保存权重值
                            saver.save(sess,
                                       self.ckpt_path.format(
                                           now_time,
                                           zero_label.replace(' ', '')),
                                       global_step=step)
                            # 记录最好结果的统计报告 和 混淆矩阵
                            best_test_result_report = test_result_report
                            best_con_matrix_report = test_con_matrix_report
                            # 保存预测的结果
                            test_result_df.to_csv(
                                self.test_result_file.format(
                                    now_time, zero_label.replace(' ', '')))
                            max_f1 = test_f1
                    if epoch + 1 == int(self.max_epochs):
                        logger.info('测试集上最好的结果:')
                        logger.info('\n' + best_test_result_report)
                        logger.info('\n' + best_con_matrix_report)
                logger.info('第{}次epoch训练完成。'.format(epoch + 1))
                writer.close()
Пример #27
0
class DisplayDL1Calib(Tool):
    name = "ctapipe-display-dl1"
    description = __doc__

    telescope = Int(
        None,
        allow_none=True,
        help="Telescope to view. Set to None to display all telescopes.",
    ).tag(config=True)

    extractor_product = traits.enum_trait(ImageExtractor,
                                          default="NeighborPeakWindowSum")

    aliases = Dict(
        dict(
            input="EventSource.input_url",
            max_events="EventSource.max_events",
            extractor="DisplayDL1Calib.extractor_product",
            T="DisplayDL1Calib.telescope",
            O="ImagePlotter.output_path",
        ))
    flags = Dict(
        dict(D=(
            {
                "ImagePlotter": {
                    "display": True
                }
            },
            "Display the photo-electron images on-screen as they are produced.",
        )))
    classes = List([EventSource, ImagePlotter] +
                   traits.classes_with_traits(ImageExtractor))

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.eventsource = None
        self.calibrator = None
        self.plotter = None

    def setup(self):
        self.eventsource = self.add_component(
            EventSource.from_url(
                get_dataset_path("gamma_test_large.simtel.gz"), parent=self))

        self.calibrator = self.add_component(CameraCalibrator(parent=self))
        self.plotter = self.add_component(ImagePlotter(parent=self))

    def start(self):
        for event in self.eventsource:
            self.calibrator(event)

            tel_list = event.r0.tels_with_data

            if self.telescope:
                if self.telescope not in tel_list:
                    continue
                tel_list = [self.telescope]
            for telid in tel_list:
                self.plotter.plot(event, telid)

    def finish(self):
        self.plotter.finish()
Пример #28
0
class DAZER(BaseNN):
    #params of zeroshot document filtering model
    embedding_size = Int(300, help="embedding dimension").tag(config=True)
    vocabulary_size = Int(2000000, help="vocabulary size").tag(config=True)
    kernal_width = Int(5, help='kernal width').tag(config=True)
    kernal_num = Int(50, help='number of kernal').tag(config=True)
    regular_term = Float(0.01, help='param for controlling wight of L2 loss').tag(config=True)
    maxpooling_num = Int(3, help='number of k-maxpooling').tag(config=True)
    decoder_mlp1_num = Int(75, help='number of hidden units of first mlp in relevance aggregation part').tag(config=True)
    decoder_mlp2_num = Int(1, help='number of hidden units of second mlp in relevance aggregation part').tag(config=True)
    emb_in = Unicode('None', help="initial embedding. Terms should be hashed to ids.").tag(config=True)
    model_learning_rate = Float(0.001, help="learning rate of model").tag(config=True)
    adv_learning_rate = Float(0.001, help='learning rate of adv classifier').tag(config=True)
    epsilon = Float(0.00001, help="Epsilon for Adam").tag(config=True)
    label_dict_path = Unicode('None', help='label dict path').tag(config=True)
    word2id_path = Unicode('None', help='word2id path').tag(config=True)
    train_class_num = Int(16, help='num of class in training data').tag(config=True)
    adv_term = Float(0.2, help='regular term of adversrial loss').tag(config=True)
    zsl_num = Int(1, help='num of zeroshot label').tag(config=True)
    zsl_type = Int(1, help='type of zeroshot label setting').tag(config=True)

    def __init__(self, **kwargs):
        #init the DAZER model
        super(DAZER, self).__init__(**kwargs)
        print ("trying to load initial embeddings from:  ", self.emb_in)
        if self.emb_in != 'None':
            self.emb = self.load_word2vec(self.emb_in)
            self.embeddings = tf.Variable(tf.constant(self.emb, dtype='float32', shape=[self.vocabulary_size + 1, self.embedding_size]),trainable=False)
            print ("Initialized embeddings with {0}".format(self.emb_in))
        else:
            self.embeddings = tf.Variable(tf.random_uniform([self.vocabulary_size + 1, self.embedding_size], -1.0, 1.0))

        #variables of the DAZER model
        self.query_gate_weight = BaseNN.weight_variable((self.embedding_size, self.kernal_num),'gate_weight')
        self.query_gate_bias = tf.Variable(initial_value=tf.zeros((self.kernal_num)),name='gate_bias')
        self.adv_weight = BaseNN.weight_variable((self.decoder_mlp1_num,self.train_class_num),name='adv_weight')
        self.adv_bias = tf.Variable(initial_value=tf.zeros((1,self.train_class_num)),name='adv_bias')
        #get the label information to help adversarial learning
        self.label_dict, self.reverse_label_dict, self.label_list = get_label.get_labels(self.label_dict_path, self.word2id_path)
        self.label_index_dict = get_label.get_label_index(self.label_list, self.zsl_num, self.zsl_type)

    def load_word2vec(self, emb_file_path):
        emb = np.zeros((self.vocabulary_size + 1, self.embedding_size))
        nlines = 0
        with open(emb_file_path) as f:
            for line in f:
                nlines += 1
                if nlines == 1:
                    continue
                items = line.split()
                tid = int(items[0])
                if tid > self.vocabulary_size:
                    print (tid)
                    continue
                vec = np.array([float(t) for t in items[1:]])
                emb[tid, :] = vec
                if nlines % 20000 == 0:
                    print ("load {0} vectors...".format(nlines))
        return emb

    def gen_adv_query_mask(self, q_ids):
        q_mask = np.zeros((self.batch_size, self.train_class_num))
        for batch_num, b_q_id in enumerate(q_ids):
            c_name = self.reverse_label_dict[b_q_id]
            c_index = self.label_index_dict[c_name]
            q_mask[batch_num][c_index] = 1
        return q_mask

    def get_class_gate(self,class_vec, emb_d):
        '''
        compute the gate in kernal space
        :param class_vec: avg emb of seed words
        :param emb_d: emb of doc
        :return:the class gate [batchsize,d_len,kernal_num]
        '''
        gate1 = tf.expand_dims(tf.matmul(class_vec, self.query_gate_weight), axis=1)
        bias = tf.expand_dims(self.query_gate_bias,axis=0)
        gate = tf.add(gate1, bias)
        return tf.sigmoid(gate)

    def L2_model_loss(self):
        all_para = [v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' not in v.name]
        loss = 0.
        for each in all_para:
            loss += tf.nn.l2_loss(each)
        return loss

    def L2_adv_loss(self):
        all_para = [v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' in v.name]
        loss = 0.
        for each in all_para:
            loss += tf.nn.l2_loss(each)
        return loss

    def train(self, train_pair_file_path, val_pair_file_path, checkpoint_dir, load_model=False):

        input_q = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_q_len])
        input_pos_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len])
        input_neg_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len])
        q_lens = tf.placeholder(tf.float32, shape=[self.batch_size,])
        q_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_q_len])
        pos_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len])
        neg_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len])
        input_q_index = tf.placeholder(tf.int32, shape=[self.batch_size,self.train_class_num])

        emb_q = tf.nn.embedding_lookup(self.embeddings,input_q)
        class_vec_sum = tf.reduce_sum(
            tf.multiply(emb_q,tf.expand_dims(q_mask,axis=-1)),
            axis=1
        )

        #get class vec
        class_vec = tf.div(class_vec_sum,tf.expand_dims(q_lens,-1))
        emb_pos_d = tf.nn.embedding_lookup(self.embeddings,input_pos_d)
        emb_neg_d = tf.nn.embedding_lookup(self.embeddings,input_neg_d)

        #get query gate
        pos_query_gate = self.get_class_gate(class_vec, emb_pos_d)
        neg_query_gate = self.get_class_gate(class_vec, emb_neg_d)

        # CNN for document
        pos_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_pos_d)
        pos_sub_info = tf.expand_dims(class_vec,axis=1) - emb_pos_d
        pos_conv_input = tf.concat([emb_pos_d,pos_mult_info,pos_sub_info], axis=-1)
        
        neg_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_neg_d)
        neg_sub_info = tf.expand_dims(class_vec,axis=1) - emb_neg_d
        neg_conv_input = tf.concat([emb_neg_d,neg_mult_info,neg_sub_info], axis=-1)


        #in fact that's 1D conv, but we implement it by conv2d
        pos_conv = tf.layers.conv2d(
            inputs = tf.expand_dims(pos_conv_input,axis=-1),
            filters = self.kernal_num,
            kernel_size=[self.kernal_width,self.embedding_size*3],
            strides = [1,self.embedding_size*3],
            padding = 'SAME',
            trainable = True,
            name='doc_conv'
        )

        neg_conv = tf.layers.conv2d(
            inputs = tf.expand_dims(neg_conv_input,axis=-1),
            filters = self.kernal_num,
            kernel_size=[self.kernal_width,self.embedding_size*3],
            strides = [1,self.embedding_size*3],
            padding = 'SAME',
            trainable = True,
            name='doc_conv',
            reuse=True
        )
        #shape=[batch,max_dlen,1,kernal_num]
        #reshape to [batch,max_dlen,kernal_num]
        rs_pos_conv = tf.squeeze(pos_conv)
        rs_neg_conv = tf.squeeze(neg_conv)

        #query_gate elment-wise multiply rs_pos_conv
        pos_gate_conv = tf.multiply(pos_query_gate, rs_pos_conv)
        neg_gate_conv = tf.multiply(neg_query_gate, rs_neg_conv)

        #K-max_pooling
        #transpose to [batch,knum,dlen],then get max k in each kernal filter
        transpose_pos_gate_conv = tf.transpose(pos_gate_conv, perm=[0,2,1])
        transpose_neg_gate_conv = tf.transpose(neg_gate_conv, perm=[0,2,1])

        #shape = [batch,k_num,maxpolling_num]
        #the k-max pooling here is implemented by function top_k, so the relative position information is ignored
        pos_kmaxpooling,_ = tf.nn.top_k(
            input=transpose_pos_gate_conv,
            k=self.maxpooling_num,
        )
        neg_kmaxpooling,_ = tf.nn.top_k(
            input=transpose_neg_gate_conv,
            k=self.maxpooling_num,
        )

        pos_encoder = tf.reshape(pos_kmaxpooling, shape=(self.batch_size,-1))
        neg_encoder = tf.reshape(neg_kmaxpooling, shape=(self.batch_size,-1))

        pos_decoder_mlp1 = tf.layers.dense(
            inputs=pos_encoder,
            units=self.decoder_mlp1_num,
            activation=tf.nn.tanh,
            trainable=True,
            name='decoder_mlp1'
        )

        neg_decoder_mlp1 = tf.layers.dense(
            inputs=neg_encoder,
            units=self.decoder_mlp1_num,
            activation=tf.nn.tanh,
            trainable=True,
            name='decoder_mlp1',
            reuse=True
        )

        pos_decoder_mlp2 = tf.layers.dense(
            inputs=pos_decoder_mlp1,
            units=self.decoder_mlp2_num,
            activation=tf.nn.tanh,
            trainable=True,
            name='decoder_mlp2'
        )

        neg_decoder_mlp2 = tf.layers.dense(
            inputs=neg_decoder_mlp1,
            units=self.decoder_mlp2_num,
            activation=tf.nn.tanh,
            trainable=True,
            name='decoder_mlp2',
            reuse=True
        )

        score_pos = pos_decoder_mlp2
        score_neg = neg_decoder_mlp2

        hinge_loss = tf.reduce_mean(tf.maximum(0.0, 1 - score_pos + score_neg))
        adv_prob = tf.nn.softmax(tf.add(tf.matmul(pos_decoder_mlp1, self.adv_weight), self.adv_bias))
        log_adv_prob = tf.log(adv_prob)
        adv_loss = tf.reduce_mean(tf.reduce_sum(tf.multiply(log_adv_prob, tf.cast(input_q_index,tf.float32)), axis=1, keep_dims=True))
        L2_adv_loss = self.regular_term*self.L2_adv_loss()

        #to apply GRL, we use two seperate optimizers for adversarial classifier and the rest part of DAZER
        #optimizer for adversarial classifier
        adv_var_list = [v for v in tf.trainable_variables() if 'adv' in v.name]
        adv_opt = tf.train.AdamOptimizer(learning_rate=self.adv_learning_rate, epsilon=self.epsilon).minimize(loss=(-1 * adv_loss + L2_adv_loss), var_list=adv_var_list)

        #optimizer for rest part of DAZER model
        L2_model_loss = self.regular_term*self.L2_model_loss()
        model_var_list = [v for v in tf.trainable_variables() if 'adv' not in v.name]
        loss = hinge_loss + L2_model_loss + (adv_loss * self.adv_term)
        model_opt = tf.train.AdamOptimizer(learning_rate=self.model_learning_rate, epsilon=self.epsilon).minimize(loss = loss, var_list = model_var_list)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        val_results = []
        save_num = 0
        save_var = [v for v in tf.trainable_variables()]

        # Create a local session to run the training.
        with tf.Session(config=config) as sess:
            saver = tf.train.Saver(max_to_keep=50,var_list=save_var)
            start_time = time.time()
            if not load_model:
                print ("Initializing a new model...")
                init = tf.global_variables_initializer()
                sess.run(init)
                print('New model initialized!')
            else:
                #to load trained model, and keep training
                #remember to change the name of ckpt file
                init = tf.global_variables_initializer()
                sess.run(init)
                saver.restore(sess, checkpoint_dir+'/zsl25.ckpt')
                print ("model loaded!")

            # Loop through training steps.
            step = 0
            loss_list = []
            for epoch in range(int(self.max_epochs)):
                epoch_val_loss = 0
                epoch_loss = 0
                epoch_hinge_loss = 0.
                epoch_adv_loss = 0
                epoch_s = time.time()
                pair_stream = open(train_pair_file_path)

                for BATCH in self.data_generator.pairwise_reader(pair_stream, self.batch_size):
                    step += 1
                    X, Y = BATCH
                    query = X[u'q']
                    str_query = X[u'q_str']
                    q_index = self.gen_adv_query_mask(str_query)
                    pos_doc = X[u'd']
                    neg_doc = X[u'd_aux']
                    train_q_lens = X[u'q_lens']
                    M_query = self.gen_query_mask(query)
                    M_pos = self.gen_doc_mask(pos_doc)
                    M_neg = self.gen_doc_mask(neg_doc)

                    if X[u'q_lens'].shape[0] != self.batch_size:
                        continue
                    train_feed_dict = {input_q:query,
                                       input_pos_d:pos_doc,
                                       q_lens:train_q_lens,
                                       input_neg_d:neg_doc,
                                       q_mask:M_query,
                                       pos_d_mask:M_pos,
                                       neg_d_mask:M_neg,
                                       input_q_index: q_index}

                    _1,l,hinge_l,_2,adv_l  = sess.run([model_opt,loss,hinge_loss,adv_opt,adv_loss], feed_dict=train_feed_dict)
                    epoch_loss += l
                    epoch_hinge_loss += hinge_l
                    epoch_adv_loss += adv_l

                if (epoch + 1) % self.eval_frequency == 0:
                    #after eval_frequency epochs we run model on val dataset
                    val_start = time.time()
                    val_pair_stream = open(val_pair_file_path)
                    for BATCH in self.val_data_generator.pairwise_reader(val_pair_stream, self.batch_size):
                        X_val,Y_val = BATCH
                        query = X_val[u'q']
                        pos_doc = X_val[u'd']
                        neg_doc = X_val[u'd_aux']
                        val_q_lens = X_val[u'q_lens']
                        M_query = self.gen_query_mask(query)
                        M_pos = self.gen_doc_mask(pos_doc)
                        M_neg = self.gen_doc_mask(neg_doc)
                        if X_val[u'q'].shape[0] != self.batch_size:
                            continue
                        train_feed_dict = {input_q:query,
                                           input_pos_d:pos_doc,
                                           input_neg_d:neg_doc,
                                           q_lens:val_q_lens,
                                           q_mask:M_query,
                                           pos_d_mask:M_pos,
                                           neg_d_mask:M_neg}

                        # Run the graph and fetch some of the nodes.
                        v_loss = sess.run(hinge_loss, feed_dict=train_feed_dict)
                        epoch_val_loss += v_loss
                        val_results.append(epoch_val_loss)

                    val_end = time.time()
                    print('---Validation:epoch %d, %.1f ms , val_loss are %f' % (epoch+1,val_end-val_start,epoch_val_loss))
                    sys.stdout.flush()
                loss_list.append(epoch_loss)
                epoch_e = time.time()
                print('---Train:%d epoches cost %f seconds, hinge cost = %f  model cost = %f, adv cost = %f...'%(epoch+1,epoch_e-epoch_s,epoch_hinge_loss, epoch_loss,epoch_adv_loss))
                # save model after checkpoint_steps epochs
                if (epoch+1)%self.checkpoint_steps == 0:
                    save_num += 1
                    saver.save(sess, checkpoint_dir + 'zsl'+str(epoch+1)+'.ckpt')
                pair_stream.close()

            with open('save_training_loss.txt','w') as f:
                for index,_loss in enumerate(loss_list):
                    f.write('epoch'+str(index+1)+', loss:'+str(_loss)+'\n')

            with open('save_val_cost.txt','w') as f:
                for index, v_l in enumerate(val_results):
                    f.write('epoch'+str((index+1)*self.eval_frequency)+' val loss:'+str(v_l)+'\n')

            # end training
            end_time = time.time()
            print('All costs %f seconds...'%(end_time-start_time))

    def test(self, test_point_file_path, test_size, output_file_path, checkpoint_dir=None, load_model=False):

        input_q = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_q_len])
        input_pos_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len])
        q_lens = tf.placeholder(tf.float32, shape=[self.batch_size,])
        q_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_q_len])
        pos_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len])

        emb_q = tf.nn.embedding_lookup(self.embeddings,input_q)
        class_vec_sum = tf.reduce_sum(
            tf.multiply(emb_q,tf.expand_dims(q_mask,axis=-1)),
            axis=1
        )

        class_vec = tf.div(class_vec_sum,tf.expand_dims(q_lens,axis=-1))
        emb_pos_d = tf.nn.embedding_lookup(self.embeddings,input_pos_d)

        #get query gate
        query_gate = self.get_class_gate(class_vec, emb_pos_d)
        pos_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_pos_d)
        pos_sub_info = tf.expand_dims(class_vec, axis=1) - emb_pos_d
        pos_conv_input = tf.concat([emb_pos_d,pos_mult_info, pos_sub_info], axis=-1)

        # CNN for document
        pos_conv = tf.layers.conv2d(
            inputs = tf.expand_dims(pos_conv_input,axis=-1),
            filters = self.kernal_num,
            kernel_size=[self.kernal_width,self.embedding_size*3],
            strides = [1,self.embedding_size*3],
            padding = 'SAME',
            trainable = True,
            name='doc_conv'
        )

        #shape=[batch,max_dlen,1,kernal_num]
        #reshape to [batch,max_dlen,kernal_num]
        rs_pos_conv = tf.squeeze(pos_conv)

        #query_gate elment-wise multiply rs_pos_conv
        #[batch,kernal_num] , [batch,max_dlen,kernal_num]
        pos_gate_conv = tf.multiply(query_gate, rs_pos_conv)

        #K-max_pooling
        #transpose to [batch,knum,dlen],then get max k in each kernal filter
        transpose_pos_gate_conv = tf.transpose(pos_gate_conv, perm=[0,2,1])

        #[batch,k_num,maxpolling_num]
        pos_kmaxpooling,_ = tf.nn.top_k(
            input=transpose_pos_gate_conv,
            k=self.maxpooling_num,
        )
        pos_encoder = tf.reshape(pos_kmaxpooling, shape=(self.batch_size,-1))

        pos_decoder_mlp1 = tf.layers.dense(
            inputs=pos_encoder,
            units=self.decoder_mlp1_num,
            activation=tf.nn.tanh,
            trainable=True,
            name='decoder_mlp1'
        )

        pos_decoder_mlp2 = tf.layers.dense(
            inputs=pos_decoder_mlp1,
            units=self.decoder_mlp2_num,
            activation=tf.nn.tanh,
            trainable=True,
            name='decoder_mlp2'
        )

        score_pos = pos_decoder_mlp2
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        save_var = [v for v in tf.trainable_variables()]
        # Create a local session to run the testing.
        for i in range(int(self.max_epochs/self.checkpoint_steps)):
            with tf.Session(config=config) as sess:
                test_point_stream = open(test_point_file_path)
                outfile = open(output_file_path+'-epoch'+str(self.checkpoint_steps*(i+1))+'.txt', 'w')
                saver = tf.train.Saver(var_list=save_var)

                if load_model:
                    p = checkpoint_dir + 'zsl'+str(self.checkpoint_steps*(i+1))+'.ckpt'
                    init = tf.global_variables_initializer()
                    sess.run(init)
                    saver.restore(sess, p)
                    print ("data loaded!")
                else:
                    init = tf.global_variables_initializer()
                    sess.run(init)

                # Loop through training steps.
                for b in range(int(np.ceil(float(test_size)/self.batch_size))):
                    X = next(self.test_data_generator.test_pairwise_reader(test_point_stream, self.batch_size))
                    if(X[u'q'].shape[0] != self.batch_size):
                        continue
                    query = X[u'q']
                    pos_doc = X[u'd']
                    test_q_lens = X[u'q_lens']
                    M_query = self.gen_query_mask(query)
                    M_pos = self.gen_doc_mask(pos_doc)
                    test_feed_dict = {input_q: query,
                                       input_pos_d: pos_doc,
                                       q_lens: test_q_lens,
                                       q_mask: M_query,
                                      pos_d_mask: M_pos}

                    # Run the graph and fetch some of the nodes.
                    scores = sess.run(score_pos, feed_dict=test_feed_dict)

                    for score in scores:
                        outfile.write('{0}\n'.format(score[0]))

                outfile.close()
                test_point_stream.close()
Пример #29
0
class DockerSpawner(Spawner):
    """A Spawner for JupyterHub that runs each user's server in a separate docker container"""

    _executor = None

    @property
    def executor(self):
        """single global executor"""
        cls = self.__class__
        if cls._executor is None:
            cls._executor = ThreadPoolExecutor(1)
        return cls._executor

    _client = None

    @property
    def client(self):
        """single global client instance"""
        cls = self.__class__
        if cls._client is None:
            kwargs = {"version": "auto"}
            if self.tls_config:
                kwargs["tls"] = docker.tls.TLSConfig(**self.tls_config)
            kwargs.update(kwargs_from_env())
            kwargs.update(self.client_kwargs)
            client = docker.APIClient(**kwargs)
            cls._client = client
        return cls._client

    # notice when user has set the command
    # default command is that of the container,
    # but user can override it via config
    _user_set_cmd = False

    @observe("cmd")
    def _cmd_changed(self, change):
        self._user_set_cmd = True

    object_id = Unicode()
    # the type of object we create
    object_type = "container"
    # the field containing the object id
    object_id_key = "Id"

    @property
    def container_id(self):
        """alias for object_id"""
        return self.object_id

    @property
    def container_name(self):
        """alias for object_name"""
        return self.object_name

    # deprecate misleading container_ip, since
    # it is not the ip in the container,
    # but the host ip of the port forwarded to the container
    # when use_internal_ip is False
    container_ip = Unicode("127.0.0.1", config=True)

    @observe("container_ip")
    def _container_ip_deprecated(self, change):
        self.log.warning(
            "DockerSpawner.container_ip is deprecated in dockerspawner-0.9."
            "  Use DockerSpawner.host_ip to specify the host ip that is forwarded to the container"
        )
        self.host_ip = change.new

    host_ip = Unicode(
        "127.0.0.1",
        help=
        """The ip address on the host on which to expose the container's port

        Typically 127.0.0.1, but can be public interfaces as well
        in cases where the Hub and/or proxy are on different machines
        from the user containers.

        Only used when use_internal_ip = False.
        """,
        config=True,
    )

    @default('host_ip')
    def _default_host_ip(self):
        docker_host = os.getenv('DOCKER_HOST')
        if docker_host:
            urlinfo = urlparse(docker_host)
            if urlinfo.scheme == 'tcp':
                return urlinfo.hostname
        return '127.0.0.1'

    # unlike container_ip, container_port is the internal port
    # on which the server is bound.
    container_port = Int(8888, min=1, max=65535, config=True)

    @observe("container_port")
    def _container_port_changed(self, change):
        self.log.warning(
            "DockerSpawner.container_port is deprecated in dockerspawner 0.9."
            "  Use DockerSpawner.port")
        self.port = change.new

    # fix default port to 8888, used in the container

    @default("port")
    def _port_default(self):
        return 8888

    # default to listening on all-interfaces in the container

    @default("ip")
    def _ip_default(self):
        return "0.0.0.0"

    container_image = Unicode("jupyterhub/singleuser:%s" % _jupyterhub_xy,
                              config=True)

    @observe("container_image")
    def _container_image_changed(self, change):
        self.log.warning(
            "DockerSpawner.container_image is deprecated in dockerspawner 0.9."
            "  Use DockerSpawner.image")
        self.image = change.new

    image = Unicode(
        "jupyterhub/singleuser:%s" % _jupyterhub_xy,
        config=True,
        help="""The image to use for single-user servers.

        This image should have the same version of jupyterhub as
        the Hub itself installed.

        If the default command of the image does not launch
        jupyterhub-singleuser, set `c.Spawner.cmd` to
        launch jupyterhub-singleuser, e.g.

        Any of the jupyter docker-stacks should work without additional config,
        as long as the version of jupyterhub in the image is compatible.
        """,
    )

    image_whitelist = Union(
        [Any(), Dict(), List()],
        default_value={},
        config=True,
        help="""
        List or dict of images that users can run.

        If specified, users will be presented with a form
        from which they can select an image to run.

        If a dictionary, the keys will be the options presented to users
        and the values the actual images that will be launched.

        If a list, will be cast to a dictionary where keys and values are the same
        (i.e. a shortcut for presenting the actual images directly to users).

        If a callable, will be called with the Spawner instance as its only argument.
        The user is accessible as spawner.user.
        The callable should return a dict or list as above.
        """,
    )

    @validate('image_whitelist')
    def _image_whitelist_dict(self, proposal):
        """cast image_whitelist to a dict

        If passing a list, cast it to a {item:item}
        dict where the keys and values are the same.
        """
        whitelist = proposal.value
        if isinstance(whitelist, list):
            whitelist = {item: item for item in whitelist}
        return whitelist

    def _get_image_whitelist(self):
        """Evaluate image_whitelist callable

        Or return the whitelist as-is if it's already a dict
        """
        if callable(self.image_whitelist):
            whitelist = self.image_whitelist(self)
            if not isinstance(whitelist, dict):
                # always return a dict
                whitelist = {item: item for item in whitelist}
            return whitelist
        return self.image_whitelist

    @default('options_form')
    def _default_options_form(self):
        image_whitelist = self._get_image_whitelist()
        if len(image_whitelist) <= 1:
            # default form only when there are images to choose from
            return ''
        # form derived from wrapspawner.ProfileSpawner
        option_t = '<option value="{image}" {selected}>{image}</option>'
        options = [
            option_t.format(image=image,
                            selected='selected' if image == self.image else '')
            for image in image_whitelist
        ]
        return """
        <label for="image">Select an image:</label>
        <select class="form-control" name="image" required autofocus>
        {options}
        </select>
        """.format(options=options)

    def options_from_form(self, formdata):
        """Turn options formdata into user_options"""
        options = {}
        if 'image' in formdata:
            options['image'] = formdata['image'][0]
        return options

    container_prefix = Unicode(config=True,
                               help="DEPRECATED in 0.10. Use prefix")

    container_name_template = Unicode(
        config=True, help="DEPRECATED in 0.10. Use name_template")

    @observe("container_name_template", "container_prefix")
    def _deprecate_container_alias(self, change):
        new_name = change.name[len("container_"):]
        setattr(self, new_name, change.new)

    prefix = Unicode(
        "jupyter",
        config=True,
        help=dedent("""
            Prefix for container names. See name_template for full container name for a particular
            user's server.
            """),
    )

    name_template = Unicode(
        "{prefix}-{username}",
        config=True,
        help=dedent("""
            Name of the container or service: with {username}, {imagename}, {prefix} replacements.
            The default name_template is <prefix>-<username> for backward compatibility.
            """),
    )

    client_kwargs = Dict(
        config=True,
        help=
        "Extra keyword arguments to pass to the docker.Client constructor.",
    )

    volumes = Dict(
        config=True,
        help=dedent("""
            Map from host file/directory to container (guest) file/directory
            mount point and (optionally) a mode. When specifying the
            guest mount point (bind) for the volume, you may use a
            dict or str. If a str, then the volume will default to a
            read-write (mode="rw"). With a dict, the bind is
            identified by "bind" and the "mode" may be one of "rw"
            (default), "ro" (read-only), "z" (public/shared SELinux
            volume label), and "Z" (private/unshared SELinux volume
            label).

            If format_volume_name is not set,
            default_format_volume_name is used for naming volumes.
            In this case, if you use {username} in either the host or guest
            file/directory path, it will be replaced with the current
            user's name.
            """),
    )

    read_only_volumes = Dict(
        config=True,
        help=dedent("""
            Map from host file/directory to container file/directory.
            Volumes specified here will be read-only in the container.

            If format_volume_name is not set,
            default_format_volume_name is used for naming volumes.
            In this case, if you use {username} in either the host or guest
            file/directory path, it will be replaced with the current
            user's name.
            """),
    )

    format_volume_name = Any(
        help=
        """Any callable that accepts a string template and a DockerSpawner instance as parameters in that order and returns a string.

        Reusable implementations should go in dockerspawner.VolumeNamingStrategy, tests should go in ...
        """).tag(config=True)

    def default_format_volume_name(template, spawner):
        return template.format(username=spawner.user.name)

    @default("format_volume_name")
    def _get_default_format_volume_name(self):
        return default_format_volume_name

    use_docker_client_env = Bool(
        True,
        config=True,
        help="DEPRECATED. Docker env variables are always used if present.",
    )

    @observe("use_docker_client_env")
    def _client_env_changed(self):
        self.log.warning(
            "DockerSpawner.use_docker_client_env is deprecated and ignored."
            "  Docker environment variables are always used if defined.")

    tls_config = Dict(
        config=True,
        help="""Arguments to pass to docker TLS configuration.

        See docker.client.TLSConfig constructor for options.
        """,
    )
    tls = tls_verify = tls_ca = tls_cert = tls_key = tls_assert_hostname = Any(
        config=True,
        help=
        """DEPRECATED. Use DockerSpawner.tls_config dict to set any TLS options.""",
    )

    @observe("tls", "tls_verify", "tls_ca", "tls_cert", "tls_key",
             "tls_assert_hostname")
    def _tls_changed(self, change):
        self.log.warning(
            "%s config ignored, use %s.tls_config dict to set full TLS configuration.",
            change.name,
            self.__class__.__name__,
        )

    remove_containers = Bool(
        False,
        config=True,
        help="DEPRECATED in DockerSpawner 0.10. Use .remove")

    @observe("remove_containers")
    def _deprecate_remove_containers(self, change):
        # preserve remove_containers alias to .remove
        self.remove = change.new

    remove = Bool(
        False,
        config=True,
        help="""
        If True, delete containers when servers are stopped.

        This will destroy any data in the container not stored in mounted volumes.
        """,
    )

    @property
    def will_resume(self):
        # indicate that we will resume,
        # so JupyterHub >= 0.7.1 won't cleanup our API token
        return not self.remove

    extra_create_kwargs = Dict(
        config=True, help="Additional args to pass for container create")
    extra_host_config = Dict(
        config=True,
        help="Additional args to create_host_config for container create")

    _docker_safe_chars = set(string.ascii_letters + string.digits + "-")
    _docker_escape_char = "_"

    hub_ip_connect = Unicode(
        config=True,
        help=dedent("""
            If set, DockerSpawner will configure the containers to use
            the specified IP to connect the hub api.  This is useful
            when the hub_api is bound to listen on all ports or is
            running inside of a container.
            """),
    )

    @observe("hub_ip_connect")
    def _ip_connect_changed(self, change):
        if jupyterhub.version_info >= (0, 8):
            warnings.warn(
                "DockerSpawner.hub_ip_connect is no longer needed with JupyterHub 0.8."
                "  Use JupyterHub.hub_connect_ip instead.",
                DeprecationWarning,
            )

    use_internal_ip = Bool(
        False,
        config=True,
        help=dedent("""
            Enable the usage of the internal docker ip. This is useful if you are running
            jupyterhub (as a container) and the user containers within the same docker network.
            E.g. by mounting the docker socket of the host into the jupyterhub container.
            Default is True if using a docker network, False if bridge or host networking is used.
            """),
    )

    @default("use_internal_ip")
    def _default_use_ip(self):
        # setting network_name to something other than bridge or host implies use_internal_ip
        if self.network_name not in {"bridge", "host"}:
            return True

        else:
            return False

    links = Dict(
        config=True,
        help=dedent("""
            Specify docker link mapping to add to the container, e.g.

                links = {'jupyterhub': 'jupyterhub'}

            If the Hub is running in a Docker container,
            this can simplify routing because all traffic will be using docker hostnames.
            """),
    )

    network_name = Unicode(
        "bridge",
        config=True,
        help=dedent("""
            Run the containers on this docker network.
            If it is an internal docker network, the Hub should be on the same network,
            as internal docker IP addresses will be used.
            For bridge networking, external ports will be bound.
            """),
    )

    @property
    def tls_client(self):
        """A tuple consisting of the TLS client certificate and key if they
        have been provided, otherwise None.

        """
        if self.tls_cert and self.tls_key:
            return (self.tls_cert, self.tls_key)

        return None

    @property
    def volume_mount_points(self):
        """
        Volumes are declared in docker-py in two stages.  First, you declare
        all the locations where you're going to mount volumes when you call
        create_container.
        Returns a sorted list of all the values in self.volumes or
        self.read_only_volumes.
        """
        return sorted([value["bind"] for value in self.volume_binds.values()])

    @property
    def volume_binds(self):
        """
        The second half of declaring a volume with docker-py happens when you
        actually call start().  The required format is a dict of dicts that
        looks like:

        {
            host_location: {'bind': container_location, 'mode': 'rw'}
        }
        mode may be 'ro', 'rw', 'z', or 'Z'.

        """
        binds = self._volumes_to_binds(self.volumes, {})
        return self._volumes_to_binds(self.read_only_volumes, binds, mode="ro")

    _escaped_name = None

    @property
    def escaped_name(self):
        """Escape the username so it's safe for docker objects"""
        if self._escaped_name is None:
            self._escaped_name = escape(
                self.user.name,
                safe=self._docker_safe_chars,
                escape_char=self._docker_escape_char,
            )
        return self._escaped_name

    object_id = Unicode(allow_none=True)

    @property
    def object_name(self):
        """Render the name of our container/service using name_template"""
        escaped_image = self.image.replace("/", "_")
        server_name = getattr(self, "name", "")
        d = {
            "username": self.escaped_name,
            "imagename": escaped_image,
            "servername": server_name,
            "prefix": self.prefix,
        }
        return self.name_template.format(**d)

    def load_state(self, state):
        super(DockerSpawner, self).load_state(state)
        if "container_id" in state:
            # backward-compatibility for dockerspawner < 0.10
            self.object_id = state.get("container_id")
        else:
            self.object_id = state.get("object_id", "")

    def get_state(self):
        state = super(DockerSpawner, self).get_state()
        if self.object_id:
            state["object_id"] = self.object_id
        return state

    def _public_hub_api_url(self):
        proto, path = self.hub.api_url.split("://", 1)
        ip, rest = path.split(":", 1)
        return "{proto}://{ip}:{rest}".format(proto=proto,
                                              ip=self.hub_ip_connect,
                                              rest=rest)

    def _env_keep_default(self):
        """Don't inherit any env from the parent process"""
        return []

    def get_args(self):
        args = super().get_args()
        if self.hub_ip_connect:
            # JupyterHub 0.7 specifies --hub-api-url
            # on the command-line, which is hard to update
            for idx, arg in enumerate(list(args)):
                if arg.startswith("--hub-api-url="):
                    args.pop(idx)
                    break

            args.append("--hub-api-url=%s" % self._public_hub_api_url())
        return args

    def _docker(self, method, *args, **kwargs):
        """wrapper for calling docker methods

        to be passed to ThreadPoolExecutor
        """
        m = getattr(self.client, method)
        return m(*args, **kwargs)

    def docker(self, method, *args, **kwargs):
        """Call a docker method in a background thread

        returns a Future
        """
        return self.executor.submit(self._docker, method, *args, **kwargs)

    @gen.coroutine
    def poll(self):
        """Check for my id in `docker ps`"""
        container = yield self.get_object()
        if not container:
            self.log.warning("Container not found: %s", self.container_name)
            return 0

        container_state = container["State"]
        self.log.debug("Container %s status: %s", self.container_id[:7],
                       pformat(container_state))

        if container_state["Running"]:
            return None

        else:
            return ("ExitCode={ExitCode}, "
                    "Error='{Error}', "
                    "FinishedAt={FinishedAt}".format(**container_state))

    @gen.coroutine
    def get_object(self):
        self.log.debug("Getting container '%s'", self.object_name)
        try:
            obj = yield self.docker("inspect_%s" % self.object_type,
                                    self.object_name)
            self.object_id = obj[self.object_id_key]
        except APIError as e:
            if e.response.status_code == 404:
                self.log.info("%s '%s' is gone", self.object_type.title(),
                              self.object_name)
                obj = None
                # my container is gone, forget my id
                self.object_id = ""
            elif e.response.status_code == 500:
                self.log.info(
                    "%s '%s' is on unhealthy node",
                    self.object_type.title(),
                    self.object_name,
                )
                obj = None
                # my container is unhealthy, forget my id
                self.object_id = ""
            else:
                raise

        return obj

    @gen.coroutine
    def get_command(self):
        """Get the command to run (full command + args)"""
        if self._user_set_cmd:
            cmd = self.cmd
        else:
            image_info = yield self.docker("inspect_image", self.image)
            cmd = image_info["Config"]["Cmd"]
        return cmd + self.get_args()

    @gen.coroutine
    def remove_object(self):
        self.log.info("Removing %s %s", self.object_type, self.object_id)
        # remove the container, as well as any associated volumes
        yield self.docker("remove_" + self.object_type, self.object_id, v=True)

    @gen.coroutine
    def check_image_whitelist(self, image):
        image_whitelist = self._get_image_whitelist()
        if not image_whitelist:
            return image
        if image not in image_whitelist:
            raise web.HTTPError(
                400,
                "Image %s not in whitelist: %s" %
                (image, ', '.join(image_whitelist)),
            )
        # resolve image alias to actual image name
        return image_whitelist[image]

    @gen.coroutine
    def create_object(self):
        """Create the container/service object"""
        # image priority:
        # 1. user options (from spawn options form)
        # 2. self.image from config
        image_option = self.user_options.get('image')
        if image_option:
            # save choice in self.image
            self.image = yield self.check_image_whitelist(image_option)

        create_kwargs = dict(
            image=self.image,
            environment=self.get_env(),
            volumes=self.volume_mount_points,
            name=self.container_name,
            command=(yield self.get_command()),
        )

        # ensure internal port is exposed
        create_kwargs["ports"] = {"%i/tcp" % self.port: None}

        create_kwargs.update(self.extra_create_kwargs)

        # build the dictionary of keyword arguments for host_config
        host_config = dict(binds=self.volume_binds, links=self.links)

        if getattr(self, "mem_limit", None) is not None:
            # If jupyterhub version > 0.7, mem_limit is a traitlet that can
            # be directly configured. If so, use it to set mem_limit.
            # this will still be overriden by extra_host_config
            host_config["mem_limit"] = self.mem_limit

        if not self.use_internal_ip:
            host_config["port_bindings"] = {self.port: (self.host_ip, )}
        host_config.update(self.extra_host_config)
        host_config.setdefault("network_mode", self.network_name)

        self.log.debug("Starting host with config: %s", host_config)

        host_config = self.client.create_host_config(**host_config)
        create_kwargs.setdefault("host_config", {}).update(host_config)

        # create the container
        obj = yield self.docker("create_container", **create_kwargs)
        return obj

    @gen.coroutine
    def start_object(self):
        """Actually start the container/service

        e.g. calling `docker start`
        """
        return self.docker("start", self.container_id)

    @gen.coroutine
    def stop_object(self):
        """Actually start the container/service

        e.g. calling `docker start`
        """
        return self.docker("stop", self.container_id)

    @gen.coroutine
    def start(self,
              image=None,
              extra_create_kwargs=None,
              extra_host_config=None):
        """Start the single-user server in a docker container.

        Additional arguments to create/host config/etc. can be specified
        via .extra_create_kwargs and .extra_host_config attributes.
        """

        if image:
            self.log.warning("Specifying image via .start args is deprecated")
            self.image = image
        if extra_create_kwargs:
            self.log.warning(
                "Specifying extra_create_kwargs via .start args is deprecated")
            self.extra_create_kwargs.update(extra_create_kwargs)
        if extra_host_config:
            self.log.warning(
                "Specifying extra_host_config via .start args is deprecated")
            self.extra_host_config.update(extra_host_config)

        image = self.image

        obj = yield self.get_object()
        if obj and self.remove:
            self.log.warning(
                "Removing %s that should have been cleaned up: %s (id: %s)",
                self.object_type,
                self.object_name,
                self.object_id[:7],
            )
            yield self.remove_object()

            obj = None

        if obj is None:
            obj = yield self.create_object()
            self.object_id = obj[self.object_id_key]
            self.log.info(
                "Created %s %s (id: %s) from image %s",
                self.object_type,
                self.object_name,
                self.object_id[:7],
                self.image,
            )

        else:
            self.log.info(
                "Found existing %s %s (id: %s)",
                self.object_type,
                self.object_name,
                self.object_id[:7],
            )
            # Handle re-using API token.
            # Get the API token from the environment variables
            # of the running container:
            for line in obj["Config"]["Env"]:
                if line.startswith(
                    ("JPY_API_TOKEN=", "JUPYTERHUB_API_TOKEN=")):
                    self.api_token = line.split("=", 1)[1]
                    break

        # TODO: handle unpause
        self.log.info(
            "Starting %s %s (id: %s)",
            self.object_type,
            self.object_name,
            self.container_id[:7],
        )

        # start the container
        yield self.start_object()

        ip, port = yield self.get_ip_and_port()
        if jupyterhub.version_info < (0, 7):
            # store on user for pre-jupyterhub-0.7:
            self.user.server.ip = ip
            self.user.server.port = port
        # jupyterhub 0.7 prefers returning ip, port:
        return (ip, port)

    @gen.coroutine
    def get_ip_and_port(self):
        """Queries Docker daemon for container's IP and port.

        If you are using network_mode=host, you will need to override
        this method as follows::

            @gen.coroutine
            def get_ip_and_port(self):
                return self.host_ip, self.port

        You will need to make sure host_ip and port
        are correct, which depends on the route to the container
        and the port it opens.
        """
        if self.use_internal_ip:
            resp = yield self.docker("inspect_container", self.container_id)
            network_settings = resp["NetworkSettings"]
            if "Networks" in network_settings:
                ip = self.get_network_ip(network_settings)
            else:  # Fallback for old versions of docker (<1.9) without network management
                ip = network_settings["IPAddress"]
            port = self.port
        else:
            resp = yield self.docker("port", self.container_id, self.port)
            if resp is None:
                raise RuntimeError("Failed to get port info for %s" %
                                   self.container_id)

            ip = resp[0]["HostIp"]
            port = int(resp[0]["HostPort"])

        if ip == "0.0.0.0":
            ip = urlparse(self.client.base_url).hostname
            if ip == "localnpipe":
                ip = "localhost"

        return ip, port

    def get_network_ip(self, network_settings):
        networks = network_settings["Networks"]
        if self.network_name not in networks:
            raise Exception(
                "Unknown docker network '{network}'."
                " Did you create it with `docker network create <name>`?".
                format(network=self.network_name))

        network = networks[self.network_name]
        ip = network["IPAddress"]
        return ip

    @gen.coroutine
    def stop(self, now=False):
        """Stop the container

        Consider using pause/unpause when docker-py adds support
        """
        self.log.info(
            "Stopping %s %s (id: %s)",
            self.object_type,
            self.object_name,
            self.object_id[:7],
        )
        yield self.stop_object()

        if self.remove:
            yield self.remove_object()

        self.clear_state()

    def _volumes_to_binds(self, volumes, binds, mode="rw"):
        """Extract the volume mount points from volumes property.

        Returns a dict of dict entries of the form::

            {'/host/dir': {'bind': '/guest/dir': 'mode': 'rw'}}
        """
        def _fmt(v):
            return self.format_volume_name(v, self)

        for k, v in volumes.items():
            m = mode
            if isinstance(v, dict):
                if "mode" in v:
                    m = v["mode"]
                v = v["bind"]
            binds[_fmt(k)] = {"bind": _fmt(v), "mode": m}
        return binds
Пример #30
0
class SwarmSpawner(Spawner):
    """A Spawner for JupyterHub using Docker Engine in Swarm mode
    """

    _executor = None

    @property
    def executor(self, max_workers=1):
        """single global executor"""
        cls = self.__class__
        if cls._executor is None:
            cls._executor = ThreadPoolExecutor(max_workers)
        return cls._executor

    _client = None

    @property
    def client(self):
        """single global client instance"""
        cls = self.__class__

        if cls._client is None:
            kwargs = {}
            if self.tls_config:
                kwargs['tls'] = docker.tls.TLSConfig(**self.tls_config)
            kwargs.update(kwargs_from_env())
            client = docker.APIClient(version='auto', **kwargs)

            cls._client = client
        return cls._client

    service_id = Unicode()
    service_port = Int(8888, min=1, max=65535, config=True)
    service_image = Unicode("jupyterhub/singleuser", config=True)
    service_prefix = Unicode("jupyter",
                             config=True,
                             help=dedent("""
            Prefix for service names. The full service name for a particular
            user will be <prefix>-<hash(username)>-<server_name>.
            """))
    tls_config = Dict(
        config=True,
        help=dedent("""Arguments to pass to docker TLS configuration.
            Check for more info: http://docker-py.readthedocs.io/en/stable/tls.html
            """))

    container_spec = Dict({}, config=True, help="Params to create the service")
    resource_spec = Dict({},
                         config=True,
                         help="Params about cpu and memory limits")

    placement = List(
        [],
        config=True,
        help=dedent("""List of placement constraints into the swarm
                         """))

    networks = List(
        [],
        config=True,
        help=dedent("""Additional args to create_host_config for service create
                        """))
    use_user_options = Bool(
        False,
        config=True,
        help=dedent("""the spawner will use the dict passed through the form
                                or as json body when using the Hub Api
                                """))
    jupyterhub_service_name = Unicode(
        config=True,
        help=dedent("""Name of the service running the JupyterHub
                                          """))

    @property
    def tls_client(self):
        """A tuple consisting of the TLS client certificate and key if they
        have been provided, otherwise None.

        """
        if self.tls_cert and self.tls_key:
            return (self.tls_cert, self.tls_key)
        return None

    _service_owner = None

    @property
    def service_owner(self):
        if self._service_owner is None:
            m = hashlib.md5()
            m.update(self.user.name.encode('utf-8'))
            self._service_owner = m.hexdigest()
        return self._service_owner

    @property
    def service_name(self):
        """
        Service name inside the Docker Swarm

        service_suffix should be a numerical value unique for user
        {service_prefix}-{service_owner}-{service_suffix}
        """
        if hasattr(self, "server_name") and self.server_name:
            server_name = self.server_name
        else:
            server_name = 1

        return "{}-{}-{}".format(self.service_prefix, self.service_owner,
                                 server_name)

    def load_state(self, state):
        super().load_state(state)
        self.service_id = state.get('service_id', '')

    def get_state(self):
        state = super().get_state()
        if self.service_id:
            state['service_id'] = self.service_id
        return state

    def _env_keep_default(self):
        """it's called in traitlets. It's a special method name.
        Don't inherit any env from the parent process"""
        return []

    def _public_hub_api_url(self):
        proto, path = self.hub.api_url.split('://', 1)
        _, rest = path.split(':', 1)
        return '{proto}://{name}:{rest}'.format(
            proto=proto, name=self.jupyterhub_service_name, rest=rest)

    def get_env(self):
        env = super().get_env()
        env.update(
            dict(JPY_USER=self.user.name,
                 JPY_COOKIE_NAME=self.user.server.cookie_name,
                 JPY_BASE_URL=self.user.server.base_url,
                 JPY_HUB_PREFIX=self.hub.server.base_url))

        if self.notebook_dir:
            env['NOTEBOOK_DIR'] = self.notebook_dir

        env['JPY_HUB_API_URL'] = self._public_hub_api_url()

        return env

    def _docker(self, method, *args, **kwargs):
        """wrapper for calling docker methods

        to be passed to ThreadPoolExecutor
        """
        m = getattr(self.client, method)
        return m(*args, **kwargs)

    def docker(self, method, *args, **kwargs):
        """Call a docker method in a background thread

        returns a Future
        """
        return self.executor.submit(self._docker, method, *args, **kwargs)

    @gen.coroutine
    def poll(self):
        """Check for a task state like `docker service ps id`"""
        service = yield self.get_service()
        if not service:
            self.log.warn("Docker service not found")
            return 0

        task_filter = {'service': service['Spec']['Name']}

        tasks = yield self.docker('tasks', task_filter)

        running_task = None
        for task in tasks:
            task_state = task['Status']['State']
            self.log.debug(
                "Task %s of Docker service %s status: %s",
                task['ID'][:7],
                self.service_id[:7],
                pformat(task_state),
            )
            if task_state == 'running':
                # there should be at most one running task
                running_task = task

        if running_task is not None:
            return None
        else:
            return 1

    @gen.coroutine
    def get_service(self):
        self.log.debug("Getting Docker service '%s'", self.service_name)
        try:
            service = yield self.docker('inspect_service', self.service_name)
            self.service_id = service['ID']
        except APIError as err:
            if err.response.status_code == 404:
                self.log.info("Docker service '%s' is gone", self.service_name)
                service = None
                # Docker service is gone, remove service id
                self.service_id = ''
            elif err.response.status_code == 500:
                self.log.info("Docker Swarm Server error")
                service = None
                # Docker service is unhealthy, remove the service_id
                self.service_id = ''
            else:
                raise
        return service

    @gen.coroutine
    def start(self):
        """Start the single-user server in a docker service.
        You can specify the params for the service through jupyterhub_config.py
        or using the user_options
        """

        # https://github.com/jupyterhub/jupyterhub/blob/master/jupyterhub/user.py#L202
        # By default jupyterhub calls the spawner passing user_options
        if self.use_user_options:
            user_options = self.user_options
        else:
            user_options = {}

        self.log.warn("user_options: {}".format(user_options))

        service = yield self.get_service()

        if service is None:

            if 'name' in user_options:
                self.server_name = user_options['name']

            if hasattr(self,
                       'container_spec') and self.container_spec is not None:
                container_spec = dict(**self.container_spec)
            elif user_options == {}:
                raise ("A container_spec is needed in to create a service")

            container_spec.update(user_options.get('container_spec', {}))

            # iterates over mounts to create
            # a new mounts list of docker.types.Mount
            container_spec['mounts'] = []
            for mount in self.container_spec['mounts']:
                m = dict(**mount)

                if 'source' in m:
                    m['source'] = m['source'].format(
                        username=self.service_owner)

                if 'driver_config' in m:
                    device = m['driver_config']['options']['device'].format(
                        username=self.service_owner)
                    m['driver_config']['options']['device'] = device
                    m['driver_config'] = docker.types.DriverConfig(
                        **m['driver_config'])

                container_spec['mounts'].append(docker.types.Mount(**m))

            # some Envs are required by the single-user-image
            container_spec['env'] = self.get_env()

            if hasattr(self, 'resource_spec'):
                resource_spec = dict(**self.resource_spec)
            resource_spec.update(user_options.get('resource_spec', {}))
            # enable to set a human readable memory unit
            if 'mem_limit' in resource_spec:
                resource_spec['mem_limit'] = parse_bytes(
                    resource_spec['mem_limit'])
            if 'mem_reservation' in resource_spec:
                resource_spec['mem_reservation'] = parse_bytes(
                    resource_spec['mem_reservation'])

            if hasattr(self, 'networks'):
                networks = self.networks
            if user_options.get('networks') is not None:
                networks = user_options.get('networks')

            if hasattr(self, 'placement'):
                placement = self.placement
            if user_options.get('placement') is not None:
                placement = user_options.get('placement')

            image = container_spec['Image']
            del container_spec['Image']

            # create the service
            container_spec = docker.types.ContainerSpec(
                image, **container_spec)
            resources = docker.types.Resources(**resource_spec)

            task_spec = {
                'container_spec': container_spec,
                'resources': resources,
                'placement': placement
            }
            task_tmpl = docker.types.TaskTemplate(**task_spec)

            resp = yield self.docker('create_service',
                                     task_tmpl,
                                     name=self.service_name,
                                     networks=networks)

            self.service_id = resp['ID']

            self.log.info("Created Docker service '%s' (id: %s) from image %s",
                          self.service_name, self.service_id[:7], image)

        else:
            self.log.info("Found existing Docker service '%s' (id: %s)",
                          self.service_name, self.service_id[:7])
            # Handle re-using API token.
            # Get the API token from the environment variables
            # of the running service:
            envs = service['Spec']['TaskTemplate']['ContainerSpec']['Env']
            for line in envs:
                if line.startswith('JPY_API_TOKEN='):
                    self.api_token = line.split('=', 1)[1]
                    break

        ip = self.service_name
        port = self.service_port

        # we use service_name instead of ip
        # https://docs.docker.com/engine/swarm/networking/#use-swarm-mode-service-discovery
        # service_port is actually equal to 8888
        return (ip, port)

    @gen.coroutine
    def stop(self, now=False):
        """Stop and remove the service

        Consider using stop/start when Docker adds support
        """
        self.log.info("Stopping and removing Docker service %s (id: %s)",
                      self.service_name, self.service_id[:7])
        yield self.docker('remove_service', self.service_id[:7])
        self.log.info("Docker service %s (id: %s) removed", self.service_name,
                      self.service_id[:7])

        self.clear_state()